1 package config
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "io/ioutil"
8 "net/http"
9 "net/http/httptest"
10 "os"
11 "path/filepath"
12 "reflect"
13 "runtime"
14 "strings"
15 "testing"
16 "time"
17
18 "github.com/aws/aws-sdk-go-v2/aws"
19 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
20 "github.com/aws/aws-sdk-go-v2/internal/awstesting"
21 "github.com/aws/aws-sdk-go-v2/service/sso"
22 "github.com/aws/aws-sdk-go-v2/service/sts"
23 "github.com/aws/smithy-go"
24 "github.com/aws/smithy-go/middleware"
25 smithytime "github.com/aws/smithy-go/time"
26 )
27
28 func swapECSContainerURI(path string) func() {
29 o := ecsContainerEndpoint
30 ecsContainerEndpoint = path
31 return func() {
32 ecsContainerEndpoint = o
33 }
34 }
35
36 func setupCredentialsEndpoints(t *testing.T) (aws.EndpointResolverWithOptions, func()) {
37 ecsMetadataServer := httptest.NewServer(http.HandlerFunc(
38 func(w http.ResponseWriter, r *http.Request) {
39 if r.URL.Path == "/ECS" {
40 w.Write([]byte(ecsResponse))
41 } else {
42 w.Write([]byte(""))
43 }
44 }))
45 resetECSEndpoint := swapECSContainerURI(ecsMetadataServer.URL)
46
47 ec2MetadataServer := httptest.NewServer(http.HandlerFunc(
48 func(w http.ResponseWriter, r *http.Request) {
49 if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
50 w.Write([]byte(ec2MetadataResponse))
51 } else if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
52 w.Write([]byte("RoleName"))
53 } else if r.URL.Path == "/latest/api/token" {
54 header := w.Header()
55
56 const ttlHeader = "X-Aws-Ec2-Metadata-Token-Ttl-Seconds"
57 header.Set(ttlHeader, r.Header.Get(ttlHeader))
58 w.Write([]byte("validToken"))
59 } else {
60 w.Write([]byte(""))
61 }
62 }))
63
64 os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", ec2MetadataServer.URL)
65
66 stsServer := httptest.NewServer(http.HandlerFunc(
67 func(w http.ResponseWriter, r *http.Request) {
68 if err := r.ParseForm(); err != nil {
69 w.WriteHeader(500)
70 return
71 }
72
73 form := r.Form
74
75 switch form.Get("Action") {
76 case "AssumeRole":
77 w.Write([]byte(fmt.Sprintf(
78 assumeRoleRespMsg,
79 smithytime.FormatDateTime(time.Now().
80 Add(15*time.Minute)))))
81 return
82 case "AssumeRoleWithWebIdentity":
83 w.Write([]byte(fmt.Sprintf(assumeRoleWithWebIdentityResponse,
84 smithytime.FormatDateTime(time.Now().
85 Add(15*time.Minute)))))
86 return
87 default:
88 w.WriteHeader(404)
89 return
90 }
91 }))
92
93 ssoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94 w.Write([]byte(fmt.Sprintf(
95 getRoleCredentialsResponse,
96 time.Now().
97 Add(15*time.Minute).
98 UnixNano()/int64(time.Millisecond))))
99 }))
100
101 resolver := aws.EndpointResolverWithOptionsFunc(
102 func(service, region string, options ...interface{}) (aws.Endpoint, error) {
103 switch service {
104 case sts.ServiceID:
105 return aws.Endpoint{
106 URL: stsServer.URL,
107 }, nil
108 case sso.ServiceID:
109 return aws.Endpoint{
110 URL: ssoServer.URL,
111 }, nil
112 default:
113 return aws.Endpoint{},
114 fmt.Errorf("unknown service endpoint, %s", service)
115 }
116 })
117
118 return resolver, func() {
119 resetECSEndpoint()
120 ecsMetadataServer.Close()
121 ec2MetadataServer.Close()
122 ssoServer.Close()
123 stsServer.Close()
124 }
125 }
126
127 func ssoTestSetup() (fn func(), err error) {
128 dir, err := ioutil.TempDir(os.TempDir(), "sso-test")
129 if err != nil {
130 return nil, err
131 }
132
133 cleanupTestDir := func() {
134 os.RemoveAll(dir)
135 }
136 defer func() {
137 if err != nil {
138 cleanupTestDir()
139 }
140 }()
141
142 cacheDir := filepath.Join(dir, ".aws", "sso", "cache")
143 err = os.MkdirAll(cacheDir, 0750)
144 if err != nil {
145 return nil, err
146 }
147
148 tokenFile, err := os.Create(filepath.Join(cacheDir, "eb5e43e71ce87dd92ec58903d76debd8ee42aefd.json"))
149 if err != nil {
150 return nil, err
151 }
152
153 defer func() {
154 closeErr := tokenFile.Close()
155 if err == nil {
156 err = closeErr
157 } else if closeErr != nil {
158 err = fmt.Errorf("close error: %v, original error: %w", closeErr, err)
159 }
160 }()
161
162 _, err = tokenFile.WriteString(fmt.Sprintf(ssoTokenCacheFile, time.Now().
163 Add(15*time.Minute).
164 Format(time.RFC3339)))
165 if err != nil {
166 return nil, err
167 }
168
169 if runtime.GOOS == "windows" {
170 os.Setenv("USERPROFILE", dir)
171 } else {
172 os.Setenv("HOME", dir)
173 }
174
175 return cleanupTestDir, nil
176 }
177
178 func TestSharedConfigCredentialSource(t *testing.T) {
179 var configFileForWindows = filepath.Join("testdata", "config_source_shared_for_windows")
180 var configFile = filepath.Join("testdata", "config_source_shared")
181
182 var credFileForWindows = filepath.Join("testdata", "credentials_source_shared_for_windows")
183 var credFile = filepath.Join("testdata", "credentials_source_shared")
184
185 cases := map[string]struct {
186 name string
187 envProfile string
188 configProfile string
189 expectedError string
190 expectedAccessKey string
191 expectedSecretKey string
192 expectedSessionToken string
193 expectedChain []string
194 init func() (func(), error)
195 dependentOnOS bool
196 }{
197 "credential source and source profile": {
198 envProfile: "invalid_source_and_credential_source",
199 expectedError: "only one credential type may be specified per profile",
200 init: func() (func(), error) {
201 os.Setenv("AWS_ACCESS_KEY", "access_key")
202 os.Setenv("AWS_SECRET_KEY", "secret_key")
203 return func() {}, nil
204 },
205 },
206 "env var credential source": {
207 configProfile: "env_var_credential_source",
208 expectedAccessKey: "AKID",
209 expectedSecretKey: "SECRET",
210 expectedSessionToken: "SESSION_TOKEN",
211 expectedChain: []string{
212 "assume_role_w_creds_role_arn_env",
213 },
214 init: func() (func(), error) {
215 os.Setenv("AWS_ACCESS_KEY", "access_key")
216 os.Setenv("AWS_SECRET_KEY", "secret_key")
217 return func() {}, nil
218 },
219 },
220 "ec2metadata credential source": {
221 envProfile: "ec2metadata",
222 expectedChain: []string{
223 "assume_role_w_creds_role_arn_ec2",
224 },
225 expectedAccessKey: "AKID",
226 expectedSecretKey: "SECRET",
227 expectedSessionToken: "SESSION_TOKEN",
228 },
229 "ecs container credential source": {
230 envProfile: "ecscontainer",
231 expectedAccessKey: "AKID",
232 expectedSecretKey: "SECRET",
233 expectedSessionToken: "SESSION_TOKEN",
234 expectedChain: []string{
235 "assume_role_w_creds_role_arn_ecs",
236 },
237 init: func() (func(), error) {
238 os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
239 return func() {}, nil
240 },
241 },
242 "chained assume role with env creds": {
243 envProfile: "chained_assume_role",
244 expectedAccessKey: "AKID",
245 expectedSecretKey: "SECRET",
246 expectedSessionToken: "SESSION_TOKEN",
247 expectedChain: []string{
248 "assume_role_w_creds_role_arn_chain",
249 "assume_role_w_creds_role_arn_ec2",
250 },
251 },
252 "credential process with no ARN set": {
253 envProfile: "cred_proc_no_arn_set",
254 dependentOnOS: true,
255 expectedAccessKey: "cred_proc_akid",
256 expectedSecretKey: "cred_proc_secret",
257 },
258 "credential process with ARN set": {
259 envProfile: "cred_proc_arn_set",
260 dependentOnOS: true,
261 expectedAccessKey: "AKID",
262 expectedSecretKey: "SECRET",
263 expectedSessionToken: "SESSION_TOKEN",
264 expectedChain: []string{
265 "assume_role_w_creds_proc_role_arn",
266 },
267 },
268 "chained assume role with credential process": {
269 envProfile: "chained_cred_proc",
270 dependentOnOS: true,
271 expectedAccessKey: "AKID",
272 expectedSecretKey: "SECRET",
273 expectedSessionToken: "SESSION_TOKEN",
274 expectedChain: []string{
275 "assume_role_w_creds_proc_source_prof",
276 },
277 },
278 "credential source overrides config source": {
279 envProfile: "credentials_overide",
280 expectedAccessKey: "AKID",
281 expectedSecretKey: "SECRET",
282 expectedSessionToken: "SESSION_TOKEN",
283 expectedChain: []string{
284 "assume_role_w_creds_role_arn_ec2",
285 },
286 init: func() (func(), error) {
287 os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
288 return func() {}, nil
289 },
290 },
291 "only credential source": {
292 envProfile: "only_credentials_source",
293 expectedAccessKey: "AKID",
294 expectedSecretKey: "SECRET",
295 expectedSessionToken: "SESSION_TOKEN",
296 expectedChain: []string{
297 "assume_role_w_creds_role_arn_ecs",
298 },
299 init: func() (func(), error) {
300 os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
301 return func() {}, nil
302 },
303 },
304 "sso credentials": {
305 envProfile: "sso_creds",
306 expectedAccessKey: "SSO_AKID",
307 expectedSecretKey: "SSO_SECRET_KEY",
308 expectedSessionToken: "SSO_SESSION_TOKEN",
309 init: func() (func(), error) {
310 return ssoTestSetup()
311 },
312 },
313 "chained assume role with sso credentials": {
314 envProfile: "source_sso_creds",
315 expectedAccessKey: "AKID",
316 expectedSecretKey: "SECRET",
317 expectedSessionToken: "SESSION_TOKEN",
318 expectedChain: []string{
319 "source_sso_creds_arn",
320 },
321 init: func() (func(), error) {
322 return ssoTestSetup()
323 },
324 },
325 "chained assume role with sso and static credentials": {
326 envProfile: "assume_sso_and_static",
327 expectedAccessKey: "AKID",
328 expectedSecretKey: "SECRET",
329 expectedSessionToken: "SESSION_TOKEN",
330 expectedChain: []string{
331 "assume_sso_and_static_arn",
332 },
333 },
334 "invalid sso configuration": {
335 envProfile: "sso_invalid",
336 expectedError: "profile \"sso_invalid\" is configured to use SSO but is missing required configuration: sso_region, sso_start_url",
337 },
338 "environment credentials with invalid sso": {
339 envProfile: "sso_invalid",
340 expectedAccessKey: "access_key",
341 expectedSecretKey: "secret_key",
342 init: func() (func(), error) {
343 os.Setenv("AWS_ACCESS_KEY", "access_key")
344 os.Setenv("AWS_SECRET_KEY", "secret_key")
345 return func() {}, nil
346 },
347 },
348 "sso mixed with credential process provider": {
349 envProfile: "sso_mixed_credproc",
350 expectedAccessKey: "SSO_AKID",
351 expectedSecretKey: "SSO_SECRET_KEY",
352 expectedSessionToken: "SSO_SESSION_TOKEN",
353 init: func() (func(), error) {
354 return ssoTestSetup()
355 },
356 },
357 "sso mixed with web identity token provider": {
358 envProfile: "sso_mixed_webident",
359 expectedAccessKey: "WEB_IDENTITY_AKID",
360 expectedSecretKey: "WEB_IDENTITY_SECRET",
361 expectedSessionToken: "WEB_IDENTITY_SESSION_TOKEN",
362 },
363 "SSO Session missing region": {
364 envProfile: "sso-session-missing-region",
365 expectedError: "profile \"sso-session-missing-region\" is configured to use SSO but is missing required configuration: sso_region",
366 },
367 "SSO Session mismatched region": {
368 envProfile: "sso-session-mismatched-region",
369 expectedError: "sso_region in profile \"sso-session-mismatched-region\" must match sso_region in sso-session",
370 },
371 "web identity": {
372 envProfile: "webident",
373 expectedAccessKey: "WEB_IDENTITY_AKID",
374 expectedSecretKey: "WEB_IDENTITY_SECRET",
375 expectedSessionToken: "WEB_IDENTITY_SESSION_TOKEN",
376 },
377 }
378
379 for name, c := range cases {
380 t.Run(name, func(t *testing.T) {
381 restoreEnv := awstesting.StashEnv()
382 defer awstesting.PopEnv(restoreEnv)
383
384 if c.dependentOnOS && runtime.GOOS == "windows" {
385 os.Setenv("AWS_CONFIG_FILE", configFileForWindows)
386 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credFileForWindows)
387 } else {
388 os.Setenv("AWS_CONFIG_FILE", configFile)
389 os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credFile)
390 }
391
392 os.Setenv("AWS_REGION", "us-east-1")
393 if len(c.envProfile) != 0 {
394 os.Setenv("AWS_PROFILE", c.envProfile)
395 }
396
397 endpointResolver, cleanupFn := setupCredentialsEndpoints(t)
398 defer cleanupFn()
399
400 var cleanup func()
401 if c.init != nil {
402 var err error
403 cleanup, err = c.init()
404 if err != nil {
405 t.Fatalf("expect no error, got %v", err)
406 }
407 defer cleanup()
408 }
409
410 var credChain []string
411
412 loadOptions := []func(*LoadOptions) error{
413 WithEndpointResolverWithOptions(endpointResolver),
414 WithAPIOptions([]func(*middleware.Stack) error{
415 func(stack *middleware.Stack) error {
416 return stack.Initialize.Add(middleware.InitializeMiddlewareFunc("GetRoleArns",
417 func(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler,
418 ) (
419 out middleware.InitializeOutput, metadata middleware.Metadata, err error,
420 ) {
421 switch v := in.Parameters.(type) {
422 case *sts.AssumeRoleInput:
423 credChain = append(credChain, *v.RoleArn)
424 }
425
426 return next.HandleInitialize(ctx, in)
427 }), middleware.After)
428 },
429 }),
430 }
431
432 if len(c.configProfile) != 0 {
433 loadOptions = append(loadOptions, WithSharedConfigProfile(c.configProfile))
434 }
435
436 config, err := LoadDefaultConfig(context.Background(), loadOptions...)
437 if err != nil {
438 if len(c.expectedError) > 0 {
439 if e, a := c.expectedError, err.Error(); !strings.Contains(a, e) {
440 t.Fatalf("expect %v, but got %v", e, a)
441 }
442 return
443 }
444 t.Fatalf("expect no error, got %v", err)
445 } else if len(c.expectedError) > 0 {
446 t.Fatalf("expect error, got none")
447 }
448
449 creds, err := config.Credentials.Retrieve(context.Background())
450 if err != nil {
451 t.Fatalf("expected no error, but received %v", err)
452 }
453
454 if e, a := c.expectedChain, credChain; !reflect.DeepEqual(e, a) {
455 t.Errorf("expected %v, but received %v", e, a)
456 }
457
458 if e, a := c.expectedAccessKey, creds.AccessKeyID; e != a {
459 t.Errorf("expected %v, but received %v", e, a)
460 }
461
462 if e, a := c.expectedSecretKey, creds.SecretAccessKey; e != a {
463 t.Errorf("expect %v, but received %v", e, a)
464 }
465
466 if e, a := c.expectedSessionToken, creds.SessionToken; e != a {
467 t.Errorf("expect %v, got %v", e, a)
468 }
469 })
470 }
471 }
472
473 func TestResolveCredentialsCacheOptions(t *testing.T) {
474 var cfg aws.Config
475 var optionsFnCalled bool
476
477 err := resolveCredentials(context.Background(), &cfg, configs{LoadOptions{
478 CredentialsCacheOptions: func(o *aws.CredentialsCacheOptions) {
479 optionsFnCalled = true
480 o.ExpiryWindow = time.Minute * 5
481 },
482 }})
483 if err != nil {
484 t.Fatalf("expect no error, got %v", err)
485 }
486
487 if !optionsFnCalled {
488 t.Errorf("expect options to be called")
489 }
490 }
491
492 func TestResolveCredentialsIMDSClient(t *testing.T) {
493 expectEnabled := func(t *testing.T, err error) {
494 if err == nil {
495 t.Fatalf("expect error got none")
496 }
497 if e, a := "expected HTTP client error", err.Error(); !strings.Contains(a, e) {
498 t.Fatalf("expected %v error in %v", e, a)
499 }
500 }
501
502 expectDisabled := func(t *testing.T, err error) {
503 var oe *smithy.OperationError
504 if !errors.As(err, &oe) {
505 t.Fatalf("unexpected error: %v", err)
506 } else {
507 e := errors.Unwrap(oe)
508 if e == nil {
509 t.Fatalf("unexpected empty operation error: %v", oe)
510 } else {
511 if !strings.HasPrefix(e.Error(), "access disabled to EC2 IMDS") {
512 t.Fatalf("unexpected operation error: %v", oe)
513 }
514 }
515 }
516 }
517
518 testcases := map[string]struct {
519 enabledState imds.ClientEnableState
520 envvar string
521 expectedState imds.ClientEnableState
522 expectedError func(*testing.T, error)
523 }{
524 "default no options": {
525 expectedState: imds.ClientDefaultEnableState,
526 expectedError: expectEnabled,
527 },
528
529 "state enabled": {
530 enabledState: imds.ClientEnabled,
531 expectedState: imds.ClientEnabled,
532 expectedError: expectEnabled,
533 },
534 "state disabled": {
535 enabledState: imds.ClientDisabled,
536 expectedState: imds.ClientDisabled,
537 expectedError: expectDisabled,
538 },
539
540 "env var DISABLED true": {
541 envvar: "true",
542 expectedState: imds.ClientDisabled,
543 expectedError: expectDisabled,
544 },
545 "env var DISABLED false": {
546 envvar: "false",
547 expectedState: imds.ClientEnabled,
548 expectedError: expectEnabled,
549 },
550
551 "option state enabled overrides env var DISABLED true": {
552 enabledState: imds.ClientEnabled,
553 envvar: "true",
554 expectedState: imds.ClientEnabled,
555 expectedError: expectEnabled,
556 },
557 "option state disabled overrides env var DISABLED false": {
558 enabledState: imds.ClientDisabled,
559 envvar: "false",
560 expectedState: imds.ClientDisabled,
561 expectedError: expectDisabled,
562 },
563 }
564
565 for name, tc := range testcases {
566 t.Run(name, func(t *testing.T) {
567 restoreEnv := awstesting.StashEnv()
568 defer awstesting.PopEnv(restoreEnv)
569
570 var httpClient HTTPClient
571 if tc.expectedState == imds.ClientDisabled {
572 httpClient = stubErrorClient{err: fmt.Errorf("expect HTTP client not to be called")}
573 } else {
574 httpClient = stubErrorClient{err: fmt.Errorf("expected HTTP client error")}
575 }
576
577 opts := []func(*LoadOptions) error{
578 WithRetryer(func() aws.Retryer { return aws.NopRetryer{} }),
579 WithHTTPClient(httpClient),
580 WithSharedConfigFiles([]string{}),
581 }
582
583 if tc.enabledState != imds.ClientDefaultEnableState {
584 opts = append(opts,
585 WithEC2IMDSClientEnableState(tc.enabledState),
586 )
587 }
588
589 if tc.envvar != "" {
590 os.Setenv("AWS_EC2_METADATA_DISABLED", tc.envvar)
591 }
592
593 c, err := LoadDefaultConfig(context.TODO(), opts...)
594 if err != nil {
595 t.Fatalf("could not load config: %s", err)
596 }
597
598 creds := c.Credentials
599
600 _, err = creds.Retrieve(context.TODO())
601 tc.expectedError(t, err)
602 })
603 }
604 }
605
606 type stubErrorClient struct {
607 err error
608 }
609
610 func (c stubErrorClient) Do(*http.Request) (*http.Response, error) { return nil, c.err }
611
View as plain text