1 package endpointcreds_test
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "io/ioutil"
10 "net/http"
11 "strings"
12 "testing"
13 "time"
14
15 "github.com/aws/aws-sdk-go-v2/credentials/endpointcreds"
16 "github.com/aws/aws-sdk-go-v2/internal/sdk"
17 "github.com/aws/smithy-go"
18 )
19
20 type mockClient func(*http.Request) (*http.Response, error)
21
22 func (m mockClient) Do(r *http.Request) (*http.Response, error) {
23 return m(r)
24 }
25
26 func TestRetrieveRefreshableCredentials(t *testing.T) {
27 orig := sdk.NowTime
28 defer func() { sdk.NowTime = orig }()
29
30 p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
31 o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
32 expTime := time.Now().UTC().Add(1 * time.Hour).Format("2006-01-02T15:04:05Z")
33
34 return &http.Response{
35 StatusCode: 200,
36 Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{
37 "AccessKeyID": "AKID",
38 "SecretAccessKey": "SECRET",
39 "Token": "TOKEN",
40 "Expiration": "%s"
41 }`, expTime)))),
42 }, nil
43 })
44 })
45 creds, err := p.Retrieve(context.Background())
46
47 if err != nil {
48 t.Fatalf("expect no error, got %v", err)
49 }
50
51 if e, a := "AKID", creds.AccessKeyID; e != a {
52 t.Errorf("expect %v, got %v", e, a)
53 }
54 if e, a := "SECRET", creds.SecretAccessKey; e != a {
55 t.Errorf("expect %v, got %v", e, a)
56 }
57 if e, a := "TOKEN", creds.SessionToken; e != a {
58 t.Errorf("expect %v, got %v", e, a)
59 }
60 if creds.Expired() {
61 t.Errorf("expect not expired")
62 }
63
64 sdk.NowTime = func() time.Time {
65 return time.Now().Add(2 * time.Hour)
66 }
67 if !creds.Expired() {
68 t.Errorf("expect to be expired")
69 }
70 }
71
72 func TestRetrieveStaticCredentials(t *testing.T) {
73 orig := sdk.NowTime
74 defer func() { sdk.NowTime = orig }()
75
76 p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
77 o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
78 return &http.Response{
79 StatusCode: 200,
80 Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
81 "AccessKeyID": "AKID",
82 "SecretAccessKey": "SECRET"
83 }`))),
84 }, nil
85 })
86 })
87 creds, err := p.Retrieve(context.Background())
88
89 if err != nil {
90 t.Fatalf("expect no error, got %v", err)
91 }
92
93 if e, a := "AKID", creds.AccessKeyID; e != a {
94 t.Errorf("expect %v, got %v", e, a)
95 }
96 if e, a := "SECRET", creds.SecretAccessKey; e != a {
97 t.Errorf("expect %v, got %v", e, a)
98 }
99 if v := creds.SessionToken; len(v) != 0 {
100 t.Errorf("expect empty, got %v", v)
101 }
102
103 sdk.NowTime = func() time.Time {
104 return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC)
105 }
106
107 if creds.Expired() {
108 t.Errorf("expect not to be expired")
109 }
110 }
111
112 func TestAuthTokenProvider(t *testing.T) {
113 cases := map[string]struct {
114 AuthToken string
115 AuthTokenProvider endpointcreds.AuthTokenProvider
116 ExpectAuthToken string
117 ExpectError bool
118 }{
119 "AuthToken": {
120 AuthToken: "Basic abc123",
121 ExpectAuthToken: "Basic abc123",
122 },
123 "AuthFileToken": {
124 AuthToken: "Basic abc123",
125 AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
126 return "Hello %20world", nil
127 }),
128 ExpectAuthToken: "Hello %20world",
129 },
130 "RetrieveFileTokenError": {
131 AuthToken: "Basic abc123",
132 AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
133 return "", fmt.Errorf("test error")
134 }),
135 ExpectAuthToken: "Hello %20world",
136 ExpectError: true,
137 },
138 }
139
140 for name, c := range cases {
141 t.Run(name, func(t *testing.T) {
142 orig := sdk.NowTime
143 defer func() { sdk.NowTime = orig }()
144
145 var actualToken string
146 p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
147 o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
148 actualToken = r.Header["Authorization"][0]
149 return &http.Response{
150 StatusCode: 200,
151 Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
152 "AccessKeyID": "AKID",
153 "SecretAccessKey": "SECRET"
154 }`))),
155 }, nil
156 })
157 o.AuthorizationToken = c.AuthToken
158 o.AuthorizationTokenProvider = c.AuthTokenProvider
159 })
160 creds, err := p.Retrieve(context.Background())
161
162 if err != nil && !c.ExpectError {
163 t.Errorf("expect no error, got %v", err)
164 } else if err == nil && c.ExpectError {
165 t.Errorf("expect error, got nil")
166 }
167
168 if c.ExpectError {
169 return
170 }
171
172 if e, a := "AKID", creds.AccessKeyID; e != a {
173 t.Errorf("expect %v, got %v", e, a)
174 }
175 if e, a := "SECRET", creds.SecretAccessKey; e != a {
176 t.Errorf("expect %v, got %v", e, a)
177 }
178 if v := creds.SessionToken; len(v) != 0 {
179 t.Errorf("expect empty, got %v", v)
180 }
181 if e, a := c.ExpectAuthToken, actualToken; e != a {
182 t.Errorf("Expect %v, got %v", e, a)
183 }
184
185 sdk.NowTime = func() time.Time {
186 return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC)
187 }
188
189 if creds.Expired() {
190 t.Errorf("expect not to be expired")
191 }
192 })
193 }
194 }
195
196 func TestFailedRetrieveCredentials(t *testing.T) {
197 p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
198 o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
199 return &http.Response{
200 StatusCode: 400,
201 Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
202 "code": "Error",
203 "message": "Message"
204 }`))),
205 Header: http.Header{
206 "Content-Type": {"application/json"},
207 },
208 }, nil
209 })
210 })
211 creds, err := p.Retrieve(context.Background())
212
213 if err == nil {
214 t.Fatalf("expect error, got none")
215 }
216
217 if e, a := "failed to load credentials", err.Error(); !strings.Contains(a, e) {
218 t.Errorf("expect %v, got %v", e, a)
219 }
220
221 var apiError smithy.APIError
222 if !errors.As(err, &apiError) {
223 t.Fatalf("expect %T error, got %v", apiError, err)
224 }
225 if e, a := "Error", apiError.ErrorCode(); e != a {
226 t.Errorf("expect %v, got %v", e, a)
227 }
228 if e, a := "Message", apiError.ErrorMessage(); e != a {
229 t.Errorf("expect %v, got %v", e, a)
230 }
231
232 if v := creds.AccessKeyID; len(v) != 0 {
233 t.Errorf("expect empty, got %v", v)
234 }
235 if v := creds.SecretAccessKey; len(v) != 0 {
236 t.Errorf("expect empty, got %v", v)
237 }
238 if v := creds.SessionToken; len(v) != 0 {
239 t.Errorf("expect empty, got %v", v)
240 }
241 if creds.Expired() {
242 t.Errorf("expect empty creds not to be expired")
243 }
244 }
245
246 type mockClientN struct {
247 responses []*http.Response
248 index int
249 }
250
251 func (c *mockClientN) Do(r *http.Request) (*http.Response, error) {
252 resp := c.responses[c.index]
253 c.index++
254 return resp, nil
255 }
256
257 func TestRetryHTTPStatusCode(t *testing.T) {
258 expTime := time.Now().UTC().Add(1 * time.Hour).Format("2006-01-02T15:04:05Z")
259 credsResp := fmt.Sprintf(`{"AccessKeyID":"AKID","SecretAccessKey":"SECRET","Token":"TOKEN","Expiration":"%s"}`, expTime)
260
261 p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
262 o.HTTPClient = &mockClientN{
263 responses: []*http.Response{
264 {
265 StatusCode: 429,
266 Body: io.NopCloser(strings.NewReader("You have made too many requests.")),
267 Header: http.Header{
268 "Content-Type": {"text/plain"},
269 },
270 },
271 {
272 StatusCode: 500,
273 Body: io.NopCloser(strings.NewReader("Internal server error.")),
274 Header: http.Header{
275 "Content-Type": {"text/plain"},
276 },
277 },
278 {
279 StatusCode: 200,
280 Body: ioutil.NopCloser(strings.NewReader(credsResp)),
281 Header: http.Header{
282 "Content-Type": {"application/json"},
283 },
284 },
285 },
286 }
287 })
288
289 creds, err := p.Retrieve(context.Background())
290 if err != nil {
291 t.Fatalf("expect no error, got %v", err)
292 }
293
294 if e, a := "AKID", creds.AccessKeyID; e != a {
295 t.Errorf("expect %v, got %v", e, a)
296 }
297 if e, a := "SECRET", creds.SecretAccessKey; e != a {
298 t.Errorf("expect %v, got %v", e, a)
299 }
300 if e, a := "TOKEN", creds.SessionToken; e != a {
301 t.Errorf("expect %v, got %v", e, a)
302 }
303 if creds.Expired() {
304 t.Errorf("expect not expired")
305 }
306
307 sdk.NowTime = func() time.Time {
308 return time.Now().Add(2 * time.Hour)
309 }
310 if !creds.Expired() {
311 t.Errorf("expect to be expired")
312 }
313 }
314
View as plain text