1
18
19 package grpc
20
21 import (
22 "context"
23 "errors"
24 "fmt"
25 "io"
26 "math"
27 "net"
28 "net/http"
29 "reflect"
30 "runtime"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "time"
35
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/credentials"
38 "google.golang.org/grpc/encoding"
39 "google.golang.org/grpc/encoding/proto"
40 "google.golang.org/grpc/grpclog"
41 "google.golang.org/grpc/internal"
42 "google.golang.org/grpc/internal/binarylog"
43 "google.golang.org/grpc/internal/channelz"
44 "google.golang.org/grpc/internal/grpcsync"
45 "google.golang.org/grpc/internal/grpcutil"
46 "google.golang.org/grpc/internal/transport"
47 "google.golang.org/grpc/keepalive"
48 "google.golang.org/grpc/metadata"
49 "google.golang.org/grpc/peer"
50 "google.golang.org/grpc/stats"
51 "google.golang.org/grpc/status"
52 "google.golang.org/grpc/tap"
53 )
54
55 const (
56 defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4
57 defaultServerMaxSendMessageSize = math.MaxInt32
58
59
60
61
62
63
64 listenerAddressForServeHTTP = "listenerAddressForServeHTTP"
65 )
66
67 func init() {
68 internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
69 return srv.opts.creds
70 }
71 internal.IsRegisteredMethod = func(srv *Server, method string) bool {
72 return srv.isRegisteredMethod(method)
73 }
74 internal.ServerFromContext = serverFromContext
75 internal.AddGlobalServerOptions = func(opt ...ServerOption) {
76 globalServerOptions = append(globalServerOptions, opt...)
77 }
78 internal.ClearGlobalServerOptions = func() {
79 globalServerOptions = nil
80 }
81 internal.BinaryLogger = binaryLogger
82 internal.JoinServerOptions = newJoinServerOption
83 internal.RecvBufferPool = recvBufferPool
84 }
85
86 var statusOK = status.New(codes.OK, "")
87 var logger = grpclog.Component("core")
88
89 type methodHandler func(srv any, ctx context.Context, dec func(any) error, interceptor UnaryServerInterceptor) (any, error)
90
91
92 type MethodDesc struct {
93 MethodName string
94 Handler methodHandler
95 }
96
97
98 type ServiceDesc struct {
99 ServiceName string
100
101
102 HandlerType any
103 Methods []MethodDesc
104 Streams []StreamDesc
105 Metadata any
106 }
107
108
109
110 type serviceInfo struct {
111
112 serviceImpl any
113 methods map[string]*MethodDesc
114 streams map[string]*StreamDesc
115 mdata any
116 }
117
118
119 type Server struct {
120 opts serverOptions
121
122 mu sync.Mutex
123 lis map[net.Listener]bool
124
125
126
127 conns map[string]map[transport.ServerTransport]bool
128 serve bool
129 drain bool
130 cv *sync.Cond
131 services map[string]*serviceInfo
132 events traceEventLog
133
134 quit *grpcsync.Event
135 done *grpcsync.Event
136 channelzRemoveOnce sync.Once
137 serveWG sync.WaitGroup
138 handlersWG sync.WaitGroup
139
140 channelz *channelz.Server
141
142 serverWorkerChannel chan func()
143 serverWorkerChannelClose func()
144 }
145
146 type serverOptions struct {
147 creds credentials.TransportCredentials
148 codec baseCodec
149 cp Compressor
150 dc Decompressor
151 unaryInt UnaryServerInterceptor
152 streamInt StreamServerInterceptor
153 chainUnaryInts []UnaryServerInterceptor
154 chainStreamInts []StreamServerInterceptor
155 binaryLogger binarylog.Logger
156 inTapHandle tap.ServerInHandle
157 statsHandlers []stats.Handler
158 maxConcurrentStreams uint32
159 maxReceiveMessageSize int
160 maxSendMessageSize int
161 unknownStreamDesc *StreamDesc
162 keepaliveParams keepalive.ServerParameters
163 keepalivePolicy keepalive.EnforcementPolicy
164 initialWindowSize int32
165 initialConnWindowSize int32
166 writeBufferSize int
167 readBufferSize int
168 sharedWriteBuffer bool
169 connectionTimeout time.Duration
170 maxHeaderListSize *uint32
171 headerTableSize *uint32
172 numServerWorkers uint32
173 recvBufferPool SharedBufferPool
174 waitForHandlers bool
175 }
176
177 var defaultServerOptions = serverOptions{
178 maxConcurrentStreams: math.MaxUint32,
179 maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
180 maxSendMessageSize: defaultServerMaxSendMessageSize,
181 connectionTimeout: 120 * time.Second,
182 writeBufferSize: defaultWriteBufSize,
183 readBufferSize: defaultReadBufSize,
184 recvBufferPool: nopBufferPool{},
185 }
186 var globalServerOptions []ServerOption
187
188
189 type ServerOption interface {
190 apply(*serverOptions)
191 }
192
193
194
195
196
197
198
199
200 type EmptyServerOption struct{}
201
202 func (EmptyServerOption) apply(*serverOptions) {}
203
204
205
206 type funcServerOption struct {
207 f func(*serverOptions)
208 }
209
210 func (fdo *funcServerOption) apply(do *serverOptions) {
211 fdo.f(do)
212 }
213
214 func newFuncServerOption(f func(*serverOptions)) *funcServerOption {
215 return &funcServerOption{
216 f: f,
217 }
218 }
219
220
221
222 type joinServerOption struct {
223 opts []ServerOption
224 }
225
226 func (mdo *joinServerOption) apply(do *serverOptions) {
227 for _, opt := range mdo.opts {
228 opt.apply(do)
229 }
230 }
231
232 func newJoinServerOption(opts ...ServerOption) ServerOption {
233 return &joinServerOption{opts: opts}
234 }
235
236
237
238
239
240
241
242
243
244 func SharedWriteBuffer(val bool) ServerOption {
245 return newFuncServerOption(func(o *serverOptions) {
246 o.sharedWriteBuffer = val
247 })
248 }
249
250
251
252
253
254 func WriteBufferSize(s int) ServerOption {
255 return newFuncServerOption(func(o *serverOptions) {
256 o.writeBufferSize = s
257 })
258 }
259
260
261
262
263
264 func ReadBufferSize(s int) ServerOption {
265 return newFuncServerOption(func(o *serverOptions) {
266 o.readBufferSize = s
267 })
268 }
269
270
271
272 func InitialWindowSize(s int32) ServerOption {
273 return newFuncServerOption(func(o *serverOptions) {
274 o.initialWindowSize = s
275 })
276 }
277
278
279
280 func InitialConnWindowSize(s int32) ServerOption {
281 return newFuncServerOption(func(o *serverOptions) {
282 o.initialConnWindowSize = s
283 })
284 }
285
286
287 func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
288 if kp.Time > 0 && kp.Time < internal.KeepaliveMinServerPingTime {
289 logger.Warning("Adjusting keepalive ping interval to minimum period of 1s")
290 kp.Time = internal.KeepaliveMinServerPingTime
291 }
292
293 return newFuncServerOption(func(o *serverOptions) {
294 o.keepaliveParams = kp
295 })
296 }
297
298
299 func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
300 return newFuncServerOption(func(o *serverOptions) {
301 o.keepalivePolicy = kep
302 })
303 }
304
305
306
307
308
309
310
311
312
313
314 func CustomCodec(codec Codec) ServerOption {
315 return newFuncServerOption(func(o *serverOptions) {
316 o.codec = codec
317 })
318 }
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343 func ForceServerCodec(codec encoding.Codec) ServerOption {
344 return newFuncServerOption(func(o *serverOptions) {
345 o.codec = codec
346 })
347 }
348
349
350
351
352
353
354
355
356
357 func RPCCompressor(cp Compressor) ServerOption {
358 return newFuncServerOption(func(o *serverOptions) {
359 o.cp = cp
360 })
361 }
362
363
364
365
366
367
368
369 func RPCDecompressor(dc Decompressor) ServerOption {
370 return newFuncServerOption(func(o *serverOptions) {
371 o.dc = dc
372 })
373 }
374
375
376
377
378
379 func MaxMsgSize(m int) ServerOption {
380 return MaxRecvMsgSize(m)
381 }
382
383
384
385 func MaxRecvMsgSize(m int) ServerOption {
386 return newFuncServerOption(func(o *serverOptions) {
387 o.maxReceiveMessageSize = m
388 })
389 }
390
391
392
393 func MaxSendMsgSize(m int) ServerOption {
394 return newFuncServerOption(func(o *serverOptions) {
395 o.maxSendMessageSize = m
396 })
397 }
398
399
400
401 func MaxConcurrentStreams(n uint32) ServerOption {
402 if n == 0 {
403 n = math.MaxUint32
404 }
405 return newFuncServerOption(func(o *serverOptions) {
406 o.maxConcurrentStreams = n
407 })
408 }
409
410
411 func Creds(c credentials.TransportCredentials) ServerOption {
412 return newFuncServerOption(func(o *serverOptions) {
413 o.creds = c
414 })
415 }
416
417
418
419
420 func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
421 return newFuncServerOption(func(o *serverOptions) {
422 if o.unaryInt != nil {
423 panic("The unary server interceptor was already set and may not be reset.")
424 }
425 o.unaryInt = i
426 })
427 }
428
429
430
431
432
433 func ChainUnaryInterceptor(interceptors ...UnaryServerInterceptor) ServerOption {
434 return newFuncServerOption(func(o *serverOptions) {
435 o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
436 })
437 }
438
439
440
441 func StreamInterceptor(i StreamServerInterceptor) ServerOption {
442 return newFuncServerOption(func(o *serverOptions) {
443 if o.streamInt != nil {
444 panic("The stream server interceptor was already set and may not be reset.")
445 }
446 o.streamInt = i
447 })
448 }
449
450
451
452
453
454 func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOption {
455 return newFuncServerOption(func(o *serverOptions) {
456 o.chainStreamInts = append(o.chainStreamInts, interceptors...)
457 })
458 }
459
460
461
462
463
464
465
466
467 func InTapHandle(h tap.ServerInHandle) ServerOption {
468 return newFuncServerOption(func(o *serverOptions) {
469 if o.inTapHandle != nil {
470 panic("The tap handle was already set and may not be reset.")
471 }
472 o.inTapHandle = h
473 })
474 }
475
476
477 func StatsHandler(h stats.Handler) ServerOption {
478 return newFuncServerOption(func(o *serverOptions) {
479 if h == nil {
480 logger.Error("ignoring nil parameter in grpc.StatsHandler ServerOption")
481
482
483 return
484 }
485 o.statsHandlers = append(o.statsHandlers, h)
486 })
487 }
488
489
490
491 func binaryLogger(bl binarylog.Logger) ServerOption {
492 return newFuncServerOption(func(o *serverOptions) {
493 o.binaryLogger = bl
494 })
495 }
496
497
498
499
500
501
502
503 func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
504 return newFuncServerOption(func(o *serverOptions) {
505 o.unknownStreamDesc = &StreamDesc{
506 StreamName: "unknown_service_handler",
507 Handler: streamHandler,
508
509 ClientStreams: true,
510 ServerStreams: true,
511 }
512 })
513 }
514
515
516
517
518
519
520
521
522
523
524 func ConnectionTimeout(d time.Duration) ServerOption {
525 return newFuncServerOption(func(o *serverOptions) {
526 o.connectionTimeout = d
527 })
528 }
529
530
531
532 type MaxHeaderListSizeServerOption struct {
533 MaxHeaderListSize uint32
534 }
535
536 func (o MaxHeaderListSizeServerOption) apply(so *serverOptions) {
537 so.maxHeaderListSize = &o.MaxHeaderListSize
538 }
539
540
541
542 func MaxHeaderListSize(s uint32) ServerOption {
543 return MaxHeaderListSizeServerOption{
544 MaxHeaderListSize: s,
545 }
546 }
547
548
549
550
551
552
553
554
555 func HeaderTableSize(s uint32) ServerOption {
556 return newFuncServerOption(func(o *serverOptions) {
557 o.headerTableSize = &s
558 })
559 }
560
561
562
563
564
565
566
567
568
569
570 func NumStreamWorkers(numServerWorkers uint32) ServerOption {
571
572
573
574
575 return newFuncServerOption(func(o *serverOptions) {
576 o.numServerWorkers = numServerWorkers
577 })
578 }
579
580
581
582
583
584
585
586
587
588
589 func WaitForHandlers(w bool) ServerOption {
590 return newFuncServerOption(func(o *serverOptions) {
591 o.waitForHandlers = w
592 })
593 }
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608 func RecvBufferPool(bufferPool SharedBufferPool) ServerOption {
609 return recvBufferPool(bufferPool)
610 }
611
612 func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
613 return newFuncServerOption(func(o *serverOptions) {
614 o.recvBufferPool = bufferPool
615 })
616 }
617
618
619
620
621
622
623 const serverWorkerResetThreshold = 1 << 16
624
625
626
627
628
629
630
631 func (s *Server) serverWorker() {
632 for completed := 0; completed < serverWorkerResetThreshold; completed++ {
633 f, ok := <-s.serverWorkerChannel
634 if !ok {
635 return
636 }
637 f()
638 }
639 go s.serverWorker()
640 }
641
642
643
644 func (s *Server) initServerWorkers() {
645 s.serverWorkerChannel = make(chan func())
646 s.serverWorkerChannelClose = grpcsync.OnceFunc(func() {
647 close(s.serverWorkerChannel)
648 })
649 for i := uint32(0); i < s.opts.numServerWorkers; i++ {
650 go s.serverWorker()
651 }
652 }
653
654
655
656 func NewServer(opt ...ServerOption) *Server {
657 opts := defaultServerOptions
658 for _, o := range globalServerOptions {
659 o.apply(&opts)
660 }
661 for _, o := range opt {
662 o.apply(&opts)
663 }
664 s := &Server{
665 lis: make(map[net.Listener]bool),
666 opts: opts,
667 conns: make(map[string]map[transport.ServerTransport]bool),
668 services: make(map[string]*serviceInfo),
669 quit: grpcsync.NewEvent(),
670 done: grpcsync.NewEvent(),
671 channelz: channelz.RegisterServer(""),
672 }
673 chainUnaryServerInterceptors(s)
674 chainStreamServerInterceptors(s)
675 s.cv = sync.NewCond(&s.mu)
676 if EnableTracing {
677 _, file, line, _ := runtime.Caller(1)
678 s.events = newTraceEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
679 }
680
681 if s.opts.numServerWorkers > 0 {
682 s.initServerWorkers()
683 }
684
685 channelz.Info(logger, s.channelz, "Server created")
686 return s
687 }
688
689
690
691 func (s *Server) printf(format string, a ...any) {
692 if s.events != nil {
693 s.events.Printf(format, a...)
694 }
695 }
696
697
698
699 func (s *Server) errorf(format string, a ...any) {
700 if s.events != nil {
701 s.events.Errorf(format, a...)
702 }
703 }
704
705
706
707
708 type ServiceRegistrar interface {
709
710
711
712
713
714 RegisterService(desc *ServiceDesc, impl any)
715 }
716
717
718
719
720
721 func (s *Server) RegisterService(sd *ServiceDesc, ss any) {
722 if ss != nil {
723 ht := reflect.TypeOf(sd.HandlerType).Elem()
724 st := reflect.TypeOf(ss)
725 if !st.Implements(ht) {
726 logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
727 }
728 }
729 s.register(sd, ss)
730 }
731
732 func (s *Server) register(sd *ServiceDesc, ss any) {
733 s.mu.Lock()
734 defer s.mu.Unlock()
735 s.printf("RegisterService(%q)", sd.ServiceName)
736 if s.serve {
737 logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
738 }
739 if _, ok := s.services[sd.ServiceName]; ok {
740 logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
741 }
742 info := &serviceInfo{
743 serviceImpl: ss,
744 methods: make(map[string]*MethodDesc),
745 streams: make(map[string]*StreamDesc),
746 mdata: sd.Metadata,
747 }
748 for i := range sd.Methods {
749 d := &sd.Methods[i]
750 info.methods[d.MethodName] = d
751 }
752 for i := range sd.Streams {
753 d := &sd.Streams[i]
754 info.streams[d.StreamName] = d
755 }
756 s.services[sd.ServiceName] = info
757 }
758
759
760 type MethodInfo struct {
761
762 Name string
763
764 IsClientStream bool
765
766 IsServerStream bool
767 }
768
769
770 type ServiceInfo struct {
771 Methods []MethodInfo
772
773 Metadata any
774 }
775
776
777
778 func (s *Server) GetServiceInfo() map[string]ServiceInfo {
779 ret := make(map[string]ServiceInfo)
780 for n, srv := range s.services {
781 methods := make([]MethodInfo, 0, len(srv.methods)+len(srv.streams))
782 for m := range srv.methods {
783 methods = append(methods, MethodInfo{
784 Name: m,
785 IsClientStream: false,
786 IsServerStream: false,
787 })
788 }
789 for m, d := range srv.streams {
790 methods = append(methods, MethodInfo{
791 Name: m,
792 IsClientStream: d.ClientStreams,
793 IsServerStream: d.ServerStreams,
794 })
795 }
796
797 ret[n] = ServiceInfo{
798 Methods: methods,
799 Metadata: srv.mdata,
800 }
801 }
802 return ret
803 }
804
805
806
807 var ErrServerStopped = errors.New("grpc: the server has been stopped")
808
809 type listenSocket struct {
810 net.Listener
811 channelz *channelz.Socket
812 }
813
814 func (l *listenSocket) Close() error {
815 err := l.Listener.Close()
816 channelz.RemoveEntry(l.channelz.ID)
817 channelz.Info(logger, l.channelz, "ListenSocket deleted")
818 return err
819 }
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839 func (s *Server) Serve(lis net.Listener) error {
840 s.mu.Lock()
841 s.printf("serving")
842 s.serve = true
843 if s.lis == nil {
844
845 s.mu.Unlock()
846 lis.Close()
847 return ErrServerStopped
848 }
849
850 s.serveWG.Add(1)
851 defer func() {
852 s.serveWG.Done()
853 if s.quit.HasFired() {
854
855 <-s.done.Done()
856 }
857 }()
858
859 ls := &listenSocket{
860 Listener: lis,
861 channelz: channelz.RegisterSocket(&channelz.Socket{
862 SocketType: channelz.SocketTypeListen,
863 Parent: s.channelz,
864 RefName: lis.Addr().String(),
865 LocalAddr: lis.Addr(),
866 SocketOptions: channelz.GetSocketOption(lis)},
867 ),
868 }
869 s.lis[ls] = true
870
871 defer func() {
872 s.mu.Lock()
873 if s.lis != nil && s.lis[ls] {
874 ls.Close()
875 delete(s.lis, ls)
876 }
877 s.mu.Unlock()
878 }()
879
880 s.mu.Unlock()
881 channelz.Info(logger, ls.channelz, "ListenSocket created")
882
883 var tempDelay time.Duration
884 for {
885 rawConn, err := lis.Accept()
886 if err != nil {
887 if ne, ok := err.(interface {
888 Temporary() bool
889 }); ok && ne.Temporary() {
890 if tempDelay == 0 {
891 tempDelay = 5 * time.Millisecond
892 } else {
893 tempDelay *= 2
894 }
895 if max := 1 * time.Second; tempDelay > max {
896 tempDelay = max
897 }
898 s.mu.Lock()
899 s.printf("Accept error: %v; retrying in %v", err, tempDelay)
900 s.mu.Unlock()
901 timer := time.NewTimer(tempDelay)
902 select {
903 case <-timer.C:
904 case <-s.quit.Done():
905 timer.Stop()
906 return nil
907 }
908 continue
909 }
910 s.mu.Lock()
911 s.printf("done serving; Accept = %v", err)
912 s.mu.Unlock()
913
914 if s.quit.HasFired() {
915 return nil
916 }
917 return err
918 }
919 tempDelay = 0
920
921
922
923
924
925 s.serveWG.Add(1)
926 go func() {
927 s.handleRawConn(lis.Addr().String(), rawConn)
928 s.serveWG.Done()
929 }()
930 }
931 }
932
933
934
935 func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
936 if s.quit.HasFired() {
937 rawConn.Close()
938 return
939 }
940 rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
941
942
943 st := s.newHTTP2Transport(rawConn)
944 rawConn.SetDeadline(time.Time{})
945 if st == nil {
946 return
947 }
948
949 if cc, ok := rawConn.(interface {
950 PassServerTransport(transport.ServerTransport)
951 }); ok {
952 cc.PassServerTransport(st)
953 }
954
955 if !s.addConn(lisAddr, st) {
956 return
957 }
958 go func() {
959 s.serveStreams(context.Background(), st, rawConn)
960 s.removeConn(lisAddr, st)
961 }()
962 }
963
964
965
966 func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
967 config := &transport.ServerConfig{
968 MaxStreams: s.opts.maxConcurrentStreams,
969 ConnectionTimeout: s.opts.connectionTimeout,
970 Credentials: s.opts.creds,
971 InTapHandle: s.opts.inTapHandle,
972 StatsHandlers: s.opts.statsHandlers,
973 KeepaliveParams: s.opts.keepaliveParams,
974 KeepalivePolicy: s.opts.keepalivePolicy,
975 InitialWindowSize: s.opts.initialWindowSize,
976 InitialConnWindowSize: s.opts.initialConnWindowSize,
977 WriteBufferSize: s.opts.writeBufferSize,
978 ReadBufferSize: s.opts.readBufferSize,
979 SharedWriteBuffer: s.opts.sharedWriteBuffer,
980 ChannelzParent: s.channelz,
981 MaxHeaderListSize: s.opts.maxHeaderListSize,
982 HeaderTableSize: s.opts.headerTableSize,
983 }
984 st, err := transport.NewServerTransport(c, config)
985 if err != nil {
986 s.mu.Lock()
987 s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
988 s.mu.Unlock()
989
990
991 if err != credentials.ErrConnDispatched {
992
993 if err != io.EOF {
994 channelz.Info(logger, s.channelz, "grpc: Server.Serve failed to create ServerTransport: ", err)
995 }
996 c.Close()
997 }
998 return nil
999 }
1000
1001 return st
1002 }
1003
1004 func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
1005 ctx = transport.SetConnection(ctx, rawConn)
1006 ctx = peer.NewContext(ctx, st.Peer())
1007 for _, sh := range s.opts.statsHandlers {
1008 ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
1009 RemoteAddr: st.Peer().Addr,
1010 LocalAddr: st.Peer().LocalAddr,
1011 })
1012 sh.HandleConn(ctx, &stats.ConnBegin{})
1013 }
1014
1015 defer func() {
1016 st.Close(errors.New("finished serving streams for the server transport"))
1017 for _, sh := range s.opts.statsHandlers {
1018 sh.HandleConn(ctx, &stats.ConnEnd{})
1019 }
1020 }()
1021
1022 streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
1023 st.HandleStreams(ctx, func(stream *transport.Stream) {
1024 s.handlersWG.Add(1)
1025 streamQuota.acquire()
1026 f := func() {
1027 defer streamQuota.release()
1028 defer s.handlersWG.Done()
1029 s.handleStream(st, stream)
1030 }
1031
1032 if s.opts.numServerWorkers > 0 {
1033 select {
1034 case s.serverWorkerChannel <- f:
1035 return
1036 default:
1037
1038 }
1039 }
1040 go f()
1041 })
1042 }
1043
1044 var _ http.Handler = (*Server)(nil)
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074 func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
1075 st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers)
1076 if err != nil {
1077
1078
1079 return
1080 }
1081 if !s.addConn(listenerAddressForServeHTTP, st) {
1082 return
1083 }
1084 defer s.removeConn(listenerAddressForServeHTTP, st)
1085 s.serveStreams(r.Context(), st, nil)
1086 }
1087
1088 func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
1089 s.mu.Lock()
1090 defer s.mu.Unlock()
1091 if s.conns == nil {
1092 st.Close(errors.New("Server.addConn called when server has already been stopped"))
1093 return false
1094 }
1095 if s.drain {
1096
1097
1098 st.Drain("")
1099 }
1100
1101 if s.conns[addr] == nil {
1102
1103 s.conns[addr] = make(map[transport.ServerTransport]bool)
1104 }
1105 s.conns[addr][st] = true
1106 return true
1107 }
1108
1109 func (s *Server) removeConn(addr string, st transport.ServerTransport) {
1110 s.mu.Lock()
1111 defer s.mu.Unlock()
1112
1113 conns := s.conns[addr]
1114 if conns != nil {
1115 delete(conns, st)
1116 if len(conns) == 0 {
1117
1118
1119
1120 delete(s.conns, addr)
1121 }
1122 s.cv.Broadcast()
1123 }
1124 }
1125
1126 func (s *Server) incrCallsStarted() {
1127 s.channelz.ServerMetrics.CallsStarted.Add(1)
1128 s.channelz.ServerMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
1129 }
1130
1131 func (s *Server) incrCallsSucceeded() {
1132 s.channelz.ServerMetrics.CallsSucceeded.Add(1)
1133 }
1134
1135 func (s *Server) incrCallsFailed() {
1136 s.channelz.ServerMetrics.CallsFailed.Add(1)
1137 }
1138
1139 func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
1140 data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
1141 if err != nil {
1142 channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
1143 return err
1144 }
1145 compData, err := compress(data, cp, comp)
1146 if err != nil {
1147 channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
1148 return err
1149 }
1150 hdr, payload := msgHeader(data, compData)
1151
1152 if len(payload) > s.opts.maxSendMessageSize {
1153 return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
1154 }
1155 err = t.Write(stream, hdr, payload, opts)
1156 if err == nil {
1157 for _, sh := range s.opts.statsHandlers {
1158 sh.HandleRPC(ctx, outPayload(false, msg, data, payload, time.Now()))
1159 }
1160 }
1161 return err
1162 }
1163
1164
1165 func chainUnaryServerInterceptors(s *Server) {
1166
1167
1168 interceptors := s.opts.chainUnaryInts
1169 if s.opts.unaryInt != nil {
1170 interceptors = append([]UnaryServerInterceptor{s.opts.unaryInt}, s.opts.chainUnaryInts...)
1171 }
1172
1173 var chainedInt UnaryServerInterceptor
1174 if len(interceptors) == 0 {
1175 chainedInt = nil
1176 } else if len(interceptors) == 1 {
1177 chainedInt = interceptors[0]
1178 } else {
1179 chainedInt = chainUnaryInterceptors(interceptors)
1180 }
1181
1182 s.opts.unaryInt = chainedInt
1183 }
1184
1185 func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
1186 return func(ctx context.Context, req any, info *UnaryServerInfo, handler UnaryHandler) (any, error) {
1187 return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
1188 }
1189 }
1190
1191 func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
1192 if curr == len(interceptors)-1 {
1193 return finalHandler
1194 }
1195 return func(ctx context.Context, req any) (any, error) {
1196 return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
1197 }
1198 }
1199
1200 func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) {
1201 shs := s.opts.statsHandlers
1202 if len(shs) != 0 || trInfo != nil || channelz.IsOn() {
1203 if channelz.IsOn() {
1204 s.incrCallsStarted()
1205 }
1206 var statsBegin *stats.Begin
1207 for _, sh := range shs {
1208 beginTime := time.Now()
1209 statsBegin = &stats.Begin{
1210 BeginTime: beginTime,
1211 IsClientStream: false,
1212 IsServerStream: false,
1213 }
1214 sh.HandleRPC(ctx, statsBegin)
1215 }
1216 if trInfo != nil {
1217 trInfo.tr.LazyLog(&trInfo.firstLine, false)
1218 }
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229 defer func() {
1230 if trInfo != nil {
1231 if err != nil && err != io.EOF {
1232 trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1233 trInfo.tr.SetError()
1234 }
1235 trInfo.tr.Finish()
1236 }
1237
1238 for _, sh := range shs {
1239 end := &stats.End{
1240 BeginTime: statsBegin.BeginTime,
1241 EndTime: time.Now(),
1242 }
1243 if err != nil && err != io.EOF {
1244 end.Error = toRPCErr(err)
1245 }
1246 sh.HandleRPC(ctx, end)
1247 }
1248
1249 if channelz.IsOn() {
1250 if err != nil && err != io.EOF {
1251 s.incrCallsFailed()
1252 } else {
1253 s.incrCallsSucceeded()
1254 }
1255 }
1256 }()
1257 }
1258 var binlogs []binarylog.MethodLogger
1259 if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil {
1260 binlogs = append(binlogs, ml)
1261 }
1262 if s.opts.binaryLogger != nil {
1263 if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil {
1264 binlogs = append(binlogs, ml)
1265 }
1266 }
1267 if len(binlogs) != 0 {
1268 md, _ := metadata.FromIncomingContext(ctx)
1269 logEntry := &binarylog.ClientHeader{
1270 Header: md,
1271 MethodName: stream.Method(),
1272 PeerAddr: nil,
1273 }
1274 if deadline, ok := ctx.Deadline(); ok {
1275 logEntry.Timeout = time.Until(deadline)
1276 if logEntry.Timeout < 0 {
1277 logEntry.Timeout = 0
1278 }
1279 }
1280 if a := md[":authority"]; len(a) > 0 {
1281 logEntry.Authority = a[0]
1282 }
1283 if peer, ok := peer.FromContext(ctx); ok {
1284 logEntry.PeerAddr = peer.Addr
1285 }
1286 for _, binlog := range binlogs {
1287 binlog.Log(ctx, logEntry)
1288 }
1289 }
1290
1291
1292
1293
1294
1295 var comp, decomp encoding.Compressor
1296 var cp Compressor
1297 var dc Decompressor
1298 var sendCompressorName string
1299
1300
1301
1302 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
1303 dc = s.opts.dc
1304 } else if rc != "" && rc != encoding.Identity {
1305 decomp = encoding.GetCompressor(rc)
1306 if decomp == nil {
1307 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
1308 t.WriteStatus(stream, st)
1309 return st.Err()
1310 }
1311 }
1312
1313
1314
1315
1316
1317 if s.opts.cp != nil {
1318 cp = s.opts.cp
1319 sendCompressorName = cp.Type()
1320 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
1321
1322 comp = encoding.GetCompressor(rc)
1323 if comp != nil {
1324 sendCompressorName = comp.Name()
1325 }
1326 }
1327
1328 if sendCompressorName != "" {
1329 if err := stream.SetSendCompress(sendCompressorName); err != nil {
1330 return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
1331 }
1332 }
1333
1334 var payInfo *payloadInfo
1335 if len(shs) != 0 || len(binlogs) != 0 {
1336 payInfo = &payloadInfo{}
1337 }
1338
1339 d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
1340 if err != nil {
1341 if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
1342 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
1343 }
1344 return err
1345 }
1346 if channelz.IsOn() {
1347 t.IncrMsgRecv()
1348 }
1349 df := func(v any) error {
1350 defer cancel()
1351
1352 if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
1353 return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
1354 }
1355 for _, sh := range shs {
1356 sh.HandleRPC(ctx, &stats.InPayload{
1357 RecvTime: time.Now(),
1358 Payload: v,
1359 Length: len(d),
1360 WireLength: payInfo.compressedLength + headerLen,
1361 CompressedLength: payInfo.compressedLength,
1362 Data: d,
1363 })
1364 }
1365 if len(binlogs) != 0 {
1366 cm := &binarylog.ClientMessage{
1367 Message: d,
1368 }
1369 for _, binlog := range binlogs {
1370 binlog.Log(ctx, cm)
1371 }
1372 }
1373 if trInfo != nil {
1374 trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
1375 }
1376 return nil
1377 }
1378 ctx = NewContextWithServerTransportStream(ctx, stream)
1379 reply, appErr := md.Handler(info.serviceImpl, ctx, df, s.opts.unaryInt)
1380 if appErr != nil {
1381 appStatus, ok := status.FromError(appErr)
1382 if !ok {
1383
1384
1385 appStatus = status.FromContextError(appErr)
1386 appErr = appStatus.Err()
1387 }
1388 if trInfo != nil {
1389 trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
1390 trInfo.tr.SetError()
1391 }
1392 if e := t.WriteStatus(stream, appStatus); e != nil {
1393 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
1394 }
1395 if len(binlogs) != 0 {
1396 if h, _ := stream.Header(); h.Len() > 0 {
1397
1398
1399 sh := &binarylog.ServerHeader{
1400 Header: h,
1401 }
1402 for _, binlog := range binlogs {
1403 binlog.Log(ctx, sh)
1404 }
1405 }
1406 st := &binarylog.ServerTrailer{
1407 Trailer: stream.Trailer(),
1408 Err: appErr,
1409 }
1410 for _, binlog := range binlogs {
1411 binlog.Log(ctx, st)
1412 }
1413 }
1414 return appErr
1415 }
1416 if trInfo != nil {
1417 trInfo.tr.LazyLog(stringer("OK"), false)
1418 }
1419 opts := &transport.Options{Last: true}
1420
1421
1422
1423 if stream.SendCompress() != sendCompressorName {
1424 comp = encoding.GetCompressor(stream.SendCompress())
1425 }
1426 if err := s.sendResponse(ctx, t, stream, reply, cp, opts, comp); err != nil {
1427 if err == io.EOF {
1428
1429 return err
1430 }
1431 if sts, ok := status.FromError(err); ok {
1432 if e := t.WriteStatus(stream, sts); e != nil {
1433 channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
1434 }
1435 } else {
1436 switch st := err.(type) {
1437 case transport.ConnectionError:
1438
1439 default:
1440 panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
1441 }
1442 }
1443 if len(binlogs) != 0 {
1444 h, _ := stream.Header()
1445 sh := &binarylog.ServerHeader{
1446 Header: h,
1447 }
1448 st := &binarylog.ServerTrailer{
1449 Trailer: stream.Trailer(),
1450 Err: appErr,
1451 }
1452 for _, binlog := range binlogs {
1453 binlog.Log(ctx, sh)
1454 binlog.Log(ctx, st)
1455 }
1456 }
1457 return err
1458 }
1459 if len(binlogs) != 0 {
1460 h, _ := stream.Header()
1461 sh := &binarylog.ServerHeader{
1462 Header: h,
1463 }
1464 sm := &binarylog.ServerMessage{
1465 Message: reply,
1466 }
1467 for _, binlog := range binlogs {
1468 binlog.Log(ctx, sh)
1469 binlog.Log(ctx, sm)
1470 }
1471 }
1472 if channelz.IsOn() {
1473 t.IncrMsgSent()
1474 }
1475 if trInfo != nil {
1476 trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
1477 }
1478
1479
1480
1481 if len(binlogs) != 0 {
1482 st := &binarylog.ServerTrailer{
1483 Trailer: stream.Trailer(),
1484 Err: appErr,
1485 }
1486 for _, binlog := range binlogs {
1487 binlog.Log(ctx, st)
1488 }
1489 }
1490 return t.WriteStatus(stream, statusOK)
1491 }
1492
1493
1494 func chainStreamServerInterceptors(s *Server) {
1495
1496
1497 interceptors := s.opts.chainStreamInts
1498 if s.opts.streamInt != nil {
1499 interceptors = append([]StreamServerInterceptor{s.opts.streamInt}, s.opts.chainStreamInts...)
1500 }
1501
1502 var chainedInt StreamServerInterceptor
1503 if len(interceptors) == 0 {
1504 chainedInt = nil
1505 } else if len(interceptors) == 1 {
1506 chainedInt = interceptors[0]
1507 } else {
1508 chainedInt = chainStreamInterceptors(interceptors)
1509 }
1510
1511 s.opts.streamInt = chainedInt
1512 }
1513
1514 func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
1515 return func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
1516 return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
1517 }
1518 }
1519
1520 func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
1521 if curr == len(interceptors)-1 {
1522 return finalHandler
1523 }
1524 return func(srv any, stream ServerStream) error {
1525 return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
1526 }
1527 }
1528
1529 func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) {
1530 if channelz.IsOn() {
1531 s.incrCallsStarted()
1532 }
1533 shs := s.opts.statsHandlers
1534 var statsBegin *stats.Begin
1535 if len(shs) != 0 {
1536 beginTime := time.Now()
1537 statsBegin = &stats.Begin{
1538 BeginTime: beginTime,
1539 IsClientStream: sd.ClientStreams,
1540 IsServerStream: sd.ServerStreams,
1541 }
1542 for _, sh := range shs {
1543 sh.HandleRPC(ctx, statsBegin)
1544 }
1545 }
1546 ctx = NewContextWithServerTransportStream(ctx, stream)
1547 ss := &serverStream{
1548 ctx: ctx,
1549 t: t,
1550 s: stream,
1551 p: &parser{r: stream, recvBufferPool: s.opts.recvBufferPool},
1552 codec: s.getCodec(stream.ContentSubtype()),
1553 maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
1554 maxSendMessageSize: s.opts.maxSendMessageSize,
1555 trInfo: trInfo,
1556 statsHandler: shs,
1557 }
1558
1559 if len(shs) != 0 || trInfo != nil || channelz.IsOn() {
1560
1561 defer func() {
1562 if trInfo != nil {
1563 ss.mu.Lock()
1564 if err != nil && err != io.EOF {
1565 ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1566 ss.trInfo.tr.SetError()
1567 }
1568 ss.trInfo.tr.Finish()
1569 ss.trInfo.tr = nil
1570 ss.mu.Unlock()
1571 }
1572
1573 if len(shs) != 0 {
1574 end := &stats.End{
1575 BeginTime: statsBegin.BeginTime,
1576 EndTime: time.Now(),
1577 }
1578 if err != nil && err != io.EOF {
1579 end.Error = toRPCErr(err)
1580 }
1581 for _, sh := range shs {
1582 sh.HandleRPC(ctx, end)
1583 }
1584 }
1585
1586 if channelz.IsOn() {
1587 if err != nil && err != io.EOF {
1588 s.incrCallsFailed()
1589 } else {
1590 s.incrCallsSucceeded()
1591 }
1592 }
1593 }()
1594 }
1595
1596 if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil {
1597 ss.binlogs = append(ss.binlogs, ml)
1598 }
1599 if s.opts.binaryLogger != nil {
1600 if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil {
1601 ss.binlogs = append(ss.binlogs, ml)
1602 }
1603 }
1604 if len(ss.binlogs) != 0 {
1605 md, _ := metadata.FromIncomingContext(ctx)
1606 logEntry := &binarylog.ClientHeader{
1607 Header: md,
1608 MethodName: stream.Method(),
1609 PeerAddr: nil,
1610 }
1611 if deadline, ok := ctx.Deadline(); ok {
1612 logEntry.Timeout = time.Until(deadline)
1613 if logEntry.Timeout < 0 {
1614 logEntry.Timeout = 0
1615 }
1616 }
1617 if a := md[":authority"]; len(a) > 0 {
1618 logEntry.Authority = a[0]
1619 }
1620 if peer, ok := peer.FromContext(ss.Context()); ok {
1621 logEntry.PeerAddr = peer.Addr
1622 }
1623 for _, binlog := range ss.binlogs {
1624 binlog.Log(ctx, logEntry)
1625 }
1626 }
1627
1628
1629
1630 if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {
1631 ss.dc = s.opts.dc
1632 } else if rc != "" && rc != encoding.Identity {
1633 ss.decomp = encoding.GetCompressor(rc)
1634 if ss.decomp == nil {
1635 st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc)
1636 t.WriteStatus(ss.s, st)
1637 return st.Err()
1638 }
1639 }
1640
1641
1642
1643
1644
1645 if s.opts.cp != nil {
1646 ss.cp = s.opts.cp
1647 ss.sendCompressorName = s.opts.cp.Type()
1648 } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
1649
1650 ss.comp = encoding.GetCompressor(rc)
1651 if ss.comp != nil {
1652 ss.sendCompressorName = rc
1653 }
1654 }
1655
1656 if ss.sendCompressorName != "" {
1657 if err := stream.SetSendCompress(ss.sendCompressorName); err != nil {
1658 return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
1659 }
1660 }
1661
1662 ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp)
1663
1664 if trInfo != nil {
1665 trInfo.tr.LazyLog(&trInfo.firstLine, false)
1666 }
1667 var appErr error
1668 var server any
1669 if info != nil {
1670 server = info.serviceImpl
1671 }
1672 if s.opts.streamInt == nil {
1673 appErr = sd.Handler(server, ss)
1674 } else {
1675 info := &StreamServerInfo{
1676 FullMethod: stream.Method(),
1677 IsClientStream: sd.ClientStreams,
1678 IsServerStream: sd.ServerStreams,
1679 }
1680 appErr = s.opts.streamInt(server, ss, info, sd.Handler)
1681 }
1682 if appErr != nil {
1683 appStatus, ok := status.FromError(appErr)
1684 if !ok {
1685
1686
1687 appStatus = status.FromContextError(appErr)
1688 appErr = appStatus.Err()
1689 }
1690 if trInfo != nil {
1691 ss.mu.Lock()
1692 ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
1693 ss.trInfo.tr.SetError()
1694 ss.mu.Unlock()
1695 }
1696 if len(ss.binlogs) != 0 {
1697 st := &binarylog.ServerTrailer{
1698 Trailer: ss.s.Trailer(),
1699 Err: appErr,
1700 }
1701 for _, binlog := range ss.binlogs {
1702 binlog.Log(ctx, st)
1703 }
1704 }
1705 t.WriteStatus(ss.s, appStatus)
1706
1707 return appErr
1708 }
1709 if trInfo != nil {
1710 ss.mu.Lock()
1711 ss.trInfo.tr.LazyLog(stringer("OK"), false)
1712 ss.mu.Unlock()
1713 }
1714 if len(ss.binlogs) != 0 {
1715 st := &binarylog.ServerTrailer{
1716 Trailer: ss.s.Trailer(),
1717 Err: appErr,
1718 }
1719 for _, binlog := range ss.binlogs {
1720 binlog.Log(ctx, st)
1721 }
1722 }
1723 return t.WriteStatus(ss.s, statusOK)
1724 }
1725
1726 func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
1727 ctx := stream.Context()
1728 ctx = contextWithServer(ctx, s)
1729 var ti *traceInfo
1730 if EnableTracing {
1731 tr := newTrace("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
1732 ctx = newTraceContext(ctx, tr)
1733 ti = &traceInfo{
1734 tr: tr,
1735 firstLine: firstLine{
1736 client: false,
1737 remoteAddr: t.Peer().Addr,
1738 },
1739 }
1740 if dl, ok := ctx.Deadline(); ok {
1741 ti.firstLine.deadline = time.Until(dl)
1742 }
1743 }
1744
1745 sm := stream.Method()
1746 if sm != "" && sm[0] == '/' {
1747 sm = sm[1:]
1748 }
1749 pos := strings.LastIndex(sm, "/")
1750 if pos == -1 {
1751 if ti != nil {
1752 ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true)
1753 ti.tr.SetError()
1754 }
1755 errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
1756 if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
1757 if ti != nil {
1758 ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1759 ti.tr.SetError()
1760 }
1761 channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
1762 }
1763 if ti != nil {
1764 ti.tr.Finish()
1765 }
1766 return
1767 }
1768 service := sm[:pos]
1769 method := sm[pos+1:]
1770
1771 md, _ := metadata.FromIncomingContext(ctx)
1772 for _, sh := range s.opts.statsHandlers {
1773 ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
1774 sh.HandleRPC(ctx, &stats.InHeader{
1775 FullMethod: stream.Method(),
1776 RemoteAddr: t.Peer().Addr,
1777 LocalAddr: t.Peer().LocalAddr,
1778 Compression: stream.RecvCompress(),
1779 WireLength: stream.HeaderWireLength(),
1780 Header: md,
1781 })
1782 }
1783
1784
1785 stream.SetContext(ctx)
1786
1787 srv, knownService := s.services[service]
1788 if knownService {
1789 if md, ok := srv.methods[method]; ok {
1790 s.processUnaryRPC(ctx, t, stream, srv, md, ti)
1791 return
1792 }
1793 if sd, ok := srv.streams[method]; ok {
1794 s.processStreamingRPC(ctx, t, stream, srv, sd, ti)
1795 return
1796 }
1797 }
1798
1799 if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
1800 s.processStreamingRPC(ctx, t, stream, nil, unknownDesc, ti)
1801 return
1802 }
1803 var errDesc string
1804 if !knownService {
1805 errDesc = fmt.Sprintf("unknown service %v", service)
1806 } else {
1807 errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
1808 }
1809 if ti != nil {
1810 ti.tr.LazyPrintf("%s", errDesc)
1811 ti.tr.SetError()
1812 }
1813 if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
1814 if ti != nil {
1815 ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
1816 ti.tr.SetError()
1817 }
1818 channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
1819 }
1820 if ti != nil {
1821 ti.tr.Finish()
1822 }
1823 }
1824
1825
1826 type streamKey struct{}
1827
1828
1829
1830
1831
1832
1833
1834
1835 func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context {
1836 return context.WithValue(ctx, streamKey{}, stream)
1837 }
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850 type ServerTransportStream interface {
1851 Method() string
1852 SetHeader(md metadata.MD) error
1853 SendHeader(md metadata.MD) error
1854 SetTrailer(md metadata.MD) error
1855 }
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865 func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream {
1866 s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
1867 return s
1868 }
1869
1870
1871
1872
1873
1874
1875 func (s *Server) Stop() {
1876 s.stop(false)
1877 }
1878
1879
1880
1881
1882 func (s *Server) GracefulStop() {
1883 s.stop(true)
1884 }
1885
1886 func (s *Server) stop(graceful bool) {
1887 s.quit.Fire()
1888 defer s.done.Fire()
1889
1890 s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelz.ID) })
1891 s.mu.Lock()
1892 s.closeListenersLocked()
1893
1894
1895 s.mu.Unlock()
1896 s.serveWG.Wait()
1897
1898 s.mu.Lock()
1899 defer s.mu.Unlock()
1900
1901 if graceful {
1902 s.drainAllServerTransportsLocked()
1903 } else {
1904 s.closeServerTransportsLocked()
1905 }
1906
1907 for len(s.conns) != 0 {
1908 s.cv.Wait()
1909 }
1910 s.conns = nil
1911
1912 if s.opts.numServerWorkers > 0 {
1913
1914
1915
1916
1917 s.serverWorkerChannelClose()
1918 }
1919
1920 if graceful || s.opts.waitForHandlers {
1921 s.handlersWG.Wait()
1922 }
1923
1924 if s.events != nil {
1925 s.events.Finish()
1926 s.events = nil
1927 }
1928 }
1929
1930
1931 func (s *Server) closeServerTransportsLocked() {
1932 for _, conns := range s.conns {
1933 for st := range conns {
1934 st.Close(errors.New("Server.Stop called"))
1935 }
1936 }
1937 }
1938
1939
1940 func (s *Server) drainAllServerTransportsLocked() {
1941 if !s.drain {
1942 for _, conns := range s.conns {
1943 for st := range conns {
1944 st.Drain("graceful_stop")
1945 }
1946 }
1947 s.drain = true
1948 }
1949 }
1950
1951
1952 func (s *Server) closeListenersLocked() {
1953 for lis := range s.lis {
1954 lis.Close()
1955 }
1956 s.lis = nil
1957 }
1958
1959
1960
1961 func (s *Server) getCodec(contentSubtype string) baseCodec {
1962 if s.opts.codec != nil {
1963 return s.opts.codec
1964 }
1965 if contentSubtype == "" {
1966 return encoding.GetCodec(proto.Name)
1967 }
1968 codec := encoding.GetCodec(contentSubtype)
1969 if codec == nil {
1970 logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
1971 return encoding.GetCodec(proto.Name)
1972 }
1973 return codec
1974 }
1975
1976 type serverKey struct{}
1977
1978
1979 func serverFromContext(ctx context.Context) *Server {
1980 s, _ := ctx.Value(serverKey{}).(*Server)
1981 return s
1982 }
1983
1984
1985 func contextWithServer(ctx context.Context, server *Server) context.Context {
1986 return context.WithValue(ctx, serverKey{}, server)
1987 }
1988
1989
1990
1991
1992 func (s *Server) isRegisteredMethod(serviceMethod string) bool {
1993 if serviceMethod != "" && serviceMethod[0] == '/' {
1994 serviceMethod = serviceMethod[1:]
1995 }
1996 pos := strings.LastIndex(serviceMethod, "/")
1997 if pos == -1 {
1998 return false
1999 }
2000 service := serviceMethod[:pos]
2001 method := serviceMethod[pos+1:]
2002 srv, knownService := s.services[service]
2003 if knownService {
2004 if _, ok := srv.methods[method]; ok {
2005 return true
2006 }
2007 if _, ok := srv.streams[method]; ok {
2008 return true
2009 }
2010 }
2011 return false
2012 }
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034 func SetHeader(ctx context.Context, md metadata.MD) error {
2035 if md.Len() == 0 {
2036 return nil
2037 }
2038 stream := ServerTransportStreamFromContext(ctx)
2039 if stream == nil {
2040 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
2041 }
2042 return stream.SetHeader(md)
2043 }
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053 func SendHeader(ctx context.Context, md metadata.MD) error {
2054 stream := ServerTransportStreamFromContext(ctx)
2055 if stream == nil {
2056 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
2057 }
2058 if err := stream.SendHeader(md); err != nil {
2059 return toRPCErr(err)
2060 }
2061 return nil
2062 }
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087 func SetSendCompressor(ctx context.Context, name string) error {
2088 stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
2089 if !ok || stream == nil {
2090 return fmt.Errorf("failed to fetch the stream from the given context")
2091 }
2092
2093 if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
2094 return fmt.Errorf("unable to set send compressor: %w", err)
2095 }
2096
2097 return stream.SetSendCompress(name)
2098 }
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109 func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
2110 stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
2111 if !ok || stream == nil {
2112 return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
2113 }
2114
2115 return stream.ClientAdvertisedCompressors(), nil
2116 }
2117
2118
2119
2120
2121
2122
2123
2124 func SetTrailer(ctx context.Context, md metadata.MD) error {
2125 if md.Len() == 0 {
2126 return nil
2127 }
2128 stream := ServerTransportStreamFromContext(ctx)
2129 if stream == nil {
2130 return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
2131 }
2132 return stream.SetTrailer(md)
2133 }
2134
2135
2136
2137 func Method(ctx context.Context) (string, bool) {
2138 s := ServerTransportStreamFromContext(ctx)
2139 if s == nil {
2140 return "", false
2141 }
2142 return s.Method(), true
2143 }
2144
2145
2146
2147 func validateSendCompressor(name string, clientCompressors []string) error {
2148 if name == encoding.Identity {
2149 return nil
2150 }
2151
2152 if !grpcutil.IsCompressorNameRegistered(name) {
2153 return fmt.Errorf("compressor not registered %q", name)
2154 }
2155
2156 for _, c := range clientCompressors {
2157 if c == name {
2158 return nil
2159 }
2160 }
2161 return fmt.Errorf("client does not support compressor %q", name)
2162 }
2163
2164
2165
2166 type atomicSemaphore struct {
2167 n atomic.Int64
2168 wait chan struct{}
2169 }
2170
2171 func (q *atomicSemaphore) acquire() {
2172 if q.n.Add(-1) < 0 {
2173
2174 <-q.wait
2175 }
2176 }
2177
2178 func (q *atomicSemaphore) release() {
2179
2180
2181
2182
2183 if q.n.Add(1) <= 0 {
2184
2185 q.wait <- struct{}{}
2186 }
2187 }
2188
2189 func newHandlerQuota(n uint32) *atomicSemaphore {
2190 a := &atomicSemaphore{wait: make(chan struct{}, 1)}
2191 a.n.Store(int64(n))
2192 return a
2193 }
2194
View as plain text