1 package chi
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "io/ioutil"
9 "net"
10 "net/http"
11 "net/http/httptest"
12 "os"
13 "sync"
14 "testing"
15 "time"
16 )
17
18 func TestMuxBasic(t *testing.T) {
19 var count uint64
20 countermw := func(next http.Handler) http.Handler {
21 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22 count++
23 next.ServeHTTP(w, r)
24 })
25 }
26
27 usermw := func(next http.Handler) http.Handler {
28 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29 ctx := r.Context()
30 ctx = context.WithValue(ctx, ctxKey{"user"}, "peter")
31 r = r.WithContext(ctx)
32 next.ServeHTTP(w, r)
33 })
34 }
35
36 exmw := func(next http.Handler) http.Handler {
37 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38 ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a")
39 r = r.WithContext(ctx)
40 next.ServeHTTP(w, r)
41 })
42 }
43
44 logbuf := bytes.NewBufferString("")
45 logmsg := "logmw test"
46 logmw := func(next http.Handler) http.Handler {
47 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
48 logbuf.WriteString(logmsg)
49 next.ServeHTTP(w, r)
50 })
51 }
52
53 cxindex := func(w http.ResponseWriter, r *http.Request) {
54 ctx := r.Context()
55 user := ctx.Value(ctxKey{"user"}).(string)
56 w.WriteHeader(200)
57 w.Write([]byte(fmt.Sprintf("hi %s", user)))
58 }
59
60 ping := func(w http.ResponseWriter, r *http.Request) {
61 w.WriteHeader(200)
62 w.Write([]byte("."))
63 }
64
65 headPing := func(w http.ResponseWriter, r *http.Request) {
66 w.Header().Set("X-Ping", "1")
67 w.WriteHeader(200)
68 }
69
70 createPing := func(w http.ResponseWriter, r *http.Request) {
71
72 w.WriteHeader(201)
73 }
74
75 pingAll := func(w http.ResponseWriter, r *http.Request) {
76 w.WriteHeader(200)
77 w.Write([]byte("ping all"))
78 }
79
80 pingAll2 := func(w http.ResponseWriter, r *http.Request) {
81 w.WriteHeader(200)
82 w.Write([]byte("ping all2"))
83 }
84
85 pingOne := func(w http.ResponseWriter, r *http.Request) {
86 idParam := URLParam(r, "id")
87 w.WriteHeader(200)
88 w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam)))
89 }
90
91 pingWoop := func(w http.ResponseWriter, r *http.Request) {
92 w.WriteHeader(200)
93 w.Write([]byte("woop." + URLParam(r, "iidd")))
94 }
95
96 catchAll := func(w http.ResponseWriter, r *http.Request) {
97 w.WriteHeader(200)
98 w.Write([]byte("catchall"))
99 }
100
101 m := NewRouter()
102 m.Use(countermw)
103 m.Use(usermw)
104 m.Use(exmw)
105 m.Use(logmw)
106 m.Get("/", cxindex)
107 m.Method("GET", "/ping", http.HandlerFunc(ping))
108 m.MethodFunc("GET", "/pingall", pingAll)
109 m.MethodFunc("get", "/ping/all", pingAll)
110 m.Get("/ping/all2", pingAll2)
111
112 m.Head("/ping", headPing)
113 m.Post("/ping", createPing)
114 m.Get("/ping/{id}", pingWoop)
115 m.Get("/ping/{id}", pingOne)
116 m.Get("/ping/{iidd}/woop", pingWoop)
117 m.HandleFunc("/admin/*", catchAll)
118
119
120 ts := httptest.NewServer(m)
121 defer ts.Close()
122
123
124 if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" {
125 t.Fatalf(body)
126 }
127 tlogmsg, _ := logbuf.ReadString(0)
128 if tlogmsg != logmsg {
129 t.Error("expecting log message from middleware:", logmsg)
130 }
131
132
133 if _, body := testRequest(t, ts, "GET", "/ping", nil); body != "." {
134 t.Fatalf(body)
135 }
136
137
138 if _, body := testRequest(t, ts, "GET", "/pingall", nil); body != "ping all" {
139 t.Fatalf(body)
140 }
141
142
143 if _, body := testRequest(t, ts, "GET", "/ping/all", nil); body != "ping all" {
144 t.Fatalf(body)
145 }
146
147
148 if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" {
149 t.Fatalf(body)
150 }
151
152
153 if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" {
154 t.Fatalf(body)
155 }
156
157
158 if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" {
159 t.Fatalf(body)
160 }
161
162
163 if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" {
164 t.Fatalf(body)
165 }
166
167
168 resp, err := http.Head(ts.URL + "/ping")
169 if err != nil {
170 t.Fatal(err)
171 }
172 if resp.StatusCode != 200 {
173 t.Error("head failed, should be 200")
174 }
175 if resp.Header.Get("X-Ping") == "" {
176 t.Error("expecting X-Ping header")
177 }
178
179
180 if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" {
181 t.Fatalf(body)
182 }
183
184
185 resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{}))
186 if err != nil {
187 t.Fatal(err)
188 }
189
190 body, err := ioutil.ReadAll(resp.Body)
191 if err != nil {
192 t.Fatal(err)
193 }
194 defer resp.Body.Close()
195
196 if resp.StatusCode != 200 {
197 t.Error("POST failed, should be 200")
198 }
199
200 if string(body) != "catchall" {
201 t.Error("expecting response body: 'catchall'")
202 }
203
204
205 if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "" || resp.StatusCode != 405 {
206 t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body))
207 }
208 }
209
210 func TestMuxMounts(t *testing.T) {
211 r := NewRouter()
212
213 r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) {
214 v := URLParam(r, "hash")
215 w.Write([]byte(fmt.Sprintf("/%s", v)))
216 })
217
218 r.Route("/{hash}/share", func(r Router) {
219 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
220 v := URLParam(r, "hash")
221 w.Write([]byte(fmt.Sprintf("/%s/share", v)))
222 })
223 r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) {
224 v := URLParam(r, "hash")
225 n := URLParam(r, "network")
226 w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n)))
227 })
228 })
229
230 m := NewRouter()
231 m.Mount("/sharing", r)
232
233 ts := httptest.NewServer(m)
234 defer ts.Close()
235
236 if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" {
237 t.Fatalf(body)
238 }
239 if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" {
240 t.Fatalf(body)
241 }
242 if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" {
243 t.Fatalf(body)
244 }
245 }
246
247 func TestMuxPlain(t *testing.T) {
248 r := NewRouter()
249 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
250 w.Write([]byte("bye"))
251 })
252 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
253 w.WriteHeader(404)
254 w.Write([]byte("nothing here"))
255 })
256
257 ts := httptest.NewServer(r)
258 defer ts.Close()
259
260 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
261 t.Fatalf(body)
262 }
263 if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
264 t.Fatalf(body)
265 }
266 }
267
268 func TestMuxEmptyRoutes(t *testing.T) {
269 mux := NewRouter()
270
271 apiRouter := NewRouter()
272
273
274 mux.Handle("/api*", apiRouter)
275
276 if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" {
277 t.Fatalf(body)
278 }
279
280 if _, body := testHandler(t, apiRouter, "GET", "/", nil); body != "404 page not found\n" {
281 t.Fatalf(body)
282 }
283 }
284
285
286
287 func TestMuxTrailingSlash(t *testing.T) {
288 r := NewRouter()
289 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
290 w.WriteHeader(404)
291 w.Write([]byte("nothing here"))
292 })
293
294 subRoutes := NewRouter()
295 indexHandler := func(w http.ResponseWriter, r *http.Request) {
296 accountID := URLParam(r, "accountID")
297 w.Write([]byte(accountID))
298 }
299 subRoutes.Get("/", indexHandler)
300
301 r.Mount("/accounts/{accountID}", subRoutes)
302 r.Get("/accounts/{accountID}/", indexHandler)
303
304 ts := httptest.NewServer(r)
305 defer ts.Close()
306
307 if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" {
308 t.Fatalf(body)
309 }
310 if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" {
311 t.Fatalf(body)
312 }
313 if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
314 t.Fatalf(body)
315 }
316 }
317
318 func TestMuxNestedNotFound(t *testing.T) {
319 r := NewRouter()
320
321 r.Use(func(next http.Handler) http.Handler {
322 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
323 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw"}, "mw"))
324 next.ServeHTTP(w, r)
325 })
326 })
327
328 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
329 w.Write([]byte("bye"))
330 })
331
332 r.With(func(next http.Handler) http.Handler {
333 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
334 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"with"}, "with"))
335 next.ServeHTTP(w, r)
336 })
337 }).NotFound(func(w http.ResponseWriter, r *http.Request) {
338 chkMw := r.Context().Value(ctxKey{"mw"}).(string)
339 chkWith := r.Context().Value(ctxKey{"with"}).(string)
340 w.WriteHeader(404)
341 w.Write([]byte(fmt.Sprintf("root 404 %s %s", chkMw, chkWith)))
342 })
343
344 sr1 := NewRouter()
345
346 sr1.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
347 w.Write([]byte("sub"))
348 })
349 sr1.Group(func(sr1 Router) {
350 sr1.Use(func(next http.Handler) http.Handler {
351 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
352 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw2"}, "mw2"))
353 next.ServeHTTP(w, r)
354 })
355 })
356 sr1.NotFound(func(w http.ResponseWriter, r *http.Request) {
357 chkMw2 := r.Context().Value(ctxKey{"mw2"}).(string)
358 w.WriteHeader(404)
359 w.Write([]byte(fmt.Sprintf("sub 404 %s", chkMw2)))
360 })
361 })
362
363 sr2 := NewRouter()
364 sr2.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
365 w.Write([]byte("sub2"))
366 })
367
368 r.Mount("/admin1", sr1)
369 r.Mount("/admin2", sr2)
370
371 ts := httptest.NewServer(r)
372 defer ts.Close()
373
374 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
375 t.Fatalf(body)
376 }
377 if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "root 404 mw with" {
378 t.Fatalf(body)
379 }
380 if _, body := testRequest(t, ts, "GET", "/admin1/sub", nil); body != "sub" {
381 t.Fatalf(body)
382 }
383 if _, body := testRequest(t, ts, "GET", "/admin1/nope", nil); body != "sub 404 mw2" {
384 t.Fatalf(body)
385 }
386 if _, body := testRequest(t, ts, "GET", "/admin2/sub", nil); body != "sub2" {
387 t.Fatalf(body)
388 }
389
390
391 if _, body := testRequest(t, ts, "GET", "/admin2/nope", nil); body != "root 404 mw with" {
392 t.Fatalf(body)
393 }
394 }
395
396 func TestMuxNestedMethodNotAllowed(t *testing.T) {
397 r := NewRouter()
398 r.Get("/root", func(w http.ResponseWriter, r *http.Request) {
399 w.Write([]byte("root"))
400 })
401 r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
402 w.WriteHeader(405)
403 w.Write([]byte("root 405"))
404 })
405
406 sr1 := NewRouter()
407 sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) {
408 w.Write([]byte("sub1"))
409 })
410 sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
411 w.WriteHeader(405)
412 w.Write([]byte("sub1 405"))
413 })
414
415 sr2 := NewRouter()
416 sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) {
417 w.Write([]byte("sub2"))
418 })
419
420 pathVar := NewRouter()
421 pathVar.Get("/{var}", func(w http.ResponseWriter, r *http.Request) {
422 w.Write([]byte("pv"))
423 })
424 pathVar.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
425 w.WriteHeader(405)
426 w.Write([]byte("pv 405"))
427 })
428
429 r.Mount("/prefix1", sr1)
430 r.Mount("/prefix2", sr2)
431 r.Mount("/pathVar", pathVar)
432
433 ts := httptest.NewServer(r)
434 defer ts.Close()
435
436 if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" {
437 t.Fatalf(body)
438 }
439 if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" {
440 t.Fatalf(body)
441 }
442 if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" {
443 t.Fatalf(body)
444 }
445 if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" {
446 t.Fatalf(body)
447 }
448 if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" {
449 t.Fatalf(body)
450 }
451 if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" {
452 t.Fatalf(body)
453 }
454 if _, body := testRequest(t, ts, "GET", "/pathVar/myvar", nil); body != "pv" {
455 t.Fatalf(body)
456 }
457 if _, body := testRequest(t, ts, "DELETE", "/pathVar/myvar", nil); body != "pv 405" {
458 t.Fatalf(body)
459 }
460 }
461
462 func TestMuxComplicatedNotFound(t *testing.T) {
463 decorateRouter := func(r *Mux) {
464
465 r.Get("/auth", func(w http.ResponseWriter, r *http.Request) {
466 w.Write([]byte("auth get"))
467 })
468 r.Route("/public", func(r Router) {
469 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
470 w.Write([]byte("public get"))
471 })
472 })
473
474
475 sub0 := NewRouter()
476 sub0.Route("/resource", func(r Router) {
477 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
478 w.Write([]byte("private get"))
479 })
480 })
481 r.Mount("/private", sub0)
482
483
484 sub1 := NewRouter()
485 sub1.Route("/resource", func(r Router) {
486 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
487 w.Write([]byte("private get"))
488 })
489 })
490 r.With(func(next http.Handler) http.Handler { return next }).Mount("/private_mw", sub1)
491 }
492
493 testNotFound := func(t *testing.T, r *Mux) {
494 ts := httptest.NewServer(r)
495 defer ts.Close()
496
497
498 if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" {
499 t.Fatalf(body)
500 }
501 if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" {
502 t.Fatalf(body)
503 }
504 if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" {
505 t.Fatalf(body)
506 }
507 if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" {
508 t.Fatalf(body)
509 }
510
511 if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" {
512 t.Fatalf(body)
513 }
514 if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" {
515 t.Fatalf(body)
516 }
517 if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" {
518 t.Fatalf(body)
519 }
520 if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" {
521 t.Fatalf(body)
522 }
523 if _, body := testRequest(t, ts, "GET", "/private_mw/nope", nil); body != "custom not-found" {
524 t.Fatalf(body)
525 }
526 if _, body := testRequest(t, ts, "GET", "/private_mw/resource/nope", nil); body != "custom not-found" {
527 t.Fatalf(body)
528 }
529
530 if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" {
531 t.Fatalf(body)
532 }
533 }
534
535 t.Run("pre", func(t *testing.T) {
536 r := NewRouter()
537 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
538 w.Write([]byte("custom not-found"))
539 })
540 decorateRouter(r)
541 testNotFound(t, r)
542 })
543
544 t.Run("post", func(t *testing.T) {
545 r := NewRouter()
546 decorateRouter(r)
547 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
548 w.Write([]byte("custom not-found"))
549 })
550 testNotFound(t, r)
551 })
552 }
553
554 func TestMuxWith(t *testing.T) {
555 var cmwInit1, cmwHandler1 uint64
556 var cmwInit2, cmwHandler2 uint64
557 mw1 := func(next http.Handler) http.Handler {
558 cmwInit1++
559 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
560 cmwHandler1++
561 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline1"}, "yes"))
562 next.ServeHTTP(w, r)
563 })
564 }
565 mw2 := func(next http.Handler) http.Handler {
566 cmwInit2++
567 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
568 cmwHandler2++
569 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline2"}, "yes"))
570 next.ServeHTTP(w, r)
571 })
572 }
573
574 r := NewRouter()
575 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
576 w.Write([]byte("bye"))
577 })
578 r.With(mw1).With(mw2).Get("/inline", func(w http.ResponseWriter, r *http.Request) {
579 v1 := r.Context().Value(ctxKey{"inline1"}).(string)
580 v2 := r.Context().Value(ctxKey{"inline2"}).(string)
581 w.Write([]byte(fmt.Sprintf("inline %s %s", v1, v2)))
582 })
583
584 ts := httptest.NewServer(r)
585 defer ts.Close()
586
587 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
588 t.Fatalf(body)
589 }
590 if _, body := testRequest(t, ts, "GET", "/inline", nil); body != "inline yes yes" {
591 t.Fatalf(body)
592 }
593 if cmwInit1 != 1 {
594 t.Fatalf("expecting cmwInit1 to be 1, got %d", cmwInit1)
595 }
596 if cmwHandler1 != 1 {
597 t.Fatalf("expecting cmwHandler1 to be 1, got %d", cmwHandler1)
598 }
599 if cmwInit2 != 1 {
600 t.Fatalf("expecting cmwInit2 to be 1, got %d", cmwInit2)
601 }
602 if cmwHandler2 != 1 {
603 t.Fatalf("expecting cmwHandler2 to be 1, got %d", cmwHandler2)
604 }
605 }
606
607 func TestRouterFromMuxWith(t *testing.T) {
608 t.Parallel()
609
610 r := NewRouter()
611
612 with := r.With(func(next http.Handler) http.Handler {
613 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
614 next.ServeHTTP(w, r)
615 })
616 })
617
618 with.Get("/with_middleware", func(w http.ResponseWriter, r *http.Request) {})
619
620 ts := httptest.NewServer(with)
621 defer ts.Close()
622
623
624 testRequest(t, ts, http.MethodGet, "/with_middleware", nil)
625 }
626
627 func TestMuxMiddlewareStack(t *testing.T) {
628 var stdmwInit, stdmwHandler uint64
629 stdmw := func(next http.Handler) http.Handler {
630 stdmwInit++
631 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
632 stdmwHandler++
633 next.ServeHTTP(w, r)
634 })
635 }
636 _ = stdmw
637
638 var ctxmwInit, ctxmwHandler uint64
639 ctxmw := func(next http.Handler) http.Handler {
640 ctxmwInit++
641 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
642 ctxmwHandler++
643 ctx := r.Context()
644 ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler)
645 r = r.WithContext(ctx)
646 next.ServeHTTP(w, r)
647 })
648 }
649
650 var inCtxmwInit, inCtxmwHandler uint64
651 inCtxmw := func(next http.Handler) http.Handler {
652 inCtxmwInit++
653 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
654 inCtxmwHandler++
655 next.ServeHTTP(w, r)
656 })
657 }
658
659 r := NewRouter()
660 r.Use(stdmw)
661 r.Use(ctxmw)
662 r.Use(func(next http.Handler) http.Handler {
663 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
664 if r.URL.Path == "/ping" {
665 w.Write([]byte("pong"))
666 return
667 }
668 next.ServeHTTP(w, r)
669 })
670 })
671
672 var handlerCount uint64
673
674 r.With(inCtxmw).Get("/", func(w http.ResponseWriter, r *http.Request) {
675 handlerCount++
676 ctx := r.Context()
677 ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64)
678 w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount)))
679 })
680
681 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
682 w.Write([]byte("wooot"))
683 })
684
685 ts := httptest.NewServer(r)
686 defer ts.Close()
687
688 testRequest(t, ts, "GET", "/", nil)
689 testRequest(t, ts, "GET", "/", nil)
690 var body string
691 _, body = testRequest(t, ts, "GET", "/", nil)
692 if body != "inits:1 reqs:3 ctxValue:3" {
693 t.Fatalf("got: '%s'", body)
694 }
695
696 _, body = testRequest(t, ts, "GET", "/ping", nil)
697 if body != "pong" {
698 t.Fatalf("got: '%s'", body)
699 }
700 }
701
702 func TestMuxRouteGroups(t *testing.T) {
703 var stdmwInit, stdmwHandler uint64
704
705 stdmw := func(next http.Handler) http.Handler {
706 stdmwInit++
707 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
708 stdmwHandler++
709 next.ServeHTTP(w, r)
710 })
711 }
712
713 var stdmwInit2, stdmwHandler2 uint64
714 stdmw2 := func(next http.Handler) http.Handler {
715 stdmwInit2++
716 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
717 stdmwHandler2++
718 next.ServeHTTP(w, r)
719 })
720 }
721
722 r := NewRouter()
723 r.Group(func(r Router) {
724 r.Use(stdmw)
725 r.Get("/group", func(w http.ResponseWriter, r *http.Request) {
726 w.Write([]byte("root group"))
727 })
728 })
729 r.Group(func(r Router) {
730 r.Use(stdmw2)
731 r.Get("/group2", func(w http.ResponseWriter, r *http.Request) {
732 w.Write([]byte("root group2"))
733 })
734 })
735
736 ts := httptest.NewServer(r)
737 defer ts.Close()
738
739
740 _, body := testRequest(t, ts, "GET", "/group", nil)
741 if body != "root group" {
742 t.Fatalf("got: '%s'", body)
743 }
744 if stdmwInit != 1 || stdmwHandler != 1 {
745 t.Logf("stdmw counters failed, should be 1:1, got %d:%d", stdmwInit, stdmwHandler)
746 }
747
748
749 _, body = testRequest(t, ts, "GET", "/group2", nil)
750 if body != "root group2" {
751 t.Fatalf("got: '%s'", body)
752 }
753 if stdmwInit2 != 1 || stdmwHandler2 != 1 {
754 t.Fatalf("stdmw2 counters failed, should be 1:1, got %d:%d", stdmwInit2, stdmwHandler2)
755 }
756 }
757
758 func TestMuxBig(t *testing.T) {
759 r := bigMux()
760
761 ts := httptest.NewServer(r)
762 defer ts.Close()
763
764 var body, expected string
765
766 _, body = testRequest(t, ts, "GET", "/favicon.ico", nil)
767 if body != "fav" {
768 t.Fatalf("got '%s'", body)
769 }
770 _, body = testRequest(t, ts, "GET", "/hubs/4/view", nil)
771 if body != "/hubs/4/view reqid:1 session:anonymous" {
772 t.Fatalf("got '%v'", body)
773 }
774 _, body = testRequest(t, ts, "GET", "/hubs/4/view/index.html", nil)
775 if body != "/hubs/4/view/index.html reqid:1 session:anonymous" {
776 t.Fatalf("got '%s'", body)
777 }
778 _, body = testRequest(t, ts, "POST", "/hubs/ethereumhub/view/index.html", nil)
779 if body != "/hubs/ethereumhub/view/index.html reqid:1 session:anonymous" {
780 t.Fatalf("got '%s'", body)
781 }
782 _, body = testRequest(t, ts, "GET", "/", nil)
783 if body != "/ reqid:1 session:elvis" {
784 t.Fatalf("got '%s'", body)
785 }
786 _, body = testRequest(t, ts, "GET", "/suggestions", nil)
787 if body != "/suggestions reqid:1 session:elvis" {
788 t.Fatalf("got '%s'", body)
789 }
790 _, body = testRequest(t, ts, "GET", "/woot/444/hiiii", nil)
791 if body != "/woot/444/hiiii" {
792 t.Fatalf("got '%s'", body)
793 }
794 _, body = testRequest(t, ts, "GET", "/hubs/123", nil)
795 expected = "/hubs/123 reqid:1 session:elvis"
796 if body != expected {
797 t.Fatalf("expected:%s got:%s", expected, body)
798 }
799 _, body = testRequest(t, ts, "GET", "/hubs/123/touch", nil)
800 if body != "/hubs/123/touch reqid:1 session:elvis" {
801 t.Fatalf("got '%s'", body)
802 }
803 _, body = testRequest(t, ts, "GET", "/hubs/123/webhooks", nil)
804 if body != "/hubs/123/webhooks reqid:1 session:elvis" {
805 t.Fatalf("got '%s'", body)
806 }
807 _, body = testRequest(t, ts, "GET", "/hubs/123/posts", nil)
808 if body != "/hubs/123/posts reqid:1 session:elvis" {
809 t.Fatalf("got '%s'", body)
810 }
811 _, body = testRequest(t, ts, "GET", "/folders", nil)
812 if body != "404 page not found\n" {
813 t.Fatalf("got '%s'", body)
814 }
815 _, body = testRequest(t, ts, "GET", "/folders/", nil)
816 if body != "/folders/ reqid:1 session:elvis" {
817 t.Fatalf("got '%s'", body)
818 }
819 _, body = testRequest(t, ts, "GET", "/folders/public", nil)
820 if body != "/folders/public reqid:1 session:elvis" {
821 t.Fatalf("got '%s'", body)
822 }
823 _, body = testRequest(t, ts, "GET", "/folders/nothing", nil)
824 if body != "404 page not found\n" {
825 t.Fatalf("got '%s'", body)
826 }
827 }
828
829 func bigMux() Router {
830 var r *Mux
831 var sr3 *Mux
832
833 r = NewRouter()
834 r.Use(func(next http.Handler) http.Handler {
835 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
836 ctx := context.WithValue(r.Context(), ctxKey{"requestID"}, "1")
837 next.ServeHTTP(w, r.WithContext(ctx))
838 })
839 })
840 r.Use(func(next http.Handler) http.Handler {
841 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
842 next.ServeHTTP(w, r)
843 })
844 })
845 r.Group(func(r Router) {
846 r.Use(func(next http.Handler) http.Handler {
847 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
848 ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "anonymous")
849 next.ServeHTTP(w, r.WithContext(ctx))
850 })
851 })
852 r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
853 w.Write([]byte("fav"))
854 })
855 r.Get("/hubs/{hubID}/view", func(w http.ResponseWriter, r *http.Request) {
856 ctx := r.Context()
857 s := fmt.Sprintf("/hubs/%s/view reqid:%s session:%s", URLParam(r, "hubID"),
858 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
859 w.Write([]byte(s))
860 })
861 r.Get("/hubs/{hubID}/view/*", func(w http.ResponseWriter, r *http.Request) {
862 ctx := r.Context()
863 s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubID"),
864 URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
865 w.Write([]byte(s))
866 })
867 r.Post("/hubs/{hubSlug}/view/*", func(w http.ResponseWriter, r *http.Request) {
868 ctx := r.Context()
869 s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubSlug"),
870 URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
871 w.Write([]byte(s))
872 })
873 })
874 r.Group(func(r Router) {
875 r.Use(func(next http.Handler) http.Handler {
876 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
877 ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "elvis")
878 next.ServeHTTP(w, r.WithContext(ctx))
879 })
880 })
881 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
882 ctx := r.Context()
883 s := fmt.Sprintf("/ reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
884 w.Write([]byte(s))
885 })
886 r.Get("/suggestions", func(w http.ResponseWriter, r *http.Request) {
887 ctx := r.Context()
888 s := fmt.Sprintf("/suggestions reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
889 w.Write([]byte(s))
890 })
891
892 r.Get("/woot/{wootID}/*", func(w http.ResponseWriter, r *http.Request) {
893 s := fmt.Sprintf("/woot/%s/%s", URLParam(r, "wootID"), URLParam(r, "*"))
894 w.Write([]byte(s))
895 })
896
897 r.Route("/hubs", func(r Router) {
898 _ = r.(*Mux)
899 r.Route("/{hubID}", func(r Router) {
900 _ = r.(*Mux)
901 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
902 ctx := r.Context()
903 s := fmt.Sprintf("/hubs/%s reqid:%s session:%s",
904 URLParam(r, "hubID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
905 w.Write([]byte(s))
906 })
907 r.Get("/touch", func(w http.ResponseWriter, r *http.Request) {
908 ctx := r.Context()
909 s := fmt.Sprintf("/hubs/%s/touch reqid:%s session:%s", URLParam(r, "hubID"),
910 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
911 w.Write([]byte(s))
912 })
913
914 sr3 = NewRouter()
915 sr3.Get("/", func(w http.ResponseWriter, r *http.Request) {
916 ctx := r.Context()
917 s := fmt.Sprintf("/hubs/%s/webhooks reqid:%s session:%s", URLParam(r, "hubID"),
918 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
919 w.Write([]byte(s))
920 })
921 sr3.Route("/{webhookID}", func(r Router) {
922 _ = r.(*Mux)
923 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
924 ctx := r.Context()
925 s := fmt.Sprintf("/hubs/%s/webhooks/%s reqid:%s session:%s", URLParam(r, "hubID"),
926 URLParam(r, "webhookID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
927 w.Write([]byte(s))
928 })
929 })
930
931 r.Mount("/webhooks", Chain(func(next http.Handler) http.Handler {
932 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
933 next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{"hook"}, true)))
934 })
935 }).Handler(sr3))
936
937 r.Route("/posts", func(r Router) {
938 _ = r.(*Mux)
939 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
940 ctx := r.Context()
941 s := fmt.Sprintf("/hubs/%s/posts reqid:%s session:%s", URLParam(r, "hubID"),
942 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
943 w.Write([]byte(s))
944 })
945 })
946 })
947 })
948
949 r.Route("/folders/", func(r Router) {
950 _ = r.(*Mux)
951 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
952 ctx := r.Context()
953 s := fmt.Sprintf("/folders/ reqid:%s session:%s",
954 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
955 w.Write([]byte(s))
956 })
957 r.Get("/public", func(w http.ResponseWriter, r *http.Request) {
958 ctx := r.Context()
959 s := fmt.Sprintf("/folders/public reqid:%s session:%s",
960 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
961 w.Write([]byte(s))
962 })
963 })
964 })
965
966 return r
967 }
968
969 func TestMuxSubroutesBasic(t *testing.T) {
970 hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
971 w.Write([]byte("index"))
972 })
973 hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
974 w.Write([]byte("articles-list"))
975 })
976 hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
977 w.Write([]byte("search-articles"))
978 })
979 hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
980 w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id"))))
981 })
982 hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
983 w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id"))))
984 })
985
986 r := NewRouter()
987
988 r.Get("/", hIndex)
989 r.Route("/articles", func(r Router) {
990
991 r.Get("/", hArticlesList)
992 r.Get("/search", hSearchArticles)
993 r.Route("/{id}", func(r Router) {
994
995 r.Get("/", hGetArticle)
996 r.Get("/sync", hSyncArticle)
997 })
998 })
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018 ts := httptest.NewServer(r)
1019 defer ts.Close()
1020
1021 var body, expected string
1022
1023 _, body = testRequest(t, ts, "GET", "/", nil)
1024 expected = "index"
1025 if body != expected {
1026 t.Fatalf("expected:%s got:%s", expected, body)
1027 }
1028 _, body = testRequest(t, ts, "GET", "/articles", nil)
1029 expected = "articles-list"
1030 if body != expected {
1031 t.Fatalf("expected:%s got:%s", expected, body)
1032 }
1033 _, body = testRequest(t, ts, "GET", "/articles/search", nil)
1034 expected = "search-articles"
1035 if body != expected {
1036 t.Fatalf("expected:%s got:%s", expected, body)
1037 }
1038 _, body = testRequest(t, ts, "GET", "/articles/123", nil)
1039 expected = "get-article:123"
1040 if body != expected {
1041 t.Fatalf("expected:%s got:%s", expected, body)
1042 }
1043 _, body = testRequest(t, ts, "GET", "/articles/123/sync", nil)
1044 expected = "sync-article:123"
1045 if body != expected {
1046 t.Fatalf("expected:%s got:%s", expected, body)
1047 }
1048 }
1049
1050 func TestMuxSubroutes(t *testing.T) {
1051 hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1052 w.Write([]byte("hub1"))
1053 })
1054 hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1055 w.Write([]byte("hub2"))
1056 })
1057 hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1058 w.Write([]byte("hub3"))
1059 })
1060 hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1061 w.Write([]byte("account1"))
1062 })
1063 hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1064 w.Write([]byte("account2"))
1065 })
1066
1067 r := NewRouter()
1068 r.Get("/hubs/{hubID}/view", hHubView1)
1069 r.Get("/hubs/{hubID}/view/*", hHubView2)
1070
1071 sr := NewRouter()
1072 sr.Get("/", hHubView3)
1073 r.Mount("/hubs/{hubID}/users", sr)
1074 r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) {
1075 w.Write([]byte("hub3 override"))
1076 })
1077
1078 sr3 := NewRouter()
1079 sr3.Get("/", hAccountView1)
1080 sr3.Get("/hi", hAccountView2)
1081
1082
1083 r.Route("/accounts/{accountID}", func(r Router) {
1084 _ = r.(*Mux)
1085
1086 r.Mount("/", sr3)
1087 })
1088
1089
1090
1091
1092
1093
1094 ts := httptest.NewServer(r)
1095 defer ts.Close()
1096
1097 var body, expected string
1098
1099 _, body = testRequest(t, ts, "GET", "/hubs/123/view", nil)
1100 expected = "hub1"
1101 if body != expected {
1102 t.Fatalf("expected:%s got:%s", expected, body)
1103 }
1104 _, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil)
1105 expected = "hub2"
1106 if body != expected {
1107 t.Fatalf("expected:%s got:%s", expected, body)
1108 }
1109 _, body = testRequest(t, ts, "GET", "/hubs/123/users", nil)
1110 expected = "hub3"
1111 if body != expected {
1112 t.Fatalf("expected:%s got:%s", expected, body)
1113 }
1114 _, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil)
1115 expected = "hub3 override"
1116 if body != expected {
1117 t.Fatalf("expected:%s got:%s", expected, body)
1118 }
1119 _, body = testRequest(t, ts, "GET", "/accounts/44", nil)
1120 expected = "account1"
1121 if body != expected {
1122 t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body)
1123 }
1124 _, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil)
1125 expected = "account2"
1126 if body != expected {
1127 t.Fatalf("expected:%s got:%s", expected, body)
1128 }
1129
1130
1131 router := r
1132 req, _ := http.NewRequest("GET", "/accounts/44/hi", nil)
1133
1134 rctx := NewRouteContext()
1135 req = req.WithContext(context.WithValue(req.Context(), RouteCtxKey, rctx))
1136
1137 w := httptest.NewRecorder()
1138 router.ServeHTTP(w, req)
1139
1140 body = w.Body.String()
1141 expected = "account2"
1142 if body != expected {
1143 t.Fatalf("expected:%s got:%s", expected, body)
1144 }
1145
1146 routePatterns := rctx.RoutePatterns
1147 if len(rctx.RoutePatterns) != 3 {
1148 t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.RoutePatterns))
1149 }
1150 expected = "/accounts/{accountID}/*"
1151 if routePatterns[0] != expected {
1152 t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0])
1153 }
1154 expected = "/*"
1155 if routePatterns[1] != expected {
1156 t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1])
1157 }
1158 expected = "/hi"
1159 if routePatterns[2] != expected {
1160 t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2])
1161 }
1162
1163 }
1164
1165 func TestSingleHandler(t *testing.T) {
1166 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1167 name := URLParam(r, "name")
1168 w.Write([]byte("hi " + name))
1169 })
1170
1171 r, _ := http.NewRequest("GET", "/", nil)
1172 rctx := NewRouteContext()
1173 r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
1174 rctx.URLParams.Add("name", "joe")
1175
1176 w := httptest.NewRecorder()
1177 h.ServeHTTP(w, r)
1178
1179 body := w.Body.String()
1180 expected := "hi joe"
1181 if body != expected {
1182 t.Fatalf("expected:%s got:%s", expected, body)
1183 }
1184 }
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209 func TestServeHTTPExistingContext(t *testing.T) {
1210 r := NewRouter()
1211 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
1212 s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
1213 w.Write([]byte(s))
1214 })
1215 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
1216 s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
1217 w.WriteHeader(404)
1218 w.Write([]byte(s))
1219 })
1220
1221 testcases := []struct {
1222 Method string
1223 Path string
1224 Ctx context.Context
1225 ExpectedStatus int
1226 ExpectedBody string
1227 }{
1228 {
1229 Method: "GET",
1230 Path: "/hi",
1231 Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"),
1232 ExpectedStatus: 200,
1233 ExpectedBody: "hi ctx",
1234 },
1235 {
1236 Method: "GET",
1237 Path: "/hello",
1238 Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"),
1239 ExpectedStatus: 404,
1240 ExpectedBody: "nothing here ctx",
1241 },
1242 }
1243
1244 for _, tc := range testcases {
1245 resp := httptest.NewRecorder()
1246 req, err := http.NewRequest(tc.Method, tc.Path, nil)
1247 if err != nil {
1248 t.Fatalf("%v", err)
1249 }
1250 req = req.WithContext(tc.Ctx)
1251 r.ServeHTTP(resp, req)
1252 b, err := ioutil.ReadAll(resp.Body)
1253 if err != nil {
1254 t.Fatalf("%v", err)
1255 }
1256 if resp.Code != tc.ExpectedStatus {
1257 t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code)
1258 }
1259 if string(b) != tc.ExpectedBody {
1260 t.Fatalf("%s != %s", tc.ExpectedBody, b)
1261 }
1262 }
1263 }
1264
1265 func TestNestedGroups(t *testing.T) {
1266 handlerPrintCounter := func(w http.ResponseWriter, r *http.Request) {
1267 counter, _ := r.Context().Value(ctxKey{"counter"}).(int)
1268 w.Write([]byte(fmt.Sprintf("%v", counter)))
1269 }
1270
1271 mwIncreaseCounter := func(next http.Handler) http.Handler {
1272 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1273 ctx := r.Context()
1274 counter, _ := ctx.Value(ctxKey{"counter"}).(int)
1275 counter++
1276 ctx = context.WithValue(ctx, ctxKey{"counter"}, counter)
1277 next.ServeHTTP(w, r.WithContext(ctx))
1278 })
1279 }
1280
1281
1282 r := NewRouter()
1283 r.Get("/0", handlerPrintCounter)
1284 r.Group(func(r Router) {
1285 r.Use(mwIncreaseCounter)
1286 r.Get("/1", handlerPrintCounter)
1287
1288
1289 r.With(mwIncreaseCounter).Get("/2", handlerPrintCounter)
1290
1291 r.Group(func(r Router) {
1292 r.Use(mwIncreaseCounter, mwIncreaseCounter)
1293 r.Get("/3", handlerPrintCounter)
1294 })
1295 r.Route("/", func(r Router) {
1296 r.Use(mwIncreaseCounter, mwIncreaseCounter)
1297
1298
1299 r.With(mwIncreaseCounter).Get("/4", handlerPrintCounter)
1300
1301 r.Group(func(r Router) {
1302 r.Use(mwIncreaseCounter, mwIncreaseCounter)
1303 r.Get("/5", handlerPrintCounter)
1304
1305 r.With(mwIncreaseCounter).Get("/6", handlerPrintCounter)
1306
1307 })
1308 })
1309 })
1310
1311 ts := httptest.NewServer(r)
1312 defer ts.Close()
1313
1314 for _, route := range []string{"0", "1", "2", "3", "4", "5", "6"} {
1315 if _, body := testRequest(t, ts, "GET", "/"+route, nil); body != route {
1316 t.Errorf("expected %v, got %v", route, body)
1317 }
1318 }
1319 }
1320
1321 func TestMiddlewarePanicOnLateUse(t *testing.T) {
1322 handler := func(w http.ResponseWriter, r *http.Request) {
1323 w.Write([]byte("hello\n"))
1324 }
1325
1326 mw := func(next http.Handler) http.Handler {
1327 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1328 next.ServeHTTP(w, r)
1329 })
1330 }
1331
1332 defer func() {
1333 if recover() == nil {
1334 t.Error("expected panic()")
1335 }
1336 }()
1337
1338 r := NewRouter()
1339 r.Get("/", handler)
1340 r.Use(mw)
1341 }
1342
1343 func TestMountingExistingPath(t *testing.T) {
1344 handler := func(w http.ResponseWriter, r *http.Request) {}
1345
1346 defer func() {
1347 if recover() == nil {
1348 t.Error("expected panic()")
1349 }
1350 }()
1351
1352 r := NewRouter()
1353 r.Get("/", handler)
1354 r.Mount("/hi", http.HandlerFunc(handler))
1355 r.Mount("/hi", http.HandlerFunc(handler))
1356 }
1357
1358 func TestMountingSimilarPattern(t *testing.T) {
1359 r := NewRouter()
1360 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
1361 w.Write([]byte("bye"))
1362 })
1363
1364 r2 := NewRouter()
1365 r2.Get("/", func(w http.ResponseWriter, r *http.Request) {
1366 w.Write([]byte("foobar"))
1367 })
1368
1369 r3 := NewRouter()
1370 r3.Get("/", func(w http.ResponseWriter, r *http.Request) {
1371 w.Write([]byte("foo"))
1372 })
1373
1374 r.Mount("/foobar", r2)
1375 r.Mount("/foo", r3)
1376
1377 ts := httptest.NewServer(r)
1378 defer ts.Close()
1379
1380 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
1381 t.Fatalf(body)
1382 }
1383 }
1384
1385 func TestMuxEmptyParams(t *testing.T) {
1386 r := NewRouter()
1387 r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) {
1388 x := URLParam(r, "x")
1389 y := URLParam(r, "y")
1390 z := URLParam(r, "z")
1391 w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z)))
1392 })
1393
1394 ts := httptest.NewServer(r)
1395 defer ts.Close()
1396
1397 if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" {
1398 t.Fatalf(body)
1399 }
1400 if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" {
1401 t.Fatalf(body)
1402 }
1403 }
1404
1405 func TestMuxMissingParams(t *testing.T) {
1406 r := NewRouter()
1407 r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) {
1408 userID := URLParam(r, "userId")
1409 w.Write([]byte(fmt.Sprintf("userId = '%s'", userID)))
1410 })
1411 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
1412 w.WriteHeader(404)
1413 w.Write([]byte("nothing here"))
1414 })
1415
1416 ts := httptest.NewServer(r)
1417 defer ts.Close()
1418
1419 if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" {
1420 t.Fatalf(body)
1421 }
1422 if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" {
1423 t.Fatalf(body)
1424 }
1425 }
1426
1427 func TestMuxWildcardRoute(t *testing.T) {
1428 handler := func(w http.ResponseWriter, r *http.Request) {}
1429
1430 defer func() {
1431 if recover() == nil {
1432 t.Error("expected panic()")
1433 }
1434 }()
1435
1436 r := NewRouter()
1437 r.Get("/*/wildcard/must/be/at/end", handler)
1438 }
1439
1440 func TestMuxWildcardRouteCheckTwo(t *testing.T) {
1441 handler := func(w http.ResponseWriter, r *http.Request) {}
1442
1443 defer func() {
1444 if recover() == nil {
1445 t.Error("expected panic()")
1446 }
1447 }()
1448
1449 r := NewRouter()
1450 r.Get("/*/wildcard/{must}/be/at/end", handler)
1451 }
1452
1453 func TestMuxRegexp(t *testing.T) {
1454 r := NewRouter()
1455 r.Route("/{param:[0-9]+}/test", func(r Router) {
1456 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
1457 w.Write([]byte(fmt.Sprintf("Hi: %s", URLParam(r, "param"))))
1458 })
1459 })
1460
1461 ts := httptest.NewServer(r)
1462 defer ts.Close()
1463
1464 if _, body := testRequest(t, ts, "GET", "//test", nil); body != "Hi: " {
1465 t.Fatalf(body)
1466 }
1467 }
1468
1469 func TestMuxRegexp2(t *testing.T) {
1470 r := NewRouter()
1471 r.Get("/foo-{suffix:[a-z]{2,3}}.json", func(w http.ResponseWriter, r *http.Request) {
1472 w.Write([]byte(URLParam(r, "suffix")))
1473 })
1474 ts := httptest.NewServer(r)
1475 defer ts.Close()
1476
1477 if _, body := testRequest(t, ts, "GET", "/foo-.json", nil); body != "" {
1478 t.Fatalf(body)
1479 }
1480 if _, body := testRequest(t, ts, "GET", "/foo-abc.json", nil); body != "abc" {
1481 t.Fatalf(body)
1482 }
1483 }
1484
1485 func TestMuxRegexp3(t *testing.T) {
1486 r := NewRouter()
1487 r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) {
1488 w.Write([]byte("first"))
1489 })
1490 r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
1491 w.Write([]byte("second"))
1492 })
1493 r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
1494 w.Write([]byte("third"))
1495 })
1496
1497 r.Route("/one", func(r Router) {
1498 r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) {
1499 writer.Write([]byte("_"))
1500 })
1501 r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) {
1502 writer.Write([]byte("_"))
1503 })
1504 r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) {
1505 writer.Write([]byte("forth"))
1506 })
1507 })
1508
1509 ts := httptest.NewServer(r)
1510 defer ts.Close()
1511
1512 if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" {
1513 t.Fatalf(body)
1514 }
1515 if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" {
1516 t.Fatalf(body)
1517 }
1518 if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" {
1519 t.Fatalf(body)
1520 }
1521 if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" {
1522 t.Fatalf(body)
1523 }
1524 }
1525
1526 func TestMuxContextIsThreadSafe(t *testing.T) {
1527 router := NewRouter()
1528 router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
1529 ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond)
1530 defer cancel()
1531
1532 <-ctx.Done()
1533 })
1534
1535 wg := sync.WaitGroup{}
1536
1537 for i := 0; i < 100; i++ {
1538 wg.Add(1)
1539 go func() {
1540 defer wg.Done()
1541 for j := 0; j < 10000; j++ {
1542 w := httptest.NewRecorder()
1543 r, err := http.NewRequest("GET", "/ok", nil)
1544 if err != nil {
1545 t.Fatal(err)
1546 }
1547
1548 ctx, cancel := context.WithCancel(r.Context())
1549 r = r.WithContext(ctx)
1550
1551 go func() {
1552 cancel()
1553 }()
1554 router.ServeHTTP(w, r)
1555 }
1556 }()
1557 }
1558 wg.Wait()
1559 }
1560
1561 func TestEscapedURLParams(t *testing.T) {
1562 m := NewRouter()
1563 m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) {
1564 w.WriteHeader(200)
1565 rctx := RouteContext(r.Context())
1566 if rctx == nil {
1567 t.Error("no context")
1568 return
1569 }
1570 identifier := URLParam(r, "identifier")
1571 if identifier != "http:%2f%2fexample.com%2fimage.png" {
1572 t.Errorf("identifier path parameter incorrect %s", identifier)
1573 return
1574 }
1575 region := URLParam(r, "region")
1576 if region != "full" {
1577 t.Errorf("region path parameter incorrect %s", region)
1578 return
1579 }
1580 size := URLParam(r, "size")
1581 if size != "max" {
1582 t.Errorf("size path parameter incorrect %s", size)
1583 return
1584 }
1585 rotation := URLParam(r, "rotation")
1586 if rotation != "0" {
1587 t.Errorf("rotation path parameter incorrect %s", rotation)
1588 return
1589 }
1590 w.Write([]byte("success"))
1591 })
1592
1593 ts := httptest.NewServer(m)
1594 defer ts.Close()
1595
1596 if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" {
1597 t.Fatalf(body)
1598 }
1599 }
1600
1601 func TestMuxMatch(t *testing.T) {
1602 r := NewRouter()
1603 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
1604 w.Header().Set("X-Test", "yes")
1605 w.Write([]byte("bye"))
1606 })
1607 r.Route("/articles", func(r Router) {
1608 r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
1609 id := URLParam(r, "id")
1610 w.Header().Set("X-Article", id)
1611 w.Write([]byte("article:" + id))
1612 })
1613 })
1614 r.Route("/users", func(r Router) {
1615 r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) {
1616 w.Header().Set("X-User", "-")
1617 w.Write([]byte("user"))
1618 })
1619 r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
1620 id := URLParam(r, "id")
1621 w.Header().Set("X-User", id)
1622 w.Write([]byte("user:" + id))
1623 })
1624 })
1625
1626 tctx := NewRouteContext()
1627
1628 tctx.Reset()
1629 if r.Match(tctx, "GET", "/users/1") == false {
1630 t.Fatal("expecting to find match for route:", "GET", "/users/1")
1631 }
1632
1633 tctx.Reset()
1634 if r.Match(tctx, "HEAD", "/articles/10") == true {
1635 t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10")
1636 }
1637 }
1638
1639 func TestServerBaseContext(t *testing.T) {
1640 r := NewRouter()
1641 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
1642 baseYes := r.Context().Value(ctxKey{"base"}).(string)
1643 if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok {
1644 panic("missing server context")
1645 }
1646 if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok {
1647 panic("missing local addr context")
1648 }
1649 w.Write([]byte(baseYes))
1650 })
1651
1652
1653 ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes")
1654 ts := httptest.NewServer(ServerBaseContext(ctx, r))
1655 defer ts.Close()
1656
1657 if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" {
1658 t.Fatalf(body)
1659 }
1660 }
1661
1662 func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
1663 req, err := http.NewRequest(method, ts.URL+path, body)
1664 if err != nil {
1665 t.Fatal(err)
1666 return nil, ""
1667 }
1668
1669 resp, err := http.DefaultClient.Do(req)
1670 if err != nil {
1671 t.Fatal(err)
1672 return nil, ""
1673 }
1674
1675 respBody, err := ioutil.ReadAll(resp.Body)
1676 if err != nil {
1677 t.Fatal(err)
1678 return nil, ""
1679 }
1680 defer resp.Body.Close()
1681
1682 return resp, string(respBody)
1683 }
1684
1685 func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) {
1686 r, _ := http.NewRequest(method, path, body)
1687 w := httptest.NewRecorder()
1688 h.ServeHTTP(w, r)
1689 return w.Result(), w.Body.String()
1690 }
1691
1692 type testFileSystem struct {
1693 open func(name string) (http.File, error)
1694 }
1695
1696 func (fs *testFileSystem) Open(name string) (http.File, error) {
1697 return fs.open(name)
1698 }
1699
1700 type testFile struct {
1701 name string
1702 contents []byte
1703 }
1704
1705 func (tf *testFile) Close() error {
1706 return nil
1707 }
1708
1709 func (tf *testFile) Read(p []byte) (n int, err error) {
1710 copy(p, tf.contents)
1711 return len(p), nil
1712 }
1713
1714 func (tf *testFile) Seek(offset int64, whence int) (int64, error) {
1715 return 0, nil
1716 }
1717
1718 func (tf *testFile) Readdir(count int) ([]os.FileInfo, error) {
1719 stat, _ := tf.Stat()
1720 return []os.FileInfo{stat}, nil
1721 }
1722
1723 func (tf *testFile) Stat() (os.FileInfo, error) {
1724 return &testFileInfo{tf.name, int64(len(tf.contents))}, nil
1725 }
1726
1727 type testFileInfo struct {
1728 name string
1729 size int64
1730 }
1731
1732 func (tfi *testFileInfo) Name() string { return tfi.name }
1733 func (tfi *testFileInfo) Size() int64 { return tfi.size }
1734 func (tfi *testFileInfo) Mode() os.FileMode { return 0755 }
1735 func (tfi *testFileInfo) ModTime() time.Time { return time.Now() }
1736 func (tfi *testFileInfo) IsDir() bool { return false }
1737 func (tfi *testFileInfo) Sys() interface{} { return nil }
1738
1739 type ctxKey struct {
1740 name string
1741 }
1742
1743 func (k ctxKey) String() string {
1744 return "context value " + k.name
1745 }
1746
1747 func BenchmarkMux(b *testing.B) {
1748 h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1749 h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1750 h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1751 h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1752 h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1753 h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1754
1755 mx := NewRouter()
1756 mx.Get("/", h1)
1757 mx.Get("/hi", h2)
1758 mx.Get("/sup/{id}/and/{this}", h3)
1759
1760 mx.Route("/sharing/{x}/{hash}", func(mx Router) {
1761 mx.Get("/", h4)
1762 mx.Get("/{network}", h5)
1763 mx.Get("/twitter", h5)
1764 mx.Route("/direct", func(mx Router) {
1765 mx.Get("/", h6)
1766 mx.Get("/download", h6)
1767 })
1768 })
1769
1770 routes := []string{
1771 "/",
1772 "/hi",
1773 "/sup/123/and/this",
1774 "/sharing/z/aBc",
1775 "/sharing/z/aBc/twitter",
1776 "/sharing/z/aBc/direct",
1777 "/sharing/z/aBc/direct/download",
1778 }
1779
1780 for _, path := range routes {
1781 b.Run("route:"+path, func(b *testing.B) {
1782 w := httptest.NewRecorder()
1783 r, _ := http.NewRequest("GET", path, nil)
1784
1785 b.ReportAllocs()
1786 b.ResetTimer()
1787
1788 for i := 0; i < b.N; i++ {
1789 mx.ServeHTTP(w, r)
1790 }
1791 })
1792 }
1793 }
1794
View as plain text