1 package imds
2
3 import (
4 "net/http"
5 "strconv"
6 "sync"
7 "sync/atomic"
8 "testing"
9 "time"
10 )
11
12 type APIHandlers interface {
13 GetAPITokenHandler() http.Handler
14 GetAPIHandler() http.Handler
15 }
16
17 func newTestServeMux(t *testing.T, handlers APIHandlers) *http.ServeMux {
18 mux := http.NewServeMux()
19
20 mux.Handle(getTokenPath, validateAPITokenRequest(t, handlers.GetAPITokenHandler()))
21 mux.Handle("/latest/", handlers.GetAPIHandler())
22
23 return mux
24 }
25
26 func validateAPITokenRequest(t *testing.T, handler http.Handler) http.Handler {
27 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28 if e, a := "PUT", r.Method; e != a {
29 t.Errorf("expect %v, http method got %v", e, a)
30 http.Error(w, http.StatusText(400), 400)
31 return
32 }
33 if len(r.Header.Get(tokenTTLHeader)) == 0 {
34 t.Errorf("expect token TTL header to be present in the request headers, got none")
35 http.Error(w, http.StatusText(400), 400)
36 return
37 }
38
39 handler.ServeHTTP(w, r)
40 })
41 }
42
43 type secureAPIHandler struct {
44 t *testing.T
45
46 tokens []string
47 tokenTTL time.Duration
48 apiHandler http.Handler
49
50 activeToken atomic.Value
51 }
52
53 func newSecureAPIHandler(t *testing.T, tokens []string, tokenTTL time.Duration, apiHandler http.Handler) *secureAPIHandler {
54 return &secureAPIHandler{
55 t: t,
56 tokens: tokens,
57 tokenTTL: tokenTTL,
58 apiHandler: apiHandler,
59 }
60 }
61
62 func (h *secureAPIHandler) GetAPITokenHandler() http.Handler {
63 return http.HandlerFunc(h.handleAPIToken)
64 }
65
66 func (h *secureAPIHandler) handleAPIToken(w http.ResponseWriter, r *http.Request) {
67 token := h.tokens[0]
68
69
70 h.storeActiveToken(token)
71
72
73 if len(h.tokens) > 1 {
74 h.tokens = h.tokens[1:]
75 }
76
77 var tokenTTLHeaderVal string
78 if h.tokenTTL == 0 {
79 tokenTTLHeaderVal = r.Header.Get(tokenTTLHeader)
80 } else {
81 tokenTTLHeaderVal = strconv.Itoa(int(h.tokenTTL / time.Second))
82 }
83
84
85 w.Header().Set(tokenTTLHeader, tokenTTLHeaderVal)
86 activeToken := h.getActiveToken()
87
88 w.Write([]byte(activeToken))
89 }
90
91 func (h *secureAPIHandler) GetAPIHandler() http.Handler {
92 return http.HandlerFunc(h.handleAPI)
93 }
94
95 func (h *secureAPIHandler) handleAPI(w http.ResponseWriter, r *http.Request) {
96 token := h.getActiveToken()
97 if len(token) == 0 {
98 h.t.Errorf("expect token to have been requested, was not")
99 http.Error(w, http.StatusText(401), 401)
100 return
101 }
102
103 if e, a := token, r.Header.Get(tokenHeader); e != a {
104 h.t.Errorf("expect %v token, got %v", e, a)
105 http.Error(w, http.StatusText(401), 401)
106 return
107 }
108
109
110 h.apiHandler.ServeHTTP(w, r)
111 }
112
113 func (h *secureAPIHandler) storeActiveToken(t string) {
114 h.activeToken.Store(t)
115 }
116
117 func (h *secureAPIHandler) getActiveToken() string {
118 activeToken := h.activeToken.Load()
119 v, ok := activeToken.(string)
120 if !ok {
121 h.t.Errorf("expect valid active token string, got %T, %v", v, v)
122 }
123
124 return v
125 }
126
127 type insecureAPIHandler struct {
128 t *testing.T
129 apiTokenErrCode int
130 apiHandler http.Handler
131 }
132
133 func newInsecureAPIHandler(t *testing.T, apiTokenErrCode int, apiHandler http.Handler) *insecureAPIHandler {
134 return &insecureAPIHandler{
135 t: t,
136 apiTokenErrCode: apiTokenErrCode,
137 apiHandler: apiHandler,
138 }
139 }
140
141 func (h *insecureAPIHandler) GetAPITokenHandler() http.Handler {
142 return http.HandlerFunc(h.handleAPIToken)
143 }
144
145 func (h *insecureAPIHandler) handleAPIToken(w http.ResponseWriter, r *http.Request) {
146 http.Error(w, http.StatusText(h.apiTokenErrCode), h.apiTokenErrCode)
147 }
148
149 func (h *insecureAPIHandler) GetAPIHandler() http.Handler {
150 return http.HandlerFunc(h.handleAPI)
151 }
152
153 func (h *insecureAPIHandler) handleAPI(w http.ResponseWriter, r *http.Request) {
154 if len(r.Header.Get(tokenHeader)) != 0 {
155 h.t.Errorf("request token found, expected none")
156 http.Error(w, http.StatusText(400), 400)
157 return
158 }
159
160
161 h.apiHandler.ServeHTTP(w, r)
162 }
163
164 type unauthorizedAPIHandler struct {
165 t *testing.T
166
167 enabled bool
168 secureAPIHandler *secureAPIHandler
169 }
170
171 func newUnauthorizedAPIHandler(t *testing.T, secureHandler *secureAPIHandler) *unauthorizedAPIHandler {
172 return &unauthorizedAPIHandler{
173 t: t,
174 secureAPIHandler: secureHandler,
175 }
176 }
177
178 func (h *unauthorizedAPIHandler) GetAPITokenHandler() http.Handler {
179 return http.HandlerFunc(h.handleAPIToken)
180 }
181
182 func (h *unauthorizedAPIHandler) handleAPIToken(w http.ResponseWriter, r *http.Request) {
183
184 if !h.enabled {
185 http.Error(w, http.StatusText(404), 404)
186 return
187 }
188
189 h.secureAPIHandler.GetAPITokenHandler().ServeHTTP(w, r)
190 }
191
192 func (h *unauthorizedAPIHandler) GetAPIHandler() http.Handler {
193 return http.HandlerFunc(h.handleAPI)
194 }
195
196 func (h *unauthorizedAPIHandler) handleAPI(w http.ResponseWriter, r *http.Request) {
197
198
199 if !h.enabled {
200 h.enabled = true
201 http.Error(w, http.StatusText(401), 401)
202 return
203 }
204
205 h.secureAPIHandler.GetAPIHandler().ServeHTTP(w, r)
206 }
207
208 type requestTrace struct {
209 requests []string
210 mu sync.Mutex
211 }
212
213 func newRequestTrace() *requestTrace {
214 return &requestTrace{}
215 }
216
217 func (t *requestTrace) WrapHandler(handler http.Handler) http.Handler {
218 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
219 t.mu.Lock()
220 t.requests = append(t.requests, r.URL.Path)
221 t.mu.Unlock()
222
223 handler.ServeHTTP(w, r)
224 })
225 }
226
View as plain text