...

Source file src/google.golang.org/api/internal/cba_test.go

Documentation: google.golang.org/api/internal

     1  // Copyright 2020 Google LLC.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package internal
     6  
     7  import (
     8  	"crypto/tls"
     9  	"net/http"
    10  	"os"
    11  	"testing"
    12  	"time"
    13  )
    14  
    15  const (
    16  	testMTLSEndpoint           = "https://test.mtls.googleapis.com/"
    17  	testRegularEndpoint        = "https://test.googleapis.com/"
    18  	testEndpointTemplate       = "https://test.UNIVERSE_DOMAIN/"
    19  	testOverrideEndpoint       = "https://test.override.example.com/"
    20  	testUniverseDomain         = "example.com"
    21  	testUniverseDomainEndpoint = "https://test.example.com/"
    22  )
    23  
    24  var dummyClientCertSource = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, nil }
    25  
    26  func TestGetEndpoint(t *testing.T) {
    27  	testCases := []struct {
    28  		UserEndpoint            string
    29  		DefaultEndpoint         string
    30  		DefaultEndpointTemplate string
    31  		Want                    string
    32  		WantErr                 bool
    33  	}{
    34  		{
    35  			DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
    36  			Want:                    "https://foo.googleapis.com/bar/baz",
    37  		},
    38  		{
    39  			UserEndpoint:            "myhost:3999",
    40  			DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
    41  			Want:                    "https://myhost:3999/bar/baz",
    42  		},
    43  		{
    44  			UserEndpoint:            "https://host/path/to/bar",
    45  			DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
    46  			Want:                    "https://host/path/to/bar",
    47  		},
    48  		{
    49  			UserEndpoint:    "host:123",
    50  			DefaultEndpoint: "",
    51  			Want:            "host:123",
    52  		},
    53  		{
    54  			UserEndpoint:    "host:123",
    55  			DefaultEndpoint: "default:443",
    56  			Want:            "host:123",
    57  		},
    58  		{
    59  			UserEndpoint:    "host:123",
    60  			DefaultEndpoint: "default:443/bar/baz",
    61  			Want:            "host:123/bar/baz",
    62  		},
    63  	}
    64  
    65  	for _, tc := range testCases {
    66  		got, err := getEndpoint(&DialSettings{
    67  			Endpoint:                tc.UserEndpoint,
    68  			DefaultEndpoint:         tc.DefaultEndpoint,
    69  			DefaultEndpointTemplate: tc.DefaultEndpointTemplate,
    70  		}, nil)
    71  		if tc.WantErr && err == nil {
    72  			t.Errorf("want err, got nil err")
    73  			continue
    74  		}
    75  		if !tc.WantErr && err != nil {
    76  			t.Errorf("want nil err, got %v", err)
    77  			continue
    78  		}
    79  		if tc.Want != got {
    80  			t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpointTemplate, got, tc.Want)
    81  		}
    82  	}
    83  }
    84  
    85  func TestGetEndpointWithClientCertSource(t *testing.T) {
    86  
    87  	testCases := []struct {
    88  		UserEndpoint        string
    89  		DefaultEndpoint     string
    90  		DefaultMTLSEndpoint string
    91  		Want                string
    92  		WantErr             bool
    93  	}{
    94  		{
    95  			DefaultEndpoint:     "https://foo.googleapis.com/bar/baz",
    96  			DefaultMTLSEndpoint: "https://foo.mtls.googleapis.com/bar/baz",
    97  			Want:                "https://foo.mtls.googleapis.com/bar/baz",
    98  		},
    99  		{
   100  			DefaultEndpoint:     "https://staging-foo.sandbox.googleapis.com/bar/baz",
   101  			DefaultMTLSEndpoint: "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
   102  			Want:                "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
   103  		},
   104  		{
   105  			UserEndpoint:    "myhost:3999",
   106  			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
   107  			Want:            "https://myhost:3999/bar/baz",
   108  		},
   109  		{
   110  			UserEndpoint:    "https://host/path/to/bar",
   111  			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
   112  			Want:            "https://host/path/to/bar",
   113  		},
   114  		{
   115  			UserEndpoint:    "host:port",
   116  			DefaultEndpoint: "",
   117  			Want:            "host:port",
   118  		},
   119  	}
   120  
   121  	for _, tc := range testCases {
   122  		got, err := getEndpoint(&DialSettings{
   123  			Endpoint:            tc.UserEndpoint,
   124  			DefaultEndpoint:     tc.DefaultEndpoint,
   125  			DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
   126  		}, dummyClientCertSource)
   127  		if tc.WantErr && err == nil {
   128  			t.Errorf("want err, got nil err")
   129  			continue
   130  		}
   131  		if !tc.WantErr && err != nil {
   132  			t.Errorf("want nil err, got %v", err)
   133  			continue
   134  		}
   135  		if tc.Want != got {
   136  			t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
   137  		}
   138  	}
   139  }
   140  
   141  func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
   142  	testCases := []struct {
   143  		Desc          string
   144  		InputSettings *DialSettings
   145  		S2ARespFunc   func() (string, error)
   146  		WantEndpoint  string
   147  	}{
   148  		{
   149  			"has client cert",
   150  			&DialSettings{
   151  				DefaultMTLSEndpoint: testMTLSEndpoint,
   152  				DefaultEndpoint:     testRegularEndpoint,
   153  				ClientCertSource:    dummyClientCertSource,
   154  			},
   155  			validConfigResp,
   156  			testMTLSEndpoint,
   157  		},
   158  		{
   159  			"no client cert, S2A address not empty",
   160  			&DialSettings{
   161  				DefaultMTLSEndpoint: testMTLSEndpoint,
   162  				DefaultEndpoint:     testRegularEndpoint,
   163  			},
   164  			validConfigResp,
   165  			testMTLSEndpoint,
   166  		},
   167  		{
   168  			"no client cert, S2A address not empty, EnableDirectPath == true",
   169  			&DialSettings{
   170  				DefaultMTLSEndpoint: testMTLSEndpoint,
   171  				DefaultEndpoint:     testRegularEndpoint,
   172  				EnableDirectPath:    true,
   173  			},
   174  			validConfigResp,
   175  			testRegularEndpoint,
   176  		},
   177  		{
   178  			"no client cert, S2A address not empty, EnableDirectPathXds == true",
   179  			&DialSettings{
   180  				DefaultMTLSEndpoint: testMTLSEndpoint,
   181  				DefaultEndpoint:     testRegularEndpoint,
   182  				EnableDirectPathXds: true,
   183  			},
   184  			validConfigResp,
   185  			testRegularEndpoint,
   186  		},
   187  		{
   188  			"no client cert, S2A address empty",
   189  			&DialSettings{
   190  				DefaultMTLSEndpoint: testMTLSEndpoint,
   191  				DefaultEndpoint:     testRegularEndpoint,
   192  			},
   193  			invalidConfigResp,
   194  			testRegularEndpoint,
   195  		},
   196  		{
   197  			"no client cert, S2A address not empty, override endpoint",
   198  			&DialSettings{
   199  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   200  				DefaultEndpointTemplate: testEndpointTemplate,
   201  				Endpoint:                testOverrideEndpoint,
   202  			},
   203  			validConfigResp,
   204  			testOverrideEndpoint,
   205  		},
   206  		{
   207  			"no client cert, S2A address not empty, DefaultMTLSEndpoint not set",
   208  			&DialSettings{
   209  				DefaultMTLSEndpoint:     "",
   210  				DefaultEndpointTemplate: testEndpointTemplate,
   211  			},
   212  			validConfigResp,
   213  			testRegularEndpoint,
   214  		},
   215  	}
   216  	defer setupTest()()
   217  
   218  	for _, tc := range testCases {
   219  		httpGetMetadataMTLSConfig = tc.S2ARespFunc
   220  		if tc.InputSettings.ClientCertSource != nil {
   221  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
   222  		} else {
   223  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
   224  		}
   225  		_, endpoint, _ := GetGRPCTransportConfigAndEndpoint(tc.InputSettings)
   226  		if tc.WantEndpoint != endpoint {
   227  			t.Errorf("%s: want endpoint: [%s], got [%s]", tc.Desc, tc.WantEndpoint, endpoint)
   228  		}
   229  		// Let the cached MTLS config expire at the end of each test case.
   230  		time.Sleep(2 * time.Millisecond)
   231  	}
   232  }
   233  
   234  func TestGetHTTPTransportConfigAndEndpoint_s2a(t *testing.T) {
   235  	testCases := []struct {
   236  		Desc          string
   237  		InputSettings *DialSettings
   238  		S2ARespFunc   func() (string, error)
   239  		WantEndpoint  string
   240  		DialFuncNil   bool
   241  	}{
   242  		{
   243  			"has client cert",
   244  			&DialSettings{
   245  				DefaultMTLSEndpoint: testMTLSEndpoint,
   246  				DefaultEndpoint:     testRegularEndpoint,
   247  				ClientCertSource:    dummyClientCertSource,
   248  			},
   249  			validConfigResp,
   250  			testMTLSEndpoint,
   251  			true,
   252  		},
   253  		{
   254  			"no client cert, S2A address not empty",
   255  			&DialSettings{
   256  				DefaultMTLSEndpoint: testMTLSEndpoint,
   257  				DefaultEndpoint:     testRegularEndpoint,
   258  			},
   259  			validConfigResp,
   260  			testMTLSEndpoint,
   261  			false,
   262  		},
   263  		{
   264  			"no client cert, S2A address not empty, EnableDirectPath == true",
   265  			&DialSettings{
   266  				DefaultMTLSEndpoint: testMTLSEndpoint,
   267  				DefaultEndpoint:     testRegularEndpoint,
   268  				EnableDirectPath:    true,
   269  			},
   270  			validConfigResp,
   271  			testRegularEndpoint,
   272  			true,
   273  		},
   274  		{
   275  			"no client cert, S2A address not empty, EnableDirectPathXds == true",
   276  			&DialSettings{
   277  				DefaultMTLSEndpoint: testMTLSEndpoint,
   278  				DefaultEndpoint:     testRegularEndpoint,
   279  				EnableDirectPathXds: true,
   280  			},
   281  			validConfigResp,
   282  			testRegularEndpoint,
   283  			true,
   284  		},
   285  		{
   286  			"no client cert, S2A address empty",
   287  			&DialSettings{
   288  				DefaultMTLSEndpoint: testMTLSEndpoint,
   289  				DefaultEndpoint:     testRegularEndpoint,
   290  			},
   291  			invalidConfigResp,
   292  			testRegularEndpoint,
   293  			true,
   294  		},
   295  		{
   296  			"no client cert, S2A address not empty, override endpoint",
   297  			&DialSettings{
   298  				DefaultMTLSEndpoint: testMTLSEndpoint,
   299  				DefaultEndpoint:     testRegularEndpoint,
   300  				Endpoint:            testOverrideEndpoint,
   301  			},
   302  			validConfigResp,
   303  			testOverrideEndpoint,
   304  			true,
   305  		},
   306  		{
   307  			"no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set",
   308  			&DialSettings{
   309  				DefaultMTLSEndpoint: "",
   310  				DefaultEndpoint:     testRegularEndpoint,
   311  			},
   312  			validConfigResp,
   313  			testRegularEndpoint,
   314  			true,
   315  		},
   316  		{
   317  			"no client cert, S2A address not empty, custom HTTP client",
   318  			&DialSettings{
   319  				DefaultMTLSEndpoint: testMTLSEndpoint,
   320  				DefaultEndpoint:     testRegularEndpoint,
   321  				HTTPClient:          http.DefaultClient,
   322  			},
   323  			validConfigResp,
   324  			testRegularEndpoint,
   325  			true,
   326  		},
   327  	}
   328  
   329  	defer setupTest()()
   330  
   331  	for _, tc := range testCases {
   332  		httpGetMetadataMTLSConfig = tc.S2ARespFunc
   333  		if tc.InputSettings.ClientCertSource != nil {
   334  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
   335  		} else {
   336  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
   337  		}
   338  		_, dialFunc, endpoint, err := GetHTTPTransportConfigAndEndpoint(tc.InputSettings)
   339  		if err != nil {
   340  			t.Fatalf("%s: err: %v", tc.Desc, err)
   341  		}
   342  		if tc.WantEndpoint != endpoint {
   343  			t.Errorf("%s: want endpoint: [%s], got [%s]", tc.Desc, tc.WantEndpoint, endpoint)
   344  		}
   345  		if want, got := tc.DialFuncNil, dialFunc == nil; want != got {
   346  			t.Errorf("%s: expecting returned dialFunc is nil: [%v], got [%v]", tc.Desc, tc.DialFuncNil, got)
   347  		}
   348  		// Let MTLS config expire at end of each test case.
   349  		time.Sleep(2 * time.Millisecond)
   350  	}
   351  }
   352  
   353  func setupTest() func() {
   354  	oldHTTPGet := httpGetMetadataMTLSConfig
   355  	oldExpiry := configExpiry
   356  	oldUseS2A := os.Getenv(googleAPIUseS2AEnv)
   357  	oldUseClientCert := os.Getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE")
   358  
   359  	configExpiry = time.Millisecond
   360  	os.Setenv(googleAPIUseS2AEnv, "true")
   361  
   362  	return func() {
   363  		httpGetMetadataMTLSConfig = oldHTTPGet
   364  		configExpiry = oldExpiry
   365  		os.Setenv(googleAPIUseS2AEnv, oldUseS2A)
   366  		os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", oldUseClientCert)
   367  	}
   368  }
   369  
   370  func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
   371  	testCases := []struct {
   372  		name         string
   373  		ds           *DialSettings
   374  		wantEndpoint string
   375  		wantErr      error
   376  	}{
   377  		{
   378  			name: "google default universe (GDU), no client cert",
   379  			ds: &DialSettings{
   380  				DefaultEndpoint:         testRegularEndpoint,
   381  				DefaultEndpointTemplate: testEndpointTemplate,
   382  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   383  			},
   384  			wantEndpoint: testRegularEndpoint,
   385  		},
   386  		{
   387  			name: "google default universe (GDU), client cert",
   388  			ds: &DialSettings{
   389  				DefaultEndpoint:         testRegularEndpoint,
   390  				DefaultEndpointTemplate: testEndpointTemplate,
   391  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   392  				ClientCertSource:        dummyClientCertSource,
   393  			},
   394  			wantEndpoint: testMTLSEndpoint,
   395  		},
   396  		{
   397  			name: "UniverseDomain, no client cert",
   398  			ds: &DialSettings{
   399  				DefaultEndpoint:         testRegularEndpoint,
   400  				DefaultEndpointTemplate: testEndpointTemplate,
   401  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   402  				UniverseDomain:          testUniverseDomain,
   403  			},
   404  			wantEndpoint: testUniverseDomainEndpoint,
   405  		},
   406  		{
   407  			name: "UniverseDomain, client cert",
   408  			ds: &DialSettings{
   409  				DefaultEndpoint:         testRegularEndpoint,
   410  				DefaultEndpointTemplate: testEndpointTemplate,
   411  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   412  				UniverseDomain:          testUniverseDomain,
   413  				ClientCertSource:        dummyClientCertSource,
   414  			},
   415  			wantEndpoint: testUniverseDomainEndpoint,
   416  			wantErr:      errUniverseNotSupportedMTLS,
   417  		},
   418  	}
   419  
   420  	for _, tc := range testCases {
   421  		if tc.ds.ClientCertSource != nil {
   422  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
   423  		} else {
   424  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
   425  		}
   426  		_, _, endpoint, err := GetHTTPTransportConfigAndEndpoint(tc.ds)
   427  		if err != nil {
   428  			if err != tc.wantErr {
   429  				t.Fatalf("%s: err: %v", tc.name, err)
   430  			}
   431  		} else {
   432  			if tc.wantEndpoint != endpoint {
   433  				t.Errorf("%s: want endpoint: [%s], got [%s]", tc.name, tc.wantEndpoint, endpoint)
   434  			}
   435  		}
   436  	}
   437  }
   438  
   439  func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
   440  	testCases := []struct {
   441  		name         string
   442  		ds           *DialSettings
   443  		wantEndpoint string
   444  		wantErr      error
   445  	}{
   446  		{
   447  			name: "google default universe (GDU), no client cert",
   448  			ds: &DialSettings{
   449  				DefaultEndpoint:         testRegularEndpoint,
   450  				DefaultEndpointTemplate: testEndpointTemplate,
   451  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   452  			},
   453  			wantEndpoint: testRegularEndpoint,
   454  		},
   455  		{
   456  			name: "google default universe (GDU), no client cert, endpoint",
   457  			ds: &DialSettings{
   458  				DefaultEndpoint:         testRegularEndpoint,
   459  				DefaultEndpointTemplate: testEndpointTemplate,
   460  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   461  				Endpoint:                testOverrideEndpoint,
   462  			},
   463  			wantEndpoint: testOverrideEndpoint,
   464  		},
   465  		{
   466  			name: "google default universe (GDU), client cert",
   467  			ds: &DialSettings{
   468  				DefaultEndpoint:         testRegularEndpoint,
   469  				DefaultEndpointTemplate: testEndpointTemplate,
   470  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   471  				ClientCertSource:        dummyClientCertSource,
   472  			},
   473  			wantEndpoint: testMTLSEndpoint,
   474  		},
   475  		{
   476  			name: "google default universe (GDU), client cert, endpoint",
   477  			ds: &DialSettings{
   478  				DefaultEndpoint:         testRegularEndpoint,
   479  				DefaultEndpointTemplate: testEndpointTemplate,
   480  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   481  				ClientCertSource:        dummyClientCertSource,
   482  				Endpoint:                testOverrideEndpoint,
   483  			},
   484  			wantEndpoint: testOverrideEndpoint,
   485  		},
   486  		{
   487  			name: "UniverseDomain, no client cert",
   488  			ds: &DialSettings{
   489  				DefaultEndpoint:         testRegularEndpoint,
   490  				DefaultEndpointTemplate: testEndpointTemplate,
   491  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   492  				UniverseDomain:          testUniverseDomain,
   493  			},
   494  			wantEndpoint: testUniverseDomainEndpoint,
   495  		},
   496  		{
   497  			name: "UniverseDomain, no client cert, endpoint",
   498  			ds: &DialSettings{
   499  				DefaultEndpoint:         testRegularEndpoint,
   500  				DefaultEndpointTemplate: testEndpointTemplate,
   501  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   502  				UniverseDomain:          testUniverseDomain,
   503  				Endpoint:                testOverrideEndpoint,
   504  			},
   505  			wantEndpoint: testOverrideEndpoint,
   506  		},
   507  		{
   508  			name: "UniverseDomain, client cert",
   509  			ds: &DialSettings{
   510  				DefaultEndpoint:         testRegularEndpoint,
   511  				DefaultEndpointTemplate: testEndpointTemplate,
   512  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   513  				UniverseDomain:          testUniverseDomain,
   514  				ClientCertSource:        dummyClientCertSource,
   515  			},
   516  			wantErr: errUniverseNotSupportedMTLS,
   517  		},
   518  		{
   519  			name: "UniverseDomain, client cert, endpoint",
   520  			ds: &DialSettings{
   521  				DefaultEndpoint:         testRegularEndpoint,
   522  				DefaultEndpointTemplate: testEndpointTemplate,
   523  				DefaultMTLSEndpoint:     testMTLSEndpoint,
   524  				UniverseDomain:          testUniverseDomain,
   525  				ClientCertSource:        dummyClientCertSource,
   526  				Endpoint:                testOverrideEndpoint,
   527  			},
   528  			wantEndpoint: testOverrideEndpoint,
   529  		},
   530  	}
   531  
   532  	for _, tc := range testCases {
   533  		if tc.ds.ClientCertSource != nil {
   534  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
   535  		} else {
   536  			os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
   537  		}
   538  		_, endpoint, err := GetGRPCTransportConfigAndEndpoint(tc.ds)
   539  		if err != nil {
   540  			if err != tc.wantErr {
   541  				t.Fatalf("%s: err: %v", tc.name, err)
   542  			}
   543  		} else {
   544  			if tc.wantEndpoint != endpoint {
   545  				t.Errorf("%s: want endpoint: [%s], got [%s]", tc.name, tc.wantEndpoint, endpoint)
   546  			}
   547  		}
   548  	}
   549  }
   550  

View as plain text