1 package runtime
2
3 import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/textproto"
8 "strings"
9
10 "github.com/golang/protobuf/proto"
11 "google.golang.org/grpc/codes"
12 "google.golang.org/grpc/metadata"
13 "google.golang.org/grpc/status"
14 )
15
16
17 type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
18
19
20
21
22
23
24
25 var ErrUnknownURI = status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
26
27
28
29 type ServeMux struct {
30
31 handlers map[string][]handler
32 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
33 marshalers marshalerRegistry
34 incomingHeaderMatcher HeaderMatcherFunc
35 outgoingHeaderMatcher HeaderMatcherFunc
36 metadataAnnotators []func(context.Context, *http.Request) metadata.MD
37 streamErrorHandler StreamErrorHandlerFunc
38 protoErrorHandler ProtoErrorHandlerFunc
39 disablePathLengthFallback bool
40 lastMatchWins bool
41 }
42
43
44 type ServeMuxOption func(*ServeMux)
45
46
47
48
49
50
51
52 func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
53 return func(serveMux *ServeMux) {
54 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
55 }
56 }
57
58
59
60
61 func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
62 return func(serveMux *ServeMux) {
63 currentQueryParser = queryParameterParser
64 }
65 }
66
67
68 type HeaderMatcherFunc func(string) (string, bool)
69
70
71
72
73 func DefaultHeaderMatcher(key string) (string, bool) {
74 key = textproto.CanonicalMIMEHeaderKey(key)
75 if isPermanentHTTPHeader(key) {
76 return MetadataPrefix + key, true
77 } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
78 return key[len(MetadataHeaderPrefix):], true
79 }
80 return "", false
81 }
82
83
84
85
86
87 func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
88 return func(mux *ServeMux) {
89 mux.incomingHeaderMatcher = fn
90 }
91 }
92
93
94
95
96
97
98 func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
99 return func(mux *ServeMux) {
100 mux.outgoingHeaderMatcher = fn
101 }
102 }
103
104
105
106
107
108 func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
109 return func(serveMux *ServeMux) {
110 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
111 }
112 }
113
114
115
116
117
118
119 func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
120 return func(serveMux *ServeMux) {
121 serveMux.protoErrorHandler = fn
122 }
123 }
124
125
126 func WithDisablePathLengthFallback() ServeMuxOption {
127 return func(serveMux *ServeMux) {
128 serveMux.disablePathLengthFallback = true
129 }
130 }
131
132
133
134
135
136
137
138
139
140 func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
141 return func(serveMux *ServeMux) {
142 serveMux.streamErrorHandler = fn
143 }
144 }
145
146
147
148
149 func WithLastMatchWins() ServeMuxOption {
150 return func(serveMux *ServeMux) {
151 serveMux.lastMatchWins = true
152 }
153 }
154
155
156 func NewServeMux(opts ...ServeMuxOption) *ServeMux {
157 serveMux := &ServeMux{
158 handlers: make(map[string][]handler),
159 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
160 marshalers: makeMarshalerMIMERegistry(),
161 streamErrorHandler: DefaultHTTPStreamErrorHandler,
162 }
163
164 for _, opt := range opts {
165 opt(serveMux)
166 }
167
168 if serveMux.incomingHeaderMatcher == nil {
169 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
170 }
171
172 if serveMux.outgoingHeaderMatcher == nil {
173 serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
174 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
175 }
176 }
177
178 return serveMux
179 }
180
181
182 func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
183 if s.lastMatchWins {
184 s.handlers[meth] = append([]handler{handler{pat: pat, h: h}}, s.handlers[meth]...)
185 } else {
186 s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
187 }
188 }
189
190
191 func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
192 ctx := r.Context()
193
194 path := r.URL.Path
195 if !strings.HasPrefix(path, "/") {
196 if s.protoErrorHandler != nil {
197 _, outboundMarshaler := MarshalerForRequest(s, r)
198 sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
199 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
200 } else {
201 OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
202 }
203 return
204 }
205
206 components := strings.Split(path[1:], "/")
207 l := len(components)
208 var verb string
209 if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
210 if s.protoErrorHandler != nil {
211 _, outboundMarshaler := MarshalerForRequest(s, r)
212 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
213 } else {
214 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
215 }
216 return
217 } else if idx > 0 {
218 c := components[l-1]
219 components[l-1], verb = c[:idx], c[idx+1:]
220 }
221
222 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
223 r.Method = strings.ToUpper(override)
224 if err := r.ParseForm(); err != nil {
225 if s.protoErrorHandler != nil {
226 _, outboundMarshaler := MarshalerForRequest(s, r)
227 sterr := status.Error(codes.InvalidArgument, err.Error())
228 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
229 } else {
230 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
231 }
232 return
233 }
234 }
235 for _, h := range s.handlers[r.Method] {
236 pathParams, err := h.pat.Match(components, verb)
237 if err != nil {
238 continue
239 }
240 h.h(w, r, pathParams)
241 return
242 }
243
244
245
246 for m, handlers := range s.handlers {
247 if m == r.Method {
248 continue
249 }
250 for _, h := range handlers {
251 pathParams, err := h.pat.Match(components, verb)
252 if err != nil {
253 continue
254 }
255
256 if s.isPathLengthFallback(r) {
257 if err := r.ParseForm(); err != nil {
258 if s.protoErrorHandler != nil {
259 _, outboundMarshaler := MarshalerForRequest(s, r)
260 sterr := status.Error(codes.InvalidArgument, err.Error())
261 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
262 } else {
263 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
264 }
265 return
266 }
267 h.h(w, r, pathParams)
268 return
269 }
270 if s.protoErrorHandler != nil {
271 _, outboundMarshaler := MarshalerForRequest(s, r)
272 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
273 } else {
274 OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
275 }
276 return
277 }
278 }
279
280 if s.protoErrorHandler != nil {
281 _, outboundMarshaler := MarshalerForRequest(s, r)
282 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
283 } else {
284 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
285 }
286 }
287
288
289 func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
290 return s.forwardResponseOptions
291 }
292
293 func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
294 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
295 }
296
297 type handler struct {
298 pat Pattern
299 h HandlerFunc
300 }
301
View as plain text