1
18
19 package handshaker
20
21 import (
22 "bytes"
23 "context"
24 "errors"
25 "fmt"
26 "io"
27 "net"
28 "strings"
29 "testing"
30
31 "github.com/google/go-cmp/cmp"
32 "github.com/google/go-cmp/cmp/cmpopts"
33 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
34 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
35 "github.com/google/s2a-go/internal/tokenmanager"
36 "golang.org/x/sync/errgroup"
37 grpc "google.golang.org/grpc"
38 "google.golang.org/protobuf/testing/protocmp"
39 )
40
41 var (
42 testAccessToken = "test_access_token"
43
44
45 testHSAddr = "handshaker_address"
46
47
48 testHostname = "localhost"
49
50
51
52 testClientHandshakerOptions = &ClientHandshakerOptions{
53 MinTLSVersion: commonpb.TLSVersion_TLS1_2,
54 MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
55 TLSCiphersuites: []commonpb.Ciphersuite{
56 commonpb.Ciphersuite_AES_128_GCM_SHA256,
57 commonpb.Ciphersuite_AES_256_GCM_SHA384,
58 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
59 },
60 TargetIdentities: []*commonpb.Identity{
61 {
62 IdentityOneof: &commonpb.Identity_SpiffeId{
63 SpiffeId: "target_spiffe_id",
64 },
65 },
66 {
67 IdentityOneof: &commonpb.Identity_Hostname{
68 Hostname: "target_hostname",
69 },
70 },
71 },
72 LocalIdentity: &commonpb.Identity{
73 IdentityOneof: &commonpb.Identity_SpiffeId{
74 SpiffeId: "client_local_spiffe_id",
75 },
76 },
77 TargetName: testHostname + ":1234",
78 }
79
80
81
82 testClientStart = &s2apb.ClientSessionStartReq{
83 ApplicationProtocols: []string{"grpc"},
84 MinTlsVersion: commonpb.TLSVersion_TLS1_2,
85 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
86 TlsCiphersuites: []commonpb.Ciphersuite{
87 commonpb.Ciphersuite_AES_128_GCM_SHA256,
88 commonpb.Ciphersuite_AES_256_GCM_SHA384,
89 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
90 },
91 TargetIdentities: []*commonpb.Identity{
92 {
93 IdentityOneof: &commonpb.Identity_SpiffeId{
94 SpiffeId: "target_spiffe_id",
95 },
96 },
97 {
98 IdentityOneof: &commonpb.Identity_Hostname{
99 Hostname: "target_hostname",
100 },
101 },
102 },
103 LocalIdentity: &commonpb.Identity{
104 IdentityOneof: &commonpb.Identity_SpiffeId{
105 SpiffeId: "client_local_spiffe_id",
106 },
107 },
108 TargetName: testHostname,
109 }
110
111
112
113 testClientNext = &s2apb.SessionNextReq{
114 InBytes: []byte("ServerHelloServerFinished"),
115 }
116
117
118
119 testServerHandshakerOptions = &ServerHandshakerOptions{
120 MinTLSVersion: commonpb.TLSVersion_TLS1_2,
121 MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
122 TLSCiphersuites: []commonpb.Ciphersuite{
123 commonpb.Ciphersuite_AES_128_GCM_SHA256,
124 commonpb.Ciphersuite_AES_256_GCM_SHA384,
125 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
126 },
127 LocalIdentities: []*commonpb.Identity{
128 {
129 IdentityOneof: &commonpb.Identity_SpiffeId{
130 SpiffeId: "server_local_spiffe_id",
131 },
132 },
133 {
134 IdentityOneof: &commonpb.Identity_Hostname{
135 Hostname: "server_local_hostname",
136 },
137 },
138 },
139 }
140
141
142
143 testServerStart = &s2apb.ServerSessionStartReq{
144 ApplicationProtocols: []string{"grpc"},
145 MinTlsVersion: commonpb.TLSVersion_TLS1_2,
146 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
147 TlsCiphersuites: []commonpb.Ciphersuite{
148 commonpb.Ciphersuite_AES_128_GCM_SHA256,
149 commonpb.Ciphersuite_AES_256_GCM_SHA384,
150 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
151 },
152 LocalIdentities: []*commonpb.Identity{
153 {
154 IdentityOneof: &commonpb.Identity_SpiffeId{
155 SpiffeId: "server_local_spiffe_id",
156 },
157 },
158 {
159 IdentityOneof: &commonpb.Identity_Hostname{
160 Hostname: "server_local_hostname",
161 },
162 },
163 },
164 InBytes: []byte("ClientHello"),
165 }
166
167
168
169 testServerNext = &s2apb.SessionNextReq{
170 InBytes: []byte("ClientFinished"),
171 }
172
173 testClientSessionResult = &s2apb.SessionResult{
174 ApplicationProtocol: "grpc",
175 State: &s2apb.SessionState{
176 TlsVersion: commonpb.TLSVersion_TLS1_3,
177 TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
178 InSequence: 0,
179 OutSequence: 0,
180 InKey: make([]byte, 32),
181 OutKey: make([]byte, 32),
182 },
183 PeerIdentity: &commonpb.Identity{
184 IdentityOneof: &commonpb.Identity_SpiffeId{
185 SpiffeId: "client_local_spiffe_id",
186 },
187 },
188 LocalIdentity: &commonpb.Identity{
189 IdentityOneof: &commonpb.Identity_SpiffeId{
190 SpiffeId: "server_local_spiffe_id",
191 },
192 },
193 LocalCertFingerprint: []byte("client_cert_fingerprint"),
194 PeerCertFingerprint: []byte("server_cert_fingerprint"),
195 }
196
197 testServerSessionResult = &s2apb.SessionResult{
198 ApplicationProtocol: "grpc",
199 State: &s2apb.SessionState{
200 TlsVersion: commonpb.TLSVersion_TLS1_3,
201 TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
202 InSequence: 0,
203 OutSequence: 0,
204 InKey: make([]byte, 32),
205 OutKey: make([]byte, 32),
206 },
207 PeerIdentity: &commonpb.Identity{
208 IdentityOneof: &commonpb.Identity_SpiffeId{
209 SpiffeId: "server_local_spiffe_id",
210 },
211 },
212 LocalIdentity: &commonpb.Identity{
213 IdentityOneof: &commonpb.Identity_SpiffeId{
214 SpiffeId: "client_local_spiffe_id",
215 },
216 },
217 LocalCertFingerprint: []byte("server_cert_fingerprint"),
218 PeerCertFingerprint: []byte("client_cert_fingerprint"),
219 }
220 testResultWithoutLocalIdentity = &s2apb.SessionResult{
221 ApplicationProtocol: "grpc",
222 State: &s2apb.SessionState{
223 TlsVersion: commonpb.TLSVersion_TLS1_3,
224 TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
225 InSequence: 0,
226 OutSequence: 0,
227 InKey: make([]byte, 32),
228 OutKey: make([]byte, 32),
229 },
230 PeerIdentity: &commonpb.Identity{
231 IdentityOneof: &commonpb.Identity_SpiffeId{
232 SpiffeId: "server_local_spiffe_id",
233 },
234 },
235 LocalCertFingerprint: []byte("server_cert_fingerprint"),
236 PeerCertFingerprint: []byte("client_cert_fingerprint"),
237 }
238 )
239
240
241
242 type fakeConn struct {
243 net.Conn
244 in *bytes.Buffer
245 out *bytes.Buffer
246 }
247
248 func (fc *fakeConn) Read(b []byte) (n int, err error) { return fc.in.Read(b) }
249 func (fc *fakeConn) Write(b []byte) (n int, err error) { return fc.out.Write(b) }
250 func (fc *fakeConn) Close() error { return nil }
251
252
253
254 type fakeInvalidConn struct {
255 net.Conn
256 }
257
258 func (fc *fakeInvalidConn) Read(_ []byte) (n int, err error) { return 0, io.EOF }
259 func (fc *fakeInvalidConn) Write(_ []byte) (n int, err error) { return 0, nil }
260 func (fc *fakeInvalidConn) Close() error { return nil }
261
262
263
264 type fakeStream struct {
265 grpc.ClientStream
266 t *testing.T
267 fc *fakeConn
268 expectedClientStart *s2apb.ClientSessionStartReq
269 expectedServerStart *s2apb.ServerSessionStartReq
270 expectToken bool
271
272
273 expectedResp *s2apb.SessionResp
274
275
276 isFirstAccess bool
277 isClient bool
278 isLocalIdentityMissing bool
279 }
280
281 func (fs *fakeStream) Recv() (*s2apb.SessionResp, error) {
282 resp := fs.expectedResp
283 fs.expectedResp = nil
284 return resp, nil
285 }
286 func (fs *fakeStream) Send(req *s2apb.SessionReq) error {
287 var resp *s2apb.SessionResp
288 if fs.expectToken {
289 if len(req.GetAuthMechanisms()) == 0 {
290 return fmt.Errorf("request to S2A did not contain any tokens")
291 }
292
293 for _, authMechanism := range req.GetAuthMechanisms() {
294 if authMechanism.GetToken() != testAccessToken {
295 return fmt.Errorf("request to S2A contained invalid token")
296 }
297 }
298 }
299 if !fs.isFirstAccess {
300
301
302 fs.isFirstAccess = true
303 if fs.isClient {
304 if diff := cmp.Diff(req.GetClientStart(), fs.expectedClientStart, protocmp.Transform()); diff != "" {
305 return fmt.Errorf("client start message is incorrect, (-want +got):\n%s", diff)
306 }
307 resp = &s2apb.SessionResp{
308 OutFrames: []byte("ClientHello"),
309
310 BytesConsumed: 0,
311 }
312 } else {
313
314 if req.GetServerStart() == nil {
315 return errors.New("first request from server does not have server start")
316 }
317 if diff := cmp.Diff(req.GetServerStart(), fs.expectedServerStart, protocmp.Transform()); diff != "" {
318 return fmt.Errorf("server start message is incorrect, (-want +got):\n%s", diff)
319 }
320 fs.fc.in.Write([]byte("ClientFinished"))
321 resp = &s2apb.SessionResp{
322 OutFrames: []byte("ServerHelloServerFinished"),
323
324 BytesConsumed: uint32(len("ClientHello")),
325 }
326 }
327 } else {
328
329 if fs.isClient {
330
331 if req.GetNext() == nil {
332 return errors.New("second request from client does not have next")
333 }
334 if got, want := cmp.Equal(req.GetNext(), testClientNext, protocmp.Transform()), true; got != want {
335 return errors.New("client next message is incorrect")
336 }
337 if fs.isLocalIdentityMissing {
338 resp = &s2apb.SessionResp{
339 Result: testResultWithoutLocalIdentity,
340 BytesConsumed: uint32(len("ClientFinished")),
341 }
342 } else {
343 resp = &s2apb.SessionResp{
344 Result: testClientSessionResult,
345 BytesConsumed: uint32(len("ServerHelloServerFinished")),
346 }
347 }
348 } else {
349
350 if req.GetNext() == nil {
351 return errors.New("second request from server does not have next")
352 }
353 if got, want := cmp.Equal(req.GetNext(), testServerNext, protocmp.Transform()), true; got != want {
354 return errors.New("server next message is incorrect")
355 }
356 if fs.isLocalIdentityMissing {
357 resp = &s2apb.SessionResp{
358 Result: testResultWithoutLocalIdentity,
359 BytesConsumed: uint32(len("ClientFinished")),
360 }
361 } else {
362 resp = &s2apb.SessionResp{
363 Result: testServerSessionResult,
364 BytesConsumed: uint32(len("ClientFinished")),
365 }
366 }
367 }
368 }
369 fs.expectedResp = resp
370 return nil
371 }
372
373 func (*fakeStream) CloseSend() error { return nil }
374
375
376
377 type fakeInvalidStream struct {
378 grpc.ClientStream
379 }
380
381 func (*fakeInvalidStream) Recv() (*s2apb.SessionResp, error) { return &s2apb.SessionResp{}, nil }
382 func (*fakeInvalidStream) Send(*s2apb.SessionReq) error { return nil }
383 func (*fakeInvalidStream) CloseSend() error { return nil }
384
385 type fakeAccessTokenManager struct {
386 acceptedIdentity *commonpb.Identity
387 accessToken string
388 allowEmptyIdentity bool
389 }
390
391 func (m *fakeAccessTokenManager) DefaultToken() (string, error) {
392 if !m.allowEmptyIdentity {
393 return "", fmt.Errorf("not allowed to get token for empty identity")
394 }
395 return m.accessToken, nil
396 }
397
398 func (m *fakeAccessTokenManager) Token(identity *commonpb.Identity) (string, error) {
399 if identity == nil || cmp.Equal(identity, &commonpb.Identity{}, protocmp.Transform()) {
400 if !m.allowEmptyIdentity {
401 return "", fmt.Errorf("not allowed to get token for empty identity")
402 }
403 return m.accessToken, nil
404 }
405 if cmp.Equal(identity, m.acceptedIdentity, protocmp.Transform()) {
406 return m.accessToken, nil
407 }
408 return "", fmt.Errorf("unable to get token")
409 }
410
411
412
413 func TestNewClientHandshaker(t *testing.T) {
414 stream := &fakeStream{}
415 c := &fakeConn{}
416 chs := newClientHandshaker(stream, c, testHSAddr, testClientHandshakerOptions, &fakeAccessTokenManager{})
417 if chs.clientOpts != testClientHandshakerOptions || chs.conn != c {
418 t.Errorf("handshaker parameters incorrect")
419 }
420 }
421
422
423
424 func TestNewServerHandshaker(t *testing.T) {
425 stream := &fakeStream{}
426 c := &fakeConn{}
427 shs := newServerHandshaker(stream, c, testHSAddr, testServerHandshakerOptions, &fakeAccessTokenManager{})
428 if shs.serverOpts != testServerHandshakerOptions || shs.conn != c {
429 t.Errorf("handshaker parameters incorrect")
430 }
431 }
432
433 func TestClientHandshakeSuccess(t *testing.T) {
434 for _, tc := range []struct {
435 description string
436 options *ClientHandshakerOptions
437 tokenManager tokenmanager.AccessTokenManager
438 expectedClientStart *s2apb.ClientSessionStartReq
439 }{
440 {
441 description: "full client options",
442 options: testClientHandshakerOptions,
443 expectedClientStart: testClientStart,
444 },
445 {
446 description: "full client options with no port in target name",
447 options: &ClientHandshakerOptions{
448 MinTLSVersion: commonpb.TLSVersion_TLS1_2,
449 MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
450 TLSCiphersuites: []commonpb.Ciphersuite{
451 commonpb.Ciphersuite_AES_128_GCM_SHA256,
452 commonpb.Ciphersuite_AES_256_GCM_SHA384,
453 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
454 },
455 TargetIdentities: []*commonpb.Identity{
456 {
457 IdentityOneof: &commonpb.Identity_SpiffeId{
458 SpiffeId: "target_spiffe_id",
459 },
460 },
461 {
462 IdentityOneof: &commonpb.Identity_Hostname{
463 Hostname: "target_hostname",
464 },
465 },
466 },
467 LocalIdentity: &commonpb.Identity{
468 IdentityOneof: &commonpb.Identity_SpiffeId{
469 SpiffeId: "client_local_spiffe_id",
470 },
471 },
472 TargetName: testHostname,
473 },
474 expectedClientStart: &s2apb.ClientSessionStartReq{
475 ApplicationProtocols: []string{"grpc"},
476 MinTlsVersion: commonpb.TLSVersion_TLS1_2,
477 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
478 TlsCiphersuites: []commonpb.Ciphersuite{
479 commonpb.Ciphersuite_AES_128_GCM_SHA256,
480 commonpb.Ciphersuite_AES_256_GCM_SHA384,
481 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
482 },
483 TargetIdentities: []*commonpb.Identity{
484 {
485 IdentityOneof: &commonpb.Identity_SpiffeId{
486 SpiffeId: "target_spiffe_id",
487 },
488 },
489 {
490 IdentityOneof: &commonpb.Identity_Hostname{
491 Hostname: "target_hostname",
492 },
493 },
494 },
495 LocalIdentity: &commonpb.Identity{
496 IdentityOneof: &commonpb.Identity_SpiffeId{
497 SpiffeId: "client_local_spiffe_id",
498 },
499 },
500 TargetName: testHostname,
501 },
502 },
503 {
504 description: "full client options with no local identity",
505 options: &ClientHandshakerOptions{
506 MinTLSVersion: commonpb.TLSVersion_TLS1_2,
507 MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
508 TLSCiphersuites: []commonpb.Ciphersuite{
509 commonpb.Ciphersuite_AES_128_GCM_SHA256,
510 commonpb.Ciphersuite_AES_256_GCM_SHA384,
511 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
512 },
513 TargetIdentities: []*commonpb.Identity{
514 {
515 IdentityOneof: &commonpb.Identity_SpiffeId{
516 SpiffeId: "target_spiffe_id",
517 },
518 },
519 {
520 IdentityOneof: &commonpb.Identity_Hostname{
521 Hostname: "target_hostname",
522 },
523 },
524 },
525 TargetName: testHostname + ":1234",
526 },
527 expectedClientStart: &s2apb.ClientSessionStartReq{
528 ApplicationProtocols: []string{"grpc"},
529 MinTlsVersion: commonpb.TLSVersion_TLS1_2,
530 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
531 TlsCiphersuites: []commonpb.Ciphersuite{
532 commonpb.Ciphersuite_AES_128_GCM_SHA256,
533 commonpb.Ciphersuite_AES_256_GCM_SHA384,
534 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
535 },
536 TargetIdentities: []*commonpb.Identity{
537 {
538 IdentityOneof: &commonpb.Identity_SpiffeId{
539 SpiffeId: "target_spiffe_id",
540 },
541 },
542 {
543 IdentityOneof: &commonpb.Identity_Hostname{
544 Hostname: "target_hostname",
545 },
546 },
547 },
548 TargetName: testHostname,
549 },
550 },
551 {
552 description: "full client options, sending tokens",
553 options: testClientHandshakerOptions,
554 expectedClientStart: testClientStart,
555 tokenManager: &fakeAccessTokenManager{
556 accessToken: testAccessToken,
557 acceptedIdentity: &commonpb.Identity{
558 IdentityOneof: &commonpb.Identity_SpiffeId{
559 SpiffeId: "client_local_spiffe_id",
560 },
561 },
562 },
563 },
564 {
565 description: "full client options with no local identity, sending tokens",
566 options: &ClientHandshakerOptions{
567 MinTLSVersion: commonpb.TLSVersion_TLS1_2,
568 MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
569 TLSCiphersuites: []commonpb.Ciphersuite{
570 commonpb.Ciphersuite_AES_128_GCM_SHA256,
571 commonpb.Ciphersuite_AES_256_GCM_SHA384,
572 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
573 },
574 TargetIdentities: []*commonpb.Identity{
575 {
576 IdentityOneof: &commonpb.Identity_SpiffeId{
577 SpiffeId: "target_spiffe_id",
578 },
579 },
580 {
581 IdentityOneof: &commonpb.Identity_Hostname{
582 Hostname: "target_hostname",
583 },
584 },
585 },
586 TargetName: testHostname + ":1234",
587 },
588 expectedClientStart: &s2apb.ClientSessionStartReq{
589 ApplicationProtocols: []string{"grpc"},
590 MinTlsVersion: commonpb.TLSVersion_TLS1_2,
591 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
592 TlsCiphersuites: []commonpb.Ciphersuite{
593 commonpb.Ciphersuite_AES_128_GCM_SHA256,
594 commonpb.Ciphersuite_AES_256_GCM_SHA384,
595 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
596 },
597 TargetIdentities: []*commonpb.Identity{
598 {
599 IdentityOneof: &commonpb.Identity_SpiffeId{
600 SpiffeId: "target_spiffe_id",
601 },
602 },
603 {
604 IdentityOneof: &commonpb.Identity_Hostname{
605 Hostname: "target_hostname",
606 },
607 },
608 },
609 TargetName: testHostname,
610 },
611 tokenManager: &fakeAccessTokenManager{
612 accessToken: testAccessToken,
613 allowEmptyIdentity: true,
614 },
615 },
616 } {
617 t.Run(tc.description, func(t *testing.T) {
618
619 var errg errgroup.Group
620 stream := &fakeStream{
621 t: t,
622 isClient: true,
623 expectedClientStart: tc.expectedClientStart,
624 expectToken: (tc.tokenManager != nil),
625 }
626 in := bytes.NewBuffer([]byte("ServerHelloServerFinished"))
627 c := &fakeConn{
628 in: in,
629 out: new(bytes.Buffer),
630 }
631
632
633 chs := newClientHandshaker(stream, c, testHSAddr, tc.options, tc.tokenManager)
634 errg.Go(func() error {
635 newConn, auth, err := chs.ClientHandshake(context.Background())
636 if err != nil {
637 return err
638 }
639 if auth.AuthType() != "s2a" {
640 return errors.New("s2a auth type incorrect")
641 }
642 if newConn == nil {
643 return errors.New("expected non-nil net.Conn")
644 }
645 if err := chs.Close(); err != nil {
646 t.Errorf("chs.Close() failed: %v", err)
647 }
648 return nil
649 })
650
651 if err := errg.Wait(); err != nil {
652 t.Errorf("client handshake failed: %v", err)
653 }
654 })
655 }
656 }
657
658 func TestServerHandshakeSuccess(t *testing.T) {
659 for _, tc := range []struct {
660 description string
661 options *ServerHandshakerOptions
662 tokenManager tokenmanager.AccessTokenManager
663 expectedServerStart *s2apb.ServerSessionStartReq
664 }{
665 {
666 description: "full server options",
667 options: testServerHandshakerOptions,
668 expectedServerStart: testServerStart,
669 },
670 {
671 description: "full server options with no local identities",
672 options: &ServerHandshakerOptions{
673 MinTLSVersion: commonpb.TLSVersion_TLS1_2,
674 MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
675 TLSCiphersuites: []commonpb.Ciphersuite{
676 commonpb.Ciphersuite_AES_128_GCM_SHA256,
677 commonpb.Ciphersuite_AES_256_GCM_SHA384,
678 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
679 },
680 },
681 expectedServerStart: &s2apb.ServerSessionStartReq{
682 ApplicationProtocols: []string{"grpc"},
683 MinTlsVersion: commonpb.TLSVersion_TLS1_2,
684 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
685 TlsCiphersuites: []commonpb.Ciphersuite{
686 commonpb.Ciphersuite_AES_128_GCM_SHA256,
687 commonpb.Ciphersuite_AES_256_GCM_SHA384,
688 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
689 },
690 InBytes: []byte("ClientHello"),
691 },
692 },
693 {
694 description: "full server options, sending tokens",
695 options: testServerHandshakerOptions,
696 expectedServerStart: testServerStart,
697 tokenManager: &fakeAccessTokenManager{
698 accessToken: testAccessToken,
699 acceptedIdentity: &commonpb.Identity{
700 IdentityOneof: &commonpb.Identity_SpiffeId{
701 SpiffeId: "server_local_spiffe_id",
702 },
703 },
704 },
705 },
706 {
707 description: "full server options with no local identity, sending tokens",
708 options: &ServerHandshakerOptions{
709 MinTLSVersion: commonpb.TLSVersion_TLS1_2,
710 MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
711 TLSCiphersuites: []commonpb.Ciphersuite{
712 commonpb.Ciphersuite_AES_128_GCM_SHA256,
713 commonpb.Ciphersuite_AES_256_GCM_SHA384,
714 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
715 },
716 },
717 expectedServerStart: &s2apb.ServerSessionStartReq{
718 ApplicationProtocols: []string{"grpc"},
719 MinTlsVersion: commonpb.TLSVersion_TLS1_2,
720 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
721 TlsCiphersuites: []commonpb.Ciphersuite{
722 commonpb.Ciphersuite_AES_128_GCM_SHA256,
723 commonpb.Ciphersuite_AES_256_GCM_SHA384,
724 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
725 },
726 InBytes: []byte("ClientHello"),
727 },
728 tokenManager: &fakeAccessTokenManager{
729 accessToken: testAccessToken,
730 allowEmptyIdentity: true,
731 },
732 },
733 } {
734 t.Run(tc.description, func(t *testing.T) {
735
736 var errg errgroup.Group
737 in := bytes.NewBuffer([]byte("ClientHello"))
738 c := &fakeConn{
739 in: in,
740 out: new(bytes.Buffer),
741 }
742 stream := &fakeStream{
743 t: t,
744 fc: c,
745 isClient: false,
746 expectedServerStart: tc.expectedServerStart,
747 expectToken: (tc.tokenManager != nil),
748 }
749
750
751 shs := newServerHandshaker(stream, c, testHSAddr, tc.options, tc.tokenManager)
752 errg.Go(func() error {
753 newConn, auth, err := shs.ServerHandshake(context.Background())
754 if err != nil {
755 return err
756 }
757 if auth.AuthType() != "s2a" {
758 return errors.New("s2a auth type incorrect")
759 }
760 if newConn == nil {
761 return errors.New("expected non-nil net.Conn")
762 }
763 if err = shs.Close(); err != nil {
764 t.Errorf("shs.Close() failed: %v", err)
765 }
766 return nil
767 })
768
769 if err := errg.Wait(); err != nil {
770 t.Errorf("server handshake failed: %v", err)
771 }
772 })
773 }
774 }
775
776
777
778
779 func TestS2ARejectsTokenFromClient(t *testing.T) {
780 stream := &fakeStream{
781 t: t,
782 isClient: true,
783 expectToken: true,
784 }
785 in := bytes.NewBuffer([]byte("ServerHelloServerFinished"))
786 c := &fakeConn{
787 in: in,
788 out: new(bytes.Buffer),
789 }
790 tokenManager := &fakeAccessTokenManager{
791 accessToken: "bad_access_token",
792 acceptedIdentity: &commonpb.Identity{
793 IdentityOneof: &commonpb.Identity_SpiffeId{
794 SpiffeId: "client_local_spiffe_id",
795 },
796 },
797 }
798
799 chs := newClientHandshaker(stream, c, testHSAddr, testClientHandshakerOptions, tokenManager)
800 _, _, err := chs.ClientHandshake(context.Background())
801 if err == nil {
802 t.Errorf("expected non-nil error from call to chs.ClientHandshake()")
803 }
804 if !strings.Contains(err.Error(), "request to S2A contained invalid token") {
805 t.Errorf("chs.ClientHandshake() produced unexpected error: %v", err)
806 }
807 }
808
809 func TestS2ARejectsTokenFromServer(t *testing.T) {
810 stream := &fakeStream{
811 t: t,
812 isClient: false,
813 expectToken: true,
814 }
815 in := bytes.NewBuffer([]byte("ClientHello"))
816 c := &fakeConn{
817 in: in,
818 out: new(bytes.Buffer),
819 }
820 tokenManager := &fakeAccessTokenManager{
821 accessToken: "bad_access_token",
822 acceptedIdentity: &commonpb.Identity{
823 IdentityOneof: &commonpb.Identity_SpiffeId{
824 SpiffeId: "server_local_spiffe_id",
825 },
826 },
827 }
828
829 chs := newServerHandshaker(stream, c, testHSAddr, testServerHandshakerOptions, tokenManager)
830 _, _, err := chs.ServerHandshake(context.Background())
831 if err == nil {
832 t.Errorf("expected non-nil error from call to chs.ServerHandshake()")
833 }
834 if !strings.Contains(err.Error(), "request to S2A contained invalid token") {
835 t.Errorf("chs.ServerHandshake() produced unexpected error: %v", err)
836 }
837 }
838
839 func TestInvalidHandshaker(t *testing.T) {
840 emptyCHS := &s2aHandshaker{
841 isClient: false,
842 }
843 _, _, err := emptyCHS.ClientHandshake(context.Background())
844 if err == nil {
845 t.Error("ClientHandshake() should fail with server-side handshaker service")
846 }
847 emptySHS := &s2aHandshaker{
848 isClient: true,
849 }
850 _, _, err = emptySHS.ServerHandshake(context.Background())
851 if err == nil {
852 t.Error("ServerHandshake() should fail with client-side handshaker service")
853 }
854 }
855
856
857
858 func TestPeerNotResponding(t *testing.T) {
859 stream := &fakeInvalidStream{}
860 c := &fakeInvalidConn{}
861 chs := &s2aHandshaker{
862 stream: stream,
863 conn: c,
864 clientOpts: testClientHandshakerOptions,
865 isClient: true,
866 hsAddr: testHSAddr,
867 }
868 _, authInfo, err := chs.ClientHandshake(context.Background())
869 if authInfo != nil {
870 t.Error("expected non-nil S2A authInfo")
871 }
872 if got, want := err, errPeerNotResponding; got != want {
873 t.Errorf("ClientHandshake() = %v, want %v", got, want)
874 }
875 if err = chs.Close(); err != nil {
876 t.Errorf("chs.Close() failed: %v", err)
877 }
878 }
879
880
881
882 func TestLocalIdentityNotSet(t *testing.T) {
883 var errg errgroup.Group
884 stream := &fakeStream{
885 t: t,
886 isClient: true,
887 isLocalIdentityMissing: true,
888 }
889 in := bytes.NewBuffer([]byte("ServerHelloServerFinished"))
890 c := &fakeConn{
891 in: in,
892 out: new(bytes.Buffer),
893 }
894 chs := &s2aHandshaker{
895 stream: stream,
896 conn: c,
897 clientOpts: testClientHandshakerOptions,
898 isClient: true,
899 hsAddr: testHSAddr,
900 }
901 errg.Go(func() error {
902 newConn, auth, err := chs.ClientHandshake(context.Background())
903 if cmp.Equal(err, errors.New("local identity must be populated in session result"), cmpopts.EquateErrors()) {
904 return fmt.Errorf("unexpected error: %v", err)
905 }
906 if auth != nil {
907 return errors.New("expected nil credentials.AuthInfo")
908 }
909 if newConn != nil {
910 return errors.New("expected nil net.Conn")
911 }
912 return nil
913 })
914
915 if err := errg.Wait(); err != nil {
916 t.Errorf("client handshake failed: %v", err)
917 }
918 }
919
920 func TestGetAuthMechanismsForClient(t *testing.T) {
921 sortProtos := cmpopts.SortSlices(func(m1, m2 *s2apb.AuthenticationMechanism) bool { return m1.String() < m2.String() })
922 for _, tc := range []struct {
923 description string
924 options *ClientHandshakerOptions
925 tokenManager tokenmanager.AccessTokenManager
926 expectedAuthMechanisms []*s2apb.AuthenticationMechanism
927 }{
928 {
929 description: "token manager is nil",
930 tokenManager: nil,
931 expectedAuthMechanisms: nil,
932 },
933 {
934 description: "token manager expects empty identity",
935 tokenManager: &fakeAccessTokenManager{
936 accessToken: testAccessToken,
937 allowEmptyIdentity: true,
938 },
939 expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
940 {
941 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
942 Token: testAccessToken,
943 },
944 },
945 },
946 },
947 {
948 description: "token manager does not expect empty identity",
949 tokenManager: &fakeAccessTokenManager{
950 allowEmptyIdentity: false,
951 },
952 expectedAuthMechanisms: nil,
953 },
954 {
955 description: "token manager expects SPIFFE ID",
956 options: &ClientHandshakerOptions{
957 LocalIdentity: &commonpb.Identity{
958 IdentityOneof: &commonpb.Identity_SpiffeId{
959 SpiffeId: "allowed_spiffe_id",
960 },
961 },
962 },
963 tokenManager: &fakeAccessTokenManager{
964 accessToken: testAccessToken,
965 acceptedIdentity: &commonpb.Identity{
966 IdentityOneof: &commonpb.Identity_SpiffeId{
967 SpiffeId: "allowed_spiffe_id",
968 },
969 },
970 },
971 expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
972 {
973 Identity: &commonpb.Identity{
974 IdentityOneof: &commonpb.Identity_SpiffeId{
975 SpiffeId: "allowed_spiffe_id",
976 },
977 },
978 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
979 Token: testAccessToken,
980 },
981 },
982 },
983 },
984 {
985 description: "token manager does not expect hostname",
986 options: &ClientHandshakerOptions{
987 LocalIdentity: &commonpb.Identity{
988 IdentityOneof: &commonpb.Identity_Hostname{
989 Hostname: "disallowed_hostname",
990 },
991 },
992 },
993 tokenManager: &fakeAccessTokenManager{},
994 expectedAuthMechanisms: nil,
995 },
996 } {
997 t.Run(tc.description, func(t *testing.T) {
998 handshaker := newClientHandshaker(nil, nil, "", tc.options, tc.tokenManager)
999 authMechanisms := handshaker.getAuthMechanisms()
1000 if got, want := (authMechanisms == nil), (tc.expectedAuthMechanisms == nil); got != want {
1001 t.Errorf("authMechanisms == nil: %t, tc.expectedAuthMechanisms == nil: %t", got, want)
1002 }
1003 if authMechanisms != nil && tc.expectedAuthMechanisms != nil {
1004 if diff := cmp.Diff(authMechanisms, tc.expectedAuthMechanisms, protocmp.Transform(), sortProtos); diff != "" {
1005 t.Errorf("handshaker.getAuthMechanisms() returned incorrect slice, (-want +got):\n%s", diff)
1006 }
1007 }
1008 })
1009 }
1010 }
1011
1012 func TestGetAuthMechanismsForServer(t *testing.T) {
1013 sortProtos := cmpopts.SortSlices(func(m1, m2 *s2apb.AuthenticationMechanism) bool { return m1.String() < m2.String() })
1014 for _, tc := range []struct {
1015 description string
1016 options *ServerHandshakerOptions
1017 tokenManager tokenmanager.AccessTokenManager
1018 expectedAuthMechanisms []*s2apb.AuthenticationMechanism
1019 }{
1020 {
1021 description: "token manager is nil",
1022 tokenManager: nil,
1023 expectedAuthMechanisms: nil,
1024 },
1025 {
1026 description: "token manager expects empty identity",
1027 tokenManager: &fakeAccessTokenManager{
1028 accessToken: testAccessToken,
1029 allowEmptyIdentity: true,
1030 },
1031 expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
1032 {
1033 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
1034 Token: testAccessToken,
1035 },
1036 },
1037 },
1038 },
1039 {
1040 description: "token manager does not expect empty identity",
1041 tokenManager: &fakeAccessTokenManager{
1042 allowEmptyIdentity: false,
1043 },
1044 expectedAuthMechanisms: nil,
1045 },
1046 {
1047 description: "token manager expects 2 SPIFFE IDs",
1048 options: &ServerHandshakerOptions{
1049 LocalIdentities: []*commonpb.Identity{
1050 {
1051 IdentityOneof: &commonpb.Identity_SpiffeId{
1052 SpiffeId: "allowed_spiffe_id",
1053 },
1054 },
1055 {
1056 IdentityOneof: &commonpb.Identity_SpiffeId{
1057 SpiffeId: "allowed_spiffe_id",
1058 },
1059 },
1060 },
1061 },
1062 tokenManager: &fakeAccessTokenManager{
1063 accessToken: testAccessToken,
1064 acceptedIdentity: &commonpb.Identity{
1065 IdentityOneof: &commonpb.Identity_SpiffeId{
1066 SpiffeId: "allowed_spiffe_id",
1067 },
1068 },
1069 },
1070 expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
1071 {
1072 Identity: &commonpb.Identity{
1073 IdentityOneof: &commonpb.Identity_SpiffeId{
1074 SpiffeId: "allowed_spiffe_id",
1075 },
1076 },
1077 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
1078 Token: testAccessToken,
1079 },
1080 },
1081 {
1082 Identity: &commonpb.Identity{
1083 IdentityOneof: &commonpb.Identity_SpiffeId{
1084 SpiffeId: "allowed_spiffe_id",
1085 },
1086 },
1087 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
1088 Token: testAccessToken,
1089 },
1090 },
1091 },
1092 },
1093 {
1094 description: "token manager expects a SPIFFE ID but does not expect hostname",
1095 options: &ServerHandshakerOptions{
1096 LocalIdentities: []*commonpb.Identity{
1097 {
1098 IdentityOneof: &commonpb.Identity_SpiffeId{
1099 SpiffeId: "allowed_spiffe_id",
1100 },
1101 },
1102 {
1103 IdentityOneof: &commonpb.Identity_Hostname{
1104 Hostname: "disallowed_hostname",
1105 },
1106 },
1107 },
1108 },
1109 tokenManager: &fakeAccessTokenManager{
1110 accessToken: testAccessToken,
1111 acceptedIdentity: &commonpb.Identity{
1112 IdentityOneof: &commonpb.Identity_SpiffeId{
1113 SpiffeId: "allowed_spiffe_id",
1114 },
1115 },
1116 },
1117 expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
1118 {
1119 Identity: &commonpb.Identity{
1120 IdentityOneof: &commonpb.Identity_SpiffeId{
1121 SpiffeId: "allowed_spiffe_id",
1122 },
1123 },
1124 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
1125 Token: testAccessToken,
1126 },
1127 },
1128 },
1129 },
1130 } {
1131 t.Run(tc.description, func(t *testing.T) {
1132 handshaker := newServerHandshaker(nil, nil, "", tc.options, tc.tokenManager)
1133 authMechanisms := handshaker.getAuthMechanisms()
1134 if got, want := (authMechanisms == nil), (tc.expectedAuthMechanisms == nil); got != want {
1135 t.Errorf("authMechanisms == nil: %t, tc.expectedAuthMechanisms == nil: %t", got, want)
1136 }
1137 if authMechanisms != nil && tc.expectedAuthMechanisms != nil {
1138 if diff := cmp.Diff(authMechanisms, tc.expectedAuthMechanisms, protocmp.Transform(), sortProtos); diff != "" {
1139 t.Errorf("handshaker.getAuthMechanisms() returned incorrect slice, (-want +got):\n%s", diff)
1140 }
1141 }
1142 })
1143 }
1144 }
1145
View as plain text