1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package mock
16
17 import (
18 "bytes"
19 "context"
20 "crypto/rand"
21 "crypto/rsa"
22 "crypto/tls"
23 "crypto/x509"
24 "crypto/x509/pkix"
25 "encoding/pem"
26 "fmt"
27 "math/big"
28 "testing"
29 "time"
30
31 "golang.org/x/oauth2"
32 )
33
34
35 type EmptyTokenSource struct{}
36
37
38 func (EmptyTokenSource) Token() (*oauth2.Token, error) {
39 return &oauth2.Token{Expiry: time.Now().Add(time.Hour)}, nil
40 }
41
42
43
44
45 type FakeCSQLInstance struct {
46 project string
47 region string
48 name string
49 dbVersion string
50
51 ipAddrs map[string]string
52 backendType string
53 DNSName string
54 signer SignFunc
55 clientSigner ClientSignFunc
56
57 Key *rsa.PrivateKey
58
59 Cert *x509.Certificate
60 }
61
62
63
64 func (f FakeCSQLInstance) String() string {
65 return fmt.Sprintf("%v:%v:%v", f.project, f.region, f.name)
66 }
67
68 func (f FakeCSQLInstance) signedCert() ([]byte, error) {
69 return f.signer(f.Cert, f.Key)
70 }
71
72
73
74 func (f FakeCSQLInstance) ClientCert(pubKey *rsa.PublicKey) ([]byte, error) {
75 return f.clientSigner(f.Cert, f.Key, pubKey)
76 }
77
78
79 type FakeCSQLInstanceOption func(f *FakeCSQLInstance)
80
81
82 func WithPublicIP(addr string) FakeCSQLInstanceOption {
83 return func(f *FakeCSQLInstance) {
84 f.ipAddrs["PUBLIC"] = addr
85 }
86 }
87
88
89 func WithPrivateIP(addr string) FakeCSQLInstanceOption {
90 return func(f *FakeCSQLInstance) {
91 f.ipAddrs["PRIVATE"] = addr
92 }
93 }
94
95
96 func WithPSC(dns string) FakeCSQLInstanceOption {
97 return func(f *FakeCSQLInstance) {
98 f.DNSName = dns
99 }
100 }
101
102
103 func WithCertExpiry(t time.Time) FakeCSQLInstanceOption {
104 return func(f *FakeCSQLInstance) {
105 f.Cert.NotAfter = t
106 }
107 }
108
109
110 func WithRegion(region string) FakeCSQLInstanceOption {
111 return func(f *FakeCSQLInstance) {
112 f.region = region
113 }
114 }
115
116
117 func WithFirstGenBackend() FakeCSQLInstanceOption {
118 return func(f *FakeCSQLInstance) {
119 f.backendType = "FIRST_GEN"
120 }
121 }
122
123
124 func WithEngineVersion(s string) FakeCSQLInstanceOption {
125 return func(f *FakeCSQLInstance) {
126 f.dbVersion = s
127 }
128 }
129
130
131
132 type SignFunc = func(*x509.Certificate, *rsa.PrivateKey) ([]byte, error)
133
134
135
136 func WithCertSigner(s SignFunc) FakeCSQLInstanceOption {
137 return func(f *FakeCSQLInstance) {
138 f.signer = s
139 }
140 }
141
142
143
144
145 type ClientSignFunc = func(*x509.Certificate, *rsa.PrivateKey, *rsa.PublicKey) ([]byte, error)
146
147
148
149 func WithClientCertSigner(s ClientSignFunc) FakeCSQLInstanceOption {
150 return func(f *FakeCSQLInstance) {
151 f.clientSigner = s
152 }
153 }
154
155
156
157 func WithNoIPAddrs() FakeCSQLInstanceOption {
158 return func(f *FakeCSQLInstance) {
159 f.ipAddrs = map[string]string{}
160 }
161 }
162
163
164 func NewFakeCSQLInstance(project, region, name string, opts ...FakeCSQLInstanceOption) FakeCSQLInstance {
165
166 key, cert, err := generateCerts(project, name)
167 if err != nil {
168 panic(err)
169 }
170
171 f := FakeCSQLInstance{
172 project: project,
173 region: region,
174 name: name,
175 ipAddrs: map[string]string{"PUBLIC": "0.0.0.0"},
176 DNSName: "",
177 dbVersion: "POSTGRES_12",
178 backendType: "SECOND_GEN",
179 signer: SelfSign,
180 clientSigner: SignWithClientKey,
181 Key: key,
182 Cert: cert,
183 }
184 for _, o := range opts {
185 o(&f)
186 }
187 return f
188 }
189
190
191 func SelfSign(c *x509.Certificate, k *rsa.PrivateKey) ([]byte, error) {
192 certBytes, err := x509.CreateCertificate(rand.Reader, c, c, &k.PublicKey, k)
193 if err != nil {
194 return nil, err
195 }
196 certPEM := new(bytes.Buffer)
197 err = pem.Encode(certPEM, &pem.Block{
198 Type: "CERTIFICATE",
199 Bytes: certBytes,
200 })
201 if err != nil {
202 return nil, err
203 }
204 return certPEM.Bytes(), nil
205 }
206
207
208
209 func SignWithClientKey(c *x509.Certificate, k *rsa.PrivateKey, clientKey *rsa.PublicKey) ([]byte, error) {
210
211 cert := &x509.Certificate{
212 SerialNumber: &big.Int{},
213 Subject: pkix.Name{
214 Country: []string{"US"},
215 Organization: []string{"Google, Inc"},
216 CommonName: "Google Cloud SQL Client",
217 },
218 NotBefore: time.Now(),
219 NotAfter: c.NotAfter,
220 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
221 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
222 BasicConstraintsValid: true,
223 }
224 certBytes, err := x509.CreateCertificate(rand.Reader, cert, c, clientKey, k)
225 if err != nil {
226 return nil, err
227 }
228 certPEM := new(bytes.Buffer)
229 err = pem.Encode(certPEM, &pem.Block{
230 Type: "CERTIFICATE",
231 Bytes: certBytes,
232 })
233 if err != nil {
234 return nil, err
235 }
236 return certPEM.Bytes(), nil
237 }
238
239
240
241 func GenerateCertWithCommonName(i FakeCSQLInstance, cn string) []byte {
242 cert := &x509.Certificate{
243 SerialNumber: &big.Int{},
244 Subject: pkix.Name{
245 CommonName: cn,
246 },
247 NotBefore: time.Now(),
248 NotAfter: time.Now().AddDate(0, 0, 1),
249 IsCA: true,
250 }
251 signed, err := x509.CreateCertificate(
252 rand.Reader, cert, i.Cert, &i.Key.PublicKey, i.Key)
253 if err != nil {
254 panic(err)
255 }
256 return signed
257 }
258
259
260
261 func generateCerts(project, name string) (*rsa.PrivateKey, *x509.Certificate, error) {
262 key, err := rsa.GenerateKey(rand.Reader, 2048)
263 if err != nil {
264 return nil, nil, err
265 }
266
267 cert := &x509.Certificate{
268 SerialNumber: &big.Int{},
269 Subject: pkix.Name{
270 CommonName: fmt.Sprintf("%s:%s", project, name),
271 },
272 NotBefore: time.Now(),
273 NotAfter: time.Now().AddDate(0, 0, 1),
274 IsCA: true,
275 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
276 KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
277 BasicConstraintsValid: true,
278 }
279
280 return key, cert, nil
281 }
282
283
284
285
286 func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
287 certBytes, err := x509.CreateCertificate(
288 rand.Reader, i.Cert, i.Cert, &i.Key.PublicKey, i.Key)
289 if err != nil {
290 t.Fatalf("failed to create certificate: %v", err)
291 }
292
293 caPEM := &bytes.Buffer{}
294 err = pem.Encode(caPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
295 if err != nil {
296 t.Fatalf("pem.Encode: %v", err)
297 }
298
299 caKeyPEM := &bytes.Buffer{}
300 err = pem.Encode(caKeyPEM, &pem.Block{
301 Type: "RSA PRIVATE KEY",
302 Bytes: x509.MarshalPKCS1PrivateKey(i.Key),
303 })
304 if err != nil {
305 t.Fatalf("pem.Encode: %v", err)
306 }
307
308 serverCert, err := tls.X509KeyPair(caPEM.Bytes(), caKeyPEM.Bytes())
309 if err != nil {
310 t.Fatalf("failed to create X.509 Key Pair: %v", err)
311 }
312 ln, err := tls.Listen("tcp", ":3307", &tls.Config{
313 Certificates: []tls.Certificate{serverCert},
314 })
315 if err != nil {
316 t.Fatalf("failed to start listener: %v", err)
317 }
318 ctx, cancel := context.WithCancel(context.Background())
319 go func() {
320 for {
321 select {
322 case <-ctx.Done():
323 return
324 default:
325 conn, err := ln.Accept()
326 if err != nil {
327 return
328 }
329 _, _ = conn.Write([]byte(i.name))
330 _ = conn.Close()
331 }
332 }
333 }()
334 return func() {
335 cancel()
336 _ = ln.Close()
337 }
338 }
339
View as plain text