1 package chi
2
3 import (
4 "context"
5 "fmt"
6 "net/http"
7 "strings"
8 "sync"
9 )
10
11 var _ Router = &Mux{}
12
13
14
15
16
17
18
19
20
21 type Mux struct {
22
23 tree *node
24
25
26 middlewares []func(http.Handler) http.Handler
27
28
29
30 inline bool
31 parent *Mux
32
33
34
35 handler http.Handler
36
37
38 pool *sync.Pool
39
40
41 notFoundHandler http.HandlerFunc
42
43
44 methodNotAllowedHandler http.HandlerFunc
45 }
46
47
48
49 func NewMux() *Mux {
50 mux := &Mux{tree: &node{}, pool: &sync.Pool{}}
51 mux.pool.New = func() interface{} {
52 return NewRouteContext()
53 }
54 return mux
55 }
56
57
58
59
60 func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
61
62 if mx.handler == nil {
63 mx.NotFoundHandler().ServeHTTP(w, r)
64 return
65 }
66
67
68 rctx, _ := r.Context().Value(RouteCtxKey).(*Context)
69 if rctx != nil {
70 mx.handler.ServeHTTP(w, r)
71 return
72 }
73
74
75
76
77
78 rctx = mx.pool.Get().(*Context)
79 rctx.Reset()
80 rctx.Routes = mx
81
82
83 r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
84
85
86 mx.handler.ServeHTTP(w, r)
87 mx.pool.Put(rctx)
88 }
89
90
91
92
93
94
95
96 func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) {
97 if mx.handler != nil {
98 panic("chi: all middlewares must be defined before routes on a mux")
99 }
100 mx.middlewares = append(mx.middlewares, middlewares...)
101 }
102
103
104
105 func (mx *Mux) Handle(pattern string, handler http.Handler) {
106 mx.handle(mALL, pattern, handler)
107 }
108
109
110
111 func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) {
112 mx.handle(mALL, pattern, handlerFn)
113 }
114
115
116
117 func (mx *Mux) Method(method, pattern string, handler http.Handler) {
118 m, ok := methodMap[strings.ToUpper(method)]
119 if !ok {
120 panic(fmt.Sprintf("chi: '%s' http method is not supported.", method))
121 }
122 mx.handle(m, pattern, handler)
123 }
124
125
126
127 func (mx *Mux) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) {
128 mx.Method(method, pattern, handlerFn)
129 }
130
131
132
133 func (mx *Mux) Connect(pattern string, handlerFn http.HandlerFunc) {
134 mx.handle(mCONNECT, pattern, handlerFn)
135 }
136
137
138
139 func (mx *Mux) Delete(pattern string, handlerFn http.HandlerFunc) {
140 mx.handle(mDELETE, pattern, handlerFn)
141 }
142
143
144
145 func (mx *Mux) Get(pattern string, handlerFn http.HandlerFunc) {
146 mx.handle(mGET, pattern, handlerFn)
147 }
148
149
150
151 func (mx *Mux) Head(pattern string, handlerFn http.HandlerFunc) {
152 mx.handle(mHEAD, pattern, handlerFn)
153 }
154
155
156
157 func (mx *Mux) Options(pattern string, handlerFn http.HandlerFunc) {
158 mx.handle(mOPTIONS, pattern, handlerFn)
159 }
160
161
162
163 func (mx *Mux) Patch(pattern string, handlerFn http.HandlerFunc) {
164 mx.handle(mPATCH, pattern, handlerFn)
165 }
166
167
168
169 func (mx *Mux) Post(pattern string, handlerFn http.HandlerFunc) {
170 mx.handle(mPOST, pattern, handlerFn)
171 }
172
173
174
175 func (mx *Mux) Put(pattern string, handlerFn http.HandlerFunc) {
176 mx.handle(mPUT, pattern, handlerFn)
177 }
178
179
180
181 func (mx *Mux) Trace(pattern string, handlerFn http.HandlerFunc) {
182 mx.handle(mTRACE, pattern, handlerFn)
183 }
184
185
186
187 func (mx *Mux) NotFound(handlerFn http.HandlerFunc) {
188
189 m := mx
190 hFn := handlerFn
191 if mx.inline && mx.parent != nil {
192 m = mx.parent
193 hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
194 }
195
196
197 m.notFoundHandler = hFn
198 m.updateSubRoutes(func(subMux *Mux) {
199 if subMux.notFoundHandler == nil {
200 subMux.NotFound(hFn)
201 }
202 })
203 }
204
205
206
207 func (mx *Mux) MethodNotAllowed(handlerFn http.HandlerFunc) {
208
209 m := mx
210 hFn := handlerFn
211 if mx.inline && mx.parent != nil {
212 m = mx.parent
213 hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
214 }
215
216
217 m.methodNotAllowedHandler = hFn
218 m.updateSubRoutes(func(subMux *Mux) {
219 if subMux.methodNotAllowedHandler == nil {
220 subMux.MethodNotAllowed(hFn)
221 }
222 })
223 }
224
225
226 func (mx *Mux) With(middlewares ...func(http.Handler) http.Handler) Router {
227
228
229 if !mx.inline && mx.handler == nil {
230 mx.buildRouteHandler()
231 }
232
233
234 var mws Middlewares
235 if mx.inline {
236 mws = make(Middlewares, len(mx.middlewares))
237 copy(mws, mx.middlewares)
238 }
239 mws = append(mws, middlewares...)
240
241 im := &Mux{
242 pool: mx.pool, inline: true, parent: mx, tree: mx.tree, middlewares: mws,
243 notFoundHandler: mx.notFoundHandler, methodNotAllowedHandler: mx.methodNotAllowedHandler,
244 }
245
246 return im
247 }
248
249
250
251
252 func (mx *Mux) Group(fn func(r Router)) Router {
253 im := mx.With().(*Mux)
254 if fn != nil {
255 fn(im)
256 }
257 return im
258 }
259
260
261
262
263 func (mx *Mux) Route(pattern string, fn func(r Router)) Router {
264 subRouter := NewRouter()
265 if fn != nil {
266 fn(subRouter)
267 }
268 mx.Mount(pattern, subRouter)
269 return subRouter
270 }
271
272
273
274
275
276
277
278
279 func (mx *Mux) Mount(pattern string, handler http.Handler) {
280
281
282 if mx.tree.findPattern(pattern+"*") || mx.tree.findPattern(pattern+"/*") {
283 panic(fmt.Sprintf("chi: attempting to Mount() a handler on an existing path, '%s'", pattern))
284 }
285
286
287 subr, ok := handler.(*Mux)
288 if ok && subr.notFoundHandler == nil && mx.notFoundHandler != nil {
289 subr.NotFound(mx.notFoundHandler)
290 }
291 if ok && subr.methodNotAllowedHandler == nil && mx.methodNotAllowedHandler != nil {
292 subr.MethodNotAllowed(mx.methodNotAllowedHandler)
293 }
294
295 mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
296 rctx := RouteContext(r.Context())
297 rctx.RoutePath = mx.nextRoutePath(rctx)
298 handler.ServeHTTP(w, r)
299 })
300
301 if pattern == "" || pattern[len(pattern)-1] != '/' {
302 mx.handle(mALL|mSTUB, pattern, mountHandler)
303 mx.handle(mALL|mSTUB, pattern+"/", mountHandler)
304 pattern += "/"
305 }
306
307 method := mALL
308 subroutes, _ := handler.(Routes)
309 if subroutes != nil {
310 method |= mSTUB
311 }
312 n := mx.handle(method, pattern+"*", mountHandler)
313
314 if subroutes != nil {
315 n.subroutes = subroutes
316 }
317 }
318
319
320
321 func (mx *Mux) Routes() []Route {
322 return mx.tree.routes()
323 }
324
325
326 func (mx *Mux) Middlewares() Middlewares {
327 return mx.middlewares
328 }
329
330
331
332
333
334
335
336 func (mx *Mux) Match(rctx *Context, method, path string) bool {
337 m, ok := methodMap[method]
338 if !ok {
339 return false
340 }
341
342 node, _, h := mx.tree.FindRoute(rctx, m, path)
343
344 if node != nil && node.subroutes != nil {
345 rctx.RoutePath = mx.nextRoutePath(rctx)
346 return node.subroutes.Match(rctx, method, rctx.RoutePath)
347 }
348
349 return h != nil
350 }
351
352
353
354 func (mx *Mux) NotFoundHandler() http.HandlerFunc {
355 if mx.notFoundHandler != nil {
356 return mx.notFoundHandler
357 }
358 return http.NotFound
359 }
360
361
362
363 func (mx *Mux) MethodNotAllowedHandler() http.HandlerFunc {
364 if mx.methodNotAllowedHandler != nil {
365 return mx.methodNotAllowedHandler
366 }
367 return methodNotAllowedHandler
368 }
369
370
371
372
373
374 func (mx *Mux) buildRouteHandler() {
375 mx.handler = chain(mx.middlewares, http.HandlerFunc(mx.routeHTTP))
376 }
377
378
379
380 func (mx *Mux) handle(method methodTyp, pattern string, handler http.Handler) *node {
381 if len(pattern) == 0 || pattern[0] != '/' {
382 panic(fmt.Sprintf("chi: routing pattern must begin with '/' in '%s'", pattern))
383 }
384
385
386 if !mx.inline && mx.handler == nil {
387 mx.buildRouteHandler()
388 }
389
390
391 var h http.Handler
392 if mx.inline {
393 mx.handler = http.HandlerFunc(mx.routeHTTP)
394 h = Chain(mx.middlewares...).Handler(handler)
395 } else {
396 h = handler
397 }
398
399
400 return mx.tree.InsertRoute(method, pattern, h)
401 }
402
403
404
405 func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) {
406
407 rctx := r.Context().Value(RouteCtxKey).(*Context)
408
409
410 routePath := rctx.RoutePath
411 if routePath == "" {
412 if r.URL.RawPath != "" {
413 routePath = r.URL.RawPath
414 } else {
415 routePath = r.URL.Path
416 }
417 }
418
419
420 if rctx.RouteMethod == "" {
421 rctx.RouteMethod = r.Method
422 }
423 method, ok := methodMap[rctx.RouteMethod]
424 if !ok {
425 mx.MethodNotAllowedHandler().ServeHTTP(w, r)
426 return
427 }
428
429
430 if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil {
431 h.ServeHTTP(w, r)
432 return
433 }
434 if rctx.methodNotAllowed {
435 mx.MethodNotAllowedHandler().ServeHTTP(w, r)
436 } else {
437 mx.NotFoundHandler().ServeHTTP(w, r)
438 }
439 }
440
441 func (mx *Mux) nextRoutePath(rctx *Context) string {
442 routePath := "/"
443 nx := len(rctx.routeParams.Keys) - 1
444 if nx >= 0 && rctx.routeParams.Keys[nx] == "*" && len(rctx.routeParams.Values) > nx {
445 routePath = "/" + rctx.routeParams.Values[nx]
446 }
447 return routePath
448 }
449
450
451 func (mx *Mux) updateSubRoutes(fn func(subMux *Mux)) {
452 for _, r := range mx.tree.routes() {
453 subMux, ok := r.SubRoutes.(*Mux)
454 if !ok {
455 continue
456 }
457 fn(subMux)
458 }
459 }
460
461
462
463 func methodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
464 w.WriteHeader(405)
465 w.Write(nil)
466 }
467
View as plain text