1
18
19 package grpc
20
21 import (
22 "context"
23 "errors"
24 "io"
25 "math"
26 "strconv"
27 "sync"
28 "time"
29
30 "google.golang.org/grpc/balancer"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/grpc/encoding"
33 "google.golang.org/grpc/internal"
34 "google.golang.org/grpc/internal/balancerload"
35 "google.golang.org/grpc/internal/binarylog"
36 "google.golang.org/grpc/internal/channelz"
37 "google.golang.org/grpc/internal/grpcrand"
38 "google.golang.org/grpc/internal/grpcutil"
39 imetadata "google.golang.org/grpc/internal/metadata"
40 iresolver "google.golang.org/grpc/internal/resolver"
41 "google.golang.org/grpc/internal/serviceconfig"
42 istatus "google.golang.org/grpc/internal/status"
43 "google.golang.org/grpc/internal/transport"
44 "google.golang.org/grpc/metadata"
45 "google.golang.org/grpc/peer"
46 "google.golang.org/grpc/stats"
47 "google.golang.org/grpc/status"
48 )
49
50 var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))
51
52
53
54
55
56
57
58
59 type StreamHandler func(srv any, stream ServerStream) error
60
61
62
63
64 type StreamDesc struct {
65
66
67 StreamName string
68 Handler StreamHandler
69
70
71
72
73 ServerStreams bool
74 ClientStreams bool
75 }
76
77
78
79
80 type Stream interface {
81
82 Context() context.Context
83
84 SendMsg(m any) error
85
86 RecvMsg(m any) error
87 }
88
89
90
91
92
93 type ClientStream interface {
94
95
96
97
98 Header() (metadata.MD, error)
99
100
101
102 Trailer() metadata.MD
103
104
105
106 CloseSend() error
107
108
109
110
111 Context() context.Context
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133 SendMsg(m any) error
134
135
136
137
138
139
140
141
142 RecvMsg(m any) error
143 }
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161 func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
162
163
164 opts = combine(cc.dopts.callOptions, opts)
165
166 if cc.dopts.streamInt != nil {
167 return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
168 }
169 return newClientStream(ctx, desc, cc, method, opts...)
170 }
171
172
173 func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
174 return cc.NewStream(ctx, desc, method, opts...)
175 }
176
177 func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
178
179
180
181 if err := cc.idlenessMgr.OnCallBegin(); err != nil {
182 return nil, err
183 }
184
185
186 opts = append([]CallOption{OnFinish(func(error) { cc.idlenessMgr.OnCallEnd() })}, opts...)
187
188 if md, added, ok := metadataFromOutgoingContextRaw(ctx); ok {
189
190 if err := imetadata.Validate(md); err != nil {
191 return nil, status.Error(codes.Internal, err.Error())
192 }
193
194 for _, kvs := range added {
195 for i := 0; i < len(kvs); i += 2 {
196 if err := imetadata.ValidatePair(kvs[i], kvs[i+1]); err != nil {
197 return nil, status.Error(codes.Internal, err.Error())
198 }
199 }
200 }
201 }
202 if channelz.IsOn() {
203 cc.incrCallsStarted()
204 defer func() {
205 if err != nil {
206 cc.incrCallsFailed()
207 }
208 }()
209 }
210
211
212 if err := cc.waitForResolvedAddrs(ctx); err != nil {
213 return nil, err
214 }
215
216 var mc serviceconfig.MethodConfig
217 var onCommit func()
218 var newStream = func(ctx context.Context, done func()) (iresolver.ClientStream, error) {
219 return newClientStreamWithParams(ctx, desc, cc, method, mc, onCommit, done, opts...)
220 }
221
222 rpcInfo := iresolver.RPCInfo{Context: ctx, Method: method}
223 rpcConfig, err := cc.safeConfigSelector.SelectConfig(rpcInfo)
224 if err != nil {
225 if st, ok := status.FromError(err); ok {
226
227 if istatus.IsRestrictedControlPlaneCode(st) {
228 err = status.Errorf(codes.Internal, "config selector returned illegal status: %v", err)
229 }
230 return nil, err
231 }
232 return nil, toRPCErr(err)
233 }
234
235 if rpcConfig != nil {
236 if rpcConfig.Context != nil {
237 ctx = rpcConfig.Context
238 }
239 mc = rpcConfig.MethodConfig
240 onCommit = rpcConfig.OnCommitted
241 if rpcConfig.Interceptor != nil {
242 rpcInfo.Context = nil
243 ns := newStream
244 newStream = func(ctx context.Context, done func()) (iresolver.ClientStream, error) {
245 cs, err := rpcConfig.Interceptor.NewStream(ctx, rpcInfo, done, ns)
246 if err != nil {
247 return nil, toRPCErr(err)
248 }
249 return cs, nil
250 }
251 }
252 }
253
254 return newStream(ctx, func() {})
255 }
256
257 func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc serviceconfig.MethodConfig, onCommit, doneFunc func(), opts ...CallOption) (_ iresolver.ClientStream, err error) {
258 c := defaultCallInfo()
259 if mc.WaitForReady != nil {
260 c.failFast = !*mc.WaitForReady
261 }
262
263
264
265
266
267
268 var cancel context.CancelFunc
269 if mc.Timeout != nil && *mc.Timeout >= 0 {
270 ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
271 } else {
272 ctx, cancel = context.WithCancel(ctx)
273 }
274 defer func() {
275 if err != nil {
276 cancel()
277 }
278 }()
279
280 for _, o := range opts {
281 if err := o.before(c); err != nil {
282 return nil, toRPCErr(err)
283 }
284 }
285 c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
286 c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
287 if err := setCallInfoCodec(c); err != nil {
288 return nil, err
289 }
290
291 callHdr := &transport.CallHdr{
292 Host: cc.authority,
293 Method: method,
294 ContentSubtype: c.contentSubtype,
295 DoneFunc: doneFunc,
296 }
297
298
299
300
301
302 var cp Compressor
303 var comp encoding.Compressor
304 if ct := c.compressorType; ct != "" {
305 callHdr.SendCompress = ct
306 if ct != encoding.Identity {
307 comp = encoding.GetCompressor(ct)
308 if comp == nil {
309 return nil, status.Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct)
310 }
311 }
312 } else if cc.dopts.cp != nil {
313 callHdr.SendCompress = cc.dopts.cp.Type()
314 cp = cc.dopts.cp
315 }
316 if c.creds != nil {
317 callHdr.Creds = c.creds
318 }
319
320 cs := &clientStream{
321 callHdr: callHdr,
322 ctx: ctx,
323 methodConfig: &mc,
324 opts: opts,
325 callInfo: c,
326 cc: cc,
327 desc: desc,
328 codec: c.codec,
329 cp: cp,
330 comp: comp,
331 cancel: cancel,
332 firstAttempt: true,
333 onCommit: onCommit,
334 }
335 if !cc.dopts.disableRetry {
336 cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler)
337 }
338 if ml := binarylog.GetMethodLogger(method); ml != nil {
339 cs.binlogs = append(cs.binlogs, ml)
340 }
341 if cc.dopts.binaryLogger != nil {
342 if ml := cc.dopts.binaryLogger.GetMethodLogger(method); ml != nil {
343 cs.binlogs = append(cs.binlogs, ml)
344 }
345 }
346
347
348
349 op := func(a *csAttempt) error {
350 if err := a.getTransport(); err != nil {
351 return err
352 }
353 if err := a.newStream(); err != nil {
354 return err
355 }
356
357
358
359 cs.attempt = a
360 return nil
361 }
362 if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }); err != nil {
363 return nil, err
364 }
365
366 if len(cs.binlogs) != 0 {
367 md, _ := metadata.FromOutgoingContext(ctx)
368 logEntry := &binarylog.ClientHeader{
369 OnClientSide: true,
370 Header: md,
371 MethodName: method,
372 Authority: cs.cc.authority,
373 }
374 if deadline, ok := ctx.Deadline(); ok {
375 logEntry.Timeout = time.Until(deadline)
376 if logEntry.Timeout < 0 {
377 logEntry.Timeout = 0
378 }
379 }
380 for _, binlog := range cs.binlogs {
381 binlog.Log(cs.ctx, logEntry)
382 }
383 }
384
385 if desc != unaryStreamDesc {
386
387
388
389
390
391 go func() {
392 select {
393 case <-cc.ctx.Done():
394 cs.finish(ErrClientConnClosing)
395 case <-ctx.Done():
396 cs.finish(toRPCErr(ctx.Err()))
397 }
398 }()
399 }
400 return cs, nil
401 }
402
403
404 func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) {
405 if err := cs.ctx.Err(); err != nil {
406 return nil, toRPCErr(err)
407 }
408 if err := cs.cc.ctx.Err(); err != nil {
409 return nil, ErrClientConnClosing
410 }
411
412 ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.cp, cs.comp)
413 method := cs.callHdr.Method
414 var beginTime time.Time
415 shs := cs.cc.dopts.copts.StatsHandlers
416 for _, sh := range shs {
417 ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast})
418 beginTime = time.Now()
419 begin := &stats.Begin{
420 Client: true,
421 BeginTime: beginTime,
422 FailFast: cs.callInfo.failFast,
423 IsClientStream: cs.desc.ClientStreams,
424 IsServerStream: cs.desc.ServerStreams,
425 IsTransparentRetryAttempt: isTransparent,
426 }
427 sh.HandleRPC(ctx, begin)
428 }
429
430 var trInfo *traceInfo
431 if EnableTracing {
432 trInfo = &traceInfo{
433 tr: newTrace("grpc.Sent."+methodFamily(method), method),
434 firstLine: firstLine{
435 client: true,
436 },
437 }
438 if deadline, ok := ctx.Deadline(); ok {
439 trInfo.firstLine.deadline = time.Until(deadline)
440 }
441 trInfo.tr.LazyLog(&trInfo.firstLine, false)
442 ctx = newTraceContext(ctx, trInfo.tr)
443 }
444
445 if cs.cc.parsedTarget.URL.Scheme == internal.GRPCResolverSchemeExtraMetadata {
446
447
448 ctx = grpcutil.WithExtraMetadata(ctx, metadata.Pairs(
449 "content-type", grpcutil.ContentType(cs.callHdr.ContentSubtype),
450 ))
451 }
452
453 return &csAttempt{
454 ctx: ctx,
455 beginTime: beginTime,
456 cs: cs,
457 dc: cs.cc.dopts.dc,
458 statsHandlers: shs,
459 trInfo: trInfo,
460 }, nil
461 }
462
463 func (a *csAttempt) getTransport() error {
464 cs := a.cs
465
466 var err error
467 a.t, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
468 if err != nil {
469 if de, ok := err.(dropError); ok {
470 err = de.error
471 a.drop = true
472 }
473 return err
474 }
475 if a.trInfo != nil {
476 a.trInfo.firstLine.SetRemoteAddr(a.t.RemoteAddr())
477 }
478 return nil
479 }
480
481 func (a *csAttempt) newStream() error {
482 cs := a.cs
483 cs.callHdr.PreviousAttempts = cs.numRetries
484
485
486
487
488
489 if a.pickResult.Metadata != nil {
490
491
492
493
494
495
496
497
498 md, _ := metadata.FromOutgoingContext(a.ctx)
499 md = metadata.Join(md, a.pickResult.Metadata)
500 a.ctx = metadata.NewOutgoingContext(a.ctx, md)
501 }
502
503 s, err := a.t.NewStream(a.ctx, cs.callHdr)
504 if err != nil {
505 nse, ok := err.(*transport.NewStreamError)
506 if !ok {
507
508 return err
509 }
510
511 if nse.AllowTransparentRetry {
512 a.allowTransparentRetry = true
513 }
514
515
516 return toRPCErr(nse.Err)
517 }
518 a.s = s
519 a.ctx = s.Context()
520 a.p = &parser{r: s, recvBufferPool: a.cs.cc.dopts.recvBufferPool}
521 return nil
522 }
523
524
525 type clientStream struct {
526 callHdr *transport.CallHdr
527 opts []CallOption
528 callInfo *callInfo
529 cc *ClientConn
530 desc *StreamDesc
531
532 codec baseCodec
533 cp Compressor
534 comp encoding.Compressor
535
536 cancel context.CancelFunc
537
538 sentLast bool
539
540 methodConfig *MethodConfig
541
542 ctx context.Context
543
544 retryThrottler *retryThrottler
545
546 binlogs []binarylog.MethodLogger
547
548
549
550
551
552
553 serverHeaderBinlogged bool
554
555 mu sync.Mutex
556 firstAttempt bool
557 numRetries int
558 numRetriesSincePushback int
559 finished bool
560
561
562
563
564
565
566
567 attempt *csAttempt
568
569 committed bool
570 onCommit func()
571 buffer []func(a *csAttempt) error
572 bufferSize int
573 }
574
575
576
577 type csAttempt struct {
578 ctx context.Context
579 cs *clientStream
580 t transport.ClientTransport
581 s *transport.Stream
582 p *parser
583 pickResult balancer.PickResult
584
585 finished bool
586 dc Decompressor
587 decomp encoding.Compressor
588 decompSet bool
589
590 mu sync.Mutex
591
592
593
594 trInfo *traceInfo
595
596 statsHandlers []stats.Handler
597 beginTime time.Time
598
599
600 allowTransparentRetry bool
601
602 drop bool
603 }
604
605 func (cs *clientStream) commitAttemptLocked() {
606 if !cs.committed && cs.onCommit != nil {
607 cs.onCommit()
608 }
609 cs.committed = true
610 cs.buffer = nil
611 }
612
613 func (cs *clientStream) commitAttempt() {
614 cs.mu.Lock()
615 cs.commitAttemptLocked()
616 cs.mu.Unlock()
617 }
618
619
620
621
622 func (a *csAttempt) shouldRetry(err error) (bool, error) {
623 cs := a.cs
624
625 if cs.finished || cs.committed || a.drop {
626
627 return false, err
628 }
629 if a.s == nil && a.allowTransparentRetry {
630 return true, nil
631 }
632
633 unprocessed := false
634 if a.s != nil {
635 <-a.s.Done()
636 unprocessed = a.s.Unprocessed()
637 }
638 if cs.firstAttempt && unprocessed {
639
640 return true, nil
641 }
642 if cs.cc.dopts.disableRetry {
643 return false, err
644 }
645
646 pushback := 0
647 hasPushback := false
648 if a.s != nil {
649 if !a.s.TrailersOnly() {
650 return false, err
651 }
652
653
654
655 sps := a.s.Trailer()["grpc-retry-pushback-ms"]
656 if len(sps) == 1 {
657 var e error
658 if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 {
659 channelz.Infof(logger, cs.cc.channelz, "Server retry pushback specified to abort (%q).", sps[0])
660 cs.retryThrottler.throttle()
661 return false, err
662 }
663 hasPushback = true
664 } else if len(sps) > 1 {
665 channelz.Warningf(logger, cs.cc.channelz, "Server retry pushback specified multiple values (%q); not retrying.", sps)
666 cs.retryThrottler.throttle()
667 return false, err
668 }
669 }
670
671 var code codes.Code
672 if a.s != nil {
673 code = a.s.Status().Code()
674 } else {
675 code = status.Code(err)
676 }
677
678 rp := cs.methodConfig.RetryPolicy
679 if rp == nil || !rp.RetryableStatusCodes[code] {
680 return false, err
681 }
682
683
684
685 if cs.retryThrottler.throttle() {
686 return false, err
687 }
688 if cs.numRetries+1 >= rp.MaxAttempts {
689 return false, err
690 }
691
692 var dur time.Duration
693 if hasPushback {
694 dur = time.Millisecond * time.Duration(pushback)
695 cs.numRetriesSincePushback = 0
696 } else {
697 fact := math.Pow(rp.BackoffMultiplier, float64(cs.numRetriesSincePushback))
698 cur := float64(rp.InitialBackoff) * fact
699 if max := float64(rp.MaxBackoff); cur > max {
700 cur = max
701 }
702 dur = time.Duration(grpcrand.Int63n(int64(cur)))
703 cs.numRetriesSincePushback++
704 }
705
706
707
708 t := time.NewTimer(dur)
709 select {
710 case <-t.C:
711 cs.numRetries++
712 return false, nil
713 case <-cs.ctx.Done():
714 t.Stop()
715 return false, status.FromContextError(cs.ctx.Err()).Err()
716 }
717 }
718
719
720 func (cs *clientStream) retryLocked(attempt *csAttempt, lastErr error) error {
721 for {
722 attempt.finish(toRPCErr(lastErr))
723 isTransparent, err := attempt.shouldRetry(lastErr)
724 if err != nil {
725 cs.commitAttemptLocked()
726 return err
727 }
728 cs.firstAttempt = false
729 attempt, err = cs.newAttemptLocked(isTransparent)
730 if err != nil {
731
732
733 return err
734 }
735
736
737 if lastErr = cs.replayBufferLocked(attempt); lastErr == nil {
738 return nil
739 }
740 }
741 }
742
743 func (cs *clientStream) Context() context.Context {
744 cs.commitAttempt()
745
746
747 if cs.attempt.s != nil {
748 return cs.attempt.s.Context()
749 }
750 return cs.ctx
751 }
752
753 func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func()) error {
754 cs.mu.Lock()
755 for {
756 if cs.committed {
757 cs.mu.Unlock()
758
759
760
761
762 return toRPCErr(op(cs.attempt))
763 }
764 if len(cs.buffer) == 0 {
765
766
767
768
769 var err error
770 if cs.attempt, err = cs.newAttemptLocked(false ); err != nil {
771 cs.mu.Unlock()
772 cs.finish(err)
773 return err
774 }
775 }
776 a := cs.attempt
777 cs.mu.Unlock()
778 err := op(a)
779 cs.mu.Lock()
780 if a != cs.attempt {
781
782 continue
783 }
784 if err == io.EOF {
785 <-a.s.Done()
786 }
787 if err == nil || (err == io.EOF && a.s.Status().Code() == codes.OK) {
788 onSuccess()
789 cs.mu.Unlock()
790 return err
791 }
792 if err := cs.retryLocked(a, err); err != nil {
793 cs.mu.Unlock()
794 return err
795 }
796 }
797 }
798
799 func (cs *clientStream) Header() (metadata.MD, error) {
800 var m metadata.MD
801 err := cs.withRetry(func(a *csAttempt) error {
802 var err error
803 m, err = a.s.Header()
804 return toRPCErr(err)
805 }, cs.commitAttemptLocked)
806
807 if m == nil && err == nil {
808
809 err = io.EOF
810 }
811
812 if err != nil {
813 cs.finish(err)
814
815 return nil, nil
816 }
817
818 if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && m != nil {
819
820
821 logEntry := &binarylog.ServerHeader{
822 OnClientSide: true,
823 Header: m,
824 PeerAddr: nil,
825 }
826 if peer, ok := peer.FromContext(cs.Context()); ok {
827 logEntry.PeerAddr = peer.Addr
828 }
829 cs.serverHeaderBinlogged = true
830 for _, binlog := range cs.binlogs {
831 binlog.Log(cs.ctx, logEntry)
832 }
833 }
834
835 return m, nil
836 }
837
838 func (cs *clientStream) Trailer() metadata.MD {
839
840
841
842
843
844
845
846 cs.commitAttempt()
847 if cs.attempt.s == nil {
848 return nil
849 }
850 return cs.attempt.s.Trailer()
851 }
852
853 func (cs *clientStream) replayBufferLocked(attempt *csAttempt) error {
854 for _, f := range cs.buffer {
855 if err := f(attempt); err != nil {
856 return err
857 }
858 }
859 return nil
860 }
861
862 func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error) {
863
864 if cs.committed {
865 return
866 }
867 cs.bufferSize += sz
868 if cs.bufferSize > cs.callInfo.maxRetryRPCBufferSize {
869 cs.commitAttemptLocked()
870 return
871 }
872 cs.buffer = append(cs.buffer, op)
873 }
874
875 func (cs *clientStream) SendMsg(m any) (err error) {
876 defer func() {
877 if err != nil && err != io.EOF {
878
879
880
881
882
883 cs.finish(err)
884 }
885 }()
886 if cs.sentLast {
887 return status.Errorf(codes.Internal, "SendMsg called after CloseSend")
888 }
889 if !cs.desc.ClientStreams {
890 cs.sentLast = true
891 }
892
893
894 hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp)
895 if err != nil {
896 return err
897 }
898
899
900 if len(payload) > *cs.callInfo.maxSendMessageSize {
901 return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize)
902 }
903 op := func(a *csAttempt) error {
904 return a.sendMsg(m, hdr, payload, data)
905 }
906 err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) })
907 if len(cs.binlogs) != 0 && err == nil {
908 cm := &binarylog.ClientMessage{
909 OnClientSide: true,
910 Message: data,
911 }
912 for _, binlog := range cs.binlogs {
913 binlog.Log(cs.ctx, cm)
914 }
915 }
916 return err
917 }
918
919 func (cs *clientStream) RecvMsg(m any) error {
920 if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged {
921
922 cs.Header()
923 }
924 var recvInfo *payloadInfo
925 if len(cs.binlogs) != 0 {
926 recvInfo = &payloadInfo{}
927 }
928 err := cs.withRetry(func(a *csAttempt) error {
929 return a.recvMsg(m, recvInfo)
930 }, cs.commitAttemptLocked)
931 if len(cs.binlogs) != 0 && err == nil {
932 sm := &binarylog.ServerMessage{
933 OnClientSide: true,
934 Message: recvInfo.uncompressedBytes,
935 }
936 for _, binlog := range cs.binlogs {
937 binlog.Log(cs.ctx, sm)
938 }
939 }
940 if err != nil || !cs.desc.ServerStreams {
941
942 cs.finish(err)
943 }
944 return err
945 }
946
947 func (cs *clientStream) CloseSend() error {
948 if cs.sentLast {
949
950 return nil
951 }
952 cs.sentLast = true
953 op := func(a *csAttempt) error {
954 a.t.Write(a.s, nil, nil, &transport.Options{Last: true})
955
956
957
958
959 return nil
960 }
961 cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) })
962 if len(cs.binlogs) != 0 {
963 chc := &binarylog.ClientHalfClose{
964 OnClientSide: true,
965 }
966 for _, binlog := range cs.binlogs {
967 binlog.Log(cs.ctx, chc)
968 }
969 }
970
971 return nil
972 }
973
974 func (cs *clientStream) finish(err error) {
975 if err == io.EOF {
976
977 err = nil
978 }
979 cs.mu.Lock()
980 if cs.finished {
981 cs.mu.Unlock()
982 return
983 }
984 cs.finished = true
985 for _, onFinish := range cs.callInfo.onFinish {
986 onFinish(err)
987 }
988 cs.commitAttemptLocked()
989 if cs.attempt != nil {
990 cs.attempt.finish(err)
991
992 if cs.attempt.s != nil {
993 for _, o := range cs.opts {
994 o.after(cs.callInfo, cs.attempt)
995 }
996 }
997 }
998
999 cs.mu.Unlock()
1000
1001 if len(cs.binlogs) != 0 {
1002 switch err {
1003 case errContextCanceled, errContextDeadline, ErrClientConnClosing:
1004 c := &binarylog.Cancel{
1005 OnClientSide: true,
1006 }
1007 for _, binlog := range cs.binlogs {
1008 binlog.Log(cs.ctx, c)
1009 }
1010 default:
1011 logEntry := &binarylog.ServerTrailer{
1012 OnClientSide: true,
1013 Trailer: cs.Trailer(),
1014 Err: err,
1015 }
1016 if peer, ok := peer.FromContext(cs.Context()); ok {
1017 logEntry.PeerAddr = peer.Addr
1018 }
1019 for _, binlog := range cs.binlogs {
1020 binlog.Log(cs.ctx, logEntry)
1021 }
1022 }
1023 }
1024 if err == nil {
1025 cs.retryThrottler.successfulRPC()
1026 }
1027 if channelz.IsOn() {
1028 if err != nil {
1029 cs.cc.incrCallsFailed()
1030 } else {
1031 cs.cc.incrCallsSucceeded()
1032 }
1033 }
1034 cs.cancel()
1035 }
1036
1037 func (a *csAttempt) sendMsg(m any, hdr, payld, data []byte) error {
1038 cs := a.cs
1039 if a.trInfo != nil {
1040 a.mu.Lock()
1041 if a.trInfo.tr != nil {
1042 a.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
1043 }
1044 a.mu.Unlock()
1045 }
1046 if err := a.t.Write(a.s, hdr, payld, &transport.Options{Last: !cs.desc.ClientStreams}); err != nil {
1047 if !cs.desc.ClientStreams {
1048
1049
1050
1051 return nil
1052 }
1053 return io.EOF
1054 }
1055 for _, sh := range a.statsHandlers {
1056 sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now()))
1057 }
1058 if channelz.IsOn() {
1059 a.t.IncrMsgSent()
1060 }
1061 return nil
1062 }
1063
1064 func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
1065 cs := a.cs
1066 if len(a.statsHandlers) != 0 && payInfo == nil {
1067 payInfo = &payloadInfo{}
1068 }
1069
1070 if !a.decompSet {
1071
1072 if ct := a.s.RecvCompress(); ct != "" && ct != encoding.Identity {
1073 if a.dc == nil || a.dc.Type() != ct {
1074
1075
1076 a.dc = nil
1077 a.decomp = encoding.GetCompressor(ct)
1078 }
1079 } else {
1080
1081 a.dc = nil
1082 }
1083
1084 a.decompSet = true
1085 }
1086 err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp)
1087 if err != nil {
1088 if err == io.EOF {
1089 if statusErr := a.s.Status().Err(); statusErr != nil {
1090 return statusErr
1091 }
1092 return io.EOF
1093 }
1094
1095 return toRPCErr(err)
1096 }
1097 if a.trInfo != nil {
1098 a.mu.Lock()
1099 if a.trInfo.tr != nil {
1100 a.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
1101 }
1102 a.mu.Unlock()
1103 }
1104 for _, sh := range a.statsHandlers {
1105 sh.HandleRPC(a.ctx, &stats.InPayload{
1106 Client: true,
1107 RecvTime: time.Now(),
1108 Payload: m,
1109
1110 Data: payInfo.uncompressedBytes,
1111 WireLength: payInfo.compressedLength + headerLen,
1112 CompressedLength: payInfo.compressedLength,
1113 Length: len(payInfo.uncompressedBytes),
1114 })
1115 }
1116 if channelz.IsOn() {
1117 a.t.IncrMsgRecv()
1118 }
1119 if cs.desc.ServerStreams {
1120
1121 return nil
1122 }
1123
1124
1125 err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp)
1126 if err == nil {
1127 return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
1128 }
1129 if err == io.EOF {
1130 return a.s.Status().Err()
1131 }
1132 return toRPCErr(err)
1133 }
1134
1135 func (a *csAttempt) finish(err error) {
1136 a.mu.Lock()
1137 if a.finished {
1138 a.mu.Unlock()
1139 return
1140 }
1141 a.finished = true
1142 if err == io.EOF {
1143
1144 err = nil
1145 }
1146 var tr metadata.MD
1147 if a.s != nil {
1148 a.t.CloseStream(a.s, err)
1149 tr = a.s.Trailer()
1150 }
1151
1152 if a.pickResult.Done != nil {
1153 br := false
1154 if a.s != nil {
1155 br = a.s.BytesReceived()
1156 }
1157 a.pickResult.Done(balancer.DoneInfo{
1158 Err: err,
1159 Trailer: tr,
1160 BytesSent: a.s != nil,
1161 BytesReceived: br,
1162 ServerLoad: balancerload.Parse(tr),
1163 })
1164 }
1165 for _, sh := range a.statsHandlers {
1166 end := &stats.End{
1167 Client: true,
1168 BeginTime: a.beginTime,
1169 EndTime: time.Now(),
1170 Trailer: tr,
1171 Error: err,
1172 }
1173 sh.HandleRPC(a.ctx, end)
1174 }
1175 if a.trInfo != nil && a.trInfo.tr != nil {
1176 if err == nil {
1177 a.trInfo.tr.LazyPrintf("RPC: [OK]")
1178 } else {
1179 a.trInfo.tr.LazyPrintf("RPC: [%v]", err)
1180 a.trInfo.tr.SetError()
1181 }
1182 a.trInfo.tr.Finish()
1183 a.trInfo.tr = nil
1184 }
1185 a.mu.Unlock()
1186 }
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199 func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method string, t transport.ClientTransport, ac *addrConn, opts ...CallOption) (_ ClientStream, err error) {
1200 if t == nil {
1201
1202 return nil, errors.New("transport provided is nil")
1203 }
1204
1205 c := &callInfo{}
1206
1207
1208
1209
1210
1211
1212 ctx, cancel := context.WithCancel(ctx)
1213 defer func() {
1214 if err != nil {
1215 cancel()
1216 }
1217 }()
1218
1219 for _, o := range opts {
1220 if err := o.before(c); err != nil {
1221 return nil, toRPCErr(err)
1222 }
1223 }
1224 c.maxReceiveMessageSize = getMaxSize(nil, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
1225 c.maxSendMessageSize = getMaxSize(nil, c.maxSendMessageSize, defaultServerMaxSendMessageSize)
1226 if err := setCallInfoCodec(c); err != nil {
1227 return nil, err
1228 }
1229
1230 callHdr := &transport.CallHdr{
1231 Host: ac.cc.authority,
1232 Method: method,
1233 ContentSubtype: c.contentSubtype,
1234 }
1235
1236
1237
1238
1239
1240 var cp Compressor
1241 var comp encoding.Compressor
1242 if ct := c.compressorType; ct != "" {
1243 callHdr.SendCompress = ct
1244 if ct != encoding.Identity {
1245 comp = encoding.GetCompressor(ct)
1246 if comp == nil {
1247 return nil, status.Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct)
1248 }
1249 }
1250 } else if ac.cc.dopts.cp != nil {
1251 callHdr.SendCompress = ac.cc.dopts.cp.Type()
1252 cp = ac.cc.dopts.cp
1253 }
1254 if c.creds != nil {
1255 callHdr.Creds = c.creds
1256 }
1257
1258
1259 as := &addrConnStream{
1260 callHdr: callHdr,
1261 ac: ac,
1262 ctx: ctx,
1263 cancel: cancel,
1264 opts: opts,
1265 callInfo: c,
1266 desc: desc,
1267 codec: c.codec,
1268 cp: cp,
1269 comp: comp,
1270 t: t,
1271 }
1272
1273 s, err := as.t.NewStream(as.ctx, as.callHdr)
1274 if err != nil {
1275 err = toRPCErr(err)
1276 return nil, err
1277 }
1278 as.s = s
1279 as.p = &parser{r: s, recvBufferPool: ac.dopts.recvBufferPool}
1280 ac.incrCallsStarted()
1281 if desc != unaryStreamDesc {
1282
1283
1284
1285
1286
1287
1288
1289 go func() {
1290 ac.mu.Lock()
1291 acCtx := ac.ctx
1292 ac.mu.Unlock()
1293 select {
1294 case <-acCtx.Done():
1295 as.finish(status.Error(codes.Canceled, "grpc: the SubConn is closing"))
1296 case <-ctx.Done():
1297 as.finish(toRPCErr(ctx.Err()))
1298 }
1299 }()
1300 }
1301 return as, nil
1302 }
1303
1304 type addrConnStream struct {
1305 s *transport.Stream
1306 ac *addrConn
1307 callHdr *transport.CallHdr
1308 cancel context.CancelFunc
1309 opts []CallOption
1310 callInfo *callInfo
1311 t transport.ClientTransport
1312 ctx context.Context
1313 sentLast bool
1314 desc *StreamDesc
1315 codec baseCodec
1316 cp Compressor
1317 comp encoding.Compressor
1318 decompSet bool
1319 dc Decompressor
1320 decomp encoding.Compressor
1321 p *parser
1322 mu sync.Mutex
1323 finished bool
1324 }
1325
1326 func (as *addrConnStream) Header() (metadata.MD, error) {
1327 m, err := as.s.Header()
1328 if err != nil {
1329 as.finish(toRPCErr(err))
1330 }
1331 return m, err
1332 }
1333
1334 func (as *addrConnStream) Trailer() metadata.MD {
1335 return as.s.Trailer()
1336 }
1337
1338 func (as *addrConnStream) CloseSend() error {
1339 if as.sentLast {
1340
1341 return nil
1342 }
1343 as.sentLast = true
1344
1345 as.t.Write(as.s, nil, nil, &transport.Options{Last: true})
1346
1347
1348
1349
1350 return nil
1351 }
1352
1353 func (as *addrConnStream) Context() context.Context {
1354 return as.s.Context()
1355 }
1356
1357 func (as *addrConnStream) SendMsg(m any) (err error) {
1358 defer func() {
1359 if err != nil && err != io.EOF {
1360
1361
1362
1363
1364
1365 as.finish(err)
1366 }
1367 }()
1368 if as.sentLast {
1369 return status.Errorf(codes.Internal, "SendMsg called after CloseSend")
1370 }
1371 if !as.desc.ClientStreams {
1372 as.sentLast = true
1373 }
1374
1375
1376 hdr, payld, _, err := prepareMsg(m, as.codec, as.cp, as.comp)
1377 if err != nil {
1378 return err
1379 }
1380
1381
1382 if len(payld) > *as.callInfo.maxSendMessageSize {
1383 return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize)
1384 }
1385
1386 if err := as.t.Write(as.s, hdr, payld, &transport.Options{Last: !as.desc.ClientStreams}); err != nil {
1387 if !as.desc.ClientStreams {
1388
1389
1390
1391 return nil
1392 }
1393 return io.EOF
1394 }
1395
1396 if channelz.IsOn() {
1397 as.t.IncrMsgSent()
1398 }
1399 return nil
1400 }
1401
1402 func (as *addrConnStream) RecvMsg(m any) (err error) {
1403 defer func() {
1404 if err != nil || !as.desc.ServerStreams {
1405
1406 as.finish(err)
1407 }
1408 }()
1409
1410 if !as.decompSet {
1411
1412 if ct := as.s.RecvCompress(); ct != "" && ct != encoding.Identity {
1413 if as.dc == nil || as.dc.Type() != ct {
1414
1415
1416 as.dc = nil
1417 as.decomp = encoding.GetCompressor(ct)
1418 }
1419 } else {
1420
1421 as.dc = nil
1422 }
1423
1424 as.decompSet = true
1425 }
1426 err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
1427 if err != nil {
1428 if err == io.EOF {
1429 if statusErr := as.s.Status().Err(); statusErr != nil {
1430 return statusErr
1431 }
1432 return io.EOF
1433 }
1434 return toRPCErr(err)
1435 }
1436
1437 if channelz.IsOn() {
1438 as.t.IncrMsgRecv()
1439 }
1440 if as.desc.ServerStreams {
1441
1442 return nil
1443 }
1444
1445
1446
1447 err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
1448 if err == nil {
1449 return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
1450 }
1451 if err == io.EOF {
1452 return as.s.Status().Err()
1453 }
1454 return toRPCErr(err)
1455 }
1456
1457 func (as *addrConnStream) finish(err error) {
1458 as.mu.Lock()
1459 if as.finished {
1460 as.mu.Unlock()
1461 return
1462 }
1463 as.finished = true
1464 if err == io.EOF {
1465
1466 err = nil
1467 }
1468 if as.s != nil {
1469 as.t.CloseStream(as.s, err)
1470 }
1471
1472 if err != nil {
1473 as.ac.incrCallsFailed()
1474 } else {
1475 as.ac.incrCallsSucceeded()
1476 }
1477 as.cancel()
1478 as.mu.Unlock()
1479 }
1480
1481
1482
1483
1484
1485
1486
1487 type ServerStream interface {
1488
1489
1490
1491
1492
1493
1494 SetHeader(metadata.MD) error
1495
1496
1497
1498 SendHeader(metadata.MD) error
1499
1500
1501 SetTrailer(metadata.MD)
1502
1503 Context() context.Context
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521 SendMsg(m any) error
1522
1523
1524
1525
1526
1527
1528
1529
1530 RecvMsg(m any) error
1531 }
1532
1533
1534 type serverStream struct {
1535 ctx context.Context
1536 t transport.ServerTransport
1537 s *transport.Stream
1538 p *parser
1539 codec baseCodec
1540
1541 cp Compressor
1542 dc Decompressor
1543 comp encoding.Compressor
1544 decomp encoding.Compressor
1545
1546 sendCompressorName string
1547
1548 maxReceiveMessageSize int
1549 maxSendMessageSize int
1550 trInfo *traceInfo
1551
1552 statsHandler []stats.Handler
1553
1554 binlogs []binarylog.MethodLogger
1555
1556
1557
1558
1559
1560
1561 serverHeaderBinlogged bool
1562
1563 mu sync.Mutex
1564 }
1565
1566 func (ss *serverStream) Context() context.Context {
1567 return ss.ctx
1568 }
1569
1570 func (ss *serverStream) SetHeader(md metadata.MD) error {
1571 if md.Len() == 0 {
1572 return nil
1573 }
1574 err := imetadata.Validate(md)
1575 if err != nil {
1576 return status.Error(codes.Internal, err.Error())
1577 }
1578 return ss.s.SetHeader(md)
1579 }
1580
1581 func (ss *serverStream) SendHeader(md metadata.MD) error {
1582 err := imetadata.Validate(md)
1583 if err != nil {
1584 return status.Error(codes.Internal, err.Error())
1585 }
1586
1587 err = ss.t.WriteHeader(ss.s, md)
1588 if len(ss.binlogs) != 0 && !ss.serverHeaderBinlogged {
1589 h, _ := ss.s.Header()
1590 sh := &binarylog.ServerHeader{
1591 Header: h,
1592 }
1593 ss.serverHeaderBinlogged = true
1594 for _, binlog := range ss.binlogs {
1595 binlog.Log(ss.ctx, sh)
1596 }
1597 }
1598 return err
1599 }
1600
1601 func (ss *serverStream) SetTrailer(md metadata.MD) {
1602 if md.Len() == 0 {
1603 return
1604 }
1605 if err := imetadata.Validate(md); err != nil {
1606 logger.Errorf("stream: failed to validate md when setting trailer, err: %v", err)
1607 }
1608 ss.s.SetTrailer(md)
1609 }
1610
1611 func (ss *serverStream) SendMsg(m any) (err error) {
1612 defer func() {
1613 if ss.trInfo != nil {
1614 ss.mu.Lock()
1615 if ss.trInfo.tr != nil {
1616 if err == nil {
1617 ss.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
1618 } else {
1619 ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1620 ss.trInfo.tr.SetError()
1621 }
1622 }
1623 ss.mu.Unlock()
1624 }
1625 if err != nil && err != io.EOF {
1626 st, _ := status.FromError(toRPCErr(err))
1627 ss.t.WriteStatus(ss.s, st)
1628
1629
1630
1631
1632
1633
1634 }
1635 if channelz.IsOn() && err == nil {
1636 ss.t.IncrMsgSent()
1637 }
1638 }()
1639
1640
1641
1642 if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
1643 ss.comp = encoding.GetCompressor(sendCompressorsName)
1644 ss.sendCompressorName = sendCompressorsName
1645 }
1646
1647
1648 hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
1649 if err != nil {
1650 return err
1651 }
1652
1653
1654 if len(payload) > ss.maxSendMessageSize {
1655 return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
1656 }
1657 if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
1658 return toRPCErr(err)
1659 }
1660 if len(ss.binlogs) != 0 {
1661 if !ss.serverHeaderBinlogged {
1662 h, _ := ss.s.Header()
1663 sh := &binarylog.ServerHeader{
1664 Header: h,
1665 }
1666 ss.serverHeaderBinlogged = true
1667 for _, binlog := range ss.binlogs {
1668 binlog.Log(ss.ctx, sh)
1669 }
1670 }
1671 sm := &binarylog.ServerMessage{
1672 Message: data,
1673 }
1674 for _, binlog := range ss.binlogs {
1675 binlog.Log(ss.ctx, sm)
1676 }
1677 }
1678 if len(ss.statsHandler) != 0 {
1679 for _, sh := range ss.statsHandler {
1680 sh.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now()))
1681 }
1682 }
1683 return nil
1684 }
1685
1686 func (ss *serverStream) RecvMsg(m any) (err error) {
1687 defer func() {
1688 if ss.trInfo != nil {
1689 ss.mu.Lock()
1690 if ss.trInfo.tr != nil {
1691 if err == nil {
1692 ss.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
1693 } else if err != io.EOF {
1694 ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1695 ss.trInfo.tr.SetError()
1696 }
1697 }
1698 ss.mu.Unlock()
1699 }
1700 if err != nil && err != io.EOF {
1701 st, _ := status.FromError(toRPCErr(err))
1702 ss.t.WriteStatus(ss.s, st)
1703
1704
1705
1706
1707
1708
1709 }
1710 if channelz.IsOn() && err == nil {
1711 ss.t.IncrMsgRecv()
1712 }
1713 }()
1714 var payInfo *payloadInfo
1715 if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
1716 payInfo = &payloadInfo{}
1717 }
1718 if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil {
1719 if err == io.EOF {
1720 if len(ss.binlogs) != 0 {
1721 chc := &binarylog.ClientHalfClose{}
1722 for _, binlog := range ss.binlogs {
1723 binlog.Log(ss.ctx, chc)
1724 }
1725 }
1726 return err
1727 }
1728 if err == io.ErrUnexpectedEOF {
1729 err = status.Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
1730 }
1731 return toRPCErr(err)
1732 }
1733 if len(ss.statsHandler) != 0 {
1734 for _, sh := range ss.statsHandler {
1735 sh.HandleRPC(ss.s.Context(), &stats.InPayload{
1736 RecvTime: time.Now(),
1737 Payload: m,
1738
1739 Data: payInfo.uncompressedBytes,
1740 Length: len(payInfo.uncompressedBytes),
1741 WireLength: payInfo.compressedLength + headerLen,
1742 CompressedLength: payInfo.compressedLength,
1743 })
1744 }
1745 }
1746 if len(ss.binlogs) != 0 {
1747 cm := &binarylog.ClientMessage{
1748 Message: payInfo.uncompressedBytes,
1749 }
1750 for _, binlog := range ss.binlogs {
1751 binlog.Log(ss.ctx, cm)
1752 }
1753 }
1754 return nil
1755 }
1756
1757
1758
1759 func MethodFromServerStream(stream ServerStream) (string, bool) {
1760 return Method(stream.Context())
1761 }
1762
1763
1764
1765
1766 func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor) (hdr, payload, data []byte, err error) {
1767 if preparedMsg, ok := m.(*PreparedMsg); ok {
1768 return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, nil
1769 }
1770
1771
1772 data, err = encode(codec, m)
1773 if err != nil {
1774 return nil, nil, nil, err
1775 }
1776 compData, err := compress(data, cp, comp)
1777 if err != nil {
1778 return nil, nil, nil, err
1779 }
1780 hdr, payload = msgHeader(data, compData)
1781 return hdr, payload, data, nil
1782 }
1783
View as plain text