...

Source file src/google.golang.org/grpc/credentials/google/google_test.go

Documentation: google.golang.org/grpc/credentials/google

     1  /*
     2   *
     3   * Copyright 2021 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package google
    20  
    21  import (
    22  	"context"
    23  	"net"
    24  	"testing"
    25  
    26  	"google.golang.org/grpc/credentials"
    27  	icredentials "google.golang.org/grpc/internal/credentials"
    28  	"google.golang.org/grpc/internal/grpctest"
    29  	"google.golang.org/grpc/internal/xds"
    30  	"google.golang.org/grpc/resolver"
    31  )
    32  
    33  type s struct {
    34  	grpctest.Tester
    35  }
    36  
    37  func Test(t *testing.T) {
    38  	grpctest.RunSubTests(t, s{})
    39  }
    40  
    41  type testCreds struct {
    42  	credentials.TransportCredentials
    43  	typ string
    44  }
    45  
    46  func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    47  	return nil, &testAuthInfo{typ: c.typ}, nil
    48  }
    49  
    50  func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    51  	return nil, &testAuthInfo{typ: c.typ}, nil
    52  }
    53  
    54  type testAuthInfo struct {
    55  	typ string
    56  }
    57  
    58  func (t *testAuthInfo) AuthType() string {
    59  	return t.typ
    60  }
    61  
    62  var (
    63  	testTLS  = &testCreds{typ: "tls"}
    64  	testALTS = &testCreds{typ: "alts"}
    65  )
    66  
    67  func overrideNewCredsFuncs() func() {
    68  	origNewTLS := newTLS
    69  	newTLS = func() credentials.TransportCredentials {
    70  		return testTLS
    71  	}
    72  	origNewALTS := newALTS
    73  	newALTS = func() credentials.TransportCredentials {
    74  		return testALTS
    75  	}
    76  	origNewADC := newADC
    77  	newADC = func(context.Context) (credentials.PerRPCCredentials, error) {
    78  		// We do not use perRPC creds in this test. It is safe to return nil here.
    79  		return nil, nil
    80  	}
    81  
    82  	return func() {
    83  		newTLS = origNewTLS
    84  		newALTS = origNewALTS
    85  		newADC = origNewADC
    86  	}
    87  }
    88  
    89  // TestClientHandshakeBasedOnClusterName that by default (without switching
    90  // modes), ClientHandshake does either tls or alts base on the cluster name in
    91  // attributes.
    92  func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
    93  	defer overrideNewCredsFuncs()()
    94  	for bundleTyp, tc := range map[string]credentials.Bundle{
    95  		"defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}),
    96  		"defaultCreds":            NewDefaultCredentials(),
    97  		"computeCreds":            NewComputeEngineCredentials(),
    98  	} {
    99  		tests := []struct {
   100  			name    string
   101  			ctx     context.Context
   102  			wantTyp string
   103  		}{
   104  			{
   105  				name:    "no cluster name",
   106  				ctx:     context.Background(),
   107  				wantTyp: "tls",
   108  			},
   109  			{
   110  				name: "with non-CFE cluster name",
   111  				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
   112  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
   113  				}),
   114  				// non-CFE backends should use alts.
   115  				wantTyp: "alts",
   116  			},
   117  			{
   118  				name: "with CFE cluster name",
   119  				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
   120  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
   121  				}),
   122  				// CFE should use tls.
   123  				wantTyp: "tls",
   124  			},
   125  			{
   126  				name: "with xdstp CFE cluster name",
   127  				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
   128  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
   129  				}),
   130  				// CFE should use tls.
   131  				wantTyp: "tls",
   132  			},
   133  			{
   134  				name: "with xdstp non-CFE cluster name",
   135  				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
   136  					Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
   137  				}),
   138  				// non-CFE should use atls.
   139  				wantTyp: "alts",
   140  			},
   141  		}
   142  		for _, tt := range tests {
   143  			t.Run(bundleTyp+" "+tt.name, func(t *testing.T) {
   144  				_, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil)
   145  				if err != nil {
   146  					t.Fatalf("ClientHandshake failed: %v", err)
   147  				}
   148  				if gotType := info.AuthType(); gotType != tt.wantTyp {
   149  					t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp)
   150  				}
   151  
   152  				_, infoServer, err := tc.TransportCredentials().ServerHandshake(nil)
   153  				if err != nil {
   154  					t.Fatalf("ClientHandshake failed: %v", err)
   155  				}
   156  				// ServerHandshake should always do TLS.
   157  				if gotType := infoServer.AuthType(); gotType != "tls" {
   158  					t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls")
   159  				}
   160  			})
   161  		}
   162  	}
   163  }
   164  

View as plain text