...

Source file src/github.com/go-openapi/runtime/middleware/router.go

Documentation: github.com/go-openapi/runtime/middleware

     1  // Copyright 2015 go-swagger maintainers
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package middleware
    16  
    17  import (
    18  	"fmt"
    19  	"net/http"
    20  	"net/url"
    21  	fpath "path"
    22  	"regexp"
    23  	"strings"
    24  
    25  	"github.com/go-openapi/runtime/logger"
    26  	"github.com/go-openapi/runtime/security"
    27  	"github.com/go-openapi/swag"
    28  
    29  	"github.com/go-openapi/analysis"
    30  	"github.com/go-openapi/errors"
    31  	"github.com/go-openapi/loads"
    32  	"github.com/go-openapi/spec"
    33  	"github.com/go-openapi/strfmt"
    34  
    35  	"github.com/go-openapi/runtime"
    36  	"github.com/go-openapi/runtime/middleware/denco"
    37  )
    38  
    39  // RouteParam is a object to capture route params in a framework agnostic way.
    40  // implementations of the muxer should use these route params to communicate with the
    41  // swagger framework
    42  type RouteParam struct {
    43  	Name  string
    44  	Value string
    45  }
    46  
    47  // RouteParams the collection of route params
    48  type RouteParams []RouteParam
    49  
    50  // Get gets the value for the route param for the specified key
    51  func (r RouteParams) Get(name string) string {
    52  	vv, _, _ := r.GetOK(name)
    53  	if len(vv) > 0 {
    54  		return vv[len(vv)-1]
    55  	}
    56  	return ""
    57  }
    58  
    59  // GetOK gets the value but also returns booleans to indicate if a key or value
    60  // is present. This aids in validation and satisfies an interface in use there
    61  //
    62  // The returned values are: data, has key, has value
    63  func (r RouteParams) GetOK(name string) ([]string, bool, bool) {
    64  	for _, p := range r {
    65  		if p.Name == name {
    66  			return []string{p.Value}, true, p.Value != ""
    67  		}
    68  	}
    69  	return nil, false, false
    70  }
    71  
    72  // NewRouter creates a new context-aware router middleware
    73  func NewRouter(ctx *Context, next http.Handler) http.Handler {
    74  	if ctx.router == nil {
    75  		ctx.router = DefaultRouter(ctx.spec, ctx.api, WithDefaultRouterLoggerFunc(ctx.debugLogf))
    76  	}
    77  
    78  	return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
    79  		if _, rCtx, ok := ctx.RouteInfo(r); ok {
    80  			next.ServeHTTP(rw, rCtx)
    81  			return
    82  		}
    83  
    84  		// Not found, check if it exists in the other methods first
    85  		if others := ctx.AllowedMethods(r); len(others) > 0 {
    86  			ctx.Respond(rw, r, ctx.analyzer.RequiredProduces(), nil, errors.MethodNotAllowed(r.Method, others))
    87  			return
    88  		}
    89  
    90  		ctx.Respond(rw, r, ctx.analyzer.RequiredProduces(), nil, errors.NotFound("path %s was not found", r.URL.EscapedPath()))
    91  	})
    92  }
    93  
    94  // RoutableAPI represents an interface for things that can serve
    95  // as a provider of implementations for the swagger router
    96  type RoutableAPI interface {
    97  	HandlerFor(string, string) (http.Handler, bool)
    98  	ServeErrorFor(string) func(http.ResponseWriter, *http.Request, error)
    99  	ConsumersFor([]string) map[string]runtime.Consumer
   100  	ProducersFor([]string) map[string]runtime.Producer
   101  	AuthenticatorsFor(map[string]spec.SecurityScheme) map[string]runtime.Authenticator
   102  	Authorizer() runtime.Authorizer
   103  	Formats() strfmt.Registry
   104  	DefaultProduces() string
   105  	DefaultConsumes() string
   106  }
   107  
   108  // Router represents a swagger-aware router
   109  type Router interface {
   110  	Lookup(method, path string) (*MatchedRoute, bool)
   111  	OtherMethods(method, path string) []string
   112  }
   113  
   114  type defaultRouteBuilder struct {
   115  	spec      *loads.Document
   116  	analyzer  *analysis.Spec
   117  	api       RoutableAPI
   118  	records   map[string][]denco.Record
   119  	debugLogf func(string, ...any) // a logging function to debug context and all components using it
   120  }
   121  
   122  type defaultRouter struct {
   123  	spec      *loads.Document
   124  	routers   map[string]*denco.Router
   125  	debugLogf func(string, ...any) // a logging function to debug context and all components using it
   126  }
   127  
   128  func newDefaultRouteBuilder(spec *loads.Document, api RoutableAPI, opts ...DefaultRouterOpt) *defaultRouteBuilder {
   129  	var o defaultRouterOpts
   130  	for _, apply := range opts {
   131  		apply(&o)
   132  	}
   133  	if o.debugLogf == nil {
   134  		o.debugLogf = debugLogfFunc(nil) // defaults to standard logger
   135  	}
   136  
   137  	return &defaultRouteBuilder{
   138  		spec:      spec,
   139  		analyzer:  analysis.New(spec.Spec()),
   140  		api:       api,
   141  		records:   make(map[string][]denco.Record),
   142  		debugLogf: o.debugLogf,
   143  	}
   144  }
   145  
   146  // DefaultRouterOpt allows to inject optional behavior to the default router.
   147  type DefaultRouterOpt func(*defaultRouterOpts)
   148  
   149  type defaultRouterOpts struct {
   150  	debugLogf func(string, ...any)
   151  }
   152  
   153  // WithDefaultRouterLogger sets the debug logger for the default router.
   154  //
   155  // This is enabled only in DEBUG mode.
   156  func WithDefaultRouterLogger(lg logger.Logger) DefaultRouterOpt {
   157  	return func(o *defaultRouterOpts) {
   158  		o.debugLogf = debugLogfFunc(lg)
   159  	}
   160  }
   161  
   162  // WithDefaultRouterLoggerFunc sets a logging debug method for the default router.
   163  func WithDefaultRouterLoggerFunc(fn func(string, ...any)) DefaultRouterOpt {
   164  	return func(o *defaultRouterOpts) {
   165  		o.debugLogf = fn
   166  	}
   167  }
   168  
   169  // DefaultRouter creates a default implementation of the router
   170  func DefaultRouter(spec *loads.Document, api RoutableAPI, opts ...DefaultRouterOpt) Router {
   171  	builder := newDefaultRouteBuilder(spec, api, opts...)
   172  	if spec != nil {
   173  		for method, paths := range builder.analyzer.Operations() {
   174  			for path, operation := range paths {
   175  				fp := fpath.Join(spec.BasePath(), path)
   176  				builder.debugLogf("adding route %s %s %q", method, fp, operation.ID)
   177  				builder.AddRoute(method, fp, operation)
   178  			}
   179  		}
   180  	}
   181  	return builder.Build()
   182  }
   183  
   184  // RouteAuthenticator is an authenticator that can compose several authenticators together.
   185  // It also knows when it contains an authenticator that allows for anonymous pass through.
   186  // Contains a group of 1 or more authenticators that have a logical AND relationship
   187  type RouteAuthenticator struct {
   188  	Authenticator  map[string]runtime.Authenticator
   189  	Schemes        []string
   190  	Scopes         map[string][]string
   191  	allScopes      []string
   192  	commonScopes   []string
   193  	allowAnonymous bool
   194  }
   195  
   196  func (ra *RouteAuthenticator) AllowsAnonymous() bool {
   197  	return ra.allowAnonymous
   198  }
   199  
   200  // AllScopes returns a list of unique scopes that is the combination
   201  // of all the scopes in the requirements
   202  func (ra *RouteAuthenticator) AllScopes() []string {
   203  	return ra.allScopes
   204  }
   205  
   206  // CommonScopes returns a list of unique scopes that are common in all the
   207  // scopes in the requirements
   208  func (ra *RouteAuthenticator) CommonScopes() []string {
   209  	return ra.commonScopes
   210  }
   211  
   212  // Authenticate Authenticator interface implementation
   213  func (ra *RouteAuthenticator) Authenticate(req *http.Request, route *MatchedRoute) (bool, interface{}, error) {
   214  	if ra.allowAnonymous {
   215  		route.Authenticator = ra
   216  		return true, nil, nil
   217  	}
   218  	// iterate in proper order
   219  	var lastResult interface{}
   220  	for _, scheme := range ra.Schemes {
   221  		if authenticator, ok := ra.Authenticator[scheme]; ok {
   222  			applies, princ, err := authenticator.Authenticate(&security.ScopedAuthRequest{
   223  				Request:        req,
   224  				RequiredScopes: ra.Scopes[scheme],
   225  			})
   226  			if !applies {
   227  				return false, nil, nil
   228  			}
   229  			if err != nil {
   230  				route.Authenticator = ra
   231  				return true, nil, err
   232  			}
   233  			lastResult = princ
   234  		}
   235  	}
   236  	route.Authenticator = ra
   237  	return true, lastResult, nil
   238  }
   239  
   240  func stringSliceUnion(slices ...[]string) []string {
   241  	unique := make(map[string]struct{})
   242  	var result []string
   243  	for _, slice := range slices {
   244  		for _, entry := range slice {
   245  			if _, ok := unique[entry]; ok {
   246  				continue
   247  			}
   248  			unique[entry] = struct{}{}
   249  			result = append(result, entry)
   250  		}
   251  	}
   252  	return result
   253  }
   254  
   255  func stringSliceIntersection(slices ...[]string) []string {
   256  	unique := make(map[string]int)
   257  	var intersection []string
   258  
   259  	total := len(slices)
   260  	var emptyCnt int
   261  	for _, slice := range slices {
   262  		if len(slice) == 0 {
   263  			emptyCnt++
   264  			continue
   265  		}
   266  
   267  		for _, entry := range slice {
   268  			unique[entry]++
   269  			if unique[entry] == total-emptyCnt { // this entry appeared in all the non-empty slices
   270  				intersection = append(intersection, entry)
   271  			}
   272  		}
   273  	}
   274  
   275  	return intersection
   276  }
   277  
   278  // RouteAuthenticators represents a group of authenticators that represent a logical OR
   279  type RouteAuthenticators []RouteAuthenticator
   280  
   281  // AllowsAnonymous returns true when there is an authenticator that means optional auth
   282  func (ras RouteAuthenticators) AllowsAnonymous() bool {
   283  	for _, ra := range ras {
   284  		if ra.AllowsAnonymous() {
   285  			return true
   286  		}
   287  	}
   288  	return false
   289  }
   290  
   291  // Authenticate method implemention so this collection can be used as authenticator
   292  func (ras RouteAuthenticators) Authenticate(req *http.Request, route *MatchedRoute) (bool, interface{}, error) {
   293  	var lastError error
   294  	var allowsAnon bool
   295  	var anonAuth RouteAuthenticator
   296  
   297  	for _, ra := range ras {
   298  		if ra.AllowsAnonymous() {
   299  			anonAuth = ra
   300  			allowsAnon = true
   301  			continue
   302  		}
   303  		applies, usr, err := ra.Authenticate(req, route)
   304  		if !applies || err != nil || usr == nil {
   305  			if err != nil {
   306  				lastError = err
   307  			}
   308  			continue
   309  		}
   310  		return applies, usr, nil
   311  	}
   312  
   313  	if allowsAnon && lastError == nil {
   314  		route.Authenticator = &anonAuth
   315  		return true, nil, lastError
   316  	}
   317  	return lastError != nil, nil, lastError
   318  }
   319  
   320  type routeEntry struct {
   321  	PathPattern    string
   322  	BasePath       string
   323  	Operation      *spec.Operation
   324  	Consumes       []string
   325  	Consumers      map[string]runtime.Consumer
   326  	Produces       []string
   327  	Producers      map[string]runtime.Producer
   328  	Parameters     map[string]spec.Parameter
   329  	Handler        http.Handler
   330  	Formats        strfmt.Registry
   331  	Binder         *UntypedRequestBinder
   332  	Authenticators RouteAuthenticators
   333  	Authorizer     runtime.Authorizer
   334  }
   335  
   336  // MatchedRoute represents the route that was matched in this request
   337  type MatchedRoute struct {
   338  	routeEntry
   339  	Params        RouteParams
   340  	Consumer      runtime.Consumer
   341  	Producer      runtime.Producer
   342  	Authenticator *RouteAuthenticator
   343  }
   344  
   345  // HasAuth returns true when the route has a security requirement defined
   346  func (m *MatchedRoute) HasAuth() bool {
   347  	return len(m.Authenticators) > 0
   348  }
   349  
   350  // NeedsAuth returns true when the request still
   351  // needs to perform authentication
   352  func (m *MatchedRoute) NeedsAuth() bool {
   353  	return m.HasAuth() && m.Authenticator == nil
   354  }
   355  
   356  func (d *defaultRouter) Lookup(method, path string) (*MatchedRoute, bool) {
   357  	mth := strings.ToUpper(method)
   358  	d.debugLogf("looking up route for %s %s", method, path)
   359  	if Debug {
   360  		if len(d.routers) == 0 {
   361  			d.debugLogf("there are no known routers")
   362  		}
   363  		for meth := range d.routers {
   364  			d.debugLogf("got a router for %s", meth)
   365  		}
   366  	}
   367  	if router, ok := d.routers[mth]; ok {
   368  		if m, rp, ok := router.Lookup(fpath.Clean(path)); ok && m != nil {
   369  			if entry, ok := m.(*routeEntry); ok {
   370  				d.debugLogf("found a route for %s %s with %d parameters", method, path, len(entry.Parameters))
   371  				var params RouteParams
   372  				for _, p := range rp {
   373  					v, err := url.PathUnescape(p.Value)
   374  					if err != nil {
   375  						d.debugLogf("failed to escape %q: %v", p.Value, err)
   376  						v = p.Value
   377  					}
   378  					// a workaround to handle fragment/composing parameters until they are supported in denco router
   379  					// check if this parameter is a fragment within a path segment
   380  					if xpos := strings.Index(entry.PathPattern, fmt.Sprintf("{%s}", p.Name)) + len(p.Name) + 2; xpos < len(entry.PathPattern) && entry.PathPattern[xpos] != '/' {
   381  						// extract fragment parameters
   382  						ep := strings.Split(entry.PathPattern[xpos:], "/")[0]
   383  						pnames, pvalues := decodeCompositParams(p.Name, v, ep, nil, nil)
   384  						for i, pname := range pnames {
   385  							params = append(params, RouteParam{Name: pname, Value: pvalues[i]})
   386  						}
   387  					} else {
   388  						// use the parameter directly
   389  						params = append(params, RouteParam{Name: p.Name, Value: v})
   390  					}
   391  				}
   392  				return &MatchedRoute{routeEntry: *entry, Params: params}, true
   393  			}
   394  		} else {
   395  			d.debugLogf("couldn't find a route by path for %s %s", method, path)
   396  		}
   397  	} else {
   398  		d.debugLogf("couldn't find a route by method for %s %s", method, path)
   399  	}
   400  	return nil, false
   401  }
   402  
   403  func (d *defaultRouter) OtherMethods(method, path string) []string {
   404  	mn := strings.ToUpper(method)
   405  	var methods []string
   406  	for k, v := range d.routers {
   407  		if k != mn {
   408  			if _, _, ok := v.Lookup(fpath.Clean(path)); ok {
   409  				methods = append(methods, k)
   410  				continue
   411  			}
   412  		}
   413  	}
   414  	return methods
   415  }
   416  
   417  func (d *defaultRouter) SetLogger(lg logger.Logger) {
   418  	d.debugLogf = debugLogfFunc(lg)
   419  }
   420  
   421  // convert swagger parameters per path segment into a denco parameter as multiple parameters per segment are not supported in denco
   422  var pathConverter = regexp.MustCompile(`{(.+?)}([^/]*)`)
   423  
   424  func decodeCompositParams(name string, value string, pattern string, names []string, values []string) ([]string, []string) {
   425  	pleft := strings.Index(pattern, "{")
   426  	names = append(names, name)
   427  	if pleft < 0 {
   428  		if strings.HasSuffix(value, pattern) {
   429  			values = append(values, value[:len(value)-len(pattern)])
   430  		} else {
   431  			values = append(values, "")
   432  		}
   433  	} else {
   434  		toskip := pattern[:pleft]
   435  		pright := strings.Index(pattern, "}")
   436  		vright := strings.Index(value, toskip)
   437  		if vright >= 0 {
   438  			values = append(values, value[:vright])
   439  		} else {
   440  			values = append(values, "")
   441  			value = ""
   442  		}
   443  		return decodeCompositParams(pattern[pleft+1:pright], value[vright+len(toskip):], pattern[pright+1:], names, values)
   444  	}
   445  	return names, values
   446  }
   447  
   448  func (d *defaultRouteBuilder) AddRoute(method, path string, operation *spec.Operation) {
   449  	mn := strings.ToUpper(method)
   450  
   451  	bp := fpath.Clean(d.spec.BasePath())
   452  	if len(bp) > 0 && bp[len(bp)-1] == '/' {
   453  		bp = bp[:len(bp)-1]
   454  	}
   455  
   456  	d.debugLogf("operation: %#v", *operation)
   457  	if handler, ok := d.api.HandlerFor(method, strings.TrimPrefix(path, bp)); ok {
   458  		consumes := d.analyzer.ConsumesFor(operation)
   459  		produces := d.analyzer.ProducesFor(operation)
   460  		parameters := d.analyzer.ParamsFor(method, strings.TrimPrefix(path, bp))
   461  
   462  		// add API defaults if not part of the spec
   463  		if defConsumes := d.api.DefaultConsumes(); defConsumes != "" && !swag.ContainsStringsCI(consumes, defConsumes) {
   464  			consumes = append(consumes, defConsumes)
   465  		}
   466  
   467  		if defProduces := d.api.DefaultProduces(); defProduces != "" && !swag.ContainsStringsCI(produces, defProduces) {
   468  			produces = append(produces, defProduces)
   469  		}
   470  
   471  		requestBinder := NewUntypedRequestBinder(parameters, d.spec.Spec(), d.api.Formats())
   472  		requestBinder.setDebugLogf(d.debugLogf)
   473  		record := denco.NewRecord(pathConverter.ReplaceAllString(path, ":$1"), &routeEntry{
   474  			BasePath:       bp,
   475  			PathPattern:    path,
   476  			Operation:      operation,
   477  			Handler:        handler,
   478  			Consumes:       consumes,
   479  			Produces:       produces,
   480  			Consumers:      d.api.ConsumersFor(normalizeOffers(consumes)),
   481  			Producers:      d.api.ProducersFor(normalizeOffers(produces)),
   482  			Parameters:     parameters,
   483  			Formats:        d.api.Formats(),
   484  			Binder:         requestBinder,
   485  			Authenticators: d.buildAuthenticators(operation),
   486  			Authorizer:     d.api.Authorizer(),
   487  		})
   488  		d.records[mn] = append(d.records[mn], record)
   489  	}
   490  }
   491  
   492  func (d *defaultRouteBuilder) buildAuthenticators(operation *spec.Operation) RouteAuthenticators {
   493  	requirements := d.analyzer.SecurityRequirementsFor(operation)
   494  	auths := make([]RouteAuthenticator, 0, len(requirements))
   495  	for _, reqs := range requirements {
   496  		schemes := make([]string, 0, len(reqs))
   497  		scopes := make(map[string][]string, len(reqs))
   498  		scopeSlices := make([][]string, 0, len(reqs))
   499  		for _, req := range reqs {
   500  			schemes = append(schemes, req.Name)
   501  			scopes[req.Name] = req.Scopes
   502  			scopeSlices = append(scopeSlices, req.Scopes)
   503  		}
   504  
   505  		definitions := d.analyzer.SecurityDefinitionsForRequirements(reqs)
   506  		authenticators := d.api.AuthenticatorsFor(definitions)
   507  		auths = append(auths, RouteAuthenticator{
   508  			Authenticator:  authenticators,
   509  			Schemes:        schemes,
   510  			Scopes:         scopes,
   511  			allScopes:      stringSliceUnion(scopeSlices...),
   512  			commonScopes:   stringSliceIntersection(scopeSlices...),
   513  			allowAnonymous: len(reqs) == 1 && reqs[0].Name == "",
   514  		})
   515  	}
   516  	return auths
   517  }
   518  
   519  func (d *defaultRouteBuilder) Build() *defaultRouter {
   520  	routers := make(map[string]*denco.Router)
   521  	for method, records := range d.records {
   522  		router := denco.New()
   523  		_ = router.Build(records)
   524  		routers[method] = router
   525  	}
   526  	return &defaultRouter{
   527  		spec:      d.spec,
   528  		routers:   routers,
   529  		debugLogf: d.debugLogf,
   530  	}
   531  }
   532  

View as plain text