1
18
19 package grpc
20
21 import (
22 "bytes"
23 "compress/gzip"
24 "context"
25 "encoding/binary"
26 "fmt"
27 "io"
28 "math"
29 "strings"
30 "sync"
31 "time"
32
33 "google.golang.org/grpc/codes"
34 "google.golang.org/grpc/credentials"
35 "google.golang.org/grpc/encoding"
36 "google.golang.org/grpc/encoding/proto"
37 "google.golang.org/grpc/internal/transport"
38 "google.golang.org/grpc/metadata"
39 "google.golang.org/grpc/peer"
40 "google.golang.org/grpc/stats"
41 "google.golang.org/grpc/status"
42 )
43
44
45
46
47 type Compressor interface {
48
49 Do(w io.Writer, p []byte) error
50
51 Type() string
52 }
53
54 type gzipCompressor struct {
55 pool sync.Pool
56 }
57
58
59
60
61 func NewGZIPCompressor() Compressor {
62 c, _ := NewGZIPCompressorWithLevel(gzip.DefaultCompression)
63 return c
64 }
65
66
67
68
69
70
71
72 func NewGZIPCompressorWithLevel(level int) (Compressor, error) {
73 if level < gzip.DefaultCompression || level > gzip.BestCompression {
74 return nil, fmt.Errorf("grpc: invalid compression level: %d", level)
75 }
76 return &gzipCompressor{
77 pool: sync.Pool{
78 New: func() any {
79 w, err := gzip.NewWriterLevel(io.Discard, level)
80 if err != nil {
81 panic(err)
82 }
83 return w
84 },
85 },
86 }, nil
87 }
88
89 func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
90 z := c.pool.Get().(*gzip.Writer)
91 defer c.pool.Put(z)
92 z.Reset(w)
93 if _, err := z.Write(p); err != nil {
94 return err
95 }
96 return z.Close()
97 }
98
99 func (c *gzipCompressor) Type() string {
100 return "gzip"
101 }
102
103
104
105
106 type Decompressor interface {
107
108 Do(r io.Reader) ([]byte, error)
109
110 Type() string
111 }
112
113 type gzipDecompressor struct {
114 pool sync.Pool
115 }
116
117
118
119
120 func NewGZIPDecompressor() Decompressor {
121 return &gzipDecompressor{}
122 }
123
124 func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) {
125 var z *gzip.Reader
126 switch maybeZ := d.pool.Get().(type) {
127 case nil:
128 newZ, err := gzip.NewReader(r)
129 if err != nil {
130 return nil, err
131 }
132 z = newZ
133 case *gzip.Reader:
134 z = maybeZ
135 if err := z.Reset(r); err != nil {
136 d.pool.Put(z)
137 return nil, err
138 }
139 }
140
141 defer func() {
142 z.Close()
143 d.pool.Put(z)
144 }()
145 return io.ReadAll(z)
146 }
147
148 func (d *gzipDecompressor) Type() string {
149 return "gzip"
150 }
151
152
153 type callInfo struct {
154 compressorType string
155 failFast bool
156 maxReceiveMessageSize *int
157 maxSendMessageSize *int
158 creds credentials.PerRPCCredentials
159 contentSubtype string
160 codec baseCodec
161 maxRetryRPCBufferSize int
162 onFinish []func(err error)
163 }
164
165 func defaultCallInfo() *callInfo {
166 return &callInfo{
167 failFast: true,
168 maxRetryRPCBufferSize: 256 * 1024,
169 }
170 }
171
172
173
174 type CallOption interface {
175
176
177 before(*callInfo) error
178
179
180
181 after(*callInfo, *csAttempt)
182 }
183
184
185
186
187 type EmptyCallOption struct{}
188
189 func (EmptyCallOption) before(*callInfo) error { return nil }
190 func (EmptyCallOption) after(*callInfo, *csAttempt) {}
191
192
193
194
195
196 func StaticMethod() CallOption {
197 return StaticMethodCallOption{}
198 }
199
200
201
202 type StaticMethodCallOption struct {
203 EmptyCallOption
204 }
205
206
207
208 func Header(md *metadata.MD) CallOption {
209 return HeaderCallOption{HeaderAddr: md}
210 }
211
212
213
214
215
216
217
218
219 type HeaderCallOption struct {
220 HeaderAddr *metadata.MD
221 }
222
223 func (o HeaderCallOption) before(c *callInfo) error { return nil }
224 func (o HeaderCallOption) after(c *callInfo, attempt *csAttempt) {
225 *o.HeaderAddr, _ = attempt.s.Header()
226 }
227
228
229
230 func Trailer(md *metadata.MD) CallOption {
231 return TrailerCallOption{TrailerAddr: md}
232 }
233
234
235
236
237
238
239
240
241 type TrailerCallOption struct {
242 TrailerAddr *metadata.MD
243 }
244
245 func (o TrailerCallOption) before(c *callInfo) error { return nil }
246 func (o TrailerCallOption) after(c *callInfo, attempt *csAttempt) {
247 *o.TrailerAddr = attempt.s.Trailer()
248 }
249
250
251
252 func Peer(p *peer.Peer) CallOption {
253 return PeerCallOption{PeerAddr: p}
254 }
255
256
257
258
259
260
261
262
263 type PeerCallOption struct {
264 PeerAddr *peer.Peer
265 }
266
267 func (o PeerCallOption) before(c *callInfo) error { return nil }
268 func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) {
269 if x, ok := peer.FromContext(attempt.s.Context()); ok {
270 *o.PeerAddr = *x
271 }
272 }
273
274
275
276
277
278
279
280
281
282
283
284
285 func WaitForReady(waitForReady bool) CallOption {
286 return FailFastCallOption{FailFast: !waitForReady}
287 }
288
289
290
291
292 func FailFast(failFast bool) CallOption {
293 return FailFastCallOption{FailFast: failFast}
294 }
295
296
297
298
299
300
301
302
303 type FailFastCallOption struct {
304 FailFast bool
305 }
306
307 func (o FailFastCallOption) before(c *callInfo) error {
308 c.failFast = o.FailFast
309 return nil
310 }
311 func (o FailFastCallOption) after(c *callInfo, attempt *csAttempt) {}
312
313
314
315
316
317
318
319
320
321
322
323
324 func OnFinish(onFinish func(err error)) CallOption {
325 return OnFinishCallOption{
326 OnFinish: onFinish,
327 }
328 }
329
330
331
332
333
334
335
336
337 type OnFinishCallOption struct {
338 OnFinish func(error)
339 }
340
341 func (o OnFinishCallOption) before(c *callInfo) error {
342 c.onFinish = append(c.onFinish, o.OnFinish)
343 return nil
344 }
345
346 func (o OnFinishCallOption) after(c *callInfo, attempt *csAttempt) {}
347
348
349
350
351 func MaxCallRecvMsgSize(bytes int) CallOption {
352 return MaxRecvMsgSizeCallOption{MaxRecvMsgSize: bytes}
353 }
354
355
356
357
358
359
360
361
362 type MaxRecvMsgSizeCallOption struct {
363 MaxRecvMsgSize int
364 }
365
366 func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error {
367 c.maxReceiveMessageSize = &o.MaxRecvMsgSize
368 return nil
369 }
370 func (o MaxRecvMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
371
372
373
374
375 func MaxCallSendMsgSize(bytes int) CallOption {
376 return MaxSendMsgSizeCallOption{MaxSendMsgSize: bytes}
377 }
378
379
380
381
382
383
384
385
386 type MaxSendMsgSizeCallOption struct {
387 MaxSendMsgSize int
388 }
389
390 func (o MaxSendMsgSizeCallOption) before(c *callInfo) error {
391 c.maxSendMessageSize = &o.MaxSendMsgSize
392 return nil
393 }
394 func (o MaxSendMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
395
396
397
398 func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
399 return PerRPCCredsCallOption{Creds: creds}
400 }
401
402
403
404
405
406
407
408
409 type PerRPCCredsCallOption struct {
410 Creds credentials.PerRPCCredentials
411 }
412
413 func (o PerRPCCredsCallOption) before(c *callInfo) error {
414 c.creds = o.Creds
415 return nil
416 }
417 func (o PerRPCCredsCallOption) after(c *callInfo, attempt *csAttempt) {}
418
419
420
421
422
423
424
425
426
427 func UseCompressor(name string) CallOption {
428 return CompressorCallOption{CompressorType: name}
429 }
430
431
432
433
434
435
436
437 type CompressorCallOption struct {
438 CompressorType string
439 }
440
441 func (o CompressorCallOption) before(c *callInfo) error {
442 c.compressorType = o.CompressorType
443 return nil
444 }
445 func (o CompressorCallOption) after(c *callInfo, attempt *csAttempt) {}
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463 func CallContentSubtype(contentSubtype string) CallOption {
464 return ContentSubtypeCallOption{ContentSubtype: strings.ToLower(contentSubtype)}
465 }
466
467
468
469
470
471
472
473
474 type ContentSubtypeCallOption struct {
475 ContentSubtype string
476 }
477
478 func (o ContentSubtypeCallOption) before(c *callInfo) error {
479 c.contentSubtype = o.ContentSubtype
480 return nil
481 }
482 func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {}
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502 func ForceCodec(codec encoding.Codec) CallOption {
503 return ForceCodecCallOption{Codec: codec}
504 }
505
506
507
508
509
510
511
512
513 type ForceCodecCallOption struct {
514 Codec encoding.Codec
515 }
516
517 func (o ForceCodecCallOption) before(c *callInfo) error {
518 c.codec = o.Codec
519 return nil
520 }
521 func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
522
523
524
525
526
527 func CallCustomCodec(codec Codec) CallOption {
528 return CustomCodecCallOption{Codec: codec}
529 }
530
531
532
533
534
535
536
537
538 type CustomCodecCallOption struct {
539 Codec Codec
540 }
541
542 func (o CustomCodecCallOption) before(c *callInfo) error {
543 c.codec = o.Codec
544 return nil
545 }
546 func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
547
548
549
550
551
552
553
554
555 func MaxRetryRPCBufferSize(bytes int) CallOption {
556 return MaxRetryRPCBufferSizeCallOption{bytes}
557 }
558
559
560
561
562
563
564
565
566 type MaxRetryRPCBufferSizeCallOption struct {
567 MaxRetryRPCBufferSize int
568 }
569
570 func (o MaxRetryRPCBufferSizeCallOption) before(c *callInfo) error {
571 c.maxRetryRPCBufferSize = o.MaxRetryRPCBufferSize
572 return nil
573 }
574 func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
575
576
577 type payloadFormat uint8
578
579 const (
580 compressionNone payloadFormat = 0
581 compressionMade payloadFormat = 1
582 )
583
584
585 type parser struct {
586
587
588
589 r io.Reader
590
591
592
593 header [5]byte
594
595
596 recvBufferPool SharedBufferPool
597 }
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613 func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
614 if _, err := p.r.Read(p.header[:]); err != nil {
615 return 0, nil, err
616 }
617
618 pf = payloadFormat(p.header[0])
619 length := binary.BigEndian.Uint32(p.header[1:])
620
621 if length == 0 {
622 return pf, nil, nil
623 }
624 if int64(length) > int64(maxInt) {
625 return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
626 }
627 if int(length) > maxReceiveMessageSize {
628 return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
629 }
630 msg = p.recvBufferPool.Get(int(length))
631 if _, err := p.r.Read(msg); err != nil {
632 if err == io.EOF {
633 err = io.ErrUnexpectedEOF
634 }
635 return 0, nil, err
636 }
637 return pf, msg, nil
638 }
639
640
641
642
643 func encode(c baseCodec, msg any) ([]byte, error) {
644 if msg == nil {
645 return nil, nil
646 }
647 b, err := c.Marshal(msg)
648 if err != nil {
649 return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
650 }
651 if uint(len(b)) > math.MaxUint32 {
652 return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
653 }
654 return b, nil
655 }
656
657
658
659
660
661
662 func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
663 if compressor == nil && cp == nil {
664 return nil, nil
665 }
666 if len(in) == 0 {
667 return nil, nil
668 }
669 wrapErr := func(err error) error {
670 return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
671 }
672 cbuf := &bytes.Buffer{}
673 if compressor != nil {
674 z, err := compressor.Compress(cbuf)
675 if err != nil {
676 return nil, wrapErr(err)
677 }
678 if _, err := z.Write(in); err != nil {
679 return nil, wrapErr(err)
680 }
681 if err := z.Close(); err != nil {
682 return nil, wrapErr(err)
683 }
684 } else {
685 if err := cp.Do(cbuf, in); err != nil {
686 return nil, wrapErr(err)
687 }
688 }
689 return cbuf.Bytes(), nil
690 }
691
692 const (
693 payloadLen = 1
694 sizeLen = 4
695 headerLen = payloadLen + sizeLen
696 )
697
698
699
700 func msgHeader(data, compData []byte) (hdr []byte, payload []byte) {
701 hdr = make([]byte, headerLen)
702 if compData != nil {
703 hdr[0] = byte(compressionMade)
704 data = compData
705 } else {
706 hdr[0] = byte(compressionNone)
707 }
708
709
710 binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data)))
711 return hdr, data
712 }
713
714 func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.OutPayload {
715 return &stats.OutPayload{
716 Client: client,
717 Payload: msg,
718 Data: data,
719 Length: len(data),
720 WireLength: len(payload) + headerLen,
721 CompressedLength: len(payload),
722 SentTime: t,
723 }
724 }
725
726 func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
727 switch pf {
728 case compressionNone:
729 case compressionMade:
730 if recvCompress == "" || recvCompress == encoding.Identity {
731 return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
732 }
733 if !haveCompressor {
734 return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
735 }
736 default:
737 return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
738 }
739 return nil
740 }
741
742 type payloadInfo struct {
743 compressedLength int
744 uncompressedBytes []byte
745 }
746
747
748
749
750
751 func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
752 ) (uncompressedBuf []byte, cancel func(), err error) {
753 pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
754 if err != nil {
755 return nil, nil, err
756 }
757
758 if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
759 return nil, nil, st.Err()
760 }
761
762 var size int
763 if pf == compressionMade {
764
765
766 if dc != nil {
767 uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
768 size = len(uncompressedBuf)
769 } else {
770 uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
771 }
772 if err != nil {
773 return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
774 }
775 if size > maxReceiveMessageSize {
776
777
778 return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
779 }
780 } else {
781 uncompressedBuf = compressedBuf
782 }
783
784 if payInfo != nil {
785 payInfo.compressedLength = len(compressedBuf)
786 payInfo.uncompressedBytes = uncompressedBuf
787
788 cancel = func() {}
789 } else {
790 cancel = func() {
791 p.recvBufferPool.Put(&compressedBuf)
792 }
793 }
794
795 return uncompressedBuf, cancel, nil
796 }
797
798
799
800 func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) {
801 dcReader, err := compressor.Decompress(bytes.NewReader(d))
802 if err != nil {
803 return nil, 0, err
804 }
805 if sizer, ok := compressor.(interface {
806 DecompressedSize(compressedBytes []byte) int
807 }); ok {
808 if size := sizer.DecompressedSize(d); size >= 0 {
809 if size > maxReceiveMessageSize {
810 return nil, size, nil
811 }
812
813
814
815
816
817
818 buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
819 bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
820 return buf.Bytes(), int(bytesRead), err
821 }
822 }
823
824
825 d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
826 return d, len(d), err
827 }
828
829
830
831
832 func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
833 buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
834 if err != nil {
835 return err
836 }
837 defer cancel()
838
839 if err := c.Unmarshal(buf, m); err != nil {
840 return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
841 }
842 return nil
843 }
844
845
846 type rpcInfo struct {
847 failfast bool
848 preloaderInfo *compressorInfo
849 }
850
851
852
853
854
855
856 type compressorInfo struct {
857 codec baseCodec
858 cp Compressor
859 comp encoding.Compressor
860 }
861
862 type rpcInfoContextKey struct{}
863
864 func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context {
865 return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{
866 failfast: failfast,
867 preloaderInfo: &compressorInfo{
868 codec: codec,
869 cp: cp,
870 comp: comp,
871 },
872 })
873 }
874
875 func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
876 s, ok = ctx.Value(rpcInfoContextKey{}).(*rpcInfo)
877 return
878 }
879
880
881
882
883
884 func Code(err error) codes.Code {
885 return status.Code(err)
886 }
887
888
889
890
891
892 func ErrorDesc(err error) string {
893 return status.Convert(err).Message()
894 }
895
896
897
898
899
900 func Errorf(c codes.Code, format string, a ...any) error {
901 return status.Errorf(c, format, a...)
902 }
903
904 var errContextCanceled = status.Error(codes.Canceled, context.Canceled.Error())
905 var errContextDeadline = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
906
907
908 func toRPCErr(err error) error {
909 switch err {
910 case nil, io.EOF:
911 return err
912 case context.DeadlineExceeded:
913 return errContextDeadline
914 case context.Canceled:
915 return errContextCanceled
916 case io.ErrUnexpectedEOF:
917 return status.Error(codes.Internal, err.Error())
918 }
919
920 switch e := err.(type) {
921 case transport.ConnectionError:
922 return status.Error(codes.Unavailable, e.Desc)
923 case *transport.NewStreamError:
924 return toRPCErr(e.Err)
925 }
926
927 if _, ok := status.FromError(err); ok {
928 return err
929 }
930
931 return status.Error(codes.Unknown, err.Error())
932 }
933
934
935 func setCallInfoCodec(c *callInfo) error {
936 if c.codec != nil {
937
938
939 if c.contentSubtype == "" {
940
941
942
943
944 if ec, ok := c.codec.(encoding.Codec); ok {
945 c.contentSubtype = strings.ToLower(ec.Name())
946 }
947 }
948 return nil
949 }
950
951 if c.contentSubtype == "" {
952
953 c.codec = encoding.GetCodec(proto.Name)
954 return nil
955 }
956
957
958 c.codec = encoding.GetCodec(c.contentSubtype)
959 if c.codec == nil {
960 return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
961 }
962 return nil
963 }
964
965
966
967
968
969
970
971
972 const (
973 SupportPackageIsVersion3 = true
974 SupportPackageIsVersion4 = true
975 SupportPackageIsVersion5 = true
976 SupportPackageIsVersion6 = true
977 SupportPackageIsVersion7 = true
978 SupportPackageIsVersion8 = true
979 SupportPackageIsVersion9 = true
980 )
981
982 const grpcUA = "grpc-go/" + Version
983
View as plain text