...

Source file src/cloud.google.com/go/cloudsqlconn/internal/mock/cloudsql.go

Documentation: cloud.google.com/go/cloudsqlconn/internal/mock

     1  // Copyright 2021 Google LLC
     2  
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mock
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/rand"
    21  	"crypto/rsa"
    22  	"crypto/tls"
    23  	"crypto/x509"
    24  	"crypto/x509/pkix"
    25  	"encoding/pem"
    26  	"fmt"
    27  	"math/big"
    28  	"testing"
    29  	"time"
    30  
    31  	"golang.org/x/oauth2"
    32  )
    33  
    34  // EmptyTokenSource is an Oauth2.TokenSource that returns empty tokens.
    35  type EmptyTokenSource struct{}
    36  
    37  // Token provides an empty oauth2.Token.
    38  func (EmptyTokenSource) Token() (*oauth2.Token, error) {
    39  	return &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, nil
    40  }
    41  
    42  // FakeCSQLInstance represents settings for a specific Cloud SQL instance.
    43  //
    44  // Use NewFakeCSQLInstance to instantiate.
    45  type FakeCSQLInstance struct {
    46  	project   string
    47  	region    string
    48  	name      string
    49  	dbVersion string
    50  	// ipAddrs is a map of IP type (PUBLIC or PRIVATE) to IP address.
    51  	ipAddrs      map[string]string
    52  	backendType  string
    53  	DNSName      string
    54  	signer       SignFunc
    55  	clientSigner ClientSignFunc
    56  	// Key is the server's private key
    57  	Key *rsa.PrivateKey
    58  	// Cert is the server's certificate
    59  	Cert *x509.Certificate
    60  }
    61  
    62  // String returns the instance connection name for the
    63  // instance.
    64  func (f FakeCSQLInstance) String() string {
    65  	return fmt.Sprintf("%v:%v:%v", f.project, f.region, f.name)
    66  }
    67  
    68  func (f FakeCSQLInstance) signedCert() ([]byte, error) {
    69  	return f.signer(f.Cert, f.Key)
    70  }
    71  
    72  // ClientCert creates an ephemeral client certificate signed with the Cloud SQL
    73  // instance's private key. The return value is PEM encoded.
    74  func (f FakeCSQLInstance) ClientCert(pubKey *rsa.PublicKey) ([]byte, error) {
    75  	return f.clientSigner(f.Cert, f.Key, pubKey)
    76  }
    77  
    78  // FakeCSQLInstanceOption is a function that configures a FakeCSQLInstance.
    79  type FakeCSQLInstanceOption func(f *FakeCSQLInstance)
    80  
    81  // WithPublicIP sets the public IP address to addr.
    82  func WithPublicIP(addr string) FakeCSQLInstanceOption {
    83  	return func(f *FakeCSQLInstance) {
    84  		f.ipAddrs["PUBLIC"] = addr
    85  	}
    86  }
    87  
    88  // WithPrivateIP sets the private IP address to addr.
    89  func WithPrivateIP(addr string) FakeCSQLInstanceOption {
    90  	return func(f *FakeCSQLInstance) {
    91  		f.ipAddrs["PRIVATE"] = addr
    92  	}
    93  }
    94  
    95  // WithPSC sets the PSC DnsName to addr.
    96  func WithPSC(dns string) FakeCSQLInstanceOption {
    97  	return func(f *FakeCSQLInstance) {
    98  		f.DNSName = dns
    99  	}
   100  }
   101  
   102  // WithCertExpiry sets the server certificate's expiration to t.
   103  func WithCertExpiry(t time.Time) FakeCSQLInstanceOption {
   104  	return func(f *FakeCSQLInstance) {
   105  		f.Cert.NotAfter = t
   106  	}
   107  }
   108  
   109  // WithRegion sets the server's region to the provided value.
   110  func WithRegion(region string) FakeCSQLInstanceOption {
   111  	return func(f *FakeCSQLInstance) {
   112  		f.region = region
   113  	}
   114  }
   115  
   116  // WithFirstGenBackend sets the server backend type to FIRST_GEN.
   117  func WithFirstGenBackend() FakeCSQLInstanceOption {
   118  	return func(f *FakeCSQLInstance) {
   119  		f.backendType = "FIRST_GEN"
   120  	}
   121  }
   122  
   123  // WithEngineVersion sets the "DB Version"
   124  func WithEngineVersion(s string) FakeCSQLInstanceOption {
   125  	return func(f *FakeCSQLInstance) {
   126  		f.dbVersion = s
   127  	}
   128  }
   129  
   130  // SignFunc is a function that signs the certificate using the provided key. The
   131  // result should be PEM-encoded.
   132  type SignFunc = func(*x509.Certificate, *rsa.PrivateKey) ([]byte, error)
   133  
   134  // WithCertSigner configures the signing function used to generate a signed
   135  // certificate.
   136  func WithCertSigner(s SignFunc) FakeCSQLInstanceOption {
   137  	return func(f *FakeCSQLInstance) {
   138  		f.signer = s
   139  	}
   140  }
   141  
   142  // ClientSignFunc is a function that produces a certificate signed using the
   143  // provided certificate, using the server's private key and the client's public
   144  // key. The result should be PEM-encoded.
   145  type ClientSignFunc = func(*x509.Certificate, *rsa.PrivateKey, *rsa.PublicKey) ([]byte, error)
   146  
   147  // WithClientCertSigner configures the signing function used to generate a
   148  // certificate signed with the client's public key.
   149  func WithClientCertSigner(s ClientSignFunc) FakeCSQLInstanceOption {
   150  	return func(f *FakeCSQLInstance) {
   151  		f.clientSigner = s
   152  	}
   153  }
   154  
   155  // WithNoIPAddrs configures a Fake Cloud SQL instance to have no IP
   156  // addresses.
   157  func WithNoIPAddrs() FakeCSQLInstanceOption {
   158  	return func(f *FakeCSQLInstance) {
   159  		f.ipAddrs = map[string]string{}
   160  	}
   161  }
   162  
   163  // NewFakeCSQLInstance returns a CloudSQLInst object for configuring mocks.
   164  func NewFakeCSQLInstance(project, region, name string, opts ...FakeCSQLInstanceOption) FakeCSQLInstance {
   165  	// TODO: consider options for this?
   166  	key, cert, err := generateCerts(project, name)
   167  	if err != nil {
   168  		panic(err)
   169  	}
   170  
   171  	f := FakeCSQLInstance{
   172  		project:      project,
   173  		region:       region,
   174  		name:         name,
   175  		ipAddrs:      map[string]string{"PUBLIC": "0.0.0.0"},
   176  		DNSName:      "",
   177  		dbVersion:    "POSTGRES_12", // default of no particular importance
   178  		backendType:  "SECOND_GEN",
   179  		signer:       SelfSign,
   180  		clientSigner: SignWithClientKey,
   181  		Key:          key,
   182  		Cert:         cert,
   183  	}
   184  	for _, o := range opts {
   185  		o(&f)
   186  	}
   187  	return f
   188  }
   189  
   190  // SelfSign produces a PEM encoded certificate that is self-signed.
   191  func SelfSign(c *x509.Certificate, k *rsa.PrivateKey) ([]byte, error) {
   192  	certBytes, err := x509.CreateCertificate(rand.Reader, c, c, &k.PublicKey, k)
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  	certPEM := new(bytes.Buffer)
   197  	err = pem.Encode(certPEM, &pem.Block{
   198  		Type:  "CERTIFICATE",
   199  		Bytes: certBytes,
   200  	})
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	return certPEM.Bytes(), nil
   205  }
   206  
   207  // SignWithClientKey produces a PEM encoded certificate signed by the parent
   208  // certificate c using the server's private key and the client's public key.
   209  func SignWithClientKey(c *x509.Certificate, k *rsa.PrivateKey, clientKey *rsa.PublicKey) ([]byte, error) {
   210  	// Create a signed cert from the client's public key.
   211  	cert := &x509.Certificate{ // TODO: Validate this format vs API
   212  		SerialNumber: &big.Int{},
   213  		Subject: pkix.Name{
   214  			Country:      []string{"US"},
   215  			Organization: []string{"Google, Inc"},
   216  			CommonName:   "Google Cloud SQL Client",
   217  		},
   218  		NotBefore:             time.Now(),
   219  		NotAfter:              c.NotAfter,
   220  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   221  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   222  		BasicConstraintsValid: true,
   223  	}
   224  	certBytes, err := x509.CreateCertificate(rand.Reader, cert, c, clientKey, k)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	certPEM := new(bytes.Buffer)
   229  	err = pem.Encode(certPEM, &pem.Block{
   230  		Type:  "CERTIFICATE",
   231  		Bytes: certBytes,
   232  	})
   233  	if err != nil {
   234  		return nil, err
   235  	}
   236  	return certPEM.Bytes(), nil
   237  }
   238  
   239  // GenerateCertWithCommonName produces a certificate signed by the Fake Cloud
   240  // SQL instance's CA with the specified common name cn.
   241  func GenerateCertWithCommonName(i FakeCSQLInstance, cn string) []byte {
   242  	cert := &x509.Certificate{
   243  		SerialNumber: &big.Int{},
   244  		Subject: pkix.Name{
   245  			CommonName: cn,
   246  		},
   247  		NotBefore: time.Now(),
   248  		NotAfter:  time.Now().AddDate(0, 0, 1),
   249  		IsCA:      true,
   250  	}
   251  	signed, err := x509.CreateCertificate(
   252  		rand.Reader, cert, i.Cert, &i.Key.PublicKey, i.Key)
   253  	if err != nil {
   254  		panic(err)
   255  	}
   256  	return signed
   257  }
   258  
   259  // generateCerts generates a private key, an X.509 certificate, and a TLS
   260  // certificate for a particular fake Cloud SQL database instance.
   261  func generateCerts(project, name string) (*rsa.PrivateKey, *x509.Certificate, error) {
   262  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   263  	if err != nil {
   264  		return nil, nil, err
   265  	}
   266  
   267  	cert := &x509.Certificate{
   268  		SerialNumber: &big.Int{},
   269  		Subject: pkix.Name{
   270  			CommonName: fmt.Sprintf("%s:%s", project, name),
   271  		},
   272  		NotBefore:             time.Now(),
   273  		NotAfter:              time.Now().AddDate(0, 0, 1),
   274  		IsCA:                  true,
   275  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   276  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   277  		BasicConstraintsValid: true,
   278  	}
   279  
   280  	return key, cert, nil
   281  }
   282  
   283  // StartServerProxy starts a fake server proxy and listens on the provided port
   284  // on all interfaces, configured with TLS as specified by the FakeCSQLInstance.
   285  // Callers should invoke the returned function to clean up all resources.
   286  func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
   287  	certBytes, err := x509.CreateCertificate(
   288  		rand.Reader, i.Cert, i.Cert, &i.Key.PublicKey, i.Key)
   289  	if err != nil {
   290  		t.Fatalf("failed to create certificate: %v", err)
   291  	}
   292  
   293  	caPEM := &bytes.Buffer{}
   294  	err = pem.Encode(caPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
   295  	if err != nil {
   296  		t.Fatalf("pem.Encode: %v", err)
   297  	}
   298  
   299  	caKeyPEM := &bytes.Buffer{}
   300  	err = pem.Encode(caKeyPEM, &pem.Block{
   301  		Type:  "RSA PRIVATE KEY",
   302  		Bytes: x509.MarshalPKCS1PrivateKey(i.Key),
   303  	})
   304  	if err != nil {
   305  		t.Fatalf("pem.Encode: %v", err)
   306  	}
   307  
   308  	serverCert, err := tls.X509KeyPair(caPEM.Bytes(), caKeyPEM.Bytes())
   309  	if err != nil {
   310  		t.Fatalf("failed to create X.509 Key Pair: %v", err)
   311  	}
   312  	ln, err := tls.Listen("tcp", ":3307", &tls.Config{
   313  		Certificates: []tls.Certificate{serverCert},
   314  	})
   315  	if err != nil {
   316  		t.Fatalf("failed to start listener: %v", err)
   317  	}
   318  	ctx, cancel := context.WithCancel(context.Background())
   319  	go func() {
   320  		for {
   321  			select {
   322  			case <-ctx.Done():
   323  				return
   324  			default:
   325  				conn, err := ln.Accept()
   326  				if err != nil {
   327  					return
   328  				}
   329  				_, _ = conn.Write([]byte(i.name))
   330  				_ = conn.Close()
   331  			}
   332  		}
   333  	}()
   334  	return func() {
   335  		cancel()
   336  		_ = ln.Close()
   337  	}
   338  }
   339  

View as plain text