...

Source file src/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/mux.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/v2/runtime

     1  package runtime
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"net/textproto"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
    13  	"google.golang.org/grpc/codes"
    14  	"google.golang.org/grpc/grpclog"
    15  	"google.golang.org/grpc/health/grpc_health_v1"
    16  	"google.golang.org/grpc/metadata"
    17  	"google.golang.org/grpc/status"
    18  	"google.golang.org/protobuf/proto"
    19  )
    20  
    21  // UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
    22  type UnescapingMode int
    23  
    24  const (
    25  	// UnescapingModeLegacy is the default V2 behavior, which escapes the entire
    26  	// path string before doing any routing.
    27  	UnescapingModeLegacy UnescapingMode = iota
    28  
    29  	// UnescapingModeAllExceptReserved unescapes all path parameters except RFC 6570
    30  	// reserved characters.
    31  	UnescapingModeAllExceptReserved
    32  
    33  	// UnescapingModeAllExceptSlash unescapes URL path parameters except path
    34  	// separators, which will be left as "%2F".
    35  	UnescapingModeAllExceptSlash
    36  
    37  	// UnescapingModeAllCharacters unescapes all URL path parameters.
    38  	UnescapingModeAllCharacters
    39  
    40  	// UnescapingModeDefault is the default escaping type.
    41  	// TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
    42  	// reference implementation
    43  	UnescapingModeDefault = UnescapingModeLegacy
    44  )
    45  
    46  var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
    47  
    48  // A HandlerFunc handles a specific pair of path pattern and HTTP method.
    49  type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
    50  
    51  // ServeMux is a request multiplexer for grpc-gateway.
    52  // It matches http requests to patterns and invokes the corresponding handler.
    53  type ServeMux struct {
    54  	// handlers maps HTTP method to a list of handlers.
    55  	handlers                  map[string][]handler
    56  	forwardResponseOptions    []func(context.Context, http.ResponseWriter, proto.Message) error
    57  	marshalers                marshalerRegistry
    58  	incomingHeaderMatcher     HeaderMatcherFunc
    59  	outgoingHeaderMatcher     HeaderMatcherFunc
    60  	outgoingTrailerMatcher    HeaderMatcherFunc
    61  	metadataAnnotators        []func(context.Context, *http.Request) metadata.MD
    62  	errorHandler              ErrorHandlerFunc
    63  	streamErrorHandler        StreamErrorHandlerFunc
    64  	routingErrorHandler       RoutingErrorHandlerFunc
    65  	disablePathLengthFallback bool
    66  	unescapingMode            UnescapingMode
    67  }
    68  
    69  // ServeMuxOption is an option that can be given to a ServeMux on construction.
    70  type ServeMuxOption func(*ServeMux)
    71  
    72  // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
    73  //
    74  // forwardResponseOption is an option that will be called on the relevant context.Context,
    75  // http.ResponseWriter, and proto.Message before every forwarded response.
    76  //
    77  // The message may be nil in the case where just a header is being sent.
    78  func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
    79  	return func(serveMux *ServeMux) {
    80  		serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
    81  	}
    82  }
    83  
    84  // WithUnescapingMode sets the escaping type. See the definitions of UnescapingMode
    85  // for more information.
    86  func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
    87  	return func(serveMux *ServeMux) {
    88  		serveMux.unescapingMode = mode
    89  	}
    90  }
    91  
    92  // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
    93  // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
    94  // done with careful consideration.
    95  func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
    96  	return func(serveMux *ServeMux) {
    97  		currentQueryParser = queryParameterParser
    98  	}
    99  }
   100  
   101  // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
   102  type HeaderMatcherFunc func(string) (string, bool)
   103  
   104  // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
   105  // keys (as specified by the IANA, e.g: Accept, Cookie, Host) to the gRPC metadata with the grpcgateway- prefix. If you want to know which headers are considered permanent, you can view the isPermanentHTTPHeader function.
   106  // HTTP headers that start with 'Grpc-Metadata-' are mapped to gRPC metadata after removing the prefix 'Grpc-Metadata-'.
   107  // Other headers are not added to the gRPC metadata.
   108  func DefaultHeaderMatcher(key string) (string, bool) {
   109  	switch key = textproto.CanonicalMIMEHeaderKey(key); {
   110  	case isPermanentHTTPHeader(key):
   111  		return MetadataPrefix + key, true
   112  	case strings.HasPrefix(key, MetadataHeaderPrefix):
   113  		return key[len(MetadataHeaderPrefix):], true
   114  	}
   115  	return "", false
   116  }
   117  
   118  func defaultOutgoingHeaderMatcher(key string) (string, bool) {
   119  	return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
   120  }
   121  
   122  func defaultOutgoingTrailerMatcher(key string) (string, bool) {
   123  	return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true
   124  }
   125  
   126  // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
   127  //
   128  // This matcher will be called with each header in http.Request. If matcher returns true, that header will be
   129  // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return the modified header.
   130  func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
   131  	for _, header := range fn.matchedMalformedHeaders() {
   132  		grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header)
   133  	}
   134  
   135  	return func(mux *ServeMux) {
   136  		mux.incomingHeaderMatcher = fn
   137  	}
   138  }
   139  
   140  // matchedMalformedHeaders returns the malformed headers that would be forwarded to gRPC server.
   141  func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string {
   142  	if fn == nil {
   143  		return nil
   144  	}
   145  	headers := make([]string, 0)
   146  	for header := range malformedHTTPHeaders {
   147  		out, accept := fn(header)
   148  		if accept && isMalformedHTTPHeader(out) {
   149  			headers = append(headers, out)
   150  		}
   151  	}
   152  	return headers
   153  }
   154  
   155  // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
   156  //
   157  // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
   158  // passed to http response returned from gateway. To transform the header before passing to response,
   159  // matcher should return the modified header.
   160  func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
   161  	return func(mux *ServeMux) {
   162  		mux.outgoingHeaderMatcher = fn
   163  	}
   164  }
   165  
   166  // WithOutgoingTrailerMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
   167  //
   168  // This matcher will be called with each header in response trailer metadata. If matcher returns true, that header will be
   169  // passed to http response returned from gateway. To transform the header before passing to response,
   170  // matcher should return the modified header.
   171  func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption {
   172  	return func(mux *ServeMux) {
   173  		mux.outgoingTrailerMatcher = fn
   174  	}
   175  }
   176  
   177  // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
   178  //
   179  // This can be used by services that need to read from http.Request and modify gRPC context. A common use case
   180  // is reading token from cookie and adding it in gRPC context.
   181  func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
   182  	return func(serveMux *ServeMux) {
   183  		serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
   184  	}
   185  }
   186  
   187  // WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
   188  //
   189  // This can be used to configure a custom error response.
   190  func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
   191  	return func(serveMux *ServeMux) {
   192  		serveMux.errorHandler = fn
   193  	}
   194  }
   195  
   196  // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
   197  // error handler, which allows for customizing the error trailer for server-streaming
   198  // calls.
   199  //
   200  // For stream errors that occur before any response has been written, the mux's
   201  // ErrorHandler will be invoked. However, once data has been written, the errors must
   202  // be handled differently: they must be included in the response body. The response body's
   203  // final message will include the error details returned by the stream error handler.
   204  func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
   205  	return func(serveMux *ServeMux) {
   206  		serveMux.streamErrorHandler = fn
   207  	}
   208  }
   209  
   210  // WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to  handle http routing errors.
   211  //
   212  // Method called for errors which can happen before gRPC route selected or executed.
   213  // The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
   214  func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
   215  	return func(serveMux *ServeMux) {
   216  		serveMux.routingErrorHandler = fn
   217  	}
   218  }
   219  
   220  // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
   221  func WithDisablePathLengthFallback() ServeMuxOption {
   222  	return func(serveMux *ServeMux) {
   223  		serveMux.disablePathLengthFallback = true
   224  	}
   225  }
   226  
   227  // WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath.
   228  // When called the handler will forward the request to the upstream grpc service health check (defined in the
   229  // gRPC Health Checking Protocol).
   230  //
   231  // See here https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/health_check/ for more information on how
   232  // to setup the protocol in the grpc server.
   233  //
   234  // If you define a service as query parameter, this will also be forwarded as service in the HealthCheckRequest.
   235  func WithHealthEndpointAt(healthCheckClient grpc_health_v1.HealthClient, endpointPath string) ServeMuxOption {
   236  	return func(s *ServeMux) {
   237  		// error can be ignored since pattern is definitely valid
   238  		_ = s.HandlePath(
   239  			http.MethodGet, endpointPath, func(w http.ResponseWriter, r *http.Request, _ map[string]string,
   240  			) {
   241  				_, outboundMarshaler := MarshalerForRequest(s, r)
   242  
   243  				resp, err := healthCheckClient.Check(r.Context(), &grpc_health_v1.HealthCheckRequest{
   244  					Service: r.URL.Query().Get("service"),
   245  				})
   246  				if err != nil {
   247  					s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
   248  					return
   249  				}
   250  
   251  				w.Header().Set("Content-Type", "application/json")
   252  
   253  				if resp.GetStatus() != grpc_health_v1.HealthCheckResponse_SERVING {
   254  					switch resp.GetStatus() {
   255  					case grpc_health_v1.HealthCheckResponse_NOT_SERVING, grpc_health_v1.HealthCheckResponse_UNKNOWN:
   256  						err = status.Error(codes.Unavailable, resp.String())
   257  					case grpc_health_v1.HealthCheckResponse_SERVICE_UNKNOWN:
   258  						err = status.Error(codes.NotFound, resp.String())
   259  					}
   260  
   261  					s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
   262  					return
   263  				}
   264  
   265  				_ = outboundMarshaler.NewEncoder(w).Encode(resp)
   266  			})
   267  	}
   268  }
   269  
   270  // WithHealthzEndpoint returns a ServeMuxOption that will add a /healthz endpoint to the created ServeMux.
   271  //
   272  // See WithHealthEndpointAt for the general implementation.
   273  func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMuxOption {
   274  	return WithHealthEndpointAt(healthCheckClient, "/healthz")
   275  }
   276  
   277  // NewServeMux returns a new ServeMux whose internal mapping is empty.
   278  func NewServeMux(opts ...ServeMuxOption) *ServeMux {
   279  	serveMux := &ServeMux{
   280  		handlers:               make(map[string][]handler),
   281  		forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
   282  		marshalers:             makeMarshalerMIMERegistry(),
   283  		errorHandler:           DefaultHTTPErrorHandler,
   284  		streamErrorHandler:     DefaultStreamErrorHandler,
   285  		routingErrorHandler:    DefaultRoutingErrorHandler,
   286  		unescapingMode:         UnescapingModeDefault,
   287  	}
   288  
   289  	for _, opt := range opts {
   290  		opt(serveMux)
   291  	}
   292  
   293  	if serveMux.incomingHeaderMatcher == nil {
   294  		serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
   295  	}
   296  	if serveMux.outgoingHeaderMatcher == nil {
   297  		serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher
   298  	}
   299  	if serveMux.outgoingTrailerMatcher == nil {
   300  		serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher
   301  	}
   302  
   303  	return serveMux
   304  }
   305  
   306  // Handle associates "h" to the pair of HTTP method and path pattern.
   307  func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
   308  	s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
   309  }
   310  
   311  // HandlePath allows users to configure custom path handlers.
   312  // refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
   313  func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
   314  	compiler, err := httprule.Parse(pathPattern)
   315  	if err != nil {
   316  		return fmt.Errorf("parsing path pattern: %w", err)
   317  	}
   318  	tp := compiler.Compile()
   319  	pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
   320  	if err != nil {
   321  		return fmt.Errorf("creating new pattern: %w", err)
   322  	}
   323  	s.Handle(meth, pattern, h)
   324  	return nil
   325  }
   326  
   327  // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.URL.Path.
   328  func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   329  	ctx := r.Context()
   330  
   331  	path := r.URL.Path
   332  	if !strings.HasPrefix(path, "/") {
   333  		_, outboundMarshaler := MarshalerForRequest(s, r)
   334  		s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
   335  		return
   336  	}
   337  
   338  	// TODO(v3): remove UnescapingModeLegacy
   339  	if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
   340  		path = r.URL.RawPath
   341  	}
   342  
   343  	if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
   344  		if err := r.ParseForm(); err != nil {
   345  			_, outboundMarshaler := MarshalerForRequest(s, r)
   346  			sterr := status.Error(codes.InvalidArgument, err.Error())
   347  			s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
   348  			return
   349  		}
   350  		r.Method = strings.ToUpper(override)
   351  	}
   352  
   353  	var pathComponents []string
   354  	// since in UnescapeModeLegacy, the URL will already have been fully unescaped, if we also split on "%2F"
   355  	// in this escaping mode we would be double unescaping but in UnescapingModeAllCharacters, we still do as the
   356  	// path is the RawPath (i.e. unescaped). That does mean that the behavior of this function will change its default
   357  	// behavior when the UnescapingModeDefault gets changed from UnescapingModeLegacy to UnescapingModeAllExceptReserved
   358  	if s.unescapingMode == UnescapingModeAllCharacters {
   359  		pathComponents = encodedPathSplitter.Split(path[1:], -1)
   360  	} else {
   361  		pathComponents = strings.Split(path[1:], "/")
   362  	}
   363  
   364  	lastPathComponent := pathComponents[len(pathComponents)-1]
   365  
   366  	for _, h := range s.handlers[r.Method] {
   367  		// If the pattern has a verb, explicitly look for a suffix in the last
   368  		// component that matches a colon plus the verb. This allows us to
   369  		// handle some cases that otherwise can't be correctly handled by the
   370  		// former LastIndex case, such as when the verb literal itself contains
   371  		// a colon. This should work for all cases that have run through the
   372  		// parser because we know what verb we're looking for, however, there
   373  		// are still some cases that the parser itself cannot disambiguate. See
   374  		// the comment there if interested.
   375  
   376  		var verb string
   377  		patVerb := h.pat.Verb()
   378  
   379  		idx := -1
   380  		if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
   381  			idx = len(lastPathComponent) - len(patVerb) - 1
   382  		}
   383  		if idx == 0 {
   384  			_, outboundMarshaler := MarshalerForRequest(s, r)
   385  			s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
   386  			return
   387  		}
   388  
   389  		comps := make([]string, len(pathComponents))
   390  		copy(comps, pathComponents)
   391  
   392  		if idx > 0 {
   393  			comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
   394  		}
   395  
   396  		pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
   397  		if err != nil {
   398  			var mse MalformedSequenceError
   399  			if ok := errors.As(err, &mse); ok {
   400  				_, outboundMarshaler := MarshalerForRequest(s, r)
   401  				s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
   402  					HTTPStatus: http.StatusBadRequest,
   403  					Err:        mse,
   404  				})
   405  			}
   406  			continue
   407  		}
   408  		h.h(w, r, pathParams)
   409  		return
   410  	}
   411  
   412  	// if no handler has found for the request, lookup for other methods
   413  	// to handle POST -> GET fallback if the request is subject to path
   414  	// length fallback.
   415  	// Note we are not eagerly checking the request here as we want to return the
   416  	// right HTTP status code, and we need to process the fallback candidates in
   417  	// order to do that.
   418  	for m, handlers := range s.handlers {
   419  		if m == r.Method {
   420  			continue
   421  		}
   422  		for _, h := range handlers {
   423  			var verb string
   424  			patVerb := h.pat.Verb()
   425  
   426  			idx := -1
   427  			if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
   428  				idx = len(lastPathComponent) - len(patVerb) - 1
   429  			}
   430  
   431  			comps := make([]string, len(pathComponents))
   432  			copy(comps, pathComponents)
   433  
   434  			if idx > 0 {
   435  				comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
   436  			}
   437  
   438  			pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
   439  			if err != nil {
   440  				var mse MalformedSequenceError
   441  				if ok := errors.As(err, &mse); ok {
   442  					_, outboundMarshaler := MarshalerForRequest(s, r)
   443  					s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
   444  						HTTPStatus: http.StatusBadRequest,
   445  						Err:        mse,
   446  					})
   447  				}
   448  				continue
   449  			}
   450  
   451  			// X-HTTP-Method-Override is optional. Always allow fallback to POST.
   452  			// Also, only consider POST -> GET fallbacks, and avoid falling back to
   453  			// potentially dangerous operations like DELETE.
   454  			if s.isPathLengthFallback(r) && m == http.MethodGet {
   455  				if err := r.ParseForm(); err != nil {
   456  					_, outboundMarshaler := MarshalerForRequest(s, r)
   457  					sterr := status.Error(codes.InvalidArgument, err.Error())
   458  					s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
   459  					return
   460  				}
   461  				h.h(w, r, pathParams)
   462  				return
   463  			}
   464  			_, outboundMarshaler := MarshalerForRequest(s, r)
   465  			s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
   466  			return
   467  		}
   468  	}
   469  
   470  	_, outboundMarshaler := MarshalerForRequest(s, r)
   471  	s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
   472  }
   473  
   474  // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
   475  func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
   476  	return s.forwardResponseOptions
   477  }
   478  
   479  func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
   480  	return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
   481  }
   482  
   483  type handler struct {
   484  	pat Pattern
   485  	h   HandlerFunc
   486  }
   487  

View as plain text