...

Source file src/cloud.google.com/go/cloudsqlconn/dialer_test.go

Documentation: cloud.google.com/go/cloudsqlconn

     1  // Copyright 2021 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cloudsqlconn
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"os"
    24  	"strings"
    25  	"sync"
    26  	"testing"
    27  	"time"
    28  
    29  	"cloud.google.com/go/cloudsqlconn/errtype"
    30  	"cloud.google.com/go/cloudsqlconn/instance"
    31  	"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
    32  	"cloud.google.com/go/cloudsqlconn/internal/mock"
    33  	"golang.org/x/oauth2"
    34  )
    35  
    36  // testSuccessfulDial uses the provided dialer to dial the specified instance
    37  // and verifies the connection works end to end.
    38  func testSuccessfulDial(
    39  	ctx context.Context, t *testing.T, d *Dialer, icn string, opts ...DialOption,
    40  ) {
    41  	conn, err := d.Dial(ctx, icn, opts...)
    42  	if err != nil {
    43  		t.Fatalf("expected Dial to succeed, but got error: %v", err)
    44  	}
    45  	defer func() { _ = conn.Close() }()
    46  
    47  	data, err := io.ReadAll(conn)
    48  	if err != nil {
    49  		t.Fatalf("expected ReadAll to succeed, got error %v", err)
    50  	}
    51  	if string(data) != "my-instance" {
    52  		t.Fatalf(
    53  			"expected known response from the server, but got %v",
    54  			string(data),
    55  		)
    56  	}
    57  }
    58  
    59  // setupConfig holds all the configuration to use when setting up a dialer.
    60  type setupConfig struct {
    61  	testInstance  mock.FakeCSQLInstance
    62  	skipServer    bool
    63  	skipVerify    bool
    64  	reqs          []*mock.Request
    65  	dialerOptions []Option
    66  }
    67  
    68  // setupDialer configures a Dialer with an HTTP client configured to point at a
    69  // mock SQL Admin API. Use setupConfig to configure the expected requests.
    70  func setupDialer(t *testing.T, c setupConfig) *Dialer {
    71  	svc, cleanup, err := mock.NewSQLAdminService(
    72  		context.Background(),
    73  		c.reqs...,
    74  	)
    75  	if err != nil {
    76  		t.Fatalf("failed to init SQLAdminService: %v", err)
    77  	}
    78  	stop := func() {}
    79  	if !c.skipServer {
    80  		stop = mock.StartServerProxy(t, c.testInstance)
    81  	}
    82  	t.Cleanup(func() {
    83  		stop()
    84  		err := cleanup()
    85  		if !c.skipVerify && err != nil {
    86  			t.Fatalf("%v", err)
    87  		}
    88  	})
    89  
    90  	opts := []Option{
    91  		WithTokenSource(mock.EmptyTokenSource{}),
    92  		// give refresh plenty of time to complete in slower CI builds
    93  		WithRefreshTimeout(time.Minute),
    94  	}
    95  	if c.dialerOptions != nil {
    96  		opts = c.dialerOptions
    97  	}
    98  
    99  	d, err := NewDialer(context.Background(), opts...)
   100  	if err != nil {
   101  		t.Fatalf("expected NewDialer to succeed, but got error: %v", err)
   102  	}
   103  	d.sqladmin = svc
   104  	return d
   105  }
   106  
   107  func TestDialerCanConnectToInstance(t *testing.T) {
   108  	inst := mock.NewFakeCSQLInstance(
   109  		"my-project", "my-region", "my-instance",
   110  	)
   111  	d := setupDialer(t, setupConfig{
   112  		testInstance: inst,
   113  		reqs: []*mock.Request{
   114  			mock.InstanceGetSuccess(inst, 1),
   115  			mock.CreateEphemeralSuccess(inst, 1),
   116  		},
   117  	})
   118  
   119  	testSuccessfulDial(
   120  		context.Background(), t, d,
   121  		inst.String(),
   122  	)
   123  }
   124  
   125  func TestDialWithAdminAPIErrors(t *testing.T) {
   126  	inst := mock.NewFakeCSQLInstance(
   127  		"my-project", "my-region", "my-instance",
   128  	)
   129  	// API server will respond with 40x's
   130  	d := setupDialer(t, setupConfig{testInstance: inst})
   131  
   132  	_, err := d.Dial(
   133  		context.Background(), inst.String(),
   134  	)
   135  	var wantErr *errtype.RefreshError
   136  	if !errors.As(err, &wantErr) {
   137  		t.Fatalf("when API call fails, want = %T, got = %v", wantErr, err)
   138  	}
   139  }
   140  
   141  func TestDialWithConfigurationErrors(t *testing.T) {
   142  	inst := mock.NewFakeCSQLInstance(
   143  		"my-project", "my-region", "my-instance",
   144  	)
   145  	d := setupDialer(t, setupConfig{
   146  		testInstance: inst,
   147  		reqs: []*mock.Request{
   148  			mock.InstanceGetSuccess(inst, 3),
   149  			mock.CreateEphemeralSuccess(inst, 3),
   150  		},
   151  		skipVerify: true,
   152  		skipServer: true,
   153  	})
   154  
   155  	_, err := d.Dial(
   156  		context.Background(),
   157  		// Try private IP of a public IP-only instance
   158  		inst.String(), WithPrivateIP(),
   159  	)
   160  	if err == nil {
   161  		t.Fatal("when IP type is invalid, want = error, got = nil")
   162  	}
   163  
   164  	_, err = d.Dial(
   165  		context.Background(), inst.String(),
   166  	)
   167  	if err == nil {
   168  		t.Fatal("when server proxy socket is unavailable, want = error, got = nil")
   169  	}
   170  }
   171  
   172  func TestDialWithExpiredCertificate(t *testing.T) {
   173  	inst := mock.NewFakeCSQLInstance(
   174  		"my-project", "my-region", "my-instance",
   175  		// Server certificate is expired
   176  		mock.WithCertExpiry(time.Now().Add(-time.Hour)),
   177  	)
   178  	d := setupDialer(t, setupConfig{
   179  		testInstance: inst,
   180  		reqs: []*mock.Request{
   181  			mock.InstanceGetSuccess(inst, 3),
   182  			mock.CreateEphemeralSuccess(inst, 3),
   183  		},
   184  		skipVerify: true,
   185  		skipServer: true,
   186  	})
   187  
   188  	_, err := d.Dial(context.Background(), inst.String())
   189  	if err == nil {
   190  		t.Fatal("when TLS handshake fails, want = error, got = nil")
   191  	}
   192  }
   193  
   194  func fakeServiceAccount(ud string) []byte {
   195  	sa := `
   196  		"type": "service_account",
   197  		"project_id": "a-project-id",
   198  		"private_key_id": "a-private-key-id",
   199  		"private_key": "a-private-key",
   200  		"client_email": "email@example.com",
   201  		"client_id": "12345",
   202  		"auth_uri": "https://accounts.google.com/o/oauth2/auth",
   203  		"token_uri": "https://oauth2.googleapis.com/token",
   204  		"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
   205  		"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/email%40example.com"
   206  	`
   207  	if ud != "" {
   208  		sa = sa + fmt.Sprintf(`, "universe_domain": "%s"`, ud)
   209  	}
   210  	return []byte(fmt.Sprintf(`{ %s }`, sa))
   211  }
   212  
   213  func TestIAMAuthn(t *testing.T) {
   214  	tcs := []struct {
   215  		desc         string
   216  		opts         Option
   217  		wantIAMAuthN bool
   218  	}{
   219  		{
   220  			desc: "When Credentials are provided with IAM Authn ENABLED",
   221  			opts: WithOptions(
   222  				WithIAMAuthN(),
   223  				WithCredentialsJSON(fakeServiceAccount("")),
   224  			),
   225  			wantIAMAuthN: true,
   226  		},
   227  		{
   228  			desc:         "When Credentials are provided with IAM Authn DISABLED",
   229  			opts:         WithCredentialsJSON(fakeServiceAccount("")),
   230  			wantIAMAuthN: false,
   231  		},
   232  	}
   233  
   234  	for _, tc := range tcs {
   235  		t.Run(tc.desc, func(t *testing.T) {
   236  			d, err := NewDialer(context.Background(), tc.opts)
   237  			if err != nil {
   238  				t.Fatalf("NewDialer failed with error = %v", err)
   239  			}
   240  			if gotIAMAuthN := d.defaultDialConfig.useIAMAuthN; gotIAMAuthN != tc.wantIAMAuthN {
   241  				t.Fatalf("want = %v, got = %v", tc.wantIAMAuthN, gotIAMAuthN)
   242  			}
   243  		})
   244  	}
   245  }
   246  
   247  func TestSQLServerFailsOnIAMAuthN(t *testing.T) {
   248  	inst := mock.NewFakeCSQLInstance("proj", "region", "inst",
   249  		mock.WithEngineVersion("SQLSERVER"),
   250  	)
   251  	d := setupDialer(t, setupConfig{
   252  		testInstance: inst,
   253  		reqs: []*mock.Request{
   254  			mock.InstanceGetSuccess(inst, 1),
   255  			mock.CreateEphemeralSuccess(inst, 1),
   256  		},
   257  		dialerOptions: []Option{
   258  			WithIAMAuthNTokenSources(
   259  				mock.EmptyTokenSource{},
   260  				mock.EmptyTokenSource{},
   261  			), WithIAMAuthN(),
   262  		},
   263  		skipVerify: true,
   264  	})
   265  
   266  	_, err := d.Dial(context.Background(), inst.String())
   267  	if err == nil {
   268  		t.Fatalf("version = %v, want error, got nil", "SQLSERVER")
   269  	}
   270  }
   271  
   272  func TestUniverseDomain(t *testing.T) {
   273  	tcs := []struct {
   274  		desc string
   275  		opts Option
   276  	}{
   277  		{
   278  			desc: "When universe domain matches GDU",
   279  			opts: WithOptions(
   280  				WithUniverseDomain("googleapis.com"),
   281  				WithCredentialsJSON(fakeServiceAccount("")),
   282  			),
   283  		},
   284  		{
   285  			desc: "When TPC universe matches TPC credential domain",
   286  			opts: WithOptions(
   287  				WithUniverseDomain("test-universe.test"),
   288  				WithCredentialsJSON(fakeServiceAccount("test-universe.test")),
   289  			),
   290  		},
   291  	}
   292  
   293  	for _, tc := range tcs {
   294  		t.Run(tc.desc, func(t *testing.T) {
   295  			_, err := NewDialer(context.Background(), tc.opts)
   296  			if err != nil {
   297  				t.Fatalf("NewDialer failed with error = %v", err)
   298  			}
   299  		})
   300  	}
   301  }
   302  
   303  func TestUniverseDomainErrors(t *testing.T) {
   304  	tcs := []struct {
   305  		desc string
   306  		opts Option
   307  	}{
   308  		{
   309  			desc: "When universe domain does not match ADC credentials from GDU",
   310  			opts: WithOptions(WithUniverseDomain("test-universe.test")),
   311  		},
   312  		{
   313  			desc: "When GDU does not match credential domain",
   314  			opts: WithOptions(WithCredentialsJSON(
   315  				fakeServiceAccount("test-universe.test"),
   316  			)),
   317  		},
   318  		{
   319  			desc: "WithUniverseDomain used alongside WithAdminAPIEndpoint",
   320  			opts: WithOptions(
   321  				WithUniverseDomain("googleapis.com"),
   322  				WithAdminAPIEndpoint("https://sqladmin.googleapis.com"),
   323  			),
   324  		},
   325  	}
   326  
   327  	for _, tc := range tcs {
   328  		t.Run(tc.desc, func(t *testing.T) {
   329  			_, err := NewDialer(context.Background(), tc.opts)
   330  			t.Log(err)
   331  			if err == nil {
   332  				t.Fatalf("Wanted universe domain mismatch, want error, got nil")
   333  			}
   334  		})
   335  	}
   336  }
   337  
   338  func TestDialerWithCustomDialFunc(t *testing.T) {
   339  	inst := mock.NewFakeCSQLInstance("proj", "region", "inst",
   340  		mock.WithEngineVersion("SQLSERVER"),
   341  	)
   342  	d := setupDialer(t, setupConfig{
   343  		testInstance: inst,
   344  		reqs: []*mock.Request{
   345  			mock.InstanceGetSuccess(inst, 1),
   346  			mock.CreateEphemeralSuccess(inst, 1),
   347  		},
   348  		dialerOptions: []Option{
   349  			WithTokenSource(mock.EmptyTokenSource{}),
   350  			WithDialFunc(func(context.Context, string, string) (net.Conn, error) {
   351  				return nil, errors.New("sentinel error")
   352  			}),
   353  		},
   354  	})
   355  
   356  	_, err := d.Dial(context.Background(), inst.String())
   357  	if !strings.Contains(err.Error(), "sentinel error") {
   358  		t.Fatalf("want = sentinel error, got = %v", err)
   359  	}
   360  }
   361  
   362  func TestDialerEngineVersion(t *testing.T) {
   363  	tests := []string{
   364  		"MYSQL_5_7", "POSTGRES_14", "SQLSERVER_2019_STANDARD", "MYSQL_8_0_18",
   365  	}
   366  	for _, wantEV := range tests {
   367  		t.Run(wantEV, func(t *testing.T) {
   368  			ctx, cancel := context.WithCancel(context.Background())
   369  			defer cancel()
   370  			inst := mock.NewFakeCSQLInstance(
   371  				"my-project", "my-region", "my-instance",
   372  				mock.WithEngineVersion(wantEV),
   373  			)
   374  			d := setupDialer(t, setupConfig{
   375  				testInstance: inst,
   376  				reqs: []*mock.Request{
   377  					mock.InstanceGetSuccess(inst, 1),
   378  					mock.CreateEphemeralSuccess(inst, 1),
   379  				},
   380  				dialerOptions: []Option{
   381  					WithTokenSource(mock.EmptyTokenSource{}),
   382  				},
   383  			})
   384  
   385  			gotEV, err := d.EngineVersion(ctx, inst.String())
   386  			if err != nil {
   387  				t.Fatalf("failed to retrieve engine version: %v", err)
   388  			}
   389  			if wantEV != gotEV {
   390  				t.Errorf(
   391  					"InstanceEngineVersion(%s) failed: want %v, got %v",
   392  					wantEV, gotEV, err,
   393  				)
   394  			}
   395  		})
   396  	}
   397  }
   398  
   399  func TestDialerUserAgent(t *testing.T) {
   400  	data, err := os.ReadFile("version.txt")
   401  	if err != nil {
   402  		t.Fatalf("failed to read version.txt: %v", err)
   403  	}
   404  	ver := strings.TrimSpace(string(data))
   405  	want := "cloud-sql-go-connector/" + ver
   406  	if want != userAgent {
   407  		t.Errorf("embed version mismatched: want %q, got %q", want, userAgent)
   408  	}
   409  }
   410  
   411  func TestWarmup(t *testing.T) {
   412  	ctx, cancel := context.WithCancel(context.Background())
   413  	defer cancel()
   414  
   415  	inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
   416  	tests := []struct {
   417  		desc          string
   418  		warmupOpts    []DialOption
   419  		dialOpts      []DialOption
   420  		expectedCalls []*mock.Request
   421  	}{
   422  		{
   423  			desc:       "warmup and dial are the same",
   424  			warmupOpts: []DialOption{WithDialIAMAuthN(true)},
   425  			dialOpts:   []DialOption{WithDialIAMAuthN(true)},
   426  			expectedCalls: []*mock.Request{
   427  				mock.InstanceGetSuccess(inst, 1),
   428  				mock.CreateEphemeralSuccess(inst, 1),
   429  			},
   430  		},
   431  		{
   432  			desc:       "warmup and dial are different",
   433  			warmupOpts: []DialOption{WithDialIAMAuthN(true)},
   434  			dialOpts:   []DialOption{WithDialIAMAuthN(false)},
   435  			expectedCalls: []*mock.Request{
   436  				mock.InstanceGetSuccess(inst, 2),
   437  				mock.CreateEphemeralSuccess(inst, 2),
   438  			},
   439  		},
   440  		{
   441  			desc:       "warmup and default dial are different",
   442  			warmupOpts: []DialOption{WithDialIAMAuthN(true)},
   443  			dialOpts:   []DialOption{},
   444  			expectedCalls: []*mock.Request{
   445  				mock.InstanceGetSuccess(inst, 2),
   446  				mock.CreateEphemeralSuccess(inst, 2),
   447  			},
   448  		},
   449  	}
   450  
   451  	for _, test := range tests {
   452  		t.Run(test.desc, func(t *testing.T) {
   453  			d := setupDialer(t, setupConfig{
   454  				testInstance: inst,
   455  				reqs:         test.expectedCalls,
   456  			})
   457  
   458  			// Warmup once with the "default" options
   459  			err := d.Warmup(ctx, inst.String(), test.warmupOpts...)
   460  			if err != nil {
   461  				t.Fatalf("Warmup failed: %v", err)
   462  			}
   463  			// Call EngineVersion to make sure we block until both API calls
   464  			// are completed.
   465  			_, err = d.EngineVersion(ctx, inst.String())
   466  			if err != nil {
   467  				t.Fatalf("Warmup failed: %v", err)
   468  			}
   469  			// Dial once with the "dial" options
   470  			testSuccessfulDial(
   471  				ctx, t, d,
   472  				inst.String(),
   473  				test.dialOpts...,
   474  			)
   475  		})
   476  	}
   477  }
   478  
   479  func TestDialDialerOptsConflicts(t *testing.T) {
   480  	ctx, cancel := context.WithCancel(context.Background())
   481  	defer cancel()
   482  
   483  	inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
   484  	tests := []struct {
   485  		desc          string
   486  		dialerOpts    []Option
   487  		dialOpts      []DialOption
   488  		expectedCalls []*mock.Request
   489  	}{
   490  		{
   491  			desc:       "dialer opts set and dial uses default",
   492  			dialerOpts: []Option{WithIAMAuthN()},
   493  			dialOpts:   []DialOption{},
   494  			expectedCalls: []*mock.Request{
   495  				mock.InstanceGetSuccess(inst, 1),
   496  				mock.CreateEphemeralSuccess(inst, 1),
   497  			},
   498  		},
   499  		{
   500  			desc:       "dialer and dial opts are the same",
   501  			dialerOpts: []Option{WithIAMAuthN()},
   502  			dialOpts:   []DialOption{WithDialIAMAuthN(true)},
   503  			expectedCalls: []*mock.Request{
   504  				mock.InstanceGetSuccess(inst, 1),
   505  				mock.CreateEphemeralSuccess(inst, 1),
   506  			},
   507  		},
   508  		{
   509  			desc:       "dialer and dial opts are different",
   510  			dialerOpts: []Option{WithIAMAuthN()},
   511  			dialOpts:   []DialOption{WithDialIAMAuthN(false)},
   512  			expectedCalls: []*mock.Request{
   513  				mock.InstanceGetSuccess(inst, 2),
   514  				mock.CreateEphemeralSuccess(inst, 2),
   515  			},
   516  		},
   517  	}
   518  
   519  	for _, tc := range tests {
   520  		t.Run(tc.desc, func(t *testing.T) {
   521  			d := setupDialer(t, setupConfig{
   522  				testInstance: inst,
   523  				reqs:         tc.expectedCalls,
   524  				dialerOptions: append(
   525  					tc.dialerOpts,
   526  					WithIAMAuthNTokenSources(
   527  						mock.EmptyTokenSource{}, mock.EmptyTokenSource{},
   528  					),
   529  				),
   530  			})
   531  
   532  			// Dial once with the "default" options
   533  			testSuccessfulDial(ctx, t, d, inst.String())
   534  
   535  			// Dial once with the "dial" options
   536  			testSuccessfulDial(ctx, t, d, inst.String(), tc.dialOpts...)
   537  		})
   538  	}
   539  }
   540  
   541  func TestTokenSourceWithIAMAuthN(t *testing.T) {
   542  	ts := oauth2.StaticTokenSource(&oauth2.Token{})
   543  	tcs := []struct {
   544  		desc    string
   545  		opts    []Option
   546  		wantErr bool
   547  	}{
   548  		{
   549  			desc:    "when token source is set with IAM AuthN",
   550  			opts:    []Option{WithTokenSource(ts), WithIAMAuthN()},
   551  			wantErr: true,
   552  		},
   553  		{
   554  			desc:    "when IAM AuthN token source is set without IAM AuthN",
   555  			opts:    []Option{WithIAMAuthNTokenSources(ts, ts)},
   556  			wantErr: true,
   557  		},
   558  	}
   559  	for _, tc := range tcs {
   560  		t.Run(tc.desc, func(t *testing.T) {
   561  			_, err := NewDialer(context.Background(), tc.opts...)
   562  			gotErr := err != nil
   563  			if tc.wantErr != gotErr {
   564  				t.Fatalf("err: want = %v, got = %v", tc.wantErr, gotErr)
   565  			}
   566  		})
   567  	}
   568  }
   569  
   570  func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
   571  	// When a dialer attempts to retrieve connection info for a
   572  	// non-existent instance, it should delete the instance from
   573  	// the cache and ensure no background refresh happens (which would be
   574  	// wasted cycles).
   575  	d, err := NewDialer(
   576  		context.Background(),
   577  		WithTokenSource(mock.EmptyTokenSource{}),
   578  	)
   579  	if err != nil {
   580  		t.Fatalf("expected NewDialer to succeed, but got error: %v", err)
   581  	}
   582  
   583  	// Populate instance map with connection info cache that will always fail
   584  	// This allows the test to verify the error case path invoking close.
   585  	badInstanceConnectionName := "doesntexist:us-central1:doesntexist"
   586  	badCN, _ := instance.ParseConnName(badInstanceConnectionName)
   587  	spy := &spyConnectionInfoCache{
   588  		connectInfoCalls: []struct {
   589  			info cloudsql.ConnectionInfo
   590  			err  error
   591  		}{{
   592  			err: errors.New("connect info failed"),
   593  		}},
   594  	}
   595  	d.cache[badCN] = monitoredCache{connectionInfoCache: spy}
   596  
   597  	_, err = d.Dial(context.Background(), badInstanceConnectionName)
   598  	if err == nil {
   599  		t.Fatal("expected Dial to return error")
   600  	}
   601  
   602  	// Verify that the connection info cache was closed (to prevent
   603  	// further failed refresh operations)
   604  	if got, want := spy.CloseWasCalled(), true; got != want {
   605  		t.Fatal("Close was not called")
   606  	}
   607  
   608  	// Now verify that bad connection name has been deleted from map.
   609  	d.lock.RLock()
   610  	_, ok := d.cache[badCN]
   611  	d.lock.RUnlock()
   612  	if ok {
   613  		t.Fatal("bad instance was not removed from the cache")
   614  	}
   615  }
   616  
   617  func TestDialRefreshesExpiredCertificates(t *testing.T) {
   618  	d, err := NewDialer(context.Background(),
   619  		WithTokenSource(mock.EmptyTokenSource{}),
   620  	)
   621  	if err != nil {
   622  		t.Fatalf("expected NewDialer to succeed, but got error: %v", err)
   623  	}
   624  
   625  	sentinel := errors.New("connect info failed")
   626  	icn := "project:region:instance"
   627  	cn, _ := instance.ParseConnName(icn)
   628  	spy := &spyConnectionInfoCache{
   629  		connectInfoCalls: []struct {
   630  			info cloudsql.ConnectionInfo
   631  			err  error
   632  		}{
   633  			// First call returns expired certificate
   634  			{
   635  				// Certificate expired 10 hours ago.
   636  				info: cloudsql.ConnectionInfo{
   637  					Expiration: time.Now().Add(-10 * time.Hour),
   638  				},
   639  			},
   640  			// Second call errors to validate error path
   641  			{
   642  				err: sentinel,
   643  			},
   644  		},
   645  	}
   646  	d.cache[cn] = monitoredCache{connectionInfoCache: spy}
   647  
   648  	_, err = d.Dial(context.Background(), icn)
   649  	if !errors.Is(err, sentinel) {
   650  		t.Fatalf("expected Dial to return sentinel error, instead got = %v", err)
   651  	}
   652  
   653  	// Verify that the cache was refreshed
   654  	if got, want := spy.ForceRefreshWasCalled(), true; got != want {
   655  		t.Fatal("ForceRefresh was not called")
   656  	}
   657  
   658  	// Verify that the connection info cache was closed (to prevent
   659  	// further failed refresh operations)
   660  	if got, want := spy.CloseWasCalled(), true; got != want {
   661  		t.Fatal("Close was not called")
   662  	}
   663  
   664  	// Now verify that bad connection name has been deleted from map.
   665  	d.lock.RLock()
   666  	_, ok := d.cache[cn]
   667  	d.lock.RUnlock()
   668  	if ok {
   669  		t.Fatal("bad instance was not removed from the cache")
   670  	}
   671  
   672  }
   673  
   674  type spyConnectionInfoCache struct {
   675  	mu               sync.Mutex
   676  	connectInfoIndex int
   677  	connectInfoCalls []struct {
   678  		info cloudsql.ConnectionInfo
   679  		err  error
   680  	}
   681  	closeWasCalled        bool
   682  	forceRefreshWasCalled bool
   683  	// embed interface to avoid having to implement irrelevant methods
   684  	connectionInfoCache
   685  }
   686  
   687  func (s *spyConnectionInfoCache) ConnectionInfo(
   688  	context.Context,
   689  ) (cloudsql.ConnectionInfo, error) {
   690  	s.mu.Lock()
   691  	defer s.mu.Unlock()
   692  	res := s.connectInfoCalls[s.connectInfoIndex]
   693  	s.connectInfoIndex++
   694  	return res.info, res.err
   695  }
   696  
   697  func (s *spyConnectionInfoCache) ForceRefresh() {
   698  	s.mu.Lock()
   699  	defer s.mu.Unlock()
   700  	s.forceRefreshWasCalled = true
   701  }
   702  
   703  func (s *spyConnectionInfoCache) UpdateRefresh(*bool) {}
   704  
   705  func (s *spyConnectionInfoCache) Close() error {
   706  	s.mu.Lock()
   707  	defer s.mu.Unlock()
   708  	s.closeWasCalled = true
   709  	return nil
   710  }
   711  
   712  func (s *spyConnectionInfoCache) CloseWasCalled() bool {
   713  	s.mu.Lock()
   714  	defer s.mu.Unlock()
   715  	return s.closeWasCalled
   716  }
   717  
   718  func (s *spyConnectionInfoCache) ForceRefreshWasCalled() bool {
   719  	s.mu.Lock()
   720  	defer s.mu.Unlock()
   721  	return s.forceRefreshWasCalled
   722  }
   723  
   724  func TestDialerSupportsOneOffDialFunction(t *testing.T) {
   725  	ctx := context.Background()
   726  	inst := mock.NewFakeCSQLInstance("p", "r", "i")
   727  	svc, cleanup, err := mock.NewSQLAdminService(
   728  		context.Background(),
   729  		mock.InstanceGetSuccess(inst, 1),
   730  		mock.CreateEphemeralSuccess(inst, 1),
   731  	)
   732  	if err != nil {
   733  		t.Fatalf("failed to init SQLAdminService: %v", err)
   734  	}
   735  	d, err := NewDialer(ctx, WithTokenSource(mock.EmptyTokenSource{}))
   736  	if err != nil {
   737  		t.Fatal(err)
   738  	}
   739  	d.sqladmin = svc
   740  	defer func() {
   741  		if err := d.Close(); err != nil {
   742  			t.Log(err)
   743  		}
   744  		_ = cleanup()
   745  	}()
   746  
   747  	sentinelErr := errors.New("dial func was called")
   748  	f := func(context.Context, string, string) (net.Conn, error) {
   749  		return nil, sentinelErr
   750  	}
   751  
   752  	if _, err := d.Dial(ctx, "p:r:i", WithOneOffDialFunc(f)); !errors.Is(err, sentinelErr) {
   753  		t.Fatal("one-off dial func was not called")
   754  	}
   755  }
   756  
   757  func TestDialerCloseReportsFriendlyError(t *testing.T) {
   758  	d, err := NewDialer(
   759  		context.Background(),
   760  		WithTokenSource(mock.EmptyTokenSource{}),
   761  	)
   762  	if err != nil {
   763  		t.Fatal(err)
   764  	}
   765  	_ = d.Close()
   766  
   767  	_, err = d.Dial(context.Background(), "p:r:i")
   768  	if !errors.Is(err, ErrDialerClosed) {
   769  		t.Fatalf("want = %v, got = %v", ErrDialerClosed, err)
   770  	}
   771  
   772  	// Ensure multiple calls to close don't panic
   773  	_ = d.Close()
   774  
   775  	_, err = d.Dial(context.Background(), "p:r:i")
   776  	if !errors.Is(err, ErrDialerClosed) {
   777  		t.Fatalf("want = %v, got = %v", ErrDialerClosed, err)
   778  	}
   779  }
   780  
   781  func TestDialerInitializesLazyCache(t *testing.T) {
   782  	cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
   783  	inst := mock.NewFakeCSQLInstance(
   784  		cn.Project(), cn.Region(), cn.Name(),
   785  	)
   786  	d := setupDialer(t, setupConfig{
   787  		testInstance: inst,
   788  		reqs: []*mock.Request{
   789  			mock.InstanceGetSuccess(inst, 1),
   790  			mock.CreateEphemeralSuccess(inst, 1),
   791  		},
   792  		dialerOptions: []Option{
   793  			WithTokenSource(mock.EmptyTokenSource{}),
   794  			WithLazyRefresh(),
   795  		},
   796  	})
   797  
   798  	// Initialize the connection info cache
   799  	_, err := d.Dial(context.Background(), inst.String())
   800  	if err != nil {
   801  		t.Fatal(err)
   802  	}
   803  
   804  	c, ok := d.cache[cn]
   805  	if !ok {
   806  		t.Fatal("cache was not populated")
   807  	}
   808  	switch tt := c.connectionInfoCache.(type) {
   809  	case *cloudsql.LazyRefreshCache:
   810  		// Pass -- the cache was initialized with the correct type
   811  	default:
   812  		t.Fatalf("dialer was initialized with non-lazy type: %T", tt)
   813  	}
   814  }
   815  

View as plain text