1
19
20 package test
21
22 import (
23 "context"
24 "fmt"
25 "testing"
26
27 "google.golang.org/grpc"
28 "google.golang.org/grpc/internal/stubserver"
29 "google.golang.org/grpc/internal/testutils"
30
31 testgrpc "google.golang.org/grpc/interop/grpc_testing"
32 testpb "google.golang.org/grpc/interop/grpc_testing"
33 )
34
35 type parentCtxkey struct{}
36 type firstInterceptorCtxkey struct{}
37 type secondInterceptorCtxkey struct{}
38 type baseInterceptorCtxKey struct{}
39
40 const (
41 parentCtxVal = "parent"
42 firstInterceptorCtxVal = "firstInterceptor"
43 secondInterceptorCtxVal = "secondInterceptor"
44 baseInterceptorCtxVal = "baseInterceptor"
45 )
46
47
48
49
50 func (s) TestUnaryClientInterceptor_ContextValuePropagation(t *testing.T) {
51 errCh := testutils.NewChannel()
52 unaryInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
53 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
54 errCh.Send(fmt.Errorf("unaryInt got %q in context.Val, want %q", got, parentCtxVal))
55 }
56 errCh.Send(nil)
57 return invoker(ctx, method, req, reply, cc, opts...)
58 }
59
60
61
62 ss := &stubserver.StubServer{
63 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
64 }
65 if err := ss.Start(nil, grpc.WithUnaryInterceptor(unaryInt)); err != nil {
66 t.Fatalf("Failed to start stub server: %v", err)
67 }
68 defer ss.Stop()
69
70 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
71 defer cancel()
72 if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil {
73 t.Fatalf("ss.Client.EmptyCall() failed: %v", err)
74 }
75 val, err := errCh.Receive(ctx)
76 if err != nil {
77 t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err)
78 }
79 if val != nil {
80 t.Fatalf("unary interceptor failed: %v", val)
81 }
82 }
83
84
85
86
87 func (s) TestChainUnaryClientInterceptor_ContextValuePropagation(t *testing.T) {
88 errCh := testutils.NewChannel()
89 firstInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
90 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
91 errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal))
92 }
93 if ctx.Value(firstInterceptorCtxkey{}) != nil {
94 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{}))
95 }
96 if ctx.Value(secondInterceptorCtxkey{}) != nil {
97 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{}))
98 }
99 firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal)
100 return invoker(firstCtx, method, req, reply, cc, opts...)
101 }
102
103 secondInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
104 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
105 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal))
106 }
107 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
108 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
109 }
110 if ctx.Value(secondInterceptorCtxkey{}) != nil {
111 errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{}))
112 }
113 secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal)
114 return invoker(secondCtx, method, req, reply, cc, opts...)
115 }
116
117 lastInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
118 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
119 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal))
120 }
121 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
122 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
123 }
124 if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal {
125 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal))
126 }
127 errCh.SendContext(ctx, nil)
128 return invoker(ctx, method, req, reply, cc, opts...)
129 }
130
131
132
133 ss := &stubserver.StubServer{
134 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
135 }
136 if err := ss.Start(nil, grpc.WithChainUnaryInterceptor(firstInt, secondInt, lastInt)); err != nil {
137 t.Fatalf("Failed to start stub server: %v", err)
138 }
139 defer ss.Stop()
140
141 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
142 defer cancel()
143 if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil {
144 t.Fatalf("ss.Client.EmptyCall() failed: %v", err)
145 }
146 val, err := errCh.Receive(ctx)
147 if err != nil {
148 t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err)
149 }
150 if val != nil {
151 t.Fatalf("unary interceptor failed: %v", val)
152 }
153 }
154
155
156
157
158
159 func (s) TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation(t *testing.T) {
160 errCh := testutils.NewChannel()
161 baseInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
162 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
163 errCh.SendContext(ctx, fmt.Errorf("base interceptor got %q in context.Val, want %q", got, parentCtxVal))
164 }
165 if ctx.Value(baseInterceptorCtxKey{}) != nil {
166 errCh.SendContext(ctx, fmt.Errorf("baseinterceptor should not have %T in context", baseInterceptorCtxKey{}))
167 }
168 baseCtx := context.WithValue(ctx, baseInterceptorCtxKey{}, baseInterceptorCtxVal)
169 return invoker(baseCtx, method, req, reply, cc, opts...)
170 }
171
172 chainInt := func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
173 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
174 errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, parentCtxVal))
175 }
176 if got, ok := ctx.Value(baseInterceptorCtxKey{}).(string); !ok || got != baseInterceptorCtxVal {
177 errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, baseInterceptorCtxVal))
178 }
179 errCh.SendContext(ctx, nil)
180 return invoker(ctx, method, req, reply, cc, opts...)
181 }
182
183
184
185 ss := &stubserver.StubServer{
186 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
187 }
188 if err := ss.Start(nil, grpc.WithUnaryInterceptor(baseInt), grpc.WithChainUnaryInterceptor(chainInt)); err != nil {
189 t.Fatalf("Failed to start stub server: %v", err)
190 }
191 defer ss.Stop()
192
193 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
194 defer cancel()
195 if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil {
196 t.Fatalf("ss.Client.EmptyCall() failed: %v", err)
197 }
198 val, err := errCh.Receive(ctx)
199 if err != nil {
200 t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err)
201 }
202 if val != nil {
203 t.Fatalf("unary interceptor failed: %v", val)
204 }
205 }
206
207
208
209
210 func (s) TestChainStreamClientInterceptor_ContextValuePropagation(t *testing.T) {
211 errCh := testutils.NewChannel()
212 firstInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
213 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
214 errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal))
215 }
216 if ctx.Value(firstInterceptorCtxkey{}) != nil {
217 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{}))
218 }
219 if ctx.Value(secondInterceptorCtxkey{}) != nil {
220 errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{}))
221 }
222 firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal)
223 return streamer(firstCtx, desc, cc, method, opts...)
224 }
225
226 secondInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
227 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
228 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal))
229 }
230 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
231 errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
232 }
233 if ctx.Value(secondInterceptorCtxkey{}) != nil {
234 errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{}))
235 }
236 secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal)
237 return streamer(secondCtx, desc, cc, method, opts...)
238 }
239
240 lastInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
241 if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal {
242 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal))
243 }
244 if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal {
245 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal))
246 }
247 if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal {
248 errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal))
249 }
250 errCh.SendContext(ctx, nil)
251 return streamer(ctx, desc, cc, method, opts...)
252 }
253
254
255
256 ss := &stubserver.StubServer{
257 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
258 if _, err := stream.Recv(); err != nil {
259 return err
260 }
261 return stream.Send(&testpb.StreamingOutputCallResponse{})
262 },
263 }
264 if err := ss.Start(nil, grpc.WithChainStreamInterceptor(firstInt, secondInt, lastInt)); err != nil {
265 t.Fatalf("Failed to start stub server: %v", err)
266 }
267 defer ss.Stop()
268
269 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
270 defer cancel()
271 if _, err := ss.Client.FullDuplexCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal)); err != nil {
272 t.Fatalf("ss.Client.FullDuplexCall() failed: %v", err)
273 }
274 val, err := errCh.Receive(ctx)
275 if err != nil {
276 t.Fatalf("timeout when waiting for stream interceptor to be invoked: %v", err)
277 }
278 if val != nil {
279 t.Fatalf("stream interceptor failed: %v", val)
280 }
281 }
282
View as plain text