...

Source file src/cloud.google.com/go/auth/credentials/internal/externalaccount/aws_provider_test.go

Documentation: cloud.google.com/go/auth/credentials/internal/externalaccount

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package externalaccount
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"errors"
    21  	"fmt"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	neturl "net/url"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  
    29  	"cloud.google.com/go/auth/internal/credsfile"
    30  	"github.com/google/go-cmp/cmp"
    31  )
    32  
    33  type validateHeaders func(r *http.Request)
    34  
    35  const (
    36  	accessKeyID     = "accessKeyID"
    37  	secretAccessKey = "secret"
    38  	sessionToken    = "sessionTok"
    39  )
    40  
    41  var (
    42  	defaultTime            = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC)
    43  	secondDefaultTime      = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC)
    44  	requestSignerWithToken = &awsRequestSigner{
    45  		RegionName: "us-east-2",
    46  		AwsSecurityCredentials: &AwsSecurityCredentials{
    47  			AccessKeyID:     accessKeyID,
    48  			SecretAccessKey: secretAccessKey,
    49  			SessionToken:    sessionToken,
    50  		},
    51  	}
    52  )
    53  
    54  func TestAWSv4Signature_GetRequest(t *testing.T) {
    55  	input, _ := http.NewRequest("GET", "https://host.foo.com", nil)
    56  	setDefaultTime(input)
    57  
    58  	output, _ := http.NewRequest("GET", "https://host.foo.com", nil)
    59  	output.Header = http.Header{
    60  		"Host":          []string{"host.foo.com"},
    61  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
    62  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470"},
    63  	}
    64  
    65  	oldNow := Now
    66  	defer func() { Now = oldNow }()
    67  	Now = setTime(defaultTime)
    68  
    69  	testRequestSigner(t, defaultRequestSigner, input, output)
    70  }
    71  
    72  func TestAWSv4Signature_GetRequestWithRelativePath(t *testing.T) {
    73  	input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil)
    74  	setDefaultTime(input)
    75  
    76  	output, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil)
    77  	output.Header = http.Header{
    78  		"Host":          []string{"host.foo.com"},
    79  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
    80  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470"},
    81  	}
    82  
    83  	oldNow := Now
    84  	defer func() { Now = oldNow }()
    85  	Now = setTime(defaultTime)
    86  
    87  	testRequestSigner(t, defaultRequestSigner, input, output)
    88  }
    89  
    90  func TestAWSv4Signature_GetRequestWithDotPath(t *testing.T) {
    91  	input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil)
    92  	setDefaultTime(input)
    93  
    94  	output, _ := http.NewRequest("GET", "https://host.foo.com/./", nil)
    95  	output.Header = http.Header{
    96  		"Host":          []string{"host.foo.com"},
    97  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
    98  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470"},
    99  	}
   100  
   101  	oldNow := Now
   102  	defer func() { Now = oldNow }()
   103  	Now = setTime(defaultTime)
   104  
   105  	testRequestSigner(t, defaultRequestSigner, input, output)
   106  }
   107  
   108  func TestAWSv4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
   109  	input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil)
   110  	setDefaultTime(input)
   111  
   112  	output, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil)
   113  	output.Header = http.Header{
   114  		"Host":          []string{"host.foo.com"},
   115  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   116  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=910e4d6c9abafaf87898e1eb4c929135782ea25bb0279703146455745391e63a"},
   117  	}
   118  
   119  	oldNow := Now
   120  	defer func() { Now = oldNow }()
   121  	Now = setTime(defaultTime)
   122  
   123  	testRequestSigner(t, defaultRequestSigner, input, output)
   124  }
   125  
   126  func TestAWSv4Signature_GetRequestWithUtf8Path(t *testing.T) {
   127  	input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil)
   128  	setDefaultTime(input)
   129  
   130  	output, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil)
   131  	output.Header = http.Header{
   132  		"Host":          []string{"host.foo.com"},
   133  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   134  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=8d6634c189aa8c75c2e51e106b6b5121bed103fdb351f7d7d4381c738823af74"},
   135  	}
   136  
   137  	oldNow := Now
   138  	defer func() { Now = oldNow }()
   139  	Now = setTime(defaultTime)
   140  
   141  	testRequestSigner(t, defaultRequestSigner, input, output)
   142  }
   143  
   144  func TestAWSv4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
   145  	input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil)
   146  	setDefaultTime(input)
   147  
   148  	output, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil)
   149  	output.Header = http.Header{
   150  		"Host":          []string{"host.foo.com"},
   151  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   152  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=be7148d34ebccdc6423b19085378aa0bee970bdc61d144bd1a8c48c33079ab09"},
   153  	}
   154  
   155  	oldNow := Now
   156  	defer func() { Now = oldNow }()
   157  	Now = setTime(defaultTime)
   158  
   159  	testRequestSigner(t, defaultRequestSigner, input, output)
   160  }
   161  
   162  func TestAWSv4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
   163  	input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil)
   164  	setDefaultTime(input)
   165  
   166  	output, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil)
   167  	output.Header = http.Header{
   168  		"Host":          []string{"host.foo.com"},
   169  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   170  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=feb926e49e382bec75c9d7dcb2a1b6dc8aa50ca43c25d2bc51143768c0875acc"},
   171  	}
   172  
   173  	oldNow := Now
   174  	defer func() { Now = oldNow }()
   175  	Now = setTime(defaultTime)
   176  
   177  	testRequestSigner(t, defaultRequestSigner, input, output)
   178  }
   179  
   180  func TestAWSv4Signature_GetRequestWithUtf8Query(t *testing.T) {
   181  	input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil)
   182  	setDefaultTime(input)
   183  
   184  	output, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil)
   185  	output.Header = http.Header{
   186  		"Host":          []string{"host.foo.com"},
   187  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   188  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=6fb359e9a05394cc7074e0feb42573a2601abc0c869a953e8c5c12e4e01f1a8c"},
   189  	}
   190  
   191  	oldNow := Now
   192  	defer func() { Now = oldNow }()
   193  	Now = setTime(defaultTime)
   194  
   195  	testRequestSigner(t, defaultRequestSigner, input, output)
   196  }
   197  
   198  func TestAWSv4Signature_PostRequest(t *testing.T) {
   199  	input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
   200  	setDefaultTime(input)
   201  	input.Header.Set("ZOO", "zoobar")
   202  
   203  	output, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
   204  	output.Header = http.Header{
   205  		"Host":          []string{"host.foo.com"},
   206  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   207  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;zoo, Signature=b7a95a52518abbca0964a999a880429ab734f35ebbf1235bd79a5de87756dc4a"},
   208  		"Zoo":           []string{"zoobar"},
   209  	}
   210  
   211  	oldNow := Now
   212  	defer func() { Now = oldNow }()
   213  	Now = setTime(defaultTime)
   214  
   215  	testRequestSigner(t, defaultRequestSigner, input, output)
   216  }
   217  
   218  func TestAWSv4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
   219  	input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
   220  	setDefaultTime(input)
   221  	input.Header.Set("zoo", "ZOOBAR")
   222  
   223  	output, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
   224  	output.Header = http.Header{
   225  		"Host":          []string{"host.foo.com"},
   226  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   227  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;zoo, Signature=273313af9d0c265c531e11db70bbd653f3ba074c1009239e8559d3987039cad7"},
   228  		"Zoo":           []string{"ZOOBAR"},
   229  	}
   230  
   231  	oldNow := Now
   232  	defer func() { Now = oldNow }()
   233  	Now = setTime(defaultTime)
   234  
   235  	testRequestSigner(t, defaultRequestSigner, input, output)
   236  }
   237  
   238  func TestAWSv4Signature_PostRequestPhfft(t *testing.T) {
   239  	input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
   240  	setDefaultTime(input)
   241  	input.Header.Set("p", "phfft")
   242  
   243  	output, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
   244  	output.Header = http.Header{
   245  		"Host":          []string{"host.foo.com"},
   246  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   247  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;p, Signature=debf546796015d6f6ded8626f5ce98597c33b47b9164cf6b17b4642036fcb592"},
   248  		"P":             []string{"phfft"},
   249  	}
   250  
   251  	oldNow := Now
   252  	defer func() { Now = oldNow }()
   253  	Now = setTime(defaultTime)
   254  
   255  	testRequestSigner(t, defaultRequestSigner, input, output)
   256  }
   257  
   258  func TestAWSv4Signature_PostRequestWithBody(t *testing.T) {
   259  	input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar"))
   260  	setDefaultTime(input)
   261  	input.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   262  
   263  	output, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
   264  	output.Header = http.Header{
   265  		"Host":          []string{"host.foo.com"},
   266  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   267  		"Content-Type":  []string{"application/x-www-form-urlencoded"},
   268  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=content-type;date;host, Signature=5a15b22cf462f047318703b92e6f4f38884e4a7ab7b1d6426ca46a8bd1c26cbc"},
   269  	}
   270  
   271  	oldNow := Now
   272  	defer func() { Now = oldNow }()
   273  	Now = setTime(defaultTime)
   274  
   275  	testRequestSigner(t, defaultRequestSigner, input, output)
   276  }
   277  
   278  func TestAWSv4Signature_PostRequestWithQueryString(t *testing.T) {
   279  	input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil)
   280  	setDefaultTime(input)
   281  
   282  	output, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil)
   283  	output.Header = http.Header{
   284  		"Host":          []string{"host.foo.com"},
   285  		"Date":          []string{"Mon, 09 Sep 2011 23:36:00 GMT"},
   286  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b6e3b79003ce0743a491606ba1035a804593b0efb1e20a11cba83f8c25a57a92"},
   287  	}
   288  
   289  	oldNow := Now
   290  	defer func() { Now = oldNow }()
   291  	Now = setTime(defaultTime)
   292  
   293  	testRequestSigner(t, defaultRequestSigner, input, output)
   294  }
   295  
   296  func TestAWSv4Signature_GetRequestWithSecurityToken(t *testing.T) {
   297  	input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
   298  
   299  	output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
   300  	output.Header = http.Header{
   301  		"Host":                 []string{"ec2.us-east-2.amazonaws.com"},
   302  		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/ec2/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=361dc730fd43b4330fa20603a7fbc305ef10b5be125d219ebef40f18569ef5b2"},
   303  		"X-Amz-Date":           []string{"20200811T065522Z"},
   304  		"X-Amz-Security-Token": []string{sessionToken},
   305  	}
   306  
   307  	oldNow := Now
   308  	defer func() { Now = oldNow }()
   309  	Now = setTime(secondDefaultTime)
   310  
   311  	testRequestSigner(t, requestSignerWithToken, input, output)
   312  }
   313  
   314  func TestAWSv4Signature_PostRequestWithSecurityToken(t *testing.T) {
   315  	input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
   316  
   317  	output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
   318  	output.Header = http.Header{
   319  		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=48541de09ff691ab53f9e017f5123ce338fecbadc1b278480bab3af221ca4f38"},
   320  		"Host":                 []string{"sts.us-east-2.amazonaws.com"},
   321  		"X-Amz-Date":           []string{"20200811T065522Z"},
   322  		"X-Amz-Security-Token": []string{sessionToken},
   323  	}
   324  
   325  	oldNow := Now
   326  	defer func() { Now = oldNow }()
   327  	Now = setTime(secondDefaultTime)
   328  
   329  	testRequestSigner(t, requestSignerWithToken, input, output)
   330  }
   331  
   332  func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) {
   333  	requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}"
   334  	input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams))
   335  	input.Header.Set("Content-Type", "application/x-amz-json-1.0")
   336  	input.Header.Set("x-amz-target", "DynamoDB_20120810.CreateTable")
   337  
   338  	output, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams))
   339  	output.Header = http.Header{
   340  		"Authorization":        []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/dynamodb/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-security-token;x-amz-target, Signature=ae7b75b3c0a9ef64626a4e9e6c3d503745dd0a109cb35d56d87c617eae804c00"},
   341  		"Host":                 []string{"dynamodb.us-east-2.amazonaws.com"},
   342  		"X-Amz-Date":           []string{"20200811T065522Z"},
   343  		"Content-Type":         []string{"application/x-amz-json-1.0"},
   344  		"X-Amz-Target":         []string{"DynamoDB_20120810.CreateTable"},
   345  		"X-Amz-Security-Token": []string{sessionToken},
   346  	}
   347  
   348  	oldNow := Now
   349  	defer func() { Now = oldNow }()
   350  	Now = setTime(secondDefaultTime)
   351  
   352  	testRequestSigner(t, requestSignerWithToken, input, output)
   353  }
   354  
   355  func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
   356  	var requestSigner = &awsRequestSigner{
   357  		RegionName: "us-east-2",
   358  		AwsSecurityCredentials: &AwsSecurityCredentials{
   359  			AccessKeyID:     accessKeyID,
   360  			SecretAccessKey: secretAccessKey,
   361  		},
   362  	}
   363  
   364  	input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
   365  
   366  	output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
   367  	output.Header = http.Header{
   368  		"Authorization": []string{"AWS4-HMAC-SHA256 Credential=" + accessKeyID + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=bff58112d4696faecff9c104c8b7b963141e8b3cc4ba46a0664938fe6d112e41"},
   369  		"Host":          []string{"sts.us-east-2.amazonaws.com"},
   370  		"X-Amz-Date":    []string{"20200811T065522Z"},
   371  	}
   372  
   373  	oldNow := Now
   374  	defer func() { Now = oldNow }()
   375  	Now = setTime(secondDefaultTime)
   376  
   377  	testRequestSigner(t, requestSigner, input, output)
   378  }
   379  
   380  type testAwsServer struct {
   381  	url                         string
   382  	securityCredentialURL       string
   383  	regionURL                   string
   384  	regionalCredVerificationURL string
   385  	imdsv2SessionTokenURL       string
   386  
   387  	Credentials map[string]string
   388  
   389  	WriteRolename            func(http.ResponseWriter, *http.Request)
   390  	WriteSecurityCredentials func(http.ResponseWriter, *http.Request)
   391  	WriteRegion              func(http.ResponseWriter, *http.Request)
   392  	WriteIMDSv2SessionToken  func(http.ResponseWriter, *http.Request)
   393  }
   394  
   395  func createAwsTestServer(url, regionURL, regionalCredVerificationURL, imdsv2SessionTokenURL string, rolename, region string, credentials map[string]string, imdsv2SessionToken string, validateHeaders validateHeaders) *testAwsServer {
   396  	server := &testAwsServer{
   397  		url:                         url,
   398  		securityCredentialURL:       fmt.Sprintf("%s/%s", url, rolename),
   399  		regionURL:                   regionURL,
   400  		regionalCredVerificationURL: regionalCredVerificationURL,
   401  		imdsv2SessionTokenURL:       imdsv2SessionTokenURL,
   402  		Credentials:                 credentials,
   403  		WriteRolename: func(w http.ResponseWriter, r *http.Request) {
   404  			validateHeaders(r)
   405  			w.Write([]byte(rolename))
   406  		},
   407  		WriteRegion: func(w http.ResponseWriter, r *http.Request) {
   408  			validateHeaders(r)
   409  			w.Write([]byte(region))
   410  		},
   411  		WriteIMDSv2SessionToken: func(w http.ResponseWriter, r *http.Request) {
   412  			validateHeaders(r)
   413  			w.Write([]byte(imdsv2SessionToken))
   414  		},
   415  	}
   416  
   417  	server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
   418  		validateHeaders(r)
   419  		jsonCredentials, _ := json.Marshal(server.Credentials)
   420  		w.Write(jsonCredentials)
   421  	}
   422  
   423  	return server
   424  }
   425  
   426  func createDefaultAwsTestServer() *testAwsServer {
   427  	return createAwsTestServer(
   428  		"/latest/meta-data/iam/security-credentials",
   429  		"/latest/meta-data/placement/availability-zone",
   430  		"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   431  		"",
   432  		"gcp-aws-role",
   433  		"us-east-2b",
   434  		map[string]string{
   435  			"SecretAccessKey": secretAccessKey,
   436  			"AccessKeyId":     accessKeyID,
   437  			"Token":           sessionToken,
   438  		},
   439  		"",
   440  		noHeaderValidation,
   441  	)
   442  }
   443  
   444  func createDefaultAwsTestServerWithImdsv2(t *testing.T) *testAwsServer {
   445  	validateSessionTokenHeaders := func(r *http.Request) {
   446  		if r.URL.Path == "/latest/api/token" {
   447  			headerValue := r.Header.Get(awsIMDSv2SessionTTLHeader)
   448  			if headerValue != awsIMDSv2SessionTTL {
   449  				t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTTLHeader, headerValue, awsIMDSv2SessionTTL)
   450  			}
   451  		} else {
   452  			headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
   453  			if headerValue != "sessiontoken" {
   454  				t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
   455  			}
   456  		}
   457  	}
   458  
   459  	return createAwsTestServer(
   460  		"/latest/meta-data/iam/security-credentials",
   461  		"/latest/meta-data/placement/availability-zone",
   462  		"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   463  		"/latest/api/token",
   464  		"gcp-aws-role",
   465  		"us-east-2b",
   466  		map[string]string{
   467  			"SecretAccessKey": secretAccessKey,
   468  			"AccessKeyId":     accessKeyID,
   469  			"Token":           sessionToken,
   470  		},
   471  		"sessiontoken",
   472  		validateSessionTokenHeaders,
   473  	)
   474  }
   475  
   476  func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   477  	switch p := r.URL.Path; p {
   478  	case server.url:
   479  		server.WriteRolename(w, r)
   480  	case server.securityCredentialURL:
   481  		server.WriteSecurityCredentials(w, r)
   482  	case server.regionURL:
   483  		server.WriteRegion(w, r)
   484  	case server.imdsv2SessionTokenURL:
   485  		server.WriteIMDSv2SessionToken(w, r)
   486  	}
   487  }
   488  
   489  func notFound(w http.ResponseWriter, r *http.Request) {
   490  	w.WriteHeader(404)
   491  	w.Write([]byte("Not Found"))
   492  }
   493  
   494  func noHeaderValidation(r *http.Request) {}
   495  
   496  func (server *testAwsServer) getCredentialSource(url string) *credsfile.CredentialSource {
   497  	return &credsfile.CredentialSource{
   498  		EnvironmentID:               "aws1",
   499  		URL:                         url + server.url,
   500  		RegionURL:                   url + server.regionURL,
   501  		RegionalCredVerificationURL: server.regionalCredVerificationURL,
   502  		IMDSv2SessionTokenURL:       url + server.imdsv2SessionTokenURL,
   503  	}
   504  }
   505  
   506  func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, sessionToken string) string {
   507  	req, _ := http.NewRequest("POST", url, nil)
   508  	req.Header.Set("x-goog-cloud-target-resource", cloneTestOpts().Audience)
   509  	signer := &awsRequestSigner{
   510  		RegionName: region,
   511  		AwsSecurityCredentials: &AwsSecurityCredentials{
   512  			AccessKeyID:     accessKeyID,
   513  			SecretAccessKey: secretAccessKey,
   514  			SessionToken:    sessionToken,
   515  		},
   516  	}
   517  	signer.signRequest(req)
   518  
   519  	result := awsRequest{
   520  		URL:    url,
   521  		Method: "POST",
   522  		Headers: []awsRequestHeader{
   523  			{
   524  				Key:   "Authorization",
   525  				Value: req.Header.Get("Authorization"),
   526  			}, {
   527  				Key:   "Host",
   528  				Value: req.Header.Get("Host"),
   529  			}, {
   530  				Key:   "X-Amz-Date",
   531  				Value: req.Header.Get("X-Amz-Date"),
   532  			},
   533  		},
   534  	}
   535  
   536  	if sessionToken != "" {
   537  		result.Headers = append(result.Headers, awsRequestHeader{
   538  			Key:   "X-Amz-Security-Token",
   539  			Value: sessionToken,
   540  		})
   541  	}
   542  
   543  	result.Headers = append(result.Headers, awsRequestHeader{
   544  		Key:   "X-Goog-Cloud-Target-Resource",
   545  		Value: cloneTestOpts().Audience,
   546  	})
   547  
   548  	str, _ := json.Marshal(result)
   549  	return neturl.QueryEscape(string(str))
   550  }
   551  
   552  func TestAWSCredential_BasicRequest(t *testing.T) {
   553  	server := createDefaultAwsTestServer()
   554  	ts := httptest.NewServer(server)
   555  
   556  	opts := cloneTestOpts()
   557  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   558  
   559  	oldGetenv := getenv
   560  	oldNow := Now
   561  	defer func() {
   562  		getenv = oldGetenv
   563  		Now = oldNow
   564  	}()
   565  	getenv = setEnvironment(map[string]string{})
   566  	Now = setTime(defaultTime)
   567  
   568  	base, err := newSubjectTokenProvider(opts)
   569  	if err != nil {
   570  		t.Fatalf("parse() failed %v", err)
   571  	}
   572  
   573  	got, err := base.subjectToken(context.Background())
   574  	if err != nil {
   575  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
   576  	}
   577  
   578  	if got, want := base.providerType(), awsProviderType; got != want {
   579  		t.Fatalf("got %q, want %q", got, want)
   580  	}
   581  
   582  	want := getExpectedSubjectToken(
   583  		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   584  		"us-east-2",
   585  		accessKeyID,
   586  		secretAccessKey,
   587  		sessionToken,
   588  	)
   589  
   590  	if got != want {
   591  		t.Errorf("got %q, want %q", got, want)
   592  	}
   593  }
   594  
   595  func TestAWSCredential_IMDSv2(t *testing.T) {
   596  	server := createDefaultAwsTestServerWithImdsv2(t)
   597  	ts := httptest.NewServer(server)
   598  
   599  	opts := cloneTestOpts()
   600  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   601  
   602  	oldGetenv := getenv
   603  	oldNow := Now
   604  	defer func() {
   605  		getenv = oldGetenv
   606  		Now = oldNow
   607  	}()
   608  	getenv = setEnvironment(map[string]string{})
   609  	Now = setTime(defaultTime)
   610  
   611  	base, err := newSubjectTokenProvider(opts)
   612  	if err != nil {
   613  		t.Fatalf("parse() failed %v", err)
   614  	}
   615  
   616  	got, err := base.subjectToken(context.Background())
   617  	if err != nil {
   618  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
   619  	}
   620  
   621  	want := getExpectedSubjectToken(
   622  		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   623  		"us-east-2",
   624  		accessKeyID,
   625  		secretAccessKey,
   626  		sessionToken,
   627  	)
   628  
   629  	if got != want {
   630  		t.Errorf("got %q, want %q", got, want)
   631  	}
   632  }
   633  
   634  func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
   635  	server := createDefaultAwsTestServer()
   636  	ts := httptest.NewServer(server)
   637  	delete(server.Credentials, "Token")
   638  
   639  	opts := cloneTestOpts()
   640  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   641  
   642  	oldGetenv := getenv
   643  	oldNow := Now
   644  	defer func() {
   645  		getenv = oldGetenv
   646  		Now = oldNow
   647  	}()
   648  	getenv = setEnvironment(map[string]string{})
   649  	Now = setTime(defaultTime)
   650  
   651  	base, err := newSubjectTokenProvider(opts)
   652  	if err != nil {
   653  		t.Fatalf("parse() failed %v", err)
   654  	}
   655  
   656  	got, err := base.subjectToken(context.Background())
   657  	if err != nil {
   658  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
   659  	}
   660  
   661  	want := getExpectedSubjectToken(
   662  		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   663  		"us-east-2",
   664  		accessKeyID,
   665  		secretAccessKey,
   666  		"",
   667  	)
   668  
   669  	if got != want {
   670  		t.Errorf("got %q, want %q", got, want)
   671  	}
   672  }
   673  
   674  func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
   675  	server := createDefaultAwsTestServer()
   676  	ts := httptest.NewServer(server)
   677  
   678  	opts := cloneTestOpts()
   679  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   680  
   681  	oldGetenv := getenv
   682  	oldNow := Now
   683  	defer func() {
   684  		getenv = oldGetenv
   685  		Now = oldNow
   686  	}()
   687  	getenv = setEnvironment(map[string]string{
   688  		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
   689  		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
   690  		"AWS_REGION":            "us-west-1",
   691  	})
   692  	Now = setTime(defaultTime)
   693  
   694  	base, err := newSubjectTokenProvider(opts)
   695  	if err != nil {
   696  		t.Fatalf("parse() failed %v", err)
   697  	}
   698  
   699  	got, err := base.subjectToken(context.Background())
   700  	if err != nil {
   701  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
   702  	}
   703  
   704  	want := getExpectedSubjectToken(
   705  		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   706  		"us-west-1",
   707  		"AKIDEXAMPLE",
   708  		"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
   709  		"",
   710  	)
   711  
   712  	if got != want {
   713  		t.Errorf("got %q, want %q", got, want)
   714  	}
   715  }
   716  
   717  func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
   718  	server := createDefaultAwsTestServer()
   719  	ts := httptest.NewServer(server)
   720  
   721  	opts := cloneTestOpts()
   722  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   723  
   724  	oldGetenv := getenv
   725  	oldNow := Now
   726  	defer func() {
   727  		getenv = oldGetenv
   728  		Now = oldNow
   729  	}()
   730  	getenv = setEnvironment(map[string]string{
   731  		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
   732  		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
   733  		"AWS_REGION":            "us-west-1",
   734  	})
   735  	Now = setTime(defaultTime)
   736  
   737  	base, err := newSubjectTokenProvider(opts)
   738  	if err != nil {
   739  		t.Fatalf("parse() failed %v", err)
   740  	}
   741  
   742  	got, err := base.subjectToken(context.Background())
   743  	if err != nil {
   744  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
   745  	}
   746  	want := getExpectedSubjectToken(
   747  		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   748  		"us-west-1",
   749  		"AKIDEXAMPLE",
   750  		"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
   751  		"",
   752  	)
   753  
   754  	if got != want {
   755  		t.Errorf("got %q, want %q", got, want)
   756  	}
   757  }
   758  
   759  func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
   760  	server := createDefaultAwsTestServer()
   761  	ts := httptest.NewServer(server)
   762  	opts := cloneTestOpts()
   763  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   764  
   765  	oldGetenv := getenv
   766  	oldNow := Now
   767  	defer func() {
   768  		getenv = oldGetenv
   769  		Now = oldNow
   770  	}()
   771  	getenv = setEnvironment(map[string]string{
   772  		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
   773  		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
   774  		"AWS_REGION":            "us-west-1",
   775  		"AWS_DEFAULT_REGION":    "us-east-1",
   776  	})
   777  	Now = setTime(defaultTime)
   778  
   779  	base, err := newSubjectTokenProvider(opts)
   780  	if err != nil {
   781  		t.Fatalf("parse() failed %v", err)
   782  	}
   783  
   784  	got, err := base.subjectToken(context.Background())
   785  	if err != nil {
   786  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
   787  	}
   788  	want := getExpectedSubjectToken(
   789  		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
   790  		"us-west-1",
   791  		"AKIDEXAMPLE",
   792  		"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
   793  		"",
   794  	)
   795  
   796  	if got != want {
   797  		t.Errorf("got %q, want %q", got, want)
   798  	}
   799  }
   800  
   801  func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
   802  	server := createDefaultAwsTestServer()
   803  	ts := httptest.NewServer(server)
   804  
   805  	opts := cloneTestOpts()
   806  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   807  	opts.CredentialSource.EnvironmentID = "aws3"
   808  
   809  	oldGetenv := getenv
   810  	defer func() {
   811  		getenv = oldGetenv
   812  	}()
   813  	getenv = setEnvironment(map[string]string{})
   814  
   815  	_, err := newSubjectTokenProvider(opts)
   816  	if got, want := err.Error(), "credentials: aws version '3' is not supported in the current build"; got != want {
   817  		t.Errorf("subjectToken = %q, want %q", got, want)
   818  	}
   819  }
   820  
   821  func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
   822  	server := createDefaultAwsTestServer()
   823  	ts := httptest.NewServer(server)
   824  
   825  	opts := cloneTestOpts()
   826  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   827  	opts.CredentialSource.RegionURL = ""
   828  
   829  	oldGetenv := getenv
   830  	defer func() {
   831  		getenv = oldGetenv
   832  	}()
   833  	getenv = setEnvironment(map[string]string{})
   834  
   835  	base, err := newSubjectTokenProvider(opts)
   836  	if err != nil {
   837  		t.Fatalf("parse() failed %v", err)
   838  	}
   839  
   840  	_, err = base.subjectToken(context.Background())
   841  	if err == nil {
   842  		t.Fatalf("retrieveSubjectToken() should have failed")
   843  	}
   844  
   845  	if got, want := err.Error(), "credentials: unable to determine AWS region"; got != want {
   846  		t.Errorf("subjectToken = %q, want %q", got, want)
   847  	}
   848  }
   849  
   850  func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
   851  	server := createDefaultAwsTestServer()
   852  	ts := httptest.NewServer(server)
   853  	server.WriteRegion = notFound
   854  
   855  	opts := cloneTestOpts()
   856  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   857  
   858  	oldGetenv := getenv
   859  	defer func() {
   860  		getenv = oldGetenv
   861  	}()
   862  	getenv = setEnvironment(map[string]string{})
   863  
   864  	base, err := newSubjectTokenProvider(opts)
   865  	if err != nil {
   866  		t.Fatalf("parse() failed %v", err)
   867  	}
   868  
   869  	_, err = base.subjectToken(context.Background())
   870  	if err == nil {
   871  		t.Fatalf("retrieveSubjectToken() should have failed")
   872  	}
   873  
   874  	if got, want := err.Error(), "credentials: unable to retrieve AWS region - Not Found"; got != want {
   875  		t.Errorf("subjectToken = %q, want %q", got, want)
   876  	}
   877  }
   878  
   879  func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
   880  	server := createDefaultAwsTestServer()
   881  	ts := httptest.NewServer(server)
   882  	server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
   883  		w.Write([]byte("{}"))
   884  	}
   885  
   886  	opts := cloneTestOpts()
   887  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   888  
   889  	oldGetenv := getenv
   890  	defer func() {
   891  		getenv = oldGetenv
   892  	}()
   893  	getenv = setEnvironment(map[string]string{})
   894  
   895  	base, err := newSubjectTokenProvider(opts)
   896  	if err != nil {
   897  		t.Fatalf("parse() failed %v", err)
   898  	}
   899  
   900  	_, err = base.subjectToken(context.Background())
   901  	if err == nil {
   902  		t.Fatalf("retrieveSubjectToken() should have failed")
   903  	}
   904  
   905  	if got, want := err.Error(), "credentials: missing AccessKeyId credential"; got != want {
   906  		t.Errorf("subjectToken = %q, want %q", got, want)
   907  	}
   908  }
   909  
   910  func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
   911  	server := createDefaultAwsTestServer()
   912  	ts := httptest.NewServer(server)
   913  	server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
   914  		w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
   915  	}
   916  
   917  	opts := cloneTestOpts()
   918  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   919  
   920  	oldGetenv := getenv
   921  	defer func() {
   922  		getenv = oldGetenv
   923  	}()
   924  	getenv = setEnvironment(map[string]string{})
   925  
   926  	base, err := newSubjectTokenProvider(opts)
   927  	if err != nil {
   928  		t.Fatalf("parse() failed %v", err)
   929  	}
   930  
   931  	_, err = base.subjectToken(context.Background())
   932  	if err == nil {
   933  		t.Fatalf("retrieveSubjectToken() should have failed")
   934  	}
   935  
   936  	if got, want := err.Error(), "credentials: missing SecretAccessKey credential"; got != want {
   937  		t.Errorf("subjectToken = %q, want %q", got, want)
   938  	}
   939  }
   940  
   941  func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
   942  	server := createDefaultAwsTestServer()
   943  	ts := httptest.NewServer(server)
   944  
   945  	opts := cloneTestOpts()
   946  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   947  	opts.CredentialSource.URL = ""
   948  
   949  	oldGetenv := getenv
   950  	defer func() {
   951  		getenv = oldGetenv
   952  	}()
   953  	getenv = setEnvironment(map[string]string{})
   954  
   955  	base, err := newSubjectTokenProvider(opts)
   956  	if err != nil {
   957  		t.Fatalf("parse() failed %v", err)
   958  	}
   959  
   960  	_, err = base.subjectToken(context.Background())
   961  	if err == nil {
   962  		t.Fatalf("retrieveSubjectToken() should have failed")
   963  	}
   964  
   965  	if got, want := err.Error(), "credentials: unable to determine the AWS metadata server security credentials endpoint"; got != want {
   966  		t.Errorf("subjectToken = %q, want %q", got, want)
   967  	}
   968  }
   969  
   970  func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
   971  	server := createDefaultAwsTestServer()
   972  	ts := httptest.NewServer(server)
   973  	server.WriteRolename = notFound
   974  
   975  	opts := cloneTestOpts()
   976  	opts.CredentialSource = server.getCredentialSource(ts.URL)
   977  
   978  	oldGetenv := getenv
   979  	defer func() {
   980  		getenv = oldGetenv
   981  	}()
   982  	getenv = setEnvironment(map[string]string{})
   983  
   984  	base, err := newSubjectTokenProvider(opts)
   985  	if err != nil {
   986  		t.Fatalf("parse() failed %v", err)
   987  	}
   988  
   989  	_, err = base.subjectToken(context.Background())
   990  	if err == nil {
   991  		t.Fatalf("retrieveSubjectToken() should have failed")
   992  	}
   993  
   994  	if got, want := err.Error(), "credentials: unable to retrieve AWS role name - Not Found"; got != want {
   995  		t.Errorf("subjectToken = %q, want %q", got, want)
   996  	}
   997  }
   998  
   999  func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
  1000  	server := createDefaultAwsTestServer()
  1001  	ts := httptest.NewServer(server)
  1002  	server.WriteSecurityCredentials = notFound
  1003  
  1004  	opts := cloneTestOpts()
  1005  	opts.CredentialSource = server.getCredentialSource(ts.URL)
  1006  
  1007  	oldGetenv := getenv
  1008  	defer func() {
  1009  		getenv = oldGetenv
  1010  	}()
  1011  	getenv = setEnvironment(map[string]string{})
  1012  
  1013  	base, err := newSubjectTokenProvider(opts)
  1014  	if err != nil {
  1015  		t.Fatalf("parse() failed %v", err)
  1016  	}
  1017  
  1018  	_, err = base.subjectToken(context.Background())
  1019  	if err == nil {
  1020  		t.Fatalf("retrieveSubjectToken() should have failed")
  1021  	}
  1022  
  1023  	if got, want := err.Error(), "credentials: unable to retrieve AWS security credentials - Not Found"; got != want {
  1024  		t.Errorf("subjectToken = %q, want %q", got, want)
  1025  	}
  1026  }
  1027  
  1028  func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
  1029  	server := createDefaultAwsTestServer()
  1030  	ts := httptest.NewServer(server)
  1031  
  1032  	metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1033  		t.Error("Metadata server should not have been called.")
  1034  	}))
  1035  
  1036  	opts := cloneTestOpts()
  1037  	opts.CredentialSource = server.getCredentialSource(ts.URL)
  1038  	opts.CredentialSource.IMDSv2SessionTokenURL = metadataTs.URL
  1039  
  1040  	oldGetenv := getenv
  1041  	oldNow := Now
  1042  	defer func() {
  1043  		getenv = oldGetenv
  1044  		Now = oldNow
  1045  	}()
  1046  	getenv = setEnvironment(map[string]string{
  1047  		"AWS_ACCESS_KEY_ID":     "AKIDEXAMPLE",
  1048  		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
  1049  		"AWS_REGION":            "us-west-1",
  1050  	})
  1051  	Now = setTime(defaultTime)
  1052  
  1053  	base, err := newSubjectTokenProvider(opts)
  1054  	if err != nil {
  1055  		t.Fatalf("parse() failed %v", err)
  1056  	}
  1057  
  1058  	got, err := base.subjectToken(context.Background())
  1059  	if err != nil {
  1060  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
  1061  	}
  1062  
  1063  	want := getExpectedSubjectToken(
  1064  		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
  1065  		"us-west-1",
  1066  		"AKIDEXAMPLE",
  1067  		"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
  1068  		"",
  1069  	)
  1070  
  1071  	if got != want {
  1072  		t.Errorf("got %q, want %q", got, want)
  1073  	}
  1074  }
  1075  
  1076  func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
  1077  	server := createDefaultAwsTestServerWithImdsv2(t)
  1078  	ts := httptest.NewServer(server)
  1079  
  1080  	opts := cloneTestOpts()
  1081  	opts.CredentialSource = server.getCredentialSource(ts.URL)
  1082  
  1083  	oldGetenv := getenv
  1084  	oldNow := Now
  1085  	defer func() {
  1086  		getenv = oldGetenv
  1087  		Now = oldNow
  1088  	}()
  1089  	getenv = setEnvironment(map[string]string{
  1090  		"AWS_ACCESS_KEY_ID":     accessKeyID,
  1091  		"AWS_SECRET_ACCESS_KEY": secretAccessKey,
  1092  	})
  1093  	Now = setTime(defaultTime)
  1094  
  1095  	base, err := newSubjectTokenProvider(opts)
  1096  	if err != nil {
  1097  		t.Fatalf("parse() failed %v", err)
  1098  	}
  1099  
  1100  	got, err := base.subjectToken(context.Background())
  1101  	if err != nil {
  1102  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
  1103  	}
  1104  
  1105  	want := getExpectedSubjectToken(
  1106  		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
  1107  		"us-east-2",
  1108  		accessKeyID,
  1109  		secretAccessKey,
  1110  		"",
  1111  	)
  1112  
  1113  	if got != want {
  1114  		t.Errorf("got %q, want %q", got, want)
  1115  	}
  1116  }
  1117  
  1118  func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
  1119  	server := createDefaultAwsTestServerWithImdsv2(t)
  1120  	ts := httptest.NewServer(server)
  1121  
  1122  	opts := cloneTestOpts()
  1123  	opts.CredentialSource = server.getCredentialSource(ts.URL)
  1124  
  1125  	oldGetenv := getenv
  1126  	oldNow := Now
  1127  	defer func() {
  1128  		getenv = oldGetenv
  1129  		Now = oldNow
  1130  	}()
  1131  	getenv = setEnvironment(map[string]string{
  1132  		"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
  1133  		"AWS_REGION":            "us-west-1",
  1134  	})
  1135  	Now = setTime(defaultTime)
  1136  
  1137  	base, err := newSubjectTokenProvider(opts)
  1138  	if err != nil {
  1139  		t.Fatalf("parse() failed %v", err)
  1140  	}
  1141  
  1142  	got, err := base.subjectToken(context.Background())
  1143  	if err != nil {
  1144  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
  1145  	}
  1146  
  1147  	want := getExpectedSubjectToken(
  1148  		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
  1149  		"us-west-1",
  1150  		accessKeyID,
  1151  		secretAccessKey,
  1152  		sessionToken,
  1153  	)
  1154  
  1155  	if got != want {
  1156  		t.Errorf("got %q, want %q", got, want)
  1157  	}
  1158  }
  1159  
  1160  func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
  1161  	server := createDefaultAwsTestServerWithImdsv2(t)
  1162  	ts := httptest.NewServer(server)
  1163  
  1164  	opts := cloneTestOpts()
  1165  	opts.CredentialSource = server.getCredentialSource(ts.URL)
  1166  
  1167  	oldGetenv := getenv
  1168  	oldNow := Now
  1169  	defer func() {
  1170  		getenv = oldGetenv
  1171  		Now = oldNow
  1172  	}()
  1173  	getenv = setEnvironment(map[string]string{
  1174  		"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
  1175  		"AWS_REGION":        "us-west-1",
  1176  	})
  1177  	Now = setTime(defaultTime)
  1178  
  1179  	base, err := newSubjectTokenProvider(opts)
  1180  	if err != nil {
  1181  		t.Fatalf("parse() failed %v", err)
  1182  	}
  1183  
  1184  	got, err := base.subjectToken(context.Background())
  1185  	if err != nil {
  1186  		t.Fatalf("retrieveSubjectToken() failed: %v", err)
  1187  	}
  1188  
  1189  	want := getExpectedSubjectToken(
  1190  		"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
  1191  		"us-west-1",
  1192  		accessKeyID,
  1193  		secretAccessKey,
  1194  		sessionToken,
  1195  	)
  1196  
  1197  	if got != want {
  1198  		t.Errorf("got %q, want %q", got, want)
  1199  	}
  1200  }
  1201  
  1202  func TestAWSCredential_Validations(t *testing.T) {
  1203  	var metadataServerValidityTests = []struct {
  1204  		name       string
  1205  		credSource *credsfile.CredentialSource
  1206  		errText    string
  1207  	}{
  1208  		{
  1209  			name: "No Metadata Server URLs",
  1210  			credSource: &credsfile.CredentialSource{
  1211  				EnvironmentID:         "aws1",
  1212  				RegionURL:             "",
  1213  				URL:                   "",
  1214  				IMDSv2SessionTokenURL: "",
  1215  			},
  1216  		}, {
  1217  			name: "IPv4 Metadata Server URLs",
  1218  			credSource: &credsfile.CredentialSource{
  1219  				EnvironmentID:         "aws1",
  1220  				RegionURL:             "http://169.254.169.254/latest/meta-data/placement/availability-zone",
  1221  				URL:                   "http://169.254.169.254/latest/meta-data/iam/security-credentials",
  1222  				IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
  1223  			},
  1224  		}, {
  1225  			name: "IPv6 Metadata Server URLs",
  1226  			credSource: &credsfile.CredentialSource{
  1227  				EnvironmentID:         "aws1",
  1228  				RegionURL:             "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone",
  1229  				URL:                   "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials",
  1230  				IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token",
  1231  			},
  1232  		},
  1233  	}
  1234  
  1235  	for _, tt := range metadataServerValidityTests {
  1236  		t.Run(tt.name, func(t *testing.T) {
  1237  			opts := cloneTestOpts()
  1238  			opts.CredentialSource = tt.credSource
  1239  
  1240  			oldGetenv := getenv
  1241  			defer func() { getenv = oldGetenv }()
  1242  			getenv = setEnvironment(map[string]string{})
  1243  
  1244  			_, err := newSubjectTokenProvider(opts)
  1245  			if err != nil {
  1246  				if tt.errText == "" {
  1247  					t.Errorf("Didn't expect an error, but got %v", err)
  1248  				} else if tt.errText != err.Error() {
  1249  					t.Errorf("got %v, want %v", err, tt.errText)
  1250  				}
  1251  			} else {
  1252  				if tt.errText != "" {
  1253  					t.Errorf("got nil, want %v", tt.errText)
  1254  				}
  1255  			}
  1256  		})
  1257  	}
  1258  }
  1259  
  1260  func TestAWSCredential_ProgrammaticAuth(t *testing.T) {
  1261  	opts := cloneTestOpts()
  1262  	opts.AwsSecurityCredentialsProvider = &fakeAwsCredsProvider{
  1263  		awsRegion: "us-east-2",
  1264  		creds: &AwsSecurityCredentials{
  1265  			AccessKeyID:     accessKeyID,
  1266  			SecretAccessKey: secretAccessKey,
  1267  			SessionToken:    sessionToken,
  1268  		},
  1269  	}
  1270  
  1271  	oldNow := Now
  1272  	defer func() {
  1273  		Now = oldNow
  1274  	}()
  1275  	Now = setTime(defaultTime)
  1276  
  1277  	base, err := newSubjectTokenProvider(opts)
  1278  	if err != nil {
  1279  		t.Fatalf("newSubjectTokenProvider() = %v", err)
  1280  	}
  1281  
  1282  	got, err := base.subjectToken(context.Background())
  1283  	if err != nil {
  1284  		t.Fatalf("subjectToken() = %v", err)
  1285  	}
  1286  
  1287  	want := getExpectedSubjectToken(
  1288  		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
  1289  		"us-east-2",
  1290  		accessKeyID,
  1291  		secretAccessKey,
  1292  		sessionToken,
  1293  	)
  1294  
  1295  	if got != want {
  1296  		t.Errorf("got %q, want %q", got, want)
  1297  	}
  1298  }
  1299  
  1300  func TestAWSCredential_ProgrammaticAuthNoSessionToken(t *testing.T) {
  1301  	opts := cloneTestOpts()
  1302  	opts.AwsSecurityCredentialsProvider = fakeAwsCredsProvider{
  1303  		awsRegion: "us-east-2",
  1304  		creds: &AwsSecurityCredentials{
  1305  			AccessKeyID:     accessKeyID,
  1306  			SecretAccessKey: secretAccessKey,
  1307  		},
  1308  	}
  1309  
  1310  	oldNow := Now
  1311  	defer func() {
  1312  		Now = oldNow
  1313  	}()
  1314  	Now = setTime(defaultTime)
  1315  
  1316  	base, err := newSubjectTokenProvider(opts)
  1317  	if err != nil {
  1318  		t.Fatalf("newSubjectTokenProvider() = %v", err)
  1319  	}
  1320  
  1321  	got, err := base.subjectToken(context.Background())
  1322  	if err != nil {
  1323  		t.Fatalf("subjectToken() = %v", err)
  1324  	}
  1325  
  1326  	want := getExpectedSubjectToken(
  1327  		"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
  1328  		"us-east-2",
  1329  		accessKeyID,
  1330  		secretAccessKey,
  1331  		"",
  1332  	)
  1333  
  1334  	if got != want {
  1335  		t.Errorf("got %q, want %q", got, want)
  1336  	}
  1337  }
  1338  
  1339  func TestAWSCredential_ProgrammaticAuthError(t *testing.T) {
  1340  	opts := cloneTestOpts()
  1341  	testErr := errors.New("test error")
  1342  	opts.AwsSecurityCredentialsProvider = fakeAwsCredsProvider{
  1343  		awsRegion: "us-east-2",
  1344  		credsErr:  testErr,
  1345  	}
  1346  
  1347  	base, err := newSubjectTokenProvider(opts)
  1348  	if err != nil {
  1349  		t.Fatalf("newSubjectTokenProvider() = %v", err)
  1350  	}
  1351  
  1352  	_, gotErr := base.subjectToken(context.Background())
  1353  	if gotErr == nil {
  1354  		t.Fatalf("subjectToken() = nil, want error")
  1355  	}
  1356  	if gotErr != testErr {
  1357  		t.Errorf("got = %v, want %v", err, testErr)
  1358  	}
  1359  }
  1360  
  1361  func TestAWSCredential_ProgrammaticAuthRegionError(t *testing.T) {
  1362  	opts := cloneTestOpts()
  1363  	testErr := errors.New("test error")
  1364  	opts.AwsSecurityCredentialsProvider = fakeAwsCredsProvider{
  1365  		regionErr: testErr,
  1366  		creds: &AwsSecurityCredentials{
  1367  			AccessKeyID:     accessKeyID,
  1368  			SecretAccessKey: secretAccessKey,
  1369  		},
  1370  	}
  1371  
  1372  	base, err := newSubjectTokenProvider(opts)
  1373  	if err != nil {
  1374  		t.Fatalf("newSubjectTokenProvider() = %v", err)
  1375  	}
  1376  
  1377  	_, gotErr := base.subjectToken(context.Background())
  1378  	if gotErr == nil {
  1379  		t.Fatalf("subjectToken() = nil, want error")
  1380  	}
  1381  	if gotErr != testErr {
  1382  		t.Errorf("got = %v, want %v", err, testErr)
  1383  	}
  1384  }
  1385  
  1386  func TestAWSCredential_ProgrammaticAuthOptions(t *testing.T) {
  1387  	opts := cloneTestOpts()
  1388  	wantOpts := &RequestOptions{Audience: opts.Audience, SubjectTokenType: opts.SubjectTokenType}
  1389  
  1390  	opts.AwsSecurityCredentialsProvider = fakeAwsCredsProvider{
  1391  		awsRegion: "us-east-2",
  1392  		creds: &AwsSecurityCredentials{
  1393  			AccessKeyID:     accessKeyID,
  1394  			SecretAccessKey: secretAccessKey,
  1395  		},
  1396  		reqOpts: wantOpts,
  1397  	}
  1398  
  1399  	base, err := newSubjectTokenProvider(opts)
  1400  	if err != nil {
  1401  		t.Fatalf("newSubjectTokenProvider() = %v", err)
  1402  	}
  1403  
  1404  	_, err = base.subjectToken(context.Background())
  1405  	if err != nil {
  1406  		t.Fatalf("subjectToken() = %v", err)
  1407  	}
  1408  }
  1409  
  1410  func setTime(testTime time.Time) func() time.Time {
  1411  	return func() time.Time {
  1412  		return testTime
  1413  	}
  1414  }
  1415  
  1416  func setEnvironment(env map[string]string) func(string) string {
  1417  	return func(key string) string {
  1418  		return env[key]
  1419  	}
  1420  }
  1421  
  1422  var defaultRequestSigner = &awsRequestSigner{
  1423  	RegionName: "us-east-1",
  1424  	AwsSecurityCredentials: &AwsSecurityCredentials{
  1425  		AccessKeyID:     "AKIDEXAMPLE",
  1426  		SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
  1427  	},
  1428  }
  1429  
  1430  func setDefaultTime(req *http.Request) {
  1431  	// Don't use time.Format for this
  1432  	// Our output signature expects this to be a Monday, even though Sept 9, 2011 is a Friday
  1433  	req.Header.Set("date", "Mon, 09 Sep 2011 23:36:00 GMT")
  1434  }
  1435  
  1436  func testRequestSigner(t *testing.T, rs *awsRequestSigner, input, wantOutput *http.Request) {
  1437  	t.Helper()
  1438  
  1439  	err := rs.signRequest(input)
  1440  	if err != nil {
  1441  		t.Fatal(err)
  1442  	}
  1443  
  1444  	if got, want := input.URL.String(), wantOutput.URL.String(); got != want {
  1445  		t.Errorf("url = %q, want %q", got, want)
  1446  	}
  1447  	if got, want := input.Method, wantOutput.Method; got != want {
  1448  		t.Errorf("method = %q, want %q", got, want)
  1449  	}
  1450  	for header := range wantOutput.Header {
  1451  		if got, want := input.Header[header], wantOutput.Header[header]; !cmp.Equal(got, want) {
  1452  			t.Errorf("header[%q] = %q, want %q", header, got, want)
  1453  		}
  1454  	}
  1455  }
  1456  
  1457  type fakeAwsCredsProvider struct {
  1458  	credsErr  error
  1459  	regionErr error
  1460  	awsRegion string
  1461  	creds     *AwsSecurityCredentials
  1462  	reqOpts   *RequestOptions
  1463  }
  1464  
  1465  func (acp fakeAwsCredsProvider) AwsRegion(ctx context.Context, opts *RequestOptions) (string, error) {
  1466  	if acp.regionErr != nil {
  1467  		return "", acp.regionErr
  1468  	}
  1469  	if acp.reqOpts != nil {
  1470  		if acp.reqOpts.Audience != opts.Audience {
  1471  			return "", errors.New("audience does not match")
  1472  		}
  1473  		if acp.reqOpts.SubjectTokenType != opts.SubjectTokenType {
  1474  			return "", errors.New("audience does not match")
  1475  		}
  1476  	}
  1477  	return acp.awsRegion, nil
  1478  }
  1479  
  1480  func (acp fakeAwsCredsProvider) AwsSecurityCredentials(ctx context.Context, opts *RequestOptions) (*AwsSecurityCredentials, error) {
  1481  	if acp.credsErr != nil {
  1482  		return nil, acp.credsErr
  1483  	}
  1484  	if acp.reqOpts != nil {
  1485  		if acp.reqOpts.Audience != opts.Audience {
  1486  			return nil, errors.New("Audience does not match")
  1487  		}
  1488  		if acp.reqOpts.SubjectTokenType != opts.SubjectTokenType {
  1489  			return nil, errors.New("Audience does not match")
  1490  		}
  1491  	}
  1492  	return acp.creds, nil
  1493  }
  1494  

View as plain text