...

Source file src/github.com/google/s2a-go/internal/v2/s2av2_e2e_test.go

Documentation: github.com/google/s2a-go/internal/v2

     1  /*
     2   *
     3   * Copyright 2022 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package v2
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"crypto/tls"
    25  	"fmt"
    26  	"io/ioutil"
    27  	"net"
    28  	"os"
    29  	"path/filepath"
    30  	"testing"
    31  	"time"
    32  
    33  	_ "embed"
    34  
    35  	"github.com/google/s2a-go/fallback"
    36  	"github.com/google/s2a-go/internal/tokenmanager"
    37  	"github.com/google/s2a-go/internal/v2/fakes2av2"
    38  	"github.com/google/s2a-go/retry"
    39  	"google.golang.org/grpc/credentials"
    40  	"google.golang.org/grpc/grpclog"
    41  
    42  	grpc "google.golang.org/grpc"
    43  
    44  	commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
    45  	helloworldpb "github.com/google/s2a-go/internal/proto/examples/helloworld_go_proto"
    46  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    47  )
    48  
    49  const (
    50  	accessTokenEnvVariable = "S2A_ACCESS_TOKEN"
    51  	defaultE2ETimeout      = time.Second * 5
    52  	clientMessage          = "echo"
    53  )
    54  
    55  var (
    56  	//go:embed testdata/client_cert.pem
    57  	clientCertpem []byte
    58  	//go:embed testdata/client_key.pem
    59  	clientKeypem []byte
    60  	//go:embed testdata/server_cert.pem
    61  	serverCertpem []byte
    62  	//go:embed testdata/server_key.pem
    63  	serverKeypem []byte
    64  )
    65  
    66  // server implements the helloworld.GreeterServer.
    67  type server struct {
    68  	helloworldpb.UnimplementedGreeterServer
    69  }
    70  
    71  // SayHello implements helloworld.GreeterServer.
    72  func (s *server) SayHello(_ context.Context, in *helloworldpb.HelloRequest) (*helloworldpb.HelloReply, error) {
    73  	return &helloworldpb.HelloReply{Message: "Hello " + in.GetName()}, nil
    74  }
    75  
    76  // startFakeS2A starts up a fake S2A and returns the address that it is
    77  // listening on.
    78  func startFakeS2A(t *testing.T, expToken string) string {
    79  	lis, err := net.Listen("tcp", ":")
    80  	if err != nil {
    81  		t.Errorf("net.Listen(tcp, :0) failed: %v", err)
    82  	}
    83  	s := grpc.NewServer()
    84  	s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
    85  	go func() {
    86  		if err := s.Serve(lis); err != nil {
    87  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
    88  		}
    89  	}()
    90  	return lis.Addr().String()
    91  }
    92  
    93  // startFakeS2A starts up a fake S2A on UDS and returns the address that it is
    94  // listening on.
    95  func startFakeS2AOnUDS(t *testing.T, expToken string) string {
    96  	dir, err := ioutil.TempDir("/tmp", "socket_dir")
    97  	if err != nil {
    98  		t.Errorf("Unable to create temporary directory: %v", err)
    99  	}
   100  	udsAddress := filepath.Join(dir, "socket")
   101  	lis, err := net.Listen("unix", filepath.Join(dir, "socket"))
   102  	if err != nil {
   103  		t.Errorf("net.Listen(unix, %s) failed: %v", udsAddress, err)
   104  	}
   105  	s := grpc.NewServer()
   106  	s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{ExpectedToken: expToken})
   107  	go func() {
   108  		if err := s.Serve(lis); err != nil {
   109  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
   110  		}
   111  	}()
   112  	return fmt.Sprintf("unix://%s", lis.Addr().String())
   113  }
   114  
   115  // startServer starts up a server and returns the address that it is listening
   116  // on.
   117  func startServer(t *testing.T, s2aAddress string, localIdentities []*commonpbv1.Identity) string {
   118  	// TODO(rmehta19): Pass verificationMode as a parameter to startServer.
   119  	creds, err := NewServerCreds(s2aAddress, nil, localIdentities, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil)
   120  	if err != nil {
   121  		t.Errorf("NewServerCreds(%s) failed: %v", s2aAddress, err)
   122  	}
   123  
   124  	lis, err := net.Listen("tcp", ":0")
   125  	if err != nil {
   126  		t.Errorf("net.Listen(tcp, :0) failed: %v", err)
   127  	}
   128  	s := grpc.NewServer(grpc.Creds(creds))
   129  	helloworldpb.RegisterGreeterServer(s, &server{})
   130  	go func() {
   131  		if err := s.Serve(lis); err != nil {
   132  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
   133  		}
   134  	}()
   135  	return lis.Addr().String()
   136  }
   137  
   138  // startFallbackServer runs a GRPC echo testing server and returns the address.
   139  // It's used to test the default fallback logic upon S2A failure.
   140  func startFallbackServer(t *testing.T) string {
   141  	lis, err := net.Listen("tcp", ":0")
   142  	if err != nil {
   143  		t.Errorf("net.Listen(tcp, :0) failed: %v", err)
   144  	}
   145  	cert, err := tls.X509KeyPair(serverCertpem, serverKeypem)
   146  	if err != nil {
   147  		t.Errorf("failure initializing tls.certificate: %v", err)
   148  	}
   149  	// Client certs are not required for the fallback server.
   150  	creds := credentials.NewTLS(&tls.Config{
   151  		MinVersion:   tls.VersionTLS13,
   152  		MaxVersion:   tls.VersionTLS13,
   153  		Certificates: []tls.Certificate{cert},
   154  	})
   155  	s := grpc.NewServer(grpc.Creds(creds))
   156  	helloworldpb.RegisterGreeterServer(s, &server{})
   157  	go func() {
   158  		if err := s.Serve(lis); err != nil {
   159  			t.Errorf("s.Serve(%v) failed: %v", lis, err)
   160  		}
   161  	}()
   162  	return lis.Addr().String()
   163  }
   164  
   165  // runClient starts up a client and calls the server.
   166  func runClient(ctx context.Context, t *testing.T, clientS2AAddress, serverAddr string, localIdentity *commonpbv1.Identity, fallbackHandshake fallback.ClientHandshake) {
   167  	creds, err := NewClientCreds(clientS2AAddress, nil, localIdentity, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, fallbackHandshake, nil, nil)
   168  	if err != nil {
   169  		t.Errorf("NewClientCreds(%s) failed: %v", clientS2AAddress, err)
   170  	}
   171  	dialOptions := []grpc.DialOption{
   172  		grpc.WithTransportCredentials(creds),
   173  		grpc.WithBlock(),
   174  	}
   175  
   176  	grpclog.Info("Client dialing server at address: %v", serverAddr)
   177  	// Establish a connection to the server.
   178  	conn, err := grpc.Dial(serverAddr, dialOptions...)
   179  	if err != nil {
   180  		t.Errorf("grpc.Dial(%v, %v) failed: %v", serverAddr, dialOptions, err)
   181  	}
   182  	defer conn.Close()
   183  
   184  	// Contact the server.
   185  	c := helloworldpb.NewGreeterClient(conn)
   186  	req := &helloworldpb.HelloRequest{Name: clientMessage}
   187  	grpclog.Infof("Client calling SayHello with request: %v", req)
   188  	resp, err := c.SayHello(ctx, req, grpc.WaitForReady(true))
   189  	if err != nil {
   190  		t.Errorf("c.SayHello(%v, %v) failed: %v", ctx, req, err)
   191  	}
   192  	if got, want := resp.GetMessage(), "Hello "+clientMessage; got != want {
   193  		t.Errorf("r.GetMessage() = %v, want %v", got, want)
   194  	}
   195  	grpclog.Infof("Client received message from server: %s", resp.GetMessage())
   196  }
   197  
   198  func TestEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   199  	os.Setenv(accessTokenEnvVariable, "TestE2ETCP_token")
   200  	oldRetry := retry.NewRetryer
   201  	defer func() { retry.NewRetryer = oldRetry }()
   202  	testRetryer := retry.NewRetryer()
   203  	retry.NewRetryer = func() *retry.S2ARetryer {
   204  		return testRetryer
   205  	}
   206  	// Start the fake S2As for the client and server.
   207  	serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
   208  	grpclog.Infof("Fake handshaker for server running at address: %v", serverS2AAddr)
   209  	clientS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
   210  	grpclog.Infof("Fake handshaker for client running at address: %v", clientS2AAddr)
   211  
   212  	// Start the server.
   213  	localIdentities := []*commonpbv1.Identity{
   214  		{
   215  			IdentityOneof: &commonpbv1.Identity_Hostname{
   216  				Hostname: "test_rsa_server_identity",
   217  			},
   218  		},
   219  	}
   220  	serverAddr := startServer(t, serverS2AAddr, localIdentities)
   221  	grpclog.Infof("Server running at address: %v", serverAddr)
   222  
   223  	// Finally, start up the client.
   224  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   225  	defer cancel()
   226  	runClient(ctx, t, clientS2AAddr, serverAddr, &commonpbv1.Identity{
   227  		IdentityOneof: &commonpbv1.Identity_Hostname{
   228  			Hostname: "test_rsa_client_identity",
   229  		},
   230  	}, nil)
   231  	if got, want := testRetryer.Attempts(), 0; got != want {
   232  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   233  	}
   234  }
   235  
   236  func TestEndToEndUsingFakeS2AOverTCPEmptyId(t *testing.T) {
   237  	os.Setenv(accessTokenEnvVariable, "TestE2ETCP_token")
   238  	// Start the fake S2As for the client and server.
   239  	serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
   240  	grpclog.Infof("Fake handshaker for server running at address: %v", serverS2AAddr)
   241  	clientS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
   242  	grpclog.Infof("Fake handshaker for client running at address: %v", clientS2AAddr)
   243  
   244  	// Start the server.
   245  	var localIdentities []*commonpbv1.Identity
   246  	localIdentities = append(localIdentities, nil)
   247  	serverAddr := startServer(t, serverS2AAddr, localIdentities)
   248  	grpclog.Infof("Server running at address: %v", serverAddr)
   249  
   250  	// Finally, start up the client.
   251  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   252  	defer cancel()
   253  	runClient(ctx, t, clientS2AAddr, serverAddr, nil, nil)
   254  }
   255  
   256  func TestEndToEndUsingFakeS2AOnUDS(t *testing.T) {
   257  	os.Setenv(accessTokenEnvVariable, "TestE2EUDS_token")
   258  	// Start fake S2As for use by the client and server.
   259  	serverS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
   260  	grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddr)
   261  	clientS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
   262  	grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddr)
   263  
   264  	// Start the server.
   265  	localIdentities := []*commonpbv1.Identity{
   266  		{
   267  			IdentityOneof: &commonpbv1.Identity_Hostname{
   268  				Hostname: "test_rsa_server_identity",
   269  			},
   270  		},
   271  	}
   272  	serverAddr := startServer(t, serverS2AAddr, localIdentities)
   273  	grpclog.Infof("Server running at address: %v", serverAddr)
   274  
   275  	// Finally, start up the client.
   276  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   277  	defer cancel()
   278  	runClient(ctx, t, clientS2AAddr, serverAddr, &commonpbv1.Identity{
   279  		IdentityOneof: &commonpbv1.Identity_Hostname{
   280  			Hostname: "test_rsa_client_identity",
   281  		},
   282  	}, nil)
   283  }
   284  
   285  func TestEndToEndUsingFakeS2AOnUDSEmptyId(t *testing.T) {
   286  	os.Setenv(accessTokenEnvVariable, "TestE2EUDS_token")
   287  	// Start fake S2As for use by the client and server.
   288  	serverS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
   289  	grpclog.Infof("Fake S2A for server listening on UDS at address: %v", serverS2AAddr)
   290  	clientS2AAddr := startFakeS2AOnUDS(t, "TestE2EUDS_token")
   291  	grpclog.Infof("Fake S2A for client listening on UDS at address: %v", clientS2AAddr)
   292  
   293  	// Start the server.
   294  	var localIdentities []*commonpbv1.Identity
   295  	localIdentities = append(localIdentities, nil)
   296  	serverAddr := startServer(t, serverS2AAddr, localIdentities)
   297  	grpclog.Infof("Server running at address: %v", serverAddr)
   298  
   299  	// Finally, start up the client.
   300  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   301  	defer cancel()
   302  	runClient(ctx, t, clientS2AAddr, serverAddr, nil, nil)
   303  }
   304  
   305  func TestGRPCFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   306  	// Set for testing only.
   307  	fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
   308  	os.Setenv(accessTokenEnvVariable, "TestE2ETCP_token")
   309  	oldRetry := retry.NewRetryer
   310  	defer func() { retry.NewRetryer = oldRetry }()
   311  	testRetryer := retry.NewRetryer()
   312  	retry.NewRetryer = func() *retry.S2ARetryer {
   313  		return testRetryer
   314  	}
   315  
   316  	// Start the fake S2A for the server.
   317  	serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
   318  	t.Logf("Fake handshaker for server running at address: %v", serverS2AAddr)
   319  
   320  	// Start the server.
   321  	localIdentities := []*commonpbv1.Identity{
   322  		{
   323  			IdentityOneof: &commonpbv1.Identity_Hostname{
   324  				Hostname: "test_rsa_server_identity",
   325  			},
   326  		},
   327  	}
   328  	serverAddr := startServer(t, serverS2AAddr, localIdentities)
   329  	fallbackServerAddr := startFallbackServer(t)
   330  	t.Logf("server running at address: %v", serverAddr)
   331  	t.Logf("fallback server running at address: %v", fallbackServerAddr)
   332  
   333  	// Finally, start up the client.
   334  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   335  	defer cancel()
   336  	fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
   337  	if err != nil {
   338  		t.Errorf("error creating fallback handshake function: %v", err)
   339  	}
   340  	fallbackCalled := false
   341  	fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
   342  		fallbackCalled = true
   343  		return fallbackHandshake(ctx, targetServer, conn, err)
   344  	}
   345  	// Set wrong S2A address for client to trigger S2A failure and fallback.
   346  	runClient(ctx, t, "not_exist", serverAddr, &commonpbv1.Identity{
   347  		IdentityOneof: &commonpbv1.Identity_Hostname{
   348  			Hostname: "test_rsa_client_identity",
   349  		},
   350  	}, fallbackHandshakeWrapper)
   351  
   352  	if !fallbackCalled {
   353  		t.Errorf("fallbackHandshake is not called")
   354  	}
   355  	if got, want := testRetryer.Attempts(), 5; got != want {
   356  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   357  	}
   358  }
   359  
   360  func TestGRPCRetryAndFallbackEndToEndUsingFakeS2AOverTCP(t *testing.T) {
   361  	// Set for testing only.
   362  	fallback.FallbackTLSConfigGRPC.InsecureSkipVerify = true
   363  	// Set an invalid token to trigger failures and retries when talking to S2A.
   364  	os.Setenv(accessTokenEnvVariable, "invalid_token")
   365  	oldRetry := retry.NewRetryer
   366  	defer func() { retry.NewRetryer = oldRetry }()
   367  	testRetryer := retry.NewRetryer()
   368  	retry.NewRetryer = func() *retry.S2ARetryer {
   369  		return testRetryer
   370  	}
   371  
   372  	clientS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
   373  	grpclog.Infof("Fake handshaker for client running at address: %v", clientS2AAddr)
   374  	serverS2AAddr := startFakeS2A(t, "TestE2ETCP_token")
   375  	grpclog.Infof("Fake handshaker for server running at address: %v", serverS2AAddr)
   376  
   377  	// Start the server.
   378  	localIdentities := []*commonpbv1.Identity{
   379  		{
   380  			IdentityOneof: &commonpbv1.Identity_Hostname{
   381  				Hostname: "test_rsa_server_identity",
   382  			},
   383  		},
   384  	}
   385  	serverAddr := startServer(t, serverS2AAddr, localIdentities)
   386  	fallbackServerAddr := startFallbackServer(t)
   387  	t.Logf("server running at address: %v", serverAddr)
   388  	t.Logf("fallback server running at address: %v", fallbackServerAddr)
   389  
   390  	// Finally, start up the client.
   391  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   392  	defer cancel()
   393  	fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(fallbackServerAddr)
   394  	if err != nil {
   395  		t.Errorf("error creating fallback handshake function: %v", err)
   396  	}
   397  	fallbackCalled := false
   398  	fallbackHandshakeWrapper := func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error) {
   399  		fallbackCalled = true
   400  		return fallbackHandshake(ctx, targetServer, conn, err)
   401  	}
   402  	runClient(ctx, t, clientS2AAddr, serverAddr, &commonpbv1.Identity{
   403  		IdentityOneof: &commonpbv1.Identity_Hostname{
   404  			Hostname: "test_rsa_client_identity",
   405  		},
   406  	}, fallbackHandshakeWrapper)
   407  
   408  	if !fallbackCalled {
   409  		t.Errorf("fallbackHandshake is not called")
   410  	}
   411  	if got, want := testRetryer.Attempts(), 5; got != want {
   412  		t.Errorf("expecting retryer attempts count:[%v], got [%v]", want, got)
   413  	}
   414  }
   415  
   416  func TestNewClientTlsConfigWithTokenManager(t *testing.T) {
   417  	os.Setenv(accessTokenEnvVariable, "TestNewClientTlsConfig_token")
   418  	s2AAddr := startFakeS2A(t, "TestNewClientTlsConfig_token")
   419  	accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
   420  	if err != nil {
   421  		t.Errorf("tokenmanager.NewSingleTokenAccessTokenManager() failed: %v", err)
   422  	}
   423  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   424  	defer cancel()
   425  	config, err := NewClientTLSConfig(ctx, s2AAddr, nil, accessTokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil)
   426  	if err != nil {
   427  		t.Errorf("NewClientTLSConfig() failed: %v", err)
   428  	}
   429  
   430  	cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
   431  	if err != nil {
   432  		t.Fatalf("tls.X509KeyPair failed: %v", err)
   433  	}
   434  	if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
   435  		t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
   436  	}
   437  }
   438  
   439  func TestNewClientTlsConfigWithoutTokenManager(t *testing.T) {
   440  	os.Unsetenv(accessTokenEnvVariable)
   441  	s2AAddr := startFakeS2A(t, "ignored-value")
   442  	var tokenManager tokenmanager.AccessTokenManager
   443  	ctx, cancel := context.WithTimeout(context.Background(), defaultE2ETimeout)
   444  	defer cancel()
   445  	config, err := NewClientTLSConfig(ctx, s2AAddr, nil, tokenManager, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, "test_server_name", nil)
   446  	if err != nil {
   447  		t.Errorf("NewClientTLSConfig() failed: %v", err)
   448  	}
   449  
   450  	cert, err := tls.X509KeyPair(clientCertpem, clientKeypem)
   451  	if err != nil {
   452  		t.Fatalf("tls.X509KeyPair failed: %v", err)
   453  	}
   454  	if got, want := config.Certificates[0].Certificate[0], cert.Certificate[0]; !bytes.Equal(got, want) {
   455  		t.Errorf("tls.Config has unexpected certificate: got: %v, want: %v", got, want)
   456  	}
   457  }
   458  

View as plain text