1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package mock
16
17 import (
18 "context"
19 "crypto/rsa"
20 "crypto/x509"
21 "encoding/json"
22 "encoding/pem"
23 "fmt"
24 "io"
25 "net/http"
26 "net/http/httptest"
27 "sync"
28 "time"
29
30 "google.golang.org/api/option"
31 sqladmin "google.golang.org/api/sqladmin/v1beta4"
32 )
33
34
35
36
37
38
39 func httpClient(requests ...*Request) (*http.Client, string, func() error) {
40
41 s := httptest.NewTLSServer(http.HandlerFunc(
42 func(resp http.ResponseWriter, req *http.Request) {
43 for _, r := range requests {
44 if r.matches(req) {
45 r.handle(resp, req)
46 return
47 }
48 }
49
50 resp.WriteHeader(http.StatusNotImplemented)
51
52 resp.Write([]byte(fmt.Sprintf("unexpected request sent to mock client: %v", req)))
53 },
54 ))
55
56 cleanup := func() error {
57 s.Close()
58 for i, e := range requests {
59 if e.reqCt > 0 {
60 return fmt.Errorf("%d calls left for specified call in pos %d: %v", e.reqCt, i, e)
61 }
62 }
63 return nil
64 }
65
66 return s.Client(), s.URL, cleanup
67
68 }
69
70
71
72
73 type Request struct {
74 sync.Mutex
75
76 reqMethod string
77 reqPath string
78 reqCt int
79
80 handle func(resp http.ResponseWriter, req *http.Request)
81 }
82
83
84 func (r *Request) matches(hR *http.Request) bool {
85 r.Lock()
86 defer r.Unlock()
87 if r.reqMethod != "" && r.reqMethod != hR.Method {
88 return false
89 }
90 if r.reqPath != "" && r.reqPath != hR.URL.Path {
91 return false
92 }
93 if r.reqCt <= 0 {
94 return false
95 }
96 r.reqCt--
97 return true
98 }
99
100
101
102
103
104 func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request {
105 var ips []*sqladmin.IpMapping
106 for ipType, addr := range i.ipAddrs {
107 if ipType == "PUBLIC" {
108 ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
109 continue
110 }
111 if ipType == "PRIVATE" {
112 ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
113 }
114 }
115 certBytes, err := i.signedCert()
116 if err != nil {
117 panic(err)
118 }
119 db := &sqladmin.ConnectSettings{
120 BackendType: i.backendType,
121 DatabaseVersion: i.dbVersion,
122 DnsName: i.DNSName,
123 IpAddresses: ips,
124 Region: i.region,
125 ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
126 }
127
128 r := &Request{
129 reqMethod: http.MethodGet,
130 reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name),
131 reqCt: ct,
132 handle: func(resp http.ResponseWriter, _ *http.Request) {
133 b, err := db.MarshalJSON()
134 if err != nil {
135 http.Error(resp, err.Error(), http.StatusInternalServerError)
136 return
137 }
138 resp.WriteHeader(http.StatusOK)
139 resp.Write(b)
140 },
141 }
142 return r
143 }
144
145
146 func InstanceGet500(i FakeCSQLInstance, count int) *Request {
147 return &Request{
148 reqMethod: http.MethodGet,
149 reqPath: fmt.Sprintf(
150 "/sql/v1beta4/projects/%s/instances/%s/connectSettings",
151 i.project, i.name,
152 ),
153 reqCt: count,
154 handle: func(resp http.ResponseWriter, _ *http.Request) {
155 http.Error(resp, "server error", http.StatusInternalServerError)
156 },
157 }
158 }
159
160
161 func CreateEphemeral500(i FakeCSQLInstance, count int) *Request {
162 return &Request{
163 reqMethod: http.MethodPost,
164 reqPath: fmt.Sprintf(
165 "/sql/v1beta4/projects/%s/instances/%s:generateEphemeralCert",
166 i.project, i.name,
167 ),
168 reqCt: count,
169 handle: func(resp http.ResponseWriter, _ *http.Request) {
170 http.Error(resp, "server error", http.StatusInternalServerError)
171 },
172 }
173 }
174
175
176
177
178
179
180 func CreateEphemeralSuccess(i FakeCSQLInstance, ct int) *Request {
181 r := &Request{
182 reqMethod: http.MethodPost,
183 reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s:generateEphemeralCert", i.project, i.name),
184 reqCt: ct,
185 handle: func(resp http.ResponseWriter, req *http.Request) {
186
187 b, err := io.ReadAll(req.Body)
188 defer req.Body.Close()
189 if err != nil {
190 http.Error(resp, fmt.Errorf("unable to read body: %w", err).Error(), http.StatusBadRequest)
191 return
192 }
193 var eR sqladmin.GenerateEphemeralCertRequest
194 err = json.Unmarshal(b, &eR)
195 if err != nil {
196 http.Error(resp, fmt.Errorf("invalid or unexpected json: %w", err).Error(), http.StatusBadRequest)
197 return
198 }
199
200 bl, _ := pem.Decode([]byte(eR.PublicKey))
201 if bl == nil {
202 http.Error(resp, fmt.Errorf("unable to decode PublicKey: %w", err).Error(), http.StatusBadRequest)
203 return
204 }
205 pubKey, err := x509.ParsePKIXPublicKey(bl.Bytes)
206 if err != nil {
207 http.Error(resp, fmt.Errorf("unable to decode PublicKey: %w", err).Error(), http.StatusBadRequest)
208 return
209 }
210
211 certBytes, err := i.ClientCert(pubKey.(*rsa.PublicKey))
212 if err != nil {
213 http.Error(resp, fmt.Errorf("failed to sign client certificate: %v", err).Error(), http.StatusBadRequest)
214 return
215 }
216
217
218 c := &sqladmin.SslCert{
219 Cert: string(certBytes),
220 CommonName: "Google Cloud SQL Client",
221 CreateTime: time.Now().Format(time.RFC3339),
222 ExpirationTime: i.Cert.NotAfter.Format(time.RFC3339),
223 Instance: i.name,
224 }
225 certResp := sqladmin.GenerateEphemeralCertResponse{
226 EphemeralCert: c,
227 }
228 b, err = certResp.MarshalJSON()
229 if err != nil {
230 http.Error(resp, fmt.Errorf("unable to encode response: %w", err).Error(), http.StatusInternalServerError)
231 return
232 }
233 resp.WriteHeader(http.StatusOK)
234 resp.Write(b)
235 },
236 }
237 return r
238 }
239
240
241
242
243
244 func NewSQLAdminService(ctx context.Context, reqs ...*Request) (*sqladmin.Service, func() error, error) {
245 mc, url, cleanup := httpClient(reqs...)
246 client, err := sqladmin.NewService(
247 ctx,
248 option.WithHTTPClient(mc),
249 option.WithEndpoint(url),
250 )
251 return client, cleanup, err
252 }
253
View as plain text