...

Source file src/github.com/grpc-ecosystem/grpc-gateway/runtime/mux.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/runtime

     1  package runtime
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"net/textproto"
     8  	"strings"
     9  
    10  	"github.com/golang/protobuf/proto"
    11  	"google.golang.org/grpc/codes"
    12  	"google.golang.org/grpc/metadata"
    13  	"google.golang.org/grpc/status"
    14  )
    15  
    16  // A HandlerFunc handles a specific pair of path pattern and HTTP method.
    17  type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
    18  
    19  // ErrUnknownURI is the error supplied to a custom ProtoErrorHandlerFunc when
    20  // a request is received with a URI path that does not match any registered
    21  // service method.
    22  //
    23  // Since gRPC servers return an "Unimplemented" code for requests with an
    24  // unrecognized URI path, this error also has a gRPC "Unimplemented" code.
    25  var ErrUnknownURI = status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
    26  
    27  // ServeMux is a request multiplexer for grpc-gateway.
    28  // It matches http requests to patterns and invokes the corresponding handler.
    29  type ServeMux struct {
    30  	// handlers maps HTTP method to a list of handlers.
    31  	handlers                  map[string][]handler
    32  	forwardResponseOptions    []func(context.Context, http.ResponseWriter, proto.Message) error
    33  	marshalers                marshalerRegistry
    34  	incomingHeaderMatcher     HeaderMatcherFunc
    35  	outgoingHeaderMatcher     HeaderMatcherFunc
    36  	metadataAnnotators        []func(context.Context, *http.Request) metadata.MD
    37  	streamErrorHandler        StreamErrorHandlerFunc
    38  	protoErrorHandler         ProtoErrorHandlerFunc
    39  	disablePathLengthFallback bool
    40  	lastMatchWins             bool
    41  }
    42  
    43  // ServeMuxOption is an option that can be given to a ServeMux on construction.
    44  type ServeMuxOption func(*ServeMux)
    45  
    46  // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
    47  //
    48  // forwardResponseOption is an option that will be called on the relevant context.Context,
    49  // http.ResponseWriter, and proto.Message before every forwarded response.
    50  //
    51  // The message may be nil in the case where just a header is being sent.
    52  func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
    53  	return func(serveMux *ServeMux) {
    54  		serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
    55  	}
    56  }
    57  
    58  // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
    59  // Configuring this will mean the generated swagger output is no longer correct, and it should be
    60  // done with careful consideration.
    61  func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
    62  	return func(serveMux *ServeMux) {
    63  		currentQueryParser = queryParameterParser
    64  	}
    65  }
    66  
    67  // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
    68  type HeaderMatcherFunc func(string) (string, bool)
    69  
    70  // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
    71  // keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
    72  // 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
    73  func DefaultHeaderMatcher(key string) (string, bool) {
    74  	key = textproto.CanonicalMIMEHeaderKey(key)
    75  	if isPermanentHTTPHeader(key) {
    76  		return MetadataPrefix + key, true
    77  	} else if strings.HasPrefix(key, MetadataHeaderPrefix) {
    78  		return key[len(MetadataHeaderPrefix):], true
    79  	}
    80  	return "", false
    81  }
    82  
    83  // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
    84  //
    85  // This matcher will be called with each header in http.Request. If matcher returns true, that header will be
    86  // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
    87  func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
    88  	return func(mux *ServeMux) {
    89  		mux.incomingHeaderMatcher = fn
    90  	}
    91  }
    92  
    93  // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
    94  //
    95  // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
    96  // passed to http response returned from gateway. To transform the header before passing to response,
    97  // matcher should return modified header.
    98  func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
    99  	return func(mux *ServeMux) {
   100  		mux.outgoingHeaderMatcher = fn
   101  	}
   102  }
   103  
   104  // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
   105  //
   106  // This can be used by services that need to read from http.Request and modify gRPC context. A common use case
   107  // is reading token from cookie and adding it in gRPC context.
   108  func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
   109  	return func(serveMux *ServeMux) {
   110  		serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
   111  	}
   112  }
   113  
   114  // WithProtoErrorHandler returns a ServeMuxOption for configuring a custom error handler.
   115  //
   116  // This can be used to handle an error as general proto message defined by gRPC.
   117  // When this option is used, the mux uses the configured error handler instead of HTTPError and
   118  // OtherErrorHandler.
   119  func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
   120  	return func(serveMux *ServeMux) {
   121  		serveMux.protoErrorHandler = fn
   122  	}
   123  }
   124  
   125  // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
   126  func WithDisablePathLengthFallback() ServeMuxOption {
   127  	return func(serveMux *ServeMux) {
   128  		serveMux.disablePathLengthFallback = true
   129  	}
   130  }
   131  
   132  // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
   133  // error handler, which allows for customizing the error trailer for server-streaming
   134  // calls.
   135  //
   136  // For stream errors that occur before any response has been written, the mux's
   137  // ProtoErrorHandler will be invoked. However, once data has been written, the errors must
   138  // be handled differently: they must be included in the response body. The response body's
   139  // final message will include the error details returned by the stream error handler.
   140  func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
   141  	return func(serveMux *ServeMux) {
   142  		serveMux.streamErrorHandler = fn
   143  	}
   144  }
   145  
   146  // WithLastMatchWins returns a ServeMuxOption that will enable "last
   147  // match wins" behavior, where if multiple path patterns match a
   148  // request path, the last one defined in the .proto file will be used.
   149  func WithLastMatchWins() ServeMuxOption {
   150  	return func(serveMux *ServeMux) {
   151  		serveMux.lastMatchWins = true
   152  	}
   153  }
   154  
   155  // NewServeMux returns a new ServeMux whose internal mapping is empty.
   156  func NewServeMux(opts ...ServeMuxOption) *ServeMux {
   157  	serveMux := &ServeMux{
   158  		handlers:               make(map[string][]handler),
   159  		forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
   160  		marshalers:             makeMarshalerMIMERegistry(),
   161  		streamErrorHandler:     DefaultHTTPStreamErrorHandler,
   162  	}
   163  
   164  	for _, opt := range opts {
   165  		opt(serveMux)
   166  	}
   167  
   168  	if serveMux.incomingHeaderMatcher == nil {
   169  		serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
   170  	}
   171  
   172  	if serveMux.outgoingHeaderMatcher == nil {
   173  		serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
   174  			return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
   175  		}
   176  	}
   177  
   178  	return serveMux
   179  }
   180  
   181  // Handle associates "h" to the pair of HTTP method and path pattern.
   182  func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
   183  	if s.lastMatchWins {
   184  		s.handlers[meth] = append([]handler{handler{pat: pat, h: h}}, s.handlers[meth]...)
   185  	} else {
   186  		s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
   187  	}
   188  }
   189  
   190  // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
   191  func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   192  	ctx := r.Context()
   193  
   194  	path := r.URL.Path
   195  	if !strings.HasPrefix(path, "/") {
   196  		if s.protoErrorHandler != nil {
   197  			_, outboundMarshaler := MarshalerForRequest(s, r)
   198  			sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
   199  			s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
   200  		} else {
   201  			OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
   202  		}
   203  		return
   204  	}
   205  
   206  	components := strings.Split(path[1:], "/")
   207  	l := len(components)
   208  	var verb string
   209  	if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
   210  		if s.protoErrorHandler != nil {
   211  			_, outboundMarshaler := MarshalerForRequest(s, r)
   212  			s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
   213  		} else {
   214  			OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
   215  		}
   216  		return
   217  	} else if idx > 0 {
   218  		c := components[l-1]
   219  		components[l-1], verb = c[:idx], c[idx+1:]
   220  	}
   221  
   222  	if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
   223  		r.Method = strings.ToUpper(override)
   224  		if err := r.ParseForm(); err != nil {
   225  			if s.protoErrorHandler != nil {
   226  				_, outboundMarshaler := MarshalerForRequest(s, r)
   227  				sterr := status.Error(codes.InvalidArgument, err.Error())
   228  				s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
   229  			} else {
   230  				OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
   231  			}
   232  			return
   233  		}
   234  	}
   235  	for _, h := range s.handlers[r.Method] {
   236  		pathParams, err := h.pat.Match(components, verb)
   237  		if err != nil {
   238  			continue
   239  		}
   240  		h.h(w, r, pathParams)
   241  		return
   242  	}
   243  
   244  	// lookup other methods to handle fallback from GET to POST and
   245  	// to determine if it is MethodNotAllowed or NotFound.
   246  	for m, handlers := range s.handlers {
   247  		if m == r.Method {
   248  			continue
   249  		}
   250  		for _, h := range handlers {
   251  			pathParams, err := h.pat.Match(components, verb)
   252  			if err != nil {
   253  				continue
   254  			}
   255  			// X-HTTP-Method-Override is optional. Always allow fallback to POST.
   256  			if s.isPathLengthFallback(r) {
   257  				if err := r.ParseForm(); err != nil {
   258  					if s.protoErrorHandler != nil {
   259  						_, outboundMarshaler := MarshalerForRequest(s, r)
   260  						sterr := status.Error(codes.InvalidArgument, err.Error())
   261  						s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
   262  					} else {
   263  						OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
   264  					}
   265  					return
   266  				}
   267  				h.h(w, r, pathParams)
   268  				return
   269  			}
   270  			if s.protoErrorHandler != nil {
   271  				_, outboundMarshaler := MarshalerForRequest(s, r)
   272  				s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
   273  			} else {
   274  				OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
   275  			}
   276  			return
   277  		}
   278  	}
   279  
   280  	if s.protoErrorHandler != nil {
   281  		_, outboundMarshaler := MarshalerForRequest(s, r)
   282  		s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
   283  	} else {
   284  		OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
   285  	}
   286  }
   287  
   288  // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
   289  func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
   290  	return s.forwardResponseOptions
   291  }
   292  
   293  func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
   294  	return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
   295  }
   296  
   297  type handler struct {
   298  	pat Pattern
   299  	h   HandlerFunc
   300  }
   301  

View as plain text