1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cloudsql
16
17 import (
18 "context"
19 "crypto/rsa"
20 "crypto/tls"
21 "crypto/x509"
22 "fmt"
23 "sync"
24 "time"
25
26 "cloud.google.com/go/cloudsqlconn/debug"
27 "cloud.google.com/go/cloudsqlconn/errtype"
28 "cloud.google.com/go/cloudsqlconn/instance"
29 "golang.org/x/oauth2"
30 "golang.org/x/time/rate"
31 sqladmin "google.golang.org/api/sqladmin/v1beta4"
32 )
33
34 const (
35
36
37 refreshBuffer = 4 * time.Minute
38
39
40
41 refreshInterval = 30 * time.Second
42
43
44
45
46 RefreshTimeout = 60 * time.Second
47
48
49 refreshBurst = 2
50 )
51
52
53
54
55 type refreshOperation struct {
56
57 ready chan struct{}
58
59 timer *time.Timer
60 result ConnectionInfo
61 err error
62 }
63
64
65
66
67 func (r *refreshOperation) cancel() bool {
68 return r.timer.Stop()
69 }
70
71
72
73 func (r *refreshOperation) isValid() bool {
74
75 select {
76 default:
77 return false
78 case <-r.ready:
79 if r.err != nil || time.Now().After(r.result.Expiration.Round(0)) {
80 return false
81 }
82 return true
83 }
84 }
85
86
87
88
89
90 type RefreshAheadCache struct {
91
92 openConns uint64
93
94 connName instance.ConnName
95 logger debug.ContextLogger
96 key *rsa.PrivateKey
97
98
99
100 refreshTimeout time.Duration
101
102 l *rate.Limiter
103 r refresher
104
105 mu sync.RWMutex
106 useIAMAuthNDial bool
107
108
109
110 cur *refreshOperation
111
112
113 next *refreshOperation
114
115
116
117 ctx context.Context
118 cancel context.CancelFunc
119 }
120
121
122 func NewRefreshAheadCache(
123 cn instance.ConnName,
124 l debug.ContextLogger,
125 client *sqladmin.Service,
126 key *rsa.PrivateKey,
127 refreshTimeout time.Duration,
128 ts oauth2.TokenSource,
129 dialerID string,
130 useIAMAuthNDial bool,
131 ) *RefreshAheadCache {
132 ctx, cancel := context.WithCancel(context.Background())
133 i := &RefreshAheadCache{
134 connName: cn,
135 logger: l,
136 key: key,
137 l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
138 r: newRefresher(
139 l,
140 client,
141 ts,
142 dialerID,
143 ),
144 refreshTimeout: refreshTimeout,
145 useIAMAuthNDial: useIAMAuthNDial,
146 ctx: ctx,
147 cancel: cancel,
148 }
149
150
151 i.mu.Lock()
152 i.cur = i.scheduleRefresh(0)
153 i.next = i.cur
154 i.mu.Unlock()
155 return i
156 }
157
158
159
160 func (i *RefreshAheadCache) Close() error {
161 i.mu.Lock()
162 defer i.mu.Unlock()
163 i.cancel()
164 i.cur.cancel()
165 i.next.cancel()
166 return nil
167 }
168
169
170
171 type ConnectionInfo struct {
172 ConnectionName instance.ConnName
173 ClientCertificate tls.Certificate
174 ServerCaCert *x509.Certificate
175 DBVersion string
176 Expiration time.Time
177
178 addrs map[string]string
179 }
180
181
182 func (c ConnectionInfo) Addr(ipType string) (string, error) {
183 var (
184 addr string
185 ok bool
186 )
187 switch ipType {
188 case AutoIP:
189
190 addr, ok = c.addrs[PublicIP]
191 if !ok {
192
193 addr, ok = c.addrs[PrivateIP]
194 }
195 default:
196 addr, ok = c.addrs[ipType]
197 }
198 if !ok {
199 err := errtype.NewConfigError(
200 fmt.Sprintf("instance does not have IP of type %q", ipType),
201 c.ConnectionName.String(),
202 )
203 return "", err
204 }
205 return addr, nil
206 }
207
208
209 func (c ConnectionInfo) TLSConfig() *tls.Config {
210 pool := x509.NewCertPool()
211 pool.AddCert(c.ServerCaCert)
212 return &tls.Config{
213 ServerName: c.ConnectionName.String(),
214 Certificates: []tls.Certificate{c.ClientCertificate},
215 RootCAs: pool,
216
217
218
219
220
221
222
223
224 InsecureSkipVerify: true,
225 VerifyPeerCertificate: verifyPeerCertificateFunc(c.ConnectionName, pool),
226 MinVersion: tls.VersionTLS13,
227 }
228 }
229
230
231
232
233
234
235 func verifyPeerCertificateFunc(
236 cn instance.ConnName, pool *x509.CertPool,
237 ) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
238 return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
239 if len(rawCerts) == 0 {
240 return errtype.NewDialError(
241 "no certificate to verify", cn.String(), nil,
242 )
243 }
244
245 cert, err := x509.ParseCertificate(rawCerts[0])
246 if err != nil {
247 return errtype.NewDialError(
248 "failed to parse X.509 certificate", cn.String(), err,
249 )
250 }
251
252 opts := x509.VerifyOptions{Roots: pool}
253 if _, err = cert.Verify(opts); err != nil {
254 return errtype.NewDialError(
255 "failed to verify certificate", cn.String(), err,
256 )
257 }
258
259 certInstanceName := fmt.Sprintf("%s:%s", cn.Project(), cn.Name())
260 if cert.Subject.CommonName != certInstanceName {
261 return errtype.NewDialError(
262 fmt.Sprintf(
263 "certificate had CN %q, expected %q",
264 cert.Subject.CommonName, certInstanceName,
265 ),
266 cn.String(),
267 nil,
268 )
269 }
270 return nil
271 }
272 }
273
274
275
276
277 func (i *RefreshAheadCache) ConnectionInfo(ctx context.Context) (ConnectionInfo, error) {
278 op, err := i.refreshOperation(ctx)
279 if err != nil {
280 return ConnectionInfo{}, err
281 }
282 return op.result, nil
283 }
284
285
286
287
288 func (i *RefreshAheadCache) UpdateRefresh(useIAMAuthNDial *bool) {
289 i.mu.Lock()
290 defer i.mu.Unlock()
291 if useIAMAuthNDial != nil && *useIAMAuthNDial != i.useIAMAuthNDial {
292
293 i.cur.cancel()
294 i.next.cancel()
295
296 i.useIAMAuthNDial = *useIAMAuthNDial
297
298 i.cur = i.scheduleRefresh(0)
299 i.next = i.cur
300 }
301 }
302
303
304
305
306 func (i *RefreshAheadCache) ForceRefresh() {
307 i.mu.Lock()
308 defer i.mu.Unlock()
309
310
311 if i.next.cancel() {
312 i.next = i.scheduleRefresh(0)
313 }
314
315
316 if !i.cur.isValid() {
317 i.cur = i.next
318 }
319 }
320
321
322
323 func (i *RefreshAheadCache) refreshOperation(ctx context.Context) (*refreshOperation, error) {
324 i.mu.RLock()
325 cur := i.cur
326 i.mu.RUnlock()
327 var err error
328 select {
329 case <-cur.ready:
330 err = cur.err
331 case <-ctx.Done():
332 err = ctx.Err()
333 case <-i.ctx.Done():
334 err = i.ctx.Err()
335 }
336 if err != nil {
337 return nil, err
338 }
339 return cur, nil
340 }
341
342
343
344
345 func refreshDuration(now, certExpiry time.Time) time.Duration {
346 d := certExpiry.Sub(now.Round(0))
347 if d < time.Hour {
348
349 if d < refreshBuffer {
350 return 0
351 }
352
353
354 return d - refreshBuffer
355 }
356 return d / 2
357 }
358
359
360
361
362 func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
363 r := &refreshOperation{}
364 r.ready = make(chan struct{})
365 r.timer = time.AfterFunc(d, func() {
366
367 if err := i.ctx.Err(); err != nil {
368 i.logger.Debugf(
369 context.Background(),
370 "[%v] Instance is closed, stopping refresh operations",
371 i.connName.String(),
372 )
373 r.err = err
374 close(r.ready)
375 return
376 }
377 i.logger.Debugf(
378 context.Background(),
379 "[%v] Connection info refresh operation started",
380 i.connName.String(),
381 )
382
383 ctx, cancel := context.WithTimeout(i.ctx, i.refreshTimeout)
384 defer cancel()
385
386
387
388 err := i.l.Wait(ctx)
389 if err != nil {
390 r.err = errtype.NewDialError(
391 "context was canceled or expired before refresh completed",
392 i.connName.String(),
393 nil,
394 )
395 } else {
396 r.result, r.err = i.r.ConnectionInfo(
397 ctx, i.connName, i.key, i.useIAMAuthNDial)
398 }
399 switch r.err {
400 case nil:
401 i.logger.Debugf(
402 ctx,
403 "[%v] Connection info refresh operation complete",
404 i.connName.String(),
405 )
406 i.logger.Debugf(
407 ctx,
408 "[%v] Current certificate expiration = %v",
409 i.connName.String(),
410 r.result.Expiration.UTC().Format(time.RFC3339),
411 )
412 default:
413 i.logger.Debugf(
414 ctx,
415 "[%v] Connection info refresh operation failed, err = %v",
416 i.connName.String(),
417 r.err,
418 )
419 }
420
421 close(r.ready)
422
423
424
425 i.mu.Lock()
426 defer i.mu.Unlock()
427
428
429 if r.err != nil {
430 i.logger.Debugf(
431 ctx,
432 "[%v] Connection info refresh operation scheduled immediately",
433 i.connName.String(),
434 )
435 i.next = i.scheduleRefresh(0)
436
437
438
439
440
441
442 if !i.cur.isValid() {
443 i.cur = r
444 }
445 return
446 }
447
448
449
450 i.cur = r
451 t := refreshDuration(time.Now(), i.cur.result.Expiration)
452 i.logger.Debugf(
453 ctx,
454 "[%v] Connection info refresh operation scheduled at %v (now + %v)",
455 i.connName.String(),
456 time.Now().Add(t).UTC().Format(time.RFC3339),
457 t.Round(time.Minute),
458 )
459 i.next = i.scheduleRefresh(t)
460 })
461 return r
462 }
463
View as plain text