...

Source file src/cloud.google.com/go/cloudsqlconn/internal/cloudsql/instance_test.go

Documentation: cloud.google.com/go/cloudsqlconn/internal/cloudsql

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cloudsql
    16  
    17  import (
    18  	"context"
    19  	"crypto/rand"
    20  	"crypto/rsa"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"encoding/pem"
    24  	"errors"
    25  	"testing"
    26  	"time"
    27  
    28  	"cloud.google.com/go/cloudsqlconn/errtype"
    29  	"cloud.google.com/go/cloudsqlconn/instance"
    30  	"cloud.google.com/go/cloudsqlconn/internal/mock"
    31  )
    32  
    33  type nullLogger struct{}
    34  
    35  func (nullLogger) Debugf(context.Context, string, ...interface{}) {}
    36  
    37  // genRSAKey generates an RSA key used for test.
    38  func genRSAKey() *rsa.PrivateKey {
    39  	key, err := rsa.GenerateKey(rand.Reader, 2048)
    40  	if err != nil {
    41  		panic(err) // unexpected, so just panic if it happens
    42  	}
    43  	return key
    44  }
    45  
    46  func testInstanceConnName() instance.ConnName {
    47  	cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
    48  	return cn
    49  }
    50  
    51  // RSAKey is used for test only.
    52  var RSAKey = genRSAKey()
    53  
    54  func TestConnectionInfoDBVersion(t *testing.T) {
    55  	ctx, cancel := context.WithCancel(context.Background())
    56  	defer cancel()
    57  	tests := []string{
    58  		"MYSQL_5_7", "POSTGRES_14", "SQLSERVER_2019_STANDARD", "MYSQL_8_0_18",
    59  	}
    60  	for _, wantEV := range tests {
    61  		inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", mock.WithEngineVersion(wantEV))
    62  		client, cleanup, err := mock.NewSQLAdminService(
    63  			ctx,
    64  			mock.InstanceGetSuccess(inst, 1),
    65  			mock.CreateEphemeralSuccess(inst, 1),
    66  		)
    67  		if err != nil {
    68  			t.Fatalf("%s", err)
    69  		}
    70  		defer func() {
    71  			if err := cleanup(); err != nil {
    72  				t.Fatalf("%v", err)
    73  			}
    74  		}()
    75  		i := NewRefreshAheadCache(
    76  			testInstanceConnName(), nullLogger{}, client,
    77  			RSAKey, 30*time.Second, nil, "", false,
    78  		)
    79  		if err != nil {
    80  			t.Fatalf("failed to init instance: %v", err)
    81  		}
    82  
    83  		ci, err := i.ConnectionInfo(ctx)
    84  		if err != nil {
    85  			t.Fatalf("failed to retrieve engine version: %v", err)
    86  		}
    87  		if wantEV != ci.DBVersion {
    88  			t.Errorf("ConnectionInfo(%s) failed: want %v, got %v", wantEV, ci, err)
    89  		}
    90  
    91  	}
    92  }
    93  
    94  func TestConnectionInfo(t *testing.T) {
    95  	ctx := context.Background()
    96  	wantAddr := "0.0.0.0"
    97  	inst := mock.NewFakeCSQLInstance(
    98  		"my-project", "my-region", "my-instance", mock.WithPublicIP(wantAddr),
    99  	)
   100  	client, cleanup, err := mock.NewSQLAdminService(
   101  		ctx,
   102  		mock.InstanceGetSuccess(inst, 1),
   103  		mock.CreateEphemeralSuccess(inst, 1),
   104  	)
   105  	if err != nil {
   106  		t.Fatalf("%s", err)
   107  	}
   108  	defer func() {
   109  		if err := cleanup(); err != nil {
   110  			t.Fatalf("%v", err)
   111  		}
   112  	}()
   113  
   114  	i := NewRefreshAheadCache(
   115  		testInstanceConnName(), nullLogger{}, client,
   116  		RSAKey, 30*time.Second, nil, "", false,
   117  	)
   118  
   119  	ci, err := i.ConnectionInfo(ctx)
   120  	if err != nil {
   121  		t.Fatalf("failed to retrieve connect info: %v", err)
   122  	}
   123  
   124  	got, err := ci.Addr(PublicIP)
   125  	if err != nil {
   126  		t.Fatal(err)
   127  	}
   128  	if got != wantAddr {
   129  		t.Fatalf(
   130  			"ConnectInfo returned unexpected IP address, want = %v, got = %v",
   131  			wantAddr, got,
   132  		)
   133  	}
   134  }
   135  
   136  func TestConnectionInfoTLSConfig(t *testing.T) {
   137  	cn := testInstanceConnName()
   138  	i := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
   139  	// Generate a client certificate with the client's public key and signed by
   140  	// the server's private key
   141  	cert, err := i.ClientCert(&RSAKey.PublicKey)
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  	// Now parse the bytes back out as structured data
   146  	// TODO: this should be done in the ClientCert method and not here.
   147  	b, _ := pem.Decode(cert)
   148  	clientCert, err := x509.ParseCertificate(b.Bytes)
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  
   153  	// Now self sign the server's cert
   154  	// TODO: this also should return structured data and handle the PEM
   155  	// encoding elsewhere
   156  	certBytes, err := mock.SelfSign(i.Cert, i.Key)
   157  	if err != nil {
   158  		t.Fatal(err)
   159  	}
   160  	b, _ = pem.Decode(certBytes)
   161  	serverCert, err := x509.ParseCertificate(b.Bytes)
   162  	if err != nil {
   163  		t.Fatal(err)
   164  	}
   165  
   166  	// Assemble a connection info with the raw and parsed client cert
   167  	// and the self-signed server certificate
   168  	ci := ConnectionInfo{
   169  		ConnectionName: cn,
   170  		ClientCertificate: tls.Certificate{
   171  			Certificate: [][]byte{clientCert.Raw},
   172  			PrivateKey:  RSAKey,
   173  			Leaf:        clientCert,
   174  		},
   175  		ServerCaCert: serverCert,
   176  		DBVersion:    "doesn't matter here",
   177  		Expiration:   clientCert.NotAfter,
   178  	}
   179  
   180  	got := ci.TLSConfig()
   181  	wantServerName := cn.String()
   182  	if got.ServerName != wantServerName {
   183  		t.Fatalf(
   184  			"ConnectInfo return unexpected server name in TLS Config, "+
   185  				"want = %v, got = %v",
   186  			wantServerName, got.ServerName,
   187  		)
   188  	}
   189  
   190  	if got.MinVersion != tls.VersionTLS13 {
   191  		t.Fatalf(
   192  			"want TLS 1.3, got = %v", got.MinVersion,
   193  		)
   194  	}
   195  
   196  	if got.Certificates[0].Leaf != ci.ClientCertificate.Leaf {
   197  		t.Fatal("leaf certificates do not match")
   198  	}
   199  
   200  	verifyPeerCert := got.VerifyPeerCertificate
   201  	err = verifyPeerCert([][]byte{serverCert.Raw}, nil)
   202  	if err != nil {
   203  		t.Fatalf("expected to verify peer cert, got error: %v", err)
   204  	}
   205  
   206  	err = verifyPeerCert(nil, nil)
   207  	var wantErr *errtype.DialError
   208  	if !errors.As(err, &wantErr) {
   209  		t.Fatalf(
   210  			"when verify peer cert fails, want = %T, got = %v", wantErr, err,
   211  		)
   212  	}
   213  
   214  	// Ensure invalid certs result in an error
   215  	err = verifyPeerCert([][]byte{[]byte("not a cert")}, nil)
   216  	if !errors.As(err, &wantErr) {
   217  		t.Fatalf(
   218  			"when verify fails on invalid cert, want = %T, got = %v",
   219  			wantErr, err,
   220  		)
   221  	}
   222  
   223  	// Ensure the common name is verified againsts the expected name
   224  	badCert := mock.GenerateCertWithCommonName(i, "wrong:wrong")
   225  	err = verifyPeerCert([][]byte{badCert}, nil)
   226  	if !errors.As(err, &wantErr) {
   227  		t.Fatalf(
   228  			"when common names mismatch, want = %T, got = %v", wantErr, err,
   229  		)
   230  	}
   231  
   232  	// Verify an unreconigzed authority is rejected
   233  	other := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
   234  	cert, err = mock.SelfSign(other.Cert, other.Key)
   235  	if err != nil {
   236  		t.Fatalf("failed to sign certificate: %v", err)
   237  	}
   238  	b, _ = pem.Decode(cert)
   239  	err = verifyPeerCert([][]byte{b.Bytes}, nil)
   240  	if !errors.As(err, &wantErr) {
   241  		t.Fatalf("when certification fails, want = %T, got = %v", wantErr, err)
   242  	}
   243  }
   244  
   245  func TestConnectInfoAutoIP(t *testing.T) {
   246  	tcs := []struct {
   247  		desc   string
   248  		ips    []mock.FakeCSQLInstanceOption
   249  		wantIP string
   250  	}{
   251  		{
   252  			desc: "when public IP is enabled",
   253  			ips: []mock.FakeCSQLInstanceOption{
   254  				mock.WithPublicIP("8.8.8.8"),
   255  				mock.WithPrivateIP("10.0.0.1"),
   256  			},
   257  			wantIP: "8.8.8.8",
   258  		},
   259  		{
   260  			desc: "when only private IP is enabled",
   261  			ips: []mock.FakeCSQLInstanceOption{
   262  				mock.WithPrivateIP("10.0.0.1"),
   263  			},
   264  			wantIP: "10.0.0.1",
   265  		},
   266  	}
   267  
   268  	for _, tc := range tcs {
   269  		var opts []mock.FakeCSQLInstanceOption
   270  		opts = append(opts, mock.WithNoIPAddrs())
   271  		opts = append(opts, tc.ips...)
   272  		inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", opts...)
   273  		client, cleanup, err := mock.NewSQLAdminService(
   274  			context.Background(),
   275  			mock.InstanceGetSuccess(inst, 1),
   276  			mock.CreateEphemeralSuccess(inst, 1),
   277  		)
   278  		if err != nil {
   279  			t.Fatalf("%s", err)
   280  		}
   281  		defer func() {
   282  			if cErr := cleanup(); cErr != nil {
   283  				t.Fatalf("%v", cErr)
   284  			}
   285  		}()
   286  
   287  		i := NewRefreshAheadCache(
   288  			testInstanceConnName(), nullLogger{}, client,
   289  			RSAKey, 30*time.Second, nil, "", false,
   290  		)
   291  		if err != nil {
   292  			t.Fatalf("failed to create mock instance: %v", err)
   293  		}
   294  
   295  		ci, err := i.ConnectionInfo(context.Background())
   296  		if err != nil {
   297  			t.Fatalf("failed to retrieve connect info: %v", err)
   298  		}
   299  
   300  		got, err := ci.Addr(AutoIP)
   301  		if err != nil {
   302  			t.Fatal(err)
   303  		}
   304  		if got != tc.wantIP {
   305  			t.Fatalf(
   306  				"ConnectInfo returned unexpected IP address, want = %v, got = %v",
   307  				tc.wantIP, got,
   308  			)
   309  		}
   310  	}
   311  }
   312  
   313  func TestClose(t *testing.T) {
   314  	ctx := context.Background()
   315  
   316  	client, cleanup, err := mock.NewSQLAdminService(ctx)
   317  	if err != nil {
   318  		t.Fatalf("%s", err)
   319  	}
   320  	defer cleanup()
   321  
   322  	// Set up an instance and then close it immediately
   323  	i := NewRefreshAheadCache(
   324  		testInstanceConnName(), nullLogger{}, client,
   325  		RSAKey, 30*time.Second, nil, "", false,
   326  	)
   327  	i.Close()
   328  
   329  	_, err = i.ConnectionInfo(ctx)
   330  	if !errors.Is(err, context.Canceled) {
   331  		t.Fatalf("failed to retrieve connect info: %v", err)
   332  	}
   333  }
   334  
   335  func TestRefreshDuration(t *testing.T) {
   336  	now := time.Now()
   337  	tcs := []struct {
   338  		desc   string
   339  		expiry time.Time
   340  		want   time.Duration
   341  	}{
   342  		{
   343  			desc:   "when expiration is greater than 1 hour",
   344  			expiry: now.Add(4 * time.Hour),
   345  			want:   2 * time.Hour,
   346  		},
   347  		{
   348  			desc:   "when expiration is equal to 1 hour",
   349  			expiry: now.Add(time.Hour),
   350  			want:   30 * time.Minute,
   351  		},
   352  		{
   353  			desc:   "when expiration is less than 1 hour, but greater than 4 minutes",
   354  			expiry: now.Add(5 * time.Minute),
   355  			want:   time.Minute,
   356  		},
   357  		{
   358  			desc:   "when expiration is less than 4 minutes",
   359  			expiry: now.Add(3 * time.Minute),
   360  			want:   0,
   361  		},
   362  		{
   363  			desc:   "when expiration is now",
   364  			expiry: now,
   365  			want:   0,
   366  		},
   367  	}
   368  	for _, tc := range tcs {
   369  		t.Run(tc.desc, func(t *testing.T) {
   370  			got := refreshDuration(now, tc.expiry)
   371  			// round to the second to remove millisecond differences
   372  			if got.Round(time.Second) != tc.want {
   373  				t.Fatalf("time until refresh: want = %v, got = %v", tc.want, got)
   374  			}
   375  		})
   376  	}
   377  }
   378  

View as plain text