1 package ec2rolecreds
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "io/ioutil"
10 "reflect"
11 "strings"
12 "testing"
13 "time"
14
15 "github.com/aws/aws-sdk-go-v2/aws"
16 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
17 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
18 "github.com/aws/aws-sdk-go-v2/internal/sdk"
19 "github.com/aws/smithy-go"
20 "github.com/aws/smithy-go/logging"
21 "github.com/aws/smithy-go/middleware"
22 )
23
24 const credsRespTmpl = `{
25 "Code": "Success",
26 "Type": "AWS-HMAC",
27 "AccessKeyId" : "accessKey",
28 "SecretAccessKey" : "secret",
29 "Token" : "token",
30 "Expiration" : "%s",
31 "LastUpdated" : "2009-11-23T00:00:00Z"
32 }`
33
34 const credsFailRespTmpl = `{
35 "Code": "ErrorCode",
36 "Message": "ErrorMsg",
37 "LastUpdated": "2009-11-23T00:00:00Z"
38 }`
39
40 type mockClient struct {
41 t *testing.T
42 roleName string
43 failAssume bool
44 expireOn string
45 }
46
47 func (c mockClient) GetMetadata(
48 ctx context.Context, params *imds.GetMetadataInput, optFns ...func(*imds.Options),
49 ) (
50 *imds.GetMetadataOutput, error,
51 ) {
52 switch params.Path {
53 case iamSecurityCredsPath:
54 return &imds.GetMetadataOutput{
55 Content: ioutil.NopCloser(strings.NewReader(c.roleName)),
56 }, nil
57
58 case iamSecurityCredsPath + c.roleName:
59 var w strings.Builder
60 if c.failAssume {
61 fmt.Fprintf(&w, credsFailRespTmpl)
62 } else {
63 fmt.Fprintf(&w, credsRespTmpl, c.expireOn)
64 }
65 return &imds.GetMetadataOutput{
66 Content: ioutil.NopCloser(strings.NewReader(w.String())),
67 }, nil
68 default:
69 return nil, fmt.Errorf("unexpected path, %v", params.Path)
70 }
71 }
72
73 var (
74 _ aws.AdjustExpiresByCredentialsCacheStrategy = (*Provider)(nil)
75 _ aws.HandleFailRefreshCredentialsCacheStrategy = (*Provider)(nil)
76 )
77
78 func TestProvider(t *testing.T) {
79 orig := sdk.NowTime
80 defer func() { sdk.NowTime = orig }()
81
82 p := New(func(options *Options) {
83 options.Client = mockClient{
84 roleName: "RoleName",
85 failAssume: false,
86 expireOn: "2014-12-16T01:51:37Z",
87 }
88 })
89
90 creds, err := p.Retrieve(context.Background())
91 if err != nil {
92 t.Fatalf("expect no error, got %v", err)
93 }
94 if e, a := "accessKey", creds.AccessKeyID; e != a {
95 t.Errorf("Expect access key ID to match")
96 }
97 if e, a := "secret", creds.SecretAccessKey; e != a {
98 t.Errorf("Expect secret access key to match")
99 }
100 if e, a := "token", creds.SessionToken; e != a {
101 t.Errorf("Expect session token to match")
102 }
103
104 sdk.NowTime = func() time.Time {
105 return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
106 }
107
108 if creds.Expired() {
109 t.Errorf("Expect not expired")
110 }
111 }
112
113 func TestProvider_FailAssume(t *testing.T) {
114 p := New(func(options *Options) {
115 options.Client = mockClient{
116 roleName: "RoleName",
117 failAssume: true,
118 expireOn: "2014-12-16T01:51:37Z",
119 }
120 })
121
122 creds, err := p.Retrieve(context.Background())
123 if err == nil {
124 t.Fatalf("expect error, got none")
125 }
126
127 var apiErr smithy.APIError
128 if !errors.As(err, &apiErr) {
129 t.Fatalf("expect %T error, got %v", apiErr, err)
130 }
131 if e, a := "ErrorCode", apiErr.ErrorCode(); e != a {
132 t.Errorf("expect %v code, got %v", e, a)
133 }
134 if e, a := "ErrorMsg", apiErr.ErrorMessage(); e != a {
135 t.Errorf("expect %v message, got %v", e, a)
136 }
137
138 nestedErr := errors.Unwrap(apiErr)
139 if nestedErr != nil {
140 t.Fatalf("expect no nested error, got %v", err)
141 }
142
143 if e, a := "", creds.AccessKeyID; e != a {
144 t.Errorf("Expect access key ID to match")
145 }
146 if e, a := "", creds.SecretAccessKey; e != a {
147 t.Errorf("Expect secret access key to match")
148 }
149 if e, a := "", creds.SessionToken; e != a {
150 t.Errorf("Expect session token to match")
151 }
152 }
153
154 func TestProvider_IsExpired(t *testing.T) {
155 orig := sdk.NowTime
156 defer func() { sdk.NowTime = orig }()
157
158 p := New(func(options *Options) {
159 options.Client = mockClient{
160 roleName: "RoleName",
161 failAssume: false,
162 expireOn: "2014-12-16T01:51:37Z",
163 }
164 })
165
166 sdk.NowTime = func() time.Time {
167 return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
168 }
169
170 creds, err := p.Retrieve(context.Background())
171 if err != nil {
172 t.Fatalf("expect no error, got %v", err)
173 }
174 if creds.Expired() {
175 t.Errorf("expect not to be expired")
176 }
177
178 sdk.NowTime = func() time.Time {
179 return time.Date(2014, 12, 16, 1, 55, 37, 0, time.UTC)
180 }
181
182 if !creds.Expired() {
183 t.Errorf("expect to be expired")
184 }
185 }
186
187 type byteReader byte
188
189 func (b byteReader) Read(p []byte) (int, error) {
190 for i := 0; i < len(p); i++ {
191 p[i] = byte(b)
192 }
193 return len(p), nil
194 }
195
196 func TestProvider_HandleFailToRetrieve(t *testing.T) {
197 origTime := sdk.NowTime
198 defer func() { sdk.NowTime = origTime }()
199 sdk.NowTime = func() time.Time {
200 return time.Date(2014, 04, 04, 0, 1, 0, 0, time.UTC)
201 }
202
203 origRand := sdkrand.Reader
204 defer func() { sdkrand.Reader = origRand }()
205 sdkrand.Reader = byteReader(0)
206
207 cases := map[string]struct {
208 creds aws.Credentials
209 err error
210 randReader io.Reader
211 expectCreds aws.Credentials
212 expectErr string
213 expectLogged string
214 }{
215 "expired low": {
216 randReader: byteReader(0),
217 creds: aws.Credentials{
218 CanExpire: true,
219 Expires: sdk.NowTime().Add(-5 * time.Minute),
220 },
221 err: fmt.Errorf("some error"),
222 expectCreds: aws.Credentials{
223 CanExpire: true,
224 Expires: sdk.NowTime().Add(5 * time.Minute),
225 },
226 expectLogged: fmt.Sprintf("again in 5 minutes"),
227 },
228 "expired high": {
229 randReader: byteReader(0xFF),
230 creds: aws.Credentials{
231 CanExpire: true,
232 Expires: sdk.NowTime().Add(-5 * time.Minute),
233 },
234 err: fmt.Errorf("some error"),
235 expectCreds: aws.Credentials{
236 CanExpire: true,
237 Expires: sdk.NowTime().Add(14*time.Minute + 59*time.Second),
238 },
239 expectLogged: fmt.Sprintf("again in 14 minutes"),
240 },
241 "not expired": {
242 randReader: byteReader(0xFF),
243 creds: aws.Credentials{
244 CanExpire: true,
245 Expires: sdk.NowTime().Add(10 * time.Minute),
246 },
247 err: fmt.Errorf("some error"),
248 expectCreds: aws.Credentials{
249 CanExpire: true,
250 Expires: sdk.NowTime().Add(10 * time.Minute),
251 },
252 },
253 "cannot expire": {
254 randReader: byteReader(0xFF),
255 creds: aws.Credentials{
256 CanExpire: false,
257 },
258 err: fmt.Errorf("some error"),
259 expectErr: "some error",
260 },
261 }
262
263 for name, c := range cases {
264 t.Run(name, func(t *testing.T) {
265 sdkrand.Reader = c.randReader
266 if sdkrand.Reader == nil {
267 sdkrand.Reader = byteReader(0)
268 }
269
270 var logBuf bytes.Buffer
271 logger := logging.LoggerFunc(func(class logging.Classification, format string, args ...interface{}) {
272 fmt.Fprintf(&logBuf, string(class)+" "+format, args...)
273 })
274 ctx := middleware.SetLogger(context.Background(), logger)
275
276 p := New()
277 creds, err := p.HandleFailToRefresh(ctx, c.creds, c.err)
278 if err == nil && len(c.expectErr) != 0 {
279 t.Fatalf("expect error %v, got none", c.expectErr)
280 }
281 if err != nil && len(c.expectErr) == 0 {
282 t.Fatalf("expect no error, got %v", err)
283 }
284 if err != nil && !strings.Contains(err.Error(), c.expectErr) {
285 t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
286 }
287 if c.expectErr != "" {
288 return
289 }
290
291 if len(c.expectLogged) != 0 && logBuf.Len() == 0 {
292 t.Errorf("expect %v logged, got none", c.expectLogged)
293 }
294 if e, a := c.expectLogged, logBuf.String(); !strings.Contains(a, e) {
295 t.Errorf("expect %v to be logged in %v", e, a)
296 }
297
298
299 creds.Expires = creds.Expires.Truncate(time.Second)
300
301 if diff := cmpDiff(c.expectCreds, creds); diff != "" {
302 t.Errorf("expect creds match\n%s", diff)
303 }
304 })
305 }
306 }
307
308 func TestProvider_AdjustExpiresBy(t *testing.T) {
309 origTime := sdk.NowTime
310 defer func() { sdk.NowTime = origTime }()
311 sdk.NowTime = func() time.Time {
312 return time.Date(2014, 04, 04, 0, 1, 0, 0, time.UTC)
313 }
314
315 cases := map[string]struct {
316 creds aws.Credentials
317 dur time.Duration
318 expectCreds aws.Credentials
319 }{
320 "modify expires": {
321 creds: aws.Credentials{
322 CanExpire: true,
323 Expires: sdk.NowTime().Add(1 * time.Hour),
324 },
325 dur: -5 * time.Minute,
326 expectCreds: aws.Credentials{
327 CanExpire: true,
328 Expires: sdk.NowTime().Add(55 * time.Minute),
329 },
330 },
331 "expiry too soon": {
332 creds: aws.Credentials{
333 CanExpire: true,
334 Expires: sdk.NowTime().Add(14*time.Minute + 59*time.Second),
335 },
336 dur: -5 * time.Minute,
337 expectCreds: aws.Credentials{
338 CanExpire: true,
339 Expires: sdk.NowTime().Add(14*time.Minute + 59*time.Second),
340 },
341 },
342 "cannot expire": {
343 creds: aws.Credentials{
344 CanExpire: false,
345 },
346 dur: 10 * time.Minute,
347 expectCreds: aws.Credentials{
348 CanExpire: false,
349 },
350 },
351 }
352
353 for name, c := range cases {
354 t.Run(name, func(t *testing.T) {
355 p := New()
356 creds, err := p.AdjustExpiresBy(c.creds, c.dur)
357
358 if err != nil {
359 t.Fatalf("expect no error, got %v", err)
360 }
361
362 if diff := cmpDiff(c.expectCreds, creds); diff != "" {
363 t.Errorf("expect creds match\n%s", diff)
364 }
365 })
366 }
367 }
368
369 func cmpDiff(e, a interface{}) string {
370 if !reflect.DeepEqual(e, a) {
371 return fmt.Sprintf("%v != %v", e, a)
372 }
373 return ""
374 }
375
View as plain text