1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cloudsqlconn
16
17 import (
18 "context"
19 "crypto/rand"
20 "crypto/rsa"
21 "crypto/tls"
22 _ "embed"
23 "errors"
24 "fmt"
25 "io"
26 "net"
27 "strings"
28 "sync"
29 "sync/atomic"
30 "time"
31
32 "cloud.google.com/go/cloudsqlconn/debug"
33 "cloud.google.com/go/cloudsqlconn/errtype"
34 "cloud.google.com/go/cloudsqlconn/instance"
35 "cloud.google.com/go/cloudsqlconn/internal/cloudsql"
36 "cloud.google.com/go/cloudsqlconn/internal/trace"
37 "github.com/google/uuid"
38 "golang.org/x/net/proxy"
39 "golang.org/x/oauth2"
40 "golang.org/x/oauth2/google"
41 "google.golang.org/api/option"
42 sqladmin "google.golang.org/api/sqladmin/v1beta4"
43 )
44
45 const (
46
47 defaultTCPKeepAlive = 30 * time.Second
48
49 serverProxyPort = "3307"
50
51
52 iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login"
53 )
54
55 var (
56
57
58 ErrDialerClosed = errors.New("cloudsqlconn: dialer is closed")
59
60
61 versionString string
62 userAgent = "cloud-sql-go-connector/" + strings.TrimSpace(versionString)
63
64
65 defaultKey *rsa.PrivateKey
66 defaultKeyErr error
67 keyOnce sync.Once
68 )
69
70 func getDefaultKeys() (*rsa.PrivateKey, error) {
71 keyOnce.Do(func() {
72 defaultKey, defaultKeyErr = rsa.GenerateKey(rand.Reader, 2048)
73 })
74 return defaultKey, defaultKeyErr
75 }
76
77 type connectionInfoCache interface {
78 ConnectionInfo(context.Context) (cloudsql.ConnectionInfo, error)
79 UpdateRefresh(*bool)
80 ForceRefresh()
81 io.Closer
82 }
83
84
85
86 type monitoredCache struct {
87 openConns uint64
88
89 connectionInfoCache
90 }
91
92
93
94
95 type Dialer struct {
96 lock sync.RWMutex
97 cache map[instance.ConnName]monitoredCache
98 key *rsa.PrivateKey
99 refreshTimeout time.Duration
100
101 closed chan struct{}
102
103 sqladmin *sqladmin.Service
104 logger debug.ContextLogger
105
106
107
108
109
110
111 lazyRefresh bool
112
113
114
115 defaultDialConfig dialConfig
116
117
118
119 dialerID string
120
121
122
123 dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)
124
125
126 iamTokenSource oauth2.TokenSource
127 }
128
129 var (
130 errUseTokenSource = errors.New("use WithTokenSource when IAM AuthN is not enabled")
131 errUseIAMTokenSource = errors.New("use WithIAMAuthNTokenSources instead of WithTokenSource be used when IAM AuthN is enabled")
132 )
133
134 type nullLogger struct{}
135
136 func (nullLogger) Debugf(_ context.Context, _ string, _ ...interface{}) {}
137
138
139
140
141
142
143 func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
144 cfg := &dialerConfig{
145 refreshTimeout: cloudsql.RefreshTimeout,
146 dialFunc: proxy.Dial,
147 logger: nullLogger{},
148 useragents: []string{userAgent},
149 serviceUniverse: "googleapis.com",
150 }
151 for _, opt := range opts {
152 opt(cfg)
153 if cfg.err != nil {
154 return nil, cfg.err
155 }
156 }
157 if cfg.useIAMAuthN && cfg.setTokenSource && !cfg.setIAMAuthNTokenSource {
158 return nil, errUseIAMTokenSource
159 }
160 if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
161 return nil, errUseTokenSource
162 }
163
164 cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))
165
166
167
168
169 if !cfg.setCredentials {
170 c, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
171 if err != nil {
172 return nil, fmt.Errorf("failed to create default credentials: %v", err)
173 }
174 ud, err := c.GetUniverseDomain()
175 if err != nil {
176 return nil, fmt.Errorf("failed to get universe domain: %v", err)
177 }
178 cfg.credentialsUniverse = ud
179 cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(c.TokenSource))
180 scoped, err := google.DefaultTokenSource(ctx, iamLoginScope)
181 if err != nil {
182 return nil, fmt.Errorf("failed to create scoped token source: %v", err)
183 }
184 cfg.iamLoginTokenSource = scoped
185 }
186
187 if cfg.rsaKey == nil {
188 key, err := getDefaultKeys()
189 if err != nil {
190 return nil, fmt.Errorf("failed to generate RSA keys: %v", err)
191 }
192 cfg.rsaKey = key
193 }
194
195 if cfg.setUniverseDomain && cfg.setAdminAPIEndpoint {
196 return nil, errors.New(
197 "can not use WithAdminAPIEndpoint and WithUniverseDomain Options together, " +
198 "use WithAdminAPIEndpoint (it already contains the universe domain)",
199 )
200 }
201
202 if cfg.credentialsUniverse != "" && cfg.serviceUniverse != "" {
203 if cfg.credentialsUniverse != cfg.serviceUniverse {
204 return nil, fmt.Errorf(
205 "the configured service universe domain (%s) does not match the credential universe domain (%s)",
206 cfg.serviceUniverse, cfg.credentialsUniverse,
207 )
208 }
209 }
210
211 client, err := sqladmin.NewService(ctx, cfg.sqladminOpts...)
212 if err != nil {
213 return nil, fmt.Errorf("failed to create sqladmin client: %v", err)
214 }
215
216 dc := dialConfig{
217 ipType: cloudsql.PublicIP,
218 tcpKeepAlive: defaultTCPKeepAlive,
219 useIAMAuthN: cfg.useIAMAuthN,
220 }
221 for _, opt := range cfg.dialOpts {
222 opt(&dc)
223 }
224
225 if err := trace.InitMetrics(); err != nil {
226 return nil, err
227 }
228 d := &Dialer{
229 closed: make(chan struct{}),
230 cache: make(map[instance.ConnName]monitoredCache),
231 lazyRefresh: cfg.lazyRefresh,
232 key: cfg.rsaKey,
233 refreshTimeout: cfg.refreshTimeout,
234 sqladmin: client,
235 logger: cfg.logger,
236 defaultDialConfig: dc,
237 dialerID: uuid.New().String(),
238 iamTokenSource: cfg.iamLoginTokenSource,
239 dialFunc: cfg.dialFunc,
240 }
241 return d, nil
242 }
243
244
245
246
247 func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn net.Conn, err error) {
248 select {
249 case <-d.closed:
250 return nil, ErrDialerClosed
251 default:
252 }
253 startTime := time.Now()
254 var endDial trace.EndSpanFunc
255 ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn.Dial",
256 trace.AddInstanceName(icn),
257 trace.AddDialerID(d.dialerID),
258 )
259 defer func() {
260 go trace.RecordDialError(context.Background(), icn, d.dialerID, err)
261 endDial(err)
262 }()
263 cn, err := instance.ParseConnName(icn)
264 if err != nil {
265 return nil, err
266 }
267
268 cfg := d.defaultDialConfig
269 for _, opt := range opts {
270 opt(&cfg)
271 }
272
273 var endInfo trace.EndSpanFunc
274 ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
275 c := d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN)
276 ci, err := c.ConnectionInfo(ctx)
277 if err != nil {
278 d.lock.Lock()
279 defer d.lock.Unlock()
280 d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String())
281
282 c.Close()
283 delete(d.cache, cn)
284 endInfo(err)
285 return nil, err
286 }
287 endInfo(err)
288
289
290
291
292
293
294 if !validClientCert(ctx, cn, d.logger, ci.Expiration) {
295 d.logger.Debugf(ctx, "[%v] Refreshing certificate now", cn.String())
296 c.ForceRefresh()
297
298 ci, err = c.ConnectionInfo(ctx)
299 if err != nil {
300 d.lock.Lock()
301 defer d.lock.Unlock()
302 d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String())
303
304 c.Close()
305 delete(d.cache, cn)
306 return nil, err
307 }
308 }
309
310 var connectEnd trace.EndSpanFunc
311 ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect")
312 defer func() { connectEnd(err) }()
313 addr, err := ci.Addr(cfg.ipType)
314 if err != nil {
315 return nil, err
316 }
317 addr = net.JoinHostPort(addr, serverProxyPort)
318 f := d.dialFunc
319 if cfg.dialFunc != nil {
320 f = cfg.dialFunc
321 }
322 d.logger.Debugf(ctx, "[%v] Dialing %v", cn.String(), addr)
323 conn, err = f(ctx, "tcp", addr)
324 if err != nil {
325 d.logger.Debugf(ctx, "[%v] Dialing %v failed: %v", cn.String(), addr, err)
326
327 c.ForceRefresh()
328 return nil, errtype.NewDialError("failed to dial", cn.String(), err)
329 }
330 if c, ok := conn.(*net.TCPConn); ok {
331 if err := c.SetKeepAlive(true); err != nil {
332 return nil, errtype.NewDialError("failed to set keep-alive", cn.String(), err)
333 }
334 if err := c.SetKeepAlivePeriod(cfg.tcpKeepAlive); err != nil {
335 return nil, errtype.NewDialError("failed to set keep-alive period", cn.String(), err)
336 }
337 }
338
339 tlsConn := tls.Client(conn, ci.TLSConfig())
340 err = tlsConn.HandshakeContext(ctx)
341 if err != nil {
342 d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
343
344 c.ForceRefresh()
345 _ = tlsConn.Close()
346 return nil, errtype.NewDialError("handshake failed", cn.String(), err)
347 }
348
349 latency := time.Since(startTime).Milliseconds()
350 go func() {
351 n := atomic.AddUint64(&c.openConns, 1)
352 trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
353 trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
354 }()
355
356 return newInstrumentedConn(tlsConn, func() {
357 n := atomic.AddUint64(&c.openConns, ^uint64(0))
358 trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
359 }), nil
360 }
361
362
363
364
365 func validClientCert(ctx context.Context, cn instance.ConnName, l debug.ContextLogger, expiration time.Time) bool {
366
367
368
369
370 now := time.Now().UTC()
371 valid := expiration.UTC().After(now)
372 l.Debugf(
373 ctx,
374 "[%v] Now = %v, Current cert expiration = %v",
375 cn.String(),
376 now.Format(time.RFC3339),
377 expiration.UTC().Format(time.RFC3339),
378 )
379 l.Debugf(ctx, "[%v] Cert is valid = %v", cn.String(), valid)
380 return valid
381 }
382
383
384
385
386
387 func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) {
388 cn, err := instance.ParseConnName(icn)
389 if err != nil {
390 return "", err
391 }
392 i := d.connectionInfoCache(ctx, cn, nil)
393 ci, err := i.ConnectionInfo(ctx)
394 if err != nil {
395 return "", err
396 }
397 return ci.DBVersion, nil
398 }
399
400
401
402
403 func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) error {
404 cn, err := instance.ParseConnName(icn)
405 if err != nil {
406 return err
407 }
408 cfg := d.defaultDialConfig
409 for _, opt := range opts {
410 opt(&cfg)
411 }
412 _ = d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN)
413 return nil
414 }
415
416
417
418 func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn {
419 return &instrumentedConn{
420 Conn: conn,
421 closeFunc: closeFunc,
422 }
423 }
424
425
426
427 type instrumentedConn struct {
428 net.Conn
429 closeFunc func()
430 }
431
432
433
434 func (i *instrumentedConn) Close() error {
435 err := i.Conn.Close()
436 if err != nil {
437 return err
438 }
439 go i.closeFunc()
440 return nil
441 }
442
443
444
445 func (d *Dialer) Close() error {
446
447 select {
448 case <-d.closed:
449 return nil
450 default:
451 }
452 close(d.closed)
453 d.lock.Lock()
454 defer d.lock.Unlock()
455 for _, i := range d.cache {
456 i.Close()
457 }
458 return nil
459 }
460
461
462
463
464 func (d *Dialer) connectionInfoCache(
465 ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
466 ) monitoredCache {
467 d.lock.RLock()
468 c, ok := d.cache[cn]
469 d.lock.RUnlock()
470 if !ok {
471 d.lock.Lock()
472 defer d.lock.Unlock()
473
474 c, ok = d.cache[cn]
475 if !ok {
476 var useIAMAuthNDial bool
477 if useIAMAuthN != nil {
478 useIAMAuthNDial = *useIAMAuthN
479 }
480 d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
481 var cache connectionInfoCache
482 if d.lazyRefresh {
483 cache = cloudsql.NewLazyRefreshCache(
484 cn,
485 d.logger,
486 d.sqladmin, d.key,
487 d.refreshTimeout, d.iamTokenSource,
488 d.dialerID, useIAMAuthNDial,
489 )
490 } else {
491 cache = cloudsql.NewRefreshAheadCache(
492 cn,
493 d.logger,
494 d.sqladmin, d.key,
495 d.refreshTimeout, d.iamTokenSource,
496 d.dialerID, useIAMAuthNDial,
497 )
498 }
499 c = monitoredCache{connectionInfoCache: cache}
500 d.cache[cn] = c
501 }
502 }
503
504 c.UpdateRefresh(useIAMAuthN)
505
506 return c
507 }
508
View as plain text