1
18
19 package xds
20
21 import (
22 "context"
23 "crypto/tls"
24 "crypto/x509"
25 "errors"
26 "fmt"
27 "net"
28 "os"
29 "strings"
30 "testing"
31 "time"
32
33 "google.golang.org/grpc/credentials"
34 "google.golang.org/grpc/credentials/tls/certprovider"
35 xdsinternal "google.golang.org/grpc/internal/credentials/xds"
36 "google.golang.org/grpc/testdata"
37 )
38
39 func makeClientTLSConfig(t *testing.T, mTLS bool) *tls.Config {
40 t.Helper()
41
42 pemData, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
43 if err != nil {
44 t.Fatal(err)
45 }
46 roots := x509.NewCertPool()
47 roots.AppendCertsFromPEM(pemData)
48
49 var certs []tls.Certificate
50 if mTLS {
51 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem"))
52 if err != nil {
53 t.Fatal(err)
54 }
55 certs = append(certs, cert)
56 }
57
58 return &tls.Config{
59 Certificates: certs,
60 RootCAs: roots,
61 ServerName: "*.test.example.com",
62
63
64
65
66
67
68 InsecureSkipVerify: true,
69 }
70 }
71
72
73
74 func makeFallbackServerCreds(t *testing.T) credentials.TransportCredentials {
75 t.Helper()
76
77 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
78 if err != nil {
79 t.Fatal(err)
80 }
81 return creds
82 }
83
84 type errorCreds struct {
85 credentials.TransportCredentials
86 }
87
88
89
90 func (s) TestServerCredsWithoutFallback(t *testing.T) {
91 if _, err := NewServerCredentials(ServerOptions{}); err == nil {
92 t.Fatal("NewServerCredentials() succeeded without specifying fallback")
93 }
94 }
95
96 type wrapperConn struct {
97 net.Conn
98 xdsHI *xdsinternal.HandshakeInfo
99 deadline time.Time
100 handshakeInfoErr error
101 }
102
103 func (wc *wrapperConn) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) {
104 return wc.xdsHI, wc.handshakeInfoErr
105 }
106
107 func (wc *wrapperConn) GetDeadline() time.Time {
108 return wc.deadline
109 }
110
111 func newWrappedConn(conn net.Conn, xdsHI *xdsinternal.HandshakeInfo, deadline time.Time) *wrapperConn {
112 return &wrapperConn{Conn: conn, xdsHI: xdsHI, deadline: deadline}
113 }
114
115
116
117
118 func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) {
119 opts := ServerOptions{FallbackCreds: &errorCreds{}}
120 creds, err := NewServerCredentials(opts)
121 if err != nil {
122 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
123 }
124
125 info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false)
126 conn := newWrappedConn(nil, info, time.Time{})
127 if _, _, err := creds.ServerHandshake(conn); err == nil {
128 t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo")
129 }
130 }
131
132
133
134 func (s) TestServerCredsProviderFailure(t *testing.T) {
135 opts := ServerOptions{FallbackCreds: &errorCreds{}}
136 creds, err := NewServerCredentials(opts)
137 if err != nil {
138 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
139 }
140
141 tests := []struct {
142 desc string
143 rootProvider certprovider.Provider
144 identityProvider certprovider.Provider
145 wantErr string
146 }{
147 {
148 desc: "erroring identity provider",
149 identityProvider: &fakeProvider{err: errors.New("identity provider error")},
150 wantErr: "identity provider error",
151 },
152 {
153 desc: "erroring root provider",
154 identityProvider: &fakeProvider{km: &certprovider.KeyMaterial{}},
155 rootProvider: &fakeProvider{err: errors.New("root provider error")},
156 wantErr: "root provider error",
157 },
158 }
159 for _, test := range tests {
160 t.Run(test.desc, func(t *testing.T) {
161 info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false)
162 conn := newWrappedConn(nil, info, time.Time{})
163 if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
164 t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
165 }
166 })
167 }
168 }
169
170
171
172
173
174 func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) {
175 opts := ServerOptions{FallbackCreds: &errorCreds{}}
176 creds, err := NewServerCredentials(opts)
177 if err != nil {
178 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
179 }
180
181
182
183 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
184
185 conn := newWrappedConn(rawConn, nil, time.Now().Add(defaultTestTimeout))
186 hiErr := errors.New("xdsHandshakeInfo error")
187 conn.handshakeInfoErr = hiErr
188
189
190
191
192 _, _, err := creds.ServerHandshake(conn)
193 if !errors.Is(err, hiErr) {
194 return handshakeResult{err: fmt.Errorf("ServerHandshake() returned err: %v, wantErr: %v", err, hiErr)}
195 }
196 return handshakeResult{}
197 })
198 defer ts.stop()
199
200
201
202 rawConn, err := net.Dial("tcp", ts.address)
203 if err != nil {
204 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
205 }
206 defer rawConn.Close()
207
208
209
210 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
211 defer cancel()
212 val, err := ts.hsResult.Receive(ctx)
213 if err != nil {
214 t.Fatalf("testServer failed to return handshake result: %v", err)
215 }
216 hsr := val.(handshakeResult)
217 if hsr.err != nil {
218 t.Fatalf("testServer handshake failure: %v", hsr.err)
219 }
220 }
221
222
223
224
225 func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
226 opts := ServerOptions{FallbackCreds: &errorCreds{}}
227 creds, err := NewServerCredentials(opts)
228 if err != nil {
229 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
230 }
231
232
233
234 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
235 hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true)
236
237
238
239 d := time.Now().Add(defaultTestShortTimeout)
240 rawConn.SetDeadline(d)
241 conn := newWrappedConn(rawConn, hi, d)
242
243
244 if _, _, err := creds.ServerHandshake(conn); err == nil {
245 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to timeout")}
246 }
247 return handshakeResult{}
248 })
249 defer ts.stop()
250
251
252
253 rawConn, err := net.Dial("tcp", ts.address)
254 if err != nil {
255 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
256 }
257 defer rawConn.Close()
258
259
260 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
261 defer cancel()
262 val, err := ts.hsResult.Receive(ctx)
263 if err != nil {
264 t.Fatalf("testServer failed to return handshake result: %v", err)
265 }
266 hsr := val.(handshakeResult)
267 if hsr.err != nil {
268 t.Fatalf("testServer handshake failure: %v", hsr.err)
269 }
270 }
271
272
273
274
275 func (s) TestServerCredsHandshakeFailure(t *testing.T) {
276 opts := ServerOptions{FallbackCreds: &errorCreds{}}
277 creds, err := NewServerCredentials(opts)
278 if err != nil {
279 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
280 }
281
282
283
284 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
285
286
287 hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)
288
289
290
291
292 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
293
294
295 if _, _, err := creds.ServerHandshake(conn); err == nil {
296 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")}
297 }
298 return handshakeResult{}
299 })
300 defer ts.stop()
301
302
303 rawConn, err := net.Dial("tcp", ts.address)
304 if err != nil {
305 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
306 }
307 defer rawConn.Close()
308 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true))
309 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout))
310 if err := tlsConn.Handshake(); err != nil {
311 t.Fatal(err)
312 }
313
314
315
316 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
317 defer cancel()
318 val, err := ts.hsResult.Receive(ctx)
319 if err != nil {
320 t.Fatalf("testServer failed to return handshake result: %v", err)
321 }
322 hsr := val.(handshakeResult)
323 if hsr.err != nil {
324 t.Fatalf("testServer handshake failure: %v", hsr.err)
325 }
326 }
327
328
329 func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
330 tests := []struct {
331 desc string
332 fallbackCreds credentials.TransportCredentials
333 rootProvider certprovider.Provider
334 identityProvider certprovider.Provider
335 requireClientCert bool
336 }{
337 {
338 desc: "fallback",
339 fallbackCreds: makeFallbackServerCreds(t),
340 },
341 {
342 desc: "TLS",
343 fallbackCreds: &errorCreds{},
344 identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"),
345 },
346 {
347 desc: "mTLS",
348 fallbackCreds: &errorCreds{},
349 identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"),
350 rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"),
351 requireClientCert: true,
352 },
353 }
354
355 for _, test := range tests {
356 t.Run(test.desc, func(t *testing.T) {
357
358 opts := ServerOptions{FallbackCreds: test.fallbackCreds}
359 creds, err := NewServerCredentials(opts)
360 if err != nil {
361 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
362 }
363
364
365
366 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
367
368 hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert)
369
370
371
372
373 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
374
375
376
377
378 _, ai, err := creds.ServerHandshake(conn)
379 if err != nil {
380 return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)}
381 }
382 if ai.AuthType() != "tls" {
383 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")}
384 }
385 info, ok := ai.(credentials.TLSInfo)
386 if !ok {
387 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})}
388 }
389 return handshakeResult{connState: info.State}
390 })
391 defer ts.stop()
392
393
394 rawConn, err := net.Dial("tcp", ts.address)
395 if err != nil {
396 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
397 }
398 defer rawConn.Close()
399 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, test.requireClientCert))
400 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout))
401 if err := tlsConn.Handshake(); err != nil {
402 t.Fatal(err)
403 }
404
405
406
407
408 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
409 defer cancel()
410 val, err := ts.hsResult.Receive(ctx)
411 if err != nil {
412 t.Fatalf("testServer failed to return handshake result: %v", err)
413 }
414 hsr := val.(handshakeResult)
415 if hsr.err != nil {
416 t.Fatalf("testServer handshake failure: %v", hsr.err)
417 }
418
419
420
421
422 if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil {
423 t.Fatal(err)
424 }
425 })
426 }
427 }
428
429 func (s) TestServerCredsProviderSwitch(t *testing.T) {
430 opts := ServerOptions{FallbackCreds: &errorCreds{}}
431 creds, err := NewServerCredentials(opts)
432 if err != nil {
433 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
434 }
435
436
437
438
439 cnt := 0
440
441
442 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
443 cnt++
444 var hi *xdsinternal.HandshakeInfo
445 if cnt == 1 {
446
447
448 hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)
449
450
451
452
453 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
454
455
456 if _, _, err := creds.ServerHandshake(conn); err == nil {
457 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")}
458 }
459 return handshakeResult{}
460 }
461
462 hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true)
463
464
465
466
467 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
468
469
470
471
472 _, ai, err := creds.ServerHandshake(conn)
473 if err != nil {
474 return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)}
475 }
476 if ai.AuthType() != "tls" {
477 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")}
478 }
479 info, ok := ai.(credentials.TLSInfo)
480 if !ok {
481 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})}
482 }
483 return handshakeResult{connState: info.State}
484 })
485 defer ts.stop()
486
487 for i := 0; i < 5; i++ {
488
489 rawConn, err := net.Dial("tcp", ts.address)
490 if err != nil {
491 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
492 }
493 defer rawConn.Close()
494 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true))
495 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout))
496 if err := tlsConn.Handshake(); err != nil {
497 t.Fatal(err)
498 }
499
500
501
502
503 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
504 defer cancel()
505 val, err := ts.hsResult.Receive(ctx)
506 if err != nil {
507 t.Fatalf("testServer failed to return handshake result: %v", err)
508 }
509 hsr := val.(handshakeResult)
510 if hsr.err != nil {
511 t.Fatalf("testServer handshake failure: %v", hsr.err)
512 }
513 if i == 0 {
514
515
516 continue
517 }
518
519
520
521 if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil {
522 t.Fatal(err)
523 }
524 }
525 }
526
527
528 func (s) TestServerClone(t *testing.T) {
529 opts := ServerOptions{FallbackCreds: makeFallbackServerCreds(t)}
530 orig, err := NewServerCredentials(opts)
531 if err != nil {
532 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
533 }
534
535
536
537
538 if clone := orig.Clone(); clone == orig {
539 t.Fatal("return value from Clone() doesn't point to new credentials instance")
540 }
541 }
542
View as plain text