1
18
19 package test
20
21 import (
22 "context"
23 "errors"
24 "fmt"
25 "net"
26 "strings"
27 "testing"
28 "time"
29
30 "google.golang.org/grpc"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/grpc/connectivity"
33 "google.golang.org/grpc/credentials"
34 "google.golang.org/grpc/credentials/insecure"
35 "google.golang.org/grpc/internal/testutils"
36 "google.golang.org/grpc/metadata"
37 "google.golang.org/grpc/resolver"
38 "google.golang.org/grpc/resolver/manual"
39 "google.golang.org/grpc/status"
40 "google.golang.org/grpc/tap"
41 "google.golang.org/grpc/testdata"
42
43 testgrpc "google.golang.org/grpc/interop/grpc_testing"
44 testpb "google.golang.org/grpc/interop/grpc_testing"
45 )
46
47 const (
48 bundlePerRPCOnly = "perRPCOnly"
49 bundleTLSOnly = "tlsOnly"
50 )
51
52 type testCredsBundle struct {
53 t *testing.T
54 mode string
55 }
56
57 func (c *testCredsBundle) TransportCredentials() credentials.TransportCredentials {
58 if c.mode == bundlePerRPCOnly {
59 return insecure.NewCredentials()
60 }
61
62 creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
63 if err != nil {
64 c.t.Logf("Failed to load credentials: %v", err)
65 return nil
66 }
67 return creds
68 }
69
70 func (c *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
71 if c.mode == bundleTLSOnly {
72 return nil
73 }
74 return testPerRPCCredentials{authdata: authdata}
75 }
76
77 func (c *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
78 return &testCredsBundle{mode: mode}, nil
79 }
80
81 func (s) TestCredsBundleBoth(t *testing.T) {
82 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"})
83 te.tapHandle = authHandle
84 te.customDialOptions = []grpc.DialOption{
85 grpc.WithCredentialsBundle(&testCredsBundle{t: t}),
86 }
87 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
88 if err != nil {
89 t.Fatalf("Failed to generate credentials %v", err)
90 }
91 te.customServerOptions = []grpc.ServerOption{
92 grpc.Creds(creds),
93 }
94 te.startServer(&testServer{})
95 defer te.tearDown()
96
97 cc := te.clientConn()
98 tc := testgrpc.NewTestServiceClient(cc)
99 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
100 defer cancel()
101 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
102 t.Fatalf("Test failed. Reason: %v", err)
103 }
104 }
105
106 func (s) TestCredsBundleTransportCredentials(t *testing.T) {
107 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"})
108 te.customDialOptions = []grpc.DialOption{
109 grpc.WithCredentialsBundle(&testCredsBundle{t: t, mode: bundleTLSOnly}),
110 }
111 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
112 if err != nil {
113 t.Fatalf("Failed to generate credentials %v", err)
114 }
115 te.customServerOptions = []grpc.ServerOption{
116 grpc.Creds(creds),
117 }
118 te.startServer(&testServer{})
119 defer te.tearDown()
120
121 cc := te.clientConn()
122 tc := testgrpc.NewTestServiceClient(cc)
123 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
124 defer cancel()
125 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
126 t.Fatalf("Test failed. Reason: %v", err)
127 }
128 }
129
130 func (s) TestCredsBundlePerRPCCredentials(t *testing.T) {
131 te := newTest(t, env{name: "creds-bundle", network: "tcp", security: "empty"})
132 te.tapHandle = authHandle
133 te.customDialOptions = []grpc.DialOption{
134 grpc.WithCredentialsBundle(&testCredsBundle{t: t, mode: bundlePerRPCOnly}),
135 }
136 te.startServer(&testServer{})
137 defer te.tearDown()
138
139 cc := te.clientConn()
140 tc := testgrpc.NewTestServiceClient(cc)
141 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
142 defer cancel()
143 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
144 t.Fatalf("Test failed. Reason: %v", err)
145 }
146 }
147
148 type clientTimeoutCreds struct {
149 credentials.TransportCredentials
150 timeoutReturned bool
151 }
152
153 func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
154 if !c.timeoutReturned {
155 c.timeoutReturned = true
156 return nil, nil, context.DeadlineExceeded
157 }
158 return rawConn, nil, nil
159 }
160
161 func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo {
162 return credentials.ProtocolInfo{}
163 }
164
165 func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials {
166 return nil
167 }
168
169 func (s) TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
170 te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "empty"})
171 te.userAgent = testAppUA
172 te.startServer(&testServer{security: te.e.security})
173 defer te.tearDown()
174
175 cc := te.clientConn(grpc.WithTransportCredentials(&clientTimeoutCreds{}))
176 tc := testgrpc.NewTestServiceClient(cc)
177 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
178 defer cancel()
179
180 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
181 te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want <nil>", err)
182 }
183 }
184
185 type methodTestCreds struct{}
186
187 func (m *methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
188 ri, _ := credentials.RequestInfoFromContext(ctx)
189 return nil, status.Errorf(codes.Unknown, ri.Method)
190 }
191
192 func (m *methodTestCreds) RequireTransportSecurity() bool { return false }
193
194 func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) {
195 const wantMethod = "/grpc.testing.TestService/EmptyCall"
196 te := newTest(t, env{name: "context-request-info", network: "tcp"})
197 te.userAgent = testAppUA
198 te.startServer(&testServer{security: te.e.security})
199 defer te.tearDown()
200
201 cc := te.clientConn(grpc.WithPerRPCCredentials(&methodTestCreds{}))
202 tc := testgrpc.NewTestServiceClient(cc)
203
204 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
205 defer cancel()
206 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod {
207 t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod)
208 }
209
210 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Convert(err).Message() != wantMethod {
211 t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod)
212 }
213 }
214
215 const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails"
216
217 type clientAlwaysFailCred struct {
218 credentials.TransportCredentials
219 }
220
221 func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
222 return nil, nil, errors.New(clientAlwaysFailCredErrorMsg)
223 }
224 func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo {
225 return credentials.ProtocolInfo{}
226 }
227 func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials {
228 return nil
229 }
230
231 func (s) TestFailFastRPCErrorOnBadCertificates(t *testing.T) {
232 te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"})
233 te.startServer(&testServer{security: te.e.security})
234 defer te.tearDown()
235
236 opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})}
237 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
238 defer cancel()
239 cc, err := grpc.DialContext(ctx, te.srvAddr, opts...)
240 if err != nil {
241 t.Fatalf("Dial(_) = %v, want %v", err, nil)
242 }
243 defer cc.Close()
244
245 tc := testgrpc.NewTestServiceClient(cc)
246 for i := 0; i < 1000; i++ {
247
248
249
250
251 if _, err = tc.EmptyCall(ctx, &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
252 return
253 }
254 time.Sleep(time.Millisecond)
255 }
256 te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg)
257 }
258
259 func (s) TestWaitForReadyRPCErrorOnBadCertificates(t *testing.T) {
260 te := newTest(t, env{name: "bad-cred", network: "tcp", security: "empty", balancer: "round_robin"})
261 te.startServer(&testServer{security: te.e.security})
262 defer te.tearDown()
263
264 opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})}
265 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
266 defer cancel()
267 cc, err := grpc.DialContext(ctx, te.srvAddr, opts...)
268 if err != nil {
269 t.Fatalf("Dial(_) = %v, want %v", err, nil)
270 }
271 defer cc.Close()
272
273 tc := testgrpc.NewTestServiceClient(cc)
274 ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
275 defer cancel()
276 if _, err = tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) {
277 t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg)
278 }
279 }
280
281 var (
282
283 authdata = map[string]string{
284 "test-key": "test-value",
285 "test-key2-bin": string([]byte{1, 2, 3}),
286 }
287 )
288
289 type testPerRPCCredentials struct {
290 authdata map[string]string
291 errChan chan error
292 }
293
294 func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
295 var err error
296 if cr.errChan != nil {
297 err = <-cr.errChan
298 }
299 return cr.authdata, err
300 }
301
302 func (cr testPerRPCCredentials) RequireTransportSecurity() bool {
303 return false
304 }
305
306 func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) {
307 md, ok := metadata.FromIncomingContext(ctx)
308 if !ok {
309 return ctx, fmt.Errorf("didn't find metadata in context")
310 }
311 for k, vwant := range authdata {
312 vgot, ok := md[k]
313 if !ok {
314 return ctx, fmt.Errorf("didn't find authdata key %v in context", k)
315 }
316 if vgot[0] != vwant {
317 return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant)
318 }
319 }
320 return ctx, nil
321 }
322
323 func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) {
324 for _, e := range listTestEnv() {
325 testPerRPCCredentialsViaDialOptions(t, e)
326 }
327 }
328
329 func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) {
330 te := newTest(t, e)
331 te.tapHandle = authHandle
332 te.perRPCCreds = testPerRPCCredentials{authdata: authdata}
333 te.startServer(&testServer{security: e.security})
334 defer te.tearDown()
335
336 cc := te.clientConn()
337 tc := testgrpc.NewTestServiceClient(cc)
338 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
339 defer cancel()
340 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
341 t.Fatalf("Test failed. Reason: %v", err)
342 }
343 }
344
345 func (s) TestPerRPCCredentialsViaCallOptions(t *testing.T) {
346 for _, e := range listTestEnv() {
347 testPerRPCCredentialsViaCallOptions(t, e)
348 }
349 }
350
351 func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) {
352 te := newTest(t, e)
353 te.tapHandle = authHandle
354 te.startServer(&testServer{security: e.security})
355 defer te.tearDown()
356
357 cc := te.clientConn()
358 tc := testgrpc.NewTestServiceClient(cc)
359 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
360 defer cancel()
361 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil {
362 t.Fatalf("Test failed. Reason: %v", err)
363 }
364 }
365
366 func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) {
367 for _, e := range listTestEnv() {
368 testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e)
369 }
370 }
371
372 func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) {
373 te := newTest(t, e)
374 te.perRPCCreds = testPerRPCCredentials{authdata: authdata}
375
376
377 te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
378 md, ok := metadata.FromIncomingContext(ctx)
379 if !ok {
380 return ctx, fmt.Errorf("couldn't find metadata in context")
381 }
382 for k, vwant := range authdata {
383 vgot, ok := md[k]
384 if !ok {
385 return ctx, fmt.Errorf("couldn't find metadata for key %v", k)
386 }
387 if len(vgot) != 2 {
388 return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot))
389 }
390 if vgot[0] != vwant || vgot[1] != vwant {
391 return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant)
392 }
393 }
394 return ctx, nil
395 }
396 te.startServer(&testServer{security: e.security})
397 defer te.tearDown()
398
399 cc := te.clientConn()
400 tc := testgrpc.NewTestServiceClient(cc)
401 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
402 defer cancel()
403 if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil {
404 t.Fatalf("Test failed. Reason: %v", err)
405 }
406 }
407
408 const testAuthority = "test.auth.ori.ty"
409
410 type authorityCheckCreds struct {
411 credentials.TransportCredentials
412 got string
413 }
414
415 func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
416 c.got = authority
417 return rawConn, nil, nil
418 }
419 func (c *authorityCheckCreds) Info() credentials.ProtocolInfo {
420 return credentials.ProtocolInfo{}
421 }
422 func (c *authorityCheckCreds) Clone() credentials.TransportCredentials {
423 return c
424 }
425
426
427
428 func (s) TestCredsHandshakeAuthority(t *testing.T) {
429 lis, err := net.Listen("tcp", "localhost:0")
430 if err != nil {
431 t.Fatal(err)
432 }
433 cred := &authorityCheckCreds{}
434 s := grpc.NewServer()
435 go s.Serve(lis)
436 defer s.Stop()
437
438 r := manual.NewBuilderWithScheme("whatever")
439
440 cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred), grpc.WithResolvers(r))
441 if err != nil {
442 t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
443 }
444 defer cc.Close()
445 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
446
447 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
448 defer cancel()
449 testutils.AwaitState(ctx, t, cc, connectivity.Ready)
450
451 if cred.got != testAuthority {
452 t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
453 }
454 }
455
456
457
458 func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) {
459 const testServerName = "test.server.name"
460
461 lis, err := net.Listen("tcp", "localhost:0")
462 if err != nil {
463 t.Fatal(err)
464 }
465 cred := &authorityCheckCreds{}
466 s := grpc.NewServer()
467 go s.Serve(lis)
468 defer s.Stop()
469
470 r := manual.NewBuilderWithScheme("whatever")
471
472 cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred), grpc.WithResolvers(r))
473 if err != nil {
474 t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
475 }
476 defer cc.Close()
477 r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}})
478
479 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
480 defer cancel()
481 testutils.AwaitState(ctx, t, cc, connectivity.Ready)
482
483 if cred.got != testServerName {
484 t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
485 }
486 }
487
488 type serverDispatchCred struct {
489 rawConnCh chan net.Conn
490 }
491
492 func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
493 return rawConn, nil, nil
494 }
495 func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
496 select {
497 case c.rawConnCh <- rawConn:
498 default:
499 }
500 return nil, nil, credentials.ErrConnDispatched
501 }
502 func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
503 return credentials.ProtocolInfo{}
504 }
505 func (c *serverDispatchCred) Clone() credentials.TransportCredentials {
506 return nil
507 }
508 func (c *serverDispatchCred) OverrideServerName(s string) error {
509 return nil
510 }
511 func (c *serverDispatchCred) getRawConn() net.Conn {
512 return <-c.rawConnCh
513 }
514
515 func (s) TestServerCredsDispatch(t *testing.T) {
516 lis, err := net.Listen("tcp", "localhost:0")
517 if err != nil {
518 t.Fatal(err)
519 }
520 cred := &serverDispatchCred{
521 rawConnCh: make(chan net.Conn, 1),
522 }
523 s := grpc.NewServer(grpc.Creds(cred))
524 go s.Serve(lis)
525 defer s.Stop()
526
527 cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred))
528 if err != nil {
529 t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
530 }
531 defer cc.Close()
532
533 rawConn := cred.getRawConn()
534
535
536 time.Sleep(100 * time.Millisecond)
537
538 if n, err := rawConn.Write([]byte{0}); n <= 0 || err != nil {
539 t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err)
540 }
541 }
542
View as plain text