...

Source file src/github.com/google/s2a-go/internal/v2/s2av2_test.go

Documentation: github.com/google/s2a-go/internal/v2

     1  /*
     2   *
     3   * Copyright 2022 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package v2
    20  
    21  import (
    22  	"context"
    23  	"os"
    24  	"reflect"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/google/go-cmp/cmp"
    29  	"github.com/google/s2a-go/fallback"
    30  	"github.com/google/s2a-go/internal/tokenmanager"
    31  	"github.com/google/s2a-go/stream"
    32  	"google.golang.org/protobuf/testing/protocmp"
    33  
    34  	commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
    35  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    36  )
    37  
    38  var (
    39  	fakes2av2Address = "0.0.0.0:0"
    40  )
    41  
    42  func TestNewClientCreds(t *testing.T) {
    43  	os.Setenv("S2A_ACCESS_TOKEN", "TestNewClientCreds_s2a_access_token")
    44  	for _, tc := range []struct {
    45  		description string
    46  	}{
    47  		{
    48  			description: "static",
    49  		},
    50  	} {
    51  		t.Run(tc.description, func(t *testing.T) {
    52  			c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
    53  				IdentityOneof: &commonpbv1.Identity_Hostname{
    54  					Hostname: "test_rsa_client_identity",
    55  				},
    56  			}, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil, nil, nil)
    57  			if err != nil {
    58  				t.Fatalf("NewClientCreds() failed: %v", err)
    59  			}
    60  			if got, want := c.Info().SecurityProtocol, s2aSecurityProtocol; got != want {
    61  				t.Errorf("c.Info().SecurityProtocol = %v, want %v", got, want)
    62  			}
    63  			_, ok := c.(*s2av2TransportCreds)
    64  			if !ok {
    65  				t.Fatal("The created creds is not of type s2av2TransportCreds")
    66  			}
    67  		})
    68  	}
    69  }
    70  
    71  func TestNewServerCreds(t *testing.T) {
    72  	os.Setenv("S2A_ACCESS_TOKEN", "TestNewServerCreds_s2a_access_token")
    73  	for _, tc := range []struct {
    74  		description string
    75  	}{
    76  		{
    77  			description: "static",
    78  		},
    79  	} {
    80  		t.Run(tc.description, func(t *testing.T) {
    81  			localIdentities := []*commonpbv1.Identity{
    82  				{
    83  					IdentityOneof: &commonpbv1.Identity_Hostname{
    84  						Hostname: "test_rsa_server_identity",
    85  					},
    86  				},
    87  			}
    88  			c, err := NewServerCreds(fakes2av2Address, nil, localIdentities, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil)
    89  			if err != nil {
    90  				t.Fatalf("NewServerCreds() failed: %v", err)
    91  			}
    92  			if got, want := c.Info().SecurityProtocol, s2aSecurityProtocol; got != want {
    93  				t.Errorf("c.Info().SecurityProtocol = %v, want %v", got, want)
    94  			}
    95  			_, ok := c.(*s2av2TransportCreds)
    96  			if !ok {
    97  				t.Fatal("The created creds is not of type s2av2TransportCreds")
    98  			}
    99  		})
   100  	}
   101  }
   102  
   103  func TestClientHandshakeFail(t *testing.T) {
   104  	cc := &s2av2TransportCreds{isClient: false}
   105  	if _, _, err := cc.ClientHandshake(context.Background(), "", nil); err == nil {
   106  		t.Errorf("c.ClientHandshake(nil, \"\", nil) should fail with incorrect transport credentials")
   107  	}
   108  }
   109  
   110  func TestServerHandshakeFail(t *testing.T) {
   111  	sc := &s2av2TransportCreds{isClient: true}
   112  	if _, _, err := sc.ServerHandshake(nil); err == nil {
   113  		t.Errorf("c.ServerHandshake(nil) should fail with incorrect transport credentials")
   114  	}
   115  }
   116  
   117  func TestInfo(t *testing.T) {
   118  	os.Setenv("S2A_ACCESS_TOKEN", "TestInfo_s2a_access_token")
   119  	c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
   120  		IdentityOneof: &commonpbv1.Identity_Hostname{
   121  			Hostname: "test_rsa_client_identity",
   122  		},
   123  	}, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil, nil, nil)
   124  	if err != nil {
   125  		t.Fatalf("NewClientCreds() failed: %v", err)
   126  	}
   127  	info := c.Info()
   128  	if got, want := info.SecurityProtocol, "tls"; got != want {
   129  		t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
   130  	}
   131  }
   132  
   133  func TestCloneClient(t *testing.T) {
   134  	os.Setenv("S2A_ACCESS_TOKEN", "TestCloneClient_s2a_access_token")
   135  	fallbackFunc, err := fallback.DefaultFallbackClientHandshakeFunc("example.com")
   136  	if err != nil {
   137  		t.Errorf("error creating fallback handshake function: %v", err)
   138  	}
   139  	c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
   140  		IdentityOneof: &commonpbv1.Identity_Hostname{
   141  			Hostname: "test_rsa_client_identity",
   142  		},
   143  	}, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, fallbackFunc, nil, nil)
   144  	if err != nil {
   145  		t.Fatalf("NewClientCreds() failed: %v", err)
   146  	}
   147  	cc := c.Clone()
   148  	s2av2Creds, ok := c.(*s2av2TransportCreds)
   149  	if !ok {
   150  		t.Fatal("The created creds is not of type s2av2TransportCreds")
   151  	}
   152  	s2av2CloneCreds, ok := cc.(*s2av2TransportCreds)
   153  	if !ok {
   154  		t.Fatal("The created clone creds is not of type s2aTransportCreds")
   155  	}
   156  	if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
   157  		xToken, err := x.DefaultToken()
   158  		if err != nil {
   159  			t.Errorf("Failed to compare cloned creds: %v", err)
   160  		}
   161  		yToken, err := y.DefaultToken()
   162  		if err != nil {
   163  			t.Errorf("Failed to compare cloned creds: %v", err)
   164  		}
   165  		if xToken == yToken {
   166  			return true
   167  		}
   168  		return false
   169  	}), cmp.Comparer(func(x, y fallback.ClientHandshake) bool {
   170  		return reflect.ValueOf(x) == reflect.ValueOf(y)
   171  	})), true; got != want {
   172  		t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
   173  	}
   174  	// Change the values and verify the creds were deep copied.
   175  	s2av2CloneCreds.info.SecurityProtocol = "s2a"
   176  	if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
   177  		xToken, err := x.DefaultToken()
   178  		if err != nil {
   179  			t.Errorf("Failed to compare cloned creds: %v", err)
   180  		}
   181  		yToken, err := y.DefaultToken()
   182  		if err != nil {
   183  			t.Errorf("Failed to compare cloned creds: %v", err)
   184  		}
   185  		if xToken == yToken {
   186  			return true
   187  		}
   188  		return false
   189  	})), false; got != want {
   190  		t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
   191  	}
   192  }
   193  
   194  func TestCloneServer(t *testing.T) {
   195  	os.Setenv("S2A_ACCESS_TOKEN", "TestCloneServer_s2a_access_token")
   196  	localIdentities := []*commonpbv1.Identity{
   197  		{
   198  			IdentityOneof: &commonpbv1.Identity_Hostname{
   199  				Hostname: "test_rsa_server_identity",
   200  			},
   201  		},
   202  	}
   203  	c, err := NewServerCreds(fakes2av2Address, nil, localIdentities, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil)
   204  	if err != nil {
   205  		t.Fatalf("NewServerCreds() failed: %v", err)
   206  	}
   207  	cc := c.Clone()
   208  	s2av2Creds, ok := c.(*s2av2TransportCreds)
   209  	if !ok {
   210  		t.Fatal("The created creds is not of type s2av2TransportCreds")
   211  	}
   212  	s2av2CloneCreds, ok := cc.(*s2av2TransportCreds)
   213  	if !ok {
   214  		t.Fatal("The created clone creds is not of type s2aTransportCreds")
   215  	}
   216  	if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
   217  		xToken, err := x.DefaultToken()
   218  		if err != nil {
   219  			t.Errorf("Failed to compare cloned creds: %v", err)
   220  		}
   221  		yToken, err := y.DefaultToken()
   222  		if err != nil {
   223  			t.Errorf("Failed to compare cloned creds: %v", err)
   224  		}
   225  		if xToken == yToken {
   226  			return true
   227  		}
   228  		return false
   229  	})), true; got != want {
   230  		t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
   231  	}
   232  	// Change the values and verify the creds were deep copied.
   233  	s2av2CloneCreds.info.SecurityProtocol = "s2a"
   234  	if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
   235  		xToken, err := x.DefaultToken()
   236  		if err != nil {
   237  			t.Errorf("Failed to compare cloned creds: %v", err)
   238  		}
   239  		yToken, err := y.DefaultToken()
   240  		if err != nil {
   241  			t.Errorf("Failed to compare cloned creds: %v", err)
   242  		}
   243  		if xToken == yToken {
   244  			return true
   245  		}
   246  		return false
   247  	})), false; got != want {
   248  		t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
   249  	}
   250  }
   251  
   252  func TestOverrideServerName(t *testing.T) {
   253  	// Setup test.
   254  	os.Setenv("S2A_ACCESS_TOKEN", "TestOverrideServerName_s2a_access_token")
   255  	c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
   256  		IdentityOneof: &commonpbv1.Identity_Hostname{
   257  			Hostname: "test_rsa_client_identity",
   258  		},
   259  	}, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil, nil, nil)
   260  	s2av2Creds, ok := c.(*s2av2TransportCreds)
   261  	if !ok {
   262  		t.Fatal("The created creds is not of type s2av2TransportCreds")
   263  	}
   264  	if err != nil {
   265  		t.Fatalf("NewClientCreds() failed: %v", err)
   266  	}
   267  	if got, want := c.Info().ServerName, ""; got != want {
   268  		t.Errorf("c.Info().ServerName = %v, want %v", got, want)
   269  	}
   270  	if got, want := s2av2Creds.serverName, ""; got != want {
   271  		t.Errorf("c.serverName = %v, want %v", got, want)
   272  	}
   273  	for _, tc := range []struct {
   274  		description    string
   275  		override       string
   276  		wantServerName string
   277  		expectError    bool
   278  	}{
   279  		{
   280  			description:    "empty string",
   281  			override:       "",
   282  			wantServerName: "",
   283  		},
   284  		{
   285  			description:    "host only",
   286  			override:       "server.name",
   287  			wantServerName: "server.name",
   288  		},
   289  		{
   290  			description:    "invalid syntax",
   291  			override:       "server::",
   292  			wantServerName: "server::",
   293  		},
   294  		{
   295  			description:    "split host port",
   296  			override:       "host:port",
   297  			wantServerName: "host",
   298  		},
   299  	} {
   300  		t.Run(tc.description, func(t *testing.T) {
   301  			c.OverrideServerName(tc.override)
   302  			if got, want := c.Info().ServerName, tc.wantServerName; got != want {
   303  				t.Errorf("c.Info().ServerName = %v, want %v", got, want)
   304  			}
   305  			if got, want := s2av2Creds.serverName, tc.wantServerName; got != want {
   306  				t.Errorf("c.serverName = %v, want %v", got, want)
   307  			}
   308  		})
   309  	}
   310  }
   311  
   312  type s2ATestStream struct {
   313  	debug string
   314  }
   315  
   316  func (x s2ATestStream) Send(m *s2av2pb.SessionReq) error {
   317  	return nil
   318  }
   319  
   320  func (x s2ATestStream) Recv() (*s2av2pb.SessionResp, error) {
   321  	return nil, nil
   322  }
   323  
   324  func (x s2ATestStream) CloseSend() error {
   325  	return nil
   326  }
   327  
   328  func TestCreateStream(t *testing.T) {
   329  	for _, tc := range []struct {
   330  		description string
   331  	}{
   332  		{
   333  			description: "static",
   334  		},
   335  	} {
   336  		t.Run(tc.description, func(t *testing.T) {
   337  			s2AStream, err := createStream(context.TODO(), "fake address", nil, func(ctx context.Context, s2av2Address string) (stream.S2AStream, error) {
   338  				return s2ATestStream{debug: "test s2a stream"}, nil
   339  			})
   340  			if err != nil {
   341  				t.Fatalf("New S2AStream failed: %v", err)
   342  			}
   343  			testStream, ok := s2AStream.(s2ATestStream)
   344  			if !ok {
   345  				t.Fatal("The created stream is not of type s2ATestStream")
   346  			}
   347  			if testStream.debug != "test s2a stream" {
   348  				t.Errorf("The created stream is not the intended stream")
   349  			}
   350  		})
   351  	}
   352  }
   353  
   354  func TestGetS2ATimeout(t *testing.T) {
   355  	oldEnvValue := os.Getenv(s2aTimeoutEnv)
   356  	defer os.Setenv(s2aTimeoutEnv, oldEnvValue)
   357  
   358  	// Unset the environment var
   359  	os.Unsetenv(s2aTimeoutEnv)
   360  	if got, want := GetS2ATimeout(), defaultS2ATimeout; got != want {
   361  		t.Fatalf("GetS2ATimeout should return default if S2A_TIMEOUT is not set")
   362  	}
   363  
   364  	// Set the environment var to empty string
   365  	os.Setenv(s2aTimeoutEnv, "")
   366  	if got, want := GetS2ATimeout(), defaultS2ATimeout; got != want {
   367  		t.Fatalf("GetS2ATimeout should return default if S2A_TIMEOUT is set to empty string")
   368  	}
   369  
   370  	// Set a valid duration string
   371  	os.Setenv(s2aTimeoutEnv, "5s")
   372  	if got, want := GetS2ATimeout(), 5*time.Second; got != want {
   373  		t.Fatalf("expected timeout to be 5s")
   374  	}
   375  
   376  	// Set an invalid duration string
   377  	os.Setenv(s2aTimeoutEnv, "5abc")
   378  	if got, want := GetS2ATimeout(), defaultS2ATimeout; got != want {
   379  		t.Fatalf("expected timeout to be default if the set timeout is invalid")
   380  	}
   381  }
   382  

View as plain text