...

Source file src/google.golang.org/grpc/internal/transport/http_util.go

Documentation: google.golang.org/grpc/internal/transport

     1  /*
     2   *
     3   * Copyright 2014 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package transport
    20  
    21  import (
    22  	"bufio"
    23  	"encoding/base64"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"math"
    28  	"net"
    29  	"net/http"
    30  	"net/url"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  	"unicode/utf8"
    36  
    37  	"golang.org/x/net/http2"
    38  	"golang.org/x/net/http2/hpack"
    39  	"google.golang.org/grpc/codes"
    40  )
    41  
    42  const (
    43  	// http2MaxFrameLen specifies the max length of a HTTP2 frame.
    44  	http2MaxFrameLen = 16384 // 16KB frame
    45  	// https://httpwg.org/specs/rfc7540.html#SettingValues
    46  	http2InitHeaderTableSize = 4096
    47  )
    48  
    49  var (
    50  	clientPreface   = []byte(http2.ClientPreface)
    51  	http2ErrConvTab = map[http2.ErrCode]codes.Code{
    52  		http2.ErrCodeNo:                 codes.Internal,
    53  		http2.ErrCodeProtocol:           codes.Internal,
    54  		http2.ErrCodeInternal:           codes.Internal,
    55  		http2.ErrCodeFlowControl:        codes.ResourceExhausted,
    56  		http2.ErrCodeSettingsTimeout:    codes.Internal,
    57  		http2.ErrCodeStreamClosed:       codes.Internal,
    58  		http2.ErrCodeFrameSize:          codes.Internal,
    59  		http2.ErrCodeRefusedStream:      codes.Unavailable,
    60  		http2.ErrCodeCancel:             codes.Canceled,
    61  		http2.ErrCodeCompression:        codes.Internal,
    62  		http2.ErrCodeConnect:            codes.Internal,
    63  		http2.ErrCodeEnhanceYourCalm:    codes.ResourceExhausted,
    64  		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
    65  		http2.ErrCodeHTTP11Required:     codes.Internal,
    66  	}
    67  	// HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
    68  	HTTPStatusConvTab = map[int]codes.Code{
    69  		// 400 Bad Request - INTERNAL.
    70  		http.StatusBadRequest: codes.Internal,
    71  		// 401 Unauthorized  - UNAUTHENTICATED.
    72  		http.StatusUnauthorized: codes.Unauthenticated,
    73  		// 403 Forbidden - PERMISSION_DENIED.
    74  		http.StatusForbidden: codes.PermissionDenied,
    75  		// 404 Not Found - UNIMPLEMENTED.
    76  		http.StatusNotFound: codes.Unimplemented,
    77  		// 429 Too Many Requests - UNAVAILABLE.
    78  		http.StatusTooManyRequests: codes.Unavailable,
    79  		// 502 Bad Gateway - UNAVAILABLE.
    80  		http.StatusBadGateway: codes.Unavailable,
    81  		// 503 Service Unavailable - UNAVAILABLE.
    82  		http.StatusServiceUnavailable: codes.Unavailable,
    83  		// 504 Gateway timeout - UNAVAILABLE.
    84  		http.StatusGatewayTimeout: codes.Unavailable,
    85  	}
    86  )
    87  
    88  var grpcStatusDetailsBinHeader = "grpc-status-details-bin"
    89  
    90  // isReservedHeader checks whether hdr belongs to HTTP2 headers
    91  // reserved by gRPC protocol. Any other headers are classified as the
    92  // user-specified metadata.
    93  func isReservedHeader(hdr string) bool {
    94  	if hdr != "" && hdr[0] == ':' {
    95  		return true
    96  	}
    97  	switch hdr {
    98  	case "content-type",
    99  		"user-agent",
   100  		"grpc-message-type",
   101  		"grpc-encoding",
   102  		"grpc-message",
   103  		"grpc-status",
   104  		"grpc-timeout",
   105  		// Intentionally exclude grpc-previous-rpc-attempts and
   106  		// grpc-retry-pushback-ms, which are "reserved", but their API
   107  		// intentionally works via metadata.
   108  		"te":
   109  		return true
   110  	default:
   111  		return false
   112  	}
   113  }
   114  
   115  // isWhitelistedHeader checks whether hdr should be propagated into metadata
   116  // visible to users, even though it is classified as "reserved", above.
   117  func isWhitelistedHeader(hdr string) bool {
   118  	switch hdr {
   119  	case ":authority", "user-agent":
   120  		return true
   121  	default:
   122  		return false
   123  	}
   124  }
   125  
   126  const binHdrSuffix = "-bin"
   127  
   128  func encodeBinHeader(v []byte) string {
   129  	return base64.RawStdEncoding.EncodeToString(v)
   130  }
   131  
   132  func decodeBinHeader(v string) ([]byte, error) {
   133  	if len(v)%4 == 0 {
   134  		// Input was padded, or padding was not necessary.
   135  		return base64.StdEncoding.DecodeString(v)
   136  	}
   137  	return base64.RawStdEncoding.DecodeString(v)
   138  }
   139  
   140  func encodeMetadataHeader(k, v string) string {
   141  	if strings.HasSuffix(k, binHdrSuffix) {
   142  		return encodeBinHeader(([]byte)(v))
   143  	}
   144  	return v
   145  }
   146  
   147  func decodeMetadataHeader(k, v string) (string, error) {
   148  	if strings.HasSuffix(k, binHdrSuffix) {
   149  		b, err := decodeBinHeader(v)
   150  		return string(b), err
   151  	}
   152  	return v, nil
   153  }
   154  
   155  type timeoutUnit uint8
   156  
   157  const (
   158  	hour        timeoutUnit = 'H'
   159  	minute      timeoutUnit = 'M'
   160  	second      timeoutUnit = 'S'
   161  	millisecond timeoutUnit = 'm'
   162  	microsecond timeoutUnit = 'u'
   163  	nanosecond  timeoutUnit = 'n'
   164  )
   165  
   166  func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
   167  	switch u {
   168  	case hour:
   169  		return time.Hour, true
   170  	case minute:
   171  		return time.Minute, true
   172  	case second:
   173  		return time.Second, true
   174  	case millisecond:
   175  		return time.Millisecond, true
   176  	case microsecond:
   177  		return time.Microsecond, true
   178  	case nanosecond:
   179  		return time.Nanosecond, true
   180  	default:
   181  	}
   182  	return
   183  }
   184  
   185  func decodeTimeout(s string) (time.Duration, error) {
   186  	size := len(s)
   187  	if size < 2 {
   188  		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
   189  	}
   190  	if size > 9 {
   191  		// Spec allows for 8 digits plus the unit.
   192  		return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
   193  	}
   194  	unit := timeoutUnit(s[size-1])
   195  	d, ok := timeoutUnitToDuration(unit)
   196  	if !ok {
   197  		return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
   198  	}
   199  	t, err := strconv.ParseInt(s[:size-1], 10, 64)
   200  	if err != nil {
   201  		return 0, err
   202  	}
   203  	const maxHours = math.MaxInt64 / int64(time.Hour)
   204  	if d == time.Hour && t > maxHours {
   205  		// This timeout would overflow math.MaxInt64; clamp it.
   206  		return time.Duration(math.MaxInt64), nil
   207  	}
   208  	return d * time.Duration(t), nil
   209  }
   210  
   211  const (
   212  	spaceByte   = ' '
   213  	tildeByte   = '~'
   214  	percentByte = '%'
   215  )
   216  
   217  // encodeGrpcMessage is used to encode status code in header field
   218  // "grpc-message". It does percent encoding and also replaces invalid utf-8
   219  // characters with Unicode replacement character.
   220  //
   221  // It checks to see if each individual byte in msg is an allowable byte, and
   222  // then either percent encoding or passing it through. When percent encoding,
   223  // the byte is converted into hexadecimal notation with a '%' prepended.
   224  func encodeGrpcMessage(msg string) string {
   225  	if msg == "" {
   226  		return ""
   227  	}
   228  	lenMsg := len(msg)
   229  	for i := 0; i < lenMsg; i++ {
   230  		c := msg[i]
   231  		if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
   232  			return encodeGrpcMessageUnchecked(msg)
   233  		}
   234  	}
   235  	return msg
   236  }
   237  
   238  func encodeGrpcMessageUnchecked(msg string) string {
   239  	var sb strings.Builder
   240  	for len(msg) > 0 {
   241  		r, size := utf8.DecodeRuneInString(msg)
   242  		for _, b := range []byte(string(r)) {
   243  			if size > 1 {
   244  				// If size > 1, r is not ascii. Always do percent encoding.
   245  				fmt.Fprintf(&sb, "%%%02X", b)
   246  				continue
   247  			}
   248  
   249  			// The for loop is necessary even if size == 1. r could be
   250  			// utf8.RuneError.
   251  			//
   252  			// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
   253  			if b >= spaceByte && b <= tildeByte && b != percentByte {
   254  				sb.WriteByte(b)
   255  			} else {
   256  				fmt.Fprintf(&sb, "%%%02X", b)
   257  			}
   258  		}
   259  		msg = msg[size:]
   260  	}
   261  	return sb.String()
   262  }
   263  
   264  // decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
   265  func decodeGrpcMessage(msg string) string {
   266  	if msg == "" {
   267  		return ""
   268  	}
   269  	lenMsg := len(msg)
   270  	for i := 0; i < lenMsg; i++ {
   271  		if msg[i] == percentByte && i+2 < lenMsg {
   272  			return decodeGrpcMessageUnchecked(msg)
   273  		}
   274  	}
   275  	return msg
   276  }
   277  
   278  func decodeGrpcMessageUnchecked(msg string) string {
   279  	var sb strings.Builder
   280  	lenMsg := len(msg)
   281  	for i := 0; i < lenMsg; i++ {
   282  		c := msg[i]
   283  		if c == percentByte && i+2 < lenMsg {
   284  			parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
   285  			if err != nil {
   286  				sb.WriteByte(c)
   287  			} else {
   288  				sb.WriteByte(byte(parsed))
   289  				i += 2
   290  			}
   291  		} else {
   292  			sb.WriteByte(c)
   293  		}
   294  	}
   295  	return sb.String()
   296  }
   297  
   298  type bufWriter struct {
   299  	pool      *sync.Pool
   300  	buf       []byte
   301  	offset    int
   302  	batchSize int
   303  	conn      net.Conn
   304  	err       error
   305  }
   306  
   307  func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
   308  	w := &bufWriter{
   309  		batchSize: batchSize,
   310  		conn:      conn,
   311  		pool:      pool,
   312  	}
   313  	// this indicates that we should use non shared buf
   314  	if pool == nil {
   315  		w.buf = make([]byte, batchSize)
   316  	}
   317  	return w
   318  }
   319  
   320  func (w *bufWriter) Write(b []byte) (n int, err error) {
   321  	if w.err != nil {
   322  		return 0, w.err
   323  	}
   324  	if w.batchSize == 0 { // Buffer has been disabled.
   325  		n, err = w.conn.Write(b)
   326  		return n, toIOError(err)
   327  	}
   328  	if w.buf == nil {
   329  		b := w.pool.Get().(*[]byte)
   330  		w.buf = *b
   331  	}
   332  	for len(b) > 0 {
   333  		nn := copy(w.buf[w.offset:], b)
   334  		b = b[nn:]
   335  		w.offset += nn
   336  		n += nn
   337  		if w.offset >= w.batchSize {
   338  			err = w.flushKeepBuffer()
   339  		}
   340  	}
   341  	return n, err
   342  }
   343  
   344  func (w *bufWriter) Flush() error {
   345  	err := w.flushKeepBuffer()
   346  	// Only release the buffer if we are in a "shared" mode
   347  	if w.buf != nil && w.pool != nil {
   348  		b := w.buf
   349  		w.pool.Put(&b)
   350  		w.buf = nil
   351  	}
   352  	return err
   353  }
   354  
   355  func (w *bufWriter) flushKeepBuffer() error {
   356  	if w.err != nil {
   357  		return w.err
   358  	}
   359  	if w.offset == 0 {
   360  		return nil
   361  	}
   362  	_, w.err = w.conn.Write(w.buf[:w.offset])
   363  	w.err = toIOError(w.err)
   364  	w.offset = 0
   365  	return w.err
   366  }
   367  
   368  type ioError struct {
   369  	error
   370  }
   371  
   372  func (i ioError) Unwrap() error {
   373  	return i.error
   374  }
   375  
   376  func isIOError(err error) bool {
   377  	return errors.As(err, &ioError{})
   378  }
   379  
   380  func toIOError(err error) error {
   381  	if err == nil {
   382  		return nil
   383  	}
   384  	return ioError{error: err}
   385  }
   386  
   387  type framer struct {
   388  	writer *bufWriter
   389  	fr     *http2.Framer
   390  }
   391  
   392  var writeBufferPoolMap map[int]*sync.Pool = make(map[int]*sync.Pool)
   393  var writeBufferMutex sync.Mutex
   394  
   395  func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer {
   396  	if writeBufferSize < 0 {
   397  		writeBufferSize = 0
   398  	}
   399  	var r io.Reader = conn
   400  	if readBufferSize > 0 {
   401  		r = bufio.NewReaderSize(r, readBufferSize)
   402  	}
   403  	var pool *sync.Pool
   404  	if sharedWriteBuffer {
   405  		pool = getWriteBufferPool(writeBufferSize)
   406  	}
   407  	w := newBufWriter(conn, writeBufferSize, pool)
   408  	f := &framer{
   409  		writer: w,
   410  		fr:     http2.NewFramer(w, r),
   411  	}
   412  	f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
   413  	// Opt-in to Frame reuse API on framer to reduce garbage.
   414  	// Frames aren't safe to read from after a subsequent call to ReadFrame.
   415  	f.fr.SetReuseFrames()
   416  	f.fr.MaxHeaderListSize = maxHeaderListSize
   417  	f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
   418  	return f
   419  }
   420  
   421  func getWriteBufferPool(size int) *sync.Pool {
   422  	writeBufferMutex.Lock()
   423  	defer writeBufferMutex.Unlock()
   424  	pool, ok := writeBufferPoolMap[size]
   425  	if ok {
   426  		return pool
   427  	}
   428  	pool = &sync.Pool{
   429  		New: func() any {
   430  			b := make([]byte, size)
   431  			return &b
   432  		},
   433  	}
   434  	writeBufferPoolMap[size] = pool
   435  	return pool
   436  }
   437  
   438  // parseDialTarget returns the network and address to pass to dialer.
   439  func parseDialTarget(target string) (string, string) {
   440  	net := "tcp"
   441  	m1 := strings.Index(target, ":")
   442  	m2 := strings.Index(target, ":/")
   443  	// handle unix:addr which will fail with url.Parse
   444  	if m1 >= 0 && m2 < 0 {
   445  		if n := target[0:m1]; n == "unix" {
   446  			return n, target[m1+1:]
   447  		}
   448  	}
   449  	if m2 >= 0 {
   450  		t, err := url.Parse(target)
   451  		if err != nil {
   452  			return net, target
   453  		}
   454  		scheme := t.Scheme
   455  		addr := t.Path
   456  		if scheme == "unix" {
   457  			if addr == "" {
   458  				addr = t.Host
   459  			}
   460  			return scheme, addr
   461  		}
   462  	}
   463  	return net, target
   464  }
   465  

View as plain text