1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cloudsql
16
17 import (
18 "context"
19 "crypto/rand"
20 "crypto/rsa"
21 "crypto/tls"
22 "crypto/x509"
23 "encoding/pem"
24 "errors"
25 "testing"
26 "time"
27
28 "cloud.google.com/go/cloudsqlconn/errtype"
29 "cloud.google.com/go/cloudsqlconn/instance"
30 "cloud.google.com/go/cloudsqlconn/internal/mock"
31 )
32
33 type nullLogger struct{}
34
35 func (nullLogger) Debugf(context.Context, string, ...interface{}) {}
36
37
38 func genRSAKey() *rsa.PrivateKey {
39 key, err := rsa.GenerateKey(rand.Reader, 2048)
40 if err != nil {
41 panic(err)
42 }
43 return key
44 }
45
46 func testInstanceConnName() instance.ConnName {
47 cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
48 return cn
49 }
50
51
52 var RSAKey = genRSAKey()
53
54 func TestConnectionInfoDBVersion(t *testing.T) {
55 ctx, cancel := context.WithCancel(context.Background())
56 defer cancel()
57 tests := []string{
58 "MYSQL_5_7", "POSTGRES_14", "SQLSERVER_2019_STANDARD", "MYSQL_8_0_18",
59 }
60 for _, wantEV := range tests {
61 inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", mock.WithEngineVersion(wantEV))
62 client, cleanup, err := mock.NewSQLAdminService(
63 ctx,
64 mock.InstanceGetSuccess(inst, 1),
65 mock.CreateEphemeralSuccess(inst, 1),
66 )
67 if err != nil {
68 t.Fatalf("%s", err)
69 }
70 defer func() {
71 if err := cleanup(); err != nil {
72 t.Fatalf("%v", err)
73 }
74 }()
75 i := NewRefreshAheadCache(
76 testInstanceConnName(), nullLogger{}, client,
77 RSAKey, 30*time.Second, nil, "", false,
78 )
79 if err != nil {
80 t.Fatalf("failed to init instance: %v", err)
81 }
82
83 ci, err := i.ConnectionInfo(ctx)
84 if err != nil {
85 t.Fatalf("failed to retrieve engine version: %v", err)
86 }
87 if wantEV != ci.DBVersion {
88 t.Errorf("ConnectionInfo(%s) failed: want %v, got %v", wantEV, ci, err)
89 }
90
91 }
92 }
93
94 func TestConnectionInfo(t *testing.T) {
95 ctx := context.Background()
96 wantAddr := "0.0.0.0"
97 inst := mock.NewFakeCSQLInstance(
98 "my-project", "my-region", "my-instance", mock.WithPublicIP(wantAddr),
99 )
100 client, cleanup, err := mock.NewSQLAdminService(
101 ctx,
102 mock.InstanceGetSuccess(inst, 1),
103 mock.CreateEphemeralSuccess(inst, 1),
104 )
105 if err != nil {
106 t.Fatalf("%s", err)
107 }
108 defer func() {
109 if err := cleanup(); err != nil {
110 t.Fatalf("%v", err)
111 }
112 }()
113
114 i := NewRefreshAheadCache(
115 testInstanceConnName(), nullLogger{}, client,
116 RSAKey, 30*time.Second, nil, "", false,
117 )
118
119 ci, err := i.ConnectionInfo(ctx)
120 if err != nil {
121 t.Fatalf("failed to retrieve connect info: %v", err)
122 }
123
124 got, err := ci.Addr(PublicIP)
125 if err != nil {
126 t.Fatal(err)
127 }
128 if got != wantAddr {
129 t.Fatalf(
130 "ConnectInfo returned unexpected IP address, want = %v, got = %v",
131 wantAddr, got,
132 )
133 }
134 }
135
136 func TestConnectionInfoTLSConfig(t *testing.T) {
137 cn := testInstanceConnName()
138 i := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
139
140
141 cert, err := i.ClientCert(&RSAKey.PublicKey)
142 if err != nil {
143 t.Fatal(err)
144 }
145
146
147 b, _ := pem.Decode(cert)
148 clientCert, err := x509.ParseCertificate(b.Bytes)
149 if err != nil {
150 t.Fatal(err)
151 }
152
153
154
155
156 certBytes, err := mock.SelfSign(i.Cert, i.Key)
157 if err != nil {
158 t.Fatal(err)
159 }
160 b, _ = pem.Decode(certBytes)
161 serverCert, err := x509.ParseCertificate(b.Bytes)
162 if err != nil {
163 t.Fatal(err)
164 }
165
166
167
168 ci := ConnectionInfo{
169 ConnectionName: cn,
170 ClientCertificate: tls.Certificate{
171 Certificate: [][]byte{clientCert.Raw},
172 PrivateKey: RSAKey,
173 Leaf: clientCert,
174 },
175 ServerCaCert: serverCert,
176 DBVersion: "doesn't matter here",
177 Expiration: clientCert.NotAfter,
178 }
179
180 got := ci.TLSConfig()
181 wantServerName := cn.String()
182 if got.ServerName != wantServerName {
183 t.Fatalf(
184 "ConnectInfo return unexpected server name in TLS Config, "+
185 "want = %v, got = %v",
186 wantServerName, got.ServerName,
187 )
188 }
189
190 if got.MinVersion != tls.VersionTLS13 {
191 t.Fatalf(
192 "want TLS 1.3, got = %v", got.MinVersion,
193 )
194 }
195
196 if got.Certificates[0].Leaf != ci.ClientCertificate.Leaf {
197 t.Fatal("leaf certificates do not match")
198 }
199
200 verifyPeerCert := got.VerifyPeerCertificate
201 err = verifyPeerCert([][]byte{serverCert.Raw}, nil)
202 if err != nil {
203 t.Fatalf("expected to verify peer cert, got error: %v", err)
204 }
205
206 err = verifyPeerCert(nil, nil)
207 var wantErr *errtype.DialError
208 if !errors.As(err, &wantErr) {
209 t.Fatalf(
210 "when verify peer cert fails, want = %T, got = %v", wantErr, err,
211 )
212 }
213
214
215 err = verifyPeerCert([][]byte{[]byte("not a cert")}, nil)
216 if !errors.As(err, &wantErr) {
217 t.Fatalf(
218 "when verify fails on invalid cert, want = %T, got = %v",
219 wantErr, err,
220 )
221 }
222
223
224 badCert := mock.GenerateCertWithCommonName(i, "wrong:wrong")
225 err = verifyPeerCert([][]byte{badCert}, nil)
226 if !errors.As(err, &wantErr) {
227 t.Fatalf(
228 "when common names mismatch, want = %T, got = %v", wantErr, err,
229 )
230 }
231
232
233 other := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
234 cert, err = mock.SelfSign(other.Cert, other.Key)
235 if err != nil {
236 t.Fatalf("failed to sign certificate: %v", err)
237 }
238 b, _ = pem.Decode(cert)
239 err = verifyPeerCert([][]byte{b.Bytes}, nil)
240 if !errors.As(err, &wantErr) {
241 t.Fatalf("when certification fails, want = %T, got = %v", wantErr, err)
242 }
243 }
244
245 func TestConnectInfoAutoIP(t *testing.T) {
246 tcs := []struct {
247 desc string
248 ips []mock.FakeCSQLInstanceOption
249 wantIP string
250 }{
251 {
252 desc: "when public IP is enabled",
253 ips: []mock.FakeCSQLInstanceOption{
254 mock.WithPublicIP("8.8.8.8"),
255 mock.WithPrivateIP("10.0.0.1"),
256 },
257 wantIP: "8.8.8.8",
258 },
259 {
260 desc: "when only private IP is enabled",
261 ips: []mock.FakeCSQLInstanceOption{
262 mock.WithPrivateIP("10.0.0.1"),
263 },
264 wantIP: "10.0.0.1",
265 },
266 }
267
268 for _, tc := range tcs {
269 var opts []mock.FakeCSQLInstanceOption
270 opts = append(opts, mock.WithNoIPAddrs())
271 opts = append(opts, tc.ips...)
272 inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", opts...)
273 client, cleanup, err := mock.NewSQLAdminService(
274 context.Background(),
275 mock.InstanceGetSuccess(inst, 1),
276 mock.CreateEphemeralSuccess(inst, 1),
277 )
278 if err != nil {
279 t.Fatalf("%s", err)
280 }
281 defer func() {
282 if cErr := cleanup(); cErr != nil {
283 t.Fatalf("%v", cErr)
284 }
285 }()
286
287 i := NewRefreshAheadCache(
288 testInstanceConnName(), nullLogger{}, client,
289 RSAKey, 30*time.Second, nil, "", false,
290 )
291 if err != nil {
292 t.Fatalf("failed to create mock instance: %v", err)
293 }
294
295 ci, err := i.ConnectionInfo(context.Background())
296 if err != nil {
297 t.Fatalf("failed to retrieve connect info: %v", err)
298 }
299
300 got, err := ci.Addr(AutoIP)
301 if err != nil {
302 t.Fatal(err)
303 }
304 if got != tc.wantIP {
305 t.Fatalf(
306 "ConnectInfo returned unexpected IP address, want = %v, got = %v",
307 tc.wantIP, got,
308 )
309 }
310 }
311 }
312
313 func TestClose(t *testing.T) {
314 ctx := context.Background()
315
316 client, cleanup, err := mock.NewSQLAdminService(ctx)
317 if err != nil {
318 t.Fatalf("%s", err)
319 }
320 defer cleanup()
321
322
323 i := NewRefreshAheadCache(
324 testInstanceConnName(), nullLogger{}, client,
325 RSAKey, 30*time.Second, nil, "", false,
326 )
327 i.Close()
328
329 _, err = i.ConnectionInfo(ctx)
330 if !errors.Is(err, context.Canceled) {
331 t.Fatalf("failed to retrieve connect info: %v", err)
332 }
333 }
334
335 func TestRefreshDuration(t *testing.T) {
336 now := time.Now()
337 tcs := []struct {
338 desc string
339 expiry time.Time
340 want time.Duration
341 }{
342 {
343 desc: "when expiration is greater than 1 hour",
344 expiry: now.Add(4 * time.Hour),
345 want: 2 * time.Hour,
346 },
347 {
348 desc: "when expiration is equal to 1 hour",
349 expiry: now.Add(time.Hour),
350 want: 30 * time.Minute,
351 },
352 {
353 desc: "when expiration is less than 1 hour, but greater than 4 minutes",
354 expiry: now.Add(5 * time.Minute),
355 want: time.Minute,
356 },
357 {
358 desc: "when expiration is less than 4 minutes",
359 expiry: now.Add(3 * time.Minute),
360 want: 0,
361 },
362 {
363 desc: "when expiration is now",
364 expiry: now,
365 want: 0,
366 },
367 }
368 for _, tc := range tcs {
369 t.Run(tc.desc, func(t *testing.T) {
370 got := refreshDuration(now, tc.expiry)
371
372 if got.Round(time.Second) != tc.want {
373 t.Fatalf("time until refresh: want = %v, got = %v", tc.want, got)
374 }
375 })
376 }
377 }
378
View as plain text