1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package externalaccount
16
17 import (
18 "bytes"
19 "context"
20 "crypto/hmac"
21 "crypto/sha256"
22 "encoding/hex"
23 "encoding/json"
24 "errors"
25 "fmt"
26 "net/http"
27 "net/url"
28 "os"
29 "path"
30 "sort"
31 "strings"
32 "time"
33
34 "cloud.google.com/go/auth/internal"
35 )
36
37 var (
38
39 getenv = os.Getenv
40 )
41
42 const (
43
44 awsAlgorithm = "AWS4-HMAC-SHA256"
45
46
47
48 awsRequestType = "aws4_request"
49
50
51 awsSecurityTokenHeader = "x-amz-security-token"
52
53
54 awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
55
56 awsIMDSv2SessionTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
57
58 awsIMDSv2SessionTTL = "300"
59
60
61 awsDateHeader = "x-amz-date"
62
63 defaultRegionalCredentialVerificationURL = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
64
65
66 awsAccessKeyIDEnvVar = "AWS_ACCESS_KEY_ID"
67 awsDefaultRegionEnvVar = "AWS_DEFAULT_REGION"
68 awsRegionEnvVar = "AWS_REGION"
69 awsSecretAccessKeyEnvVar = "AWS_SECRET_ACCESS_KEY"
70 awsSessionTokenEnvVar = "AWS_SESSION_TOKEN"
71
72 awsTimeFormatLong = "20060102T150405Z"
73 awsTimeFormatShort = "20060102"
74 awsProviderType = "aws"
75 )
76
77 type awsSubjectProvider struct {
78 EnvironmentID string
79 RegionURL string
80 RegionalCredVerificationURL string
81 CredVerificationURL string
82 IMDSv2SessionTokenURL string
83 TargetResource string
84 requestSigner *awsRequestSigner
85 region string
86 securityCredentialsProvider AwsSecurityCredentialsProvider
87 reqOpts *RequestOptions
88
89 Client *http.Client
90 }
91
92 func (sp *awsSubjectProvider) subjectToken(ctx context.Context) (string, error) {
93
94 if sp.RegionalCredVerificationURL == "" {
95 sp.RegionalCredVerificationURL = defaultRegionalCredentialVerificationURL
96 }
97 if sp.requestSigner == nil {
98 headers := make(map[string]string)
99 if sp.shouldUseMetadataServer() {
100 awsSessionToken, err := sp.getAWSSessionToken(ctx)
101 if err != nil {
102 return "", err
103 }
104
105 if awsSessionToken != "" {
106 headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
107 }
108 }
109
110 awsSecurityCredentials, err := sp.getSecurityCredentials(ctx, headers)
111 if err != nil {
112 return "", err
113 }
114 if sp.region, err = sp.getRegion(ctx, headers); err != nil {
115 return "", err
116 }
117 sp.requestSigner = &awsRequestSigner{
118 RegionName: sp.region,
119 AwsSecurityCredentials: awsSecurityCredentials,
120 }
121 }
122
123
124
125 req, err := http.NewRequest("POST", strings.Replace(sp.RegionalCredVerificationURL, "{region}", sp.region, 1), nil)
126 if err != nil {
127 return "", err
128 }
129
130
131
132
133 if sp.TargetResource != "" {
134 req.Header.Set("x-goog-cloud-target-resource", sp.TargetResource)
135 }
136 sp.requestSigner.signRequest(req)
137
138
152
153 awsSignedReq := awsRequest{
154 URL: req.URL.String(),
155 Method: "POST",
156 }
157 for headerKey, headerList := range req.Header {
158 for _, headerValue := range headerList {
159 awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
160 Key: headerKey,
161 Value: headerValue,
162 })
163 }
164 }
165 sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
166 headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
167 if headerCompare == 0 {
168 return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
169 }
170 return headerCompare < 0
171 })
172
173 result, err := json.Marshal(awsSignedReq)
174 if err != nil {
175 return "", err
176 }
177 return url.QueryEscape(string(result)), nil
178 }
179
180 func (sp *awsSubjectProvider) providerType() string {
181 if sp.securityCredentialsProvider != nil {
182 return programmaticProviderType
183 }
184 return awsProviderType
185 }
186
187 func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
188 if sp.IMDSv2SessionTokenURL == "" {
189 return "", nil
190 }
191 req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil)
192 if err != nil {
193 return "", err
194 }
195 req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)
196
197 resp, err := sp.Client.Do(req)
198 if err != nil {
199 return "", err
200 }
201 defer resp.Body.Close()
202
203 respBody, err := internal.ReadAll(resp.Body)
204 if err != nil {
205 return "", err
206 }
207 if resp.StatusCode != http.StatusOK {
208 return "", fmt.Errorf("credentials: unable to retrieve AWS session token: %s", respBody)
209 }
210 return string(respBody), nil
211 }
212
213 func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
214 if sp.securityCredentialsProvider != nil {
215 return sp.securityCredentialsProvider.AwsRegion(ctx, sp.reqOpts)
216 }
217 if canRetrieveRegionFromEnvironment() {
218 if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" {
219 return envAwsRegion, nil
220 }
221 return getenv(awsDefaultRegionEnvVar), nil
222 }
223
224 if sp.RegionURL == "" {
225 return "", errors.New("credentials: unable to determine AWS region")
226 }
227
228 req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil)
229 if err != nil {
230 return "", err
231 }
232
233 for name, value := range headers {
234 req.Header.Add(name, value)
235 }
236
237 resp, err := sp.Client.Do(req)
238 if err != nil {
239 return "", err
240 }
241 defer resp.Body.Close()
242
243 respBody, err := internal.ReadAll(resp.Body)
244 if err != nil {
245 return "", err
246 }
247
248 if resp.StatusCode != http.StatusOK {
249 return "", fmt.Errorf("credentials: unable to retrieve AWS region - %s", respBody)
250 }
251
252
253
254 bodyLen := len(respBody)
255 if bodyLen == 0 {
256 return "", nil
257 }
258 return string(respBody[:bodyLen-1]), nil
259 }
260
261 func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result *AwsSecurityCredentials, err error) {
262 if sp.securityCredentialsProvider != nil {
263 return sp.securityCredentialsProvider.AwsSecurityCredentials(ctx, sp.reqOpts)
264 }
265 if canRetrieveSecurityCredentialFromEnvironment() {
266 return &AwsSecurityCredentials{
267 AccessKeyID: getenv(awsAccessKeyIDEnvVar),
268 SecretAccessKey: getenv(awsSecretAccessKeyEnvVar),
269 SessionToken: getenv(awsSessionTokenEnvVar),
270 }, nil
271 }
272
273 roleName, err := sp.getMetadataRoleName(ctx, headers)
274 if err != nil {
275 return
276 }
277 credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers)
278 if err != nil {
279 return
280 }
281
282 if credentials.AccessKeyID == "" {
283 return result, errors.New("credentials: missing AccessKeyId credential")
284 }
285 if credentials.SecretAccessKey == "" {
286 return result, errors.New("credentials: missing SecretAccessKey credential")
287 }
288
289 return credentials, nil
290 }
291
292 func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (*AwsSecurityCredentials, error) {
293 var result *AwsSecurityCredentials
294
295 req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil)
296 if err != nil {
297 return result, err
298 }
299 for name, value := range headers {
300 req.Header.Add(name, value)
301 }
302
303 resp, err := sp.Client.Do(req)
304 if err != nil {
305 return result, err
306 }
307 defer resp.Body.Close()
308
309 respBody, err := internal.ReadAll(resp.Body)
310 if err != nil {
311 return result, err
312 }
313 if resp.StatusCode != http.StatusOK {
314 return result, fmt.Errorf("credentials: unable to retrieve AWS security credentials - %s", respBody)
315 }
316 err = json.Unmarshal(respBody, &result)
317 return result, err
318 }
319
320 func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
321 if sp.CredVerificationURL == "" {
322 return "", errors.New("credentials: unable to determine the AWS metadata server security credentials endpoint")
323 }
324 req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil)
325 if err != nil {
326 return "", err
327 }
328 for name, value := range headers {
329 req.Header.Add(name, value)
330 }
331
332 resp, err := sp.Client.Do(req)
333 if err != nil {
334 return "", err
335 }
336 defer resp.Body.Close()
337
338 respBody, err := internal.ReadAll(resp.Body)
339 if err != nil {
340 return "", err
341 }
342 if resp.StatusCode != http.StatusOK {
343 return "", fmt.Errorf("credentials: unable to retrieve AWS role name - %s", respBody)
344 }
345 return string(respBody), nil
346 }
347
348
349 type awsRequestSigner struct {
350 RegionName string
351 AwsSecurityCredentials *AwsSecurityCredentials
352 }
353
354
355
356 func (rs *awsRequestSigner) signRequest(req *http.Request) error {
357
358 signedRequest := cloneRequest(req)
359 timestamp := Now()
360 signedRequest.Header.Set("host", requestHost(req))
361 if rs.AwsSecurityCredentials.SessionToken != "" {
362 signedRequest.Header.Set(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
363 }
364 if signedRequest.Header.Get("date") == "" {
365 signedRequest.Header.Set(awsDateHeader, timestamp.Format(awsTimeFormatLong))
366 }
367 authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
368 if err != nil {
369 return err
370 }
371 signedRequest.Header.Set("Authorization", authorizationCode)
372 req.Header = signedRequest.Header
373 return nil
374 }
375
376 func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
377 canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
378 dateStamp := timestamp.Format(awsTimeFormatShort)
379 serviceName := ""
380
381 if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
382 serviceName = splitHost[0]
383 }
384 credentialScope := strings.Join([]string{dateStamp, rs.RegionName, serviceName, awsRequestType}, "/")
385 requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
386 if err != nil {
387 return "", err
388 }
389 requestHash, err := getSha256([]byte(requestString))
390 if err != nil {
391 return "", err
392 }
393
394 stringToSign := strings.Join([]string{awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash}, "\n")
395 signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
396 for _, signingInput := range []string{
397 dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
398 } {
399 signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
400 if err != nil {
401 return "", err
402 }
403 }
404
405 return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
406 }
407
408 func getSha256(input []byte) (string, error) {
409 hash := sha256.New()
410 if _, err := hash.Write(input); err != nil {
411 return "", err
412 }
413 return hex.EncodeToString(hash.Sum(nil)), nil
414 }
415
416 func getHmacSha256(key, input []byte) ([]byte, error) {
417 hash := hmac.New(sha256.New, key)
418 if _, err := hash.Write(input); err != nil {
419 return nil, err
420 }
421 return hash.Sum(nil), nil
422 }
423
424 func cloneRequest(r *http.Request) *http.Request {
425 r2 := new(http.Request)
426 *r2 = *r
427 if r.Header != nil {
428 r2.Header = make(http.Header, len(r.Header))
429
430
431 headerCount := 0
432 for _, headerValues := range r.Header {
433 headerCount += len(headerValues)
434 }
435 copiedHeaders := make([]string, headerCount)
436
437 for headerKey, headerValues := range r.Header {
438 headerCount = copy(copiedHeaders, headerValues)
439 r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
440 copiedHeaders = copiedHeaders[headerCount:]
441 }
442 }
443 return r2
444 }
445
446 func canonicalPath(req *http.Request) string {
447 result := req.URL.EscapedPath()
448 if result == "" {
449 return "/"
450 }
451 return path.Clean(result)
452 }
453
454 func canonicalQuery(req *http.Request) string {
455 queryValues := req.URL.Query()
456 for queryKey := range queryValues {
457 sort.Strings(queryValues[queryKey])
458 }
459 return queryValues.Encode()
460 }
461
462 func canonicalHeaders(req *http.Request) (string, string) {
463
464 var headers []string
465 lowerCaseHeaders := make(http.Header)
466 for k, v := range req.Header {
467 k := strings.ToLower(k)
468 if _, ok := lowerCaseHeaders[k]; ok {
469
470 lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
471 } else {
472 headers = append(headers, k)
473 lowerCaseHeaders[k] = v
474 }
475 }
476 sort.Strings(headers)
477
478 var fullHeaders bytes.Buffer
479 for _, header := range headers {
480 headerValue := strings.Join(lowerCaseHeaders[header], ",")
481 fullHeaders.WriteString(header)
482 fullHeaders.WriteRune(':')
483 fullHeaders.WriteString(headerValue)
484 fullHeaders.WriteRune('\n')
485 }
486
487 return strings.Join(headers, ";"), fullHeaders.String()
488 }
489
490 func requestDataHash(req *http.Request) (string, error) {
491 var requestData []byte
492 if req.Body != nil {
493 requestBody, err := req.GetBody()
494 if err != nil {
495 return "", err
496 }
497 defer requestBody.Close()
498
499 requestData, err = internal.ReadAll(requestBody)
500 if err != nil {
501 return "", err
502 }
503 }
504
505 return getSha256(requestData)
506 }
507
508 func requestHost(req *http.Request) string {
509 if req.Host != "" {
510 return req.Host
511 }
512 return req.URL.Host
513 }
514
515 func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
516 dataHash, err := requestDataHash(req)
517 if err != nil {
518 return "", err
519 }
520 return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
521 }
522
523 type awsRequestHeader struct {
524 Key string `json:"key"`
525 Value string `json:"value"`
526 }
527
528 type awsRequest struct {
529 URL string `json:"url"`
530 Method string `json:"method"`
531 Headers []awsRequestHeader `json:"headers"`
532 }
533
534
535
536 func canRetrieveRegionFromEnvironment() bool {
537 return getenv(awsRegionEnvVar) != "" || getenv(awsDefaultRegionEnvVar) != ""
538 }
539
540
541 func canRetrieveSecurityCredentialFromEnvironment() bool {
542 return getenv(awsAccessKeyIDEnvVar) != "" && getenv(awsSecretAccessKeyEnvVar) != ""
543 }
544
545 func (sp *awsSubjectProvider) shouldUseMetadataServer() bool {
546 return sp.securityCredentialsProvider == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
547 }
548
View as plain text