1 package chi
2
3
4
5
6
7 import (
8 "fmt"
9 "math"
10 "net/http"
11 "regexp"
12 "sort"
13 "strconv"
14 "strings"
15 )
16
17 type methodTyp int
18
19 const (
20 mSTUB methodTyp = 1 << iota
21 mCONNECT
22 mDELETE
23 mGET
24 mHEAD
25 mOPTIONS
26 mPATCH
27 mPOST
28 mPUT
29 mTRACE
30 )
31
32 var mALL = mCONNECT | mDELETE | mGET | mHEAD |
33 mOPTIONS | mPATCH | mPOST | mPUT | mTRACE
34
35 var methodMap = map[string]methodTyp{
36 http.MethodConnect: mCONNECT,
37 http.MethodDelete: mDELETE,
38 http.MethodGet: mGET,
39 http.MethodHead: mHEAD,
40 http.MethodOptions: mOPTIONS,
41 http.MethodPatch: mPATCH,
42 http.MethodPost: mPOST,
43 http.MethodPut: mPUT,
44 http.MethodTrace: mTRACE,
45 }
46
47
48
49 func RegisterMethod(method string) {
50 if method == "" {
51 return
52 }
53 method = strings.ToUpper(method)
54 if _, ok := methodMap[method]; ok {
55 return
56 }
57 n := len(methodMap)
58 if n > strconv.IntSize {
59 panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize))
60 }
61 mt := methodTyp(math.Exp2(float64(n)))
62 methodMap[method] = mt
63 mALL |= mt
64 }
65
66 type nodeTyp uint8
67
68 const (
69 ntStatic nodeTyp = iota
70 ntRegexp
71 ntParam
72 ntCatchAll
73 )
74
75 type node struct {
76
77 typ nodeTyp
78
79
80 label byte
81
82
83 tail byte
84
85
86 prefix string
87
88
89 rex *regexp.Regexp
90
91
92 endpoints endpoints
93
94
95 subroutes Routes
96
97
98
99 children [ntCatchAll + 1]nodes
100 }
101
102
103
104 type endpoints map[methodTyp]*endpoint
105
106 type endpoint struct {
107
108 handler http.Handler
109
110
111 pattern string
112
113
114 paramKeys []string
115 }
116
117 func (s endpoints) Value(method methodTyp) *endpoint {
118 mh, ok := s[method]
119 if !ok {
120 mh = &endpoint{}
121 s[method] = mh
122 }
123 return mh
124 }
125
126 func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node {
127 var parent *node
128 search := pattern
129
130 for {
131
132 if len(search) == 0 {
133
134 n.setEndpoint(method, handler, pattern)
135 return n
136 }
137
138
139
140 var label = search[0]
141 var segTail byte
142 var segEndIdx int
143 var segTyp nodeTyp
144 var segRexpat string
145 if label == '{' || label == '*' {
146 segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search)
147 }
148
149 var prefix string
150 if segTyp == ntRegexp {
151 prefix = segRexpat
152 }
153
154
155 parent = n
156 n = n.getEdge(segTyp, label, segTail, prefix)
157
158
159 if n == nil {
160 child := &node{label: label, tail: segTail, prefix: search}
161 hn := parent.addChild(child, search)
162 hn.setEndpoint(method, handler, pattern)
163
164 return hn
165 }
166
167
168
169 if n.typ > ntStatic {
170
171
172
173 search = search[segEndIdx:]
174 continue
175 }
176
177
178
179 commonPrefix := longestPrefix(search, n.prefix)
180 if commonPrefix == len(n.prefix) {
181
182
183 search = search[commonPrefix:]
184 continue
185 }
186
187
188 child := &node{
189 typ: ntStatic,
190 prefix: search[:commonPrefix],
191 }
192 parent.replaceChild(search[0], segTail, child)
193
194
195 n.label = n.prefix[commonPrefix]
196 n.prefix = n.prefix[commonPrefix:]
197 child.addChild(n, n.prefix)
198
199
200 search = search[commonPrefix:]
201 if len(search) == 0 {
202 child.setEndpoint(method, handler, pattern)
203 return child
204 }
205
206
207 subchild := &node{
208 typ: ntStatic,
209 label: search[0],
210 prefix: search,
211 }
212 hn := child.addChild(subchild, search)
213 hn.setEndpoint(method, handler, pattern)
214 return hn
215 }
216 }
217
218
219
220
221
222 func (n *node) addChild(child *node, prefix string) *node {
223 search := prefix
224
225
226
227 hn := child
228
229
230 segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search)
231
232
233 switch segTyp {
234
235 case ntStatic:
236
237
238
239 default:
240
241
242 if segTyp == ntRegexp {
243 rex, err := regexp.Compile(segRexpat)
244 if err != nil {
245 panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat))
246 }
247 child.prefix = segRexpat
248 child.rex = rex
249 }
250
251 if segStartIdx == 0 {
252
253 child.typ = segTyp
254
255 if segTyp == ntCatchAll {
256 segStartIdx = -1
257 } else {
258 segStartIdx = segEndIdx
259 }
260 if segStartIdx < 0 {
261 segStartIdx = len(search)
262 }
263 child.tail = segTail
264
265 if segStartIdx != len(search) {
266
267
268
269
270 search = search[segStartIdx:]
271
272 nn := &node{
273 typ: ntStatic,
274 label: search[0],
275 prefix: search,
276 }
277 hn = child.addChild(nn, search)
278 }
279
280 } else if segStartIdx > 0 {
281
282
283
284 child.typ = ntStatic
285 child.prefix = search[:segStartIdx]
286 child.rex = nil
287
288
289 search = search[segStartIdx:]
290
291 nn := &node{
292 typ: segTyp,
293 label: search[0],
294 tail: segTail,
295 }
296 hn = child.addChild(nn, search)
297
298 }
299 }
300
301 n.children[child.typ] = append(n.children[child.typ], child)
302 n.children[child.typ].Sort()
303 return hn
304 }
305
306 func (n *node) replaceChild(label, tail byte, child *node) {
307 for i := 0; i < len(n.children[child.typ]); i++ {
308 if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail {
309 n.children[child.typ][i] = child
310 n.children[child.typ][i].label = label
311 n.children[child.typ][i].tail = tail
312 return
313 }
314 }
315 panic("chi: replacing missing child")
316 }
317
318 func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node {
319 nds := n.children[ntyp]
320 for i := 0; i < len(nds); i++ {
321 if nds[i].label == label && nds[i].tail == tail {
322 if ntyp == ntRegexp && nds[i].prefix != prefix {
323 continue
324 }
325 return nds[i]
326 }
327 }
328 return nil
329 }
330
331 func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) {
332
333 if n.endpoints == nil {
334 n.endpoints = make(endpoints)
335 }
336
337 paramKeys := patParamKeys(pattern)
338
339 if method&mSTUB == mSTUB {
340 n.endpoints.Value(mSTUB).handler = handler
341 }
342 if method&mALL == mALL {
343 h := n.endpoints.Value(mALL)
344 h.handler = handler
345 h.pattern = pattern
346 h.paramKeys = paramKeys
347 for _, m := range methodMap {
348 h := n.endpoints.Value(m)
349 h.handler = handler
350 h.pattern = pattern
351 h.paramKeys = paramKeys
352 }
353 } else {
354 h := n.endpoints.Value(method)
355 h.handler = handler
356 h.pattern = pattern
357 h.paramKeys = paramKeys
358 }
359 }
360
361 func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) {
362
363 rctx.routePattern = ""
364 rctx.routeParams.Keys = rctx.routeParams.Keys[:0]
365 rctx.routeParams.Values = rctx.routeParams.Values[:0]
366
367
368 rn := n.findRoute(rctx, method, path)
369 if rn == nil {
370 return nil, nil, nil
371 }
372
373
374 rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...)
375 rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...)
376
377
378 if rn.endpoints[method].pattern != "" {
379 rctx.routePattern = rn.endpoints[method].pattern
380 rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern)
381 }
382
383 return rn, rn.endpoints, rn.endpoints[method].handler
384 }
385
386
387
388 func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
389 nn := n
390 search := path
391
392 for t, nds := range nn.children {
393 ntyp := nodeTyp(t)
394 if len(nds) == 0 {
395 continue
396 }
397
398 var xn *node
399 xsearch := search
400
401 var label byte
402 if search != "" {
403 label = search[0]
404 }
405
406 switch ntyp {
407 case ntStatic:
408 xn = nds.findEdge(label)
409 if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) {
410 continue
411 }
412 xsearch = xsearch[len(xn.prefix):]
413
414 case ntParam, ntRegexp:
415
416 if xsearch == "" {
417 continue
418 }
419
420
421 for idx := 0; idx < len(nds); idx++ {
422 xn = nds[idx]
423
424
425 p := strings.IndexByte(xsearch, xn.tail)
426
427 if p < 0 {
428 if xn.tail == '/' {
429 p = len(xsearch)
430 } else {
431 continue
432 }
433 }
434
435 if ntyp == ntRegexp && xn.rex != nil {
436 if !xn.rex.Match([]byte(xsearch[:p])) {
437 continue
438 }
439 } else if strings.IndexByte(xsearch[:p], '/') != -1 {
440
441 continue
442 }
443
444 prevlen := len(rctx.routeParams.Values)
445 rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p])
446 xsearch = xsearch[p:]
447
448 if len(xsearch) == 0 {
449 if xn.isLeaf() {
450 h := xn.endpoints[method]
451 if h != nil && h.handler != nil {
452 rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
453 return xn
454 }
455
456
457
458 rctx.methodNotAllowed = true
459 }
460 }
461
462
463 fin := xn.findRoute(rctx, method, xsearch)
464 if fin != nil {
465 return fin
466 }
467
468
469 rctx.routeParams.Values = rctx.routeParams.Values[:prevlen]
470 xsearch = search
471 }
472
473 rctx.routeParams.Values = append(rctx.routeParams.Values, "")
474
475 default:
476
477 rctx.routeParams.Values = append(rctx.routeParams.Values, search)
478 xn = nds[0]
479 xsearch = ""
480 }
481
482 if xn == nil {
483 continue
484 }
485
486
487 if len(xsearch) == 0 {
488 if xn.isLeaf() {
489 h := xn.endpoints[method]
490 if h != nil && h.handler != nil {
491 rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
492 return xn
493 }
494
495
496
497 rctx.methodNotAllowed = true
498 }
499 }
500
501
502 fin := xn.findRoute(rctx, method, xsearch)
503 if fin != nil {
504 return fin
505 }
506
507
508 if xn.typ > ntStatic {
509 if len(rctx.routeParams.Values) > 0 {
510 rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1]
511 }
512 }
513
514 }
515
516 return nil
517 }
518
519 func (n *node) findEdge(ntyp nodeTyp, label byte) *node {
520 nds := n.children[ntyp]
521 num := len(nds)
522 idx := 0
523
524 switch ntyp {
525 case ntStatic, ntParam, ntRegexp:
526 i, j := 0, num-1
527 for i <= j {
528 idx = i + (j-i)/2
529 if label > nds[idx].label {
530 i = idx + 1
531 } else if label < nds[idx].label {
532 j = idx - 1
533 } else {
534 i = num
535 }
536 }
537 if nds[idx].label != label {
538 return nil
539 }
540 return nds[idx]
541
542 default:
543 return nds[idx]
544 }
545 }
546
547 func (n *node) isLeaf() bool {
548 return n.endpoints != nil
549 }
550
551 func (n *node) findPattern(pattern string) bool {
552 nn := n
553 for _, nds := range nn.children {
554 if len(nds) == 0 {
555 continue
556 }
557
558 n = nn.findEdge(nds[0].typ, pattern[0])
559 if n == nil {
560 continue
561 }
562
563 var idx int
564 var xpattern string
565
566 switch n.typ {
567 case ntStatic:
568 idx = longestPrefix(pattern, n.prefix)
569 if idx < len(n.prefix) {
570 continue
571 }
572
573 case ntParam, ntRegexp:
574 idx = strings.IndexByte(pattern, '}') + 1
575
576 case ntCatchAll:
577 idx = longestPrefix(pattern, "*")
578
579 default:
580 panic("chi: unknown node type")
581 }
582
583 xpattern = pattern[idx:]
584 if len(xpattern) == 0 {
585 return true
586 }
587
588 return n.findPattern(xpattern)
589 }
590 return false
591 }
592
593 func (n *node) routes() []Route {
594 rts := []Route{}
595
596 n.walk(func(eps endpoints, subroutes Routes) bool {
597 if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil {
598 return false
599 }
600
601
602 pats := make(map[string]endpoints)
603
604 for mt, h := range eps {
605 if h.pattern == "" {
606 continue
607 }
608 p, ok := pats[h.pattern]
609 if !ok {
610 p = endpoints{}
611 pats[h.pattern] = p
612 }
613 p[mt] = h
614 }
615
616 for p, mh := range pats {
617 hs := make(map[string]http.Handler)
618 if mh[mALL] != nil && mh[mALL].handler != nil {
619 hs["*"] = mh[mALL].handler
620 }
621
622 for mt, h := range mh {
623 if h.handler == nil {
624 continue
625 }
626 m := methodTypString(mt)
627 if m == "" {
628 continue
629 }
630 hs[m] = h.handler
631 }
632
633 rt := Route{p, hs, subroutes}
634 rts = append(rts, rt)
635 }
636
637 return false
638 })
639
640 return rts
641 }
642
643 func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool {
644
645 if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) {
646 return true
647 }
648
649
650 for _, ns := range n.children {
651 for _, cn := range ns {
652 if cn.walk(fn) {
653 return true
654 }
655 }
656 }
657 return false
658 }
659
660
661
662 func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) {
663 ps := strings.Index(pattern, "{")
664 ws := strings.Index(pattern, "*")
665
666 if ps < 0 && ws < 0 {
667 return ntStatic, "", "", 0, 0, len(pattern)
668 }
669
670
671 if ps >= 0 && ws >= 0 && ws < ps {
672 panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'")
673 }
674
675 var tail byte = '/'
676
677 if ps >= 0 {
678
679 nt := ntParam
680
681
682 cc := 0
683 pe := ps
684 for i, c := range pattern[ps:] {
685 if c == '{' {
686 cc++
687 } else if c == '}' {
688 cc--
689 if cc == 0 {
690 pe = ps + i
691 break
692 }
693 }
694 }
695 if pe == ps {
696 panic("chi: route param closing delimiter '}' is missing")
697 }
698
699 key := pattern[ps+1 : pe]
700 pe++
701
702 if pe < len(pattern) {
703 tail = pattern[pe]
704 }
705
706 var rexpat string
707 if idx := strings.Index(key, ":"); idx >= 0 {
708 nt = ntRegexp
709 rexpat = key[idx+1:]
710 key = key[:idx]
711 }
712
713 if len(rexpat) > 0 {
714 if rexpat[0] != '^' {
715 rexpat = "^" + rexpat
716 }
717 if rexpat[len(rexpat)-1] != '$' {
718 rexpat += "$"
719 }
720 }
721
722 return nt, key, rexpat, tail, ps, pe
723 }
724
725
726 if ws < len(pattern)-1 {
727 panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead")
728 }
729 return ntCatchAll, "*", "", 0, ws, len(pattern)
730 }
731
732 func patParamKeys(pattern string) []string {
733 pat := pattern
734 paramKeys := []string{}
735 for {
736 ptyp, paramKey, _, _, _, e := patNextSegment(pat)
737 if ptyp == ntStatic {
738 return paramKeys
739 }
740 for i := 0; i < len(paramKeys); i++ {
741 if paramKeys[i] == paramKey {
742 panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey))
743 }
744 }
745 paramKeys = append(paramKeys, paramKey)
746 pat = pat[e:]
747 }
748 }
749
750
751
752 func longestPrefix(k1, k2 string) int {
753 max := len(k1)
754 if l := len(k2); l < max {
755 max = l
756 }
757 var i int
758 for i = 0; i < max; i++ {
759 if k1[i] != k2[i] {
760 break
761 }
762 }
763 return i
764 }
765
766 func methodTypString(method methodTyp) string {
767 for s, t := range methodMap {
768 if method == t {
769 return s
770 }
771 }
772 return ""
773 }
774
775 type nodes []*node
776
777
778 func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() }
779 func (ns nodes) Len() int { return len(ns) }
780 func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] }
781 func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label }
782
783
784
785 func (ns nodes) tailSort() {
786 for i := len(ns) - 1; i >= 0; i-- {
787 if ns[i].typ > ntStatic && ns[i].tail == '/' {
788 ns.Swap(i, len(ns)-1)
789 return
790 }
791 }
792 }
793
794 func (ns nodes) findEdge(label byte) *node {
795 num := len(ns)
796 idx := 0
797 i, j := 0, num-1
798 for i <= j {
799 idx = i + (j-i)/2
800 if label > ns[idx].label {
801 i = idx + 1
802 } else if label < ns[idx].label {
803 j = idx - 1
804 } else {
805 i = num
806 }
807 }
808 if ns[idx].label != label {
809 return nil
810 }
811 return ns[idx]
812 }
813
814
815
816 type Route struct {
817 Pattern string
818 Handlers map[string]http.Handler
819 SubRoutes Routes
820 }
821
822
823 type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error
824
825
826 func Walk(r Routes, walkFn WalkFunc) error {
827 return walk(r, walkFn, "")
828 }
829
830 func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error {
831 for _, route := range r.Routes() {
832 mws := make([]func(http.Handler) http.Handler, len(parentMw))
833 copy(mws, parentMw)
834 mws = append(mws, r.Middlewares()...)
835
836 if route.SubRoutes != nil {
837 if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil {
838 return err
839 }
840 continue
841 }
842
843 for method, handler := range route.Handlers {
844 if method == "*" {
845
846 continue
847 }
848
849 fullRoute := parentRoute + route.Pattern
850 fullRoute = strings.Replace(fullRoute, "/*/", "/", -1)
851
852 if chain, ok := handler.(*ChainHandler); ok {
853 if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil {
854 return err
855 }
856 } else {
857 if err := walkFn(method, fullRoute, handler, mws...); err != nil {
858 return err
859 }
860 }
861 }
862 }
863
864 return nil
865 }
866
View as plain text