1 package stscreds_test
2
3 import (
4 "context"
5 "fmt"
6 "testing"
7 "time"
8
9 "github.com/aws/aws-sdk-go-v2/aws"
10 "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
11 "github.com/aws/aws-sdk-go-v2/service/sts"
12 "github.com/aws/aws-sdk-go-v2/service/sts/types"
13 )
14
15 type mockAssumeRole struct {
16 TestInput func(*sts.AssumeRoleInput)
17 }
18
19 func (s *mockAssumeRole) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) {
20 if s.TestInput != nil {
21 s.TestInput(params)
22 }
23 expiry := time.Now().Add(60 * time.Minute)
24
25 return &sts.AssumeRoleOutput{
26 Credentials: &types.Credentials{
27
28 AccessKeyId: params.RoleArn,
29 SecretAccessKey: aws.String("assumedSecretAccessKey"),
30 SessionToken: aws.String("assumedSessionToken"),
31 Expiration: &expiry,
32 },
33 }, nil
34 }
35
36 const roleARN = "00000000000000000000000000000000000"
37 const tokenCode = "00000000000000000000"
38
39 func TestAssumeRoleProvider(t *testing.T) {
40 stub := &mockAssumeRole{}
41 p := stscreds.NewAssumeRoleProvider(stub, roleARN)
42
43 creds, err := p.Retrieve(context.Background())
44 if err != nil {
45 t.Fatalf("Expect no error, %v", err)
46 }
47
48 if e, a := roleARN, creds.AccessKeyID; e != a {
49 t.Errorf("Expect access key ID to be reflected role ARN")
50 }
51 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
52 t.Errorf("Expect secret access key to match")
53 }
54 if e, a := "assumedSessionToken", creds.SessionToken; e != a {
55 t.Errorf("Expect session token to match")
56 }
57 }
58
59 func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
60 stub := &mockAssumeRole{
61 TestInput: func(in *sts.AssumeRoleInput) {
62 if e, a := "0123456789", *in.SerialNumber; e != a {
63 t.Errorf("expect %v, got %v", e, a)
64 }
65 if e, a := tokenCode, *in.TokenCode; e != a {
66 t.Errorf("expect %v, got %v", e, a)
67 }
68 },
69 }
70 p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
71 options.SerialNumber = aws.String("0123456789")
72 options.TokenProvider = func() (string, error) {
73 return tokenCode, nil
74 }
75 })
76
77 creds, err := p.Retrieve(context.Background())
78 if err != nil {
79 t.Fatalf("Expect no error, %v", err)
80 }
81
82 if e, a := roleARN, creds.AccessKeyID; e != a {
83 t.Errorf("Expect access key ID to be reflected role ARN")
84 }
85 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
86 t.Errorf("Expect secret access key to match")
87 }
88 if e, a := "assumedSessionToken", creds.SessionToken; e != a {
89 t.Errorf("Expect session token to match")
90 }
91 }
92
93 func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
94 stub := &mockAssumeRole{
95 TestInput: func(in *sts.AssumeRoleInput) {
96 t.Fatalf("API request should not of been called")
97 },
98 }
99 p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
100 options.SerialNumber = aws.String("0123456789")
101 options.TokenProvider = func() (string, error) {
102 return "", fmt.Errorf("error occurred")
103 }
104 })
105
106 creds, err := p.Retrieve(context.Background())
107 if err == nil {
108 t.Fatalf("expect error, got none")
109 }
110
111 if v := creds.AccessKeyID; len(v) != 0 {
112 t.Errorf("expect zero, got %v", v)
113 }
114 if v := creds.SecretAccessKey; len(v) != 0 {
115 t.Errorf("expect zero, got %v", v)
116 }
117 if v := creds.SessionToken; len(v) != 0 {
118 t.Errorf("expect zero, got %v", v)
119 }
120 }
121
122 func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
123 stub := &mockAssumeRole{
124 TestInput: func(in *sts.AssumeRoleInput) {
125 t.Fatalf("API request should not of been called")
126 },
127 }
128 p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
129 options.SerialNumber = aws.String("0123456789")
130 })
131
132 creds, err := p.Retrieve(context.Background())
133 if err == nil {
134 t.Fatalf("expect error, got none")
135 }
136
137 if v := creds.AccessKeyID; len(v) != 0 {
138 t.Errorf("expect zero, got %v", v)
139 }
140 if v := creds.SecretAccessKey; len(v) != 0 {
141 t.Errorf("expect zero, got %v", v)
142 }
143 if v := creds.SessionToken; len(v) != 0 {
144 t.Errorf("expect zero, got %v", v)
145 }
146 }
147
148 func TestAssumeRoleProvider_WithSourceIdentity(t *testing.T) {
149 const sourceIdentity = "Source-Identity"
150
151 stub := &mockAssumeRole{
152 TestInput: func(in *sts.AssumeRoleInput) {
153 if e, a := sourceIdentity, *in.SourceIdentity; e != a {
154 t.Fatalf("expect %v, got %v", e, a)
155 }
156 },
157 }
158 p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
159 options.SourceIdentity = aws.String(sourceIdentity)
160 })
161
162 creds, err := p.Retrieve(context.Background())
163 if err != nil {
164 t.Fatalf("Expect no error, %v", err)
165 }
166
167 if e, a := roleARN, creds.AccessKeyID; e != a {
168 t.Errorf("Expect access key ID to be reflected role ARN")
169 }
170 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
171 t.Errorf("Expect secret access key to match")
172 }
173 if e, a := "assumedSessionToken", creds.SessionToken; e != a {
174 t.Errorf("Expect session token to match")
175 }
176 }
177
178 func TestAssumeRoleProvider_WithTags(t *testing.T) {
179 stub := &mockAssumeRole{
180 TestInput: func(in *sts.AssumeRoleInput) {
181 if e, a := 1, len(in.Tags); e != a {
182 t.Fatalf("expect %v, got %v", e, a)
183 }
184 tag := in.Tags[0]
185 if e, a := "KEY", *tag.Key; e != a {
186 t.Errorf("expect %v, got %v", e, a)
187 }
188 if e, a := "value", *tag.Value; e != a {
189 t.Errorf("expect %v, got %v", e, a)
190 }
191 },
192 }
193 p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
194 options.Tags = []types.Tag{
195 {
196 Key: aws.String("KEY"),
197 Value: aws.String("value"),
198 },
199 }
200 })
201
202 creds, err := p.Retrieve(context.Background())
203 if err != nil {
204 t.Fatalf("Expect no error, %v", err)
205 }
206
207 if e, a := roleARN, creds.AccessKeyID; e != a {
208 t.Errorf("Expect access key ID to be reflected role ARN")
209 }
210 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
211 t.Errorf("Expect secret access key to match")
212 }
213 if e, a := "assumedSessionToken", creds.SessionToken; e != a {
214 t.Errorf("Expect session token to match")
215 }
216 }
217
218 func TestAssumeRoleProvider_WithTransitiveTagKeys(t *testing.T) {
219 stub := &mockAssumeRole{
220 TestInput: func(in *sts.AssumeRoleInput) {
221 if e, a := 1, len(in.TransitiveTagKeys); e != a {
222 t.Fatalf("expect %v, got %v", e, a)
223 }
224 if e, a := "KEY", in.TransitiveTagKeys[0]; e != a {
225 t.Errorf("expect %v, got %v", e, a)
226 }
227 },
228 }
229 p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
230 options.Tags = []types.Tag{
231 {
232 Key: aws.String("KEY"),
233 Value: aws.String("value"),
234 },
235 }
236 options.TransitiveTagKeys = []string{"KEY"}
237 })
238
239 creds, err := p.Retrieve(context.Background())
240 if err != nil {
241 t.Fatalf("Expect no error, %v", err)
242 }
243
244 if e, a := roleARN, creds.AccessKeyID; e != a {
245 t.Errorf("Expect access key ID to be reflected role ARN")
246 }
247 if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
248 t.Errorf("Expect secret access key to match")
249 }
250 if e, a := "assumedSessionToken", creds.SessionToken; e != a {
251 t.Errorf("Expect session token to match")
252 }
253 }
254
255 func BenchmarkAssumeRoleProvider(b *testing.B) {
256 stub := &mockAssumeRole{}
257 p := stscreds.NewAssumeRoleProvider(stub, roleARN)
258
259 b.ResetTimer()
260 for i := 0; i < b.N; i++ {
261 if _, err := p.Retrieve(context.Background()); err != nil {
262 b.Fatal(err)
263 }
264 }
265 }
266
View as plain text