...

Source file src/github.com/aws/aws-sdk-go-v2/config/resolve_credentials_test.go

Documentation: github.com/aws/aws-sdk-go-v2/config

     1  package config
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"runtime"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/aws/aws-sdk-go-v2/aws"
    19  	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
    20  	"github.com/aws/aws-sdk-go-v2/internal/awstesting"
    21  	"github.com/aws/aws-sdk-go-v2/service/sso"
    22  	"github.com/aws/aws-sdk-go-v2/service/sts"
    23  	"github.com/aws/smithy-go"
    24  	"github.com/aws/smithy-go/middleware"
    25  	smithytime "github.com/aws/smithy-go/time"
    26  )
    27  
    28  func swapECSContainerURI(path string) func() {
    29  	o := ecsContainerEndpoint
    30  	ecsContainerEndpoint = path
    31  	return func() {
    32  		ecsContainerEndpoint = o
    33  	}
    34  }
    35  
    36  func setupCredentialsEndpoints(t *testing.T) (aws.EndpointResolverWithOptions, func()) {
    37  	ecsMetadataServer := httptest.NewServer(http.HandlerFunc(
    38  		func(w http.ResponseWriter, r *http.Request) {
    39  			if r.URL.Path == "/ECS" {
    40  				w.Write([]byte(ecsResponse))
    41  			} else {
    42  				w.Write([]byte(""))
    43  			}
    44  		}))
    45  	resetECSEndpoint := swapECSContainerURI(ecsMetadataServer.URL)
    46  
    47  	ec2MetadataServer := httptest.NewServer(http.HandlerFunc(
    48  		func(w http.ResponseWriter, r *http.Request) {
    49  			if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
    50  				w.Write([]byte(ec2MetadataResponse))
    51  			} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
    52  				w.Write([]byte("RoleName"))
    53  			} else if r.URL.Path == "/latest/api/token" {
    54  				header := w.Header()
    55  				// bounce the TTL header
    56  				const ttlHeader = "X-Aws-Ec2-Metadata-Token-Ttl-Seconds"
    57  				header.Set(ttlHeader, r.Header.Get(ttlHeader))
    58  				w.Write([]byte("validToken"))
    59  			} else {
    60  				w.Write([]byte(""))
    61  			}
    62  		}))
    63  
    64  	os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", ec2MetadataServer.URL)
    65  
    66  	stsServer := httptest.NewServer(http.HandlerFunc(
    67  		func(w http.ResponseWriter, r *http.Request) {
    68  			if err := r.ParseForm(); err != nil {
    69  				w.WriteHeader(500)
    70  				return
    71  			}
    72  
    73  			form := r.Form
    74  
    75  			switch form.Get("Action") {
    76  			case "AssumeRole":
    77  				w.Write([]byte(fmt.Sprintf(
    78  					assumeRoleRespMsg,
    79  					smithytime.FormatDateTime(time.Now().
    80  						Add(15*time.Minute)))))
    81  				return
    82  			case "AssumeRoleWithWebIdentity":
    83  				w.Write([]byte(fmt.Sprintf(assumeRoleWithWebIdentityResponse,
    84  					smithytime.FormatDateTime(time.Now().
    85  						Add(15*time.Minute)))))
    86  				return
    87  			default:
    88  				w.WriteHeader(404)
    89  				return
    90  			}
    91  		}))
    92  
    93  	ssoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    94  		w.Write([]byte(fmt.Sprintf(
    95  			getRoleCredentialsResponse,
    96  			time.Now().
    97  				Add(15*time.Minute).
    98  				UnixNano()/int64(time.Millisecond))))
    99  	}))
   100  
   101  	resolver := aws.EndpointResolverWithOptionsFunc(
   102  		func(service, region string, options ...interface{}) (aws.Endpoint, error) {
   103  			switch service {
   104  			case sts.ServiceID:
   105  				return aws.Endpoint{
   106  					URL: stsServer.URL,
   107  				}, nil
   108  			case sso.ServiceID:
   109  				return aws.Endpoint{
   110  					URL: ssoServer.URL,
   111  				}, nil
   112  			default:
   113  				return aws.Endpoint{},
   114  					fmt.Errorf("unknown service endpoint, %s", service)
   115  			}
   116  		})
   117  
   118  	return resolver, func() {
   119  		resetECSEndpoint()
   120  		ecsMetadataServer.Close()
   121  		ec2MetadataServer.Close()
   122  		ssoServer.Close()
   123  		stsServer.Close()
   124  	}
   125  }
   126  
   127  func ssoTestSetup() (fn func(), err error) {
   128  	dir, err := ioutil.TempDir(os.TempDir(), "sso-test")
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	cleanupTestDir := func() {
   134  		os.RemoveAll(dir)
   135  	}
   136  	defer func() {
   137  		if err != nil {
   138  			cleanupTestDir()
   139  		}
   140  	}()
   141  
   142  	cacheDir := filepath.Join(dir, ".aws", "sso", "cache")
   143  	err = os.MkdirAll(cacheDir, 0750)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	tokenFile, err := os.Create(filepath.Join(cacheDir, "eb5e43e71ce87dd92ec58903d76debd8ee42aefd.json"))
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	defer func() {
   154  		closeErr := tokenFile.Close()
   155  		if err == nil {
   156  			err = closeErr
   157  		} else if closeErr != nil {
   158  			err = fmt.Errorf("close error: %v, original error: %w", closeErr, err)
   159  		}
   160  	}()
   161  
   162  	_, err = tokenFile.WriteString(fmt.Sprintf(ssoTokenCacheFile, time.Now().
   163  		Add(15*time.Minute).
   164  		Format(time.RFC3339)))
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	if runtime.GOOS == "windows" {
   170  		os.Setenv("USERPROFILE", dir)
   171  	} else {
   172  		os.Setenv("HOME", dir)
   173  	}
   174  
   175  	return cleanupTestDir, nil
   176  }
   177  
   178  func TestSharedConfigCredentialSource(t *testing.T) {
   179  	var configFileForWindows = filepath.Join("testdata", "config_source_shared_for_windows")
   180  	var configFile = filepath.Join("testdata", "config_source_shared")
   181  
   182  	var credFileForWindows = filepath.Join("testdata", "credentials_source_shared_for_windows")
   183  	var credFile = filepath.Join("testdata", "credentials_source_shared")
   184  
   185  	cases := map[string]struct {
   186  		name                 string
   187  		envProfile           string
   188  		configProfile        string
   189  		expectedError        string
   190  		expectedAccessKey    string
   191  		expectedSecretKey    string
   192  		expectedSessionToken string
   193  		expectedChain        []string
   194  		init                 func() (func(), error)
   195  		dependentOnOS        bool
   196  	}{
   197  		"credential source and source profile": {
   198  			envProfile:    "invalid_source_and_credential_source",
   199  			expectedError: "only one credential type may be specified per profile",
   200  			init: func() (func(), error) {
   201  				os.Setenv("AWS_ACCESS_KEY", "access_key")
   202  				os.Setenv("AWS_SECRET_KEY", "secret_key")
   203  				return func() {}, nil
   204  			},
   205  		},
   206  		"env var credential source": {
   207  			configProfile:        "env_var_credential_source",
   208  			expectedAccessKey:    "AKID",
   209  			expectedSecretKey:    "SECRET",
   210  			expectedSessionToken: "SESSION_TOKEN",
   211  			expectedChain: []string{
   212  				"assume_role_w_creds_role_arn_env",
   213  			},
   214  			init: func() (func(), error) {
   215  				os.Setenv("AWS_ACCESS_KEY", "access_key")
   216  				os.Setenv("AWS_SECRET_KEY", "secret_key")
   217  				return func() {}, nil
   218  			},
   219  		},
   220  		"ec2metadata credential source": {
   221  			envProfile: "ec2metadata",
   222  			expectedChain: []string{
   223  				"assume_role_w_creds_role_arn_ec2",
   224  			},
   225  			expectedAccessKey:    "AKID",
   226  			expectedSecretKey:    "SECRET",
   227  			expectedSessionToken: "SESSION_TOKEN",
   228  		},
   229  		"ecs container credential source": {
   230  			envProfile:           "ecscontainer",
   231  			expectedAccessKey:    "AKID",
   232  			expectedSecretKey:    "SECRET",
   233  			expectedSessionToken: "SESSION_TOKEN",
   234  			expectedChain: []string{
   235  				"assume_role_w_creds_role_arn_ecs",
   236  			},
   237  			init: func() (func(), error) {
   238  				os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
   239  				return func() {}, nil
   240  			},
   241  		},
   242  		"chained assume role with env creds": {
   243  			envProfile:           "chained_assume_role",
   244  			expectedAccessKey:    "AKID",
   245  			expectedSecretKey:    "SECRET",
   246  			expectedSessionToken: "SESSION_TOKEN",
   247  			expectedChain: []string{
   248  				"assume_role_w_creds_role_arn_chain",
   249  				"assume_role_w_creds_role_arn_ec2",
   250  			},
   251  		},
   252  		"credential process with no ARN set": {
   253  			envProfile:        "cred_proc_no_arn_set",
   254  			dependentOnOS:     true,
   255  			expectedAccessKey: "cred_proc_akid",
   256  			expectedSecretKey: "cred_proc_secret",
   257  		},
   258  		"credential process with ARN set": {
   259  			envProfile:           "cred_proc_arn_set",
   260  			dependentOnOS:        true,
   261  			expectedAccessKey:    "AKID",
   262  			expectedSecretKey:    "SECRET",
   263  			expectedSessionToken: "SESSION_TOKEN",
   264  			expectedChain: []string{
   265  				"assume_role_w_creds_proc_role_arn",
   266  			},
   267  		},
   268  		"chained assume role with credential process": {
   269  			envProfile:           "chained_cred_proc",
   270  			dependentOnOS:        true,
   271  			expectedAccessKey:    "AKID",
   272  			expectedSecretKey:    "SECRET",
   273  			expectedSessionToken: "SESSION_TOKEN",
   274  			expectedChain: []string{
   275  				"assume_role_w_creds_proc_source_prof",
   276  			},
   277  		},
   278  		"credential source overrides config source": {
   279  			envProfile:           "credentials_overide",
   280  			expectedAccessKey:    "AKID",
   281  			expectedSecretKey:    "SECRET",
   282  			expectedSessionToken: "SESSION_TOKEN",
   283  			expectedChain: []string{
   284  				"assume_role_w_creds_role_arn_ec2",
   285  			},
   286  			init: func() (func(), error) {
   287  				os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
   288  				return func() {}, nil
   289  			},
   290  		},
   291  		"only credential source": {
   292  			envProfile:           "only_credentials_source",
   293  			expectedAccessKey:    "AKID",
   294  			expectedSecretKey:    "SECRET",
   295  			expectedSessionToken: "SESSION_TOKEN",
   296  			expectedChain: []string{
   297  				"assume_role_w_creds_role_arn_ecs",
   298  			},
   299  			init: func() (func(), error) {
   300  				os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS")
   301  				return func() {}, nil
   302  			},
   303  		},
   304  		"sso credentials": {
   305  			envProfile:           "sso_creds",
   306  			expectedAccessKey:    "SSO_AKID",
   307  			expectedSecretKey:    "SSO_SECRET_KEY",
   308  			expectedSessionToken: "SSO_SESSION_TOKEN",
   309  			init: func() (func(), error) {
   310  				return ssoTestSetup()
   311  			},
   312  		},
   313  		"chained assume role with sso credentials": {
   314  			envProfile:           "source_sso_creds",
   315  			expectedAccessKey:    "AKID",
   316  			expectedSecretKey:    "SECRET",
   317  			expectedSessionToken: "SESSION_TOKEN",
   318  			expectedChain: []string{
   319  				"source_sso_creds_arn",
   320  			},
   321  			init: func() (func(), error) {
   322  				return ssoTestSetup()
   323  			},
   324  		},
   325  		"chained assume role with sso and static credentials": {
   326  			envProfile:           "assume_sso_and_static",
   327  			expectedAccessKey:    "AKID",
   328  			expectedSecretKey:    "SECRET",
   329  			expectedSessionToken: "SESSION_TOKEN",
   330  			expectedChain: []string{
   331  				"assume_sso_and_static_arn",
   332  			},
   333  		},
   334  		"invalid sso configuration": {
   335  			envProfile:    "sso_invalid",
   336  			expectedError: "profile \"sso_invalid\" is configured to use SSO but is missing required configuration: sso_region, sso_start_url",
   337  		},
   338  		"environment credentials with invalid sso": {
   339  			envProfile:        "sso_invalid",
   340  			expectedAccessKey: "access_key",
   341  			expectedSecretKey: "secret_key",
   342  			init: func() (func(), error) {
   343  				os.Setenv("AWS_ACCESS_KEY", "access_key")
   344  				os.Setenv("AWS_SECRET_KEY", "secret_key")
   345  				return func() {}, nil
   346  			},
   347  		},
   348  		"sso mixed with credential process provider": {
   349  			envProfile:           "sso_mixed_credproc",
   350  			expectedAccessKey:    "SSO_AKID",
   351  			expectedSecretKey:    "SSO_SECRET_KEY",
   352  			expectedSessionToken: "SSO_SESSION_TOKEN",
   353  			init: func() (func(), error) {
   354  				return ssoTestSetup()
   355  			},
   356  		},
   357  		"sso mixed with web identity token provider": {
   358  			envProfile:           "sso_mixed_webident",
   359  			expectedAccessKey:    "WEB_IDENTITY_AKID",
   360  			expectedSecretKey:    "WEB_IDENTITY_SECRET",
   361  			expectedSessionToken: "WEB_IDENTITY_SESSION_TOKEN",
   362  		},
   363  		"SSO Session missing region": {
   364  			envProfile:    "sso-session-missing-region",
   365  			expectedError: "profile \"sso-session-missing-region\" is configured to use SSO but is missing required configuration: sso_region",
   366  		},
   367  		"SSO Session mismatched region": {
   368  			envProfile:    "sso-session-mismatched-region",
   369  			expectedError: "sso_region in profile \"sso-session-mismatched-region\" must match sso_region in sso-session",
   370  		},
   371  		"web identity": {
   372  			envProfile:           "webident",
   373  			expectedAccessKey:    "WEB_IDENTITY_AKID",
   374  			expectedSecretKey:    "WEB_IDENTITY_SECRET",
   375  			expectedSessionToken: "WEB_IDENTITY_SESSION_TOKEN",
   376  		},
   377  	}
   378  
   379  	for name, c := range cases {
   380  		t.Run(name, func(t *testing.T) {
   381  			restoreEnv := awstesting.StashEnv()
   382  			defer awstesting.PopEnv(restoreEnv)
   383  
   384  			if c.dependentOnOS && runtime.GOOS == "windows" {
   385  				os.Setenv("AWS_CONFIG_FILE", configFileForWindows)
   386  				os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credFileForWindows)
   387  			} else {
   388  				os.Setenv("AWS_CONFIG_FILE", configFile)
   389  				os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credFile)
   390  			}
   391  
   392  			os.Setenv("AWS_REGION", "us-east-1")
   393  			if len(c.envProfile) != 0 {
   394  				os.Setenv("AWS_PROFILE", c.envProfile)
   395  			}
   396  
   397  			endpointResolver, cleanupFn := setupCredentialsEndpoints(t)
   398  			defer cleanupFn()
   399  
   400  			var cleanup func()
   401  			if c.init != nil {
   402  				var err error
   403  				cleanup, err = c.init()
   404  				if err != nil {
   405  					t.Fatalf("expect no error, got %v", err)
   406  				}
   407  				defer cleanup()
   408  			}
   409  
   410  			var credChain []string
   411  
   412  			loadOptions := []func(*LoadOptions) error{
   413  				WithEndpointResolverWithOptions(endpointResolver),
   414  				WithAPIOptions([]func(*middleware.Stack) error{
   415  					func(stack *middleware.Stack) error {
   416  						return stack.Initialize.Add(middleware.InitializeMiddlewareFunc("GetRoleArns",
   417  							func(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler,
   418  							) (
   419  								out middleware.InitializeOutput, metadata middleware.Metadata, err error,
   420  							) {
   421  								switch v := in.Parameters.(type) {
   422  								case *sts.AssumeRoleInput:
   423  									credChain = append(credChain, *v.RoleArn)
   424  								}
   425  
   426  								return next.HandleInitialize(ctx, in)
   427  							}), middleware.After)
   428  					},
   429  				}),
   430  			}
   431  
   432  			if len(c.configProfile) != 0 {
   433  				loadOptions = append(loadOptions, WithSharedConfigProfile(c.configProfile))
   434  			}
   435  
   436  			config, err := LoadDefaultConfig(context.Background(), loadOptions...)
   437  			if err != nil {
   438  				if len(c.expectedError) > 0 {
   439  					if e, a := c.expectedError, err.Error(); !strings.Contains(a, e) {
   440  						t.Fatalf("expect %v, but got %v", e, a)
   441  					}
   442  					return
   443  				}
   444  				t.Fatalf("expect no error, got %v", err)
   445  			} else if len(c.expectedError) > 0 {
   446  				t.Fatalf("expect error, got none")
   447  			}
   448  
   449  			creds, err := config.Credentials.Retrieve(context.Background())
   450  			if err != nil {
   451  				t.Fatalf("expected no error, but received %v", err)
   452  			}
   453  
   454  			if e, a := c.expectedChain, credChain; !reflect.DeepEqual(e, a) {
   455  				t.Errorf("expected %v, but received %v", e, a)
   456  			}
   457  
   458  			if e, a := c.expectedAccessKey, creds.AccessKeyID; e != a {
   459  				t.Errorf("expected %v, but received %v", e, a)
   460  			}
   461  
   462  			if e, a := c.expectedSecretKey, creds.SecretAccessKey; e != a {
   463  				t.Errorf("expect %v, but received %v", e, a)
   464  			}
   465  
   466  			if e, a := c.expectedSessionToken, creds.SessionToken; e != a {
   467  				t.Errorf("expect %v, got %v", e, a)
   468  			}
   469  		})
   470  	}
   471  }
   472  
   473  func TestResolveCredentialsCacheOptions(t *testing.T) {
   474  	var cfg aws.Config
   475  	var optionsFnCalled bool
   476  
   477  	err := resolveCredentials(context.Background(), &cfg, configs{LoadOptions{
   478  		CredentialsCacheOptions: func(o *aws.CredentialsCacheOptions) {
   479  			optionsFnCalled = true
   480  			o.ExpiryWindow = time.Minute * 5
   481  		},
   482  	}})
   483  	if err != nil {
   484  		t.Fatalf("expect no error, got %v", err)
   485  	}
   486  
   487  	if !optionsFnCalled {
   488  		t.Errorf("expect options to be called")
   489  	}
   490  }
   491  
   492  func TestResolveCredentialsIMDSClient(t *testing.T) {
   493  	expectEnabled := func(t *testing.T, err error) {
   494  		if err == nil {
   495  			t.Fatalf("expect error got none")
   496  		}
   497  		if e, a := "expected HTTP client error", err.Error(); !strings.Contains(a, e) {
   498  			t.Fatalf("expected %v error in %v", e, a)
   499  		}
   500  	}
   501  
   502  	expectDisabled := func(t *testing.T, err error) {
   503  		var oe *smithy.OperationError
   504  		if !errors.As(err, &oe) {
   505  			t.Fatalf("unexpected error: %v", err)
   506  		} else {
   507  			e := errors.Unwrap(oe)
   508  			if e == nil {
   509  				t.Fatalf("unexpected empty operation error: %v", oe)
   510  			} else {
   511  				if !strings.HasPrefix(e.Error(), "access disabled to EC2 IMDS") {
   512  					t.Fatalf("unexpected operation error: %v", oe)
   513  				}
   514  			}
   515  		}
   516  	}
   517  
   518  	testcases := map[string]struct {
   519  		enabledState  imds.ClientEnableState
   520  		envvar        string
   521  		expectedState imds.ClientEnableState
   522  		expectedError func(*testing.T, error)
   523  	}{
   524  		"default no options": {
   525  			expectedState: imds.ClientDefaultEnableState,
   526  			expectedError: expectEnabled,
   527  		},
   528  
   529  		"state enabled": {
   530  			enabledState:  imds.ClientEnabled,
   531  			expectedState: imds.ClientEnabled,
   532  			expectedError: expectEnabled,
   533  		},
   534  		"state disabled": {
   535  			enabledState:  imds.ClientDisabled,
   536  			expectedState: imds.ClientDisabled,
   537  			expectedError: expectDisabled,
   538  		},
   539  
   540  		"env var DISABLED true": {
   541  			envvar:        "true",
   542  			expectedState: imds.ClientDisabled,
   543  			expectedError: expectDisabled,
   544  		},
   545  		"env var DISABLED false": {
   546  			envvar:        "false",
   547  			expectedState: imds.ClientEnabled,
   548  			expectedError: expectEnabled,
   549  		},
   550  
   551  		"option state enabled overrides env var DISABLED true": {
   552  			enabledState:  imds.ClientEnabled,
   553  			envvar:        "true",
   554  			expectedState: imds.ClientEnabled,
   555  			expectedError: expectEnabled,
   556  		},
   557  		"option state disabled overrides env var DISABLED false": {
   558  			enabledState:  imds.ClientDisabled,
   559  			envvar:        "false",
   560  			expectedState: imds.ClientDisabled,
   561  			expectedError: expectDisabled,
   562  		},
   563  	}
   564  
   565  	for name, tc := range testcases {
   566  		t.Run(name, func(t *testing.T) {
   567  			restoreEnv := awstesting.StashEnv()
   568  			defer awstesting.PopEnv(restoreEnv)
   569  
   570  			var httpClient HTTPClient
   571  			if tc.expectedState == imds.ClientDisabled {
   572  				httpClient = stubErrorClient{err: fmt.Errorf("expect HTTP client not to be called")}
   573  			} else {
   574  				httpClient = stubErrorClient{err: fmt.Errorf("expected HTTP client error")}
   575  			}
   576  
   577  			opts := []func(*LoadOptions) error{
   578  				WithRetryer(func() aws.Retryer { return aws.NopRetryer{} }),
   579  				WithHTTPClient(httpClient),
   580  				WithSharedConfigFiles([]string{}),
   581  			}
   582  
   583  			if tc.enabledState != imds.ClientDefaultEnableState {
   584  				opts = append(opts,
   585  					WithEC2IMDSClientEnableState(tc.enabledState),
   586  				)
   587  			}
   588  
   589  			if tc.envvar != "" {
   590  				os.Setenv("AWS_EC2_METADATA_DISABLED", tc.envvar)
   591  			}
   592  
   593  			c, err := LoadDefaultConfig(context.TODO(), opts...)
   594  			if err != nil {
   595  				t.Fatalf("could not load config: %s", err)
   596  			}
   597  
   598  			creds := c.Credentials
   599  
   600  			_, err = creds.Retrieve(context.TODO())
   601  			tc.expectedError(t, err)
   602  		})
   603  	}
   604  }
   605  
   606  type stubErrorClient struct {
   607  	err error
   608  }
   609  
   610  func (c stubErrorClient) Do(*http.Request) (*http.Response, error) { return nil, c.err }
   611  

View as plain text