...

Source file src/google.golang.org/grpc/credentials/xds/xds_client_test.go

Documentation: google.golang.org/grpc/credentials/xds

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package xds
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"errors"
    26  	"fmt"
    27  	"net"
    28  	"os"
    29  	"strings"
    30  	"sync/atomic"
    31  	"testing"
    32  	"time"
    33  	"unsafe"
    34  
    35  	"google.golang.org/grpc/credentials"
    36  	"google.golang.org/grpc/credentials/tls/certprovider"
    37  	icredentials "google.golang.org/grpc/internal/credentials"
    38  	xdsinternal "google.golang.org/grpc/internal/credentials/xds"
    39  	"google.golang.org/grpc/internal/grpctest"
    40  	"google.golang.org/grpc/internal/testutils"
    41  	"google.golang.org/grpc/internal/xds/matcher"
    42  	"google.golang.org/grpc/resolver"
    43  	"google.golang.org/grpc/testdata"
    44  )
    45  
    46  const (
    47  	defaultTestTimeout      = 1 * time.Second
    48  	defaultTestShortTimeout = 10 * time.Millisecond
    49  	defaultTestCertSAN      = "abc.test.example.com"
    50  	authority               = "authority"
    51  )
    52  
    53  type s struct {
    54  	grpctest.Tester
    55  }
    56  
    57  func Test(t *testing.T) {
    58  	grpctest.RunSubTests(t, s{})
    59  }
    60  
    61  // Helper function to create a real TLS client credentials which is used as
    62  // fallback credentials from multiple tests.
    63  func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
    64  	creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	return creds
    69  }
    70  
    71  // testServer is a no-op server which listens on a local TCP port for incoming
    72  // connections, and performs a manual TLS handshake on the received raw
    73  // connection using a user specified handshake function. It then makes the
    74  // result of the handshake operation available through a channel for tests to
    75  // inspect. Tests should stop the testServer as part of their cleanup.
    76  type testServer struct {
    77  	lis           net.Listener
    78  	address       string             // Listening address of the test server.
    79  	handshakeFunc testHandshakeFunc  // Test specified handshake function.
    80  	hsResult      *testutils.Channel // Channel to deliver handshake results.
    81  }
    82  
    83  // handshakeResult wraps the result of the handshake operation on the test
    84  // server. It consists of TLS connection state and an error, if the handshake
    85  // failed. This result is delivered on the `hsResult` channel on the testServer.
    86  type handshakeResult struct {
    87  	connState tls.ConnectionState
    88  	err       error
    89  }
    90  
    91  // Configurable handshake function for the testServer. Tests can set this to
    92  // simulate different conditions like handshake success, failure, timeout etc.
    93  type testHandshakeFunc func(net.Conn) handshakeResult
    94  
    95  // newTestServerWithHandshakeFunc starts a new testServer which listens for
    96  // connections on a local TCP port, and uses the provided custom handshake
    97  // function to perform TLS handshake.
    98  func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
    99  	ts := &testServer{
   100  		handshakeFunc: f,
   101  		hsResult:      testutils.NewChannel(),
   102  	}
   103  	ts.start()
   104  	return ts
   105  }
   106  
   107  // starts actually starts listening on a local TCP port, and spawns a goroutine
   108  // to handle new connections.
   109  func (ts *testServer) start() error {
   110  	lis, err := net.Listen("tcp", "localhost:0")
   111  	if err != nil {
   112  		return err
   113  	}
   114  	ts.lis = lis
   115  	ts.address = lis.Addr().String()
   116  	go ts.handleConn()
   117  	return nil
   118  }
   119  
   120  // handleconn accepts a new raw connection, and invokes the test provided
   121  // handshake function to perform TLS handshake, and returns the result on the
   122  // `hsResult` channel.
   123  func (ts *testServer) handleConn() {
   124  	for {
   125  		rawConn, err := ts.lis.Accept()
   126  		if err != nil {
   127  			// Once the listeners closed, Accept() will return with an error.
   128  			return
   129  		}
   130  		hsr := ts.handshakeFunc(rawConn)
   131  		ts.hsResult.Send(hsr)
   132  	}
   133  }
   134  
   135  // stop closes the associated listener which causes the connection handling
   136  // goroutine to exit.
   137  func (ts *testServer) stop() {
   138  	ts.lis.Close()
   139  }
   140  
   141  // A handshake function which simulates a successful handshake without client
   142  // authentication (server does not request for client certificate during the
   143  // handshake here).
   144  func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
   145  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   146  	if err != nil {
   147  		return handshakeResult{err: err}
   148  	}
   149  	cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
   150  	conn := tls.Server(rawConn, cfg)
   151  	if err := conn.Handshake(); err != nil {
   152  		return handshakeResult{err: err}
   153  	}
   154  	return handshakeResult{connState: conn.ConnectionState()}
   155  }
   156  
   157  // A handshake function which simulates a successful handshake with mutual
   158  // authentication.
   159  func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
   160  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   161  	if err != nil {
   162  		return handshakeResult{err: err}
   163  	}
   164  	pemData, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
   165  	if err != nil {
   166  		return handshakeResult{err: err}
   167  	}
   168  	roots := x509.NewCertPool()
   169  	roots.AppendCertsFromPEM(pemData)
   170  	cfg := &tls.Config{
   171  		Certificates: []tls.Certificate{cert},
   172  		ClientCAs:    roots,
   173  	}
   174  	conn := tls.Server(rawConn, cfg)
   175  	if err := conn.Handshake(); err != nil {
   176  		return handshakeResult{err: err}
   177  	}
   178  	return handshakeResult{connState: conn.ConnectionState()}
   179  }
   180  
   181  // fakeProvider is an implementation of the certprovider.Provider interface
   182  // which returns the configured key material and error in calls to
   183  // KeyMaterial().
   184  type fakeProvider struct {
   185  	km  *certprovider.KeyMaterial
   186  	err error
   187  }
   188  
   189  func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
   190  	return f.km, f.err
   191  }
   192  
   193  func (f *fakeProvider) Close() {}
   194  
   195  // makeIdentityProvider creates a new instance of the fakeProvider returning the
   196  // identity key material specified in the provider file paths.
   197  func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider {
   198  	t.Helper()
   199  	cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}}
   204  }
   205  
   206  // makeRootProvider creates a new instance of the fakeProvider returning the
   207  // root key material specified in the provider file paths.
   208  func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
   209  	pemData, err := os.ReadFile(testdata.Path(caPath))
   210  	if err != nil {
   211  		t.Fatal(err)
   212  	}
   213  	roots := x509.NewCertPool()
   214  	roots.AppendCertsFromPEM(pemData)
   215  	return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
   216  }
   217  
   218  // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
   219  // context value added to it.
   220  func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
   221  	// Creating the HandshakeInfo and adding it to the attributes is very
   222  	// similar to what the CDS balancer would do when it intercepts calls to
   223  	// NewSubConn().
   224  	var sms []matcher.StringMatcher
   225  	if sanExactMatch != "" {
   226  		sms = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}
   227  	}
   228  	info := xdsinternal.NewHandshakeInfo(root, identity, sms, false)
   229  	uPtr := unsafe.Pointer(info)
   230  	addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
   231  
   232  	// Moving the attributes from the resolver.Address to the context passed to
   233  	// the handshaker is done in the transport layer. Since we directly call the
   234  	// handshaker in these tests, we need to do the same here.
   235  	return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
   236  }
   237  
   238  // compareAuthInfo compares the AuthInfo received on the client side after a
   239  // successful handshake with the authInfo available on the testServer.
   240  func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
   241  	if ai.AuthType() != "tls" {
   242  		return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
   243  	}
   244  	info, ok := ai.(credentials.TLSInfo)
   245  	if !ok {
   246  		return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})
   247  	}
   248  	gotState := info.State
   249  
   250  	// Read the handshake result from the testServer which contains the TLS
   251  	// connection state and compare it with the one received on the client-side.
   252  	val, err := ts.hsResult.Receive(ctx)
   253  	if err != nil {
   254  		return fmt.Errorf("testServer failed to return handshake result: %v", err)
   255  	}
   256  	hsr := val.(handshakeResult)
   257  	if hsr.err != nil {
   258  		return fmt.Errorf("testServer handshake failure: %v", hsr.err)
   259  	}
   260  	// AuthInfo contains a variety of information. We only verify a subset here.
   261  	// This is the same subset which is verified in TLS credentials tests.
   262  	if err := compareConnState(gotState, hsr.connState); err != nil {
   263  		return err
   264  	}
   265  	return nil
   266  }
   267  
   268  func compareConnState(got, want tls.ConnectionState) error {
   269  	switch {
   270  	case got.Version != want.Version:
   271  		return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version)
   272  	case got.HandshakeComplete != want.HandshakeComplete:
   273  		return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete)
   274  	case got.CipherSuite != want.CipherSuite:
   275  		return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite)
   276  	case got.NegotiatedProtocol != want.NegotiatedProtocol:
   277  		return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol)
   278  	}
   279  	return nil
   280  }
   281  
   282  // TestClientCredsWithoutFallback verifies that the call to
   283  // NewClientCredentials() fails when no fallback is specified.
   284  func (s) TestClientCredsWithoutFallback(t *testing.T) {
   285  	if _, err := NewClientCredentials(ClientOptions{}); err == nil {
   286  		t.Fatal("NewClientCredentials() succeeded without specifying fallback")
   287  	}
   288  }
   289  
   290  // TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in
   291  // HandshakeInfo is invalid because it does not contain the expected certificate
   292  // providers.
   293  func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
   294  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   295  	creds, err := NewClientCredentials(opts)
   296  	if err != nil {
   297  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   298  	}
   299  
   300  	pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   301  	defer cancel()
   302  	ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "")
   303  	if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
   304  		t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
   305  	}
   306  }
   307  
   308  // TestClientCredsProviderFailure verifies the cases where an expected
   309  // certificate provider is missing in the HandshakeInfo value in the context.
   310  func (s) TestClientCredsProviderFailure(t *testing.T) {
   311  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   312  	creds, err := NewClientCredentials(opts)
   313  	if err != nil {
   314  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   315  	}
   316  
   317  	tests := []struct {
   318  		desc             string
   319  		rootProvider     certprovider.Provider
   320  		identityProvider certprovider.Provider
   321  		wantErr          string
   322  	}{
   323  		{
   324  			desc:         "erroring root provider",
   325  			rootProvider: &fakeProvider{err: errors.New("root provider error")},
   326  			wantErr:      "root provider error",
   327  		},
   328  		{
   329  			desc:             "erroring identity provider",
   330  			rootProvider:     &fakeProvider{km: &certprovider.KeyMaterial{}},
   331  			identityProvider: &fakeProvider{err: errors.New("identity provider error")},
   332  			wantErr:          "identity provider error",
   333  		},
   334  	}
   335  	for _, test := range tests {
   336  		t.Run(test.desc, func(t *testing.T) {
   337  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   338  			defer cancel()
   339  			ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "")
   340  			if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
   341  				t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
   342  			}
   343  		})
   344  	}
   345  }
   346  
   347  // TestClientCredsSuccess verifies successful client handshake cases.
   348  func (s) TestClientCredsSuccess(t *testing.T) {
   349  	tests := []struct {
   350  		desc             string
   351  		handshakeFunc    testHandshakeFunc
   352  		handshakeInfoCtx func(ctx context.Context) context.Context
   353  	}{
   354  		{
   355  			desc:          "fallback",
   356  			handshakeFunc: testServerTLSHandshake,
   357  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   358  				// Since we don't add a HandshakeInfo to the context, the
   359  				// ClientHandshake() method will delegate to the fallback.
   360  				return ctx
   361  			},
   362  		},
   363  		{
   364  			desc:          "TLS",
   365  			handshakeFunc: testServerTLSHandshake,
   366  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   367  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
   368  			},
   369  		},
   370  		{
   371  			desc:          "mTLS",
   372  			handshakeFunc: testServerMutualTLSHandshake,
   373  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   374  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
   375  			},
   376  		},
   377  		{
   378  			desc:          "mTLS with no acceptedSANs specified",
   379  			handshakeFunc: testServerMutualTLSHandshake,
   380  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   381  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
   382  			},
   383  		},
   384  	}
   385  
   386  	for _, test := range tests {
   387  		t.Run(test.desc, func(t *testing.T) {
   388  			ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
   389  			defer ts.stop()
   390  
   391  			opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   392  			creds, err := NewClientCredentials(opts)
   393  			if err != nil {
   394  				t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   395  			}
   396  
   397  			conn, err := net.Dial("tcp", ts.address)
   398  			if err != nil {
   399  				t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   400  			}
   401  			defer conn.Close()
   402  
   403  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   404  			defer cancel()
   405  			_, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
   406  			if err != nil {
   407  				t.Fatalf("ClientHandshake() returned failed: %q", err)
   408  			}
   409  			if err := compareAuthInfo(ctx, ts, ai); err != nil {
   410  				t.Fatal(err)
   411  			}
   412  		})
   413  	}
   414  }
   415  
   416  func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
   417  	clientDone := make(chan struct{})
   418  	// A handshake function which simulates a handshake timeout from the
   419  	// server-side by simply blocking on the client-side handshake to timeout
   420  	// and not writing any handshake data.
   421  	hErr := errors.New("server handshake error")
   422  	ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
   423  		<-clientDone
   424  		return handshakeResult{err: hErr}
   425  	})
   426  	defer ts.stop()
   427  
   428  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   429  	creds, err := NewClientCredentials(opts)
   430  	if err != nil {
   431  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   432  	}
   433  
   434  	conn, err := net.Dial("tcp", ts.address)
   435  	if err != nil {
   436  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   437  	}
   438  	defer conn.Close()
   439  
   440  	sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   441  	defer sCancel()
   442  	ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
   443  	if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
   444  		t.Fatal("ClientHandshake() succeeded when expected to timeout")
   445  	}
   446  	close(clientDone)
   447  
   448  	// Read the handshake result from the testServer and make sure the expected
   449  	// error is returned.
   450  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   451  	defer cancel()
   452  	val, err := ts.hsResult.Receive(ctx)
   453  	if err != nil {
   454  		t.Fatalf("testServer failed to return handshake result: %v", err)
   455  	}
   456  	hsr := val.(handshakeResult)
   457  	if hsr.err != hErr {
   458  		t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
   459  	}
   460  }
   461  
   462  // TestClientCredsHandshakeFailure verifies different handshake failure cases.
   463  func (s) TestClientCredsHandshakeFailure(t *testing.T) {
   464  	tests := []struct {
   465  		desc          string
   466  		handshakeFunc testHandshakeFunc
   467  		rootProvider  certprovider.Provider
   468  		san           string
   469  		wantErr       string
   470  	}{
   471  		{
   472  			desc:          "cert validation failure",
   473  			handshakeFunc: testServerTLSHandshake,
   474  			rootProvider:  makeRootProvider(t, "x509/client_ca_cert.pem"),
   475  			san:           defaultTestCertSAN,
   476  			wantErr:       "x509: certificate signed by unknown authority",
   477  		},
   478  		{
   479  			desc:          "SAN mismatch",
   480  			handshakeFunc: testServerTLSHandshake,
   481  			rootProvider:  makeRootProvider(t, "x509/server_ca_cert.pem"),
   482  			san:           "bad-san",
   483  			wantErr:       "do not match any of the accepted SANs",
   484  		},
   485  	}
   486  
   487  	for _, test := range tests {
   488  		t.Run(test.desc, func(t *testing.T) {
   489  			ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
   490  			defer ts.stop()
   491  
   492  			opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   493  			creds, err := NewClientCredentials(opts)
   494  			if err != nil {
   495  				t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   496  			}
   497  
   498  			conn, err := net.Dial("tcp", ts.address)
   499  			if err != nil {
   500  				t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   501  			}
   502  			defer conn.Close()
   503  
   504  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   505  			defer cancel()
   506  			ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
   507  			if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
   508  				t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
   509  			}
   510  		})
   511  	}
   512  }
   513  
   514  // TestClientCredsProviderSwitch verifies the case where the first attempt of
   515  // ClientHandshake fails because of a handshake failure. Then we update the
   516  // certificate provider and the second attempt succeeds. This is an
   517  // approximation of the flow of events when the control plane specifies new
   518  // security config which results in new certificate providers being used.
   519  func (s) TestClientCredsProviderSwitch(t *testing.T) {
   520  	ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
   521  	defer ts.stop()
   522  
   523  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   524  	creds, err := NewClientCredentials(opts)
   525  	if err != nil {
   526  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   527  	}
   528  
   529  	conn, err := net.Dial("tcp", ts.address)
   530  	if err != nil {
   531  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   532  	}
   533  	defer conn.Close()
   534  
   535  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   536  	defer cancel()
   537  	// Create a root provider which will fail the handshake because it does not
   538  	// use the correct trust roots.
   539  	root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
   540  	handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
   541  	// We need to repeat most of what newTestContextWithHandshakeInfo() does
   542  	// here because we need access to the underlying HandshakeInfo so that we
   543  	// can update it before the next call to ClientHandshake().
   544  	uPtr := unsafe.Pointer(handshakeInfo)
   545  	addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
   546  	ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
   547  	if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
   548  		t.Fatal("ClientHandshake() succeeded when expected to fail")
   549  	}
   550  	// Drain the result channel on the test server so that we can inspect the
   551  	// result for the next handshake.
   552  	_, err = ts.hsResult.Receive(ctx)
   553  	if err != nil {
   554  		t.Errorf("testServer failed to return handshake result: %v", err)
   555  	}
   556  
   557  	conn, err = net.Dial("tcp", ts.address)
   558  	if err != nil {
   559  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   560  	}
   561  	defer conn.Close()
   562  
   563  	// Create a new root provider which uses the correct trust roots. And update
   564  	// the HandshakeInfo with the new provider.
   565  	root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
   566  	handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
   567  	// Update the existing pointer, which address attribute will continue to
   568  	// point to.
   569  	atomic.StorePointer(&uPtr, unsafe.Pointer(handshakeInfo))
   570  	_, ai, err := creds.ClientHandshake(ctx, authority, conn)
   571  	if err != nil {
   572  		t.Fatalf("ClientHandshake() returned failed: %q", err)
   573  	}
   574  	if err := compareAuthInfo(ctx, ts, ai); err != nil {
   575  		t.Fatal(err)
   576  	}
   577  }
   578  
   579  // TestClientClone verifies the Clone() method on client credentials.
   580  func (s) TestClientClone(t *testing.T) {
   581  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   582  	orig, err := NewClientCredentials(opts)
   583  	if err != nil {
   584  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   585  	}
   586  
   587  	// The credsImpl does not have any exported fields, and it does not make
   588  	// sense to use any cmp options to look deep into. So, all we make sure here
   589  	// is that the cloned object points to a different location in memory.
   590  	if clone := orig.Clone(); clone == orig {
   591  		t.Fatal("return value from Clone() doesn't point to new credentials instance")
   592  	}
   593  }
   594  
   595  func newStringP(s string) *string {
   596  	return &s
   597  }
   598  

View as plain text