1
18
19 package transport
20
21 import (
22 "bufio"
23 "encoding/base64"
24 "errors"
25 "fmt"
26 "io"
27 "math"
28 "net"
29 "net/http"
30 "net/url"
31 "strconv"
32 "strings"
33 "sync"
34 "time"
35 "unicode/utf8"
36
37 "golang.org/x/net/http2"
38 "golang.org/x/net/http2/hpack"
39 "google.golang.org/grpc/codes"
40 )
41
42 const (
43
44 http2MaxFrameLen = 16384
45
46 http2InitHeaderTableSize = 4096
47 )
48
49 var (
50 clientPreface = []byte(http2.ClientPreface)
51 http2ErrConvTab = map[http2.ErrCode]codes.Code{
52 http2.ErrCodeNo: codes.Internal,
53 http2.ErrCodeProtocol: codes.Internal,
54 http2.ErrCodeInternal: codes.Internal,
55 http2.ErrCodeFlowControl: codes.ResourceExhausted,
56 http2.ErrCodeSettingsTimeout: codes.Internal,
57 http2.ErrCodeStreamClosed: codes.Internal,
58 http2.ErrCodeFrameSize: codes.Internal,
59 http2.ErrCodeRefusedStream: codes.Unavailable,
60 http2.ErrCodeCancel: codes.Canceled,
61 http2.ErrCodeCompression: codes.Internal,
62 http2.ErrCodeConnect: codes.Internal,
63 http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
64 http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
65 http2.ErrCodeHTTP11Required: codes.Internal,
66 }
67
68 HTTPStatusConvTab = map[int]codes.Code{
69
70 http.StatusBadRequest: codes.Internal,
71
72 http.StatusUnauthorized: codes.Unauthenticated,
73
74 http.StatusForbidden: codes.PermissionDenied,
75
76 http.StatusNotFound: codes.Unimplemented,
77
78 http.StatusTooManyRequests: codes.Unavailable,
79
80 http.StatusBadGateway: codes.Unavailable,
81
82 http.StatusServiceUnavailable: codes.Unavailable,
83
84 http.StatusGatewayTimeout: codes.Unavailable,
85 }
86 )
87
88 var grpcStatusDetailsBinHeader = "grpc-status-details-bin"
89
90
91
92
93 func isReservedHeader(hdr string) bool {
94 if hdr != "" && hdr[0] == ':' {
95 return true
96 }
97 switch hdr {
98 case "content-type",
99 "user-agent",
100 "grpc-message-type",
101 "grpc-encoding",
102 "grpc-message",
103 "grpc-status",
104 "grpc-timeout",
105
106
107
108 "te":
109 return true
110 default:
111 return false
112 }
113 }
114
115
116
117 func isWhitelistedHeader(hdr string) bool {
118 switch hdr {
119 case ":authority", "user-agent":
120 return true
121 default:
122 return false
123 }
124 }
125
126 const binHdrSuffix = "-bin"
127
128 func encodeBinHeader(v []byte) string {
129 return base64.RawStdEncoding.EncodeToString(v)
130 }
131
132 func decodeBinHeader(v string) ([]byte, error) {
133 if len(v)%4 == 0 {
134
135 return base64.StdEncoding.DecodeString(v)
136 }
137 return base64.RawStdEncoding.DecodeString(v)
138 }
139
140 func encodeMetadataHeader(k, v string) string {
141 if strings.HasSuffix(k, binHdrSuffix) {
142 return encodeBinHeader(([]byte)(v))
143 }
144 return v
145 }
146
147 func decodeMetadataHeader(k, v string) (string, error) {
148 if strings.HasSuffix(k, binHdrSuffix) {
149 b, err := decodeBinHeader(v)
150 return string(b), err
151 }
152 return v, nil
153 }
154
155 type timeoutUnit uint8
156
157 const (
158 hour timeoutUnit = 'H'
159 minute timeoutUnit = 'M'
160 second timeoutUnit = 'S'
161 millisecond timeoutUnit = 'm'
162 microsecond timeoutUnit = 'u'
163 nanosecond timeoutUnit = 'n'
164 )
165
166 func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
167 switch u {
168 case hour:
169 return time.Hour, true
170 case minute:
171 return time.Minute, true
172 case second:
173 return time.Second, true
174 case millisecond:
175 return time.Millisecond, true
176 case microsecond:
177 return time.Microsecond, true
178 case nanosecond:
179 return time.Nanosecond, true
180 default:
181 }
182 return
183 }
184
185 func decodeTimeout(s string) (time.Duration, error) {
186 size := len(s)
187 if size < 2 {
188 return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
189 }
190 if size > 9 {
191
192 return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
193 }
194 unit := timeoutUnit(s[size-1])
195 d, ok := timeoutUnitToDuration(unit)
196 if !ok {
197 return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
198 }
199 t, err := strconv.ParseInt(s[:size-1], 10, 64)
200 if err != nil {
201 return 0, err
202 }
203 const maxHours = math.MaxInt64 / int64(time.Hour)
204 if d == time.Hour && t > maxHours {
205
206 return time.Duration(math.MaxInt64), nil
207 }
208 return d * time.Duration(t), nil
209 }
210
211 const (
212 spaceByte = ' '
213 tildeByte = '~'
214 percentByte = '%'
215 )
216
217
218
219
220
221
222
223
224 func encodeGrpcMessage(msg string) string {
225 if msg == "" {
226 return ""
227 }
228 lenMsg := len(msg)
229 for i := 0; i < lenMsg; i++ {
230 c := msg[i]
231 if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
232 return encodeGrpcMessageUnchecked(msg)
233 }
234 }
235 return msg
236 }
237
238 func encodeGrpcMessageUnchecked(msg string) string {
239 var sb strings.Builder
240 for len(msg) > 0 {
241 r, size := utf8.DecodeRuneInString(msg)
242 for _, b := range []byte(string(r)) {
243 if size > 1 {
244
245 fmt.Fprintf(&sb, "%%%02X", b)
246 continue
247 }
248
249
250
251
252
253 if b >= spaceByte && b <= tildeByte && b != percentByte {
254 sb.WriteByte(b)
255 } else {
256 fmt.Fprintf(&sb, "%%%02X", b)
257 }
258 }
259 msg = msg[size:]
260 }
261 return sb.String()
262 }
263
264
265 func decodeGrpcMessage(msg string) string {
266 if msg == "" {
267 return ""
268 }
269 lenMsg := len(msg)
270 for i := 0; i < lenMsg; i++ {
271 if msg[i] == percentByte && i+2 < lenMsg {
272 return decodeGrpcMessageUnchecked(msg)
273 }
274 }
275 return msg
276 }
277
278 func decodeGrpcMessageUnchecked(msg string) string {
279 var sb strings.Builder
280 lenMsg := len(msg)
281 for i := 0; i < lenMsg; i++ {
282 c := msg[i]
283 if c == percentByte && i+2 < lenMsg {
284 parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
285 if err != nil {
286 sb.WriteByte(c)
287 } else {
288 sb.WriteByte(byte(parsed))
289 i += 2
290 }
291 } else {
292 sb.WriteByte(c)
293 }
294 }
295 return sb.String()
296 }
297
298 type bufWriter struct {
299 pool *sync.Pool
300 buf []byte
301 offset int
302 batchSize int
303 conn net.Conn
304 err error
305 }
306
307 func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
308 w := &bufWriter{
309 batchSize: batchSize,
310 conn: conn,
311 pool: pool,
312 }
313
314 if pool == nil {
315 w.buf = make([]byte, batchSize)
316 }
317 return w
318 }
319
320 func (w *bufWriter) Write(b []byte) (n int, err error) {
321 if w.err != nil {
322 return 0, w.err
323 }
324 if w.batchSize == 0 {
325 n, err = w.conn.Write(b)
326 return n, toIOError(err)
327 }
328 if w.buf == nil {
329 b := w.pool.Get().(*[]byte)
330 w.buf = *b
331 }
332 for len(b) > 0 {
333 nn := copy(w.buf[w.offset:], b)
334 b = b[nn:]
335 w.offset += nn
336 n += nn
337 if w.offset >= w.batchSize {
338 err = w.flushKeepBuffer()
339 }
340 }
341 return n, err
342 }
343
344 func (w *bufWriter) Flush() error {
345 err := w.flushKeepBuffer()
346
347 if w.buf != nil && w.pool != nil {
348 b := w.buf
349 w.pool.Put(&b)
350 w.buf = nil
351 }
352 return err
353 }
354
355 func (w *bufWriter) flushKeepBuffer() error {
356 if w.err != nil {
357 return w.err
358 }
359 if w.offset == 0 {
360 return nil
361 }
362 _, w.err = w.conn.Write(w.buf[:w.offset])
363 w.err = toIOError(w.err)
364 w.offset = 0
365 return w.err
366 }
367
368 type ioError struct {
369 error
370 }
371
372 func (i ioError) Unwrap() error {
373 return i.error
374 }
375
376 func isIOError(err error) bool {
377 return errors.As(err, &ioError{})
378 }
379
380 func toIOError(err error) error {
381 if err == nil {
382 return nil
383 }
384 return ioError{error: err}
385 }
386
387 type framer struct {
388 writer *bufWriter
389 fr *http2.Framer
390 }
391
392 var writeBufferPoolMap map[int]*sync.Pool = make(map[int]*sync.Pool)
393 var writeBufferMutex sync.Mutex
394
395 func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer {
396 if writeBufferSize < 0 {
397 writeBufferSize = 0
398 }
399 var r io.Reader = conn
400 if readBufferSize > 0 {
401 r = bufio.NewReaderSize(r, readBufferSize)
402 }
403 var pool *sync.Pool
404 if sharedWriteBuffer {
405 pool = getWriteBufferPool(writeBufferSize)
406 }
407 w := newBufWriter(conn, writeBufferSize, pool)
408 f := &framer{
409 writer: w,
410 fr: http2.NewFramer(w, r),
411 }
412 f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
413
414
415 f.fr.SetReuseFrames()
416 f.fr.MaxHeaderListSize = maxHeaderListSize
417 f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
418 return f
419 }
420
421 func getWriteBufferPool(size int) *sync.Pool {
422 writeBufferMutex.Lock()
423 defer writeBufferMutex.Unlock()
424 pool, ok := writeBufferPoolMap[size]
425 if ok {
426 return pool
427 }
428 pool = &sync.Pool{
429 New: func() any {
430 b := make([]byte, size)
431 return &b
432 },
433 }
434 writeBufferPoolMap[size] = pool
435 return pool
436 }
437
438
439 func parseDialTarget(target string) (string, string) {
440 net := "tcp"
441 m1 := strings.Index(target, ":")
442 m2 := strings.Index(target, ":/")
443
444 if m1 >= 0 && m2 < 0 {
445 if n := target[0:m1]; n == "unix" {
446 return n, target[m1+1:]
447 }
448 }
449 if m2 >= 0 {
450 t, err := url.Parse(target)
451 if err != nil {
452 return net, target
453 }
454 scheme := t.Scheme
455 addr := t.Path
456 if scheme == "unix" {
457 if addr == "" {
458 addr = t.Host
459 }
460 return scheme, addr
461 }
462 }
463 return net, target
464 }
465
View as plain text