1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package certs
17
18 import (
19 "crypto/rand"
20 "crypto/rsa"
21 "crypto/tls"
22 "crypto/x509"
23 "encoding/pem"
24 "errors"
25 "fmt"
26 "math"
27 mrand "math/rand"
28 "net/http"
29 "strings"
30 "sync"
31 "time"
32
33 "github.com/GoogleCloudPlatform/cloudsql-proxy/logging"
34 "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/util"
35 "golang.org/x/oauth2"
36 "google.golang.org/api/googleapi"
37 sqladmin "google.golang.org/api/sqladmin/v1beta4"
38 )
39
40 const defaultUserAgent = "custom cloud_sql_proxy version >= 1.10"
41
42
43
44
45
46 func NewCertSource(host string, c *http.Client, checkRegion bool) *RemoteCertSource {
47 return NewCertSourceOpts(c, RemoteOpts{
48 APIBasePath: host,
49 IgnoreRegion: !checkRegion,
50 UserAgent: defaultUserAgent,
51 })
52 }
53
54
55
56 type RemoteOpts struct {
57
58
59
60 APIBasePath string
61
62
63
64
65 IgnoreRegion bool
66
67
68
69 UserAgent string
70
71
72 IPAddrTypeOpts []string
73
74
75 EnableIAMLogin bool
76
77
78 TokenSource oauth2.TokenSource
79
80
81
82
83 DelayKeyGenerate bool
84 }
85
86
87
88
89
90
91 func NewCertSourceOpts(c *http.Client, opts RemoteOpts) *RemoteCertSource {
92 serv, err := sqladmin.New(c)
93 if err != nil {
94 panic(err)
95 }
96 if opts.APIBasePath != "" {
97 serv.BasePath = opts.APIBasePath
98 }
99 ua := opts.UserAgent
100 if ua == "" {
101 ua = defaultUserAgent
102 }
103 serv.UserAgent = ua
104
105
106 if len(opts.IPAddrTypeOpts) == 0 {
107 opts.IPAddrTypeOpts = []string{"PUBLIC", "PRIVATE"}
108 }
109
110
111 for index, ipAddressType := range opts.IPAddrTypeOpts {
112 if strings.ToUpper(ipAddressType) == "PUBLIC" {
113 opts.IPAddrTypeOpts[index] = "PRIMARY"
114 }
115 }
116
117 certSource := &RemoteCertSource{
118 serv: serv,
119 checkRegion: !opts.IgnoreRegion,
120 IPAddrTypes: opts.IPAddrTypeOpts,
121 EnableIAMLogin: opts.EnableIAMLogin,
122 TokenSource: opts.TokenSource,
123 }
124 if !opts.DelayKeyGenerate {
125
126 go certSource.generateKey()
127 }
128
129 return certSource
130 }
131
132
133
134
135
136 type RemoteCertSource struct {
137
138 keyOnce sync.Once
139
140 key *rsa.PrivateKey
141
142 serv *sqladmin.Service
143
144
145
146 checkRegion bool
147
148 IPAddrTypes []string
149
150 EnableIAMLogin bool
151
152 TokenSource oauth2.TokenSource
153 }
154
155
156
157 const (
158 baseBackoff = float64(200 * time.Millisecond)
159 backoffMult = 1.618
160 backoffRetries = 5
161 )
162
163 func backoffAPIRetry(desc, instance string, do func(staleRead time.Time) error) error {
164 var (
165 err error
166 t time.Time
167 )
168 for i := 0; i < backoffRetries; i++ {
169 err = do(t)
170 gErr, ok := err.(*googleapi.Error)
171 switch {
172 case !ok:
173
174 return err
175 case gErr.Code == 403 && len(gErr.Errors) > 0 && gErr.Errors[0].Reason == "insufficientPermissions":
176
177 return fmt.Errorf("ensure that the Cloud SQL API is enabled for your project (https://console.cloud.google.com/flows/enableapi?apiid=sqladmin). Error during %s %s: %v", desc, instance, err)
178 case gErr.Code == 404 || gErr.Code == 403:
179 return fmt.Errorf("ensure that the account has access to %q (and make sure there's no typo in that name). Error during %s %s: %v", instance, desc, instance, err)
180 case gErr.Code < 500:
181
182 return err
183 }
184
185
186 exp := float64(i+1) + mrand.Float64()
187 sleep := time.Duration(baseBackoff * math.Pow(backoffMult, exp))
188 logging.Errorf("Error in %s %s: %v; retrying in %v", desc, instance, err, sleep)
189 time.Sleep(sleep)
190
191 t = time.Now().UTC().Add(-30 * time.Second)
192 }
193 return err
194 }
195
196 func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) {
197 expiredToken := &oauth2.Token{
198 AccessToken: tok.AccessToken,
199 TokenType: tok.TokenType,
200 RefreshToken: tok.RefreshToken,
201 Expiry: time.Time{}.Add(1),
202 }
203 return oauth2.ReuseTokenSource(expiredToken, ts).Token()
204 }
205
206
207
208 func (s *RemoteCertSource) Local(instance string) (tls.Certificate, error) {
209 pkix, err := x509.MarshalPKIXPublicKey(s.generateKey().Public())
210 if err != nil {
211 return tls.Certificate{}, err
212 }
213
214 p, r, n := util.SplitName(instance)
215 regionName := fmt.Sprintf("%s~%s", r, n)
216 pubKey := string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"}))
217 generateEphemeralCertRequest := &sqladmin.GenerateEphemeralCertRequest{
218 PublicKey: pubKey,
219 }
220 var tok *oauth2.Token
221
222
223 if s.EnableIAMLogin {
224 var tokErr error
225 tok, tokErr = s.TokenSource.Token()
226 if tokErr != nil {
227 return tls.Certificate{}, tokErr
228 }
229
230
231 tok, tokErr = refreshToken(s.TokenSource, tok)
232 if tokErr != nil {
233 return tls.Certificate{}, tokErr
234 }
235
236
237 generateEphemeralCertRequest.AccessToken = strings.TrimRight(tok.AccessToken, ".")
238 }
239 req := s.serv.Connect.GenerateEphemeralCert(p, regionName, generateEphemeralCertRequest)
240
241 var data *sqladmin.GenerateEphemeralCertResponse
242 err = backoffAPIRetry("generateEphemeral for", instance, func(staleRead time.Time) error {
243 if !staleRead.IsZero() {
244 generateEphemeralCertRequest.ReadTime = staleRead.Format(time.RFC3339)
245 }
246 data, err = req.Do()
247 return err
248 })
249 if err != nil {
250 return tls.Certificate{}, err
251 }
252
253 c, err := parseCert(data.EphemeralCert.Cert)
254 if err != nil {
255 return tls.Certificate{}, fmt.Errorf("couldn't parse ephemeral certificate for instance %q: %v", instance, err)
256 }
257
258 if s.EnableIAMLogin {
259
260 if tok.Expiry.Before(c.NotAfter) {
261 c.NotAfter = tok.Expiry
262 }
263 }
264 return tls.Certificate{
265 Certificate: [][]byte{c.Raw},
266 PrivateKey: s.generateKey(),
267 Leaf: c,
268 }, nil
269 }
270
271 func parseCert(pemCert string) (*x509.Certificate, error) {
272 bl, _ := pem.Decode([]byte(pemCert))
273 if bl == nil {
274 return nil, errors.New("invalid PEM: " + pemCert)
275 }
276 return x509.ParseCertificate(bl.Bytes)
277 }
278
279
280 func (s *RemoteCertSource) generateKey() *rsa.PrivateKey {
281 s.keyOnce.Do(func() {
282 start := time.Now()
283 pkey, err := rsa.GenerateKey(rand.Reader, 2048)
284 if err != nil {
285 panic(err)
286 }
287 logging.Verbosef("Generated RSA key in %v", time.Since(start))
288 s.key = pkey
289 })
290 return s.key
291 }
292
293
294 func (s *RemoteCertSource) findIPAddr(data *sqladmin.ConnectSettings, instance string) (ipAddrInUse string, err error) {
295 for _, eachIPAddrTypeByUser := range s.IPAddrTypes {
296 for _, eachIPAddrTypeOfInstance := range data.IpAddresses {
297 if strings.ToUpper(eachIPAddrTypeOfInstance.Type) == strings.ToUpper(eachIPAddrTypeByUser) {
298 ipAddrInUse = eachIPAddrTypeOfInstance.IpAddress
299 return ipAddrInUse, nil
300 }
301 }
302 }
303
304 ipAddrTypesOfInstance := ""
305 for _, eachIPAddrTypeOfInstance := range data.IpAddresses {
306 ipAddrTypesOfInstance += fmt.Sprintf("(TYPE=%v, IP_ADDR=%v)", eachIPAddrTypeOfInstance.Type, eachIPAddrTypeOfInstance.IpAddress)
307 }
308
309 ipAddrTypeOfUser := fmt.Sprintf("%v", s.IPAddrTypes)
310
311 return "", fmt.Errorf("User input IP address type %v does not match the instance %v, the instance's IP addresses are %v ", ipAddrTypeOfUser, instance, ipAddrTypesOfInstance)
312 }
313
314
315 func (s *RemoteCertSource) Remote(instance string) (cert *x509.Certificate, addr, name, version string, err error) {
316 p, region, n := util.SplitName(instance)
317 regionName := fmt.Sprintf("%s~%s", region, n)
318 req := s.serv.Connect.Get(p, regionName)
319
320 var data *sqladmin.ConnectSettings
321 err = backoffAPIRetry("get instance", instance, func(staleRead time.Time) error {
322 if !staleRead.IsZero() {
323 req.ReadTime(staleRead.Format(time.RFC3339))
324 }
325 data, err = req.Do()
326 return err
327 })
328 if err != nil {
329 return nil, "", "", "", err
330 }
331
332
333 if data.Region == "us-central" {
334 data.Region = "us-central1"
335 }
336 if data.Region != region {
337 if region == "" {
338 err = fmt.Errorf("instance %v doesn't provide region", instance)
339 } else {
340 err = fmt.Errorf(`for connection string "%s": got region %q, want %q`, instance, region, data.Region)
341 }
342 if s.checkRegion {
343 return nil, "", "", "", err
344 }
345 logging.Errorf("%v", err)
346 logging.Errorf("WARNING: specifying the correct region in an instance string will become required in a future version!")
347 }
348
349 if len(data.IpAddresses) == 0 {
350 return nil, "", "", "", fmt.Errorf("no IP address found for %v", instance)
351 }
352 if data.BackendType == "FIRST_GEN" {
353 logging.Errorf("WARNING: proxy client does not support first generation Cloud SQL instances.")
354 return nil, "", "", "", fmt.Errorf("%q is a first generation instance", instance)
355 }
356
357
358 ipAddrInUse := ""
359 ipAddrInUse, err = s.findIPAddr(data, instance)
360 if err != nil {
361 return nil, "", "", "", err
362 }
363
364 c, err := parseCert(data.ServerCaCert.Cert)
365
366 return c, ipAddrInUse, p + ":" + n, data.DatabaseVersion, err
367 }
368
View as plain text