...

Source file src/google.golang.org/grpc/credentials/alts/alts_test.go

Documentation: google.golang.org/grpc/credentials/alts

     1  //go:build linux || windows
     2  // +build linux windows
     3  
     4  /*
     5   *
     6   * Copyright 2018 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     http://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  package alts
    23  
    24  import (
    25  	"context"
    26  	"reflect"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/credentials/alts/internal/handshaker"
    34  	"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
    35  	altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    36  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    37  	"google.golang.org/grpc/credentials/alts/internal/testutil"
    38  	"google.golang.org/grpc/internal/grpctest"
    39  	"google.golang.org/grpc/internal/testutils"
    40  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    41  	testpb "google.golang.org/grpc/interop/grpc_testing"
    42  	"google.golang.org/grpc/peer"
    43  	"google.golang.org/grpc/status"
    44  	"google.golang.org/protobuf/proto"
    45  )
    46  
    47  const (
    48  	defaultTestLongTimeout  = 60 * time.Second
    49  	defaultTestShortTimeout = 10 * time.Millisecond
    50  )
    51  
    52  type s struct {
    53  	grpctest.Tester
    54  }
    55  
    56  func init() {
    57  	// The vmOnGCP global variable MUST be forced to true. Otherwise, if
    58  	// this test is run anywhere except on a GCP VM, then an ALTS handshake
    59  	// will immediately fail.
    60  	once.Do(func() {})
    61  	vmOnGCP = true
    62  }
    63  
    64  func Test(t *testing.T) {
    65  	grpctest.RunSubTests(t, s{})
    66  }
    67  
    68  func (s) TestInfoServerName(t *testing.T) {
    69  	// This is not testing any handshaker functionality, so it's fine to only
    70  	// use NewServerCreds and not NewClientCreds.
    71  	alts := NewServerCreds(DefaultServerOptions())
    72  	if got, want := alts.Info().ServerName, ""; got != want {
    73  		t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
    74  	}
    75  }
    76  
    77  func (s) TestOverrideServerName(t *testing.T) {
    78  	wantServerName := "server.name"
    79  	// This is not testing any handshaker functionality, so it's fine to only
    80  	// use NewServerCreds and not NewClientCreds.
    81  	c := NewServerCreds(DefaultServerOptions())
    82  	c.OverrideServerName(wantServerName)
    83  	if got, want := c.Info().ServerName, wantServerName; got != want {
    84  		t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
    85  	}
    86  }
    87  
    88  func (s) TestCloneClient(t *testing.T) {
    89  	wantServerName := "server.name"
    90  	opt := DefaultClientOptions()
    91  	opt.TargetServiceAccounts = []string{"not", "empty"}
    92  	c := NewClientCreds(opt)
    93  	c.OverrideServerName(wantServerName)
    94  	cc := c.Clone()
    95  	if got, want := cc.Info().ServerName, wantServerName; got != want {
    96  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
    97  	}
    98  	cc.OverrideServerName("")
    99  	if got, want := c.Info().ServerName, wantServerName; got != want {
   100  		t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
   101  	}
   102  	if got, want := cc.Info().ServerName, ""; got != want {
   103  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   104  	}
   105  
   106  	ct := c.(*altsTC)
   107  	cct := cc.(*altsTC)
   108  
   109  	if ct.side != cct.side {
   110  		t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
   111  	}
   112  	if ct.hsAddress != cct.hsAddress {
   113  		t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
   114  	}
   115  	if !reflect.DeepEqual(ct.accounts, cct.accounts) {
   116  		t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
   117  	}
   118  }
   119  
   120  func (s) TestCloneServer(t *testing.T) {
   121  	wantServerName := "server.name"
   122  	c := NewServerCreds(DefaultServerOptions())
   123  	c.OverrideServerName(wantServerName)
   124  	cc := c.Clone()
   125  	if got, want := cc.Info().ServerName, wantServerName; got != want {
   126  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   127  	}
   128  	cc.OverrideServerName("")
   129  	if got, want := c.Info().ServerName, wantServerName; got != want {
   130  		t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
   131  	}
   132  	if got, want := cc.Info().ServerName, ""; got != want {
   133  		t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
   134  	}
   135  
   136  	ct := c.(*altsTC)
   137  	cct := cc.(*altsTC)
   138  
   139  	if ct.side != cct.side {
   140  		t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
   141  	}
   142  	if ct.hsAddress != cct.hsAddress {
   143  		t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
   144  	}
   145  	if !reflect.DeepEqual(ct.accounts, cct.accounts) {
   146  		t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
   147  	}
   148  }
   149  
   150  func (s) TestInfo(t *testing.T) {
   151  	// This is not testing any handshaker functionality, so it's fine to only
   152  	// use NewServerCreds and not NewClientCreds.
   153  	c := NewServerCreds(DefaultServerOptions())
   154  	info := c.Info()
   155  	if got, want := info.ProtocolVersion, ""; got != want {
   156  		t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
   157  	}
   158  	if got, want := info.SecurityProtocol, "alts"; got != want {
   159  		t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
   160  	}
   161  	if got, want := info.SecurityVersion, "1.0"; got != want {
   162  		t.Errorf("info.SecurityVersion=%v, want %v", got, want)
   163  	}
   164  	if got, want := info.ServerName, ""; got != want {
   165  		t.Errorf("info.ServerName=%v, want %v", got, want)
   166  	}
   167  }
   168  
   169  func (s) TestCompareRPCVersions(t *testing.T) {
   170  	for _, tc := range []struct {
   171  		v1     *altspb.RpcProtocolVersions_Version
   172  		v2     *altspb.RpcProtocolVersions_Version
   173  		output int
   174  	}{
   175  		{
   176  			version(3, 2),
   177  			version(2, 1),
   178  			1,
   179  		},
   180  		{
   181  			version(3, 2),
   182  			version(3, 1),
   183  			1,
   184  		},
   185  		{
   186  			version(2, 1),
   187  			version(3, 2),
   188  			-1,
   189  		},
   190  		{
   191  			version(3, 1),
   192  			version(3, 2),
   193  			-1,
   194  		},
   195  		{
   196  			version(3, 2),
   197  			version(3, 2),
   198  			0,
   199  		},
   200  	} {
   201  		if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
   202  			t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
   203  		}
   204  	}
   205  }
   206  
   207  func (s) TestCheckRPCVersions(t *testing.T) {
   208  	for _, tc := range []struct {
   209  		desc             string
   210  		local            *altspb.RpcProtocolVersions
   211  		peer             *altspb.RpcProtocolVersions
   212  		output           bool
   213  		maxCommonVersion *altspb.RpcProtocolVersions_Version
   214  	}{
   215  		{
   216  			"local.max > peer.max and local.min > peer.min",
   217  			versions(2, 1, 3, 2),
   218  			versions(1, 2, 2, 1),
   219  			true,
   220  			version(2, 1),
   221  		},
   222  		{
   223  			"local.max > peer.max and local.min < peer.min",
   224  			versions(1, 2, 3, 2),
   225  			versions(2, 1, 2, 1),
   226  			true,
   227  			version(2, 1),
   228  		},
   229  		{
   230  			"local.max > peer.max and local.min = peer.min",
   231  			versions(2, 1, 3, 2),
   232  			versions(2, 1, 2, 1),
   233  			true,
   234  			version(2, 1),
   235  		},
   236  		{
   237  			"local.max < peer.max and local.min > peer.min",
   238  			versions(2, 1, 2, 1),
   239  			versions(1, 2, 3, 2),
   240  			true,
   241  			version(2, 1),
   242  		},
   243  		{
   244  			"local.max = peer.max and local.min > peer.min",
   245  			versions(2, 1, 2, 1),
   246  			versions(1, 2, 2, 1),
   247  			true,
   248  			version(2, 1),
   249  		},
   250  		{
   251  			"local.max < peer.max and local.min < peer.min",
   252  			versions(1, 2, 2, 1),
   253  			versions(2, 1, 3, 2),
   254  			true,
   255  			version(2, 1),
   256  		},
   257  		{
   258  			"local.max < peer.max and local.min = peer.min",
   259  			versions(1, 2, 2, 1),
   260  			versions(1, 2, 3, 2),
   261  			true,
   262  			version(2, 1),
   263  		},
   264  		{
   265  			"local.max = peer.max and local.min < peer.min",
   266  			versions(1, 2, 2, 1),
   267  			versions(2, 1, 2, 1),
   268  			true,
   269  			version(2, 1),
   270  		},
   271  		{
   272  			"all equal",
   273  			versions(2, 1, 2, 1),
   274  			versions(2, 1, 2, 1),
   275  			true,
   276  			version(2, 1),
   277  		},
   278  		{
   279  			"max is smaller than min",
   280  			versions(2, 1, 1, 2),
   281  			versions(2, 1, 1, 2),
   282  			false,
   283  			nil,
   284  		},
   285  		{
   286  			"no overlap, local > peer",
   287  			versions(4, 3, 6, 5),
   288  			versions(1, 0, 2, 1),
   289  			false,
   290  			nil,
   291  		},
   292  		{
   293  			"no overlap, local < peer",
   294  			versions(1, 0, 2, 1),
   295  			versions(4, 3, 6, 5),
   296  			false,
   297  			nil,
   298  		},
   299  		{
   300  			"no overlap, max < min",
   301  			versions(6, 5, 4, 3),
   302  			versions(2, 1, 1, 0),
   303  			false,
   304  			nil,
   305  		},
   306  	} {
   307  		output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
   308  		if got, want := output, tc.output; got != want {
   309  			t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
   310  		}
   311  		if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
   312  			t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
   313  		}
   314  	}
   315  }
   316  
   317  // TestFullHandshake performs a full ALTS handshake between a test client and
   318  // server, where both client and server offload to a local, fake handshaker
   319  // service.
   320  func (s) TestFullHandshake(t *testing.T) {
   321  	// Start the fake handshaker service and the server.
   322  	var wait sync.WaitGroup
   323  	defer wait.Wait()
   324  	stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
   325  	defer stopHandshaker()
   326  	stopServer, serverAddress := startServer(t, handshakerAddress, &wait)
   327  	defer stopServer()
   328  
   329  	// Ping the server, authenticating with ALTS.
   330  	establishAltsConnection(t, handshakerAddress, serverAddress)
   331  
   332  	// Close open connections to the fake handshaker service.
   333  	if err := service.CloseForTesting(); err != nil {
   334  		t.Errorf("service.CloseForTesting() failed: %v", err)
   335  	}
   336  }
   337  
   338  // TestConcurrentHandshakes performs a several, concurrent ALTS handshakes
   339  // between a test client and server, where both client and server offload to a
   340  // local, fake handshaker service.
   341  func (s) TestConcurrentHandshakes(t *testing.T) {
   342  	// Set the max number of concurrent handshakes to 3, so that we can
   343  	// test the handshaker behavior when handshakes are queued by
   344  	// performing more than 3 concurrent handshakes (specifically, 10).
   345  	handshaker.ResetConcurrentHandshakeSemaphoreForTesting(3)
   346  
   347  	// Start the fake handshaker service and the server.
   348  	var wait sync.WaitGroup
   349  	defer wait.Wait()
   350  	stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
   351  	defer stopHandshaker()
   352  	stopServer, serverAddress := startServer(t, handshakerAddress, &wait)
   353  	defer stopServer()
   354  
   355  	// Ping the server, authenticating with ALTS.
   356  	var waitForConnections sync.WaitGroup
   357  	for i := 0; i < 10; i++ {
   358  		waitForConnections.Add(1)
   359  		go func() {
   360  			establishAltsConnection(t, handshakerAddress, serverAddress)
   361  			waitForConnections.Done()
   362  		}()
   363  	}
   364  	waitForConnections.Wait()
   365  
   366  	// Close open connections to the fake handshaker service.
   367  	if err := service.CloseForTesting(); err != nil {
   368  		t.Errorf("service.CloseForTesting() failed: %v", err)
   369  	}
   370  }
   371  
   372  func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
   373  	return &altspb.RpcProtocolVersions_Version{
   374  		Major: major,
   375  		Minor: minor,
   376  	}
   377  }
   378  
   379  func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
   380  	return &altspb.RpcProtocolVersions{
   381  		MinRpcVersion: version(minMajor, minMinor),
   382  		MaxRpcVersion: version(maxMajor, maxMinor),
   383  	}
   384  }
   385  
   386  func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
   387  	clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
   388  	conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds))
   389  	if err != nil {
   390  		t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err)
   391  	}
   392  	defer conn.Close()
   393  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
   394  	defer cancel()
   395  	c := testgrpc.NewTestServiceClient(conn)
   396  	var peer peer.Peer
   397  	success := false
   398  	for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
   399  		_, err = c.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.Peer(&peer))
   400  		if err == nil {
   401  			success = true
   402  			break
   403  		}
   404  		if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded {
   405  			// The server is not ready yet or there were too many concurrent handshakes.
   406  			// Try again.
   407  			continue
   408  		}
   409  		t.Fatalf("c.UnaryCall() failed: %v", err)
   410  	}
   411  	if !success {
   412  		t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout)
   413  	}
   414  
   415  	// Check that peer.AuthInfo was populated with an ALTS AuthInfo
   416  	// instance. As a sanity check, also verify that the AuthType() and
   417  	// ApplicationProtocol() have the expected values.
   418  	if got, want := peer.AuthInfo.AuthType(), "alts"; got != want {
   419  		t.Errorf("authInfo.AuthType() = %s, want = %s", got, want)
   420  	}
   421  	authInfo, err := AuthInfoFromPeer(&peer)
   422  	if err != nil {
   423  		t.Errorf("AuthInfoFromPeer failed: %v", err)
   424  	}
   425  	if got, want := authInfo.ApplicationProtocol(), "grpc"; got != want {
   426  		t.Errorf("authInfo.ApplicationProtocol() = %s, want = %s", got, want)
   427  	}
   428  }
   429  
   430  func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
   431  	listener, err := testutils.LocalTCPListener()
   432  	if err != nil {
   433  		t.Fatalf("LocalTCPListener() failed: %v", err)
   434  	}
   435  	s := grpc.NewServer()
   436  	altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{})
   437  	wait.Add(1)
   438  	go func() {
   439  		defer wait.Done()
   440  		if err := s.Serve(listener); err != nil {
   441  			t.Errorf("failed to serve: %v", err)
   442  		}
   443  	}()
   444  	return func() { s.Stop() }, listener.Addr().String()
   445  }
   446  
   447  func startServer(t *testing.T, handshakerServiceAddress string, wait *sync.WaitGroup) (stop func(), address string) {
   448  	listener, err := testutils.LocalTCPListener()
   449  	if err != nil {
   450  		t.Fatalf("LocalTCPListener() failed: %v", err)
   451  	}
   452  	serverOpts := &ServerOptions{HandshakerServiceAddress: handshakerServiceAddress}
   453  	creds := NewServerCreds(serverOpts)
   454  	s := grpc.NewServer(grpc.Creds(creds))
   455  	testgrpc.RegisterTestServiceServer(s, &testServer{})
   456  	wait.Add(1)
   457  	go func() {
   458  		defer wait.Done()
   459  		if err := s.Serve(listener); err != nil {
   460  			t.Errorf("s.Serve(%v) failed: %v", listener, err)
   461  		}
   462  	}()
   463  	return func() { s.Stop() }, listener.Addr().String()
   464  }
   465  
   466  type testServer struct {
   467  	testgrpc.UnimplementedTestServiceServer
   468  }
   469  
   470  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   471  	return &testpb.SimpleResponse{
   472  		Payload: &testpb.Payload{},
   473  	}, nil
   474  }
   475  

View as plain text