...

Source file src/github.com/google/s2a-go/internal/handshaker/handshaker_test.go

Documentation: github.com/google/s2a-go/internal/handshaker

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package handshaker
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"strings"
    29  	"testing"
    30  
    31  	"github.com/google/go-cmp/cmp"
    32  	"github.com/google/go-cmp/cmp/cmpopts"
    33  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    34  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
    35  	"github.com/google/s2a-go/internal/tokenmanager"
    36  	"golang.org/x/sync/errgroup"
    37  	grpc "google.golang.org/grpc"
    38  	"google.golang.org/protobuf/testing/protocmp"
    39  )
    40  
    41  var (
    42  	testAccessToken = "test_access_token"
    43  
    44  	// testHSAddr is the handshaker service address used for testing
    45  	testHSAddr = "handshaker_address"
    46  
    47  	// testHostname is the hostname of the server used for testing.
    48  	testHostname = "localhost"
    49  
    50  	// testClientHandshakerOptions are the client-side handshaker options used for
    51  	// testing.
    52  	testClientHandshakerOptions = &ClientHandshakerOptions{
    53  		MinTLSVersion: commonpb.TLSVersion_TLS1_2,
    54  		MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
    55  		TLSCiphersuites: []commonpb.Ciphersuite{
    56  			commonpb.Ciphersuite_AES_128_GCM_SHA256,
    57  			commonpb.Ciphersuite_AES_256_GCM_SHA384,
    58  			commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
    59  		},
    60  		TargetIdentities: []*commonpb.Identity{
    61  			{
    62  				IdentityOneof: &commonpb.Identity_SpiffeId{
    63  					SpiffeId: "target_spiffe_id",
    64  				},
    65  			},
    66  			{
    67  				IdentityOneof: &commonpb.Identity_Hostname{
    68  					Hostname: "target_hostname",
    69  				},
    70  			},
    71  		},
    72  		LocalIdentity: &commonpb.Identity{
    73  			IdentityOneof: &commonpb.Identity_SpiffeId{
    74  				SpiffeId: "client_local_spiffe_id",
    75  			},
    76  		},
    77  		TargetName: testHostname + ":1234",
    78  	}
    79  
    80  	// testClientStart is the ClientSessionStartReq message that the S2A expects
    81  	// to receive first from the test client.
    82  	testClientStart = &s2apb.ClientSessionStartReq{
    83  		ApplicationProtocols: []string{"grpc"},
    84  		MinTlsVersion:        commonpb.TLSVersion_TLS1_2,
    85  		MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
    86  		TlsCiphersuites: []commonpb.Ciphersuite{
    87  			commonpb.Ciphersuite_AES_128_GCM_SHA256,
    88  			commonpb.Ciphersuite_AES_256_GCM_SHA384,
    89  			commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
    90  		},
    91  		TargetIdentities: []*commonpb.Identity{
    92  			{
    93  				IdentityOneof: &commonpb.Identity_SpiffeId{
    94  					SpiffeId: "target_spiffe_id",
    95  				},
    96  			},
    97  			{
    98  				IdentityOneof: &commonpb.Identity_Hostname{
    99  					Hostname: "target_hostname",
   100  				},
   101  			},
   102  		},
   103  		LocalIdentity: &commonpb.Identity{
   104  			IdentityOneof: &commonpb.Identity_SpiffeId{
   105  				SpiffeId: "client_local_spiffe_id",
   106  			},
   107  		},
   108  		TargetName: testHostname,
   109  	}
   110  
   111  	// testClientNext is the SessionNextReq message that the S2A expects
   112  	// to receive second from the test client.
   113  	testClientNext = &s2apb.SessionNextReq{
   114  		InBytes: []byte("ServerHelloServerFinished"),
   115  	}
   116  
   117  	// testServerHandshakerOptions are the server-side handshaker options used
   118  	// for testing.
   119  	testServerHandshakerOptions = &ServerHandshakerOptions{
   120  		MinTLSVersion: commonpb.TLSVersion_TLS1_2,
   121  		MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
   122  		TLSCiphersuites: []commonpb.Ciphersuite{
   123  			commonpb.Ciphersuite_AES_128_GCM_SHA256,
   124  			commonpb.Ciphersuite_AES_256_GCM_SHA384,
   125  			commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   126  		},
   127  		LocalIdentities: []*commonpb.Identity{
   128  			{
   129  				IdentityOneof: &commonpb.Identity_SpiffeId{
   130  					SpiffeId: "server_local_spiffe_id",
   131  				},
   132  			},
   133  			{
   134  				IdentityOneof: &commonpb.Identity_Hostname{
   135  					Hostname: "server_local_hostname",
   136  				},
   137  			},
   138  		},
   139  	}
   140  
   141  	// testServerStart is the ServerSessionStartReq message that the S2A expects
   142  	// to receive from the test server.
   143  	testServerStart = &s2apb.ServerSessionStartReq{
   144  		ApplicationProtocols: []string{"grpc"},
   145  		MinTlsVersion:        commonpb.TLSVersion_TLS1_2,
   146  		MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   147  		TlsCiphersuites: []commonpb.Ciphersuite{
   148  			commonpb.Ciphersuite_AES_128_GCM_SHA256,
   149  			commonpb.Ciphersuite_AES_256_GCM_SHA384,
   150  			commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   151  		},
   152  		LocalIdentities: []*commonpb.Identity{
   153  			{
   154  				IdentityOneof: &commonpb.Identity_SpiffeId{
   155  					SpiffeId: "server_local_spiffe_id",
   156  				},
   157  			},
   158  			{
   159  				IdentityOneof: &commonpb.Identity_Hostname{
   160  					Hostname: "server_local_hostname",
   161  				},
   162  			},
   163  		},
   164  		InBytes: []byte("ClientHello"),
   165  	}
   166  
   167  	// testServerNext is the SessionNextReq message that the S2A expects to
   168  	// receive second from the test server.
   169  	testServerNext = &s2apb.SessionNextReq{
   170  		InBytes: []byte("ClientFinished"),
   171  	}
   172  
   173  	testClientSessionResult = &s2apb.SessionResult{
   174  		ApplicationProtocol: "grpc",
   175  		State: &s2apb.SessionState{
   176  			TlsVersion:     commonpb.TLSVersion_TLS1_3,
   177  			TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
   178  			InSequence:     0,
   179  			OutSequence:    0,
   180  			InKey:          make([]byte, 32),
   181  			OutKey:         make([]byte, 32),
   182  		},
   183  		PeerIdentity: &commonpb.Identity{
   184  			IdentityOneof: &commonpb.Identity_SpiffeId{
   185  				SpiffeId: "client_local_spiffe_id",
   186  			},
   187  		},
   188  		LocalIdentity: &commonpb.Identity{
   189  			IdentityOneof: &commonpb.Identity_SpiffeId{
   190  				SpiffeId: "server_local_spiffe_id",
   191  			},
   192  		},
   193  		LocalCertFingerprint: []byte("client_cert_fingerprint"),
   194  		PeerCertFingerprint:  []byte("server_cert_fingerprint"),
   195  	}
   196  
   197  	testServerSessionResult = &s2apb.SessionResult{
   198  		ApplicationProtocol: "grpc",
   199  		State: &s2apb.SessionState{
   200  			TlsVersion:     commonpb.TLSVersion_TLS1_3,
   201  			TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
   202  			InSequence:     0,
   203  			OutSequence:    0,
   204  			InKey:          make([]byte, 32),
   205  			OutKey:         make([]byte, 32),
   206  		},
   207  		PeerIdentity: &commonpb.Identity{
   208  			IdentityOneof: &commonpb.Identity_SpiffeId{
   209  				SpiffeId: "server_local_spiffe_id",
   210  			},
   211  		},
   212  		LocalIdentity: &commonpb.Identity{
   213  			IdentityOneof: &commonpb.Identity_SpiffeId{
   214  				SpiffeId: "client_local_spiffe_id",
   215  			},
   216  		},
   217  		LocalCertFingerprint: []byte("server_cert_fingerprint"),
   218  		PeerCertFingerprint:  []byte("client_cert_fingerprint"),
   219  	}
   220  	testResultWithoutLocalIdentity = &s2apb.SessionResult{
   221  		ApplicationProtocol: "grpc",
   222  		State: &s2apb.SessionState{
   223  			TlsVersion:     commonpb.TLSVersion_TLS1_3,
   224  			TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
   225  			InSequence:     0,
   226  			OutSequence:    0,
   227  			InKey:          make([]byte, 32),
   228  			OutKey:         make([]byte, 32),
   229  		},
   230  		PeerIdentity: &commonpb.Identity{
   231  			IdentityOneof: &commonpb.Identity_SpiffeId{
   232  				SpiffeId: "server_local_spiffe_id",
   233  			},
   234  		},
   235  		LocalCertFingerprint: []byte("server_cert_fingerprint"),
   236  		PeerCertFingerprint:  []byte("client_cert_fingerprint"),
   237  	}
   238  )
   239  
   240  // fakeConn is a fake implementation of the net.Conn interface that is used for
   241  // testing.
   242  type fakeConn struct {
   243  	net.Conn
   244  	in  *bytes.Buffer
   245  	out *bytes.Buffer
   246  }
   247  
   248  func (fc *fakeConn) Read(b []byte) (n int, err error)  { return fc.in.Read(b) }
   249  func (fc *fakeConn) Write(b []byte) (n int, err error) { return fc.out.Write(b) }
   250  func (fc *fakeConn) Close() error                      { return nil }
   251  
   252  // fakeInvalidConn is a fake implementation of a invalid net.Conn interface
   253  // that is used for testing.
   254  type fakeInvalidConn struct {
   255  	net.Conn
   256  }
   257  
   258  func (fc *fakeInvalidConn) Read(_ []byte) (n int, err error)  { return 0, io.EOF }
   259  func (fc *fakeInvalidConn) Write(_ []byte) (n int, err error) { return 0, nil }
   260  func (fc *fakeInvalidConn) Close() error                      { return nil }
   261  
   262  // fakeStream is a fake implementation of the grpc.ClientStream interface that
   263  // is used for testing.
   264  type fakeStream struct {
   265  	grpc.ClientStream
   266  	t                   *testing.T
   267  	fc                  *fakeConn
   268  	expectedClientStart *s2apb.ClientSessionStartReq
   269  	expectedServerStart *s2apb.ServerSessionStartReq
   270  	expectToken         bool
   271  	// expectedResp is the expected SessionResp message from the handshaker
   272  	// service.
   273  	expectedResp *s2apb.SessionResp
   274  	// isFirstAccess indicates whether the first call to the handshaker service
   275  	// has been made or not.
   276  	isFirstAccess          bool
   277  	isClient               bool
   278  	isLocalIdentityMissing bool
   279  }
   280  
   281  func (fs *fakeStream) Recv() (*s2apb.SessionResp, error) {
   282  	resp := fs.expectedResp
   283  	fs.expectedResp = nil
   284  	return resp, nil
   285  }
   286  func (fs *fakeStream) Send(req *s2apb.SessionReq) error {
   287  	var resp *s2apb.SessionResp
   288  	if fs.expectToken {
   289  		if len(req.GetAuthMechanisms()) == 0 {
   290  			return fmt.Errorf("request to S2A did not contain any tokens")
   291  		}
   292  		// Ensure that every token appearing in the request has a valid token.
   293  		for _, authMechanism := range req.GetAuthMechanisms() {
   294  			if authMechanism.GetToken() != testAccessToken {
   295  				return fmt.Errorf("request to S2A contained invalid token")
   296  			}
   297  		}
   298  	}
   299  	if !fs.isFirstAccess {
   300  		// Generate the bytes to be returned by Recv() for the first handshake
   301  		// message.
   302  		fs.isFirstAccess = true
   303  		if fs.isClient {
   304  			if diff := cmp.Diff(req.GetClientStart(), fs.expectedClientStart, protocmp.Transform()); diff != "" {
   305  				return fmt.Errorf("client start message is incorrect, (-want +got):\n%s", diff)
   306  			}
   307  			resp = &s2apb.SessionResp{
   308  				OutFrames: []byte("ClientHello"),
   309  				// There are no consumed bytes for a client start message
   310  				BytesConsumed: 0,
   311  			}
   312  		} else {
   313  			// Expect a server start message.
   314  			if req.GetServerStart() == nil {
   315  				return errors.New("first request from server does not have server start")
   316  			}
   317  			if diff := cmp.Diff(req.GetServerStart(), fs.expectedServerStart, protocmp.Transform()); diff != "" {
   318  				return fmt.Errorf("server start message is incorrect, (-want +got):\n%s", diff)
   319  			}
   320  			fs.fc.in.Write([]byte("ClientFinished"))
   321  			resp = &s2apb.SessionResp{
   322  				OutFrames: []byte("ServerHelloServerFinished"),
   323  				// Simulate consuming the ClientHello message.
   324  				BytesConsumed: uint32(len("ClientHello")),
   325  			}
   326  		}
   327  	} else {
   328  		// Construct a SessionResp message that contains the handshake result.
   329  		if fs.isClient {
   330  			// Expect next message with "ServerHelloServerFinished".
   331  			if req.GetNext() == nil {
   332  				return errors.New("second request from client does not have next")
   333  			}
   334  			if got, want := cmp.Equal(req.GetNext(), testClientNext, protocmp.Transform()), true; got != want {
   335  				return errors.New("client next message is incorrect")
   336  			}
   337  			if fs.isLocalIdentityMissing {
   338  				resp = &s2apb.SessionResp{
   339  					Result:        testResultWithoutLocalIdentity,
   340  					BytesConsumed: uint32(len("ClientFinished")),
   341  				}
   342  			} else {
   343  				resp = &s2apb.SessionResp{
   344  					Result:        testClientSessionResult,
   345  					BytesConsumed: uint32(len("ServerHelloServerFinished")),
   346  				}
   347  			}
   348  		} else {
   349  			// Expect next message with "ClientFinished".
   350  			if req.GetNext() == nil {
   351  				return errors.New("second request from server does not have next")
   352  			}
   353  			if got, want := cmp.Equal(req.GetNext(), testServerNext, protocmp.Transform()), true; got != want {
   354  				return errors.New("server next message is incorrect")
   355  			}
   356  			if fs.isLocalIdentityMissing {
   357  				resp = &s2apb.SessionResp{
   358  					Result:        testResultWithoutLocalIdentity,
   359  					BytesConsumed: uint32(len("ClientFinished")),
   360  				}
   361  			} else {
   362  				resp = &s2apb.SessionResp{
   363  					Result:        testServerSessionResult,
   364  					BytesConsumed: uint32(len("ClientFinished")),
   365  				}
   366  			}
   367  		}
   368  	}
   369  	fs.expectedResp = resp
   370  	return nil
   371  }
   372  
   373  func (*fakeStream) CloseSend() error { return nil }
   374  
   375  // fakeInvalidStream is a fake implementation of an invalid grpc.ClientStream
   376  // interface that is used for testing.
   377  type fakeInvalidStream struct {
   378  	grpc.ClientStream
   379  }
   380  
   381  func (*fakeInvalidStream) Recv() (*s2apb.SessionResp, error) { return &s2apb.SessionResp{}, nil }
   382  func (*fakeInvalidStream) Send(*s2apb.SessionReq) error      { return nil }
   383  func (*fakeInvalidStream) CloseSend() error                  { return nil }
   384  
   385  type fakeAccessTokenManager struct {
   386  	acceptedIdentity   *commonpb.Identity
   387  	accessToken        string
   388  	allowEmptyIdentity bool
   389  }
   390  
   391  func (m *fakeAccessTokenManager) DefaultToken() (string, error) {
   392  	if !m.allowEmptyIdentity {
   393  		return "", fmt.Errorf("not allowed to get token for empty identity")
   394  	}
   395  	return m.accessToken, nil
   396  }
   397  
   398  func (m *fakeAccessTokenManager) Token(identity *commonpb.Identity) (string, error) {
   399  	if identity == nil || cmp.Equal(identity, &commonpb.Identity{}, protocmp.Transform()) {
   400  		if !m.allowEmptyIdentity {
   401  			return "", fmt.Errorf("not allowed to get token for empty identity")
   402  		}
   403  		return m.accessToken, nil
   404  	}
   405  	if cmp.Equal(identity, m.acceptedIdentity, protocmp.Transform()) {
   406  		return m.accessToken, nil
   407  	}
   408  	return "", fmt.Errorf("unable to get token")
   409  }
   410  
   411  // TestNewClientHandshaker creates a fake stream, and ensures that
   412  // newClientHandshaker returns a valid client-side handshaker instance.
   413  func TestNewClientHandshaker(t *testing.T) {
   414  	stream := &fakeStream{}
   415  	c := &fakeConn{}
   416  	chs := newClientHandshaker(stream, c, testHSAddr, testClientHandshakerOptions, &fakeAccessTokenManager{})
   417  	if chs.clientOpts != testClientHandshakerOptions || chs.conn != c {
   418  		t.Errorf("handshaker parameters incorrect")
   419  	}
   420  }
   421  
   422  // TestNewServerHandshaker creates a fake stream, and ensures that
   423  // newServerHandshaker returns a valid server-side handshaker instance.
   424  func TestNewServerHandshaker(t *testing.T) {
   425  	stream := &fakeStream{}
   426  	c := &fakeConn{}
   427  	shs := newServerHandshaker(stream, c, testHSAddr, testServerHandshakerOptions, &fakeAccessTokenManager{})
   428  	if shs.serverOpts != testServerHandshakerOptions || shs.conn != c {
   429  		t.Errorf("handshaker parameters incorrect")
   430  	}
   431  }
   432  
   433  func TestClientHandshakeSuccess(t *testing.T) {
   434  	for _, tc := range []struct {
   435  		description         string
   436  		options             *ClientHandshakerOptions
   437  		tokenManager        tokenmanager.AccessTokenManager
   438  		expectedClientStart *s2apb.ClientSessionStartReq
   439  	}{
   440  		{
   441  			description:         "full client options",
   442  			options:             testClientHandshakerOptions,
   443  			expectedClientStart: testClientStart,
   444  		},
   445  		{
   446  			description: "full client options with no port in target name",
   447  			options: &ClientHandshakerOptions{
   448  				MinTLSVersion: commonpb.TLSVersion_TLS1_2,
   449  				MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
   450  				TLSCiphersuites: []commonpb.Ciphersuite{
   451  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   452  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   453  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   454  				},
   455  				TargetIdentities: []*commonpb.Identity{
   456  					{
   457  						IdentityOneof: &commonpb.Identity_SpiffeId{
   458  							SpiffeId: "target_spiffe_id",
   459  						},
   460  					},
   461  					{
   462  						IdentityOneof: &commonpb.Identity_Hostname{
   463  							Hostname: "target_hostname",
   464  						},
   465  					},
   466  				},
   467  				LocalIdentity: &commonpb.Identity{
   468  					IdentityOneof: &commonpb.Identity_SpiffeId{
   469  						SpiffeId: "client_local_spiffe_id",
   470  					},
   471  				},
   472  				TargetName: testHostname,
   473  			},
   474  			expectedClientStart: &s2apb.ClientSessionStartReq{
   475  				ApplicationProtocols: []string{"grpc"},
   476  				MinTlsVersion:        commonpb.TLSVersion_TLS1_2,
   477  				MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   478  				TlsCiphersuites: []commonpb.Ciphersuite{
   479  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   480  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   481  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   482  				},
   483  				TargetIdentities: []*commonpb.Identity{
   484  					{
   485  						IdentityOneof: &commonpb.Identity_SpiffeId{
   486  							SpiffeId: "target_spiffe_id",
   487  						},
   488  					},
   489  					{
   490  						IdentityOneof: &commonpb.Identity_Hostname{
   491  							Hostname: "target_hostname",
   492  						},
   493  					},
   494  				},
   495  				LocalIdentity: &commonpb.Identity{
   496  					IdentityOneof: &commonpb.Identity_SpiffeId{
   497  						SpiffeId: "client_local_spiffe_id",
   498  					},
   499  				},
   500  				TargetName: testHostname,
   501  			},
   502  		},
   503  		{
   504  			description: "full client options with no local identity",
   505  			options: &ClientHandshakerOptions{
   506  				MinTLSVersion: commonpb.TLSVersion_TLS1_2,
   507  				MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
   508  				TLSCiphersuites: []commonpb.Ciphersuite{
   509  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   510  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   511  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   512  				},
   513  				TargetIdentities: []*commonpb.Identity{
   514  					{
   515  						IdentityOneof: &commonpb.Identity_SpiffeId{
   516  							SpiffeId: "target_spiffe_id",
   517  						},
   518  					},
   519  					{
   520  						IdentityOneof: &commonpb.Identity_Hostname{
   521  							Hostname: "target_hostname",
   522  						},
   523  					},
   524  				},
   525  				TargetName: testHostname + ":1234",
   526  			},
   527  			expectedClientStart: &s2apb.ClientSessionStartReq{
   528  				ApplicationProtocols: []string{"grpc"},
   529  				MinTlsVersion:        commonpb.TLSVersion_TLS1_2,
   530  				MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   531  				TlsCiphersuites: []commonpb.Ciphersuite{
   532  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   533  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   534  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   535  				},
   536  				TargetIdentities: []*commonpb.Identity{
   537  					{
   538  						IdentityOneof: &commonpb.Identity_SpiffeId{
   539  							SpiffeId: "target_spiffe_id",
   540  						},
   541  					},
   542  					{
   543  						IdentityOneof: &commonpb.Identity_Hostname{
   544  							Hostname: "target_hostname",
   545  						},
   546  					},
   547  				},
   548  				TargetName: testHostname,
   549  			},
   550  		},
   551  		{
   552  			description:         "full client options, sending tokens",
   553  			options:             testClientHandshakerOptions,
   554  			expectedClientStart: testClientStart,
   555  			tokenManager: &fakeAccessTokenManager{
   556  				accessToken: testAccessToken,
   557  				acceptedIdentity: &commonpb.Identity{
   558  					IdentityOneof: &commonpb.Identity_SpiffeId{
   559  						SpiffeId: "client_local_spiffe_id",
   560  					},
   561  				},
   562  			},
   563  		},
   564  		{
   565  			description: "full client options with no local identity, sending tokens",
   566  			options: &ClientHandshakerOptions{
   567  				MinTLSVersion: commonpb.TLSVersion_TLS1_2,
   568  				MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
   569  				TLSCiphersuites: []commonpb.Ciphersuite{
   570  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   571  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   572  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   573  				},
   574  				TargetIdentities: []*commonpb.Identity{
   575  					{
   576  						IdentityOneof: &commonpb.Identity_SpiffeId{
   577  							SpiffeId: "target_spiffe_id",
   578  						},
   579  					},
   580  					{
   581  						IdentityOneof: &commonpb.Identity_Hostname{
   582  							Hostname: "target_hostname",
   583  						},
   584  					},
   585  				},
   586  				TargetName: testHostname + ":1234",
   587  			},
   588  			expectedClientStart: &s2apb.ClientSessionStartReq{
   589  				ApplicationProtocols: []string{"grpc"},
   590  				MinTlsVersion:        commonpb.TLSVersion_TLS1_2,
   591  				MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   592  				TlsCiphersuites: []commonpb.Ciphersuite{
   593  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   594  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   595  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   596  				},
   597  				TargetIdentities: []*commonpb.Identity{
   598  					{
   599  						IdentityOneof: &commonpb.Identity_SpiffeId{
   600  							SpiffeId: "target_spiffe_id",
   601  						},
   602  					},
   603  					{
   604  						IdentityOneof: &commonpb.Identity_Hostname{
   605  							Hostname: "target_hostname",
   606  						},
   607  					},
   608  				},
   609  				TargetName: testHostname,
   610  			},
   611  			tokenManager: &fakeAccessTokenManager{
   612  				accessToken:        testAccessToken,
   613  				allowEmptyIdentity: true,
   614  			},
   615  		},
   616  	} {
   617  		t.Run(tc.description, func(t *testing.T) {
   618  			// Set up all fakes and input data.
   619  			var errg errgroup.Group
   620  			stream := &fakeStream{
   621  				t:                   t,
   622  				isClient:            true,
   623  				expectedClientStart: tc.expectedClientStart,
   624  				expectToken:         (tc.tokenManager != nil),
   625  			}
   626  			in := bytes.NewBuffer([]byte("ServerHelloServerFinished"))
   627  			c := &fakeConn{
   628  				in:  in,
   629  				out: new(bytes.Buffer),
   630  			}
   631  
   632  			// Do the handshake.
   633  			chs := newClientHandshaker(stream, c, testHSAddr, tc.options, tc.tokenManager)
   634  			errg.Go(func() error {
   635  				newConn, auth, err := chs.ClientHandshake(context.Background())
   636  				if err != nil {
   637  					return err
   638  				}
   639  				if auth.AuthType() != "s2a" {
   640  					return errors.New("s2a auth type incorrect")
   641  				}
   642  				if newConn == nil {
   643  					return errors.New("expected non-nil net.Conn")
   644  				}
   645  				if err := chs.Close(); err != nil {
   646  					t.Errorf("chs.Close() failed: %v", err)
   647  				}
   648  				return nil
   649  			})
   650  
   651  			if err := errg.Wait(); err != nil {
   652  				t.Errorf("client handshake failed: %v", err)
   653  			}
   654  		})
   655  	}
   656  }
   657  
   658  func TestServerHandshakeSuccess(t *testing.T) {
   659  	for _, tc := range []struct {
   660  		description         string
   661  		options             *ServerHandshakerOptions
   662  		tokenManager        tokenmanager.AccessTokenManager
   663  		expectedServerStart *s2apb.ServerSessionStartReq
   664  	}{
   665  		{
   666  			description:         "full server options",
   667  			options:             testServerHandshakerOptions,
   668  			expectedServerStart: testServerStart,
   669  		},
   670  		{
   671  			description: "full server options with no local identities",
   672  			options: &ServerHandshakerOptions{
   673  				MinTLSVersion: commonpb.TLSVersion_TLS1_2,
   674  				MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
   675  				TLSCiphersuites: []commonpb.Ciphersuite{
   676  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   677  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   678  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   679  				},
   680  			},
   681  			expectedServerStart: &s2apb.ServerSessionStartReq{
   682  				ApplicationProtocols: []string{"grpc"},
   683  				MinTlsVersion:        commonpb.TLSVersion_TLS1_2,
   684  				MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   685  				TlsCiphersuites: []commonpb.Ciphersuite{
   686  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   687  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   688  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   689  				},
   690  				InBytes: []byte("ClientHello"),
   691  			},
   692  		},
   693  		{
   694  			description:         "full server options, sending tokens",
   695  			options:             testServerHandshakerOptions,
   696  			expectedServerStart: testServerStart,
   697  			tokenManager: &fakeAccessTokenManager{
   698  				accessToken: testAccessToken,
   699  				acceptedIdentity: &commonpb.Identity{
   700  					IdentityOneof: &commonpb.Identity_SpiffeId{
   701  						SpiffeId: "server_local_spiffe_id",
   702  					},
   703  				},
   704  			},
   705  		},
   706  		{
   707  			description: "full server options with no local identity, sending tokens",
   708  			options: &ServerHandshakerOptions{
   709  				MinTLSVersion: commonpb.TLSVersion_TLS1_2,
   710  				MaxTLSVersion: commonpb.TLSVersion_TLS1_3,
   711  				TLSCiphersuites: []commonpb.Ciphersuite{
   712  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   713  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   714  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   715  				},
   716  			},
   717  			expectedServerStart: &s2apb.ServerSessionStartReq{
   718  				ApplicationProtocols: []string{"grpc"},
   719  				MinTlsVersion:        commonpb.TLSVersion_TLS1_2,
   720  				MaxTlsVersion:        commonpb.TLSVersion_TLS1_3,
   721  				TlsCiphersuites: []commonpb.Ciphersuite{
   722  					commonpb.Ciphersuite_AES_128_GCM_SHA256,
   723  					commonpb.Ciphersuite_AES_256_GCM_SHA384,
   724  					commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   725  				},
   726  				InBytes: []byte("ClientHello"),
   727  			},
   728  			tokenManager: &fakeAccessTokenManager{
   729  				accessToken:        testAccessToken,
   730  				allowEmptyIdentity: true,
   731  			},
   732  		},
   733  	} {
   734  		t.Run(tc.description, func(t *testing.T) {
   735  			// Set up all fakes and input data.
   736  			var errg errgroup.Group
   737  			in := bytes.NewBuffer([]byte("ClientHello"))
   738  			c := &fakeConn{
   739  				in:  in,
   740  				out: new(bytes.Buffer),
   741  			}
   742  			stream := &fakeStream{
   743  				t:                   t,
   744  				fc:                  c,
   745  				isClient:            false,
   746  				expectedServerStart: tc.expectedServerStart,
   747  				expectToken:         (tc.tokenManager != nil),
   748  			}
   749  
   750  			// Do the handshake.
   751  			shs := newServerHandshaker(stream, c, testHSAddr, tc.options, tc.tokenManager)
   752  			errg.Go(func() error {
   753  				newConn, auth, err := shs.ServerHandshake(context.Background())
   754  				if err != nil {
   755  					return err
   756  				}
   757  				if auth.AuthType() != "s2a" {
   758  					return errors.New("s2a auth type incorrect")
   759  				}
   760  				if newConn == nil {
   761  					return errors.New("expected non-nil net.Conn")
   762  				}
   763  				if err = shs.Close(); err != nil {
   764  					t.Errorf("shs.Close() failed: %v", err)
   765  				}
   766  				return nil
   767  			})
   768  
   769  			if err := errg.Wait(); err != nil {
   770  				t.Errorf("server handshake failed: %v", err)
   771  			}
   772  		})
   773  	}
   774  }
   775  
   776  // Note that there is no need to test the case where S2A is expecting a token
   777  // and the application does not send a token, because this case is functionally
   778  // the same as the application sending an invalid token.
   779  func TestS2ARejectsTokenFromClient(t *testing.T) {
   780  	stream := &fakeStream{
   781  		t:           t,
   782  		isClient:    true,
   783  		expectToken: true,
   784  	}
   785  	in := bytes.NewBuffer([]byte("ServerHelloServerFinished"))
   786  	c := &fakeConn{
   787  		in:  in,
   788  		out: new(bytes.Buffer),
   789  	}
   790  	tokenManager := &fakeAccessTokenManager{
   791  		accessToken: "bad_access_token",
   792  		acceptedIdentity: &commonpb.Identity{
   793  			IdentityOneof: &commonpb.Identity_SpiffeId{
   794  				SpiffeId: "client_local_spiffe_id",
   795  			},
   796  		},
   797  	}
   798  
   799  	chs := newClientHandshaker(stream, c, testHSAddr, testClientHandshakerOptions, tokenManager)
   800  	_, _, err := chs.ClientHandshake(context.Background())
   801  	if err == nil {
   802  		t.Errorf("expected non-nil error from call to chs.ClientHandshake()")
   803  	}
   804  	if !strings.Contains(err.Error(), "request to S2A contained invalid token") {
   805  		t.Errorf("chs.ClientHandshake() produced unexpected error: %v", err)
   806  	}
   807  }
   808  
   809  func TestS2ARejectsTokenFromServer(t *testing.T) {
   810  	stream := &fakeStream{
   811  		t:           t,
   812  		isClient:    false,
   813  		expectToken: true,
   814  	}
   815  	in := bytes.NewBuffer([]byte("ClientHello"))
   816  	c := &fakeConn{
   817  		in:  in,
   818  		out: new(bytes.Buffer),
   819  	}
   820  	tokenManager := &fakeAccessTokenManager{
   821  		accessToken: "bad_access_token",
   822  		acceptedIdentity: &commonpb.Identity{
   823  			IdentityOneof: &commonpb.Identity_SpiffeId{
   824  				SpiffeId: "server_local_spiffe_id",
   825  			},
   826  		},
   827  	}
   828  
   829  	chs := newServerHandshaker(stream, c, testHSAddr, testServerHandshakerOptions, tokenManager)
   830  	_, _, err := chs.ServerHandshake(context.Background())
   831  	if err == nil {
   832  		t.Errorf("expected non-nil error from call to chs.ServerHandshake()")
   833  	}
   834  	if !strings.Contains(err.Error(), "request to S2A contained invalid token") {
   835  		t.Errorf("chs.ServerHandshake() produced unexpected error: %v", err)
   836  	}
   837  }
   838  
   839  func TestInvalidHandshaker(t *testing.T) {
   840  	emptyCHS := &s2aHandshaker{
   841  		isClient: false,
   842  	}
   843  	_, _, err := emptyCHS.ClientHandshake(context.Background())
   844  	if err == nil {
   845  		t.Error("ClientHandshake() should fail with server-side handshaker service")
   846  	}
   847  	emptySHS := &s2aHandshaker{
   848  		isClient: true,
   849  	}
   850  	_, _, err = emptySHS.ServerHandshake(context.Background())
   851  	if err == nil {
   852  		t.Error("ServerHandshake() should fail with client-side handshaker service")
   853  	}
   854  }
   855  
   856  // TestPeerNotResponding uses an invalid net.Conn instance and performs a
   857  // client-side handshake to test the case when the peer is not responding.
   858  func TestPeerNotResponding(t *testing.T) {
   859  	stream := &fakeInvalidStream{}
   860  	c := &fakeInvalidConn{}
   861  	chs := &s2aHandshaker{
   862  		stream:     stream,
   863  		conn:       c,
   864  		clientOpts: testClientHandshakerOptions,
   865  		isClient:   true,
   866  		hsAddr:     testHSAddr,
   867  	}
   868  	_, authInfo, err := chs.ClientHandshake(context.Background())
   869  	if authInfo != nil {
   870  		t.Error("expected non-nil S2A authInfo")
   871  	}
   872  	if got, want := err, errPeerNotResponding; got != want {
   873  		t.Errorf("ClientHandshake() = %v, want %v", got, want)
   874  	}
   875  	if err = chs.Close(); err != nil {
   876  		t.Errorf("chs.Close() failed: %v", err)
   877  	}
   878  }
   879  
   880  // TestLocalIdentityNotSet performs a client-side handshake that fails
   881  // because the local identity is not set in the handshake result.
   882  func TestLocalIdentityNotSet(t *testing.T) {
   883  	var errg errgroup.Group
   884  	stream := &fakeStream{
   885  		t:                      t,
   886  		isClient:               true,
   887  		isLocalIdentityMissing: true,
   888  	}
   889  	in := bytes.NewBuffer([]byte("ServerHelloServerFinished"))
   890  	c := &fakeConn{
   891  		in:  in,
   892  		out: new(bytes.Buffer),
   893  	}
   894  	chs := &s2aHandshaker{
   895  		stream:     stream,
   896  		conn:       c,
   897  		clientOpts: testClientHandshakerOptions,
   898  		isClient:   true,
   899  		hsAddr:     testHSAddr,
   900  	}
   901  	errg.Go(func() error {
   902  		newConn, auth, err := chs.ClientHandshake(context.Background())
   903  		if cmp.Equal(err, errors.New("local identity must be populated in session result"), cmpopts.EquateErrors()) {
   904  			return fmt.Errorf("unexpected error: %v", err)
   905  		}
   906  		if auth != nil {
   907  			return errors.New("expected nil credentials.AuthInfo")
   908  		}
   909  		if newConn != nil {
   910  			return errors.New("expected nil net.Conn")
   911  		}
   912  		return nil
   913  	})
   914  
   915  	if err := errg.Wait(); err != nil {
   916  		t.Errorf("client handshake failed: %v", err)
   917  	}
   918  }
   919  
   920  func TestGetAuthMechanismsForClient(t *testing.T) {
   921  	sortProtos := cmpopts.SortSlices(func(m1, m2 *s2apb.AuthenticationMechanism) bool { return m1.String() < m2.String() })
   922  	for _, tc := range []struct {
   923  		description            string
   924  		options                *ClientHandshakerOptions
   925  		tokenManager           tokenmanager.AccessTokenManager
   926  		expectedAuthMechanisms []*s2apb.AuthenticationMechanism
   927  	}{
   928  		{
   929  			description:            "token manager is nil",
   930  			tokenManager:           nil,
   931  			expectedAuthMechanisms: nil,
   932  		},
   933  		{
   934  			description: "token manager expects empty identity",
   935  			tokenManager: &fakeAccessTokenManager{
   936  				accessToken:        testAccessToken,
   937  				allowEmptyIdentity: true,
   938  			},
   939  			expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
   940  				{
   941  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   942  						Token: testAccessToken,
   943  					},
   944  				},
   945  			},
   946  		},
   947  		{
   948  			description: "token manager does not expect empty identity",
   949  			tokenManager: &fakeAccessTokenManager{
   950  				allowEmptyIdentity: false,
   951  			},
   952  			expectedAuthMechanisms: nil,
   953  		},
   954  		{
   955  			description: "token manager expects SPIFFE ID",
   956  			options: &ClientHandshakerOptions{
   957  				LocalIdentity: &commonpb.Identity{
   958  					IdentityOneof: &commonpb.Identity_SpiffeId{
   959  						SpiffeId: "allowed_spiffe_id",
   960  					},
   961  				},
   962  			},
   963  			tokenManager: &fakeAccessTokenManager{
   964  				accessToken: testAccessToken,
   965  				acceptedIdentity: &commonpb.Identity{
   966  					IdentityOneof: &commonpb.Identity_SpiffeId{
   967  						SpiffeId: "allowed_spiffe_id",
   968  					},
   969  				},
   970  			},
   971  			expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
   972  				{
   973  					Identity: &commonpb.Identity{
   974  						IdentityOneof: &commonpb.Identity_SpiffeId{
   975  							SpiffeId: "allowed_spiffe_id",
   976  						},
   977  					},
   978  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   979  						Token: testAccessToken,
   980  					},
   981  				},
   982  			},
   983  		},
   984  		{
   985  			description: "token manager does not expect hostname",
   986  			options: &ClientHandshakerOptions{
   987  				LocalIdentity: &commonpb.Identity{
   988  					IdentityOneof: &commonpb.Identity_Hostname{
   989  						Hostname: "disallowed_hostname",
   990  					},
   991  				},
   992  			},
   993  			tokenManager:           &fakeAccessTokenManager{},
   994  			expectedAuthMechanisms: nil,
   995  		},
   996  	} {
   997  		t.Run(tc.description, func(t *testing.T) {
   998  			handshaker := newClientHandshaker(nil, nil, "", tc.options, tc.tokenManager)
   999  			authMechanisms := handshaker.getAuthMechanisms()
  1000  			if got, want := (authMechanisms == nil), (tc.expectedAuthMechanisms == nil); got != want {
  1001  				t.Errorf("authMechanisms == nil: %t, tc.expectedAuthMechanisms == nil: %t", got, want)
  1002  			}
  1003  			if authMechanisms != nil && tc.expectedAuthMechanisms != nil {
  1004  				if diff := cmp.Diff(authMechanisms, tc.expectedAuthMechanisms, protocmp.Transform(), sortProtos); diff != "" {
  1005  					t.Errorf("handshaker.getAuthMechanisms() returned incorrect slice, (-want +got):\n%s", diff)
  1006  				}
  1007  			}
  1008  		})
  1009  	}
  1010  }
  1011  
  1012  func TestGetAuthMechanismsForServer(t *testing.T) {
  1013  	sortProtos := cmpopts.SortSlices(func(m1, m2 *s2apb.AuthenticationMechanism) bool { return m1.String() < m2.String() })
  1014  	for _, tc := range []struct {
  1015  		description            string
  1016  		options                *ServerHandshakerOptions
  1017  		tokenManager           tokenmanager.AccessTokenManager
  1018  		expectedAuthMechanisms []*s2apb.AuthenticationMechanism
  1019  	}{
  1020  		{
  1021  			description:            "token manager is nil",
  1022  			tokenManager:           nil,
  1023  			expectedAuthMechanisms: nil,
  1024  		},
  1025  		{
  1026  			description: "token manager expects empty identity",
  1027  			tokenManager: &fakeAccessTokenManager{
  1028  				accessToken:        testAccessToken,
  1029  				allowEmptyIdentity: true,
  1030  			},
  1031  			expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
  1032  				{
  1033  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
  1034  						Token: testAccessToken,
  1035  					},
  1036  				},
  1037  			},
  1038  		},
  1039  		{
  1040  			description: "token manager does not expect empty identity",
  1041  			tokenManager: &fakeAccessTokenManager{
  1042  				allowEmptyIdentity: false,
  1043  			},
  1044  			expectedAuthMechanisms: nil,
  1045  		},
  1046  		{
  1047  			description: "token manager expects 2 SPIFFE IDs",
  1048  			options: &ServerHandshakerOptions{
  1049  				LocalIdentities: []*commonpb.Identity{
  1050  					{
  1051  						IdentityOneof: &commonpb.Identity_SpiffeId{
  1052  							SpiffeId: "allowed_spiffe_id",
  1053  						},
  1054  					},
  1055  					{
  1056  						IdentityOneof: &commonpb.Identity_SpiffeId{
  1057  							SpiffeId: "allowed_spiffe_id",
  1058  						},
  1059  					},
  1060  				},
  1061  			},
  1062  			tokenManager: &fakeAccessTokenManager{
  1063  				accessToken: testAccessToken,
  1064  				acceptedIdentity: &commonpb.Identity{
  1065  					IdentityOneof: &commonpb.Identity_SpiffeId{
  1066  						SpiffeId: "allowed_spiffe_id",
  1067  					},
  1068  				},
  1069  			},
  1070  			expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
  1071  				{
  1072  					Identity: &commonpb.Identity{
  1073  						IdentityOneof: &commonpb.Identity_SpiffeId{
  1074  							SpiffeId: "allowed_spiffe_id",
  1075  						},
  1076  					},
  1077  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
  1078  						Token: testAccessToken,
  1079  					},
  1080  				},
  1081  				{
  1082  					Identity: &commonpb.Identity{
  1083  						IdentityOneof: &commonpb.Identity_SpiffeId{
  1084  							SpiffeId: "allowed_spiffe_id",
  1085  						},
  1086  					},
  1087  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
  1088  						Token: testAccessToken,
  1089  					},
  1090  				},
  1091  			},
  1092  		},
  1093  		{
  1094  			description: "token manager expects a SPIFFE ID but does not expect hostname",
  1095  			options: &ServerHandshakerOptions{
  1096  				LocalIdentities: []*commonpb.Identity{
  1097  					{
  1098  						IdentityOneof: &commonpb.Identity_SpiffeId{
  1099  							SpiffeId: "allowed_spiffe_id",
  1100  						},
  1101  					},
  1102  					{
  1103  						IdentityOneof: &commonpb.Identity_Hostname{
  1104  							Hostname: "disallowed_hostname",
  1105  						},
  1106  					},
  1107  				},
  1108  			},
  1109  			tokenManager: &fakeAccessTokenManager{
  1110  				accessToken: testAccessToken,
  1111  				acceptedIdentity: &commonpb.Identity{
  1112  					IdentityOneof: &commonpb.Identity_SpiffeId{
  1113  						SpiffeId: "allowed_spiffe_id",
  1114  					},
  1115  				},
  1116  			},
  1117  			expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
  1118  				{
  1119  					Identity: &commonpb.Identity{
  1120  						IdentityOneof: &commonpb.Identity_SpiffeId{
  1121  							SpiffeId: "allowed_spiffe_id",
  1122  						},
  1123  					},
  1124  					MechanismOneof: &s2apb.AuthenticationMechanism_Token{
  1125  						Token: testAccessToken,
  1126  					},
  1127  				},
  1128  			},
  1129  		},
  1130  	} {
  1131  		t.Run(tc.description, func(t *testing.T) {
  1132  			handshaker := newServerHandshaker(nil, nil, "", tc.options, tc.tokenManager)
  1133  			authMechanisms := handshaker.getAuthMechanisms()
  1134  			if got, want := (authMechanisms == nil), (tc.expectedAuthMechanisms == nil); got != want {
  1135  				t.Errorf("authMechanisms == nil: %t, tc.expectedAuthMechanisms == nil: %t", got, want)
  1136  			}
  1137  			if authMechanisms != nil && tc.expectedAuthMechanisms != nil {
  1138  				if diff := cmp.Diff(authMechanisms, tc.expectedAuthMechanisms, protocmp.Transform(), sortProtos); diff != "" {
  1139  					t.Errorf("handshaker.getAuthMechanisms() returned incorrect slice, (-want +got):\n%s", diff)
  1140  				}
  1141  			}
  1142  		})
  1143  	}
  1144  }
  1145  

View as plain text