...

Source file src/github.com/grpc-ecosystem/grpc-gateway/runtime/context.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/runtime

     1  package runtime
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/textproto"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/metadata"
    17  	"google.golang.org/grpc/status"
    18  )
    19  
    20  // MetadataHeaderPrefix is the http prefix that represents custom metadata
    21  // parameters to or from a gRPC call.
    22  const MetadataHeaderPrefix = "Grpc-Metadata-"
    23  
    24  // MetadataPrefix is prepended to permanent HTTP header keys (as specified
    25  // by the IANA) when added to the gRPC context.
    26  const MetadataPrefix = "grpcgateway-"
    27  
    28  // MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
    29  // HTTP headers in a response handled by grpc-gateway
    30  const MetadataTrailerPrefix = "Grpc-Trailer-"
    31  
    32  const metadataGrpcTimeout = "Grpc-Timeout"
    33  const metadataHeaderBinarySuffix = "-Bin"
    34  
    35  const xForwardedFor = "X-Forwarded-For"
    36  const xForwardedHost = "X-Forwarded-Host"
    37  
    38  var (
    39  	// DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
    40  	// header isn't present. If the value is 0 the sent `context` will not have a timeout.
    41  	DefaultContextTimeout = 0 * time.Second
    42  )
    43  
    44  func decodeBinHeader(v string) ([]byte, error) {
    45  	if len(v)%4 == 0 {
    46  		// Input was padded, or padding was not necessary.
    47  		return base64.StdEncoding.DecodeString(v)
    48  	}
    49  	return base64.RawStdEncoding.DecodeString(v)
    50  }
    51  
    52  /*
    53  AnnotateContext adds context information such as metadata from the request.
    54  
    55  At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
    56  except that the forwarded destination is not another HTTP service but rather
    57  a gRPC service.
    58  */
    59  func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, error) {
    60  	ctx, md, err := annotateContext(ctx, mux, req)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	if md == nil {
    65  		return ctx, nil
    66  	}
    67  
    68  	return metadata.NewOutgoingContext(ctx, md), nil
    69  }
    70  
    71  // AnnotateIncomingContext adds context information such as metadata from the request.
    72  // Attach metadata as incoming context.
    73  func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, error) {
    74  	ctx, md, err := annotateContext(ctx, mux, req)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	if md == nil {
    79  		return ctx, nil
    80  	}
    81  
    82  	return metadata.NewIncomingContext(ctx, md), nil
    83  }
    84  
    85  func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, metadata.MD, error) {
    86  	var pairs []string
    87  	timeout := DefaultContextTimeout
    88  	if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
    89  		var err error
    90  		timeout, err = timeoutDecode(tm)
    91  		if err != nil {
    92  			return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
    93  		}
    94  	}
    95  
    96  	for key, vals := range req.Header {
    97  		key = textproto.CanonicalMIMEHeaderKey(key)
    98  		for _, val := range vals {
    99  			// For backwards-compatibility, pass through 'authorization' header with no prefix.
   100  			if key == "Authorization" {
   101  				pairs = append(pairs, "authorization", val)
   102  			}
   103  			if h, ok := mux.incomingHeaderMatcher(key); ok {
   104  				// Handles "-bin" metadata in grpc, since grpc will do another base64
   105  				// encode before sending to server, we need to decode it first.
   106  				if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
   107  					b, err := decodeBinHeader(val)
   108  					if err != nil {
   109  						return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
   110  					}
   111  
   112  					val = string(b)
   113  				}
   114  				pairs = append(pairs, h, val)
   115  			}
   116  		}
   117  	}
   118  	if host := req.Header.Get(xForwardedHost); host != "" {
   119  		pairs = append(pairs, strings.ToLower(xForwardedHost), host)
   120  	} else if req.Host != "" {
   121  		pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
   122  	}
   123  
   124  	if addr := req.RemoteAddr; addr != "" {
   125  		if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
   126  			if fwd := req.Header.Get(xForwardedFor); fwd == "" {
   127  				pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
   128  			} else {
   129  				pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
   130  			}
   131  		}
   132  	}
   133  
   134  	if timeout != 0 {
   135  		ctx, _ = context.WithTimeout(ctx, timeout)
   136  	}
   137  	if len(pairs) == 0 {
   138  		return ctx, nil, nil
   139  	}
   140  	md := metadata.Pairs(pairs...)
   141  	for _, mda := range mux.metadataAnnotators {
   142  		md = metadata.Join(md, mda(ctx, req))
   143  	}
   144  	return ctx, md, nil
   145  }
   146  
   147  // ServerMetadata consists of metadata sent from gRPC server.
   148  type ServerMetadata struct {
   149  	HeaderMD  metadata.MD
   150  	TrailerMD metadata.MD
   151  }
   152  
   153  type serverMetadataKey struct{}
   154  
   155  // NewServerMetadataContext creates a new context with ServerMetadata
   156  func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
   157  	return context.WithValue(ctx, serverMetadataKey{}, md)
   158  }
   159  
   160  // ServerMetadataFromContext returns the ServerMetadata in ctx
   161  func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
   162  	md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
   163  	return
   164  }
   165  
   166  // ServerTransportStream implements grpc.ServerTransportStream.
   167  // It should only be used by the generated files to support grpc.SendHeader
   168  // outside of gRPC server use.
   169  type ServerTransportStream struct {
   170  	mu      sync.Mutex
   171  	header  metadata.MD
   172  	trailer metadata.MD
   173  }
   174  
   175  // Method returns the method for the stream.
   176  func (s *ServerTransportStream) Method() string {
   177  	return ""
   178  }
   179  
   180  // Header returns the header metadata of the stream.
   181  func (s *ServerTransportStream) Header() metadata.MD {
   182  	s.mu.Lock()
   183  	defer s.mu.Unlock()
   184  	return s.header.Copy()
   185  }
   186  
   187  // SetHeader sets the header metadata.
   188  func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
   189  	if md.Len() == 0 {
   190  		return nil
   191  	}
   192  
   193  	s.mu.Lock()
   194  	s.header = metadata.Join(s.header, md)
   195  	s.mu.Unlock()
   196  	return nil
   197  }
   198  
   199  // SendHeader sets the header metadata.
   200  func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
   201  	return s.SetHeader(md)
   202  }
   203  
   204  // Trailer returns the cached trailer metadata.
   205  func (s *ServerTransportStream) Trailer() metadata.MD {
   206  	s.mu.Lock()
   207  	defer s.mu.Unlock()
   208  	return s.trailer.Copy()
   209  }
   210  
   211  // SetTrailer sets the trailer metadata.
   212  func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
   213  	if md.Len() == 0 {
   214  		return nil
   215  	}
   216  
   217  	s.mu.Lock()
   218  	s.trailer = metadata.Join(s.trailer, md)
   219  	s.mu.Unlock()
   220  	return nil
   221  }
   222  
   223  func timeoutDecode(s string) (time.Duration, error) {
   224  	size := len(s)
   225  	if size < 2 {
   226  		return 0, fmt.Errorf("timeout string is too short: %q", s)
   227  	}
   228  	d, ok := timeoutUnitToDuration(s[size-1])
   229  	if !ok {
   230  		return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
   231  	}
   232  	t, err := strconv.ParseInt(s[:size-1], 10, 64)
   233  	if err != nil {
   234  		return 0, err
   235  	}
   236  	return d * time.Duration(t), nil
   237  }
   238  
   239  func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
   240  	switch u {
   241  	case 'H':
   242  		return time.Hour, true
   243  	case 'M':
   244  		return time.Minute, true
   245  	case 'S':
   246  		return time.Second, true
   247  	case 'm':
   248  		return time.Millisecond, true
   249  	case 'u':
   250  		return time.Microsecond, true
   251  	case 'n':
   252  		return time.Nanosecond, true
   253  	default:
   254  	}
   255  	return
   256  }
   257  
   258  // isPermanentHTTPHeader checks whether hdr belongs to the list of
   259  // permanent request headers maintained by IANA.
   260  // http://www.iana.org/assignments/message-headers/message-headers.xml
   261  func isPermanentHTTPHeader(hdr string) bool {
   262  	switch hdr {
   263  	case
   264  		"Accept",
   265  		"Accept-Charset",
   266  		"Accept-Language",
   267  		"Accept-Ranges",
   268  		"Authorization",
   269  		"Cache-Control",
   270  		"Content-Type",
   271  		"Cookie",
   272  		"Date",
   273  		"Expect",
   274  		"From",
   275  		"Host",
   276  		"If-Match",
   277  		"If-Modified-Since",
   278  		"If-None-Match",
   279  		"If-Schedule-Tag-Match",
   280  		"If-Unmodified-Since",
   281  		"Max-Forwards",
   282  		"Origin",
   283  		"Pragma",
   284  		"Referer",
   285  		"User-Agent",
   286  		"Via",
   287  		"Warning":
   288  		return true
   289  	}
   290  	return false
   291  }
   292  

View as plain text