1 package handlers
2
3 import (
4 "net/http"
5 "net/http/httptest"
6 "strings"
7 "testing"
8 )
9
10 func TestDefaultCORSHandlerReturnsOk(t *testing.T) {
11 r := newRequest("GET", "http://www.example.com/")
12 rr := httptest.NewRecorder()
13
14 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
15
16 CORS()(testHandler).ServeHTTP(rr, r)
17
18 if got, want := rr.Code, http.StatusOK; got != want {
19 t.Fatalf("bad status: got %v want %v", got, want)
20 }
21 }
22
23 func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) {
24 r := newRequest("GET", "http://www.example.com/")
25 r.Header.Set("Origin", r.URL.String())
26
27 rr := httptest.NewRecorder()
28
29 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
30
31 CORS()(testHandler).ServeHTTP(rr, r)
32
33 if got, want := rr.Code, http.StatusOK; got != want {
34 t.Fatalf("bad status: got %v want %v", got, want)
35 }
36 }
37
38 func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) {
39 r := newRequest("OPTIONS", "http://www.example.com/")
40 r.Header.Set("Origin", r.URL.String())
41
42 rr := httptest.NewRecorder()
43
44 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
45 w.WriteHeader(http.StatusTeapot)
46 })
47
48 CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r)
49
50 if got, want := rr.Code, http.StatusTeapot; got != want {
51 t.Fatalf("bad status: got %v want %v", got, want)
52 }
53 }
54
55 func TestCORSHandlerSetsExposedHeaders(t *testing.T) {
56
57 r := newRequest("GET", "http://www.example.com/")
58 r.Header.Set("Origin", r.URL.String())
59
60 rr := httptest.NewRecorder()
61
62 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
63
64 CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r)
65
66 if got, want := rr.Code, http.StatusOK; got != want {
67 t.Fatalf("bad status: got %v want %v", got, want)
68 }
69
70 header := rr.HeaderMap.Get(corsExposeHeadersHeader)
71 if got, want := header, "X-Cors-Test"; got != want {
72 t.Fatalf("bad header: expected %q header, got empty header for method.", want)
73 }
74 }
75
76 func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) {
77 r := newRequest("OPTIONS", "http://www.example.com/")
78 r.Header.Set("Origin", r.URL.String())
79
80 rr := httptest.NewRecorder()
81
82 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
83
84 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
85
86 if got, want := rr.Code, http.StatusBadRequest; got != want {
87 t.Fatalf("bad status: got %v want %v", got, want)
88 }
89 }
90
91 func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) {
92 r := newRequest("OPTIONS", "http://www.example.com/")
93 r.Header.Set("Origin", r.URL.String())
94 r.Header.Set(corsRequestMethodHeader, "DELETE")
95
96 rr := httptest.NewRecorder()
97
98 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
99
100 CORS()(testHandler).ServeHTTP(rr, r)
101
102 if got, want := rr.Code, http.StatusMethodNotAllowed; got != want {
103 t.Fatalf("bad status: got %v want %v", got, want)
104 }
105 }
106
107 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) {
108 r := newRequest("OPTIONS", "http://www.example.com/")
109 r.Header.Set("Origin", r.URL.String())
110 r.Header.Set(corsRequestMethodHeader, "GET")
111
112 rr := httptest.NewRecorder()
113
114 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
115 t.Fatal("Options request must not be passed to next handler")
116 })
117
118 CORS()(testHandler).ServeHTTP(rr, r)
119
120 if got, want := rr.Code, http.StatusOK; got != want {
121 t.Fatalf("bad status: got %v want %v", got, want)
122 }
123 }
124
125 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) {
126 statusCode := http.StatusNoContent
127 r := newRequest("OPTIONS", "http://www.example.com/")
128 r.Header.Set("Origin", r.URL.String())
129 r.Header.Set(corsRequestMethodHeader, "GET")
130
131 rr := httptest.NewRecorder()
132
133 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
134 t.Fatal("Options request must not be passed to next handler")
135 })
136
137 CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r)
138
139 if got, want := rr.Code, statusCode; got != want {
140 t.Fatalf("bad status: got %v want %v", got, want)
141 }
142 }
143
144 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) {
145 r := newRequest("OPTIONS", "http://www.example.com/")
146 r.Header.Set("Origin", r.URL.String())
147 r.Header.Set(corsRequestMethodHeader, "GET")
148
149 rr := httptest.NewRecorder()
150
151 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152 t.Fatal("Options request must not be passed to next handler")
153 })
154
155 CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r)
156
157 if got, want := rr.Code, http.StatusOK; got != want {
158 t.Fatalf("bad status: got %v want %v", got, want)
159 }
160 }
161
162 func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) {
163 r := newRequest("OPTIONS", "http://www.example.com/")
164 r.Header.Set("Origin", r.URL.String())
165 r.Header.Set(corsRequestMethodHeader, "DELETE")
166
167 rr := httptest.NewRecorder()
168
169 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
170
171 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
172
173 if got, want := rr.Code, http.StatusOK; got != want {
174 t.Fatalf("bad status: got %v want %v", got, want)
175 }
176
177 header := rr.HeaderMap.Get(corsAllowMethodsHeader)
178 if got, want := header, "DELETE"; got != want {
179 t.Fatalf("bad header: expected %q method header, got %q header.", want, got)
180 }
181 }
182
183 func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) {
184 for _, method := range defaultCorsMethods {
185 r := newRequest("OPTIONS", "http://www.example.com/")
186 r.Header.Set("Origin", r.URL.String())
187 r.Header.Set(corsRequestMethodHeader, method)
188
189 rr := httptest.NewRecorder()
190
191 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
192
193 CORS()(testHandler).ServeHTTP(rr, r)
194
195 if got, want := rr.Code, http.StatusOK; got != want {
196 t.Fatalf("bad status: got %v want %v", got, want)
197 }
198
199 header := rr.HeaderMap.Get(corsAllowMethodsHeader)
200 if got, want := header, ""; got != want {
201 t.Fatalf("bad header: expected %q method header, got %q.", want, got)
202 }
203 }
204 }
205
206 func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) {
207 for _, simpleHeader := range defaultCorsHeaders {
208 r := newRequest("OPTIONS", "http://www.example.com/")
209 r.Header.Set("Origin", r.URL.String())
210 r.Header.Set(corsRequestMethodHeader, "GET")
211 r.Header.Set(corsRequestHeadersHeader, simpleHeader)
212
213 rr := httptest.NewRecorder()
214
215 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
216
217 CORS()(testHandler).ServeHTTP(rr, r)
218
219 if got, want := rr.Code, http.StatusOK; got != want {
220 t.Fatalf("bad status: got %v want %v", got, want)
221 }
222
223 header := rr.HeaderMap.Get(corsAllowHeadersHeader)
224 if got, want := header, ""; got != want {
225 t.Fatalf("bad header: expected %q header, got %q.", want, got)
226 }
227 }
228 }
229
230 func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) {
231 r := newRequest("OPTIONS", "http://www.example.com/")
232 r.Header.Set("Origin", r.URL.String())
233 r.Header.Set(corsRequestMethodHeader, "POST")
234 r.Header.Set(corsRequestHeadersHeader, "Content-Type")
235
236 rr := httptest.NewRecorder()
237
238 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
239
240 CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r)
241
242 if got, want := rr.Code, http.StatusOK; got != want {
243 t.Fatalf("bad status: got %v want %v", got, want)
244 }
245
246 header := rr.HeaderMap.Get(corsAllowHeadersHeader)
247 if got, want := header, "Content-Type"; got != want {
248 t.Fatalf("bad header: expected %q header, got %q header.", want, got)
249 }
250 }
251
252 func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) {
253 r := newRequest("OPTIONS", "http://www.example.com/")
254 r.Header.Set("Origin", r.URL.String())
255 r.Header.Set(corsRequestMethodHeader, "POST")
256 r.Header.Set(corsRequestHeadersHeader, "Content-Type")
257
258 rr := httptest.NewRecorder()
259
260 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
261
262 CORS()(testHandler).ServeHTTP(rr, r)
263
264 if got, want := rr.Code, http.StatusForbidden; got != want {
265 t.Fatalf("bad status: got %v want %v", got, want)
266 }
267 }
268
269 func TestCORSHandlerMaxAgeForPreflight(t *testing.T) {
270 r := newRequest("OPTIONS", "http://www.example.com/")
271 r.Header.Set("Origin", r.URL.String())
272 r.Header.Set(corsRequestMethodHeader, "POST")
273
274 rr := httptest.NewRecorder()
275
276 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
277
278 CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r)
279
280 if got, want := rr.Code, http.StatusOK; got != want {
281 t.Fatalf("bad status: got %v want %v", got, want)
282 }
283
284 header := rr.HeaderMap.Get(corsMaxAgeHeader)
285 if got, want := header, "600"; got != want {
286 t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got)
287 }
288 }
289
290 func TestCORSHandlerAllowedCredentials(t *testing.T) {
291 r := newRequest("GET", "http://www.example.com/")
292 r.Header.Set("Origin", r.URL.String())
293
294 rr := httptest.NewRecorder()
295
296 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
297
298 CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r)
299
300 if status := rr.Code; status != http.StatusOK {
301 t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
302 }
303
304 header := rr.HeaderMap.Get(corsAllowCredentialsHeader)
305 if got, want := header, "true"; got != want {
306 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowCredentialsHeader, want, got)
307 }
308 }
309
310 func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) {
311 r := newRequest("GET", "http://www.example.com/")
312 r.Header.Set("Origin", r.URL.String())
313
314 rr := httptest.NewRecorder()
315
316 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
317
318 CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r)
319
320 if status := rr.Code; status != http.StatusOK {
321 t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
322 }
323
324 header := rr.HeaderMap.Get(corsVaryHeader)
325 if got, want := header, corsOriginHeader; got != want {
326 t.Fatalf("bad header: expected %s to be %q, got %q.", corsVaryHeader, want, got)
327 }
328 }
329
330 func TestCORSWithMultipleHandlers(t *testing.T) {
331 var lastHandledBy string
332 corsMiddleware := CORS()
333
334 testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
335 lastHandledBy = "testHandler1"
336 })
337 testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
338 lastHandledBy = "testHandler2"
339 })
340
341 r1 := newRequest("GET", "http://www.example.com/")
342 rr1 := httptest.NewRecorder()
343 handler1 := corsMiddleware(testHandler1)
344
345 corsMiddleware(testHandler2)
346
347 handler1.ServeHTTP(rr1, r1)
348 if lastHandledBy != "testHandler1" {
349 t.Fatalf("bad CORS() registration: Handler served should be Handler registered")
350 }
351 }
352
353 func TestCORSOriginValidatorWithImplicitStar(t *testing.T) {
354 r := newRequest("GET", "http://a.example.com")
355 r.Header.Set("Origin", r.URL.String())
356 rr := httptest.NewRecorder()
357
358 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
359
360 originValidator := func(origin string) bool {
361 if strings.HasSuffix(origin, ".example.com") {
362 return true
363 }
364 return false
365 }
366
367 CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r)
368 header := rr.HeaderMap.Get(corsAllowOriginHeader)
369 if got, want := header, r.URL.String(); got != want {
370 t.Fatalf("bad header: expected %s to be %q, got %q.", corsAllowOriginHeader, want, got)
371 }
372 }
373
374 func TestCORSOriginValidatorWithExplicitStar(t *testing.T) {
375 r := newRequest("GET", "http://a.example.com")
376 r.Header.Set("Origin", r.URL.String())
377 rr := httptest.NewRecorder()
378
379 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
380
381 originValidator := func(origin string) bool {
382 if strings.HasSuffix(origin, ".example.com") {
383 return true
384 }
385 return false
386 }
387
388 CORS(
389 AllowedOriginValidator(originValidator),
390 AllowedOrigins([]string{"*"}),
391 )(testHandler).ServeHTTP(rr, r)
392 header := rr.HeaderMap.Get(corsAllowOriginHeader)
393 if got, want := header, "*"; got != want {
394 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got)
395 }
396 }
397
398 func TestCORSAllowStar(t *testing.T) {
399 r := newRequest("GET", "http://a.example.com")
400 r.Header.Set("Origin", r.URL.String())
401 rr := httptest.NewRecorder()
402
403 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
404
405 CORS()(testHandler).ServeHTTP(rr, r)
406 header := rr.HeaderMap.Get(corsAllowOriginHeader)
407 if got, want := header, "*"; got != want {
408 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got)
409 }
410 }
411
View as plain text