1 package config
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "io/ioutil"
8 "net/http"
9 "os"
10 "strings"
11 "testing"
12 "time"
13
14 "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
15 "github.com/aws/aws-sdk-go-v2/internal/awstesting"
16 )
17
18 func TestAssumeRole(t *testing.T) {
19 restoreEnv := initConfigTestEnv()
20 defer awstesting.PopEnv(restoreEnv)
21
22 os.Setenv("AWS_REGION", "us-east-1")
23 os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
24 os.Setenv("AWS_PROFILE", "assume_role_w_creds")
25
26 client := mockHTTPClient(func(r *http.Request) (*http.Response, error) {
27 return &http.Response{
28 StatusCode: 200,
29 Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(assumeRoleRespMsg,
30 time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))),
31 }, nil
32 })
33
34 config, err := LoadDefaultConfig(context.Background(), WithHTTPClient(client))
35 if err != nil {
36 t.Fatalf("expect no error, got %v", err)
37 }
38
39 creds, err := config.Credentials.Retrieve(context.Background())
40 if err != nil {
41 t.Fatalf("expect no error, got %v", err)
42 }
43 if e, a := "AKID", creds.AccessKeyID; e != a {
44 t.Errorf("expect %v, got %v", e, a)
45 }
46 if e, a := "SECRET", creds.SecretAccessKey; e != a {
47 t.Errorf("expect %v, got %v", e, a)
48 }
49 if e, a := "SESSION_TOKEN", creds.SessionToken; e != a {
50 t.Errorf("expect %v, got %v", e, a)
51 }
52 if e, a := "AssumeRoleProvider", creds.Source; !strings.Contains(a, e) {
53 t.Errorf("expect %v, to be in %v", e, a)
54 }
55 }
56
57 func TestAssumeRole_WithMFA(t *testing.T) {
58 restoreEnv := initConfigTestEnv()
59 defer awstesting.PopEnv(restoreEnv)
60
61 os.Setenv("AWS_REGION", "us-east-1")
62 os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
63 os.Setenv("AWS_PROFILE", "assume_role_w_creds")
64
65 client := mockHTTPClient(func(r *http.Request) (*http.Response, error) {
66 t.Helper()
67
68 if e, a := r.FormValue("SerialNumber"), "0123456789"; e != a {
69 t.Errorf("expect %v, got %v", e, a)
70 }
71 if e, a := r.FormValue("TokenCode"), "tokencode"; e != a {
72 t.Errorf("expect %v, got %v", e, a)
73 }
74 if e, a := "900", r.FormValue("DurationSeconds"); e != a {
75 t.Errorf("expect %v, got %v", e, a)
76 }
77
78 return &http.Response{
79 StatusCode: 200,
80 Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(assumeRoleRespMsg,
81 time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))),
82 }, nil
83 })
84
85 customProviderCalled := false
86 config, err := LoadDefaultConfig(context.Background(),
87 WithHTTPClient(client),
88 WithRegion("us-east-1"),
89 WithSharedConfigProfile("assume_role_w_mfa"),
90 WithAssumeRoleCredentialOptions(func(options *stscreds.AssumeRoleOptions) {
91 options.TokenProvider = func() (string, error) {
92 customProviderCalled = true
93 return "tokencode", nil
94 }
95 }),
96 )
97 if err != nil {
98 t.Fatalf("expect no error, got %v", err)
99 }
100
101 creds, err := config.Credentials.Retrieve(context.Background())
102 if err != nil {
103 t.Fatalf("expect no error, got %v", err)
104 }
105 if !customProviderCalled {
106 t.Errorf("expect true")
107 }
108
109 if e, a := "AKID", creds.AccessKeyID; e != a {
110 t.Errorf("expect %v, got %v", e, a)
111 }
112 if e, a := "SECRET", creds.SecretAccessKey; e != a {
113 t.Errorf("expect %v, got %v", e, a)
114 }
115 if e, a := "SESSION_TOKEN", creds.SessionToken; e != a {
116 t.Errorf("expect %v, got %v", e, a)
117 }
118 if e, a := "AssumeRoleProvider", creds.Source; !strings.Contains(a, e) {
119 t.Errorf("expect %v, to be in %v", e, a)
120 }
121 }
122
123 func TestAssumeRole_WithMFA_NoTokenProvider(t *testing.T) {
124 restoreEnv := initConfigTestEnv()
125 defer awstesting.PopEnv(restoreEnv)
126
127 os.Setenv("AWS_REGION", "us-east-1")
128 os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
129 os.Setenv("AWS_PROFILE", "assume_role_w_creds")
130
131 _, err := LoadDefaultConfig(context.Background(), WithSharedConfigProfile("assume_role_w_mfa"))
132 if e, a := (AssumeRoleTokenProviderNotSetError{}), err; e != a {
133 t.Errorf("expect %v, got %v", e, a)
134 }
135 }
136
137 func TestAssumeRole_InvalidSourceProfile(t *testing.T) {
138
139
140 restoreEnv := initConfigTestEnv()
141 defer awstesting.PopEnv(restoreEnv)
142
143 os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
144 os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile")
145
146 _, err := LoadDefaultConfig(context.Background())
147 if err == nil {
148 t.Fatalf("expect error, got none")
149 }
150
151 expectMsg := "failed to load assume role"
152 if e, a := expectMsg, err.Error(); !strings.Contains(a, e) {
153 t.Errorf("expect %v, to be in %v", e, a)
154 }
155 }
156
157 func TestAssumeRole_ExtendedDuration(t *testing.T) {
158 restoreEnv := initConfigTestEnv()
159 defer awstesting.PopEnv(restoreEnv)
160
161 os.Setenv("AWS_REGION", "us-east-1")
162 os.Setenv("AWS_CONFIG_FILE", testConfigFilename)
163 os.Setenv("AWS_PROFILE", "assume_role_w_creds_ext_dur")
164
165 client := mockHTTPClient(func(r *http.Request) (*http.Response, error) {
166 t.Helper()
167
168 if e, a := "1800", r.FormValue("DurationSeconds"); e != a {
169 t.Errorf("expect %v, got %v", e, a)
170 }
171
172 return &http.Response{
173 StatusCode: 200,
174 Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(
175 assumeRoleRespMsg,
176 time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z"))))),
177 }, nil
178 })
179
180 config, err := LoadDefaultConfig(context.Background(), WithHTTPClient(client))
181 if err != nil {
182 t.Fatalf("expect no error, got %v", err)
183 }
184
185 creds, err := config.Credentials.Retrieve(context.Background())
186 if err != nil {
187 t.Fatalf("expect no error, got %v", err)
188 }
189 if e, a := "AKID", creds.AccessKeyID; e != a {
190 t.Errorf("expect %v, got %v", e, a)
191 }
192 if e, a := "SECRET", creds.SecretAccessKey; e != a {
193 t.Errorf("expect %v, got %v", e, a)
194 }
195 if e, a := "SESSION_TOKEN", creds.SessionToken; e != a {
196 t.Errorf("expect %v, got %v", e, a)
197 }
198 if e, a := "AssumeRoleProvider", creds.Source; !strings.Contains(a, e) {
199 t.Errorf("expect %v, to be in %v", e, a)
200 }
201 }
202
View as plain text