1 package runtime
2
3 import (
4 "context"
5 "encoding/base64"
6 "fmt"
7 "net"
8 "net/http"
9 "net/textproto"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "google.golang.org/grpc/codes"
16 "google.golang.org/grpc/grpclog"
17 "google.golang.org/grpc/metadata"
18 "google.golang.org/grpc/status"
19 )
20
21
22
23 const MetadataHeaderPrefix = "Grpc-Metadata-"
24
25
26
27 const MetadataPrefix = "grpcgateway-"
28
29
30
31 const MetadataTrailerPrefix = "Grpc-Trailer-"
32
33 const metadataGrpcTimeout = "Grpc-Timeout"
34 const metadataHeaderBinarySuffix = "-Bin"
35
36 const xForwardedFor = "X-Forwarded-For"
37 const xForwardedHost = "X-Forwarded-Host"
38
39
40
41 var DefaultContextTimeout = 0 * time.Second
42
43
44
45 var malformedHTTPHeaders = map[string]struct{}{
46 "connection": {},
47 }
48
49 type (
50 rpcMethodKey struct{}
51 httpPathPatternKey struct{}
52
53 AnnotateContextOption func(ctx context.Context) context.Context
54 )
55
56 func WithHTTPPathPattern(pattern string) AnnotateContextOption {
57 return func(ctx context.Context) context.Context {
58 return withHTTPPathPattern(ctx, pattern)
59 }
60 }
61
62 func decodeBinHeader(v string) ([]byte, error) {
63 if len(v)%4 == 0 {
64
65 return base64.StdEncoding.DecodeString(v)
66 }
67 return base64.RawStdEncoding.DecodeString(v)
68 }
69
70
77 func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
78 ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
79 if err != nil {
80 return nil, err
81 }
82 if md == nil {
83 return ctx, nil
84 }
85
86 return metadata.NewOutgoingContext(ctx, md), nil
87 }
88
89
90
91 func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
92 ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
93 if err != nil {
94 return nil, err
95 }
96 if md == nil {
97 return ctx, nil
98 }
99
100 return metadata.NewIncomingContext(ctx, md), nil
101 }
102
103 func isValidGRPCMetadataKey(key string) bool {
104
105
106
107
108
109 bytes := []byte(key)
110 for _, ch := range bytes {
111 validLowercaseLetter := ch >= 'a' && ch <= 'z'
112 validUppercaseLetter := ch >= 'A' && ch <= 'Z'
113 validDigit := ch >= '0' && ch <= '9'
114 validOther := ch == '.' || ch == '-' || ch == '_'
115 if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther {
116 return false
117 }
118 }
119 return true
120 }
121
122 func isValidGRPCMetadataTextValue(textValue string) bool {
123
124
125
126 bytes := []byte(textValue)
127 for _, ch := range bytes {
128 if ch < 0x20 || ch > 0x7E {
129 return false
130 }
131 }
132 return true
133 }
134
135 func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) {
136 ctx = withRPCMethod(ctx, rpcMethodName)
137 for _, o := range options {
138 ctx = o(ctx)
139 }
140 timeout := DefaultContextTimeout
141 if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
142 var err error
143 timeout, err = timeoutDecode(tm)
144 if err != nil {
145 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
146 }
147 }
148 var pairs []string
149 for key, vals := range req.Header {
150 key = textproto.CanonicalMIMEHeaderKey(key)
151 for _, val := range vals {
152
153 if key == "Authorization" {
154 pairs = append(pairs, "authorization", val)
155 }
156 if h, ok := mux.incomingHeaderMatcher(key); ok {
157 if !isValidGRPCMetadataKey(h) {
158 grpclog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h)
159 continue
160 }
161
162
163 if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
164 b, err := decodeBinHeader(val)
165 if err != nil {
166 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
167 }
168
169 val = string(b)
170 } else if !isValidGRPCMetadataTextValue(val) {
171 grpclog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h)
172 continue
173 }
174 pairs = append(pairs, h, val)
175 }
176 }
177 }
178 if host := req.Header.Get(xForwardedHost); host != "" {
179 pairs = append(pairs, strings.ToLower(xForwardedHost), host)
180 } else if req.Host != "" {
181 pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
182 }
183
184 if addr := req.RemoteAddr; addr != "" {
185 if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
186 if fwd := req.Header.Get(xForwardedFor); fwd == "" {
187 pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
188 } else {
189 pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
190 }
191 }
192 }
193
194 if timeout != 0 {
195
196 ctx, _ = context.WithTimeout(ctx, timeout)
197 }
198 if len(pairs) == 0 {
199 return ctx, nil, nil
200 }
201 md := metadata.Pairs(pairs...)
202 for _, mda := range mux.metadataAnnotators {
203 md = metadata.Join(md, mda(ctx, req))
204 }
205 return ctx, md, nil
206 }
207
208
209 type ServerMetadata struct {
210 HeaderMD metadata.MD
211 TrailerMD metadata.MD
212 }
213
214 type serverMetadataKey struct{}
215
216
217 func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
218 if ctx == nil {
219 ctx = context.Background()
220 }
221 return context.WithValue(ctx, serverMetadataKey{}, md)
222 }
223
224
225 func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
226 if ctx == nil {
227 return md, false
228 }
229 md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
230 return
231 }
232
233
234
235
236 type ServerTransportStream struct {
237 mu sync.Mutex
238 header metadata.MD
239 trailer metadata.MD
240 }
241
242
243 func (s *ServerTransportStream) Method() string {
244 return ""
245 }
246
247
248 func (s *ServerTransportStream) Header() metadata.MD {
249 s.mu.Lock()
250 defer s.mu.Unlock()
251 return s.header.Copy()
252 }
253
254
255 func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
256 if md.Len() == 0 {
257 return nil
258 }
259
260 s.mu.Lock()
261 s.header = metadata.Join(s.header, md)
262 s.mu.Unlock()
263 return nil
264 }
265
266
267 func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
268 return s.SetHeader(md)
269 }
270
271
272 func (s *ServerTransportStream) Trailer() metadata.MD {
273 s.mu.Lock()
274 defer s.mu.Unlock()
275 return s.trailer.Copy()
276 }
277
278
279 func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
280 if md.Len() == 0 {
281 return nil
282 }
283
284 s.mu.Lock()
285 s.trailer = metadata.Join(s.trailer, md)
286 s.mu.Unlock()
287 return nil
288 }
289
290 func timeoutDecode(s string) (time.Duration, error) {
291 size := len(s)
292 if size < 2 {
293 return 0, fmt.Errorf("timeout string is too short: %q", s)
294 }
295 d, ok := timeoutUnitToDuration(s[size-1])
296 if !ok {
297 return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
298 }
299 t, err := strconv.ParseInt(s[:size-1], 10, 64)
300 if err != nil {
301 return 0, err
302 }
303 return d * time.Duration(t), nil
304 }
305
306 func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
307 switch u {
308 case 'H':
309 return time.Hour, true
310 case 'M':
311 return time.Minute, true
312 case 'S':
313 return time.Second, true
314 case 'm':
315 return time.Millisecond, true
316 case 'u':
317 return time.Microsecond, true
318 case 'n':
319 return time.Nanosecond, true
320 default:
321 return
322 }
323 }
324
325
326
327
328 func isPermanentHTTPHeader(hdr string) bool {
329 switch hdr {
330 case
331 "Accept",
332 "Accept-Charset",
333 "Accept-Language",
334 "Accept-Ranges",
335 "Authorization",
336 "Cache-Control",
337 "Content-Type",
338 "Cookie",
339 "Date",
340 "Expect",
341 "From",
342 "Host",
343 "If-Match",
344 "If-Modified-Since",
345 "If-None-Match",
346 "If-Schedule-Tag-Match",
347 "If-Unmodified-Since",
348 "Max-Forwards",
349 "Origin",
350 "Pragma",
351 "Referer",
352 "User-Agent",
353 "Via",
354 "Warning":
355 return true
356 }
357 return false
358 }
359
360
361
362 func isMalformedHTTPHeader(header string) bool {
363 _, isMalformed := malformedHTTPHeaders[strings.ToLower(header)]
364 return isMalformed
365 }
366
367
368
369 func RPCMethod(ctx context.Context) (string, bool) {
370 m := ctx.Value(rpcMethodKey{})
371 if m == nil {
372 return "", false
373 }
374 ms, ok := m.(string)
375 if !ok {
376 return "", false
377 }
378 return ms, true
379 }
380
381 func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context {
382 return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName)
383 }
384
385
386
387 func HTTPPathPattern(ctx context.Context) (string, bool) {
388 m := ctx.Value(httpPathPatternKey{})
389 if m == nil {
390 return "", false
391 }
392 ms, ok := m.(string)
393 if !ok {
394 return "", false
395 }
396 return ms, true
397 }
398
399 func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
400 return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
401 }
402
View as plain text