...

Source file src/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/context.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/v2/runtime

     1  package runtime
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/textproto"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/grpclog"
    17  	"google.golang.org/grpc/metadata"
    18  	"google.golang.org/grpc/status"
    19  )
    20  
    21  // MetadataHeaderPrefix is the http prefix that represents custom metadata
    22  // parameters to or from a gRPC call.
    23  const MetadataHeaderPrefix = "Grpc-Metadata-"
    24  
    25  // MetadataPrefix is prepended to permanent HTTP header keys (as specified
    26  // by the IANA) when added to the gRPC context.
    27  const MetadataPrefix = "grpcgateway-"
    28  
    29  // MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
    30  // HTTP headers in a response handled by grpc-gateway
    31  const MetadataTrailerPrefix = "Grpc-Trailer-"
    32  
    33  const metadataGrpcTimeout = "Grpc-Timeout"
    34  const metadataHeaderBinarySuffix = "-Bin"
    35  
    36  const xForwardedFor = "X-Forwarded-For"
    37  const xForwardedHost = "X-Forwarded-Host"
    38  
    39  // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
    40  // header isn't present. If the value is 0 the sent `context` will not have a timeout.
    41  var DefaultContextTimeout = 0 * time.Second
    42  
    43  // malformedHTTPHeaders lists the headers that the gRPC server may reject outright as malformed.
    44  // See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more context.
    45  var malformedHTTPHeaders = map[string]struct{}{
    46  	"connection": {},
    47  }
    48  
    49  type (
    50  	rpcMethodKey       struct{}
    51  	httpPathPatternKey struct{}
    52  
    53  	AnnotateContextOption func(ctx context.Context) context.Context
    54  )
    55  
    56  func WithHTTPPathPattern(pattern string) AnnotateContextOption {
    57  	return func(ctx context.Context) context.Context {
    58  		return withHTTPPathPattern(ctx, pattern)
    59  	}
    60  }
    61  
    62  func decodeBinHeader(v string) ([]byte, error) {
    63  	if len(v)%4 == 0 {
    64  		// Input was padded, or padding was not necessary.
    65  		return base64.StdEncoding.DecodeString(v)
    66  	}
    67  	return base64.RawStdEncoding.DecodeString(v)
    68  }
    69  
    70  /*
    71  AnnotateContext adds context information such as metadata from the request.
    72  
    73  At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
    74  except that the forwarded destination is not another HTTP service but rather
    75  a gRPC service.
    76  */
    77  func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
    78  	ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  	if md == nil {
    83  		return ctx, nil
    84  	}
    85  
    86  	return metadata.NewOutgoingContext(ctx, md), nil
    87  }
    88  
    89  // AnnotateIncomingContext adds context information such as metadata from the request.
    90  // Attach metadata as incoming context.
    91  func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
    92  	ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	if md == nil {
    97  		return ctx, nil
    98  	}
    99  
   100  	return metadata.NewIncomingContext(ctx, md), nil
   101  }
   102  
   103  func isValidGRPCMetadataKey(key string) bool {
   104  	// Must be a valid gRPC "Header-Name" as defined here:
   105  	//   https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
   106  	// This means 0-9 a-z _ - .
   107  	// Only lowercase letters are valid in the wire protocol, but the client library will normalize
   108  	// uppercase ASCII to lowercase, so uppercase ASCII is also acceptable.
   109  	bytes := []byte(key) // gRPC validates strings on the byte level, not Unicode.
   110  	for _, ch := range bytes {
   111  		validLowercaseLetter := ch >= 'a' && ch <= 'z'
   112  		validUppercaseLetter := ch >= 'A' && ch <= 'Z'
   113  		validDigit := ch >= '0' && ch <= '9'
   114  		validOther := ch == '.' || ch == '-' || ch == '_'
   115  		if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther {
   116  			return false
   117  		}
   118  	}
   119  	return true
   120  }
   121  
   122  func isValidGRPCMetadataTextValue(textValue string) bool {
   123  	// Must be a valid gRPC "ASCII-Value" as defined here:
   124  	//   https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
   125  	// This means printable ASCII (including/plus spaces); 0x20 to 0x7E inclusive.
   126  	bytes := []byte(textValue) // gRPC validates strings on the byte level, not Unicode.
   127  	for _, ch := range bytes {
   128  		if ch < 0x20 || ch > 0x7E {
   129  			return false
   130  		}
   131  	}
   132  	return true
   133  }
   134  
   135  func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) {
   136  	ctx = withRPCMethod(ctx, rpcMethodName)
   137  	for _, o := range options {
   138  		ctx = o(ctx)
   139  	}
   140  	timeout := DefaultContextTimeout
   141  	if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
   142  		var err error
   143  		timeout, err = timeoutDecode(tm)
   144  		if err != nil {
   145  			return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
   146  		}
   147  	}
   148  	var pairs []string
   149  	for key, vals := range req.Header {
   150  		key = textproto.CanonicalMIMEHeaderKey(key)
   151  		for _, val := range vals {
   152  			// For backwards-compatibility, pass through 'authorization' header with no prefix.
   153  			if key == "Authorization" {
   154  				pairs = append(pairs, "authorization", val)
   155  			}
   156  			if h, ok := mux.incomingHeaderMatcher(key); ok {
   157  				if !isValidGRPCMetadataKey(h) {
   158  					grpclog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h)
   159  					continue
   160  				}
   161  				// Handles "-bin" metadata in grpc, since grpc will do another base64
   162  				// encode before sending to server, we need to decode it first.
   163  				if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
   164  					b, err := decodeBinHeader(val)
   165  					if err != nil {
   166  						return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
   167  					}
   168  
   169  					val = string(b)
   170  				} else if !isValidGRPCMetadataTextValue(val) {
   171  					grpclog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h)
   172  					continue
   173  				}
   174  				pairs = append(pairs, h, val)
   175  			}
   176  		}
   177  	}
   178  	if host := req.Header.Get(xForwardedHost); host != "" {
   179  		pairs = append(pairs, strings.ToLower(xForwardedHost), host)
   180  	} else if req.Host != "" {
   181  		pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
   182  	}
   183  
   184  	if addr := req.RemoteAddr; addr != "" {
   185  		if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
   186  			if fwd := req.Header.Get(xForwardedFor); fwd == "" {
   187  				pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
   188  			} else {
   189  				pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
   190  			}
   191  		}
   192  	}
   193  
   194  	if timeout != 0 {
   195  		//nolint:govet  // The context outlives this function
   196  		ctx, _ = context.WithTimeout(ctx, timeout)
   197  	}
   198  	if len(pairs) == 0 {
   199  		return ctx, nil, nil
   200  	}
   201  	md := metadata.Pairs(pairs...)
   202  	for _, mda := range mux.metadataAnnotators {
   203  		md = metadata.Join(md, mda(ctx, req))
   204  	}
   205  	return ctx, md, nil
   206  }
   207  
   208  // ServerMetadata consists of metadata sent from gRPC server.
   209  type ServerMetadata struct {
   210  	HeaderMD  metadata.MD
   211  	TrailerMD metadata.MD
   212  }
   213  
   214  type serverMetadataKey struct{}
   215  
   216  // NewServerMetadataContext creates a new context with ServerMetadata
   217  func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
   218  	if ctx == nil {
   219  		ctx = context.Background()
   220  	}
   221  	return context.WithValue(ctx, serverMetadataKey{}, md)
   222  }
   223  
   224  // ServerMetadataFromContext returns the ServerMetadata in ctx
   225  func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
   226  	if ctx == nil {
   227  		return md, false
   228  	}
   229  	md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
   230  	return
   231  }
   232  
   233  // ServerTransportStream implements grpc.ServerTransportStream.
   234  // It should only be used by the generated files to support grpc.SendHeader
   235  // outside of gRPC server use.
   236  type ServerTransportStream struct {
   237  	mu      sync.Mutex
   238  	header  metadata.MD
   239  	trailer metadata.MD
   240  }
   241  
   242  // Method returns the method for the stream.
   243  func (s *ServerTransportStream) Method() string {
   244  	return ""
   245  }
   246  
   247  // Header returns the header metadata of the stream.
   248  func (s *ServerTransportStream) Header() metadata.MD {
   249  	s.mu.Lock()
   250  	defer s.mu.Unlock()
   251  	return s.header.Copy()
   252  }
   253  
   254  // SetHeader sets the header metadata.
   255  func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
   256  	if md.Len() == 0 {
   257  		return nil
   258  	}
   259  
   260  	s.mu.Lock()
   261  	s.header = metadata.Join(s.header, md)
   262  	s.mu.Unlock()
   263  	return nil
   264  }
   265  
   266  // SendHeader sets the header metadata.
   267  func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
   268  	return s.SetHeader(md)
   269  }
   270  
   271  // Trailer returns the cached trailer metadata.
   272  func (s *ServerTransportStream) Trailer() metadata.MD {
   273  	s.mu.Lock()
   274  	defer s.mu.Unlock()
   275  	return s.trailer.Copy()
   276  }
   277  
   278  // SetTrailer sets the trailer metadata.
   279  func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
   280  	if md.Len() == 0 {
   281  		return nil
   282  	}
   283  
   284  	s.mu.Lock()
   285  	s.trailer = metadata.Join(s.trailer, md)
   286  	s.mu.Unlock()
   287  	return nil
   288  }
   289  
   290  func timeoutDecode(s string) (time.Duration, error) {
   291  	size := len(s)
   292  	if size < 2 {
   293  		return 0, fmt.Errorf("timeout string is too short: %q", s)
   294  	}
   295  	d, ok := timeoutUnitToDuration(s[size-1])
   296  	if !ok {
   297  		return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
   298  	}
   299  	t, err := strconv.ParseInt(s[:size-1], 10, 64)
   300  	if err != nil {
   301  		return 0, err
   302  	}
   303  	return d * time.Duration(t), nil
   304  }
   305  
   306  func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
   307  	switch u {
   308  	case 'H':
   309  		return time.Hour, true
   310  	case 'M':
   311  		return time.Minute, true
   312  	case 'S':
   313  		return time.Second, true
   314  	case 'm':
   315  		return time.Millisecond, true
   316  	case 'u':
   317  		return time.Microsecond, true
   318  	case 'n':
   319  		return time.Nanosecond, true
   320  	default:
   321  		return
   322  	}
   323  }
   324  
   325  // isPermanentHTTPHeader checks whether hdr belongs to the list of
   326  // permanent request headers maintained by IANA.
   327  // http://www.iana.org/assignments/message-headers/message-headers.xml
   328  func isPermanentHTTPHeader(hdr string) bool {
   329  	switch hdr {
   330  	case
   331  		"Accept",
   332  		"Accept-Charset",
   333  		"Accept-Language",
   334  		"Accept-Ranges",
   335  		"Authorization",
   336  		"Cache-Control",
   337  		"Content-Type",
   338  		"Cookie",
   339  		"Date",
   340  		"Expect",
   341  		"From",
   342  		"Host",
   343  		"If-Match",
   344  		"If-Modified-Since",
   345  		"If-None-Match",
   346  		"If-Schedule-Tag-Match",
   347  		"If-Unmodified-Since",
   348  		"Max-Forwards",
   349  		"Origin",
   350  		"Pragma",
   351  		"Referer",
   352  		"User-Agent",
   353  		"Via",
   354  		"Warning":
   355  		return true
   356  	}
   357  	return false
   358  }
   359  
   360  // isMalformedHTTPHeader checks whether header belongs to the list of
   361  // "malformed headers" and would be rejected by the gRPC server.
   362  func isMalformedHTTPHeader(header string) bool {
   363  	_, isMalformed := malformedHTTPHeaders[strings.ToLower(header)]
   364  	return isMalformed
   365  }
   366  
   367  // RPCMethod returns the method string for the server context. The returned
   368  // string is in the format of "/package.service/method".
   369  func RPCMethod(ctx context.Context) (string, bool) {
   370  	m := ctx.Value(rpcMethodKey{})
   371  	if m == nil {
   372  		return "", false
   373  	}
   374  	ms, ok := m.(string)
   375  	if !ok {
   376  		return "", false
   377  	}
   378  	return ms, true
   379  }
   380  
   381  func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context {
   382  	return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName)
   383  }
   384  
   385  // HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists.
   386  // The format of the returned string is defined by the google.api.http path template type.
   387  func HTTPPathPattern(ctx context.Context) (string, bool) {
   388  	m := ctx.Value(httpPathPatternKey{})
   389  	if m == nil {
   390  		return "", false
   391  	}
   392  	ms, ok := m.(string)
   393  	if !ok {
   394  		return "", false
   395  	}
   396  	return ms, true
   397  }
   398  
   399  func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
   400  	return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
   401  }
   402  

View as plain text