...

Source file src/cloud.google.com/go/cloudsqlconn/internal/cloudsql/instance.go

Documentation: cloud.google.com/go/cloudsqlconn/internal/cloudsql

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cloudsql
    16  
    17  import (
    18  	"context"
    19  	"crypto/rsa"
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"fmt"
    23  	"sync"
    24  	"time"
    25  
    26  	"cloud.google.com/go/cloudsqlconn/debug"
    27  	"cloud.google.com/go/cloudsqlconn/errtype"
    28  	"cloud.google.com/go/cloudsqlconn/instance"
    29  	"golang.org/x/oauth2"
    30  	"golang.org/x/time/rate"
    31  	sqladmin "google.golang.org/api/sqladmin/v1beta4"
    32  )
    33  
    34  const (
    35  	// the refresh buffer is the amount of time before a refresh operation's
    36  	// certificate expires that a new refresh operation begins.
    37  	refreshBuffer = 4 * time.Minute
    38  
    39  	// refreshInterval is the amount of time between refresh attempts as
    40  	// enforced by the rate limiter.
    41  	refreshInterval = 30 * time.Second
    42  
    43  	// RefreshTimeout is the maximum amount of time to wait for a refresh
    44  	// cycle to complete. This value should be greater than the
    45  	// refreshInterval.
    46  	RefreshTimeout = 60 * time.Second
    47  
    48  	// refreshBurst is the initial burst allowed by the rate limiter.
    49  	refreshBurst = 2
    50  )
    51  
    52  // refreshOperation is a pending result of a refresh operation of data used to
    53  // connect securely. It should only be initialized by the Instance struct as
    54  // part of a refresh cycle.
    55  type refreshOperation struct {
    56  	// indicates the struct is ready to read from
    57  	ready chan struct{}
    58  	// timer that triggers refresh, can be used to cancel.
    59  	timer  *time.Timer
    60  	result ConnectionInfo
    61  	err    error
    62  }
    63  
    64  // cancel prevents the the refresh operation from starting, if it hasn't
    65  // already started. Returns true if timer was stopped successfully, or false if
    66  // it has already started.
    67  func (r *refreshOperation) cancel() bool {
    68  	return r.timer.Stop()
    69  }
    70  
    71  // isValid returns true if this result is complete, successful, and is still
    72  // valid.
    73  func (r *refreshOperation) isValid() bool {
    74  	// verify the refreshOperation has finished running
    75  	select {
    76  	default:
    77  		return false
    78  	case <-r.ready:
    79  		if r.err != nil || time.Now().After(r.result.Expiration.Round(0)) {
    80  			return false
    81  		}
    82  		return true
    83  	}
    84  }
    85  
    86  // RefreshAheadCache manages the information used to connect to the Cloud SQL
    87  // instance by periodically calling the Cloud SQL Admin API. It automatically
    88  // refreshes the required information approximately 4 minutes before the
    89  // previous certificate expires (every ~56 minutes).
    90  type RefreshAheadCache struct {
    91  	// openConns is the number of open connections to the instance.
    92  	openConns uint64
    93  
    94  	connName instance.ConnName
    95  	logger   debug.ContextLogger
    96  	key      *rsa.PrivateKey
    97  
    98  	// refreshTimeout sets the maximum duration a refresh cycle can run
    99  	// for.
   100  	refreshTimeout time.Duration
   101  	// l controls the rate at which refresh cycles are run.
   102  	l *rate.Limiter
   103  	r refresher
   104  
   105  	mu              sync.RWMutex
   106  	useIAMAuthNDial bool
   107  	// cur represents the current refreshOperation that will be used to
   108  	// create connections. If a valid complete refreshOperation isn't
   109  	// available it's possible for cur to be equal to next.
   110  	cur *refreshOperation
   111  	// next represents a future or ongoing refreshOperation. Once complete,
   112  	// it will replace cur and schedule a replacement to occur.
   113  	next *refreshOperation
   114  
   115  	// ctx is the default ctx for refresh operations. Canceling it prevents
   116  	// new refresh operations from being triggered.
   117  	ctx    context.Context
   118  	cancel context.CancelFunc
   119  }
   120  
   121  // NewRefreshAheadCache initializes a new Instance given an instance connection name
   122  func NewRefreshAheadCache(
   123  	cn instance.ConnName,
   124  	l debug.ContextLogger,
   125  	client *sqladmin.Service,
   126  	key *rsa.PrivateKey,
   127  	refreshTimeout time.Duration,
   128  	ts oauth2.TokenSource,
   129  	dialerID string,
   130  	useIAMAuthNDial bool,
   131  ) *RefreshAheadCache {
   132  	ctx, cancel := context.WithCancel(context.Background())
   133  	i := &RefreshAheadCache{
   134  		connName: cn,
   135  		logger:   l,
   136  		key:      key,
   137  		l:        rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
   138  		r: newRefresher(
   139  			l,
   140  			client,
   141  			ts,
   142  			dialerID,
   143  		),
   144  		refreshTimeout:  refreshTimeout,
   145  		useIAMAuthNDial: useIAMAuthNDial,
   146  		ctx:             ctx,
   147  		cancel:          cancel,
   148  	}
   149  	// For the initial refresh operation, set cur = next so that connection
   150  	// requests block until the first refresh is complete.
   151  	i.mu.Lock()
   152  	i.cur = i.scheduleRefresh(0)
   153  	i.next = i.cur
   154  	i.mu.Unlock()
   155  	return i
   156  }
   157  
   158  // Close closes the instance; it stops the refresh cycle and prevents it from
   159  // making additional calls to the Cloud SQL Admin API.
   160  func (i *RefreshAheadCache) Close() error {
   161  	i.mu.Lock()
   162  	defer i.mu.Unlock()
   163  	i.cancel()
   164  	i.cur.cancel()
   165  	i.next.cancel()
   166  	return nil
   167  }
   168  
   169  // ConnectionInfo contains all necessary information to connect securely to the
   170  // server-side Proxy running on a Cloud SQL instance.
   171  type ConnectionInfo struct {
   172  	ConnectionName    instance.ConnName
   173  	ClientCertificate tls.Certificate
   174  	ServerCaCert      *x509.Certificate
   175  	DBVersion         string
   176  	Expiration        time.Time
   177  
   178  	addrs map[string]string
   179  }
   180  
   181  // Addr returns the IP address or DNS name for the given IP type.
   182  func (c ConnectionInfo) Addr(ipType string) (string, error) {
   183  	var (
   184  		addr string
   185  		ok   bool
   186  	)
   187  	switch ipType {
   188  	case AutoIP:
   189  		// Try Public first
   190  		addr, ok = c.addrs[PublicIP]
   191  		if !ok {
   192  			// Try Private second
   193  			addr, ok = c.addrs[PrivateIP]
   194  		}
   195  	default:
   196  		addr, ok = c.addrs[ipType]
   197  	}
   198  	if !ok {
   199  		err := errtype.NewConfigError(
   200  			fmt.Sprintf("instance does not have IP of type %q", ipType),
   201  			c.ConnectionName.String(),
   202  		)
   203  		return "", err
   204  	}
   205  	return addr, nil
   206  }
   207  
   208  // TLSConfig constructs a TLS configuration for the given connection info.
   209  func (c ConnectionInfo) TLSConfig() *tls.Config {
   210  	pool := x509.NewCertPool()
   211  	pool.AddCert(c.ServerCaCert)
   212  	return &tls.Config{
   213  		ServerName:   c.ConnectionName.String(),
   214  		Certificates: []tls.Certificate{c.ClientCertificate},
   215  		RootCAs:      pool,
   216  		// We need to set InsecureSkipVerify to true due to
   217  		// https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/194
   218  		// https://tip.golang.org/doc/go1.11#crypto/x509
   219  		//
   220  		// Since we have a secure channel to the Cloud SQL API which we use to
   221  		// retrieve the certificates, we instead need to implement our own
   222  		// VerifyPeerCertificate function that will verify that the certificate
   223  		// is OK.
   224  		InsecureSkipVerify:    true,
   225  		VerifyPeerCertificate: verifyPeerCertificateFunc(c.ConnectionName, pool),
   226  		MinVersion:            tls.VersionTLS13,
   227  	}
   228  }
   229  
   230  // verifyPeerCertificateFunc creates a VerifyPeerCertificate func that
   231  // verifies that the peer certificate is in the cert pool. We need to define
   232  // our own because CloudSQL instances use the instance name (e.g.,
   233  // my-project:my-instance) instead of a valid domain name for the certificate's
   234  // Common Name.
   235  func verifyPeerCertificateFunc(
   236  	cn instance.ConnName, pool *x509.CertPool,
   237  ) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
   238  	return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
   239  		if len(rawCerts) == 0 {
   240  			return errtype.NewDialError(
   241  				"no certificate to verify", cn.String(), nil,
   242  			)
   243  		}
   244  
   245  		cert, err := x509.ParseCertificate(rawCerts[0])
   246  		if err != nil {
   247  			return errtype.NewDialError(
   248  				"failed to parse X.509 certificate", cn.String(), err,
   249  			)
   250  		}
   251  
   252  		opts := x509.VerifyOptions{Roots: pool}
   253  		if _, err = cert.Verify(opts); err != nil {
   254  			return errtype.NewDialError(
   255  				"failed to verify certificate", cn.String(), err,
   256  			)
   257  		}
   258  
   259  		certInstanceName := fmt.Sprintf("%s:%s", cn.Project(), cn.Name())
   260  		if cert.Subject.CommonName != certInstanceName {
   261  			return errtype.NewDialError(
   262  				fmt.Sprintf(
   263  					"certificate had CN %q, expected %q",
   264  					cert.Subject.CommonName, certInstanceName,
   265  				),
   266  				cn.String(),
   267  				nil,
   268  			)
   269  		}
   270  		return nil
   271  	}
   272  }
   273  
   274  // ConnectionInfo returns an IP address specified by ipType (i.e., public or
   275  // private) and a TLS config that can be used to connect to a Cloud SQL
   276  // instance.
   277  func (i *RefreshAheadCache) ConnectionInfo(ctx context.Context) (ConnectionInfo, error) {
   278  	op, err := i.refreshOperation(ctx)
   279  	if err != nil {
   280  		return ConnectionInfo{}, err
   281  	}
   282  	return op.result, nil
   283  }
   284  
   285  // UpdateRefresh cancels all existing refresh attempts and schedules new
   286  // attempts with the provided config only if it differs from the current
   287  // configuration.
   288  func (i *RefreshAheadCache) UpdateRefresh(useIAMAuthNDial *bool) {
   289  	i.mu.Lock()
   290  	defer i.mu.Unlock()
   291  	if useIAMAuthNDial != nil && *useIAMAuthNDial != i.useIAMAuthNDial {
   292  		// Cancel any pending refreshes
   293  		i.cur.cancel()
   294  		i.next.cancel()
   295  
   296  		i.useIAMAuthNDial = *useIAMAuthNDial
   297  		// reschedule a new refresh immediately
   298  		i.cur = i.scheduleRefresh(0)
   299  		i.next = i.cur
   300  	}
   301  }
   302  
   303  // ForceRefresh triggers an immediate refresh operation to be scheduled and
   304  // used for future connection attempts. Until the refresh completes, the
   305  // existing connection info will be available for use if valid.
   306  func (i *RefreshAheadCache) ForceRefresh() {
   307  	i.mu.Lock()
   308  	defer i.mu.Unlock()
   309  	// If the next refresh hasn't started yet, we can cancel it and start an
   310  	// immediate one
   311  	if i.next.cancel() {
   312  		i.next = i.scheduleRefresh(0)
   313  	}
   314  	// block all sequential connection attempts on the next refresh operation
   315  	// if current is invalid
   316  	if !i.cur.isValid() {
   317  		i.cur = i.next
   318  	}
   319  }
   320  
   321  // refreshOperation returns the most recent refresh operation
   322  // waiting for it to complete if necessary
   323  func (i *RefreshAheadCache) refreshOperation(ctx context.Context) (*refreshOperation, error) {
   324  	i.mu.RLock()
   325  	cur := i.cur
   326  	i.mu.RUnlock()
   327  	var err error
   328  	select {
   329  	case <-cur.ready:
   330  		err = cur.err
   331  	case <-ctx.Done():
   332  		err = ctx.Err()
   333  	case <-i.ctx.Done():
   334  		err = i.ctx.Err()
   335  	}
   336  	if err != nil {
   337  		return nil, err
   338  	}
   339  	return cur, nil
   340  }
   341  
   342  // refreshDuration returns the duration to wait before starting the next
   343  // refresh. Usually that duration will be half of the time until certificate
   344  // expiration.
   345  func refreshDuration(now, certExpiry time.Time) time.Duration {
   346  	d := certExpiry.Sub(now.Round(0))
   347  	if d < time.Hour {
   348  		// Something is wrong with the certificate, refresh now.
   349  		if d < refreshBuffer {
   350  			return 0
   351  		}
   352  		// Otherwise wait until 4 minutes before expiration for next
   353  		// refresh cycle.
   354  		return d - refreshBuffer
   355  	}
   356  	return d / 2
   357  }
   358  
   359  // scheduleRefresh schedules a refresh operation to be triggered after a given
   360  // duration. The returned refreshOperation can be used to either Cancel or Wait
   361  // for the operation's completion.
   362  func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
   363  	r := &refreshOperation{}
   364  	r.ready = make(chan struct{})
   365  	r.timer = time.AfterFunc(d, func() {
   366  		// instance has been closed, don't schedule anything
   367  		if err := i.ctx.Err(); err != nil {
   368  			i.logger.Debugf(
   369  				context.Background(),
   370  				"[%v] Instance is closed, stopping refresh operations",
   371  				i.connName.String(),
   372  			)
   373  			r.err = err
   374  			close(r.ready)
   375  			return
   376  		}
   377  		i.logger.Debugf(
   378  			context.Background(),
   379  			"[%v] Connection info refresh operation started",
   380  			i.connName.String(),
   381  		)
   382  
   383  		ctx, cancel := context.WithTimeout(i.ctx, i.refreshTimeout)
   384  		defer cancel()
   385  
   386  		// avoid refreshing too often to try not to tax the SQL Admin
   387  		// API quotas
   388  		err := i.l.Wait(ctx)
   389  		if err != nil {
   390  			r.err = errtype.NewDialError(
   391  				"context was canceled or expired before refresh completed",
   392  				i.connName.String(),
   393  				nil,
   394  			)
   395  		} else {
   396  			r.result, r.err = i.r.ConnectionInfo(
   397  				ctx, i.connName, i.key, i.useIAMAuthNDial)
   398  		}
   399  		switch r.err {
   400  		case nil:
   401  			i.logger.Debugf(
   402  				ctx,
   403  				"[%v] Connection info refresh operation complete",
   404  				i.connName.String(),
   405  			)
   406  			i.logger.Debugf(
   407  				ctx,
   408  				"[%v] Current certificate expiration = %v",
   409  				i.connName.String(),
   410  				r.result.Expiration.UTC().Format(time.RFC3339),
   411  			)
   412  		default:
   413  			i.logger.Debugf(
   414  				ctx,
   415  				"[%v] Connection info refresh operation failed, err = %v",
   416  				i.connName.String(),
   417  				r.err,
   418  			)
   419  		}
   420  
   421  		close(r.ready)
   422  
   423  		// Once the refresh is complete, update "current" with working
   424  		// refreshOperation and schedule a new refresh
   425  		i.mu.Lock()
   426  		defer i.mu.Unlock()
   427  
   428  		// if failed, scheduled the next refresh immediately
   429  		if r.err != nil {
   430  			i.logger.Debugf(
   431  				ctx,
   432  				"[%v] Connection info refresh operation scheduled immediately",
   433  				i.connName.String(),
   434  			)
   435  			i.next = i.scheduleRefresh(0)
   436  			// If the latest refreshOperation is bad, avoid replacing the
   437  			// used refreshOperation while it's still valid and potentially
   438  			// able to provide successful connections. TODO: This
   439  			// means that errors while the current refreshOperation is still
   440  			// valid are suppressed. We should try to surface
   441  			// errors in a more meaningful way.
   442  			if !i.cur.isValid() {
   443  				i.cur = r
   444  			}
   445  			return
   446  		}
   447  
   448  		// Update the current results, and schedule the next refresh in
   449  		// the future
   450  		i.cur = r
   451  		t := refreshDuration(time.Now(), i.cur.result.Expiration)
   452  		i.logger.Debugf(
   453  			ctx,
   454  			"[%v] Connection info refresh operation scheduled at %v (now + %v)",
   455  			i.connName.String(),
   456  			time.Now().Add(t).UTC().Format(time.RFC3339),
   457  			t.Round(time.Minute),
   458  		)
   459  		i.next = i.scheduleRefresh(t)
   460  	})
   461  	return r
   462  }
   463  

View as plain text