1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cloudsql
16
17 import (
18 "bytes"
19 "context"
20 "crypto/rsa"
21 "crypto/x509"
22 "encoding/pem"
23 "errors"
24 "sync"
25 "testing"
26 "time"
27
28 "cloud.google.com/go/cloudsqlconn/errtype"
29 "cloud.google.com/go/cloudsqlconn/internal/mock"
30 "golang.org/x/oauth2"
31 )
32
33 const testDialerID = "some-dialer-id"
34
35 func TestRefresh(t *testing.T) {
36 wantPublicIP := "127.0.0.1"
37 wantPrivateIP := "10.0.0.1"
38 wantPSC := "abcde.12345.us-central1.sql.goog"
39 wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second)
40 cn := testInstanceConnName()
41 inst := mock.NewFakeCSQLInstance(
42 cn.Project(), cn.Region(), cn.Name(),
43 mock.WithPublicIP(wantPublicIP),
44 mock.WithPrivateIP(wantPrivateIP),
45 mock.WithPSC(wantPSC),
46 mock.WithCertExpiry(wantExpiry),
47 )
48 client, cleanup, err := mock.NewSQLAdminService(
49 context.Background(),
50 mock.InstanceGetSuccess(inst, 1),
51 mock.CreateEphemeralSuccess(inst, 1),
52 )
53 if err != nil {
54 t.Fatalf("failed to create test SQL admin service: %s", err)
55 }
56 defer func() {
57 if err := cleanup(); err != nil {
58 t.Fatalf("%v", err)
59 }
60 }()
61
62 r := newRefresher(nullLogger{}, client, nil, testDialerID)
63 rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false)
64 if err != nil {
65 t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
66 }
67
68 gotIP, ok := rr.addrs[PublicIP]
69 if !ok {
70 t.Fatal("metadata IP addresses did not include public address")
71 }
72 if wantPublicIP != gotIP {
73 t.Fatalf("metadata IP mismatch, want = %v, got = %v", wantPublicIP, gotIP)
74 }
75 gotIP, ok = rr.addrs[PrivateIP]
76 if !ok {
77 t.Fatal("metadata IP addresses did not include private address")
78 }
79 if wantPrivateIP != gotIP {
80 t.Fatalf("metadata IP mismatch, want = %v, got = %v", wantPrivateIP, gotIP)
81 }
82 gotPSC, ok := rr.addrs[PSC]
83 if !ok {
84 t.Fatal("metadata IP addresses did not include PSC endpoint")
85 }
86 if wantPSC != gotPSC {
87 t.Fatalf("metadata IP mismatch, want = %v. got = %v", wantPSC, gotPSC)
88 }
89 if cn != rr.ConnectionName {
90 t.Fatalf(
91 "connection name mismatch, want = %v, got = %v",
92 wantExpiry, rr.Expiration,
93 )
94 }
95 if wantExpiry != rr.Expiration {
96 t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.Expiration)
97 }
98 }
99 func TestRefreshRetries50xResponses(t *testing.T) {
100 cn := testInstanceConnName()
101 inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
102 mock.WithEngineVersion("WANTED_VERSION"),
103 )
104 client, cleanup, err := mock.NewSQLAdminService(
105 context.Background(),
106
107 mock.InstanceGet500(inst, 1),
108 mock.InstanceGetSuccess(inst, 1),
109
110 mock.CreateEphemeral500(inst, 1),
111 mock.CreateEphemeralSuccess(inst, 1),
112 )
113 if err != nil {
114 t.Fatalf("failed to create test SQL admin service: %s", err)
115 }
116 defer func() {
117 if err := cleanup(); err != nil {
118 t.Fatalf("%v", err)
119 }
120 }()
121
122 r := newRefresher(nullLogger{}, client, nil, testDialerID)
123 rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false)
124 if err != nil {
125 t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
126 }
127 if rr.DBVersion != "WANTED_VERSION" {
128 t.Fatalf("DB version did not match expected, got = %v, want = %v",
129 rr.DBVersion, "WANTED_VERSION",
130 )
131 }
132 }
133
134 func TestRefreshFailsFast(t *testing.T) {
135 cn := testInstanceConnName()
136 inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
137 client, cleanup, err := mock.NewSQLAdminService(
138 context.Background(),
139 mock.InstanceGetSuccess(inst, 1),
140 mock.CreateEphemeralSuccess(inst, 1),
141 )
142 if err != nil {
143 t.Fatalf("failed to create test SQL admin service: %s", err)
144 }
145 defer cleanup()
146
147 r := newRefresher(nullLogger{}, client, nil, testDialerID)
148 _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
149 if err != nil {
150 t.Fatalf("expected no error, got = %v", err)
151 }
152
153 ctx, cancel := context.WithCancel(context.Background())
154 cancel()
155
156 _, err = r.ConnectionInfo(ctx, cn, RSAKey, false)
157 if !errors.Is(err, context.Canceled) {
158 t.Fatalf("expected context.Canceled error, got = %v", err)
159 }
160 }
161
162 type tokenResp struct {
163 tok *oauth2.Token
164 err error
165 }
166
167 type fakeTokenSource struct {
168 responses []tokenResp
169 mu sync.Mutex
170 ct int
171 }
172
173 func (f *fakeTokenSource) Token() (*oauth2.Token, error) {
174 f.mu.Lock()
175 defer f.mu.Unlock()
176 resp := f.responses[f.ct]
177 f.ct++
178 return resp.tok, resp.err
179 }
180
181 func (f *fakeTokenSource) count() int {
182 f.mu.Lock()
183 defer f.mu.Unlock()
184 return f.ct
185 }
186
187 func TestRefreshAdjustsCertExpiry(t *testing.T) {
188 certExpiry := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
189 t1 := time.Now().Add(59 * time.Minute).UTC().Truncate(time.Second)
190 t2 := time.Now().Add(61 * time.Minute).UTC().Truncate(time.Second)
191 tcs := []struct {
192 desc string
193 resps []tokenResp
194 wantExpiry time.Time
195 }{
196 {
197 desc: "when the token's expiration comes BEFORE the cert",
198 resps: []tokenResp{
199 {tok: &oauth2.Token{}},
200 {tok: &oauth2.Token{Expiry: t1}},
201 },
202 wantExpiry: t1,
203 },
204 {
205 desc: "when the token's expiration comes AFTER the cert",
206 resps: []tokenResp{
207 {tok: &oauth2.Token{}},
208 {tok: &oauth2.Token{Expiry: t2}},
209 },
210 wantExpiry: certExpiry,
211 },
212 }
213 cn := testInstanceConnName()
214 inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance",
215 mock.WithCertExpiry(certExpiry))
216 client, cleanup, err := mock.NewSQLAdminService(
217 context.Background(),
218 mock.InstanceGetSuccess(inst, 2),
219 mock.CreateEphemeralSuccess(inst, 2),
220 )
221 if err != nil {
222 t.Fatalf("failed to create test SQL admin service: %s", err)
223 }
224 defer cleanup()
225
226 for _, tc := range tcs {
227 t.Run(tc.desc, func(t *testing.T) {
228 ts := &fakeTokenSource{responses: tc.resps}
229 r := newRefresher(nullLogger{}, client, ts, testDialerID)
230 rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
231 if err != nil {
232 t.Fatalf("want no error, got = %v", err)
233 }
234 if tc.wantExpiry != rr.Expiration {
235 t.Fatalf("want = %v, got = %v", tc.wantExpiry, rr.Expiration)
236 }
237 })
238 }
239 }
240
241 func TestRefreshWithIAMAuthErrors(t *testing.T) {
242 tcs := []struct {
243 desc string
244 resps []tokenResp
245 wantCount int
246 }{
247 {
248 desc: "when fetching a token fails",
249 resps: []tokenResp{{tok: nil, err: errors.New("fetch failed")}},
250 wantCount: 1,
251 },
252 {
253 desc: "when refreshing a token fails",
254 resps: []tokenResp{
255 {tok: &oauth2.Token{}, err: nil},
256 {tok: nil, err: errors.New("refresh failed")},
257 },
258 wantCount: 2,
259 },
260 }
261 cn := testInstanceConnName()
262 inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
263 client, cleanup, err := mock.NewSQLAdminService(
264 context.Background(),
265 mock.InstanceGetSuccess(inst, 2),
266 )
267 if err != nil {
268 t.Fatalf("failed to create test SQL admin service: %s", err)
269 }
270 defer cleanup()
271
272 for _, tc := range tcs {
273 t.Run(tc.desc, func(t *testing.T) {
274 ts := &fakeTokenSource{responses: tc.resps}
275 r := newRefresher(nullLogger{}, client, ts, testDialerID)
276 _, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
277 if err == nil {
278 t.Fatalf("expected get failed error, got = %v", err)
279 }
280 if count := ts.count(); count != tc.wantCount {
281 t.Fatalf("expected fake token source to be called %v time, got = %v", tc.wantCount, count)
282 }
283 })
284 }
285 }
286
287 func TestRefreshMetadataConfigError(t *testing.T) {
288 cn := testInstanceConnName()
289
290 testCases := []struct {
291 req *mock.Request
292 wantErr *errtype.ConfigError
293 desc string
294 }{
295 {
296 req: mock.InstanceGetSuccess(
297 mock.NewFakeCSQLInstance(
298 cn.Project(), cn.Region(), cn.Name(),
299 mock.WithRegion("my-region"),
300 mock.WithFirstGenBackend(),
301 ), 1),
302 wantErr: &errtype.ConfigError{},
303 desc: "When the instance isn't Second generation",
304 },
305 {
306 req: mock.InstanceGetSuccess(
307 mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
308 mock.WithRegion("some-other-region")), 1),
309 wantErr: &errtype.ConfigError{},
310 desc: "When the region does not match",
311 },
312 {
313 req: mock.InstanceGetSuccess(
314 mock.NewFakeCSQLInstance(
315 cn.Project(), cn.Region(), cn.Name(),
316 mock.WithRegion("my-region"),
317 mock.WithNoIPAddrs(),
318 ), 1),
319 wantErr: &errtype.ConfigError{},
320 desc: "When the instance has no supported IP addresses",
321 },
322 }
323
324 for i, tc := range testCases {
325 t.Run(tc.desc, func(t *testing.T) {
326 client, cleanup, err := mock.NewSQLAdminService(
327 context.Background(),
328 tc.req,
329 )
330 if err != nil {
331 t.Fatalf("failed to create test SQL admin service: %s", err)
332 }
333 defer cleanup()
334
335 r := newRefresher(nullLogger{}, client, nil, testDialerID)
336 _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
337 if !errors.As(err, &tc.wantErr) {
338 t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
339 }
340 })
341 }
342 }
343
344 func TestRefreshMetadataRefreshError(t *testing.T) {
345 cn := testInstanceConnName()
346
347 testCases := []struct {
348 req *mock.Request
349 wantErr *errtype.RefreshError
350 desc string
351 }{
352 {
353 req: mock.CreateEphemeralSuccess(
354 mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name()), 1),
355 wantErr: &errtype.RefreshError{},
356 desc: "When the Metadata call fails",
357 },
358 {
359 req: mock.InstanceGetSuccess(
360 mock.NewFakeCSQLInstance(
361 cn.Project(), cn.Region(), cn.Name(),
362 mock.WithRegion("my-region"),
363 mock.WithCertSigner(func(_ *x509.Certificate, _ *rsa.PrivateKey) ([]byte, error) {
364 return nil, nil
365 }),
366 ), 1),
367 wantErr: &errtype.RefreshError{},
368 desc: "When the server cert does not decode",
369 },
370 {
371 req: mock.InstanceGetSuccess(
372 mock.NewFakeCSQLInstance(
373 cn.Project(), cn.Region(), cn.Name(),
374 mock.WithRegion("my-region"),
375 mock.WithCertSigner(func(_ *x509.Certificate, _ *rsa.PrivateKey) ([]byte, error) {
376 certPEM := &bytes.Buffer{}
377 pem.Encode(certPEM, &pem.Block{
378 Type: "CERTIFICATE",
379 Bytes: []byte("hello"),
380 })
381 return certPEM.Bytes(), nil
382 }),
383 ), 1),
384 wantErr: &errtype.RefreshError{},
385 desc: "When the cert is not a valid X.509 cert",
386 },
387 }
388
389 for i, tc := range testCases {
390 t.Run(tc.desc, func(t *testing.T) {
391 client, cleanup, err := mock.NewSQLAdminService(
392 context.Background(),
393 tc.req,
394 )
395 if err != nil {
396 t.Fatalf("failed to create test SQL admin service: %s", err)
397 }
398 defer cleanup()
399
400 r := newRefresher(nullLogger{}, client, nil, testDialerID)
401 _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
402 if !errors.As(err, &tc.wantErr) {
403 t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
404 }
405 })
406 }
407 }
408
409 func TestRefreshWithFailedEphemeralCertCall(t *testing.T) {
410 cn := testInstanceConnName()
411 inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
412
413 testCases := []struct {
414 reqs []*mock.Request
415 wantErr *errtype.RefreshError
416 desc string
417 }{
418 {
419 reqs: []*mock.Request{mock.InstanceGetSuccess(inst, 1)},
420 wantErr: &errtype.RefreshError{},
421 desc: "When the CreateEphemeralCert call fails",
422 },
423 {
424 reqs: []*mock.Request{mock.InstanceGetSuccess(inst, 1),
425 mock.CreateEphemeralSuccess(
426 mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
427 mock.WithClientCertSigner(
428 func(*x509.Certificate, *rsa.PrivateKey, *rsa.PublicKey) ([]byte, error) {
429 return nil, nil
430 }),
431 ), 1),
432 },
433 wantErr: &errtype.RefreshError{},
434 desc: "When decoding the cert fails",
435 },
436 {
437 reqs: []*mock.Request{mock.InstanceGetSuccess(inst, 1),
438 mock.CreateEphemeralSuccess(
439 mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
440 mock.WithClientCertSigner(
441 func(*x509.Certificate, *rsa.PrivateKey, *rsa.PublicKey) ([]byte, error) {
442 certPEM := &bytes.Buffer{}
443 pem.Encode(certPEM, &pem.Block{
444 Type: "CERTIFICATE",
445 Bytes: []byte("hello"),
446 })
447 return certPEM.Bytes(), nil
448 }),
449 ), 1),
450 },
451 wantErr: &errtype.RefreshError{},
452 desc: "When parsing the cert fails",
453 },
454 }
455 for i, tc := range testCases {
456 client, cleanup, err := mock.NewSQLAdminService(
457 context.Background(),
458 tc.reqs...,
459 )
460 if err != nil {
461 t.Fatalf("failed to create test SQL admin service: %s", err)
462 }
463 defer cleanup()
464
465 r := newRefresher(nullLogger{}, client, nil, testDialerID)
466 _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false)
467
468 if !errors.As(err, &tc.wantErr) {
469 t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err)
470 }
471 }
472 }
473
View as plain text