1
2
3
4
5 package externalaccount
6
7 import (
8 "bytes"
9 "context"
10 "crypto/hmac"
11 "crypto/sha256"
12 "encoding/hex"
13 "encoding/json"
14 "errors"
15 "fmt"
16 "io"
17 "io/ioutil"
18 "net/http"
19 "net/url"
20 "os"
21 "path"
22 "sort"
23 "strings"
24 "time"
25
26 "golang.org/x/oauth2"
27 )
28
29
30 type AwsSecurityCredentials struct {
31
32 AccessKeyID string `json:"AccessKeyID"`
33
34 SecretAccessKey string `json:"SecretAccessKey"`
35
36 SessionToken string `json:"Token"`
37 }
38
39
40 type awsRequestSigner struct {
41 RegionName string
42 AwsSecurityCredentials *AwsSecurityCredentials
43 }
44
45
46 var getenv = os.Getenv
47
48 const (
49 defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
50
51
52 awsAlgorithm = "AWS4-HMAC-SHA256"
53
54
55
56 awsRequestType = "aws4_request"
57
58
59 awsSecurityTokenHeader = "x-amz-security-token"
60
61
62 awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
63
64 awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
65
66 awsIMDSv2SessionTtl = "300"
67
68
69 awsDateHeader = "x-amz-date"
70
71
72 awsAccessKeyId = "AWS_ACCESS_KEY_ID"
73 awsDefaultRegion = "AWS_DEFAULT_REGION"
74 awsRegion = "AWS_REGION"
75 awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
76 awsSessionToken = "AWS_SESSION_TOKEN"
77
78 awsTimeFormatLong = "20060102T150405Z"
79 awsTimeFormatShort = "20060102"
80 )
81
82 func getSha256(input []byte) (string, error) {
83 hash := sha256.New()
84 if _, err := hash.Write(input); err != nil {
85 return "", err
86 }
87 return hex.EncodeToString(hash.Sum(nil)), nil
88 }
89
90 func getHmacSha256(key, input []byte) ([]byte, error) {
91 hash := hmac.New(sha256.New, key)
92 if _, err := hash.Write(input); err != nil {
93 return nil, err
94 }
95 return hash.Sum(nil), nil
96 }
97
98 func cloneRequest(r *http.Request) *http.Request {
99 r2 := new(http.Request)
100 *r2 = *r
101 if r.Header != nil {
102 r2.Header = make(http.Header, len(r.Header))
103
104
105 headerCount := 0
106 for _, headerValues := range r.Header {
107 headerCount += len(headerValues)
108 }
109 copiedHeaders := make([]string, headerCount)
110
111 for headerKey, headerValues := range r.Header {
112 headerCount = copy(copiedHeaders, headerValues)
113 r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
114 copiedHeaders = copiedHeaders[headerCount:]
115 }
116 }
117 return r2
118 }
119
120 func canonicalPath(req *http.Request) string {
121 result := req.URL.EscapedPath()
122 if result == "" {
123 return "/"
124 }
125 return path.Clean(result)
126 }
127
128 func canonicalQuery(req *http.Request) string {
129 queryValues := req.URL.Query()
130 for queryKey := range queryValues {
131 sort.Strings(queryValues[queryKey])
132 }
133 return queryValues.Encode()
134 }
135
136 func canonicalHeaders(req *http.Request) (string, string) {
137
138 var headers []string
139 lowerCaseHeaders := make(http.Header)
140 for k, v := range req.Header {
141 k := strings.ToLower(k)
142 if _, ok := lowerCaseHeaders[k]; ok {
143
144 lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
145 } else {
146 headers = append(headers, k)
147 lowerCaseHeaders[k] = v
148 }
149 }
150 sort.Strings(headers)
151
152 var fullHeaders bytes.Buffer
153 for _, header := range headers {
154 headerValue := strings.Join(lowerCaseHeaders[header], ",")
155 fullHeaders.WriteString(header)
156 fullHeaders.WriteRune(':')
157 fullHeaders.WriteString(headerValue)
158 fullHeaders.WriteRune('\n')
159 }
160
161 return strings.Join(headers, ";"), fullHeaders.String()
162 }
163
164 func requestDataHash(req *http.Request) (string, error) {
165 var requestData []byte
166 if req.Body != nil {
167 requestBody, err := req.GetBody()
168 if err != nil {
169 return "", err
170 }
171 defer requestBody.Close()
172
173 requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
174 if err != nil {
175 return "", err
176 }
177 }
178
179 return getSha256(requestData)
180 }
181
182 func requestHost(req *http.Request) string {
183 if req.Host != "" {
184 return req.Host
185 }
186 return req.URL.Host
187 }
188
189 func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
190 dataHash, err := requestDataHash(req)
191 if err != nil {
192 return "", err
193 }
194
195 return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
196 }
197
198
199
200 func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
201 signedRequest := cloneRequest(req)
202 timestamp := now()
203
204 signedRequest.Header.Add("host", requestHost(req))
205
206 if rs.AwsSecurityCredentials.SessionToken != "" {
207 signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
208 }
209
210 if signedRequest.Header.Get("date") == "" {
211 signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
212 }
213
214 authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
215 if err != nil {
216 return err
217 }
218 signedRequest.Header.Set("Authorization", authorizationCode)
219
220 req.Header = signedRequest.Header
221 return nil
222 }
223
224 func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
225 canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
226
227 dateStamp := timestamp.Format(awsTimeFormatShort)
228 serviceName := ""
229 if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
230 serviceName = splitHost[0]
231 }
232
233 credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
234
235 requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
236 if err != nil {
237 return "", err
238 }
239 requestHash, err := getSha256([]byte(requestString))
240 if err != nil {
241 return "", err
242 }
243
244 stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
245
246 signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
247 for _, signingInput := range []string{
248 dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
249 } {
250 signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
251 if err != nil {
252 return "", err
253 }
254 }
255
256 return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
257 }
258
259 type awsCredentialSource struct {
260 environmentID string
261 regionURL string
262 regionalCredVerificationURL string
263 credVerificationURL string
264 imdsv2SessionTokenURL string
265 targetResource string
266 requestSigner *awsRequestSigner
267 region string
268 ctx context.Context
269 client *http.Client
270 awsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier
271 supplierOptions SupplierOptions
272 }
273
274 type awsRequestHeader struct {
275 Key string `json:"key"`
276 Value string `json:"value"`
277 }
278
279 type awsRequest struct {
280 URL string `json:"url"`
281 Method string `json:"method"`
282 Headers []awsRequestHeader `json:"headers"`
283 }
284
285 func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
286 if cs.client == nil {
287 cs.client = oauth2.NewClient(cs.ctx, nil)
288 }
289 return cs.client.Do(req.WithContext(cs.ctx))
290 }
291
292 func canRetrieveRegionFromEnvironment() bool {
293
294
295 return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
296 }
297
298 func canRetrieveSecurityCredentialFromEnvironment() bool {
299
300 return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
301 }
302
303 func (cs awsCredentialSource) shouldUseMetadataServer() bool {
304 return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
305 }
306
307 func (cs awsCredentialSource) credentialSourceType() string {
308 if cs.awsSecurityCredentialsSupplier != nil {
309 return "programmatic"
310 }
311 return "aws"
312 }
313
314 func (cs awsCredentialSource) subjectToken() (string, error) {
315
316 if cs.regionalCredVerificationURL == "" {
317 cs.regionalCredVerificationURL = defaultRegionalCredentialVerificationUrl
318 }
319 if cs.requestSigner == nil {
320 headers := make(map[string]string)
321 if cs.shouldUseMetadataServer() {
322 awsSessionToken, err := cs.getAWSSessionToken()
323 if err != nil {
324 return "", err
325 }
326
327 if awsSessionToken != "" {
328 headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
329 }
330 }
331
332 awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
333 if err != nil {
334 return "", err
335 }
336 cs.region, err = cs.getRegion(headers)
337 if err != nil {
338 return "", err
339 }
340
341 cs.requestSigner = &awsRequestSigner{
342 RegionName: cs.region,
343 AwsSecurityCredentials: awsSecurityCredentials,
344 }
345 }
346
347
348
349 req, err := http.NewRequest("POST", strings.Replace(cs.regionalCredVerificationURL, "{region}", cs.region, 1), nil)
350 if err != nil {
351 return "", err
352 }
353
354
355
356
357 if cs.targetResource != "" {
358 req.Header.Add("x-goog-cloud-target-resource", cs.targetResource)
359 }
360 cs.requestSigner.SignRequest(req)
361
362
376
377 awsSignedReq := awsRequest{
378 URL: req.URL.String(),
379 Method: "POST",
380 }
381 for headerKey, headerList := range req.Header {
382 for _, headerValue := range headerList {
383 awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
384 Key: headerKey,
385 Value: headerValue,
386 })
387 }
388 }
389 sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
390 headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
391 if headerCompare == 0 {
392 return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
393 }
394 return headerCompare < 0
395 })
396
397 result, err := json.Marshal(awsSignedReq)
398 if err != nil {
399 return "", err
400 }
401 return url.QueryEscape(string(result)), nil
402 }
403
404 func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
405 if cs.imdsv2SessionTokenURL == "" {
406 return "", nil
407 }
408
409 req, err := http.NewRequest("PUT", cs.imdsv2SessionTokenURL, nil)
410 if err != nil {
411 return "", err
412 }
413
414 req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
415
416 resp, err := cs.doRequest(req)
417 if err != nil {
418 return "", err
419 }
420 defer resp.Body.Close()
421
422 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
423 if err != nil {
424 return "", err
425 }
426
427 if resp.StatusCode != 200 {
428 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS session token - %s", string(respBody))
429 }
430
431 return string(respBody), nil
432 }
433
434 func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
435 if cs.awsSecurityCredentialsSupplier != nil {
436 return cs.awsSecurityCredentialsSupplier.AwsRegion(cs.ctx, cs.supplierOptions)
437 }
438 if canRetrieveRegionFromEnvironment() {
439 if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
440 cs.region = envAwsRegion
441 return envAwsRegion, nil
442 }
443 return getenv("AWS_DEFAULT_REGION"), nil
444 }
445
446 if cs.regionURL == "" {
447 return "", errors.New("oauth2/google/externalaccount: unable to determine AWS region")
448 }
449
450 req, err := http.NewRequest("GET", cs.regionURL, nil)
451 if err != nil {
452 return "", err
453 }
454
455 for name, value := range headers {
456 req.Header.Add(name, value)
457 }
458
459 resp, err := cs.doRequest(req)
460 if err != nil {
461 return "", err
462 }
463 defer resp.Body.Close()
464
465 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
466 if err != nil {
467 return "", err
468 }
469
470 if resp.StatusCode != 200 {
471 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS region - %s", string(respBody))
472 }
473
474
475
476 respBodyEnd := 0
477 if len(respBody) > 1 {
478 respBodyEnd = len(respBody) - 1
479 }
480 return string(respBody[:respBodyEnd]), nil
481 }
482
483 func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result *AwsSecurityCredentials, err error) {
484 if cs.awsSecurityCredentialsSupplier != nil {
485 return cs.awsSecurityCredentialsSupplier.AwsSecurityCredentials(cs.ctx, cs.supplierOptions)
486 }
487 if canRetrieveSecurityCredentialFromEnvironment() {
488 return &AwsSecurityCredentials{
489 AccessKeyID: getenv(awsAccessKeyId),
490 SecretAccessKey: getenv(awsSecretAccessKey),
491 SessionToken: getenv(awsSessionToken),
492 }, nil
493 }
494
495 roleName, err := cs.getMetadataRoleName(headers)
496 if err != nil {
497 return
498 }
499
500 credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
501 if err != nil {
502 return
503 }
504
505 if credentials.AccessKeyID == "" {
506 return result, errors.New("oauth2/google/externalaccount: missing AccessKeyId credential")
507 }
508
509 if credentials.SecretAccessKey == "" {
510 return result, errors.New("oauth2/google/externalaccount: missing SecretAccessKey credential")
511 }
512
513 return &credentials, nil
514 }
515
516 func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) {
517 var result AwsSecurityCredentials
518
519 req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.credVerificationURL, roleName), nil)
520 if err != nil {
521 return result, err
522 }
523
524 for name, value := range headers {
525 req.Header.Add(name, value)
526 }
527
528 resp, err := cs.doRequest(req)
529 if err != nil {
530 return result, err
531 }
532 defer resp.Body.Close()
533
534 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
535 if err != nil {
536 return result, err
537 }
538
539 if resp.StatusCode != 200 {
540 return result, fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS security credentials - %s", string(respBody))
541 }
542
543 err = json.Unmarshal(respBody, &result)
544 return result, err
545 }
546
547 func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
548 if cs.credVerificationURL == "" {
549 return "", errors.New("oauth2/google/externalaccount: unable to determine the AWS metadata server security credentials endpoint")
550 }
551
552 req, err := http.NewRequest("GET", cs.credVerificationURL, nil)
553 if err != nil {
554 return "", err
555 }
556
557 for name, value := range headers {
558 req.Header.Add(name, value)
559 }
560
561 resp, err := cs.doRequest(req)
562 if err != nil {
563 return "", err
564 }
565 defer resp.Body.Close()
566
567 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
568 if err != nil {
569 return "", err
570 }
571
572 if resp.StatusCode != 200 {
573 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS role name - %s", string(respBody))
574 }
575
576 return string(respBody), nil
577 }
578
View as plain text