...
1 package mux
2
3 import (
4 "net/http"
5 "strings"
6 )
7
8
9
10
11 type MiddlewareFunc func(http.Handler) http.Handler
12
13
14 type middleware interface {
15 Middleware(handler http.Handler) http.Handler
16 }
17
18
19 func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
20 return mw(handler)
21 }
22
23
24 func (r *Router) Use(mwf ...MiddlewareFunc) {
25 for _, fn := range mwf {
26 r.middlewares = append(r.middlewares, fn)
27 }
28 }
29
30
31 func (r *Router) useInterface(mw middleware) {
32 r.middlewares = append(r.middlewares, mw)
33 }
34
35
36
37
38
39 func CORSMethodMiddleware(r *Router) MiddlewareFunc {
40 return func(next http.Handler) http.Handler {
41 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
42 allMethods, err := getAllMethodsForRoute(r, req)
43 if err == nil {
44 for _, v := range allMethods {
45 if v == http.MethodOptions {
46 w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
47 }
48 }
49 }
50
51 next.ServeHTTP(w, req)
52 })
53 }
54 }
55
56
57
58 func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
59 var allMethods []string
60
61 for _, route := range r.routes {
62 var match RouteMatch
63 if route.Match(req, &match) || match.MatchErr == ErrMethodMismatch {
64 methods, err := route.GetMethods()
65 if err != nil {
66 return nil, err
67 }
68
69 allMethods = append(allMethods, methods...)
70 }
71 }
72
73 return allMethods, nil
74 }
75
View as plain text