...

Source file src/cloud.google.com/go/cloudsqlconn/internal/cloudsql/refresh_test.go

Documentation: cloud.google.com/go/cloudsqlconn/internal/cloudsql

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cloudsql
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/rsa"
    21  	"crypto/x509"
    22  	"encoding/pem"
    23  	"errors"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"cloud.google.com/go/cloudsqlconn/errtype"
    29  	"cloud.google.com/go/cloudsqlconn/internal/mock"
    30  	"golang.org/x/oauth2"
    31  )
    32  
    33  const testDialerID = "some-dialer-id"
    34  
    35  func TestRefresh(t *testing.T) {
    36  	wantPublicIP := "127.0.0.1"
    37  	wantPrivateIP := "10.0.0.1"
    38  	wantPSC := "abcde.12345.us-central1.sql.goog"
    39  	wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second)
    40  	cn := testInstanceConnName()
    41  	inst := mock.NewFakeCSQLInstance(
    42  		cn.Project(), cn.Region(), cn.Name(),
    43  		mock.WithPublicIP(wantPublicIP),
    44  		mock.WithPrivateIP(wantPrivateIP),
    45  		mock.WithPSC(wantPSC),
    46  		mock.WithCertExpiry(wantExpiry),
    47  	)
    48  	client, cleanup, err := mock.NewSQLAdminService(
    49  		context.Background(),
    50  		mock.InstanceGetSuccess(inst, 1),
    51  		mock.CreateEphemeralSuccess(inst, 1),
    52  	)
    53  	if err != nil {
    54  		t.Fatalf("failed to create test SQL admin service: %s", err)
    55  	}
    56  	defer func() {
    57  		if err := cleanup(); err != nil {
    58  			t.Fatalf("%v", err)
    59  		}
    60  	}()
    61  
    62  	r := newRefresher(nullLogger{}, client, nil, testDialerID)
    63  	rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false)
    64  	if err != nil {
    65  		t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
    66  	}
    67  
    68  	gotIP, ok := rr.addrs[PublicIP]
    69  	if !ok {
    70  		t.Fatal("metadata IP addresses did not include public address")
    71  	}
    72  	if wantPublicIP != gotIP {
    73  		t.Fatalf("metadata IP mismatch, want = %v, got = %v", wantPublicIP, gotIP)
    74  	}
    75  	gotIP, ok = rr.addrs[PrivateIP]
    76  	if !ok {
    77  		t.Fatal("metadata IP addresses did not include private address")
    78  	}
    79  	if wantPrivateIP != gotIP {
    80  		t.Fatalf("metadata IP mismatch, want = %v, got = %v", wantPrivateIP, gotIP)
    81  	}
    82  	gotPSC, ok := rr.addrs[PSC]
    83  	if !ok {
    84  		t.Fatal("metadata IP addresses did not include PSC endpoint")
    85  	}
    86  	if wantPSC != gotPSC {
    87  		t.Fatalf("metadata IP mismatch, want = %v. got = %v", wantPSC, gotPSC)
    88  	}
    89  	if cn != rr.ConnectionName {
    90  		t.Fatalf(
    91  			"connection name mismatch, want = %v, got = %v",
    92  			wantExpiry, rr.Expiration,
    93  		)
    94  	}
    95  	if wantExpiry != rr.Expiration {
    96  		t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.Expiration)
    97  	}
    98  }
    99  func TestRefreshRetries50xResponses(t *testing.T) {
   100  	cn := testInstanceConnName()
   101  	inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
   102  		mock.WithEngineVersion("WANTED_VERSION"),
   103  	)
   104  	client, cleanup, err := mock.NewSQLAdminService(
   105  		context.Background(),
   106  		// First a 500, then a 200 response
   107  		mock.InstanceGet500(inst, 1),
   108  		mock.InstanceGetSuccess(inst, 1),
   109  		// First a 500, then a 200 response
   110  		mock.CreateEphemeral500(inst, 1),
   111  		mock.CreateEphemeralSuccess(inst, 1),
   112  	)
   113  	if err != nil {
   114  		t.Fatalf("failed to create test SQL admin service: %s", err)
   115  	}
   116  	defer func() {
   117  		if err := cleanup(); err != nil {
   118  			t.Fatalf("%v", err)
   119  		}
   120  	}()
   121  
   122  	r := newRefresher(nullLogger{}, client, nil, testDialerID)
   123  	rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false)
   124  	if err != nil {
   125  		t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
   126  	}
   127  	if rr.DBVersion != "WANTED_VERSION" {
   128  		t.Fatalf("DB version did not match expected, got = %v, want = %v",
   129  			rr.DBVersion, "WANTED_VERSION",
   130  		)
   131  	}
   132  }
   133  
   134  func TestRefreshFailsFast(t *testing.T) {
   135  	cn := testInstanceConnName()
   136  	inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
   137  	client, cleanup, err := mock.NewSQLAdminService(
   138  		context.Background(),
   139  		mock.InstanceGetSuccess(inst, 1),
   140  		mock.CreateEphemeralSuccess(inst, 1),
   141  	)
   142  	if err != nil {
   143  		t.Fatalf("failed to create test SQL admin service: %s", err)
   144  	}
   145  	defer cleanup()
   146  
   147  	r := newRefresher(nullLogger{}, client, nil, testDialerID)
   148  	_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
   149  	if err != nil {
   150  		t.Fatalf("expected no error, got = %v", err)
   151  	}
   152  
   153  	ctx, cancel := context.WithCancel(context.Background())
   154  	cancel()
   155  	// context is canceled
   156  	_, err = r.ConnectionInfo(ctx, cn, RSAKey, false)
   157  	if !errors.Is(err, context.Canceled) {
   158  		t.Fatalf("expected context.Canceled error, got = %v", err)
   159  	}
   160  }
   161  
   162  type tokenResp struct {
   163  	tok *oauth2.Token
   164  	err error
   165  }
   166  
   167  type fakeTokenSource struct {
   168  	responses []tokenResp
   169  	mu        sync.Mutex
   170  	ct        int
   171  }
   172  
   173  func (f *fakeTokenSource) Token() (*oauth2.Token, error) {
   174  	f.mu.Lock()
   175  	defer f.mu.Unlock()
   176  	resp := f.responses[f.ct]
   177  	f.ct++
   178  	return resp.tok, resp.err
   179  }
   180  
   181  func (f *fakeTokenSource) count() int {
   182  	f.mu.Lock()
   183  	defer f.mu.Unlock()
   184  	return f.ct
   185  }
   186  
   187  func TestRefreshAdjustsCertExpiry(t *testing.T) {
   188  	certExpiry := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
   189  	t1 := time.Now().Add(59 * time.Minute).UTC().Truncate(time.Second)
   190  	t2 := time.Now().Add(61 * time.Minute).UTC().Truncate(time.Second)
   191  	tcs := []struct {
   192  		desc       string
   193  		resps      []tokenResp
   194  		wantExpiry time.Time
   195  	}{
   196  		{
   197  			desc: "when the token's expiration comes BEFORE the cert",
   198  			resps: []tokenResp{
   199  				{tok: &oauth2.Token{}},
   200  				{tok: &oauth2.Token{Expiry: t1}},
   201  			},
   202  			wantExpiry: t1,
   203  		},
   204  		{
   205  			desc: "when the token's expiration comes AFTER the cert",
   206  			resps: []tokenResp{
   207  				{tok: &oauth2.Token{}},
   208  				{tok: &oauth2.Token{Expiry: t2}},
   209  			},
   210  			wantExpiry: certExpiry,
   211  		},
   212  	}
   213  	cn := testInstanceConnName()
   214  	inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance",
   215  		mock.WithCertExpiry(certExpiry))
   216  	client, cleanup, err := mock.NewSQLAdminService(
   217  		context.Background(),
   218  		mock.InstanceGetSuccess(inst, 2),
   219  		mock.CreateEphemeralSuccess(inst, 2),
   220  	)
   221  	if err != nil {
   222  		t.Fatalf("failed to create test SQL admin service: %s", err)
   223  	}
   224  	defer cleanup()
   225  
   226  	for _, tc := range tcs {
   227  		t.Run(tc.desc, func(t *testing.T) {
   228  			ts := &fakeTokenSource{responses: tc.resps}
   229  			r := newRefresher(nullLogger{}, client, ts, testDialerID)
   230  			rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
   231  			if err != nil {
   232  				t.Fatalf("want no error, got = %v", err)
   233  			}
   234  			if tc.wantExpiry != rr.Expiration {
   235  				t.Fatalf("want = %v, got = %v", tc.wantExpiry, rr.Expiration)
   236  			}
   237  		})
   238  	}
   239  }
   240  
   241  func TestRefreshWithIAMAuthErrors(t *testing.T) {
   242  	tcs := []struct {
   243  		desc      string
   244  		resps     []tokenResp
   245  		wantCount int
   246  	}{
   247  		{
   248  			desc:      "when fetching a token fails",
   249  			resps:     []tokenResp{{tok: nil, err: errors.New("fetch failed")}},
   250  			wantCount: 1,
   251  		},
   252  		{
   253  			desc: "when refreshing a token fails",
   254  			resps: []tokenResp{
   255  				{tok: &oauth2.Token{}, err: nil},
   256  				{tok: nil, err: errors.New("refresh failed")},
   257  			},
   258  			wantCount: 2,
   259  		},
   260  	}
   261  	cn := testInstanceConnName()
   262  	inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
   263  	client, cleanup, err := mock.NewSQLAdminService(
   264  		context.Background(),
   265  		mock.InstanceGetSuccess(inst, 2),
   266  	)
   267  	if err != nil {
   268  		t.Fatalf("failed to create test SQL admin service: %s", err)
   269  	}
   270  	defer cleanup()
   271  
   272  	for _, tc := range tcs {
   273  		t.Run(tc.desc, func(t *testing.T) {
   274  			ts := &fakeTokenSource{responses: tc.resps}
   275  			r := newRefresher(nullLogger{}, client, ts, testDialerID)
   276  			_, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
   277  			if err == nil {
   278  				t.Fatalf("expected get failed error, got = %v", err)
   279  			}
   280  			if count := ts.count(); count != tc.wantCount {
   281  				t.Fatalf("expected fake token source to be called %v time, got = %v", tc.wantCount, count)
   282  			}
   283  		})
   284  	}
   285  }
   286  
   287  func TestRefreshMetadataConfigError(t *testing.T) {
   288  	cn := testInstanceConnName()
   289  
   290  	testCases := []struct {
   291  		req     *mock.Request
   292  		wantErr *errtype.ConfigError
   293  		desc    string
   294  	}{
   295  		{
   296  			req: mock.InstanceGetSuccess(
   297  				mock.NewFakeCSQLInstance(
   298  					cn.Project(), cn.Region(), cn.Name(),
   299  					mock.WithRegion("my-region"),
   300  					mock.WithFirstGenBackend(),
   301  				), 1),
   302  			wantErr: &errtype.ConfigError{},
   303  			desc:    "When the instance isn't Second generation",
   304  		},
   305  		{
   306  			req: mock.InstanceGetSuccess(
   307  				mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
   308  					mock.WithRegion("some-other-region")), 1),
   309  			wantErr: &errtype.ConfigError{},
   310  			desc:    "When the region does not match",
   311  		},
   312  		{
   313  			req: mock.InstanceGetSuccess(
   314  				mock.NewFakeCSQLInstance(
   315  					cn.Project(), cn.Region(), cn.Name(),
   316  					mock.WithRegion("my-region"),
   317  					mock.WithNoIPAddrs(),
   318  				), 1),
   319  			wantErr: &errtype.ConfigError{},
   320  			desc:    "When the instance has no supported IP addresses",
   321  		},
   322  	}
   323  
   324  	for i, tc := range testCases {
   325  		t.Run(tc.desc, func(t *testing.T) {
   326  			client, cleanup, err := mock.NewSQLAdminService(
   327  				context.Background(),
   328  				tc.req,
   329  			)
   330  			if err != nil {
   331  				t.Fatalf("failed to create test SQL admin service: %s", err)
   332  			}
   333  			defer cleanup()
   334  
   335  			r := newRefresher(nullLogger{}, client, nil, testDialerID)
   336  			_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
   337  			if !errors.As(err, &tc.wantErr) {
   338  				t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
   339  			}
   340  		})
   341  	}
   342  }
   343  
   344  func TestRefreshMetadataRefreshError(t *testing.T) {
   345  	cn := testInstanceConnName()
   346  
   347  	testCases := []struct {
   348  		req     *mock.Request
   349  		wantErr *errtype.RefreshError
   350  		desc    string
   351  	}{
   352  		{
   353  			req: mock.CreateEphemeralSuccess(
   354  				mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name()), 1),
   355  			wantErr: &errtype.RefreshError{},
   356  			desc:    "When the Metadata call fails",
   357  		},
   358  		{
   359  			req: mock.InstanceGetSuccess(
   360  				mock.NewFakeCSQLInstance(
   361  					cn.Project(), cn.Region(), cn.Name(),
   362  					mock.WithRegion("my-region"),
   363  					mock.WithCertSigner(func(_ *x509.Certificate, _ *rsa.PrivateKey) ([]byte, error) {
   364  						return nil, nil
   365  					}),
   366  				), 1),
   367  			wantErr: &errtype.RefreshError{},
   368  			desc:    "When the server cert does not decode",
   369  		},
   370  		{
   371  			req: mock.InstanceGetSuccess(
   372  				mock.NewFakeCSQLInstance(
   373  					cn.Project(), cn.Region(), cn.Name(),
   374  					mock.WithRegion("my-region"),
   375  					mock.WithCertSigner(func(_ *x509.Certificate, _ *rsa.PrivateKey) ([]byte, error) {
   376  						certPEM := &bytes.Buffer{}
   377  						pem.Encode(certPEM, &pem.Block{
   378  							Type:  "CERTIFICATE",
   379  							Bytes: []byte("hello"), // woops no cert
   380  						})
   381  						return certPEM.Bytes(), nil
   382  					}),
   383  				), 1),
   384  			wantErr: &errtype.RefreshError{},
   385  			desc:    "When the cert is not a valid X.509 cert",
   386  		},
   387  	}
   388  
   389  	for i, tc := range testCases {
   390  		t.Run(tc.desc, func(t *testing.T) {
   391  			client, cleanup, err := mock.NewSQLAdminService(
   392  				context.Background(),
   393  				tc.req,
   394  			)
   395  			if err != nil {
   396  				t.Fatalf("failed to create test SQL admin service: %s", err)
   397  			}
   398  			defer cleanup()
   399  
   400  			r := newRefresher(nullLogger{}, client, nil, testDialerID)
   401  			_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
   402  			if !errors.As(err, &tc.wantErr) {
   403  				t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
   404  			}
   405  		})
   406  	}
   407  }
   408  
   409  func TestRefreshWithFailedEphemeralCertCall(t *testing.T) {
   410  	cn := testInstanceConnName()
   411  	inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
   412  
   413  	testCases := []struct {
   414  		reqs    []*mock.Request
   415  		wantErr *errtype.RefreshError
   416  		desc    string
   417  	}{
   418  		{
   419  			reqs:    []*mock.Request{mock.InstanceGetSuccess(inst, 1)}, // no ephemeral cert call registered
   420  			wantErr: &errtype.RefreshError{},
   421  			desc:    "When the CreateEphemeralCert call fails",
   422  		},
   423  		{
   424  			reqs: []*mock.Request{mock.InstanceGetSuccess(inst, 1),
   425  				mock.CreateEphemeralSuccess(
   426  					mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
   427  						mock.WithClientCertSigner(
   428  							func(*x509.Certificate, *rsa.PrivateKey, *rsa.PublicKey) ([]byte, error) {
   429  								return nil, nil
   430  							}),
   431  					), 1),
   432  			},
   433  			wantErr: &errtype.RefreshError{},
   434  			desc:    "When decoding the cert fails", // SQL Admin API fail
   435  		},
   436  		{
   437  			reqs: []*mock.Request{mock.InstanceGetSuccess(inst, 1),
   438  				mock.CreateEphemeralSuccess(
   439  					mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
   440  						mock.WithClientCertSigner(
   441  							func(*x509.Certificate, *rsa.PrivateKey, *rsa.PublicKey) ([]byte, error) {
   442  								certPEM := &bytes.Buffer{}
   443  								pem.Encode(certPEM, &pem.Block{
   444  									Type:  "CERTIFICATE",
   445  									Bytes: []byte("hello"), // woops no cert
   446  								})
   447  								return certPEM.Bytes(), nil
   448  							}),
   449  					), 1),
   450  			},
   451  			wantErr: &errtype.RefreshError{},
   452  			desc:    "When parsing the cert fails", // SQL Admin API fail
   453  		},
   454  	}
   455  	for i, tc := range testCases {
   456  		client, cleanup, err := mock.NewSQLAdminService(
   457  			context.Background(),
   458  			tc.reqs...,
   459  		)
   460  		if err != nil {
   461  			t.Fatalf("failed to create test SQL admin service: %s", err)
   462  		}
   463  		defer cleanup()
   464  
   465  		r := newRefresher(nullLogger{}, client, nil, testDialerID)
   466  		_, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
   467  
   468  		if !errors.As(err, &tc.wantErr) {
   469  			t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
   470  		}
   471  	}
   472  }
   473  

View as plain text