...

Source file src/cloud.google.com/go/cloudsqlconn/internal/cloudsql/refresh.go

Documentation: cloud.google.com/go/cloudsqlconn/internal/cloudsql

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cloudsql
    16  
    17  import (
    18  	"context"
    19  	"crypto/rsa"
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"encoding/pem"
    23  	"fmt"
    24  	"strings"
    25  	"time"
    26  
    27  	"cloud.google.com/go/cloudsqlconn/debug"
    28  	"cloud.google.com/go/cloudsqlconn/errtype"
    29  	"cloud.google.com/go/cloudsqlconn/instance"
    30  	"cloud.google.com/go/cloudsqlconn/internal/trace"
    31  	"golang.org/x/oauth2"
    32  	sqladmin "google.golang.org/api/sqladmin/v1beta4"
    33  )
    34  
    35  const (
    36  	// PublicIP is the value for public IP Cloud SQL instances.
    37  	PublicIP = "PUBLIC"
    38  	// PrivateIP is the value for private IP Cloud SQL instances.
    39  	PrivateIP = "PRIVATE"
    40  	// PSC is the value for private service connect Cloud SQL instances.
    41  	PSC = "PSC"
    42  	// AutoIP selects public IP if available and otherwise selects private
    43  	// IP.
    44  	AutoIP = "AutoIP"
    45  )
    46  
    47  // metadata contains information about a Cloud SQL instance needed to create
    48  // connections.
    49  type metadata struct {
    50  	ipAddrs      map[string]string
    51  	serverCaCert *x509.Certificate
    52  	version      string
    53  }
    54  
    55  // fetchMetadata uses the Cloud SQL Admin APIs get method to retrieve the
    56  // information about a Cloud SQL instance that is used to create secure
    57  // connections.
    58  func fetchMetadata(
    59  	ctx context.Context, client *sqladmin.Service, inst instance.ConnName,
    60  ) (m metadata, err error) {
    61  
    62  	var end trace.EndSpanFunc
    63  	ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.FetchMetadata")
    64  	defer func() { end(err) }()
    65  
    66  	db, err := retry50x(ctx, func(ctx2 context.Context) (*sqladmin.ConnectSettings, error) {
    67  		return client.Connect.Get(
    68  			inst.Project(), inst.Name(),
    69  		).Context(ctx2).Do()
    70  	}, exponentialBackoff)
    71  	if err != nil {
    72  		return metadata{}, errtype.NewRefreshError("failed to get instance metadata", inst.String(), err)
    73  	}
    74  	// validate the instance is supported for authenticated connections
    75  	if db.Region != inst.Region() {
    76  		msg := fmt.Sprintf(
    77  			"provided region was mismatched - got %s, want %s",
    78  			inst.Region(), db.Region,
    79  		)
    80  		return metadata{}, errtype.NewConfigError(msg, inst.String())
    81  	}
    82  	if db.BackendType != "SECOND_GEN" {
    83  		return metadata{}, errtype.NewConfigError(
    84  			"unsupported instance - only Second Generation instances are supported",
    85  			inst.String(),
    86  		)
    87  	}
    88  
    89  	// parse any ip addresses that might be used to connect
    90  	ipAddrs := make(map[string]string)
    91  	for _, ip := range db.IpAddresses {
    92  		switch ip.Type {
    93  		case "PRIMARY":
    94  			ipAddrs[PublicIP] = ip.IpAddress
    95  		case "PRIVATE":
    96  			ipAddrs[PrivateIP] = ip.IpAddress
    97  		}
    98  	}
    99  
   100  	// resolve DnsName into IP address for PSC
   101  	if db.DnsName != "" {
   102  		ipAddrs[PSC] = db.DnsName
   103  	}
   104  
   105  	if len(ipAddrs) == 0 {
   106  		return metadata{}, errtype.NewConfigError(
   107  			"cannot connect to instance - it has no supported IP addresses",
   108  			inst.String(),
   109  		)
   110  	}
   111  
   112  	// parse the server-side CA certificate
   113  	b, _ := pem.Decode([]byte(db.ServerCaCert.Cert))
   114  	if b == nil {
   115  		return metadata{}, errtype.NewRefreshError("failed to decode valid PEM cert", inst.String(), nil)
   116  	}
   117  	cert, err := x509.ParseCertificate(b.Bytes)
   118  	if err != nil {
   119  		return metadata{}, errtype.NewRefreshError(
   120  			fmt.Sprintf("failed to parse as X.509 certificate: %v", err),
   121  			inst.String(),
   122  			nil,
   123  		)
   124  	}
   125  
   126  	m = metadata{
   127  		ipAddrs:      ipAddrs,
   128  		serverCaCert: cert,
   129  		version:      db.DatabaseVersion,
   130  	}
   131  
   132  	return m, nil
   133  }
   134  
   135  func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) {
   136  	expiredToken := &oauth2.Token{
   137  		AccessToken:  tok.AccessToken,
   138  		TokenType:    tok.TokenType,
   139  		RefreshToken: tok.RefreshToken,
   140  		Expiry:       time.Time{}.Add(1), // Expired
   141  	}
   142  	return oauth2.ReuseTokenSource(expiredToken, ts).Token()
   143  }
   144  
   145  // fetchEphemeralCert uses the Cloud SQL Admin API's createEphemeral method to
   146  // create a signed TLS certificate that authorized to connect via the Cloud SQL
   147  // instance's serverside proxy. The cert if valid for approximately one hour.
   148  func fetchEphemeralCert(
   149  	ctx context.Context,
   150  	client *sqladmin.Service,
   151  	inst instance.ConnName,
   152  	key *rsa.PrivateKey,
   153  	ts oauth2.TokenSource,
   154  ) (c tls.Certificate, err error) {
   155  	var end trace.EndSpanFunc
   156  	ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.FetchEphemeralCert")
   157  	defer func() { end(err) }()
   158  	clientPubKey, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
   159  	if err != nil {
   160  		return tls.Certificate{}, err
   161  	}
   162  
   163  	req := sqladmin.GenerateEphemeralCertRequest{
   164  		PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: clientPubKey, Type: "RSA PUBLIC KEY"})),
   165  	}
   166  	var tok *oauth2.Token
   167  	if ts != nil {
   168  		var tokErr error
   169  		tok, tokErr = ts.Token()
   170  		if tokErr != nil {
   171  			return tls.Certificate{}, errtype.NewRefreshError(
   172  				"failed to retrieve Oauth2 token",
   173  				inst.String(),
   174  				tokErr,
   175  			)
   176  		}
   177  		// Always refresh the token to ensure its expiration is far enough in
   178  		// the future.
   179  		tok, tokErr = refreshToken(ts, tok)
   180  		if tokErr != nil {
   181  			return tls.Certificate{}, errtype.NewRefreshError(
   182  				"failed to refresh Oauth2 token",
   183  				inst.String(),
   184  				tokErr,
   185  			)
   186  		}
   187  		req.AccessToken = tok.AccessToken
   188  	}
   189  	resp, err := retry50x(ctx, func(ctx2 context.Context) (*sqladmin.GenerateEphemeralCertResponse, error) {
   190  		return client.Connect.GenerateEphemeralCert(
   191  			inst.Project(), inst.Name(), &req,
   192  		).Context(ctx2).Do()
   193  	}, exponentialBackoff)
   194  	if err != nil {
   195  		return tls.Certificate{}, errtype.NewRefreshError(
   196  			"create ephemeral cert failed",
   197  			inst.String(),
   198  			err,
   199  		)
   200  	}
   201  
   202  	// parse the client cert
   203  	b, _ := pem.Decode([]byte(resp.EphemeralCert.Cert))
   204  	if b == nil {
   205  		return tls.Certificate{}, errtype.NewRefreshError(
   206  			"failed to decode valid PEM cert",
   207  			inst.String(),
   208  			nil,
   209  		)
   210  	}
   211  	clientCert, err := x509.ParseCertificate(b.Bytes)
   212  	if err != nil {
   213  		return tls.Certificate{}, errtype.NewRefreshError(
   214  			fmt.Sprintf("failed to parse as X.509 certificate: %v", err),
   215  			inst.String(),
   216  			nil,
   217  		)
   218  	}
   219  	if ts != nil {
   220  		// Adjust the certificate's expiration to be the earliest of the token's
   221  		// expiration or the certificate's expiration.
   222  		if tok.Expiry.Before(clientCert.NotAfter) {
   223  			clientCert.NotAfter = tok.Expiry
   224  		}
   225  	}
   226  
   227  	c = tls.Certificate{
   228  		Certificate: [][]byte{clientCert.Raw},
   229  		PrivateKey:  key,
   230  		Leaf:        clientCert,
   231  	}
   232  	return c, nil
   233  }
   234  
   235  // newRefresher creates a Refresher.
   236  func newRefresher(
   237  	l debug.ContextLogger,
   238  	svc *sqladmin.Service,
   239  	ts oauth2.TokenSource,
   240  	dialerID string,
   241  ) refresher {
   242  	return refresher{
   243  		dialerID: dialerID,
   244  		logger:   l,
   245  		client:   svc,
   246  		ts:       ts,
   247  	}
   248  }
   249  
   250  // refresher manages the SQL Admin API access to instance metadata and to
   251  // ephemeral certificates.
   252  type refresher struct {
   253  	// dialerID is the unique ID of the associated dialer.
   254  	dialerID string
   255  	logger   debug.ContextLogger
   256  	client   *sqladmin.Service
   257  	// ts is the TokenSource used for IAM DB AuthN.
   258  	ts oauth2.TokenSource
   259  }
   260  
   261  // ConnectionInfo immediately performs a full refresh operation using the Cloud
   262  // SQL Admin API.
   263  func (r refresher) ConnectionInfo(
   264  	ctx context.Context, cn instance.ConnName, k *rsa.PrivateKey, iamAuthNDial bool,
   265  ) (ci ConnectionInfo, err error) {
   266  
   267  	var refreshEnd trace.EndSpanFunc
   268  	ctx, refreshEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.RefreshConnection",
   269  		trace.AddInstanceName(cn.String()),
   270  	)
   271  	defer func() {
   272  		go trace.RecordRefreshResult(context.Background(), cn.String(), r.dialerID, err)
   273  		refreshEnd(err)
   274  	}()
   275  
   276  	// start async fetching the instance's metadata
   277  	type mdRes struct {
   278  		md  metadata
   279  		err error
   280  	}
   281  	mdC := make(chan mdRes, 1)
   282  	go func() {
   283  		defer close(mdC)
   284  		md, err := fetchMetadata(ctx, r.client, cn)
   285  		mdC <- mdRes{md, err}
   286  	}()
   287  
   288  	// start async fetching the certs
   289  	type ecRes struct {
   290  		ec  tls.Certificate
   291  		err error
   292  	}
   293  	ecC := make(chan ecRes, 1)
   294  	go func() {
   295  		defer close(ecC)
   296  		var iamTS oauth2.TokenSource
   297  		if iamAuthNDial {
   298  			iamTS = r.ts
   299  		}
   300  		ec, err := fetchEphemeralCert(ctx, r.client, cn, k, iamTS)
   301  		ecC <- ecRes{ec, err}
   302  	}()
   303  
   304  	// wait for the results of each operation
   305  	var md metadata
   306  	select {
   307  	case r := <-mdC:
   308  		if r.err != nil {
   309  			return ConnectionInfo{}, fmt.Errorf("failed to get instance: %w", r.err)
   310  		}
   311  		md = r.md
   312  	case <-ctx.Done():
   313  		return ci, fmt.Errorf("refresh failed: %w", ctx.Err())
   314  	}
   315  	if iamAuthNDial {
   316  		if vErr := supportsAutoIAMAuthN(md.version); vErr != nil {
   317  			return ConnectionInfo{}, vErr
   318  		}
   319  	}
   320  
   321  	var ec tls.Certificate
   322  	select {
   323  	case r := <-ecC:
   324  		if r.err != nil {
   325  			return ConnectionInfo{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err)
   326  		}
   327  		ec = r.ec
   328  	case <-ctx.Done():
   329  		return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err())
   330  	}
   331  
   332  	return ConnectionInfo{
   333  		addrs:             md.ipAddrs,
   334  		ServerCaCert:      md.serverCaCert,
   335  		ClientCertificate: ec,
   336  		Expiration:        ec.Leaf.NotAfter,
   337  		DBVersion:         md.version,
   338  		ConnectionName:    cn,
   339  	}, nil
   340  }
   341  
   342  // supportsAutoIAMAuthN checks that the engine support automatic IAM authn. If
   343  // auto IAM authn was not request, this is a no-op.
   344  func supportsAutoIAMAuthN(version string) error {
   345  	switch {
   346  	case strings.HasPrefix(version, "POSTGRES"):
   347  		return nil
   348  	case strings.HasPrefix(version, "MYSQL"):
   349  		return nil
   350  	default:
   351  		return fmt.Errorf("%s does not support Auto IAM DB Authentication", version)
   352  	}
   353  }
   354  

View as plain text