...

Source file src/cloud.google.com/go/cloudsqlconn/dialer.go

Documentation: cloud.google.com/go/cloudsqlconn

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cloudsqlconn
    16  
    17  import (
    18  	"context"
    19  	"crypto/rand"
    20  	"crypto/rsa"
    21  	"crypto/tls"
    22  	_ "embed"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net"
    27  	"strings"
    28  	"sync"
    29  	"sync/atomic"
    30  	"time"
    31  
    32  	"cloud.google.com/go/cloudsqlconn/debug"
    33  	"cloud.google.com/go/cloudsqlconn/errtype"
    34  	"cloud.google.com/go/cloudsqlconn/instance"
    35  	"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
    36  	"cloud.google.com/go/cloudsqlconn/internal/trace"
    37  	"github.com/google/uuid"
    38  	"golang.org/x/net/proxy"
    39  	"golang.org/x/oauth2"
    40  	"golang.org/x/oauth2/google"
    41  	"google.golang.org/api/option"
    42  	sqladmin "google.golang.org/api/sqladmin/v1beta4"
    43  )
    44  
    45  const (
    46  	// defaultTCPKeepAlive is the default keep alive value used on connections to a Cloud SQL instance.
    47  	defaultTCPKeepAlive = 30 * time.Second
    48  	// serverProxyPort is the port the server-side proxy receives connections on.
    49  	serverProxyPort = "3307"
    50  	// iamLoginScope is the OAuth2 scope used for tokens embedded in the ephemeral
    51  	// certificate.
    52  	iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login"
    53  )
    54  
    55  var (
    56  	// ErrDialerClosed is used when a caller invokes Dial after closing the
    57  	// Dialer.
    58  	ErrDialerClosed = errors.New("cloudsqlconn: dialer is closed")
    59  	// versionString indicates the version of this library.
    60  	//go:embed version.txt
    61  	versionString string
    62  	userAgent     = "cloud-sql-go-connector/" + strings.TrimSpace(versionString)
    63  
    64  	// defaultKey is the default RSA public/private keypair used by the clients.
    65  	defaultKey    *rsa.PrivateKey
    66  	defaultKeyErr error
    67  	keyOnce       sync.Once
    68  )
    69  
    70  func getDefaultKeys() (*rsa.PrivateKey, error) {
    71  	keyOnce.Do(func() {
    72  		defaultKey, defaultKeyErr = rsa.GenerateKey(rand.Reader, 2048)
    73  	})
    74  	return defaultKey, defaultKeyErr
    75  }
    76  
    77  type connectionInfoCache interface {
    78  	ConnectionInfo(context.Context) (cloudsql.ConnectionInfo, error)
    79  	UpdateRefresh(*bool)
    80  	ForceRefresh()
    81  	io.Closer
    82  }
    83  
    84  // monitoredCache is a wrapper around a connectionInfoCache that tracks the
    85  // number of connections to the associated instance.
    86  type monitoredCache struct {
    87  	openConns uint64
    88  
    89  	connectionInfoCache
    90  }
    91  
    92  // A Dialer is used to create connections to Cloud SQL instances.
    93  //
    94  // Use NewDialer to initialize a Dialer.
    95  type Dialer struct {
    96  	lock           sync.RWMutex
    97  	cache          map[instance.ConnName]monitoredCache
    98  	key            *rsa.PrivateKey
    99  	refreshTimeout time.Duration
   100  	// closed reports if the dialer has been closed.
   101  	closed chan struct{}
   102  
   103  	sqladmin *sqladmin.Service
   104  	logger   debug.ContextLogger
   105  
   106  	// lazyRefresh determines what kind of caching is used for ephemeral
   107  	// certificates. When lazyRefresh is true, the dialer will use a lazy
   108  	// cache, refresh certificates only when a connection attempt needs a fresh
   109  	// certificate. Otherwise, a refresh ahead cache will be used. The refresh
   110  	// ahead cache assumes a background goroutine may run consistently.
   111  	lazyRefresh bool
   112  
   113  	// defaultDialConfig holds the constructor level DialOptions, so that it
   114  	// can be copied and mutated by the Dial function.
   115  	defaultDialConfig dialConfig
   116  
   117  	// dialerID uniquely identifies a Dialer. Used for monitoring purposes,
   118  	// *only* when a client has configured OpenCensus exporters.
   119  	dialerID string
   120  
   121  	// dialFunc is the function used to connect to the address on the named
   122  	// network. By default, it is golang.org/x/net/proxy#Dial.
   123  	dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)
   124  
   125  	// iamTokenSource supplies the OAuth2 token used for IAM DB Authn.
   126  	iamTokenSource oauth2.TokenSource
   127  }
   128  
   129  var (
   130  	errUseTokenSource    = errors.New("use WithTokenSource when IAM AuthN is not enabled")
   131  	errUseIAMTokenSource = errors.New("use WithIAMAuthNTokenSources instead of WithTokenSource be used when IAM AuthN is enabled")
   132  )
   133  
   134  type nullLogger struct{}
   135  
   136  func (nullLogger) Debugf(_ context.Context, _ string, _ ...interface{}) {}
   137  
   138  // NewDialer creates a new Dialer.
   139  //
   140  // Initial calls to NewDialer make take longer than normal because generation of an
   141  // RSA keypair is performed. Calls with a WithRSAKeyPair DialOption or after a default
   142  // RSA keypair is generated will be faster.
   143  func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
   144  	cfg := &dialerConfig{
   145  		refreshTimeout:  cloudsql.RefreshTimeout,
   146  		dialFunc:        proxy.Dial,
   147  		logger:          nullLogger{},
   148  		useragents:      []string{userAgent},
   149  		serviceUniverse: "googleapis.com",
   150  	}
   151  	for _, opt := range opts {
   152  		opt(cfg)
   153  		if cfg.err != nil {
   154  			return nil, cfg.err
   155  		}
   156  	}
   157  	if cfg.useIAMAuthN && cfg.setTokenSource && !cfg.setIAMAuthNTokenSource {
   158  		return nil, errUseIAMTokenSource
   159  	}
   160  	if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
   161  		return nil, errUseTokenSource
   162  	}
   163  	// Add this to the end to make sure it's not overridden
   164  	cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))
   165  
   166  	// If callers have not provided a token source, either explicitly with
   167  	// WithTokenSource or implicitly with WithCredentialsJSON etc., then use the
   168  	// default token source.
   169  	if !cfg.setCredentials {
   170  		c, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
   171  		if err != nil {
   172  			return nil, fmt.Errorf("failed to create default credentials: %v", err)
   173  		}
   174  		ud, err := c.GetUniverseDomain()
   175  		if err != nil {
   176  			return nil, fmt.Errorf("failed to get universe domain: %v", err)
   177  		}
   178  		cfg.credentialsUniverse = ud
   179  		cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(c.TokenSource))
   180  		scoped, err := google.DefaultTokenSource(ctx, iamLoginScope)
   181  		if err != nil {
   182  			return nil, fmt.Errorf("failed to create scoped token source: %v", err)
   183  		}
   184  		cfg.iamLoginTokenSource = scoped
   185  	}
   186  
   187  	if cfg.rsaKey == nil {
   188  		key, err := getDefaultKeys()
   189  		if err != nil {
   190  			return nil, fmt.Errorf("failed to generate RSA keys: %v", err)
   191  		}
   192  		cfg.rsaKey = key
   193  	}
   194  
   195  	if cfg.setUniverseDomain && cfg.setAdminAPIEndpoint {
   196  		return nil, errors.New(
   197  			"can not use WithAdminAPIEndpoint and WithUniverseDomain Options together, " +
   198  				"use WithAdminAPIEndpoint (it already contains the universe domain)",
   199  		)
   200  	}
   201  
   202  	if cfg.credentialsUniverse != "" && cfg.serviceUniverse != "" {
   203  		if cfg.credentialsUniverse != cfg.serviceUniverse {
   204  			return nil, fmt.Errorf(
   205  				"the configured service universe domain (%s) does not match the credential universe domain (%s)",
   206  				cfg.serviceUniverse, cfg.credentialsUniverse,
   207  			)
   208  		}
   209  	}
   210  
   211  	client, err := sqladmin.NewService(ctx, cfg.sqladminOpts...)
   212  	if err != nil {
   213  		return nil, fmt.Errorf("failed to create sqladmin client: %v", err)
   214  	}
   215  
   216  	dc := dialConfig{
   217  		ipType:       cloudsql.PublicIP,
   218  		tcpKeepAlive: defaultTCPKeepAlive,
   219  		useIAMAuthN:  cfg.useIAMAuthN,
   220  	}
   221  	for _, opt := range cfg.dialOpts {
   222  		opt(&dc)
   223  	}
   224  
   225  	if err := trace.InitMetrics(); err != nil {
   226  		return nil, err
   227  	}
   228  	d := &Dialer{
   229  		closed:            make(chan struct{}),
   230  		cache:             make(map[instance.ConnName]monitoredCache),
   231  		lazyRefresh:       cfg.lazyRefresh,
   232  		key:               cfg.rsaKey,
   233  		refreshTimeout:    cfg.refreshTimeout,
   234  		sqladmin:          client,
   235  		logger:            cfg.logger,
   236  		defaultDialConfig: dc,
   237  		dialerID:          uuid.New().String(),
   238  		iamTokenSource:    cfg.iamLoginTokenSource,
   239  		dialFunc:          cfg.dialFunc,
   240  	}
   241  	return d, nil
   242  }
   243  
   244  // Dial returns a net.Conn connected to the specified Cloud SQL instance. The
   245  // icn argument must be the instance's connection name, which is in the format
   246  // "project-name:region:instance-name".
   247  func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn net.Conn, err error) {
   248  	select {
   249  	case <-d.closed:
   250  		return nil, ErrDialerClosed
   251  	default:
   252  	}
   253  	startTime := time.Now()
   254  	var endDial trace.EndSpanFunc
   255  	ctx, endDial = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn.Dial",
   256  		trace.AddInstanceName(icn),
   257  		trace.AddDialerID(d.dialerID),
   258  	)
   259  	defer func() {
   260  		go trace.RecordDialError(context.Background(), icn, d.dialerID, err)
   261  		endDial(err)
   262  	}()
   263  	cn, err := instance.ParseConnName(icn)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  
   268  	cfg := d.defaultDialConfig
   269  	for _, opt := range opts {
   270  		opt(&cfg)
   271  	}
   272  
   273  	var endInfo trace.EndSpanFunc
   274  	ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
   275  	c := d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN)
   276  	ci, err := c.ConnectionInfo(ctx)
   277  	if err != nil {
   278  		d.lock.Lock()
   279  		defer d.lock.Unlock()
   280  		d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String())
   281  		// Stop all background refreshes
   282  		c.Close()
   283  		delete(d.cache, cn)
   284  		endInfo(err)
   285  		return nil, err
   286  	}
   287  	endInfo(err)
   288  
   289  	// If the client certificate has expired (as when the computer goes to
   290  	// sleep, and the refresh cycle cannot run), force a refresh immediately.
   291  	// The TLS handshake will not fail on an expired client certificate. It's
   292  	// not until the first read where the client cert error will be surfaced.
   293  	// So check that the certificate is valid before proceeding.
   294  	if !validClientCert(ctx, cn, d.logger, ci.Expiration) {
   295  		d.logger.Debugf(ctx, "[%v] Refreshing certificate now", cn.String())
   296  		c.ForceRefresh()
   297  		// Block on refreshed connection info
   298  		ci, err = c.ConnectionInfo(ctx)
   299  		if err != nil {
   300  			d.lock.Lock()
   301  			defer d.lock.Unlock()
   302  			d.logger.Debugf(ctx, "[%v] Removing connection info from cache", cn.String())
   303  			// Stop all background refreshes
   304  			c.Close()
   305  			delete(d.cache, cn)
   306  			return nil, err
   307  		}
   308  	}
   309  
   310  	var connectEnd trace.EndSpanFunc
   311  	ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect")
   312  	defer func() { connectEnd(err) }()
   313  	addr, err := ci.Addr(cfg.ipType)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	addr = net.JoinHostPort(addr, serverProxyPort)
   318  	f := d.dialFunc
   319  	if cfg.dialFunc != nil {
   320  		f = cfg.dialFunc
   321  	}
   322  	d.logger.Debugf(ctx, "[%v] Dialing %v", cn.String(), addr)
   323  	conn, err = f(ctx, "tcp", addr)
   324  	if err != nil {
   325  		d.logger.Debugf(ctx, "[%v] Dialing %v failed: %v", cn.String(), addr, err)
   326  		// refresh the instance info in case it caused the connection failure
   327  		c.ForceRefresh()
   328  		return nil, errtype.NewDialError("failed to dial", cn.String(), err)
   329  	}
   330  	if c, ok := conn.(*net.TCPConn); ok {
   331  		if err := c.SetKeepAlive(true); err != nil {
   332  			return nil, errtype.NewDialError("failed to set keep-alive", cn.String(), err)
   333  		}
   334  		if err := c.SetKeepAlivePeriod(cfg.tcpKeepAlive); err != nil {
   335  			return nil, errtype.NewDialError("failed to set keep-alive period", cn.String(), err)
   336  		}
   337  	}
   338  
   339  	tlsConn := tls.Client(conn, ci.TLSConfig())
   340  	err = tlsConn.HandshakeContext(ctx)
   341  	if err != nil {
   342  		d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
   343  		// refresh the instance info in case it caused the handshake failure
   344  		c.ForceRefresh()
   345  		_ = tlsConn.Close() // best effort close attempt
   346  		return nil, errtype.NewDialError("handshake failed", cn.String(), err)
   347  	}
   348  
   349  	latency := time.Since(startTime).Milliseconds()
   350  	go func() {
   351  		n := atomic.AddUint64(&c.openConns, 1)
   352  		trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
   353  		trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
   354  	}()
   355  
   356  	return newInstrumentedConn(tlsConn, func() {
   357  		n := atomic.AddUint64(&c.openConns, ^uint64(0))
   358  		trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
   359  	}), nil
   360  }
   361  
   362  // validClientCert checks that the ephemeral client certificate retrieved from
   363  // the cache is unexpired. The time comparisons strip the monotonic clock value
   364  // to ensure an accurate result, even after laptop sleep.
   365  func validClientCert(ctx context.Context, cn instance.ConnName, l debug.ContextLogger, expiration time.Time) bool {
   366  	// Use UTC() to strip monotonic clock value to guard against inaccurate
   367  	// comparisons, especially after laptop sleep.
   368  	// See the comments on the monotonic clock in the Go documentation for
   369  	// details: https://pkg.go.dev/time#hdr-Monotonic_Clocks
   370  	now := time.Now().UTC()
   371  	valid := expiration.UTC().After(now)
   372  	l.Debugf(
   373  		ctx,
   374  		"[%v] Now = %v, Current cert expiration = %v",
   375  		cn.String(),
   376  		now.Format(time.RFC3339),
   377  		expiration.UTC().Format(time.RFC3339),
   378  	)
   379  	l.Debugf(ctx, "[%v] Cert is valid = %v", cn.String(), valid)
   380  	return valid
   381  }
   382  
   383  // EngineVersion returns the engine type and version for the instance
   384  // connection name. The value will correspond to one of the following types for
   385  // the instance:
   386  // https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
   387  func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) {
   388  	cn, err := instance.ParseConnName(icn)
   389  	if err != nil {
   390  		return "", err
   391  	}
   392  	i := d.connectionInfoCache(ctx, cn, nil)
   393  	ci, err := i.ConnectionInfo(ctx)
   394  	if err != nil {
   395  		return "", err
   396  	}
   397  	return ci.DBVersion, nil
   398  }
   399  
   400  // Warmup starts the background refresh necessary to connect to the instance.
   401  // Use Warmup to start the refresh process early if you don't know when you'll
   402  // need to call "Dial".
   403  func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) error {
   404  	cn, err := instance.ParseConnName(icn)
   405  	if err != nil {
   406  		return err
   407  	}
   408  	cfg := d.defaultDialConfig
   409  	for _, opt := range opts {
   410  		opt(&cfg)
   411  	}
   412  	_ = d.connectionInfoCache(ctx, cn, &cfg.useIAMAuthN)
   413  	return nil
   414  }
   415  
   416  // newInstrumentedConn initializes an instrumentedConn that on closing will
   417  // decrement the number of open connects and record the result.
   418  func newInstrumentedConn(conn net.Conn, closeFunc func()) *instrumentedConn {
   419  	return &instrumentedConn{
   420  		Conn:      conn,
   421  		closeFunc: closeFunc,
   422  	}
   423  }
   424  
   425  // instrumentedConn wraps a net.Conn and invokes closeFunc when the connection
   426  // is closed.
   427  type instrumentedConn struct {
   428  	net.Conn
   429  	closeFunc func()
   430  }
   431  
   432  // Close delegates to the underlying net.Conn interface and reports the close
   433  // to the provided closeFunc only when Close returns no error.
   434  func (i *instrumentedConn) Close() error {
   435  	err := i.Conn.Close()
   436  	if err != nil {
   437  		return err
   438  	}
   439  	go i.closeFunc()
   440  	return nil
   441  }
   442  
   443  // Close closes the Dialer; it prevents the Dialer from refreshing the information
   444  // needed to connect.
   445  func (d *Dialer) Close() error {
   446  	// Check if Close has already been called.
   447  	select {
   448  	case <-d.closed:
   449  		return nil
   450  	default:
   451  	}
   452  	close(d.closed)
   453  	d.lock.Lock()
   454  	defer d.lock.Unlock()
   455  	for _, i := range d.cache {
   456  		i.Close()
   457  	}
   458  	return nil
   459  }
   460  
   461  // connectionInfoCache is a helper function for returning the appropriate
   462  // connection info Cache in a threadsafe way. It will create a new cache,
   463  // modify the existing one, or leave it unchanged as needed.
   464  func (d *Dialer) connectionInfoCache(
   465  	ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
   466  ) monitoredCache {
   467  	d.lock.RLock()
   468  	c, ok := d.cache[cn]
   469  	d.lock.RUnlock()
   470  	if !ok {
   471  		d.lock.Lock()
   472  		defer d.lock.Unlock()
   473  		// Recheck to ensure instance wasn't created or changed between locks
   474  		c, ok = d.cache[cn]
   475  		if !ok {
   476  			var useIAMAuthNDial bool
   477  			if useIAMAuthN != nil {
   478  				useIAMAuthNDial = *useIAMAuthN
   479  			}
   480  			d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String())
   481  			var cache connectionInfoCache
   482  			if d.lazyRefresh {
   483  				cache = cloudsql.NewLazyRefreshCache(
   484  					cn,
   485  					d.logger,
   486  					d.sqladmin, d.key,
   487  					d.refreshTimeout, d.iamTokenSource,
   488  					d.dialerID, useIAMAuthNDial,
   489  				)
   490  			} else {
   491  				cache = cloudsql.NewRefreshAheadCache(
   492  					cn,
   493  					d.logger,
   494  					d.sqladmin, d.key,
   495  					d.refreshTimeout, d.iamTokenSource,
   496  					d.dialerID, useIAMAuthNDial,
   497  				)
   498  			}
   499  			c = monitoredCache{connectionInfoCache: cache}
   500  			d.cache[cn] = c
   501  		}
   502  	}
   503  
   504  	c.UpdateRefresh(useIAMAuthN)
   505  
   506  	return c
   507  }
   508  

View as plain text