1
18 package test
19
20 import (
21 "context"
22 "io"
23 "net"
24 "sync"
25 "testing"
26
27 "google.golang.org/grpc"
28 "google.golang.org/grpc/codes"
29 "google.golang.org/grpc/credentials"
30 "google.golang.org/grpc/internal/grpcsync"
31 "google.golang.org/grpc/internal/stubserver"
32 "google.golang.org/grpc/internal/transport"
33 "google.golang.org/grpc/status"
34
35 testgrpc "google.golang.org/grpc/interop/grpc_testing"
36 testpb "google.golang.org/grpc/interop/grpc_testing"
37 )
38
39
40 type connWrapperWithCloseCh struct {
41 net.Conn
42 close *grpcsync.Event
43 }
44
45
46 func (cw *connWrapperWithCloseCh) Close() error {
47 cw.close.Fire()
48 return cw.Conn.Close()
49 }
50
51
52
53 type transportRestartCheckCreds struct {
54 mu sync.Mutex
55 connections []*connWrapperWithCloseCh
56 }
57
58 func (c *transportRestartCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
59 return rawConn, nil, nil
60 }
61 func (c *transportRestartCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
62 c.mu.Lock()
63 defer c.mu.Unlock()
64 conn := &connWrapperWithCloseCh{Conn: rawConn, close: grpcsync.NewEvent()}
65 c.connections = append(c.connections, conn)
66 return conn, nil, nil
67 }
68 func (c *transportRestartCheckCreds) Info() credentials.ProtocolInfo {
69 return credentials.ProtocolInfo{}
70 }
71 func (c *transportRestartCheckCreds) Clone() credentials.TransportCredentials {
72 return c
73 }
74 func (c *transportRestartCheckCreds) OverrideServerName(s string) error {
75 return nil
76 }
77
78
79
80
81 func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
82
83 originalMaxStreamID := transport.MaxStreamID
84 transport.MaxStreamID = 4
85 defer func() {
86 transport.MaxStreamID = originalMaxStreamID
87 }()
88
89 ss := &stubserver.StubServer{
90 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
91 if _, err := stream.Recv(); err != nil {
92 return status.Errorf(codes.Internal, "unexpected error receiving: %v", err)
93 }
94 if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil {
95 return status.Errorf(codes.Internal, "unexpected error sending: %v", err)
96 }
97 if recv, err := stream.Recv(); err != io.EOF {
98 return status.Errorf(codes.Internal, "Recv = %v, %v; want _, io.EOF", recv, err)
99 }
100 return nil
101 },
102 }
103
104 creds := &transportRestartCheckCreds{}
105 if err := ss.Start(nil, grpc.WithTransportCredentials(creds)); err != nil {
106 t.Fatalf("Starting stubServer: %v", err)
107 }
108 defer ss.Stop()
109
110 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
111 defer cancel()
112
113 var streams []testgrpc.TestService_FullDuplexCallClient
114
115 const numStreams = 3
116
117
118 expectedNumConns := [numStreams]int{1, 1, 2}
119
120
121 for i := 0; i < numStreams; i++ {
122 s, err := ss.Client.FullDuplexCall(ctx)
123 if err != nil {
124 t.Fatalf("Creating FullDuplex stream: %v", err)
125 }
126 streams = append(streams, s)
127
128 if len(creds.connections) != expectedNumConns[i] {
129 t.Fatalf("Got number of connections created: %v, want: %v", len(creds.connections), expectedNumConns[i])
130 }
131 }
132
133
134 for i, stream := range streams {
135 if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
136 t.Fatalf("Sending on stream %d: %v", i, err)
137 }
138 if _, err := stream.Recv(); err != nil {
139 t.Fatalf("Receiving on stream %d: %v", i, err)
140 }
141 }
142
143 for i, stream := range streams {
144 if err := stream.CloseSend(); err != nil {
145 t.Fatalf("CloseSend() on stream %d: %v", i, err)
146 }
147 }
148
149
150 select {
151 case <-creds.connections[0].close.Done():
152 case <-ctx.Done():
153 t.Fatal("Timeout expired when waiting for first client transport to close")
154 }
155 }
156
View as plain text