...

Source file src/github.com/go-chi/chi/middleware/route_headers.go

Documentation: github.com/go-chi/chi/middleware

     1  package middleware
     2  
     3  import (
     4  	"net/http"
     5  	"strings"
     6  )
     7  
     8  // RouteHeaders is a neat little header-based router that allows you to direct
     9  // the flow of a request through a middleware stack based on a request header.
    10  //
    11  // For example, lets say you'd like to setup multiple routers depending on the
    12  // request Host header, you could then do something as so:
    13  //
    14  // r := chi.NewRouter()
    15  // rSubdomain := chi.NewRouter()
    16  //
    17  // r.Use(middleware.RouteHeaders().
    18  //   Route("Host", "example.com", middleware.New(r)).
    19  //   Route("Host", "*.example.com", middleware.New(rSubdomain)).
    20  //   Handler)
    21  //
    22  // r.Get("/", h)
    23  // rSubdomain.Get("/", h2)
    24  //
    25  //
    26  // Another example, imagine you want to setup multiple CORS handlers, where for
    27  // your origin servers you allow authorized requests, but for third-party public
    28  // requests, authorization is disabled.
    29  //
    30  // r := chi.NewRouter()
    31  //
    32  // r.Use(middleware.RouteHeaders().
    33  //   Route("Origin", "https://app.skyweaver.net", cors.Handler(cors.Options{
    34  // 	   AllowedOrigins:   []string{"https://api.skyweaver.net"},
    35  // 	   AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
    36  // 	   AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type"},
    37  // 	   AllowCredentials: true, // <----------<<< allow credentials
    38  //   })).
    39  //   Route("Origin", "*", cors.Handler(cors.Options{
    40  // 	   AllowedOrigins:   []string{"*"},
    41  // 	   AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
    42  // 	   AllowedHeaders:   []string{"Accept", "Content-Type"},
    43  // 	   AllowCredentials: false, // <----------<<< do not allow credentials
    44  //   })).
    45  //   Handler)
    46  //
    47  func RouteHeaders() HeaderRouter {
    48  	return HeaderRouter{}
    49  }
    50  
    51  type HeaderRouter map[string][]HeaderRoute
    52  
    53  func (hr HeaderRouter) Route(header string, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
    54  	header = strings.ToLower(header)
    55  	k := hr[header]
    56  	if k == nil {
    57  		hr[header] = []HeaderRoute{}
    58  	}
    59  	hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler})
    60  	return hr
    61  }
    62  
    63  func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
    64  	header = strings.ToLower(header)
    65  	k := hr[header]
    66  	if k == nil {
    67  		hr[header] = []HeaderRoute{}
    68  	}
    69  	patterns := []Pattern{}
    70  	for _, m := range match {
    71  		patterns = append(patterns, NewPattern(m))
    72  	}
    73  	hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler})
    74  	return hr
    75  }
    76  
    77  func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter {
    78  	hr["*"] = []HeaderRoute{{Middleware: handler}}
    79  	return hr
    80  }
    81  
    82  func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
    83  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    84  		if len(hr) == 0 {
    85  			// skip if no routes set
    86  			next.ServeHTTP(w, r)
    87  		}
    88  
    89  		// find first matching header route, and continue
    90  		for header, matchers := range hr {
    91  			headerValue := r.Header.Get(header)
    92  			if headerValue == "" {
    93  				continue
    94  			}
    95  			headerValue = strings.ToLower(headerValue)
    96  			for _, matcher := range matchers {
    97  				if matcher.IsMatch(headerValue) {
    98  					matcher.Middleware(next).ServeHTTP(w, r)
    99  					return
   100  				}
   101  			}
   102  		}
   103  
   104  		// if no match, check for "*" default route
   105  		matcher, ok := hr["*"]
   106  		if !ok || matcher[0].Middleware == nil {
   107  			next.ServeHTTP(w, r)
   108  			return
   109  		}
   110  		matcher[0].Middleware(next).ServeHTTP(w, r)
   111  	})
   112  }
   113  
   114  type HeaderRoute struct {
   115  	MatchAny   []Pattern
   116  	MatchOne   Pattern
   117  	Middleware func(next http.Handler) http.Handler
   118  }
   119  
   120  func (r HeaderRoute) IsMatch(value string) bool {
   121  	if len(r.MatchAny) > 0 {
   122  		for _, m := range r.MatchAny {
   123  			if m.Match(value) {
   124  				return true
   125  			}
   126  		}
   127  	} else if r.MatchOne.Match(value) {
   128  		return true
   129  	}
   130  	return false
   131  }
   132  
   133  type Pattern struct {
   134  	prefix   string
   135  	suffix   string
   136  	wildcard bool
   137  }
   138  
   139  func NewPattern(value string) Pattern {
   140  	p := Pattern{}
   141  	if i := strings.IndexByte(value, '*'); i >= 0 {
   142  		p.wildcard = true
   143  		p.prefix = value[0:i]
   144  		p.suffix = value[i+1:]
   145  	} else {
   146  		p.prefix = value
   147  	}
   148  	return p
   149  }
   150  
   151  func (p Pattern) Match(v string) bool {
   152  	if !p.wildcard {
   153  		if p.prefix == v {
   154  			return true
   155  		} else {
   156  			return false
   157  		}
   158  	}
   159  	return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix)
   160  }
   161  

View as plain text