1
18
19 package xds
20
21 import (
22 "context"
23 "crypto/tls"
24 "crypto/x509"
25 "errors"
26 "fmt"
27 "net"
28 "os"
29 "strings"
30 "sync/atomic"
31 "testing"
32 "time"
33 "unsafe"
34
35 "google.golang.org/grpc/credentials"
36 "google.golang.org/grpc/credentials/tls/certprovider"
37 icredentials "google.golang.org/grpc/internal/credentials"
38 xdsinternal "google.golang.org/grpc/internal/credentials/xds"
39 "google.golang.org/grpc/internal/grpctest"
40 "google.golang.org/grpc/internal/testutils"
41 "google.golang.org/grpc/internal/xds/matcher"
42 "google.golang.org/grpc/resolver"
43 "google.golang.org/grpc/testdata"
44 )
45
46 const (
47 defaultTestTimeout = 1 * time.Second
48 defaultTestShortTimeout = 10 * time.Millisecond
49 defaultTestCertSAN = "abc.test.example.com"
50 authority = "authority"
51 )
52
53 type s struct {
54 grpctest.Tester
55 }
56
57 func Test(t *testing.T) {
58 grpctest.RunSubTests(t, s{})
59 }
60
61
62
63 func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
64 creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
65 if err != nil {
66 t.Fatal(err)
67 }
68 return creds
69 }
70
71
72
73
74
75
76 type testServer struct {
77 lis net.Listener
78 address string
79 handshakeFunc testHandshakeFunc
80 hsResult *testutils.Channel
81 }
82
83
84
85
86 type handshakeResult struct {
87 connState tls.ConnectionState
88 err error
89 }
90
91
92
93 type testHandshakeFunc func(net.Conn) handshakeResult
94
95
96
97
98 func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
99 ts := &testServer{
100 handshakeFunc: f,
101 hsResult: testutils.NewChannel(),
102 }
103 ts.start()
104 return ts
105 }
106
107
108
109 func (ts *testServer) start() error {
110 lis, err := net.Listen("tcp", "localhost:0")
111 if err != nil {
112 return err
113 }
114 ts.lis = lis
115 ts.address = lis.Addr().String()
116 go ts.handleConn()
117 return nil
118 }
119
120
121
122
123 func (ts *testServer) handleConn() {
124 for {
125 rawConn, err := ts.lis.Accept()
126 if err != nil {
127
128 return
129 }
130 hsr := ts.handshakeFunc(rawConn)
131 ts.hsResult.Send(hsr)
132 }
133 }
134
135
136
137 func (ts *testServer) stop() {
138 ts.lis.Close()
139 }
140
141
142
143
144 func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
145 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
146 if err != nil {
147 return handshakeResult{err: err}
148 }
149 cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
150 conn := tls.Server(rawConn, cfg)
151 if err := conn.Handshake(); err != nil {
152 return handshakeResult{err: err}
153 }
154 return handshakeResult{connState: conn.ConnectionState()}
155 }
156
157
158
159 func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
160 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
161 if err != nil {
162 return handshakeResult{err: err}
163 }
164 pemData, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
165 if err != nil {
166 return handshakeResult{err: err}
167 }
168 roots := x509.NewCertPool()
169 roots.AppendCertsFromPEM(pemData)
170 cfg := &tls.Config{
171 Certificates: []tls.Certificate{cert},
172 ClientCAs: roots,
173 }
174 conn := tls.Server(rawConn, cfg)
175 if err := conn.Handshake(); err != nil {
176 return handshakeResult{err: err}
177 }
178 return handshakeResult{connState: conn.ConnectionState()}
179 }
180
181
182
183
184 type fakeProvider struct {
185 km *certprovider.KeyMaterial
186 err error
187 }
188
189 func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
190 return f.km, f.err
191 }
192
193 func (f *fakeProvider) Close() {}
194
195
196
197 func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider {
198 t.Helper()
199 cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
200 if err != nil {
201 t.Fatal(err)
202 }
203 return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}}
204 }
205
206
207
208 func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
209 pemData, err := os.ReadFile(testdata.Path(caPath))
210 if err != nil {
211 t.Fatal(err)
212 }
213 roots := x509.NewCertPool()
214 roots.AppendCertsFromPEM(pemData)
215 return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
216 }
217
218
219
220 func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
221
222
223
224 var sms []matcher.StringMatcher
225 if sanExactMatch != "" {
226 sms = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}
227 }
228 info := xdsinternal.NewHandshakeInfo(root, identity, sms, false)
229 uPtr := unsafe.Pointer(info)
230 addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
231
232
233
234
235 return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
236 }
237
238
239
240 func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
241 if ai.AuthType() != "tls" {
242 return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
243 }
244 info, ok := ai.(credentials.TLSInfo)
245 if !ok {
246 return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})
247 }
248 gotState := info.State
249
250
251
252 val, err := ts.hsResult.Receive(ctx)
253 if err != nil {
254 return fmt.Errorf("testServer failed to return handshake result: %v", err)
255 }
256 hsr := val.(handshakeResult)
257 if hsr.err != nil {
258 return fmt.Errorf("testServer handshake failure: %v", hsr.err)
259 }
260
261
262 if err := compareConnState(gotState, hsr.connState); err != nil {
263 return err
264 }
265 return nil
266 }
267
268 func compareConnState(got, want tls.ConnectionState) error {
269 switch {
270 case got.Version != want.Version:
271 return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version)
272 case got.HandshakeComplete != want.HandshakeComplete:
273 return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete)
274 case got.CipherSuite != want.CipherSuite:
275 return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite)
276 case got.NegotiatedProtocol != want.NegotiatedProtocol:
277 return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol)
278 }
279 return nil
280 }
281
282
283
284 func (s) TestClientCredsWithoutFallback(t *testing.T) {
285 if _, err := NewClientCredentials(ClientOptions{}); err == nil {
286 t.Fatal("NewClientCredentials() succeeded without specifying fallback")
287 }
288 }
289
290
291
292
293 func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
294 opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
295 creds, err := NewClientCredentials(opts)
296 if err != nil {
297 t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
298 }
299
300 pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
301 defer cancel()
302 ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "")
303 if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
304 t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
305 }
306 }
307
308
309
310 func (s) TestClientCredsProviderFailure(t *testing.T) {
311 opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
312 creds, err := NewClientCredentials(opts)
313 if err != nil {
314 t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
315 }
316
317 tests := []struct {
318 desc string
319 rootProvider certprovider.Provider
320 identityProvider certprovider.Provider
321 wantErr string
322 }{
323 {
324 desc: "erroring root provider",
325 rootProvider: &fakeProvider{err: errors.New("root provider error")},
326 wantErr: "root provider error",
327 },
328 {
329 desc: "erroring identity provider",
330 rootProvider: &fakeProvider{km: &certprovider.KeyMaterial{}},
331 identityProvider: &fakeProvider{err: errors.New("identity provider error")},
332 wantErr: "identity provider error",
333 },
334 }
335 for _, test := range tests {
336 t.Run(test.desc, func(t *testing.T) {
337 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
338 defer cancel()
339 ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "")
340 if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
341 t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
342 }
343 })
344 }
345 }
346
347
348 func (s) TestClientCredsSuccess(t *testing.T) {
349 tests := []struct {
350 desc string
351 handshakeFunc testHandshakeFunc
352 handshakeInfoCtx func(ctx context.Context) context.Context
353 }{
354 {
355 desc: "fallback",
356 handshakeFunc: testServerTLSHandshake,
357 handshakeInfoCtx: func(ctx context.Context) context.Context {
358
359
360 return ctx
361 },
362 },
363 {
364 desc: "TLS",
365 handshakeFunc: testServerTLSHandshake,
366 handshakeInfoCtx: func(ctx context.Context) context.Context {
367 return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
368 },
369 },
370 {
371 desc: "mTLS",
372 handshakeFunc: testServerMutualTLSHandshake,
373 handshakeInfoCtx: func(ctx context.Context) context.Context {
374 return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
375 },
376 },
377 {
378 desc: "mTLS with no acceptedSANs specified",
379 handshakeFunc: testServerMutualTLSHandshake,
380 handshakeInfoCtx: func(ctx context.Context) context.Context {
381 return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
382 },
383 },
384 }
385
386 for _, test := range tests {
387 t.Run(test.desc, func(t *testing.T) {
388 ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
389 defer ts.stop()
390
391 opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
392 creds, err := NewClientCredentials(opts)
393 if err != nil {
394 t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
395 }
396
397 conn, err := net.Dial("tcp", ts.address)
398 if err != nil {
399 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
400 }
401 defer conn.Close()
402
403 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
404 defer cancel()
405 _, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
406 if err != nil {
407 t.Fatalf("ClientHandshake() returned failed: %q", err)
408 }
409 if err := compareAuthInfo(ctx, ts, ai); err != nil {
410 t.Fatal(err)
411 }
412 })
413 }
414 }
415
416 func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
417 clientDone := make(chan struct{})
418
419
420
421 hErr := errors.New("server handshake error")
422 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
423 <-clientDone
424 return handshakeResult{err: hErr}
425 })
426 defer ts.stop()
427
428 opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
429 creds, err := NewClientCredentials(opts)
430 if err != nil {
431 t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
432 }
433
434 conn, err := net.Dial("tcp", ts.address)
435 if err != nil {
436 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
437 }
438 defer conn.Close()
439
440 sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
441 defer sCancel()
442 ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
443 if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
444 t.Fatal("ClientHandshake() succeeded when expected to timeout")
445 }
446 close(clientDone)
447
448
449
450 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
451 defer cancel()
452 val, err := ts.hsResult.Receive(ctx)
453 if err != nil {
454 t.Fatalf("testServer failed to return handshake result: %v", err)
455 }
456 hsr := val.(handshakeResult)
457 if hsr.err != hErr {
458 t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
459 }
460 }
461
462
463 func (s) TestClientCredsHandshakeFailure(t *testing.T) {
464 tests := []struct {
465 desc string
466 handshakeFunc testHandshakeFunc
467 rootProvider certprovider.Provider
468 san string
469 wantErr string
470 }{
471 {
472 desc: "cert validation failure",
473 handshakeFunc: testServerTLSHandshake,
474 rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"),
475 san: defaultTestCertSAN,
476 wantErr: "x509: certificate signed by unknown authority",
477 },
478 {
479 desc: "SAN mismatch",
480 handshakeFunc: testServerTLSHandshake,
481 rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"),
482 san: "bad-san",
483 wantErr: "do not match any of the accepted SANs",
484 },
485 }
486
487 for _, test := range tests {
488 t.Run(test.desc, func(t *testing.T) {
489 ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
490 defer ts.stop()
491
492 opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
493 creds, err := NewClientCredentials(opts)
494 if err != nil {
495 t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
496 }
497
498 conn, err := net.Dial("tcp", ts.address)
499 if err != nil {
500 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
501 }
502 defer conn.Close()
503
504 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
505 defer cancel()
506 ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
507 if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
508 t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
509 }
510 })
511 }
512 }
513
514
515
516
517
518
519 func (s) TestClientCredsProviderSwitch(t *testing.T) {
520 ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
521 defer ts.stop()
522
523 opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
524 creds, err := NewClientCredentials(opts)
525 if err != nil {
526 t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
527 }
528
529 conn, err := net.Dial("tcp", ts.address)
530 if err != nil {
531 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
532 }
533 defer conn.Close()
534
535 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
536 defer cancel()
537
538
539 root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
540 handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
541
542
543
544 uPtr := unsafe.Pointer(handshakeInfo)
545 addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
546 ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
547 if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
548 t.Fatal("ClientHandshake() succeeded when expected to fail")
549 }
550
551
552 _, err = ts.hsResult.Receive(ctx)
553 if err != nil {
554 t.Errorf("testServer failed to return handshake result: %v", err)
555 }
556
557 conn, err = net.Dial("tcp", ts.address)
558 if err != nil {
559 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
560 }
561 defer conn.Close()
562
563
564
565 root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
566 handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
567
568
569 atomic.StorePointer(&uPtr, unsafe.Pointer(handshakeInfo))
570 _, ai, err := creds.ClientHandshake(ctx, authority, conn)
571 if err != nil {
572 t.Fatalf("ClientHandshake() returned failed: %q", err)
573 }
574 if err := compareAuthInfo(ctx, ts, ai); err != nil {
575 t.Fatal(err)
576 }
577 }
578
579
580 func (s) TestClientClone(t *testing.T) {
581 opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
582 orig, err := NewClientCredentials(opts)
583 if err != nil {
584 t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
585 }
586
587
588
589
590 if clone := orig.Clone(); clone == orig {
591 t.Fatal("return value from Clone() doesn't point to new credentials instance")
592 }
593 }
594
595 func newStringP(s string) *string {
596 return &s
597 }
598
View as plain text