1 package runtime
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "net/http"
8 "net/textproto"
9 "regexp"
10 "strings"
11
12 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
13 "google.golang.org/grpc/codes"
14 "google.golang.org/grpc/grpclog"
15 "google.golang.org/grpc/health/grpc_health_v1"
16 "google.golang.org/grpc/metadata"
17 "google.golang.org/grpc/status"
18 "google.golang.org/protobuf/proto"
19 )
20
21
22 type UnescapingMode int
23
24 const (
25
26
27 UnescapingModeLegacy UnescapingMode = iota
28
29
30
31 UnescapingModeAllExceptReserved
32
33
34
35 UnescapingModeAllExceptSlash
36
37
38 UnescapingModeAllCharacters
39
40
41
42
43 UnescapingModeDefault = UnescapingModeLegacy
44 )
45
46 var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
47
48
49 type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
50
51
52
53 type ServeMux struct {
54
55 handlers map[string][]handler
56 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
57 marshalers marshalerRegistry
58 incomingHeaderMatcher HeaderMatcherFunc
59 outgoingHeaderMatcher HeaderMatcherFunc
60 outgoingTrailerMatcher HeaderMatcherFunc
61 metadataAnnotators []func(context.Context, *http.Request) metadata.MD
62 errorHandler ErrorHandlerFunc
63 streamErrorHandler StreamErrorHandlerFunc
64 routingErrorHandler RoutingErrorHandlerFunc
65 disablePathLengthFallback bool
66 unescapingMode UnescapingMode
67 }
68
69
70 type ServeMuxOption func(*ServeMux)
71
72
73
74
75
76
77
78 func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
79 return func(serveMux *ServeMux) {
80 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
81 }
82 }
83
84
85
86 func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
87 return func(serveMux *ServeMux) {
88 serveMux.unescapingMode = mode
89 }
90 }
91
92
93
94
95 func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
96 return func(serveMux *ServeMux) {
97 currentQueryParser = queryParameterParser
98 }
99 }
100
101
102 type HeaderMatcherFunc func(string) (string, bool)
103
104
105
106
107
108 func DefaultHeaderMatcher(key string) (string, bool) {
109 switch key = textproto.CanonicalMIMEHeaderKey(key); {
110 case isPermanentHTTPHeader(key):
111 return MetadataPrefix + key, true
112 case strings.HasPrefix(key, MetadataHeaderPrefix):
113 return key[len(MetadataHeaderPrefix):], true
114 }
115 return "", false
116 }
117
118 func defaultOutgoingHeaderMatcher(key string) (string, bool) {
119 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
120 }
121
122 func defaultOutgoingTrailerMatcher(key string) (string, bool) {
123 return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true
124 }
125
126
127
128
129
130 func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
131 for _, header := range fn.matchedMalformedHeaders() {
132 grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header)
133 }
134
135 return func(mux *ServeMux) {
136 mux.incomingHeaderMatcher = fn
137 }
138 }
139
140
141 func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string {
142 if fn == nil {
143 return nil
144 }
145 headers := make([]string, 0)
146 for header := range malformedHTTPHeaders {
147 out, accept := fn(header)
148 if accept && isMalformedHTTPHeader(out) {
149 headers = append(headers, out)
150 }
151 }
152 return headers
153 }
154
155
156
157
158
159
160 func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
161 return func(mux *ServeMux) {
162 mux.outgoingHeaderMatcher = fn
163 }
164 }
165
166
167
168
169
170
171 func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption {
172 return func(mux *ServeMux) {
173 mux.outgoingTrailerMatcher = fn
174 }
175 }
176
177
178
179
180
181 func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
182 return func(serveMux *ServeMux) {
183 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
184 }
185 }
186
187
188
189
190 func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
191 return func(serveMux *ServeMux) {
192 serveMux.errorHandler = fn
193 }
194 }
195
196
197
198
199
200
201
202
203
204 func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
205 return func(serveMux *ServeMux) {
206 serveMux.streamErrorHandler = fn
207 }
208 }
209
210
211
212
213
214 func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
215 return func(serveMux *ServeMux) {
216 serveMux.routingErrorHandler = fn
217 }
218 }
219
220
221 func WithDisablePathLengthFallback() ServeMuxOption {
222 return func(serveMux *ServeMux) {
223 serveMux.disablePathLengthFallback = true
224 }
225 }
226
227
228
229
230
231
232
233
234
235 func WithHealthEndpointAt(healthCheckClient grpc_health_v1.HealthClient, endpointPath string) ServeMuxOption {
236 return func(s *ServeMux) {
237
238 _ = s.HandlePath(
239 http.MethodGet, endpointPath, func(w http.ResponseWriter, r *http.Request, _ map[string]string,
240 ) {
241 _, outboundMarshaler := MarshalerForRequest(s, r)
242
243 resp, err := healthCheckClient.Check(r.Context(), &grpc_health_v1.HealthCheckRequest{
244 Service: r.URL.Query().Get("service"),
245 })
246 if err != nil {
247 s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
248 return
249 }
250
251 w.Header().Set("Content-Type", "application/json")
252
253 if resp.GetStatus() != grpc_health_v1.HealthCheckResponse_SERVING {
254 switch resp.GetStatus() {
255 case grpc_health_v1.HealthCheckResponse_NOT_SERVING, grpc_health_v1.HealthCheckResponse_UNKNOWN:
256 err = status.Error(codes.Unavailable, resp.String())
257 case grpc_health_v1.HealthCheckResponse_SERVICE_UNKNOWN:
258 err = status.Error(codes.NotFound, resp.String())
259 }
260
261 s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
262 return
263 }
264
265 _ = outboundMarshaler.NewEncoder(w).Encode(resp)
266 })
267 }
268 }
269
270
271
272
273 func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMuxOption {
274 return WithHealthEndpointAt(healthCheckClient, "/healthz")
275 }
276
277
278 func NewServeMux(opts ...ServeMuxOption) *ServeMux {
279 serveMux := &ServeMux{
280 handlers: make(map[string][]handler),
281 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
282 marshalers: makeMarshalerMIMERegistry(),
283 errorHandler: DefaultHTTPErrorHandler,
284 streamErrorHandler: DefaultStreamErrorHandler,
285 routingErrorHandler: DefaultRoutingErrorHandler,
286 unescapingMode: UnescapingModeDefault,
287 }
288
289 for _, opt := range opts {
290 opt(serveMux)
291 }
292
293 if serveMux.incomingHeaderMatcher == nil {
294 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
295 }
296 if serveMux.outgoingHeaderMatcher == nil {
297 serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher
298 }
299 if serveMux.outgoingTrailerMatcher == nil {
300 serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher
301 }
302
303 return serveMux
304 }
305
306
307 func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
308 s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
309 }
310
311
312
313 func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
314 compiler, err := httprule.Parse(pathPattern)
315 if err != nil {
316 return fmt.Errorf("parsing path pattern: %w", err)
317 }
318 tp := compiler.Compile()
319 pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
320 if err != nil {
321 return fmt.Errorf("creating new pattern: %w", err)
322 }
323 s.Handle(meth, pattern, h)
324 return nil
325 }
326
327
328 func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
329 ctx := r.Context()
330
331 path := r.URL.Path
332 if !strings.HasPrefix(path, "/") {
333 _, outboundMarshaler := MarshalerForRequest(s, r)
334 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
335 return
336 }
337
338
339 if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
340 path = r.URL.RawPath
341 }
342
343 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
344 if err := r.ParseForm(); err != nil {
345 _, outboundMarshaler := MarshalerForRequest(s, r)
346 sterr := status.Error(codes.InvalidArgument, err.Error())
347 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
348 return
349 }
350 r.Method = strings.ToUpper(override)
351 }
352
353 var pathComponents []string
354
355
356
357
358 if s.unescapingMode == UnescapingModeAllCharacters {
359 pathComponents = encodedPathSplitter.Split(path[1:], -1)
360 } else {
361 pathComponents = strings.Split(path[1:], "/")
362 }
363
364 lastPathComponent := pathComponents[len(pathComponents)-1]
365
366 for _, h := range s.handlers[r.Method] {
367
368
369
370
371
372
373
374
375
376 var verb string
377 patVerb := h.pat.Verb()
378
379 idx := -1
380 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
381 idx = len(lastPathComponent) - len(patVerb) - 1
382 }
383 if idx == 0 {
384 _, outboundMarshaler := MarshalerForRequest(s, r)
385 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
386 return
387 }
388
389 comps := make([]string, len(pathComponents))
390 copy(comps, pathComponents)
391
392 if idx > 0 {
393 comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
394 }
395
396 pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
397 if err != nil {
398 var mse MalformedSequenceError
399 if ok := errors.As(err, &mse); ok {
400 _, outboundMarshaler := MarshalerForRequest(s, r)
401 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
402 HTTPStatus: http.StatusBadRequest,
403 Err: mse,
404 })
405 }
406 continue
407 }
408 h.h(w, r, pathParams)
409 return
410 }
411
412
413
414
415
416
417
418 for m, handlers := range s.handlers {
419 if m == r.Method {
420 continue
421 }
422 for _, h := range handlers {
423 var verb string
424 patVerb := h.pat.Verb()
425
426 idx := -1
427 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
428 idx = len(lastPathComponent) - len(patVerb) - 1
429 }
430
431 comps := make([]string, len(pathComponents))
432 copy(comps, pathComponents)
433
434 if idx > 0 {
435 comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
436 }
437
438 pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
439 if err != nil {
440 var mse MalformedSequenceError
441 if ok := errors.As(err, &mse); ok {
442 _, outboundMarshaler := MarshalerForRequest(s, r)
443 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
444 HTTPStatus: http.StatusBadRequest,
445 Err: mse,
446 })
447 }
448 continue
449 }
450
451
452
453
454 if s.isPathLengthFallback(r) && m == http.MethodGet {
455 if err := r.ParseForm(); err != nil {
456 _, outboundMarshaler := MarshalerForRequest(s, r)
457 sterr := status.Error(codes.InvalidArgument, err.Error())
458 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
459 return
460 }
461 h.h(w, r, pathParams)
462 return
463 }
464 _, outboundMarshaler := MarshalerForRequest(s, r)
465 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
466 return
467 }
468 }
469
470 _, outboundMarshaler := MarshalerForRequest(s, r)
471 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
472 }
473
474
475 func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
476 return s.forwardResponseOptions
477 }
478
479 func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
480 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
481 }
482
483 type handler struct {
484 pat Pattern
485 h HandlerFunc
486 }
487
View as plain text