...

Source file src/golang.org/x/oauth2/google/externalaccount/aws.go

Documentation: golang.org/x/oauth2/google/externalaccount

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package externalaccount
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/hmac"
    11  	"crypto/sha256"
    12  	"encoding/hex"
    13  	"encoding/json"
    14  	"errors"
    15  	"fmt"
    16  	"io"
    17  	"io/ioutil"
    18  	"net/http"
    19  	"net/url"
    20  	"os"
    21  	"path"
    22  	"sort"
    23  	"strings"
    24  	"time"
    25  
    26  	"golang.org/x/oauth2"
    27  )
    28  
    29  // AwsSecurityCredentials models AWS security credentials.
    30  type AwsSecurityCredentials struct {
    31  	// AccessKeyId is the AWS Access Key ID - Required.
    32  	AccessKeyID string `json:"AccessKeyID"`
    33  	// SecretAccessKey is the AWS Secret Access Key - Required.
    34  	SecretAccessKey string `json:"SecretAccessKey"`
    35  	// SessionToken is the AWS Session token. This should be provided for temporary AWS security credentials - Optional.
    36  	SessionToken string `json:"Token"`
    37  }
    38  
    39  // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
    40  type awsRequestSigner struct {
    41  	RegionName             string
    42  	AwsSecurityCredentials *AwsSecurityCredentials
    43  }
    44  
    45  // getenv aliases os.Getenv for testing
    46  var getenv = os.Getenv
    47  
    48  const (
    49  	defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
    50  
    51  	// AWS Signature Version 4 signing algorithm identifier.
    52  	awsAlgorithm = "AWS4-HMAC-SHA256"
    53  
    54  	// The termination string for the AWS credential scope value as defined in
    55  	// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
    56  	awsRequestType = "aws4_request"
    57  
    58  	// The AWS authorization header name for the security session token if available.
    59  	awsSecurityTokenHeader = "x-amz-security-token"
    60  
    61  	// The name of the header containing the session token for metadata endpoint calls
    62  	awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
    63  
    64  	awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
    65  
    66  	awsIMDSv2SessionTtl = "300"
    67  
    68  	// The AWS authorization header name for the auto-generated date.
    69  	awsDateHeader = "x-amz-date"
    70  
    71  	// Supported AWS configuration environment variables.
    72  	awsAccessKeyId     = "AWS_ACCESS_KEY_ID"
    73  	awsDefaultRegion   = "AWS_DEFAULT_REGION"
    74  	awsRegion          = "AWS_REGION"
    75  	awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
    76  	awsSessionToken    = "AWS_SESSION_TOKEN"
    77  
    78  	awsTimeFormatLong  = "20060102T150405Z"
    79  	awsTimeFormatShort = "20060102"
    80  )
    81  
    82  func getSha256(input []byte) (string, error) {
    83  	hash := sha256.New()
    84  	if _, err := hash.Write(input); err != nil {
    85  		return "", err
    86  	}
    87  	return hex.EncodeToString(hash.Sum(nil)), nil
    88  }
    89  
    90  func getHmacSha256(key, input []byte) ([]byte, error) {
    91  	hash := hmac.New(sha256.New, key)
    92  	if _, err := hash.Write(input); err != nil {
    93  		return nil, err
    94  	}
    95  	return hash.Sum(nil), nil
    96  }
    97  
    98  func cloneRequest(r *http.Request) *http.Request {
    99  	r2 := new(http.Request)
   100  	*r2 = *r
   101  	if r.Header != nil {
   102  		r2.Header = make(http.Header, len(r.Header))
   103  
   104  		// Find total number of values.
   105  		headerCount := 0
   106  		for _, headerValues := range r.Header {
   107  			headerCount += len(headerValues)
   108  		}
   109  		copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
   110  
   111  		for headerKey, headerValues := range r.Header {
   112  			headerCount = copy(copiedHeaders, headerValues)
   113  			r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
   114  			copiedHeaders = copiedHeaders[headerCount:]
   115  		}
   116  	}
   117  	return r2
   118  }
   119  
   120  func canonicalPath(req *http.Request) string {
   121  	result := req.URL.EscapedPath()
   122  	if result == "" {
   123  		return "/"
   124  	}
   125  	return path.Clean(result)
   126  }
   127  
   128  func canonicalQuery(req *http.Request) string {
   129  	queryValues := req.URL.Query()
   130  	for queryKey := range queryValues {
   131  		sort.Strings(queryValues[queryKey])
   132  	}
   133  	return queryValues.Encode()
   134  }
   135  
   136  func canonicalHeaders(req *http.Request) (string, string) {
   137  	// Header keys need to be sorted alphabetically.
   138  	var headers []string
   139  	lowerCaseHeaders := make(http.Header)
   140  	for k, v := range req.Header {
   141  		k := strings.ToLower(k)
   142  		if _, ok := lowerCaseHeaders[k]; ok {
   143  			// include additional values
   144  			lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
   145  		} else {
   146  			headers = append(headers, k)
   147  			lowerCaseHeaders[k] = v
   148  		}
   149  	}
   150  	sort.Strings(headers)
   151  
   152  	var fullHeaders bytes.Buffer
   153  	for _, header := range headers {
   154  		headerValue := strings.Join(lowerCaseHeaders[header], ",")
   155  		fullHeaders.WriteString(header)
   156  		fullHeaders.WriteRune(':')
   157  		fullHeaders.WriteString(headerValue)
   158  		fullHeaders.WriteRune('\n')
   159  	}
   160  
   161  	return strings.Join(headers, ";"), fullHeaders.String()
   162  }
   163  
   164  func requestDataHash(req *http.Request) (string, error) {
   165  	var requestData []byte
   166  	if req.Body != nil {
   167  		requestBody, err := req.GetBody()
   168  		if err != nil {
   169  			return "", err
   170  		}
   171  		defer requestBody.Close()
   172  
   173  		requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
   174  		if err != nil {
   175  			return "", err
   176  		}
   177  	}
   178  
   179  	return getSha256(requestData)
   180  }
   181  
   182  func requestHost(req *http.Request) string {
   183  	if req.Host != "" {
   184  		return req.Host
   185  	}
   186  	return req.URL.Host
   187  }
   188  
   189  func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
   190  	dataHash, err := requestDataHash(req)
   191  	if err != nil {
   192  		return "", err
   193  	}
   194  
   195  	return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
   196  }
   197  
   198  // SignRequest adds the appropriate headers to an http.Request
   199  // or returns an error if something prevented this.
   200  func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
   201  	signedRequest := cloneRequest(req)
   202  	timestamp := now()
   203  
   204  	signedRequest.Header.Add("host", requestHost(req))
   205  
   206  	if rs.AwsSecurityCredentials.SessionToken != "" {
   207  		signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
   208  	}
   209  
   210  	if signedRequest.Header.Get("date") == "" {
   211  		signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
   212  	}
   213  
   214  	authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
   215  	if err != nil {
   216  		return err
   217  	}
   218  	signedRequest.Header.Set("Authorization", authorizationCode)
   219  
   220  	req.Header = signedRequest.Header
   221  	return nil
   222  }
   223  
   224  func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
   225  	canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
   226  
   227  	dateStamp := timestamp.Format(awsTimeFormatShort)
   228  	serviceName := ""
   229  	if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
   230  		serviceName = splitHost[0]
   231  	}
   232  
   233  	credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
   234  
   235  	requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
   236  	if err != nil {
   237  		return "", err
   238  	}
   239  	requestHash, err := getSha256([]byte(requestString))
   240  	if err != nil {
   241  		return "", err
   242  	}
   243  
   244  	stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
   245  
   246  	signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
   247  	for _, signingInput := range []string{
   248  		dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
   249  	} {
   250  		signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
   251  		if err != nil {
   252  			return "", err
   253  		}
   254  	}
   255  
   256  	return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
   257  }
   258  
   259  type awsCredentialSource struct {
   260  	environmentID                  string
   261  	regionURL                      string
   262  	regionalCredVerificationURL    string
   263  	credVerificationURL            string
   264  	imdsv2SessionTokenURL          string
   265  	targetResource                 string
   266  	requestSigner                  *awsRequestSigner
   267  	region                         string
   268  	ctx                            context.Context
   269  	client                         *http.Client
   270  	awsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier
   271  	supplierOptions                SupplierOptions
   272  }
   273  
   274  type awsRequestHeader struct {
   275  	Key   string `json:"key"`
   276  	Value string `json:"value"`
   277  }
   278  
   279  type awsRequest struct {
   280  	URL     string             `json:"url"`
   281  	Method  string             `json:"method"`
   282  	Headers []awsRequestHeader `json:"headers"`
   283  }
   284  
   285  func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
   286  	if cs.client == nil {
   287  		cs.client = oauth2.NewClient(cs.ctx, nil)
   288  	}
   289  	return cs.client.Do(req.WithContext(cs.ctx))
   290  }
   291  
   292  func canRetrieveRegionFromEnvironment() bool {
   293  	// The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
   294  	// required.
   295  	return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
   296  }
   297  
   298  func canRetrieveSecurityCredentialFromEnvironment() bool {
   299  	// Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
   300  	return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
   301  }
   302  
   303  func (cs awsCredentialSource) shouldUseMetadataServer() bool {
   304  	return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
   305  }
   306  
   307  func (cs awsCredentialSource) credentialSourceType() string {
   308  	if cs.awsSecurityCredentialsSupplier != nil {
   309  		return "programmatic"
   310  	}
   311  	return "aws"
   312  }
   313  
   314  func (cs awsCredentialSource) subjectToken() (string, error) {
   315  	// Set Defaults
   316  	if cs.regionalCredVerificationURL == "" {
   317  		cs.regionalCredVerificationURL = defaultRegionalCredentialVerificationUrl
   318  	}
   319  	if cs.requestSigner == nil {
   320  		headers := make(map[string]string)
   321  		if cs.shouldUseMetadataServer() {
   322  			awsSessionToken, err := cs.getAWSSessionToken()
   323  			if err != nil {
   324  				return "", err
   325  			}
   326  
   327  			if awsSessionToken != "" {
   328  				headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
   329  			}
   330  		}
   331  
   332  		awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
   333  		if err != nil {
   334  			return "", err
   335  		}
   336  		cs.region, err = cs.getRegion(headers)
   337  		if err != nil {
   338  			return "", err
   339  		}
   340  
   341  		cs.requestSigner = &awsRequestSigner{
   342  			RegionName:             cs.region,
   343  			AwsSecurityCredentials: awsSecurityCredentials,
   344  		}
   345  	}
   346  
   347  	// Generate the signed request to AWS STS GetCallerIdentity API.
   348  	// Use the required regional endpoint. Otherwise, the request will fail.
   349  	req, err := http.NewRequest("POST", strings.Replace(cs.regionalCredVerificationURL, "{region}", cs.region, 1), nil)
   350  	if err != nil {
   351  		return "", err
   352  	}
   353  	// The full, canonical resource name of the workload identity pool
   354  	// provider, with or without the HTTPS prefix.
   355  	// Including this header as part of the signature is recommended to
   356  	// ensure data integrity.
   357  	if cs.targetResource != "" {
   358  		req.Header.Add("x-goog-cloud-target-resource", cs.targetResource)
   359  	}
   360  	cs.requestSigner.SignRequest(req)
   361  
   362  	/*
   363  	   The GCP STS endpoint expects the headers to be formatted as:
   364  	   # [
   365  	   #   {key: 'x-amz-date', value: '...'},
   366  	   #   {key: 'Authorization', value: '...'},
   367  	   #   ...
   368  	   # ]
   369  	   # And then serialized as:
   370  	   # quote(json.dumps({
   371  	   #   url: '...',
   372  	   #   method: 'POST',
   373  	   #   headers: [{key: 'x-amz-date', value: '...'}, ...]
   374  	   # }))
   375  	*/
   376  
   377  	awsSignedReq := awsRequest{
   378  		URL:    req.URL.String(),
   379  		Method: "POST",
   380  	}
   381  	for headerKey, headerList := range req.Header {
   382  		for _, headerValue := range headerList {
   383  			awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
   384  				Key:   headerKey,
   385  				Value: headerValue,
   386  			})
   387  		}
   388  	}
   389  	sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
   390  		headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
   391  		if headerCompare == 0 {
   392  			return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
   393  		}
   394  		return headerCompare < 0
   395  	})
   396  
   397  	result, err := json.Marshal(awsSignedReq)
   398  	if err != nil {
   399  		return "", err
   400  	}
   401  	return url.QueryEscape(string(result)), nil
   402  }
   403  
   404  func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
   405  	if cs.imdsv2SessionTokenURL == "" {
   406  		return "", nil
   407  	}
   408  
   409  	req, err := http.NewRequest("PUT", cs.imdsv2SessionTokenURL, nil)
   410  	if err != nil {
   411  		return "", err
   412  	}
   413  
   414  	req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
   415  
   416  	resp, err := cs.doRequest(req)
   417  	if err != nil {
   418  		return "", err
   419  	}
   420  	defer resp.Body.Close()
   421  
   422  	respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
   423  	if err != nil {
   424  		return "", err
   425  	}
   426  
   427  	if resp.StatusCode != 200 {
   428  		return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS session token - %s", string(respBody))
   429  	}
   430  
   431  	return string(respBody), nil
   432  }
   433  
   434  func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
   435  	if cs.awsSecurityCredentialsSupplier != nil {
   436  		return cs.awsSecurityCredentialsSupplier.AwsRegion(cs.ctx, cs.supplierOptions)
   437  	}
   438  	if canRetrieveRegionFromEnvironment() {
   439  		if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
   440  			cs.region = envAwsRegion
   441  			return envAwsRegion, nil
   442  		}
   443  		return getenv("AWS_DEFAULT_REGION"), nil
   444  	}
   445  
   446  	if cs.regionURL == "" {
   447  		return "", errors.New("oauth2/google/externalaccount: unable to determine AWS region")
   448  	}
   449  
   450  	req, err := http.NewRequest("GET", cs.regionURL, nil)
   451  	if err != nil {
   452  		return "", err
   453  	}
   454  
   455  	for name, value := range headers {
   456  		req.Header.Add(name, value)
   457  	}
   458  
   459  	resp, err := cs.doRequest(req)
   460  	if err != nil {
   461  		return "", err
   462  	}
   463  	defer resp.Body.Close()
   464  
   465  	respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
   466  	if err != nil {
   467  		return "", err
   468  	}
   469  
   470  	if resp.StatusCode != 200 {
   471  		return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS region - %s", string(respBody))
   472  	}
   473  
   474  	// This endpoint will return the region in format: us-east-2b.
   475  	// Only the us-east-2 part should be used.
   476  	respBodyEnd := 0
   477  	if len(respBody) > 1 {
   478  		respBodyEnd = len(respBody) - 1
   479  	}
   480  	return string(respBody[:respBodyEnd]), nil
   481  }
   482  
   483  func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result *AwsSecurityCredentials, err error) {
   484  	if cs.awsSecurityCredentialsSupplier != nil {
   485  		return cs.awsSecurityCredentialsSupplier.AwsSecurityCredentials(cs.ctx, cs.supplierOptions)
   486  	}
   487  	if canRetrieveSecurityCredentialFromEnvironment() {
   488  		return &AwsSecurityCredentials{
   489  			AccessKeyID:     getenv(awsAccessKeyId),
   490  			SecretAccessKey: getenv(awsSecretAccessKey),
   491  			SessionToken:    getenv(awsSessionToken),
   492  		}, nil
   493  	}
   494  
   495  	roleName, err := cs.getMetadataRoleName(headers)
   496  	if err != nil {
   497  		return
   498  	}
   499  
   500  	credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
   501  	if err != nil {
   502  		return
   503  	}
   504  
   505  	if credentials.AccessKeyID == "" {
   506  		return result, errors.New("oauth2/google/externalaccount: missing AccessKeyId credential")
   507  	}
   508  
   509  	if credentials.SecretAccessKey == "" {
   510  		return result, errors.New("oauth2/google/externalaccount: missing SecretAccessKey credential")
   511  	}
   512  
   513  	return &credentials, nil
   514  }
   515  
   516  func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) {
   517  	var result AwsSecurityCredentials
   518  
   519  	req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.credVerificationURL, roleName), nil)
   520  	if err != nil {
   521  		return result, err
   522  	}
   523  
   524  	for name, value := range headers {
   525  		req.Header.Add(name, value)
   526  	}
   527  
   528  	resp, err := cs.doRequest(req)
   529  	if err != nil {
   530  		return result, err
   531  	}
   532  	defer resp.Body.Close()
   533  
   534  	respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
   535  	if err != nil {
   536  		return result, err
   537  	}
   538  
   539  	if resp.StatusCode != 200 {
   540  		return result, fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS security credentials - %s", string(respBody))
   541  	}
   542  
   543  	err = json.Unmarshal(respBody, &result)
   544  	return result, err
   545  }
   546  
   547  func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
   548  	if cs.credVerificationURL == "" {
   549  		return "", errors.New("oauth2/google/externalaccount: unable to determine the AWS metadata server security credentials endpoint")
   550  	}
   551  
   552  	req, err := http.NewRequest("GET", cs.credVerificationURL, nil)
   553  	if err != nil {
   554  		return "", err
   555  	}
   556  
   557  	for name, value := range headers {
   558  		req.Header.Add(name, value)
   559  	}
   560  
   561  	resp, err := cs.doRequest(req)
   562  	if err != nil {
   563  		return "", err
   564  	}
   565  	defer resp.Body.Close()
   566  
   567  	respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
   568  	if err != nil {
   569  		return "", err
   570  	}
   571  
   572  	if resp.StatusCode != 200 {
   573  		return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS role name - %s", string(respBody))
   574  	}
   575  
   576  	return string(respBody), nil
   577  }
   578  

View as plain text