...

Source file src/cloud.google.com/go/auth/credentials/internal/externalaccount/aws_provider.go

Documentation: cloud.google.com/go/auth/credentials/internal/externalaccount

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package externalaccount
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/hmac"
    21  	"crypto/sha256"
    22  	"encoding/hex"
    23  	"encoding/json"
    24  	"errors"
    25  	"fmt"
    26  	"net/http"
    27  	"net/url"
    28  	"os"
    29  	"path"
    30  	"sort"
    31  	"strings"
    32  	"time"
    33  
    34  	"cloud.google.com/go/auth/internal"
    35  )
    36  
    37  var (
    38  	// getenv aliases os.Getenv for testing
    39  	getenv = os.Getenv
    40  )
    41  
    42  const (
    43  	// AWS Signature Version 4 signing algorithm identifier.
    44  	awsAlgorithm = "AWS4-HMAC-SHA256"
    45  
    46  	// The termination string for the AWS credential scope value as defined in
    47  	// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
    48  	awsRequestType = "aws4_request"
    49  
    50  	// The AWS authorization header name for the security session token if available.
    51  	awsSecurityTokenHeader = "x-amz-security-token"
    52  
    53  	// The name of the header containing the session token for metadata endpoint calls
    54  	awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
    55  
    56  	awsIMDSv2SessionTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
    57  
    58  	awsIMDSv2SessionTTL = "300"
    59  
    60  	// The AWS authorization header name for the auto-generated date.
    61  	awsDateHeader = "x-amz-date"
    62  
    63  	defaultRegionalCredentialVerificationURL = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
    64  
    65  	// Supported AWS configuration environment variables.
    66  	awsAccessKeyIDEnvVar     = "AWS_ACCESS_KEY_ID"
    67  	awsDefaultRegionEnvVar   = "AWS_DEFAULT_REGION"
    68  	awsRegionEnvVar          = "AWS_REGION"
    69  	awsSecretAccessKeyEnvVar = "AWS_SECRET_ACCESS_KEY"
    70  	awsSessionTokenEnvVar    = "AWS_SESSION_TOKEN"
    71  
    72  	awsTimeFormatLong  = "20060102T150405Z"
    73  	awsTimeFormatShort = "20060102"
    74  	awsProviderType    = "aws"
    75  )
    76  
    77  type awsSubjectProvider struct {
    78  	EnvironmentID               string
    79  	RegionURL                   string
    80  	RegionalCredVerificationURL string
    81  	CredVerificationURL         string
    82  	IMDSv2SessionTokenURL       string
    83  	TargetResource              string
    84  	requestSigner               *awsRequestSigner
    85  	region                      string
    86  	securityCredentialsProvider AwsSecurityCredentialsProvider
    87  	reqOpts                     *RequestOptions
    88  
    89  	Client *http.Client
    90  }
    91  
    92  func (sp *awsSubjectProvider) subjectToken(ctx context.Context) (string, error) {
    93  	// Set Defaults
    94  	if sp.RegionalCredVerificationURL == "" {
    95  		sp.RegionalCredVerificationURL = defaultRegionalCredentialVerificationURL
    96  	}
    97  	if sp.requestSigner == nil {
    98  		headers := make(map[string]string)
    99  		if sp.shouldUseMetadataServer() {
   100  			awsSessionToken, err := sp.getAWSSessionToken(ctx)
   101  			if err != nil {
   102  				return "", err
   103  			}
   104  
   105  			if awsSessionToken != "" {
   106  				headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
   107  			}
   108  		}
   109  
   110  		awsSecurityCredentials, err := sp.getSecurityCredentials(ctx, headers)
   111  		if err != nil {
   112  			return "", err
   113  		}
   114  		if sp.region, err = sp.getRegion(ctx, headers); err != nil {
   115  			return "", err
   116  		}
   117  		sp.requestSigner = &awsRequestSigner{
   118  			RegionName:             sp.region,
   119  			AwsSecurityCredentials: awsSecurityCredentials,
   120  		}
   121  	}
   122  
   123  	// Generate the signed request to AWS STS GetCallerIdentity API.
   124  	// Use the required regional endpoint. Otherwise, the request will fail.
   125  	req, err := http.NewRequest("POST", strings.Replace(sp.RegionalCredVerificationURL, "{region}", sp.region, 1), nil)
   126  	if err != nil {
   127  		return "", err
   128  	}
   129  	// The full, canonical resource name of the workload identity pool
   130  	// provider, with or without the HTTPS prefix.
   131  	// Including this header as part of the signature is recommended to
   132  	// ensure data integrity.
   133  	if sp.TargetResource != "" {
   134  		req.Header.Set("x-goog-cloud-target-resource", sp.TargetResource)
   135  	}
   136  	sp.requestSigner.signRequest(req)
   137  
   138  	/*
   139  	   The GCP STS endpoint expects the headers to be formatted as:
   140  	   # [
   141  	   #   {key: 'x-amz-date', value: '...'},
   142  	   #   {key: 'Authorization', value: '...'},
   143  	   #   ...
   144  	   # ]
   145  	   # And then serialized as:
   146  	   # quote(json.dumps({
   147  	   #   url: '...',
   148  	   #   method: 'POST',
   149  	   #   headers: [{key: 'x-amz-date', value: '...'}, ...]
   150  	   # }))
   151  	*/
   152  
   153  	awsSignedReq := awsRequest{
   154  		URL:    req.URL.String(),
   155  		Method: "POST",
   156  	}
   157  	for headerKey, headerList := range req.Header {
   158  		for _, headerValue := range headerList {
   159  			awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
   160  				Key:   headerKey,
   161  				Value: headerValue,
   162  			})
   163  		}
   164  	}
   165  	sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
   166  		headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
   167  		if headerCompare == 0 {
   168  			return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
   169  		}
   170  		return headerCompare < 0
   171  	})
   172  
   173  	result, err := json.Marshal(awsSignedReq)
   174  	if err != nil {
   175  		return "", err
   176  	}
   177  	return url.QueryEscape(string(result)), nil
   178  }
   179  
   180  func (sp *awsSubjectProvider) providerType() string {
   181  	if sp.securityCredentialsProvider != nil {
   182  		return programmaticProviderType
   183  	}
   184  	return awsProviderType
   185  }
   186  
   187  func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
   188  	if sp.IMDSv2SessionTokenURL == "" {
   189  		return "", nil
   190  	}
   191  	req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil)
   192  	if err != nil {
   193  		return "", err
   194  	}
   195  	req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)
   196  
   197  	resp, err := sp.Client.Do(req)
   198  	if err != nil {
   199  		return "", err
   200  	}
   201  	defer resp.Body.Close()
   202  
   203  	respBody, err := internal.ReadAll(resp.Body)
   204  	if err != nil {
   205  		return "", err
   206  	}
   207  	if resp.StatusCode != http.StatusOK {
   208  		return "", fmt.Errorf("credentials: unable to retrieve AWS session token: %s", respBody)
   209  	}
   210  	return string(respBody), nil
   211  }
   212  
   213  func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
   214  	if sp.securityCredentialsProvider != nil {
   215  		return sp.securityCredentialsProvider.AwsRegion(ctx, sp.reqOpts)
   216  	}
   217  	if canRetrieveRegionFromEnvironment() {
   218  		if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" {
   219  			return envAwsRegion, nil
   220  		}
   221  		return getenv(awsDefaultRegionEnvVar), nil
   222  	}
   223  
   224  	if sp.RegionURL == "" {
   225  		return "", errors.New("credentials: unable to determine AWS region")
   226  	}
   227  
   228  	req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil)
   229  	if err != nil {
   230  		return "", err
   231  	}
   232  
   233  	for name, value := range headers {
   234  		req.Header.Add(name, value)
   235  	}
   236  
   237  	resp, err := sp.Client.Do(req)
   238  	if err != nil {
   239  		return "", err
   240  	}
   241  	defer resp.Body.Close()
   242  
   243  	respBody, err := internal.ReadAll(resp.Body)
   244  	if err != nil {
   245  		return "", err
   246  	}
   247  
   248  	if resp.StatusCode != http.StatusOK {
   249  		return "", fmt.Errorf("credentials: unable to retrieve AWS region - %s", respBody)
   250  	}
   251  
   252  	// This endpoint will return the region in format: us-east-2b.
   253  	// Only the us-east-2 part should be used.
   254  	bodyLen := len(respBody)
   255  	if bodyLen == 0 {
   256  		return "", nil
   257  	}
   258  	return string(respBody[:bodyLen-1]), nil
   259  }
   260  
   261  func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result *AwsSecurityCredentials, err error) {
   262  	if sp.securityCredentialsProvider != nil {
   263  		return sp.securityCredentialsProvider.AwsSecurityCredentials(ctx, sp.reqOpts)
   264  	}
   265  	if canRetrieveSecurityCredentialFromEnvironment() {
   266  		return &AwsSecurityCredentials{
   267  			AccessKeyID:     getenv(awsAccessKeyIDEnvVar),
   268  			SecretAccessKey: getenv(awsSecretAccessKeyEnvVar),
   269  			SessionToken:    getenv(awsSessionTokenEnvVar),
   270  		}, nil
   271  	}
   272  
   273  	roleName, err := sp.getMetadataRoleName(ctx, headers)
   274  	if err != nil {
   275  		return
   276  	}
   277  	credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers)
   278  	if err != nil {
   279  		return
   280  	}
   281  
   282  	if credentials.AccessKeyID == "" {
   283  		return result, errors.New("credentials: missing AccessKeyId credential")
   284  	}
   285  	if credentials.SecretAccessKey == "" {
   286  		return result, errors.New("credentials: missing SecretAccessKey credential")
   287  	}
   288  
   289  	return credentials, nil
   290  }
   291  
   292  func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (*AwsSecurityCredentials, error) {
   293  	var result *AwsSecurityCredentials
   294  
   295  	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil)
   296  	if err != nil {
   297  		return result, err
   298  	}
   299  	for name, value := range headers {
   300  		req.Header.Add(name, value)
   301  	}
   302  
   303  	resp, err := sp.Client.Do(req)
   304  	if err != nil {
   305  		return result, err
   306  	}
   307  	defer resp.Body.Close()
   308  
   309  	respBody, err := internal.ReadAll(resp.Body)
   310  	if err != nil {
   311  		return result, err
   312  	}
   313  	if resp.StatusCode != http.StatusOK {
   314  		return result, fmt.Errorf("credentials: unable to retrieve AWS security credentials - %s", respBody)
   315  	}
   316  	err = json.Unmarshal(respBody, &result)
   317  	return result, err
   318  }
   319  
   320  func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
   321  	if sp.CredVerificationURL == "" {
   322  		return "", errors.New("credentials: unable to determine the AWS metadata server security credentials endpoint")
   323  	}
   324  	req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil)
   325  	if err != nil {
   326  		return "", err
   327  	}
   328  	for name, value := range headers {
   329  		req.Header.Add(name, value)
   330  	}
   331  
   332  	resp, err := sp.Client.Do(req)
   333  	if err != nil {
   334  		return "", err
   335  	}
   336  	defer resp.Body.Close()
   337  
   338  	respBody, err := internal.ReadAll(resp.Body)
   339  	if err != nil {
   340  		return "", err
   341  	}
   342  	if resp.StatusCode != http.StatusOK {
   343  		return "", fmt.Errorf("credentials: unable to retrieve AWS role name - %s", respBody)
   344  	}
   345  	return string(respBody), nil
   346  }
   347  
   348  // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
   349  type awsRequestSigner struct {
   350  	RegionName             string
   351  	AwsSecurityCredentials *AwsSecurityCredentials
   352  }
   353  
   354  // signRequest adds the appropriate headers to an http.Request
   355  // or returns an error if something prevented this.
   356  func (rs *awsRequestSigner) signRequest(req *http.Request) error {
   357  	// req is assumed non-nil
   358  	signedRequest := cloneRequest(req)
   359  	timestamp := Now()
   360  	signedRequest.Header.Set("host", requestHost(req))
   361  	if rs.AwsSecurityCredentials.SessionToken != "" {
   362  		signedRequest.Header.Set(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
   363  	}
   364  	if signedRequest.Header.Get("date") == "" {
   365  		signedRequest.Header.Set(awsDateHeader, timestamp.Format(awsTimeFormatLong))
   366  	}
   367  	authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
   368  	if err != nil {
   369  		return err
   370  	}
   371  	signedRequest.Header.Set("Authorization", authorizationCode)
   372  	req.Header = signedRequest.Header
   373  	return nil
   374  }
   375  
   376  func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
   377  	canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
   378  	dateStamp := timestamp.Format(awsTimeFormatShort)
   379  	serviceName := ""
   380  
   381  	if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
   382  		serviceName = splitHost[0]
   383  	}
   384  	credentialScope := strings.Join([]string{dateStamp, rs.RegionName, serviceName, awsRequestType}, "/")
   385  	requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
   386  	if err != nil {
   387  		return "", err
   388  	}
   389  	requestHash, err := getSha256([]byte(requestString))
   390  	if err != nil {
   391  		return "", err
   392  	}
   393  
   394  	stringToSign := strings.Join([]string{awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash}, "\n")
   395  	signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
   396  	for _, signingInput := range []string{
   397  		dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
   398  	} {
   399  		signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
   400  		if err != nil {
   401  			return "", err
   402  		}
   403  	}
   404  
   405  	return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
   406  }
   407  
   408  func getSha256(input []byte) (string, error) {
   409  	hash := sha256.New()
   410  	if _, err := hash.Write(input); err != nil {
   411  		return "", err
   412  	}
   413  	return hex.EncodeToString(hash.Sum(nil)), nil
   414  }
   415  
   416  func getHmacSha256(key, input []byte) ([]byte, error) {
   417  	hash := hmac.New(sha256.New, key)
   418  	if _, err := hash.Write(input); err != nil {
   419  		return nil, err
   420  	}
   421  	return hash.Sum(nil), nil
   422  }
   423  
   424  func cloneRequest(r *http.Request) *http.Request {
   425  	r2 := new(http.Request)
   426  	*r2 = *r
   427  	if r.Header != nil {
   428  		r2.Header = make(http.Header, len(r.Header))
   429  
   430  		// Find total number of values.
   431  		headerCount := 0
   432  		for _, headerValues := range r.Header {
   433  			headerCount += len(headerValues)
   434  		}
   435  		copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
   436  
   437  		for headerKey, headerValues := range r.Header {
   438  			headerCount = copy(copiedHeaders, headerValues)
   439  			r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
   440  			copiedHeaders = copiedHeaders[headerCount:]
   441  		}
   442  	}
   443  	return r2
   444  }
   445  
   446  func canonicalPath(req *http.Request) string {
   447  	result := req.URL.EscapedPath()
   448  	if result == "" {
   449  		return "/"
   450  	}
   451  	return path.Clean(result)
   452  }
   453  
   454  func canonicalQuery(req *http.Request) string {
   455  	queryValues := req.URL.Query()
   456  	for queryKey := range queryValues {
   457  		sort.Strings(queryValues[queryKey])
   458  	}
   459  	return queryValues.Encode()
   460  }
   461  
   462  func canonicalHeaders(req *http.Request) (string, string) {
   463  	// Header keys need to be sorted alphabetically.
   464  	var headers []string
   465  	lowerCaseHeaders := make(http.Header)
   466  	for k, v := range req.Header {
   467  		k := strings.ToLower(k)
   468  		if _, ok := lowerCaseHeaders[k]; ok {
   469  			// include additional values
   470  			lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
   471  		} else {
   472  			headers = append(headers, k)
   473  			lowerCaseHeaders[k] = v
   474  		}
   475  	}
   476  	sort.Strings(headers)
   477  
   478  	var fullHeaders bytes.Buffer
   479  	for _, header := range headers {
   480  		headerValue := strings.Join(lowerCaseHeaders[header], ",")
   481  		fullHeaders.WriteString(header)
   482  		fullHeaders.WriteRune(':')
   483  		fullHeaders.WriteString(headerValue)
   484  		fullHeaders.WriteRune('\n')
   485  	}
   486  
   487  	return strings.Join(headers, ";"), fullHeaders.String()
   488  }
   489  
   490  func requestDataHash(req *http.Request) (string, error) {
   491  	var requestData []byte
   492  	if req.Body != nil {
   493  		requestBody, err := req.GetBody()
   494  		if err != nil {
   495  			return "", err
   496  		}
   497  		defer requestBody.Close()
   498  
   499  		requestData, err = internal.ReadAll(requestBody)
   500  		if err != nil {
   501  			return "", err
   502  		}
   503  	}
   504  
   505  	return getSha256(requestData)
   506  }
   507  
   508  func requestHost(req *http.Request) string {
   509  	if req.Host != "" {
   510  		return req.Host
   511  	}
   512  	return req.URL.Host
   513  }
   514  
   515  func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
   516  	dataHash, err := requestDataHash(req)
   517  	if err != nil {
   518  		return "", err
   519  	}
   520  	return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
   521  }
   522  
   523  type awsRequestHeader struct {
   524  	Key   string `json:"key"`
   525  	Value string `json:"value"`
   526  }
   527  
   528  type awsRequest struct {
   529  	URL     string             `json:"url"`
   530  	Method  string             `json:"method"`
   531  	Headers []awsRequestHeader `json:"headers"`
   532  }
   533  
   534  // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
   535  // required.
   536  func canRetrieveRegionFromEnvironment() bool {
   537  	return getenv(awsRegionEnvVar) != "" || getenv(awsDefaultRegionEnvVar) != ""
   538  }
   539  
   540  // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
   541  func canRetrieveSecurityCredentialFromEnvironment() bool {
   542  	return getenv(awsAccessKeyIDEnvVar) != "" && getenv(awsSecretAccessKeyEnvVar) != ""
   543  }
   544  
   545  func (sp *awsSubjectProvider) shouldUseMetadataServer() bool {
   546  	return sp.securityCredentialsProvider == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
   547  }
   548  

View as plain text