...

Source file src/github.com/google/s2a-go/s2a_e2e_test.go

Documentation: github.com/google/s2a-go

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package s2a
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"crypto/tls"
    25  	"crypto/x509"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net"
    30  	"net/http"
    31  	"os"
    32  	"path/filepath"
    33  	"testing"
    34  	"time"
    35  
    36  	_ "embed"
    37  
    38  	"github.com/google/s2a-go/fallback"
    39  	"github.com/google/s2a-go/internal/fakehandshaker/service"
    40  	"github.com/google/s2a-go/internal/v2/fakes2av2"
    41  	"github.com/google/s2a-go/retry"
    42  	"google.golang.org/grpc/credentials"
    43  	"google.golang.org/grpc/grpclog"
    44  	"google.golang.org/grpc/peer"
    45  
    46  	grpc "google.golang.org/grpc"
    47  
    48  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    49  	helloworldpb "github.com/google/s2a-go/internal/proto/examples/helloworld_go_proto"
    50  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
    51  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    52  )
    53  
    54  const (
    55  	accessTokenEnvVariable = "S2A_ACCESS_TOKEN"
    56  	testAccessToken        = "test_access_token"
    57  	testV2AccessToken      = "valid_token"
    58  
    59  	applicationProtocol   = "grpc"
    60  	authType              = "s2a"
    61  	clientHostname        = "test_client_hostname"
    62  	serverSpiffeID        = "test_server_spiffe_id"
    63  	clientMessage         = "echo"
    64  	defaultE2ETestTimeout = time.Second * 5
    65  )
    66  
    67  var (
    68  	//go:embed testdata/client_cert.pem
    69  	clientCertpem []byte
    70  	//go:embed testdata/client_key.pem
    71  	clientKeypem []byte
    72  	//go:embed testdata/server_cert.pem
    73  	serverCertpem []byte
    74  	//go:embed testdata/server_key.pem
    75  	serverKeypem []byte
    76  
    77  	//go:embed testdata/mds_root_cert.pem
    78  	mdsRootCertPem []byte
    79  	//go:embed testdata/mds_server_cert.pem
    80  	mdsServerCertPem []byte
    81  	//go:embed testdata/mds_server_key.pem
    82  	mdsServerKeyPem []byte
    83  	//go:embed testdata/mds_client_cert.pem
    84  	mdsClientCertPem []byte
    85  	//go:embed testdata/mds_client_key.pem
    86  	mdsClientKeyPem []byte
    87  	//go:embed testdata/self_signed_cert.pem
    88  	selfSignedCertPem []byte
    89  	//go:embed testdata/self_signed_key.pem
    90  	selfSignedKeyPem []byte
    91  )
    92  
    93  // server is used to implement helloworld.GreeterServer.
    94  type server struct {
    95  	helloworldpb.UnimplementedGreeterServer
    96  }
    97  
    98  // SayHello implements helloworld.GreeterServer.
    99  func (s *server) SayHello(_ context.Context, in *helloworldpb.HelloRequest) (*helloworldpb.HelloReply, error) {
   100  	return &helloworldpb.HelloReply{Message: "Hello " + in.GetName()}, nil
   101  }
   102  
   103  // startFakeS2A starts up a fake S2A and returns the address that it is
   104  // listening on.
   105  func startFakeS2A(t *testing.T, enableLegacyMode bool, expToken string, serverTransportCreds credentials.TransportCredentials) string {
   106  	lis, err := net.Listen("tcp", ":")
   107  	if err != nil {
   108  		t.Errorf("net.Listen(tcp, :0) failed: %v", err)
   109  	}
   110  
   111  	var s *grpc.Server
   112  	if serverTransportCreds != nil {
   113  		s = grpc.NewServer(grpc.Creds(serverTransportCreds))
   114  	} else {
   115  		s = grpc.NewServer()
   116  	}
   117  
   118  	if enableLegacyMode {
   119  		s2apb.RegisterS2AServiceServer(s, &service.FakeHandshakerService{})
   120  	} else {
   121  		s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
   122  	}
   123  	go func() {
   124  		if err := s.Serve(lis); err != nil {
   125  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
   126  		}
   127  	}()
   128  	return lis.Addr().String()
   129  }
   130  
   131  // startFakeS2AOnUDS starts up a fake S2A on UDS and returns the address that
   132  // it is listening on.
   133  func startFakeS2AOnUDS(t *testing.T, enableLegacyMode bool, expToken string) string {
   134  	dir, err := ioutil.TempDir("/tmp", "socket_dir")
   135  	if err != nil {
   136  		t.Errorf("Unable to create temporary directory: %v", err)
   137  	}
   138  	udsAddress := filepath.Join(dir, "socket")
   139  	lis, err := net.Listen("unix", filepath.Join(dir, "socket"))
   140  	if err != nil {
   141  		t.Errorf("net.Listen(unix, %s) failed: %v", udsAddress, err)
   142  	}
   143  	s := grpc.NewServer()
   144  	if enableLegacyMode {
   145  		s2apb.RegisterS2AServiceServer(s, &service.FakeHandshakerService{})
   146  	} else {
   147  		s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
   148  	}
   149  	go func() {
   150  		if err := s.Serve(lis); err != nil {
   151  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
   152  		}
   153  	}()
   154  	return fmt.Sprintf("unix://%s", lis.Addr().String())
   155  }
   156  
   157  // startServer starts up a server and returns the address that it is listening
   158  // on.
   159  func startServer(t *testing.T, s2aAddress string, transportCreds credentials.TransportCredentials, enableLegacyMode bool) string {
   160  	serverOpts := &ServerOptions{
   161  		LocalIdentities:  []Identity{NewSpiffeID(serverSpiffeID)},
   162  		S2AAddress:       s2aAddress,
   163  		TransportCreds:   transportCreds,
   164  		EnableLegacyMode: enableLegacyMode,
   165  	}
   166  	creds, err := NewServerCreds(serverOpts)
   167  	if err != nil {
   168  		t.Errorf("NewServerCreds(%v) failed: %v", serverOpts, err)
   169  	}
   170  
   171  	lis, err := net.Listen("tcp", ":0")
   172  	if err != nil {
   173  		t.Errorf("net.Listen(tcp, :0) failed: %v", err)
   174  	}
   175  	s := grpc.NewServer(grpc.Creds(creds))
   176  	helloworldpb.RegisterGreeterServer(s, &server{})
   177  	go func() {
   178  		if err := s.Serve(lis); err != nil {
   179  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
   180  		}
   181  	}()
   182  	return lis.Addr().String()
   183  }
   184  
   185  // runClient starts up a client and calls the server.
   186  func runClient(ctx context.Context, t *testing.T, clientS2AAddress string, transportCreds credentials.TransportCredentials, serverAddr string, enableLegacyMode bool, fallbackHandshake fallback.ClientHandshake) {
   187  	clientOpts := &ClientOptions{
   188  		TargetIdentities: []Identity{NewSpiffeID(serverSpiffeID)},
   189  		LocalIdentity:    NewHostname(clientHostname),
   190  		S2AAddress:       clientS2AAddress,
   191  		TransportCreds:   transportCreds,
   192  		EnableLegacyMode: enableLegacyMode,
   193  		FallbackOpts: &FallbackOptions{
   194  			FallbackClientHandshakeFunc: fallbackHandshake,
   195  		},
   196  	}
   197  	creds, err := NewClientCreds(clientOpts)
   198  	if err != nil {
   199  		t.Errorf("NewClientCreds(%v) failed: %v", clientOpts, err)
   200  	}
   201  	dialOptions := []grpc.DialOption{
   202  		grpc.WithTransportCredentials(creds),
   203  		grpc.WithBlock(),
   204  	}
   205  
   206  	grpclog.Info("Client dialing server at address: %v", serverAddr)
   207  	// Establish a connection to the server.
   208  	conn, err := grpc.Dial(serverAddr, dialOptions...)
   209  	if err != nil {
   210  		t.Errorf("grpc.Dial(%v, %v) failed: %v", serverAddr, dialOptions, err)
   211  	}
   212  	defer conn.Close()
   213  
   214  	// Contact the server.
   215  	peer := new(peer.Peer)
   216  	c := helloworldpb.NewGreeterClient(conn)
   217  	req := &helloworldpb.HelloRequest{Name: clientMessage}
   218  	grpclog.Infof("Client calling SayHello with request: %v", req)
   219  	resp, err := c.SayHello(ctx, req, grpc.Peer(peer), grpc.WaitForReady(true))
   220  	if err != nil {
   221  		t.Errorf("c.SayHello(%v, %v) failed: %v", ctx, req, err)
   222  	}
   223  	if got, want := resp.GetMessage(), "Hello "+clientMessage; got != want {
   224  		t.Errorf("r.GetMessage() = %v, want %v", got, want)
   225  	}
   226  	grpclog.Infof("Client received message from server: %s", resp.GetMessage())
   227  
   228  	if enableLegacyMode {
   229  		// Check the auth info.
   230  		authInfo, err := AuthInfoFromPeer(peer)
   231  		if err != nil {
   232  			t.Errorf("AuthInfoFromContext(peer) failed: %v", err)
   233  		}
   234  		s2aAuthInfo, ok := authInfo.(AuthInfo)
   235  		if !ok {
   236  			t.Errorf("authInfo is not an s2a.AuthInfo")
   237  		}
   238  		if got, want := s2aAuthInfo.AuthType(), authType; got != want {
   239  			t.Errorf("s2aAuthInfo.AuthType() = %v, want %v", got, want)
   240  		}
   241  		if got, want := s2aAuthInfo.ApplicationProtocol(), applicationProtocol; got != want {
   242  			t.Errorf("s2aAuthInfo.ApplicationProtocol() = %v, want %v", got, want)
   243  		}
   244  		if got, want := s2aAuthInfo.TLSVersion(), commonpb.TLSVersion_TLS1_3; got != want {
   245  			t.Errorf("s2aAuthInfo.TLSVersion() = %v, want %v", got, want)
   246  		}
   247  		if got, want := s2aAuthInfo.IsHandshakeResumed(), false; got != want {
   248  			t.Errorf("s2aAuthInfo.IsHandshakeResumed() = %v, want %v", got, want)
   249  		}
   250  		if got, want := s2aAuthInfo.SecurityLevel(), credentials.PrivacyAndIntegrity; got != want {
   251  			t.Errorf("s2aAuthInfo.SecurityLevel() = %v, want %v", got, want)
   252  		}
   253  	}
   254  }
   255  
   256  func TestV1EndToEndUsingFakeS2AOverTCP(t *testing.T) {
   257  	os.Setenv(accessTokenEnvVariable, "")
   258  
   259  	// Start the fake S2As for the client and server.
   260  	serverHandshakerAddr := startFakeS2A(t, true, "", nil)
   261  	grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
   262  	clientHandshakerAddr := startFakeS2A(t, true, "", nil)
   263  	grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
   264  
   265  	// Start the server.
   266  	serverAddr := startServer(t, serverHandshakerAddr, nil, true)
   267  	grpclog.Infof("Server running at address: %v", serverAddr)
   268  
   269  	// Finally, start up the client.
   270  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   271  	defer cancel()
   272  	runClient(ctx, t, clientHandshakerAddr, nil, serverAddr, true, nil)
   273  }
   274  
   275  func TestV2EndToEndUsingFakeS2AOverTCP(t *testing.T) {
   276  	os.Setenv(accessTokenEnvVariable, testV2AccessToken)
   277  	oldRetry := retry.NewRetryer
   278  	defer func() { retry.NewRetryer = oldRetry }()
   279  	testRetryer := retry.NewRetryer()
   280  	retry.NewRetryer = func() *retry.S2ARetryer {
   281  		return testRetryer
   282  	}
   283  	// Start the fake S2As for the client and server.
   284  	serverHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
   285  	grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
   286  	clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
   287  	grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
   288  
   289  	// Start the server.
   290  	serverAddr := startServer(t, serverHandshakerAddr, nil, false)
   291  	grpclog.Infof("Server running at address: %v", serverAddr)
   292  
   293  	// Finally, start up the client.
   294  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   295  	defer cancel()
   296  	runClient(ctx, t, clientHandshakerAddr, nil, serverAddr, false, nil)
   297  	if got, want := testRetryer.Attempts(), 0; got != want {
   298  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   299  	}
   300  }
   301  
   302  func TestV2EndToEndUsingFakeMTLSS2AOverTCP(t *testing.T) {
   303  	os.Setenv(accessTokenEnvVariable, "")
   304  	oldRetry := retry.NewRetryer
   305  	defer func() { retry.NewRetryer = oldRetry }()
   306  	testRetryer := retry.NewRetryer()
   307  	retry.NewRetryer = func() *retry.S2ARetryer {
   308  		return testRetryer
   309  	}
   310  	serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
   311  	// Start the fake S2As for the client and server.
   312  	serverHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
   313  	grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
   314  	clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
   315  	grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
   316  
   317  	clientTransportCreds := loadClientTransportCreds(t, mdsClientCertPem, mdsClientKeyPem)
   318  	// Start the server.
   319  	serverAddr := startServer(t, serverHandshakerAddr, clientTransportCreds, false)
   320  	grpclog.Infof("Server running at address: %v", serverAddr)
   321  
   322  	// Finally, start up the client.
   323  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   324  	defer cancel()
   325  	runClient(ctx, t, clientHandshakerAddr, clientTransportCreds, serverAddr, false, nil)
   326  	if got, want := testRetryer.Attempts(), 0; got != want {
   327  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   328  	}
   329  }
   330  
   331  func TestV2EndToEndUsingFakeMTLSS2AOverTCP_SelfSignedClientTransportCreds(t *testing.T) {
   332  	os.Setenv(accessTokenEnvVariable, "")
   333  	fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
   334  	oldRetry := retry.NewRetryer
   335  	defer func() { retry.NewRetryer = oldRetry }()
   336  	testRetryer := retry.NewRetryer()
   337  	retry.NewRetryer = func() *retry.S2ARetryer {
   338  		return testRetryer
   339  	}
   340  	serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
   341  	// Start the fake S2As for the client and server.
   342  	serverHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
   343  	grpclog.Infof("Fake handshaker for server running at address: %v", serverHandshakerAddr)
   344  	clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
   345  	grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
   346  
   347  	clientTransportCreds := loadClientTransportCreds(t, mdsClientCertPem, mdsClientKeyPem)
   348  	// Load self-signed client credentials.
   349  	selfSignedClientTransportCreds := loadClientTransportCreds(t, selfSignedCertPem, selfSignedKeyPem)
   350  	// Start the server.
   351  	serverAddr := startServer(t, serverHandshakerAddr, clientTransportCreds, false)
   352  	fallbackServerAddr := startFallbackServer(t)
   353  	t.Logf("server running at address: %v", serverAddr)
   354  	t.Logf("fallback server running at address: %v", fallbackServerAddr)
   355  
   356  	// Finally, start up the client.
   357  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   358  	defer cancel()
   359  	fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
   360  	if err != nil {
   361  		t.Errorf("error creating fallback handshake function: %v", err)
   362  	}
   363  	fallbackCalled := false
   364  	fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
   365  		fallbackCalled = true
   366  		return fallbackHandshake(ctx, targetServer, conn, err)
   367  	}
   368  
   369  	// Use self-signed cert to trigger handshake failure when connecting to MTLS-S2A gRPC server.
   370  	// This should cause retries and eventually fallback.
   371  	runClient(ctx, t, clientHandshakerAddr, selfSignedClientTransportCreds, serverAddr, false, fallbackHandshakeWrapper)
   372  	if !fallbackCalled {
   373  		t.Errorf("fallbackHandshake is not called")
   374  	}
   375  	if got, want := testRetryer.Attempts(), 5; got != want {
   376  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   377  	}
   378  }
   379  
   380  func loadServerTransportCreds(t *testing.T, cert, key []byte) credentials.TransportCredentials {
   381  	certificate, err := tls.X509KeyPair(cert, key)
   382  	if err != nil {
   383  		t.Errorf("failed to load S2A server cert/key: %v", err)
   384  	}
   385  	caPool := x509.NewCertPool()
   386  	if !caPool.AppendCertsFromPEM(mdsRootCertPem) {
   387  		t.Errorf("failed to add ca cert")
   388  	}
   389  	tlsConfig := &tls.Config{
   390  		ClientAuth:   tls.RequireAndVerifyClientCert,
   391  		Certificates: []tls.Certificate{certificate},
   392  		ClientCAs:    caPool,
   393  	}
   394  	return credentials.NewTLS(tlsConfig)
   395  }
   396  
   397  func loadClientTransportCreds(t *testing.T, cert, key []byte) credentials.TransportCredentials {
   398  	certificate, err := tls.X509KeyPair(cert, key)
   399  	if err != nil {
   400  		t.Errorf("failed to load S2A client cert/key: %v", err)
   401  	}
   402  	caPool := x509.NewCertPool()
   403  	if !caPool.AppendCertsFromPEM(mdsRootCertPem) {
   404  		t.Errorf("failed to add ca cert")
   405  	}
   406  	tlsConfig := &tls.Config{
   407  		Certificates: []tls.Certificate{certificate},
   408  		RootCAs:      caPool,
   409  	}
   410  	return credentials.NewTLS(tlsConfig)
   411  }
   412  
   413  // startFallbackServer runs a GRPC echo testing server and returns the address.
   414  // It's used to test the default fallback logic upon S2A failure.
   415  func startFallbackServer(t *testing.T) string {
   416  	lis, err := net.Listen("tcp", ":0")
   417  	if err != nil {
   418  		t.Errorf("net.Listen(tcp, :0) failed: %v", err)
   419  	}
   420  	cert, err := tls.X509KeyPair(serverCertpem, serverKeypem)
   421  	if err != nil {
   422  		t.Errorf("failure initializing tls.certificate: %v", err)
   423  	}
   424  	// Client certs are not required for the fallback server.
   425  	creds := credentials.NewTLS(&tls.Config{
   426  		MinVersion:   tls.VersionTLS13,
   427  		MaxVersion:   tls.VersionTLS13,
   428  		Certificates: []tls.Certificate{cert},
   429  	})
   430  	s := grpc.NewServer(grpc.Creds(creds))
   431  	helloworldpb.RegisterGreeterServer(s, &server{})
   432  	go func() {
   433  		if err := s.Serve(lis); err != nil {
   434  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
   435  		}
   436  	}()
   437  	return lis.Addr().String()
   438  }
   439  func TestV2GRPCFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   440  	// Set for testing only.
   441  	fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
   442  	os.Setenv(accessTokenEnvVariable, testV2AccessToken)
   443  	oldRetry := retry.NewRetryer
   444  	defer func() { retry.NewRetryer = oldRetry }()
   445  	testRetryer := retry.NewRetryer()
   446  	retry.NewRetryer = func() *retry.S2ARetryer {
   447  		return testRetryer
   448  	}
   449  	// Start the fake S2A for the server.
   450  	serverHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
   451  	grpclog.Infof("fake handshaker for server running at address: %v", serverHandshakerAddr)
   452  
   453  	// Start the server.
   454  	serverAddr := startServer(t, serverHandshakerAddr, nil, false)
   455  	fallbackServerAddr := startFallbackServer(t)
   456  	t.Logf("server running at address: %v", serverAddr)
   457  	t.Logf("fallback server running at address: %v", fallbackServerAddr)
   458  
   459  	// Finally, start up the client.
   460  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   461  	defer cancel()
   462  	fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
   463  	if err != nil {
   464  		t.Errorf("error creating fallback handshake function: %v", err)
   465  	}
   466  	fallbackCalled := false
   467  	fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
   468  		fallbackCalled = true
   469  		return fallbackHandshake(ctx, targetServer, conn, err)
   470  	}
   471  	runClient(ctx, t, "not_exist", nil, serverAddr, false, fallbackHandshakeWrapper)
   472  	if !fallbackCalled {
   473  		t.Errorf("fallbackHandshake is not called")
   474  	}
   475  	if got, want := testRetryer.Attempts(), 5; got != want {
   476  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   477  	}
   478  }
   479  
   480  func TestV2GRPCRetryAndFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   481  	// Set for testing only.
   482  	fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
   483  	// Set an invalid token to trigger failures and retries when talking to S2A.
   484  	os.Setenv(accessTokenEnvVariable, "invalid_token")
   485  	oldRetry := retry.NewRetryer
   486  	defer func() { retry.NewRetryer = oldRetry }()
   487  	testRetryer := retry.NewRetryer()
   488  	retry.NewRetryer = func() *retry.S2ARetryer {
   489  		return testRetryer
   490  	}
   491  	// Start the fake S2A for the server and client.
   492  	serverHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
   493  	grpclog.Infof("fake handshaker for server running at address: %v", serverHandshakerAddr)
   494  	clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
   495  	grpclog.Infof("Fake handshaker for client running at address: %v", clientHandshakerAddr)
   496  
   497  	// Start the server.
   498  	serverAddr := startServer(t, serverHandshakerAddr, nil, false)
   499  	fallbackServerAddr := startFallbackServer(t)
   500  	t.Logf("server running at address: %v", serverAddr)
   501  	t.Logf("fallback server running at address: %v", fallbackServerAddr)
   502  
   503  	// Finally, start up the client.
   504  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   505  	defer cancel()
   506  	fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
   507  	if err != nil {
   508  		t.Errorf("error creating fallback handshake function: %v", err)
   509  	}
   510  	fallbackCalled := false
   511  	fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
   512  		fallbackCalled = true
   513  		return fallbackHandshake(ctx, targetServer, conn, err)
   514  	}
   515  	runClient(ctx, t, clientHandshakerAddr, nil, serverAddr, false, fallbackHandshakeWrapper)
   516  	if !fallbackCalled {
   517  		t.Errorf("fallbackHandshake is not called")
   518  	}
   519  	if got, want := testRetryer.Attempts(), 5; got != want {
   520  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   521  	}
   522  }
   523  
   524  func TestV1EndToEndUsingTokens(t *testing.T) {
   525  	os.Setenv(accessTokenEnvVariable, testAccessToken)
   526  
   527  	// Start the handshaker servers for the client and server.
   528  	serverS2AAddress := startFakeS2A(t, true, "", nil)
   529  	grpclog.Infof("Fake S2A for server running at address: %v", serverS2AAddress)
   530  	clientS2AAddress := startFakeS2A(t, true, "", nil)
   531  	grpclog.Infof("Fake S2A for client running at address: %v", clientS2AAddress)
   532  
   533  	// Start the server.
   534  	serverAddr := startServer(t, serverS2AAddress, nil, true)
   535  	grpclog.Infof("Server running at address: %v", serverAddr)
   536  
   537  	// Finally, start up the client.
   538  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   539  	defer cancel()
   540  	runClient(ctx, t, clientS2AAddress, nil, serverAddr, true, nil)
   541  }
   542  
   543  func TestV2EndToEndUsingTokens(t *testing.T) {
   544  	os.Setenv(accessTokenEnvVariable, testV2AccessToken)
   545  
   546  	// Start the handshaker servers for the client and server.
   547  	serverS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
   548  	grpclog.Infof("Fake S2A for server running at address: %v", serverS2AAddress)
   549  	clientS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
   550  	grpclog.Infof("Fake S2A for client running at address: %v", clientS2AAddress)
   551  
   552  	// Start the server.
   553  	serverAddr := startServer(t, serverS2AAddress, nil, false)
   554  	grpclog.Infof("Server running at address: %v", serverAddr)
   555  
   556  	// Finally, start up the client.
   557  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   558  	defer cancel()
   559  	runClient(ctx, t, clientS2AAddress, nil, serverAddr, false, nil)
   560  }
   561  
   562  func TestV2EndToEndEmptyToken(t *testing.T) {
   563  	os.Unsetenv(accessTokenEnvVariable)
   564  
   565  	// Start the handshaker servers for the client and server.
   566  	serverS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
   567  	grpclog.Infof("Fake S2A for server running at address: %v", serverS2AAddress)
   568  	clientS2AAddress := startFakeS2A(t, false, testV2AccessToken, nil)
   569  	grpclog.Infof("Fake S2A for client running at address: %v", clientS2AAddress)
   570  
   571  	// Start the server.
   572  	serverAddr := startServer(t, serverS2AAddress, nil, false)
   573  	grpclog.Infof("Server running at address: %v", serverAddr)
   574  
   575  	// Finally, start up the client.
   576  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   577  	defer cancel()
   578  	runClient(ctx, t, clientS2AAddress, nil, serverAddr, false, nil)
   579  }
   580  
   581  func TestV1EndToEndUsingFakeS2AOnUDS(t *testing.T) {
   582  	os.Setenv(accessTokenEnvVariable, "")
   583  
   584  	// Start fake S2As for use by the client and server.
   585  	serverS2AAddress := startFakeS2AOnUDS(t, true, "")
   586  	grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddress)
   587  	clientS2AAddress := startFakeS2AOnUDS(t, true, "")
   588  	grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddress)
   589  
   590  	// Start the server.
   591  	serverAddress := startServer(t, serverS2AAddress, nil, true)
   592  	grpclog.Infof("Server running at address: %v", serverS2AAddress)
   593  
   594  	// Finally, start up the client.
   595  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   596  	defer cancel()
   597  	runClient(ctx, t, clientS2AAddress, nil, serverAddress, true, nil)
   598  }
   599  
   600  func TestV2EndToEndUsingFakeS2AOnUDS(t *testing.T) {
   601  	os.Setenv(accessTokenEnvVariable, testV2AccessToken)
   602  
   603  	// Start fake S2As for use by the client and server.
   604  	serverS2AAddress := startFakeS2AOnUDS(t, false, testV2AccessToken)
   605  	grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddress)
   606  	clientS2AAddress := startFakeS2AOnUDS(t, false, testV2AccessToken)
   607  	grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddress)
   608  
   609  	// Start the server.
   610  	serverAddress := startServer(t, serverS2AAddress, nil, false)
   611  	grpclog.Infof("Server running at address: %v", serverS2AAddress)
   612  
   613  	// Finally, start up the client.
   614  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   615  	defer cancel()
   616  	runClient(ctx, t, clientS2AAddress, nil, serverAddress, false, nil)
   617  }
   618  
   619  func TestNewTLSClientConfigFactoryWithTokenManager(t *testing.T) {
   620  	os.Setenv(accessTokenEnvVariable, "TestNewTLSClientConfigFactory_token")
   621  	s2AAddr := startFakeS2A(t, false, "TestNewTLSClientConfigFactory_token", nil)
   622  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   623  	defer cancel()
   624  
   625  	factory, err := NewTLSClientConfigFactory(&ClientOptions{
   626  		S2AAddress: s2AAddr,
   627  	})
   628  	if err != nil {
   629  		t.Errorf("NewTLSClientConfigFactory() failed: %v", err)
   630  	}
   631  
   632  	config, err := factory.Build(ctx, nil)
   633  	if err != nil {
   634  		t.Errorf("Build tls config failed: %v", err)
   635  	}
   636  
   637  	cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
   638  	if err != nil {
   639  		t.Fatalf("tls.X509KeyPair failed: %v", err)
   640  	}
   641  
   642  	if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
   643  		t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
   644  	}
   645  }
   646  
   647  func TestNewTLSClientConfigFactoryWithoutTokenManager(t *testing.T) {
   648  	os.Unsetenv(accessTokenEnvVariable)
   649  	s2AAddr := startFakeS2A(t, false, "ignored-value", nil)
   650  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETestTimeout)
   651  	defer cancel()
   652  
   653  	factory, err := NewTLSClientConfigFactory(&ClientOptions{
   654  		S2AAddress: s2AAddr,
   655  	})
   656  	if err != nil {
   657  		t.Errorf("NewTLSClientConfigFactory() failed: %v", err)
   658  	}
   659  
   660  	config, err := factory.Build(ctx, nil)
   661  	if err != nil {
   662  		t.Errorf("Build tls config failed: %v", err)
   663  	}
   664  
   665  	cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
   666  	if err != nil {
   667  		t.Fatalf("tls.X509KeyPair failed: %v", err)
   668  	}
   669  	if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
   670  		t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
   671  	}
   672  }
   673  
   674  // startHTTPServer runs an HTTP server on a random local port and serves a /hello endpoint.
   675  // The response of the /hello endpoint should be passed in via the `resp` parameter.
   676  // It returns the address of the server.
   677  func startHTTPServer(t *testing.T, resp string) string {
   678  	cert, _ := tls.X509KeyPair(serverCertpem, serverKeypem)
   679  	tlsConfig := tls.Config{
   680  		MinVersion:   tls.VersionTLS13,
   681  		MaxVersion:   tls.VersionTLS13,
   682  		Certificates: []tls.Certificate{cert},
   683  	}
   684  	s := http.NewServeMux()
   685  	s.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
   686  		fmt.Fprintf(w, resp)
   687  	})
   688  	lis, err := tls.Listen("tcp", ":0", &tlsConfig)
   689  	if err != nil {
   690  		t.Errorf("net.Listen(tcp, :0) failed: %v", err)
   691  	}
   692  	go func() {
   693  		http.Serve(lis, s)
   694  	}()
   695  	return lis.Addr().String()
   696  }
   697  
   698  // runHTTPClient starts an HTTP client and talks to an HTTP server using S2A.
   699  // It returns the response from the /hello endpoint.
   700  func runHTTPClient(t *testing.T, clientS2AAddress string, transportCreds credentials.TransportCredentials, serverAddr string, fallbackOpts *FallbackOptions) string {
   701  	dialTLSContext := NewS2ADialTLSContextFunc(&ClientOptions{
   702  		S2AAddress:     clientS2AAddress,
   703  		TransportCreds: transportCreds,
   704  		FallbackOpts:   fallbackOpts,
   705  	})
   706  
   707  	tr := http.Transport{
   708  		DialTLSContext: dialTLSContext,
   709  	}
   710  
   711  	client := &http.Client{Transport: &tr}
   712  	reqURL := fmt.Sprintf("https://%s/hello", serverAddr)
   713  	t.Logf("reqURL is set to: %v", reqURL)
   714  	req, err := http.NewRequest(http.MethodGet, reqURL, nil)
   715  	if err != nil {
   716  		t.Errorf("error creating new HTTP request: %v", err)
   717  	}
   718  	resp, err := client.Do(req)
   719  	if err != nil {
   720  		t.Errorf("error making client HTTP request: %v", err)
   721  	}
   722  	respBody, err := io.ReadAll(resp.Body)
   723  	if err != nil {
   724  		t.Errorf("error reading HTTP response: %v", err)
   725  	}
   726  	return string(respBody)
   727  }
   728  func TestHTTPEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   729  	os.Setenv(accessTokenEnvVariable, testV2AccessToken)
   730  	oldRetry := retry.NewRetryer
   731  	defer func() { retry.NewRetryer = oldRetry }()
   732  	testRetryer := retry.NewRetryer()
   733  	retry.NewRetryer = func() *retry.S2ARetryer {
   734  		return testRetryer
   735  	}
   736  
   737  	// Start the fake S2As for the client.
   738  	clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
   739  	t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
   740  
   741  	// Start the server.
   742  	serverAddr := startHTTPServer(t, "hello")
   743  	t.Logf("HTTP server running at address: %v", serverAddr)
   744  
   745  	// Finally, start up the client.
   746  	resp := runHTTPClient(t, clientHandshakerAddr, nil, serverAddr, nil)
   747  
   748  	if got, want := resp, "hello"; got != want {
   749  		t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
   750  	}
   751  	if got, want := testRetryer.Attempts(), 0; got != want {
   752  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   753  	}
   754  }
   755  
   756  func TestHTTPEndToEndSUsingFakeMTLSS2AOverTCP(t *testing.T) {
   757  	os.Setenv(accessTokenEnvVariable, "")
   758  	oldRetry := retry.NewRetryer
   759  	defer func() { retry.NewRetryer = oldRetry }()
   760  	testRetryer := retry.NewRetryer()
   761  	retry.NewRetryer = func() *retry.S2ARetryer {
   762  		return testRetryer
   763  	}
   764  
   765  	// Start the fake S2As for the client.
   766  	serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
   767  	clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
   768  	t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
   769  
   770  	// Start the server.
   771  	serverAddr := startHTTPServer(t, "hello")
   772  	t.Logf("HTTP server running at address: %v", serverAddr)
   773  
   774  	// Finally, start up the client.
   775  	clientTransportCreds := loadClientTransportCreds(t, mdsClientCertPem, mdsClientKeyPem)
   776  	resp := runHTTPClient(t, clientHandshakerAddr, clientTransportCreds, serverAddr, nil)
   777  
   778  	if got, want := resp, "hello"; got != want {
   779  		t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
   780  	}
   781  	if got, want := testRetryer.Attempts(), 0; got != want {
   782  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   783  	}
   784  }
   785  
   786  func TestHTTPEndToEndSUsingFakeMTLSS2AOverTCP_SelfSignedClientTransportCreds(t *testing.T) {
   787  	fallback.FallbackTLSConfigHTTP.InsecureSkipVerify = true
   788  	os.Setenv(accessTokenEnvVariable, "")
   789  	oldRetry := retry.NewRetryer
   790  	defer func() { retry.NewRetryer = oldRetry }()
   791  	testRetryer := retry.NewRetryer()
   792  	retry.NewRetryer = func() *retry.S2ARetryer {
   793  		return testRetryer
   794  	}
   795  
   796  	// Start the fake S2As for the client.
   797  	serverTransportCreds := loadServerTransportCreds(t, mdsServerCertPem, mdsServerKeyPem)
   798  	clientHandshakerAddr := startFakeS2A(t, false, "", serverTransportCreds)
   799  	t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
   800  
   801  	serverAddr := startHTTPServer(t, "hello")
   802  	t.Logf("HTTP server running at address: %v", serverAddr)
   803  
   804  	fallbackServerAddr := startHTTPServer(t, "hello fallback")
   805  	t.Logf("fallback HTTP server running at address: %v", fallbackServerAddr)
   806  
   807  	// Configure fallback options.
   808  	fbDialer, fbAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackServerAddr)
   809  	if err != nil {
   810  		t.Errorf("error creating fallback dialer: %v", err)
   811  	}
   812  	fallbackOpts := &FallbackOptions{
   813  		FallbackDialer: &FallbackDialer{
   814  			Dialer:     fbDialer,
   815  			ServerAddr: fbAddr,
   816  		},
   817  	}
   818  	// Load self-signed client credentials.
   819  	selfSignedClientTransportCreds := loadClientTransportCreds(t, selfSignedCertPem, selfSignedKeyPem)
   820  	// Use self-signed cert to trigger handshake failure when connecting to MTLS-S2A gRPC server.
   821  	// This should cause retries and eventually fallback.
   822  	resp := runHTTPClient(t, clientHandshakerAddr, selfSignedClientTransportCreds, serverAddr, fallbackOpts)
   823  	if got, want := resp, "hello fallback"; got != want {
   824  		t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
   825  	}
   826  
   827  	if got, want := testRetryer.Attempts(), 5; got != want {
   828  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   829  	}
   830  }
   831  
   832  func TestHTTPFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   833  	fallback.FallbackTLSConfigHTTP.InsecureSkipVerify = true
   834  	os.Setenv(accessTokenEnvVariable, testV2AccessToken)
   835  	oldRetry := retry.NewRetryer
   836  	defer func() { retry.NewRetryer = oldRetry }()
   837  	testRetryer := retry.NewRetryer()
   838  	retry.NewRetryer = func() *retry.S2ARetryer {
   839  		return testRetryer
   840  	}
   841  
   842  	// Start the server.
   843  	serverAddr := startHTTPServer(t, "hello")
   844  	t.Logf("HTTP server running at address: %v", serverAddr)
   845  
   846  	fallbackServerAddr := startHTTPServer(t, "hello fallback")
   847  	t.Logf("fallback HTTP server running at address: %v", fallbackServerAddr)
   848  
   849  	// Configure fallback options.
   850  	fbDialer, fbAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackServerAddr)
   851  	if err != nil {
   852  		t.Errorf("error creating fallback dialer: %v", err)
   853  	}
   854  
   855  	fallbackOpts := &FallbackOptions{
   856  		FallbackDialer: &FallbackDialer{
   857  			Dialer:     fbDialer,
   858  			ServerAddr: fbAddr,
   859  		},
   860  	}
   861  	// Set wrong client S2A address to trigger S2A failure and fallback.
   862  	resp := runHTTPClient(t, "not_exist", nil, serverAddr, fallbackOpts)
   863  
   864  	if got, want := resp, "hello fallback"; got != want {
   865  		t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
   866  	}
   867  
   868  	if got, want := testRetryer.Attempts(), 5; got != want {
   869  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   870  	}
   871  }
   872  
   873  func TestHTTPRetryAndFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   874  	fallback.FallbackTLSConfigHTTP.InsecureSkipVerify = true
   875  	// Set an invalid token to trigger failures and retries when talking to S2A.
   876  	os.Setenv(accessTokenEnvVariable, "invalid_token")
   877  	oldRetry := retry.NewRetryer
   878  	defer func() { retry.NewRetryer = oldRetry }()
   879  	testRetryer := retry.NewRetryer()
   880  	retry.NewRetryer = func() *retry.S2ARetryer {
   881  		return testRetryer
   882  	}
   883  
   884  	// Start the fake S2As for the client.
   885  	clientHandshakerAddr := startFakeS2A(t, false, testV2AccessToken, nil)
   886  	t.Logf("fake handshaker for client running at address: %v", clientHandshakerAddr)
   887  
   888  	serverAddr := startHTTPServer(t, "hello")
   889  	t.Logf("HTTP server running at address: %v", serverAddr)
   890  
   891  	fallbackServerAddr := startHTTPServer(t, "hello fallback")
   892  	t.Logf("fallback HTTP server running at address: %v", fallbackServerAddr)
   893  
   894  	// Configure fallback options.
   895  	fbDialer, fbAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackServerAddr)
   896  	if err != nil {
   897  		t.Errorf("error creating fallback dialer: %v", err)
   898  	}
   899  
   900  	fallbackOpts := &FallbackOptions{
   901  		FallbackDialer: &FallbackDialer{
   902  			Dialer:     fbDialer,
   903  			ServerAddr: fbAddr,
   904  		},
   905  	}
   906  
   907  	resp := runHTTPClient(t, clientHandshakerAddr, nil, serverAddr, fallbackOpts)
   908  
   909  	if got, want := resp, "hello fallback"; got != want {
   910  		t.Errorf("expecting HTTP response:[%s], got [%s]", want, got)
   911  	}
   912  
   913  	if got, want := testRetryer.Attempts(), 5; got != want {
   914  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   915  	}
   916  }
   917  

View as plain text