1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cloudsql
16
17 import (
18 "context"
19 "sync"
20 "testing"
21 "time"
22
23 "cloud.google.com/go/cloudsqlconn/instance"
24 "cloud.google.com/go/cloudsqlconn/internal/mock"
25 "golang.org/x/oauth2"
26 )
27
28 func TestLazyRefreshCacheConnectionInfo(t *testing.T) {
29 cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
30 inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
31 client, cleanup, err := mock.NewSQLAdminService(
32 context.Background(),
33 mock.InstanceGetSuccess(inst, 1),
34 mock.CreateEphemeralSuccess(inst, 1),
35 )
36 if err != nil {
37 t.Fatal(err)
38 }
39 defer func() {
40 if err := cleanup(); err != nil {
41 t.Fatalf("%v", err)
42 }
43 }()
44 c := NewLazyRefreshCache(
45 testInstanceConnName(), nullLogger{}, client,
46 RSAKey, 30*time.Second, nil, "", false,
47 )
48
49 ci, err := c.ConnectionInfo(context.Background())
50 if err != nil {
51 t.Fatal(err)
52 }
53 if ci.ConnectionName != cn {
54 t.Fatalf("want = %v, got = %v", cn, ci.ConnectionName)
55 }
56
57
58 _, err = c.ConnectionInfo(context.Background())
59 if err != nil {
60 t.Fatal(err)
61 }
62 }
63
64 func TestLazyRefreshCacheForceRefresh(t *testing.T) {
65 cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
66 inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
67 client, cleanup, err := mock.NewSQLAdminService(
68 context.Background(),
69 mock.InstanceGetSuccess(inst, 2),
70 mock.CreateEphemeralSuccess(inst, 2),
71 )
72 if err != nil {
73 t.Fatal(err)
74 }
75 defer func() {
76 if err := cleanup(); err != nil {
77 t.Fatalf("%v", err)
78 }
79 }()
80 c := NewLazyRefreshCache(
81 testInstanceConnName(), nullLogger{}, client,
82 RSAKey, 30*time.Second, nil, "", false,
83 )
84
85 _, err = c.ConnectionInfo(context.Background())
86 if err != nil {
87 t.Fatal(err)
88 }
89
90 c.ForceRefresh()
91
92 _, err = c.ConnectionInfo(context.Background())
93 if err != nil {
94 t.Fatal(err)
95 }
96 }
97
98
99 type spyTokenSource struct {
100 mu sync.Mutex
101 count int
102 }
103
104 func (s *spyTokenSource) Token() (*oauth2.Token, error) {
105 s.mu.Lock()
106 defer s.mu.Unlock()
107 s.count++
108 return &oauth2.Token{}, nil
109 }
110
111 func (s *spyTokenSource) callCount() int {
112 s.mu.Lock()
113 defer s.mu.Unlock()
114 return s.count
115 }
116
117 func TestLazyRefreshCacheUpdateRefresh(t *testing.T) {
118 cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
119 inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
120 client, cleanup, err := mock.NewSQLAdminService(
121 context.Background(),
122 mock.InstanceGetSuccess(inst, 2),
123 mock.CreateEphemeralSuccess(inst, 2),
124 )
125 if err != nil {
126 t.Fatal(err)
127 }
128 defer func() {
129 if err := cleanup(); err != nil {
130 t.Fatalf("%v", err)
131 }
132 }()
133
134 spy := &spyTokenSource{}
135 c := NewLazyRefreshCache(
136 testInstanceConnName(), nullLogger{}, client,
137 RSAKey, 30*time.Second, spy, "", false,
138 )
139
140 _, err = c.ConnectionInfo(context.Background())
141 if err != nil {
142 t.Fatal(err)
143 }
144
145 if got := spy.callCount(); got != 0 {
146 t.Fatal("oauth2.TokenSource was called, but should not have been")
147 }
148
149 c.UpdateRefresh(ptr(true))
150
151 _, err = c.ConnectionInfo(context.Background())
152 if err != nil {
153 t.Fatal(err)
154 }
155
156
157
158
159 if got, want := spy.callCount(), 2; got != want {
160 t.Fatalf(
161 "oauth2.TokenSource call count, got = %v, want = %v",
162 got, want,
163 )
164 }
165 }
166
167 func ptr(val bool) *bool {
168 return &val
169 }
170
View as plain text