1 package mux
2
3 import (
4 "bytes"
5 "net/http"
6 "testing"
7 )
8
9 type testMiddleware struct {
10 timesCalled uint
11 }
12
13 func (tm *testMiddleware) Middleware(h http.Handler) http.Handler {
14 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15 tm.timesCalled++
16 h.ServeHTTP(w, r)
17 })
18 }
19
20 func dummyHandler(w http.ResponseWriter, r *http.Request) {}
21
22 func TestMiddlewareAdd(t *testing.T) {
23 router := NewRouter()
24 router.HandleFunc("/", dummyHandler).Methods("GET")
25
26 mw := &testMiddleware{}
27
28 router.useInterface(mw)
29 if len(router.middlewares) != 1 || router.middlewares[0] != mw {
30 t.Fatal("Middleware interface was not added correctly")
31 }
32
33 router.Use(mw.Middleware)
34 if len(router.middlewares) != 2 {
35 t.Fatal("Middleware method was not added correctly")
36 }
37
38 banalMw := func(handler http.Handler) http.Handler {
39 return handler
40 }
41 router.Use(banalMw)
42 if len(router.middlewares) != 3 {
43 t.Fatal("Middleware function was not added correctly")
44 }
45 }
46
47 func TestMiddleware(t *testing.T) {
48 router := NewRouter()
49 router.HandleFunc("/", dummyHandler).Methods("GET")
50
51 mw := &testMiddleware{}
52 router.useInterface(mw)
53
54 rw := NewRecorder()
55 req := newRequest("GET", "/")
56
57 t.Run("regular middleware call", func(t *testing.T) {
58 router.ServeHTTP(rw, req)
59 if mw.timesCalled != 1 {
60 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
61 }
62 })
63
64 t.Run("not called for 404", func(t *testing.T) {
65 req = newRequest("GET", "/not/found")
66 router.ServeHTTP(rw, req)
67 if mw.timesCalled != 1 {
68 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
69 }
70 })
71
72 t.Run("not called for method mismatch", func(t *testing.T) {
73 req = newRequest("POST", "/")
74 router.ServeHTTP(rw, req)
75 if mw.timesCalled != 1 {
76 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
77 }
78 })
79
80 t.Run("regular call using function middleware", func(t *testing.T) {
81 router.Use(mw.Middleware)
82 req = newRequest("GET", "/")
83 router.ServeHTTP(rw, req)
84 if mw.timesCalled != 3 {
85 t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
86 }
87 })
88 }
89
90 func TestMiddlewareSubrouter(t *testing.T) {
91 router := NewRouter()
92 router.HandleFunc("/", dummyHandler).Methods("GET")
93
94 subrouter := router.PathPrefix("/sub").Subrouter()
95 subrouter.HandleFunc("/x", dummyHandler).Methods("GET")
96
97 mw := &testMiddleware{}
98 subrouter.useInterface(mw)
99
100 rw := NewRecorder()
101 req := newRequest("GET", "/")
102
103 t.Run("not called for route outside subrouter", func(t *testing.T) {
104 router.ServeHTTP(rw, req)
105 if mw.timesCalled != 0 {
106 t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
107 }
108 })
109
110 t.Run("not called for subrouter root 404", func(t *testing.T) {
111 req = newRequest("GET", "/sub/")
112 router.ServeHTTP(rw, req)
113 if mw.timesCalled != 0 {
114 t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
115 }
116 })
117
118 t.Run("called once for route inside subrouter", func(t *testing.T) {
119 req = newRequest("GET", "/sub/x")
120 router.ServeHTTP(rw, req)
121 if mw.timesCalled != 1 {
122 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
123 }
124 })
125
126 t.Run("not called for 404 inside subrouter", func(t *testing.T) {
127 req = newRequest("GET", "/sub/not/found")
128 router.ServeHTTP(rw, req)
129 if mw.timesCalled != 1 {
130 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
131 }
132 })
133
134 t.Run("middleware added to router", func(t *testing.T) {
135 router.useInterface(mw)
136
137 t.Run("called once for route outside subrouter", func(t *testing.T) {
138 req = newRequest("GET", "/")
139 router.ServeHTTP(rw, req)
140 if mw.timesCalled != 2 {
141 t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
142 }
143 })
144
145 t.Run("called twice for route inside subrouter", func(t *testing.T) {
146 req = newRequest("GET", "/sub/x")
147 router.ServeHTTP(rw, req)
148 if mw.timesCalled != 4 {
149 t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
150 }
151 })
152 })
153 }
154
155 func TestMiddlewareExecution(t *testing.T) {
156 mwStr := []byte("Middleware\n")
157 handlerStr := []byte("Logic\n")
158
159 router := NewRouter()
160 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
161 _, err := w.Write(handlerStr)
162 if err != nil {
163 t.Fatalf("Failed writing HTTP response: %v", err)
164 }
165 })
166
167 t.Run("responds normally without middleware", func(t *testing.T) {
168 rw := NewRecorder()
169 req := newRequest("GET", "/")
170
171 router.ServeHTTP(rw, req)
172
173 if !bytes.Equal(rw.Body.Bytes(), handlerStr) {
174 t.Fatal("Handler response is not what it should be")
175 }
176 })
177
178 t.Run("responds with handler and middleware response", func(t *testing.T) {
179 rw := NewRecorder()
180 req := newRequest("GET", "/")
181
182 router.Use(func(h http.Handler) http.Handler {
183 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184 _, err := w.Write(mwStr)
185 if err != nil {
186 t.Fatalf("Failed writing HTTP response: %v", err)
187 }
188 h.ServeHTTP(w, r)
189 })
190 })
191
192 router.ServeHTTP(rw, req)
193 if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) {
194 t.Fatal("Middleware + handler response is not what it should be")
195 }
196 })
197 }
198
199 func TestMiddlewareNotFound(t *testing.T) {
200 mwStr := []byte("Middleware\n")
201 handlerStr := []byte("Logic\n")
202
203 router := NewRouter()
204 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
205 _, err := w.Write(handlerStr)
206 if err != nil {
207 t.Fatalf("Failed writing HTTP response: %v", err)
208 }
209 })
210 router.Use(func(h http.Handler) http.Handler {
211 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
212 _, err := w.Write(mwStr)
213 if err != nil {
214 t.Fatalf("Failed writing HTTP response: %v", err)
215 }
216 h.ServeHTTP(w, r)
217 })
218 })
219
220
221 t.Run("not called", func(t *testing.T) {
222 rw := NewRecorder()
223 req := newRequest("GET", "/notfound")
224
225 router.ServeHTTP(rw, req)
226 if bytes.Contains(rw.Body.Bytes(), mwStr) {
227 t.Fatal("Middleware was called for a 404")
228 }
229 })
230
231 t.Run("not called with custom not found handler", func(t *testing.T) {
232 rw := NewRecorder()
233 req := newRequest("GET", "/notfound")
234
235 router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
236 _, err := rw.Write([]byte("Custom 404 handler"))
237 if err != nil {
238 t.Fatalf("Failed writing HTTP response: %v", err)
239 }
240 })
241 router.ServeHTTP(rw, req)
242
243 if bytes.Contains(rw.Body.Bytes(), mwStr) {
244 t.Fatal("Middleware was called for a custom 404")
245 }
246 })
247 }
248
249 func TestMiddlewareMethodMismatch(t *testing.T) {
250 mwStr := []byte("Middleware\n")
251 handlerStr := []byte("Logic\n")
252
253 router := NewRouter()
254 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
255 _, err := w.Write(handlerStr)
256 if err != nil {
257 t.Fatalf("Failed writing HTTP response: %v", err)
258 }
259 }).Methods("GET")
260
261 router.Use(func(h http.Handler) http.Handler {
262 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
263 _, err := w.Write(mwStr)
264 if err != nil {
265 t.Fatalf("Failed writing HTTP response: %v", err)
266 }
267 h.ServeHTTP(w, r)
268 })
269 })
270
271 t.Run("not called", func(t *testing.T) {
272 rw := NewRecorder()
273 req := newRequest("POST", "/")
274
275 router.ServeHTTP(rw, req)
276 if bytes.Contains(rw.Body.Bytes(), mwStr) {
277 t.Fatal("Middleware was called for a method mismatch")
278 }
279 })
280
281 t.Run("not called with custom method not allowed handler", func(t *testing.T) {
282 rw := NewRecorder()
283 req := newRequest("POST", "/")
284
285 router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
286 _, err := rw.Write([]byte("Method not allowed"))
287 if err != nil {
288 t.Fatalf("Failed writing HTTP response: %v", err)
289 }
290 })
291 router.ServeHTTP(rw, req)
292
293 if bytes.Contains(rw.Body.Bytes(), mwStr) {
294 t.Fatal("Middleware was called for a method mismatch")
295 }
296 })
297 }
298
299 func TestMiddlewareNotFoundSubrouter(t *testing.T) {
300 mwStr := []byte("Middleware\n")
301 handlerStr := []byte("Logic\n")
302
303 router := NewRouter()
304 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
305 _, err := w.Write(handlerStr)
306 if err != nil {
307 t.Fatalf("Failed writing HTTP response: %v", err)
308 }
309 })
310
311 subrouter := router.PathPrefix("/sub/").Subrouter()
312 subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
313 _, err := w.Write(handlerStr)
314 if err != nil {
315 t.Fatalf("Failed writing HTTP response: %v", err)
316 }
317 })
318
319 router.Use(func(h http.Handler) http.Handler {
320 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
321 _, err := w.Write(mwStr)
322 if err != nil {
323 t.Fatalf("Failed writing HTTP response: %v", err)
324 }
325 h.ServeHTTP(w, r)
326 })
327 })
328
329 t.Run("not called", func(t *testing.T) {
330 rw := NewRecorder()
331 req := newRequest("GET", "/sub/notfound")
332
333 router.ServeHTTP(rw, req)
334 if bytes.Contains(rw.Body.Bytes(), mwStr) {
335 t.Fatal("Middleware was called for a 404")
336 }
337 })
338
339 t.Run("not called with custom not found handler", func(t *testing.T) {
340 rw := NewRecorder()
341 req := newRequest("GET", "/sub/notfound")
342
343 subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
344 _, err := rw.Write([]byte("Custom 404 handler"))
345 if err != nil {
346 t.Fatalf("Failed writing HTTP response: %v", err)
347 }
348 })
349 router.ServeHTTP(rw, req)
350
351 if bytes.Contains(rw.Body.Bytes(), mwStr) {
352 t.Fatal("Middleware was called for a custom 404")
353 }
354 })
355 }
356
357 func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
358 mwStr := []byte("Middleware\n")
359 handlerStr := []byte("Logic\n")
360
361 router := NewRouter()
362 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
363 _, err := w.Write(handlerStr)
364 if err != nil {
365 t.Fatalf("Failed writing HTTP response: %v", err)
366 }
367 })
368
369 subrouter := router.PathPrefix("/sub/").Subrouter()
370 subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
371 _, err := w.Write(handlerStr)
372 if err != nil {
373 t.Fatalf("Failed writing HTTP response: %v", err)
374 }
375 }).Methods("GET")
376
377 router.Use(func(h http.Handler) http.Handler {
378 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
379 _, err := w.Write(mwStr)
380 if err != nil {
381 t.Fatalf("Failed writing HTTP response: %v", err)
382 }
383 h.ServeHTTP(w, r)
384 })
385 })
386
387 t.Run("not called", func(t *testing.T) {
388 rw := NewRecorder()
389 req := newRequest("POST", "/sub/")
390
391 router.ServeHTTP(rw, req)
392 if bytes.Contains(rw.Body.Bytes(), mwStr) {
393 t.Fatal("Middleware was called for a method mismatch")
394 }
395 })
396
397 t.Run("not called with custom method not allowed handler", func(t *testing.T) {
398 rw := NewRecorder()
399 req := newRequest("POST", "/sub/")
400
401 router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
402 _, err := rw.Write([]byte("Method not allowed"))
403 if err != nil {
404 t.Fatalf("Failed writing HTTP response: %v", err)
405 }
406 })
407 router.ServeHTTP(rw, req)
408
409 if bytes.Contains(rw.Body.Bytes(), mwStr) {
410 t.Fatal("Middleware was called for a method mismatch")
411 }
412 })
413 }
414
415 func TestCORSMethodMiddleware(t *testing.T) {
416 testCases := []struct {
417 name string
418 registerRoutes func(r *Router)
419 requestHeader http.Header
420 requestMethod string
421 requestPath string
422 expectedAccessControlAllowMethodsHeader string
423 expectedResponse string
424 }{
425 {
426 name: "does not set without OPTIONS matcher",
427 registerRoutes: func(r *Router) {
428 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
429 },
430 requestMethod: "GET",
431 requestPath: "/foo",
432 expectedAccessControlAllowMethodsHeader: "",
433 expectedResponse: "a",
434 },
435 {
436 name: "sets on non OPTIONS",
437 registerRoutes: func(r *Router) {
438 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
439 r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
440 },
441 requestMethod: "GET",
442 requestPath: "/foo",
443 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
444 expectedResponse: "a",
445 },
446 {
447 name: "sets without preflight headers",
448 registerRoutes: func(r *Router) {
449 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
450 r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
451 },
452 requestMethod: "OPTIONS",
453 requestPath: "/foo",
454 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
455 expectedResponse: "b",
456 },
457 {
458 name: "does not set on error",
459 registerRoutes: func(r *Router) {
460 r.HandleFunc("/foo", stringHandler("a"))
461 },
462 requestMethod: "OPTIONS",
463 requestPath: "/foo",
464 expectedAccessControlAllowMethodsHeader: "",
465 expectedResponse: "a",
466 },
467 {
468 name: "sets header on valid preflight",
469 registerRoutes: func(r *Router) {
470 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
471 r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
472 },
473 requestMethod: "OPTIONS",
474 requestPath: "/foo",
475 requestHeader: http.Header{
476 "Access-Control-Request-Method": []string{"GET"},
477 "Access-Control-Request-Headers": []string{"Authorization"},
478 "Origin": []string{"http://example.com"},
479 },
480 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
481 expectedResponse: "b",
482 },
483 {
484 name: "does not set methods from unmatching routes",
485 registerRoutes: func(r *Router) {
486 r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete)
487 r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
488 r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions)
489 },
490 requestMethod: "OPTIONS",
491 requestPath: "/foo/bar",
492 requestHeader: http.Header{
493 "Access-Control-Request-Method": []string{"GET"},
494 "Access-Control-Request-Headers": []string{"Authorization"},
495 "Origin": []string{"http://example.com"},
496 },
497 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
498 expectedResponse: "b",
499 },
500 }
501
502 for _, tt := range testCases {
503 t.Run(tt.name, func(t *testing.T) {
504 router := NewRouter()
505
506 tt.registerRoutes(router)
507
508 router.Use(CORSMethodMiddleware(router))
509
510 rw := NewRecorder()
511 req := newRequest(tt.requestMethod, tt.requestPath)
512 req.Header = tt.requestHeader
513
514 router.ServeHTTP(rw, req)
515
516 actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods")
517 if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader {
518 t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader)
519 }
520
521 actualResponse := rw.Body.String()
522 if actualResponse != tt.expectedResponse {
523 t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse)
524 }
525 })
526 }
527 }
528
529 func TestCORSMethodMiddlewareSubrouter(t *testing.T) {
530 router := NewRouter().StrictSlash(true)
531
532 subrouter := router.PathPrefix("/test").Subrouter()
533 subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost)
534 subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions)
535
536 subrouter.Use(CORSMethodMiddleware(subrouter))
537
538 rw := NewRecorder()
539 req := newRequest("GET", "/test/hello/asdf")
540 router.ServeHTTP(rw, req)
541
542 actualMethods := rw.Header().Get("Access-Control-Allow-Methods")
543 expectedMethods := "GET,OPTIONS"
544 if actualMethods != expectedMethods {
545 t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods)
546 }
547 }
548
549 func TestMiddlewareOnMultiSubrouter(t *testing.T) {
550 first := "first"
551 second := "second"
552 notFound := "404 not found"
553
554 router := NewRouter()
555 firstSubRouter := router.PathPrefix("/").Subrouter()
556 secondSubRouter := router.PathPrefix("/").Subrouter()
557
558 router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
559 _, err := rw.Write([]byte(notFound))
560 if err != nil {
561 t.Fatalf("Failed writing HTTP response: %v", err)
562 }
563 })
564
565 firstSubRouter.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) {
566
567 })
568
569 secondSubRouter.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) {
570
571 })
572
573 firstSubRouter.Use(func(h http.Handler) http.Handler {
574 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
575 _, err := w.Write([]byte(first))
576 if err != nil {
577 t.Fatalf("Failed writing HTTP response: %v", err)
578 }
579 h.ServeHTTP(w, r)
580 })
581 })
582
583 secondSubRouter.Use(func(h http.Handler) http.Handler {
584 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
585 _, err := w.Write([]byte(second))
586 if err != nil {
587 t.Fatalf("Failed writing HTTP response: %v", err)
588 }
589 h.ServeHTTP(w, r)
590 })
591 })
592
593 t.Run("/first uses first middleware", func(t *testing.T) {
594 rw := NewRecorder()
595 req := newRequest("GET", "/first")
596
597 router.ServeHTTP(rw, req)
598 if rw.Body.String() != first {
599 t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String())
600 }
601 })
602
603 t.Run("/second uses second middleware", func(t *testing.T) {
604 rw := NewRecorder()
605 req := newRequest("GET", "/second")
606
607 router.ServeHTTP(rw, req)
608 if rw.Body.String() != second {
609 t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String())
610 }
611 })
612
613 t.Run("uses not found handler", func(t *testing.T) {
614 rw := NewRecorder()
615 req := newRequest("GET", "/second/not-exist")
616
617 router.ServeHTTP(rw, req)
618 if rw.Body.String() != notFound {
619 t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String())
620 }
621 })
622 }
623
View as plain text