...

Source file src/cloud.google.com/go/cloudsqlconn/internal/mock/sqladmin.go

Documentation: cloud.google.com/go/cloudsqlconn/internal/mock

     1  // Copyright 2021 Google LLC
     2  
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mock
    16  
    17  import (
    18  	"context"
    19  	"crypto/rsa"
    20  	"crypto/x509"
    21  	"encoding/json"
    22  	"encoding/pem"
    23  	"fmt"
    24  	"io"
    25  	"net/http"
    26  	"net/http/httptest"
    27  	"sync"
    28  	"time"
    29  
    30  	"google.golang.org/api/option"
    31  	sqladmin "google.golang.org/api/sqladmin/v1beta4"
    32  )
    33  
    34  // httpClient returns an *http.Client, URL, and cleanup function. The http.Client is
    35  // configured to connect to test SSL Server at the returned URL. This server will
    36  // respond to HTTP requests defined, or return a 5xx server error for unexpected ones.
    37  // The cleanup function will close the server, and return an error if any expected calls
    38  // weren't received.
    39  func httpClient(requests ...*Request) (*http.Client, string, func() error) {
    40  	// Create a TLS Server that responses to the requests defined
    41  	s := httptest.NewTLSServer(http.HandlerFunc(
    42  		func(resp http.ResponseWriter, req *http.Request) {
    43  			for _, r := range requests {
    44  				if r.matches(req) {
    45  					r.handle(resp, req)
    46  					return
    47  				}
    48  			}
    49  			// Unexpected requests should throw an error
    50  			resp.WriteHeader(http.StatusNotImplemented)
    51  			// TODO: follow error format better?
    52  			resp.Write([]byte(fmt.Sprintf("unexpected request sent to mock client: %v", req)))
    53  		},
    54  	))
    55  	// cleanup stops the test server and checks for uncalled requests
    56  	cleanup := func() error {
    57  		s.Close()
    58  		for i, e := range requests {
    59  			if e.reqCt > 0 {
    60  				return fmt.Errorf("%d calls left for specified call in pos %d: %v", e.reqCt, i, e)
    61  			}
    62  		}
    63  		return nil
    64  	}
    65  
    66  	return s.Client(), s.URL, cleanup
    67  
    68  }
    69  
    70  // Request represents a HTTP request for a test Server to mock responses for.
    71  //
    72  // Use NewRequest to initialize new Requests.
    73  type Request struct {
    74  	sync.Mutex
    75  
    76  	reqMethod string
    77  	reqPath   string
    78  	reqCt     int
    79  
    80  	handle func(resp http.ResponseWriter, req *http.Request)
    81  }
    82  
    83  // matches returns true if a given http.Request should be handled by this Request.
    84  func (r *Request) matches(hR *http.Request) bool {
    85  	r.Lock()
    86  	defer r.Unlock()
    87  	if r.reqMethod != "" && r.reqMethod != hR.Method {
    88  		return false
    89  	}
    90  	if r.reqPath != "" && r.reqPath != hR.URL.Path {
    91  		return false
    92  	}
    93  	if r.reqCt <= 0 {
    94  		return false
    95  	}
    96  	r.reqCt--
    97  	return true
    98  }
    99  
   100  // InstanceGetSuccess returns a Request that responds to the `instance.get` SQL Admin
   101  // endpoint. It responds with a "StatusOK" and a DatabaseInstance object.
   102  //
   103  // https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/get
   104  func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request {
   105  	var ips []*sqladmin.IpMapping
   106  	for ipType, addr := range i.ipAddrs {
   107  		if ipType == "PUBLIC" {
   108  			ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
   109  			continue
   110  		}
   111  		if ipType == "PRIVATE" {
   112  			ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
   113  		}
   114  	}
   115  	certBytes, err := i.signedCert()
   116  	if err != nil {
   117  		panic(err)
   118  	}
   119  	db := &sqladmin.ConnectSettings{
   120  		BackendType:     i.backendType,
   121  		DatabaseVersion: i.dbVersion,
   122  		DnsName:         i.DNSName,
   123  		IpAddresses:     ips,
   124  		Region:          i.region,
   125  		ServerCaCert:    &sqladmin.SslCert{Cert: string(certBytes)},
   126  	}
   127  
   128  	r := &Request{
   129  		reqMethod: http.MethodGet,
   130  		reqPath:   fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name),
   131  		reqCt:     ct,
   132  		handle: func(resp http.ResponseWriter, _ *http.Request) {
   133  			b, err := db.MarshalJSON()
   134  			if err != nil {
   135  				http.Error(resp, err.Error(), http.StatusInternalServerError)
   136  				return
   137  			}
   138  			resp.WriteHeader(http.StatusOK)
   139  			resp.Write(b)
   140  		},
   141  	}
   142  	return r
   143  }
   144  
   145  // InstanceGet500 returns a 500 HTTP response
   146  func InstanceGet500(i FakeCSQLInstance, count int) *Request {
   147  	return &Request{
   148  		reqMethod: http.MethodGet,
   149  		reqPath: fmt.Sprintf(
   150  			"/sql/v1beta4/projects/%s/instances/%s/connectSettings",
   151  			i.project, i.name,
   152  		),
   153  		reqCt: count,
   154  		handle: func(resp http.ResponseWriter, _ *http.Request) {
   155  			http.Error(resp, "server error", http.StatusInternalServerError)
   156  		},
   157  	}
   158  }
   159  
   160  // CreateEphemeral500 returns a 500 HTTP response.
   161  func CreateEphemeral500(i FakeCSQLInstance, count int) *Request {
   162  	return &Request{
   163  		reqMethod: http.MethodPost,
   164  		reqPath: fmt.Sprintf(
   165  			"/sql/v1beta4/projects/%s/instances/%s:generateEphemeralCert",
   166  			i.project, i.name,
   167  		),
   168  		reqCt: count,
   169  		handle: func(resp http.ResponseWriter, _ *http.Request) {
   170  			http.Error(resp, "server error", http.StatusInternalServerError)
   171  		},
   172  	}
   173  }
   174  
   175  // CreateEphemeralSuccess returns a Request that responds to the
   176  // `connect.generateEphemeralCert` SQL Admin endpoint. It responds with a
   177  // "StatusOK" and a SslCerts object.
   178  //
   179  // https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/connect/generateEphemeralCert
   180  func CreateEphemeralSuccess(i FakeCSQLInstance, ct int) *Request {
   181  	r := &Request{
   182  		reqMethod: http.MethodPost,
   183  		reqPath:   fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s:generateEphemeralCert", i.project, i.name),
   184  		reqCt:     ct,
   185  		handle: func(resp http.ResponseWriter, req *http.Request) {
   186  			// Read the body from the request.
   187  			b, err := io.ReadAll(req.Body)
   188  			defer req.Body.Close()
   189  			if err != nil {
   190  				http.Error(resp, fmt.Errorf("unable to read body: %w", err).Error(), http.StatusBadRequest)
   191  				return
   192  			}
   193  			var eR sqladmin.GenerateEphemeralCertRequest
   194  			err = json.Unmarshal(b, &eR)
   195  			if err != nil {
   196  				http.Error(resp, fmt.Errorf("invalid or unexpected json: %w", err).Error(), http.StatusBadRequest)
   197  				return
   198  			}
   199  			// Extract the certificate from the request.
   200  			bl, _ := pem.Decode([]byte(eR.PublicKey))
   201  			if bl == nil {
   202  				http.Error(resp, fmt.Errorf("unable to decode PublicKey: %w", err).Error(), http.StatusBadRequest)
   203  				return
   204  			}
   205  			pubKey, err := x509.ParsePKIXPublicKey(bl.Bytes)
   206  			if err != nil {
   207  				http.Error(resp, fmt.Errorf("unable to decode PublicKey: %w", err).Error(), http.StatusBadRequest)
   208  				return
   209  			}
   210  
   211  			certBytes, err := i.ClientCert(pubKey.(*rsa.PublicKey))
   212  			if err != nil {
   213  				http.Error(resp, fmt.Errorf("failed to sign client certificate: %v", err).Error(), http.StatusBadRequest)
   214  				return
   215  			}
   216  
   217  			// Return the signed cert to the client.
   218  			c := &sqladmin.SslCert{
   219  				Cert:           string(certBytes),
   220  				CommonName:     "Google Cloud SQL Client",
   221  				CreateTime:     time.Now().Format(time.RFC3339),
   222  				ExpirationTime: i.Cert.NotAfter.Format(time.RFC3339),
   223  				Instance:       i.name,
   224  			}
   225  			certResp := sqladmin.GenerateEphemeralCertResponse{
   226  				EphemeralCert: c,
   227  			}
   228  			b, err = certResp.MarshalJSON()
   229  			if err != nil {
   230  				http.Error(resp, fmt.Errorf("unable to encode response: %w", err).Error(), http.StatusInternalServerError)
   231  				return
   232  			}
   233  			resp.WriteHeader(http.StatusOK)
   234  			resp.Write(b)
   235  		},
   236  	}
   237  	return r
   238  }
   239  
   240  // NewSQLAdminService creates a SQL Admin API service backed by a mock HTTP
   241  // backend. Callers should use the cleanup function to close down the server. If
   242  // the cleanup function returns an error, a caller has not exercised all the
   243  // registered requests.
   244  func NewSQLAdminService(ctx context.Context, reqs ...*Request) (*sqladmin.Service, func() error, error) {
   245  	mc, url, cleanup := httpClient(reqs...)
   246  	client, err := sqladmin.NewService(
   247  		ctx,
   248  		option.WithHTTPClient(mc),
   249  		option.WithEndpoint(url),
   250  	)
   251  	return client, cleanup, err
   252  }
   253  

View as plain text