1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cloudsql
16
17 import (
18 "context"
19 "crypto/rsa"
20 "crypto/tls"
21 "crypto/x509"
22 "encoding/pem"
23 "fmt"
24 "strings"
25 "time"
26
27 "cloud.google.com/go/cloudsqlconn/debug"
28 "cloud.google.com/go/cloudsqlconn/errtype"
29 "cloud.google.com/go/cloudsqlconn/instance"
30 "cloud.google.com/go/cloudsqlconn/internal/trace"
31 "golang.org/x/oauth2"
32 sqladmin "google.golang.org/api/sqladmin/v1beta4"
33 )
34
35 const (
36
37 PublicIP = "PUBLIC"
38
39 PrivateIP = "PRIVATE"
40
41 PSC = "PSC"
42
43
44 AutoIP = "AutoIP"
45 )
46
47
48
49 type metadata struct {
50 ipAddrs map[string]string
51 serverCaCert *x509.Certificate
52 version string
53 }
54
55
56
57
58 func fetchMetadata(
59 ctx context.Context, client *sqladmin.Service, inst instance.ConnName,
60 ) (m metadata, err error) {
61
62 var end trace.EndSpanFunc
63 ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.FetchMetadata")
64 defer func() { end(err) }()
65
66 db, err := retry50x(ctx, func(ctx2 context.Context) (*sqladmin.ConnectSettings, error) {
67 return client.Connect.Get(
68 inst.Project(), inst.Name(),
69 ).Context(ctx2).Do()
70 }, exponentialBackoff)
71 if err != nil {
72 return metadata{}, errtype.NewRefreshError("failed to get instance metadata", inst.String(), err)
73 }
74
75 if db.Region != inst.Region() {
76 msg := fmt.Sprintf(
77 "provided region was mismatched - got %s, want %s",
78 inst.Region(), db.Region,
79 )
80 return metadata{}, errtype.NewConfigError(msg, inst.String())
81 }
82 if db.BackendType != "SECOND_GEN" {
83 return metadata{}, errtype.NewConfigError(
84 "unsupported instance - only Second Generation instances are supported",
85 inst.String(),
86 )
87 }
88
89
90 ipAddrs := make(map[string]string)
91 for _, ip := range db.IpAddresses {
92 switch ip.Type {
93 case "PRIMARY":
94 ipAddrs[PublicIP] = ip.IpAddress
95 case "PRIVATE":
96 ipAddrs[PrivateIP] = ip.IpAddress
97 }
98 }
99
100
101 if db.DnsName != "" {
102 ipAddrs[PSC] = db.DnsName
103 }
104
105 if len(ipAddrs) == 0 {
106 return metadata{}, errtype.NewConfigError(
107 "cannot connect to instance - it has no supported IP addresses",
108 inst.String(),
109 )
110 }
111
112
113 b, _ := pem.Decode([]byte(db.ServerCaCert.Cert))
114 if b == nil {
115 return metadata{}, errtype.NewRefreshError("failed to decode valid PEM cert", inst.String(), nil)
116 }
117 cert, err := x509.ParseCertificate(b.Bytes)
118 if err != nil {
119 return metadata{}, errtype.NewRefreshError(
120 fmt.Sprintf("failed to parse as X.509 certificate: %v", err),
121 inst.String(),
122 nil,
123 )
124 }
125
126 m = metadata{
127 ipAddrs: ipAddrs,
128 serverCaCert: cert,
129 version: db.DatabaseVersion,
130 }
131
132 return m, nil
133 }
134
135 func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) {
136 expiredToken := &oauth2.Token{
137 AccessToken: tok.AccessToken,
138 TokenType: tok.TokenType,
139 RefreshToken: tok.RefreshToken,
140 Expiry: time.Time{}.Add(1),
141 }
142 return oauth2.ReuseTokenSource(expiredToken, ts).Token()
143 }
144
145
146
147
148 func fetchEphemeralCert(
149 ctx context.Context,
150 client *sqladmin.Service,
151 inst instance.ConnName,
152 key *rsa.PrivateKey,
153 ts oauth2.TokenSource,
154 ) (c tls.Certificate, err error) {
155 var end trace.EndSpanFunc
156 ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.FetchEphemeralCert")
157 defer func() { end(err) }()
158 clientPubKey, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
159 if err != nil {
160 return tls.Certificate{}, err
161 }
162
163 req := sqladmin.GenerateEphemeralCertRequest{
164 PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: clientPubKey, Type: "RSA PUBLIC KEY"})),
165 }
166 var tok *oauth2.Token
167 if ts != nil {
168 var tokErr error
169 tok, tokErr = ts.Token()
170 if tokErr != nil {
171 return tls.Certificate{}, errtype.NewRefreshError(
172 "failed to retrieve Oauth2 token",
173 inst.String(),
174 tokErr,
175 )
176 }
177
178
179 tok, tokErr = refreshToken(ts, tok)
180 if tokErr != nil {
181 return tls.Certificate{}, errtype.NewRefreshError(
182 "failed to refresh Oauth2 token",
183 inst.String(),
184 tokErr,
185 )
186 }
187 req.AccessToken = tok.AccessToken
188 }
189 resp, err := retry50x(ctx, func(ctx2 context.Context) (*sqladmin.GenerateEphemeralCertResponse, error) {
190 return client.Connect.GenerateEphemeralCert(
191 inst.Project(), inst.Name(), &req,
192 ).Context(ctx2).Do()
193 }, exponentialBackoff)
194 if err != nil {
195 return tls.Certificate{}, errtype.NewRefreshError(
196 "create ephemeral cert failed",
197 inst.String(),
198 err,
199 )
200 }
201
202
203 b, _ := pem.Decode([]byte(resp.EphemeralCert.Cert))
204 if b == nil {
205 return tls.Certificate{}, errtype.NewRefreshError(
206 "failed to decode valid PEM cert",
207 inst.String(),
208 nil,
209 )
210 }
211 clientCert, err := x509.ParseCertificate(b.Bytes)
212 if err != nil {
213 return tls.Certificate{}, errtype.NewRefreshError(
214 fmt.Sprintf("failed to parse as X.509 certificate: %v", err),
215 inst.String(),
216 nil,
217 )
218 }
219 if ts != nil {
220
221
222 if tok.Expiry.Before(clientCert.NotAfter) {
223 clientCert.NotAfter = tok.Expiry
224 }
225 }
226
227 c = tls.Certificate{
228 Certificate: [][]byte{clientCert.Raw},
229 PrivateKey: key,
230 Leaf: clientCert,
231 }
232 return c, nil
233 }
234
235
236 func newRefresher(
237 l debug.ContextLogger,
238 svc *sqladmin.Service,
239 ts oauth2.TokenSource,
240 dialerID string,
241 ) refresher {
242 return refresher{
243 dialerID: dialerID,
244 logger: l,
245 client: svc,
246 ts: ts,
247 }
248 }
249
250
251
252 type refresher struct {
253
254 dialerID string
255 logger debug.ContextLogger
256 client *sqladmin.Service
257
258 ts oauth2.TokenSource
259 }
260
261
262
263 func (r refresher) ConnectionInfo(
264 ctx context.Context, cn instance.ConnName, k *rsa.PrivateKey, iamAuthNDial bool,
265 ) (ci ConnectionInfo, err error) {
266
267 var refreshEnd trace.EndSpanFunc
268 ctx, refreshEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.RefreshConnection",
269 trace.AddInstanceName(cn.String()),
270 )
271 defer func() {
272 go trace.RecordRefreshResult(context.Background(), cn.String(), r.dialerID, err)
273 refreshEnd(err)
274 }()
275
276
277 type mdRes struct {
278 md metadata
279 err error
280 }
281 mdC := make(chan mdRes, 1)
282 go func() {
283 defer close(mdC)
284 md, err := fetchMetadata(ctx, r.client, cn)
285 mdC <- mdRes{md, err}
286 }()
287
288
289 type ecRes struct {
290 ec tls.Certificate
291 err error
292 }
293 ecC := make(chan ecRes, 1)
294 go func() {
295 defer close(ecC)
296 var iamTS oauth2.TokenSource
297 if iamAuthNDial {
298 iamTS = r.ts
299 }
300 ec, err := fetchEphemeralCert(ctx, r.client, cn, k, iamTS)
301 ecC <- ecRes{ec, err}
302 }()
303
304
305 var md metadata
306 select {
307 case r := <-mdC:
308 if r.err != nil {
309 return ConnectionInfo{}, fmt.Errorf("failed to get instance: %w", r.err)
310 }
311 md = r.md
312 case <-ctx.Done():
313 return ci, fmt.Errorf("refresh failed: %w", ctx.Err())
314 }
315 if iamAuthNDial {
316 if vErr := supportsAutoIAMAuthN(md.version); vErr != nil {
317 return ConnectionInfo{}, vErr
318 }
319 }
320
321 var ec tls.Certificate
322 select {
323 case r := <-ecC:
324 if r.err != nil {
325 return ConnectionInfo{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err)
326 }
327 ec = r.ec
328 case <-ctx.Done():
329 return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err())
330 }
331
332 return ConnectionInfo{
333 addrs: md.ipAddrs,
334 ServerCaCert: md.serverCaCert,
335 ClientCertificate: ec,
336 Expiration: ec.Leaf.NotAfter,
337 DBVersion: md.version,
338 ConnectionName: cn,
339 }, nil
340 }
341
342
343
344 func supportsAutoIAMAuthN(version string) error {
345 switch {
346 case strings.HasPrefix(version, "POSTGRES"):
347 return nil
348 case strings.HasPrefix(version, "MYSQL"):
349 return nil
350 default:
351 return fmt.Errorf("%s does not support Auto IAM DB Authentication", version)
352 }
353 }
354
View as plain text