1
2
3
4
5 package mux
6
7 import (
8 "context"
9 "errors"
10 "fmt"
11 "net/http"
12 "path"
13 "regexp"
14 )
15
16 var (
17
18
19 ErrMethodMismatch = errors.New("method is not allowed")
20
21 ErrNotFound = errors.New("no matching route was found")
22 )
23
24
25 func NewRouter() *Router {
26 return &Router{namedRoutes: make(map[string]*Route)}
27 }
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47 type Router struct {
48
49
50 NotFoundHandler http.Handler
51
52
53
54 MethodNotAllowedHandler http.Handler
55
56
57 routes []*Route
58
59
60 namedRoutes map[string]*Route
61
62
63
64
65 KeepContext bool
66
67
68 middlewares []middleware
69
70
71 routeConf
72 }
73
74
75 type routeConf struct {
76
77 useEncodedPath bool
78
79
80
81 strictSlash bool
82
83
84
85 skipClean bool
86
87
88 regexp routeRegexpGroup
89
90
91 matchers []matcher
92
93
94 buildScheme string
95
96 buildVarsFunc BuildVarsFunc
97 }
98
99
100 func copyRouteConf(r routeConf) routeConf {
101 c := r
102
103 if r.regexp.path != nil {
104 c.regexp.path = copyRouteRegexp(r.regexp.path)
105 }
106
107 if r.regexp.host != nil {
108 c.regexp.host = copyRouteRegexp(r.regexp.host)
109 }
110
111 c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries))
112 for _, q := range r.regexp.queries {
113 c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q))
114 }
115
116 c.matchers = make([]matcher, len(r.matchers))
117 copy(c.matchers, r.matchers)
118
119 return c
120 }
121
122 func copyRouteRegexp(r *routeRegexp) *routeRegexp {
123 c := *r
124 return &c
125 }
126
127
128
129
130
131
132
133
134
135
136
137
138 func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
139 for _, route := range r.routes {
140 if route.Match(req, match) {
141
142 if match.MatchErr == nil {
143 for i := len(r.middlewares) - 1; i >= 0; i-- {
144 match.Handler = r.middlewares[i].Middleware(match.Handler)
145 }
146 }
147 return true
148 }
149 }
150
151 if match.MatchErr == ErrMethodMismatch {
152 if r.MethodNotAllowedHandler != nil {
153 match.Handler = r.MethodNotAllowedHandler
154 return true
155 }
156
157 return false
158 }
159
160
161 if r.NotFoundHandler != nil {
162 match.Handler = r.NotFoundHandler
163 match.MatchErr = ErrNotFound
164 return true
165 }
166
167 match.MatchErr = ErrNotFound
168 return false
169 }
170
171
172
173
174
175 func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
176 if !r.skipClean {
177 path := req.URL.Path
178 if r.useEncodedPath {
179 path = req.URL.EscapedPath()
180 }
181
182 if p := cleanPath(path); p != path {
183
184
185
186
187 url := *req.URL
188 url.Path = p
189 p = url.String()
190
191 w.Header().Set("Location", p)
192 w.WriteHeader(http.StatusMovedPermanently)
193 return
194 }
195 }
196 var match RouteMatch
197 var handler http.Handler
198 if r.Match(req, &match) {
199 handler = match.Handler
200 req = requestWithVars(req, match.Vars)
201 req = requestWithRoute(req, match.Route)
202 }
203
204 if handler == nil && match.MatchErr == ErrMethodMismatch {
205 handler = methodNotAllowedHandler()
206 }
207
208 if handler == nil {
209 handler = http.NotFoundHandler()
210 }
211
212 handler.ServeHTTP(w, req)
213 }
214
215
216 func (r *Router) Get(name string) *Route {
217 return r.namedRoutes[name]
218 }
219
220
221
222 func (r *Router) GetRoute(name string) *Route {
223 return r.namedRoutes[name]
224 }
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245 func (r *Router) StrictSlash(value bool) *Router {
246 r.strictSlash = value
247 return r
248 }
249
250
251
252
253
254
255
256
257
258 func (r *Router) SkipClean(value bool) *Router {
259 r.skipClean = value
260 return r
261 }
262
263
264
265
266
267
268
269 func (r *Router) UseEncodedPath() *Router {
270 r.useEncodedPath = true
271 return r
272 }
273
274
275
276
277
278
279 func (r *Router) NewRoute() *Route {
280
281 route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
282 r.routes = append(r.routes, route)
283 return route
284 }
285
286
287
288 func (r *Router) Name(name string) *Route {
289 return r.NewRoute().Name(name)
290 }
291
292
293
294 func (r *Router) Handle(path string, handler http.Handler) *Route {
295 return r.NewRoute().Path(path).Handler(handler)
296 }
297
298
299
300 func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
301 *http.Request)) *Route {
302 return r.NewRoute().Path(path).HandlerFunc(f)
303 }
304
305
306
307 func (r *Router) Headers(pairs ...string) *Route {
308 return r.NewRoute().Headers(pairs...)
309 }
310
311
312
313 func (r *Router) Host(tpl string) *Route {
314 return r.NewRoute().Host(tpl)
315 }
316
317
318
319 func (r *Router) MatcherFunc(f MatcherFunc) *Route {
320 return r.NewRoute().MatcherFunc(f)
321 }
322
323
324
325 func (r *Router) Methods(methods ...string) *Route {
326 return r.NewRoute().Methods(methods...)
327 }
328
329
330
331 func (r *Router) Path(tpl string) *Route {
332 return r.NewRoute().Path(tpl)
333 }
334
335
336
337 func (r *Router) PathPrefix(tpl string) *Route {
338 return r.NewRoute().PathPrefix(tpl)
339 }
340
341
342
343 func (r *Router) Queries(pairs ...string) *Route {
344 return r.NewRoute().Queries(pairs...)
345 }
346
347
348
349 func (r *Router) Schemes(schemes ...string) *Route {
350 return r.NewRoute().Schemes(schemes...)
351 }
352
353
354
355 func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
356 return r.NewRoute().BuildVarsFunc(f)
357 }
358
359
360
361
362 func (r *Router) Walk(walkFn WalkFunc) error {
363 return r.walk(walkFn, []*Route{})
364 }
365
366
367
368 var SkipRouter = errors.New("skip this router")
369
370
371
372
373 type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
374
375 func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
376 for _, t := range r.routes {
377 err := walkFn(t, r, ancestors)
378 if err == SkipRouter {
379 continue
380 }
381 if err != nil {
382 return err
383 }
384 for _, sr := range t.matchers {
385 if h, ok := sr.(*Router); ok {
386 ancestors = append(ancestors, t)
387 err := h.walk(walkFn, ancestors)
388 if err != nil {
389 return err
390 }
391 ancestors = ancestors[:len(ancestors)-1]
392 }
393 }
394 if h, ok := t.handler.(*Router); ok {
395 ancestors = append(ancestors, t)
396 err := h.walk(walkFn, ancestors)
397 if err != nil {
398 return err
399 }
400 ancestors = ancestors[:len(ancestors)-1]
401 }
402 }
403 return nil
404 }
405
406
407
408
409
410
411 type RouteMatch struct {
412 Route *Route
413 Handler http.Handler
414 Vars map[string]string
415
416
417
418
419 MatchErr error
420 }
421
422 type contextKey int
423
424 const (
425 varsKey contextKey = iota
426 routeKey
427 )
428
429
430 func Vars(r *http.Request) map[string]string {
431 if rv := r.Context().Value(varsKey); rv != nil {
432 return rv.(map[string]string)
433 }
434 return nil
435 }
436
437
438
439
440
441 func CurrentRoute(r *http.Request) *Route {
442 if rv := r.Context().Value(routeKey); rv != nil {
443 return rv.(*Route)
444 }
445 return nil
446 }
447
448 func requestWithVars(r *http.Request, vars map[string]string) *http.Request {
449 ctx := context.WithValue(r.Context(), varsKey, vars)
450 return r.WithContext(ctx)
451 }
452
453 func requestWithRoute(r *http.Request, route *Route) *http.Request {
454 ctx := context.WithValue(r.Context(), routeKey, route)
455 return r.WithContext(ctx)
456 }
457
458
459
460
461
462
463
464 func cleanPath(p string) string {
465 if p == "" {
466 return "/"
467 }
468 if p[0] != '/' {
469 p = "/" + p
470 }
471 np := path.Clean(p)
472
473
474 if p[len(p)-1] == '/' && np != "/" {
475 np += "/"
476 }
477
478 return np
479 }
480
481
482 func uniqueVars(s1, s2 []string) error {
483 for _, v1 := range s1 {
484 for _, v2 := range s2 {
485 if v1 == v2 {
486 return fmt.Errorf("mux: duplicated route variable %q", v2)
487 }
488 }
489 }
490 return nil
491 }
492
493
494
495 func checkPairs(pairs ...string) (int, error) {
496 length := len(pairs)
497 if length%2 != 0 {
498 return length, fmt.Errorf(
499 "mux: number of parameters must be multiple of 2, got %v", pairs)
500 }
501 return length, nil
502 }
503
504
505
506 func mapFromPairsToString(pairs ...string) (map[string]string, error) {
507 length, err := checkPairs(pairs...)
508 if err != nil {
509 return nil, err
510 }
511 m := make(map[string]string, length/2)
512 for i := 0; i < length; i += 2 {
513 m[pairs[i]] = pairs[i+1]
514 }
515 return m, nil
516 }
517
518
519
520 func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
521 length, err := checkPairs(pairs...)
522 if err != nil {
523 return nil, err
524 }
525 m := make(map[string]*regexp.Regexp, length/2)
526 for i := 0; i < length; i += 2 {
527 regex, err := regexp.Compile(pairs[i+1])
528 if err != nil {
529 return nil, err
530 }
531 m[pairs[i]] = regex
532 }
533 return m, nil
534 }
535
536
537 func matchInArray(arr []string, value string) bool {
538 for _, v := range arr {
539 if v == value {
540 return true
541 }
542 }
543 return false
544 }
545
546
547 func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
548 for k, v := range toCheck {
549
550 if canonicalKey {
551 k = http.CanonicalHeaderKey(k)
552 }
553 if values := toMatch[k]; values == nil {
554 return false
555 } else if v != "" {
556
557
558 valueExists := false
559 for _, value := range values {
560 if v == value {
561 valueExists = true
562 break
563 }
564 }
565 if !valueExists {
566 return false
567 }
568 }
569 }
570 return true
571 }
572
573
574
575 func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
576 for k, v := range toCheck {
577
578 if canonicalKey {
579 k = http.CanonicalHeaderKey(k)
580 }
581 if values := toMatch[k]; values == nil {
582 return false
583 } else if v != nil {
584
585
586 valueExists := false
587 for _, value := range values {
588 if v.MatchString(value) {
589 valueExists = true
590 break
591 }
592 }
593 if !valueExists {
594 return false
595 }
596 }
597 }
598 return true
599 }
600
601
602 func methodNotAllowed(w http.ResponseWriter, r *http.Request) {
603 w.WriteHeader(http.StatusMethodNotAllowed)
604 }
605
606
607
608 func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) }
609
View as plain text