1 package imds
2
3 import (
4 "bytes"
5 "context"
6 "encoding/hex"
7 "fmt"
8 "io"
9 "io/ioutil"
10 "net/http"
11 "net/http/httptest"
12 "reflect"
13 "strings"
14 "testing"
15 "time"
16
17 "github.com/aws/aws-sdk-go-v2/aws"
18
19 "github.com/aws/aws-sdk-go-v2/internal/awstesting"
20 "github.com/aws/aws-sdk-go-v2/internal/sdk"
21 "github.com/aws/smithy-go/middleware"
22 smithyhttp "github.com/aws/smithy-go/transport/http"
23 )
24
25 func TestAddRequestMiddleware(t *testing.T) {
26 cases := map[string]struct {
27 AddMiddleware func(*middleware.Stack, Options) error
28 ExpectInitialize []string
29 ExpectSerialize []string
30 ExpectBuild []string
31 ExpectFinalize []string
32 ExpectDeserialize []string
33 }{
34 "api request": {
35 AddMiddleware: func(stack *middleware.Stack, options Options) error {
36 return addAPIRequestMiddleware(stack, options,
37 "TestRequest",
38 func(interface{}) (string, error) {
39 return "/mockPath", nil
40 },
41 func(*smithyhttp.Response) (interface{}, error) {
42 return struct{}{}, nil
43 },
44 )
45 },
46 ExpectInitialize: []string{
47 (*operationTimeout)(nil).ID(),
48 "SetLogger",
49 },
50 ExpectSerialize: []string{
51 "ResolveEndpoint",
52 "OperationSerializer",
53 },
54 ExpectBuild: []string{
55 "UserAgent",
56 },
57 ExpectFinalize: []string{
58 "ResolveAuthScheme",
59 "GetIdentity",
60 "ResolveEndpointV2",
61 "Retry",
62 "APITokenProvider",
63 "RetryMetricsHeader",
64 "Signing",
65 },
66 ExpectDeserialize: []string{
67 "APITokenProvider",
68 "OperationDeserializer",
69 "RequestResponseLogger",
70 },
71 },
72
73 "base request": {
74 AddMiddleware: func(stack *middleware.Stack, options Options) error {
75 return addRequestMiddleware(stack, options, "POST", "TestRequest",
76 func(interface{}) (string, error) {
77 return "/mockPath", nil
78 },
79 func(*smithyhttp.Response) (interface{}, error) {
80 return struct{}{}, nil
81 },
82 )
83 },
84 ExpectInitialize: []string{
85 (*operationTimeout)(nil).ID(),
86 "SetLogger",
87 },
88 ExpectSerialize: []string{
89 "ResolveEndpoint",
90 "OperationSerializer",
91 },
92 ExpectBuild: []string{
93 "UserAgent",
94 },
95 ExpectFinalize: []string{
96 "ResolveAuthScheme",
97 "GetIdentity",
98 "ResolveEndpointV2",
99 "Retry",
100 "RetryMetricsHeader",
101 "Signing",
102 },
103 ExpectDeserialize: []string{
104 "OperationDeserializer",
105 "RequestResponseLogger",
106 },
107 },
108 }
109
110 for name, c := range cases {
111 t.Run(name, func(t *testing.T) {
112 client := New(Options{})
113
114 stack := middleware.NewStack("mockOp", smithyhttp.NewStackRequest)
115
116 if err := c.AddMiddleware(stack, client.options); err != nil {
117 t.Fatalf("expect no error adding middleware, got %v", err)
118 }
119
120 if diff := cmpDiff(c.ExpectInitialize, stack.Initialize.List()); len(diff) != 0 {
121 t.Errorf("expect initialize middleware\n%s", diff)
122 }
123
124 if diff := cmpDiff(c.ExpectSerialize, stack.Serialize.List()); len(diff) != 0 {
125 t.Errorf("expect serialize middleware\n%s", diff)
126 }
127
128 if diff := cmpDiff(c.ExpectBuild, stack.Build.List()); len(diff) != 0 {
129 t.Errorf("expect build middleware\n%s", diff)
130 }
131
132 if diff := cmpDiff(c.ExpectFinalize, stack.Finalize.List()); len(diff) != 0 {
133 t.Errorf("expect finalize middleware\n%s", diff)
134 }
135
136 if diff := cmpDiff(c.ExpectDeserialize, stack.Deserialize.List()); len(diff) != 0 {
137 t.Errorf("expect deserialize middleware\n%s", diff)
138 }
139 })
140 }
141 }
142
143 func TestOperationTimeoutMiddleware(t *testing.T) {
144 m := &operationTimeout{
145 DefaultTimeout: time.Nanosecond,
146 }
147
148 _, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
149 middleware.InitializeHandlerFunc(func(
150 ctx context.Context, input middleware.InitializeInput,
151 ) (
152 out middleware.InitializeOutput, metadata middleware.Metadata, err error,
153 ) {
154 if _, ok := ctx.Deadline(); !ok {
155 return out, metadata, fmt.Errorf("expect context deadline to be set")
156 }
157
158 if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
159 return out, metadata, err
160 }
161
162 return out, metadata, nil
163 }))
164 if err == nil {
165 t.Fatalf("expect error got none")
166 }
167
168 if e, a := "deadline exceeded", err.Error(); !strings.Contains(a, e) {
169 t.Errorf("expect %q error in %q", e, a)
170 }
171 }
172
173 func TestOperationTimeoutMiddleware_noDefaultTimeout(t *testing.T) {
174 m := &operationTimeout{}
175
176 _, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
177 middleware.InitializeHandlerFunc(func(
178 ctx context.Context, input middleware.InitializeInput,
179 ) (
180 out middleware.InitializeOutput, metadata middleware.Metadata, err error,
181 ) {
182 if t, ok := ctx.Deadline(); ok {
183 return out, metadata, fmt.Errorf("expect no context deadline, got %v", t)
184 }
185
186 return out, metadata, nil
187 }))
188 if err != nil {
189 t.Fatalf("expect no error, got %v", err)
190 }
191 }
192
193 func TestOperationTimeoutMiddleware_withCustomDeadline(t *testing.T) {
194 m := &operationTimeout{
195 DefaultTimeout: time.Nanosecond,
196 }
197
198 expectDeadline := time.Now().Add(time.Hour)
199 ctx, cancelFn := context.WithDeadline(context.Background(), expectDeadline)
200 defer cancelFn()
201
202 _, _, err := m.HandleInitialize(ctx, middleware.InitializeInput{},
203 middleware.InitializeHandlerFunc(func(
204 ctx context.Context, input middleware.InitializeInput,
205 ) (
206 out middleware.InitializeOutput, metadata middleware.Metadata, err error,
207 ) {
208 t, ok := ctx.Deadline()
209 if !ok {
210 return out, metadata, fmt.Errorf("expect context deadline to be set")
211 }
212 if e, a := expectDeadline, t; !e.Equal(a) {
213 return out, metadata, fmt.Errorf("expect %v deadline, got %v", e, a)
214 }
215
216 return out, metadata, nil
217 }))
218 if err != nil {
219 t.Fatalf("expect no error, got %v", err)
220 }
221 }
222
223 func TestOperationTimeoutMiddleware_Disabled(t *testing.T) {
224 m := &operationTimeout{
225 Disabled: true,
226 DefaultTimeout: time.Nanosecond,
227 }
228
229 _, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
230 middleware.InitializeHandlerFunc(func(
231 ctx context.Context, input middleware.InitializeInput,
232 ) (
233 out middleware.InitializeOutput, metadata middleware.Metadata, err error,
234 ) {
235 if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
236 return out, metadata, err
237 }
238
239 return out, metadata, nil
240 }))
241 if err != nil {
242 t.Fatalf("expect no error, got %v", err)
243 }
244 }
245
246
247
248
249
250 func TestDeserailizeResponse_cacheBody(t *testing.T) {
251 type Output struct {
252 Content io.ReadCloser
253 }
254 m := &deserializeResponse{
255 GetOutput: func(resp *smithyhttp.Response) (interface{}, error) {
256 return &Output{
257 Content: resp.Body,
258 }, nil
259 },
260 }
261
262 expectBody := "hello world!"
263 originalBody := &bytesReader{
264 reader: strings.NewReader(expectBody),
265 }
266 if originalBody.closed {
267 t.Fatalf("expect original body not to be closed yet")
268 }
269
270 out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{},
271 middleware.DeserializeHandlerFunc(func(
272 ctx context.Context, input middleware.DeserializeInput,
273 ) (
274 out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
275 ) {
276 out.RawResponse = &smithyhttp.Response{
277 Response: &http.Response{
278 StatusCode: 200,
279 Status: "200 OK",
280 Header: http.Header{},
281 ContentLength: int64(originalBody.Len()),
282 Body: originalBody,
283 },
284 }
285 return out, metadata, nil
286 }))
287 if err != nil {
288 t.Fatalf("expect no error, got %v", err)
289 }
290
291 if !originalBody.closed {
292 t.Errorf("expect original body to be closed, was not")
293 }
294
295 result, ok := out.Result.(*Output)
296 if !ok {
297 t.Fatalf("expect result to be Output, got %T, %v", result, result)
298 }
299
300 actualBody, err := ioutil.ReadAll(result.Content)
301 if err != nil {
302 t.Fatalf("expect no error, got %v", err)
303 }
304 if e, a := expectBody, string(actualBody); e != a {
305 t.Errorf("expect %v body, got %v", e, a)
306 }
307 if err := result.Content.Close(); err != nil {
308 t.Fatalf("expect no error, got %v", err)
309 }
310 }
311
312 type bytesReader struct {
313 reader interface {
314 io.Reader
315 Len() int
316 }
317 closed bool
318 }
319
320 func (r *bytesReader) Len() int {
321 return r.reader.Len()
322 }
323 func (r *bytesReader) Close() error {
324 r.closed = true
325 return nil
326 }
327 func (r *bytesReader) Read(p []byte) (int, error) {
328 if r.closed {
329 return 0, io.EOF
330 }
331 return r.reader.Read(p)
332 }
333
334 type successAPIResponseHandler struct {
335 t *testing.T
336 path string
337 method string
338
339
340 statusCode int
341 header http.Header
342 body []byte
343 }
344
345 func (h *successAPIResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
346 if e, a := h.path, r.URL.Path; e != a {
347 h.t.Errorf("expect %v path, got %v", e, a)
348 }
349 if e, a := h.method, r.Method; e != a {
350 h.t.Errorf("expect %v method, got %v", e, a)
351 }
352
353 for k, vs := range h.header {
354 for _, v := range vs {
355 w.Header().Add(k, v)
356 }
357 }
358
359 if h.statusCode != 0 {
360 w.WriteHeader(h.statusCode)
361 }
362 w.Write(h.body)
363 }
364
365 func TestRequestGetToken(t *testing.T) {
366 cases := map[string]struct {
367 GetHandler func(*testing.T) http.Handler
368 APICallCount int
369 ExpectTrace []string
370 ExpectContent []byte
371 ExpectErr string
372 EnableFallback aws.Ternary
373 }{
374 "secure": {
375 ExpectTrace: []string{
376 getTokenPath,
377 "/latest/foo",
378 "/latest/foo",
379 },
380 APICallCount: 2,
381 GetHandler: func(t *testing.T) http.Handler {
382 return newTestServeMux(t,
383 newSecureAPIHandler(t,
384 []string{"tokenA"},
385 5*time.Minute,
386 &successAPIResponseHandler{t: t,
387 path: "/latest/foo",
388 method: "GET",
389 body: []byte("hello"),
390 },
391 ))
392 },
393 ExpectContent: []byte("hello"),
394 },
395
396 "secure multi token": {
397 ExpectTrace: []string{
398 getTokenPath,
399 "/latest/foo",
400 getTokenPath,
401 "/latest/foo",
402 getTokenPath,
403 "/latest/foo",
404 getTokenPath,
405 "/latest/foo",
406 },
407 APICallCount: 4,
408 GetHandler: func(t *testing.T) http.Handler {
409 return newTestServeMux(t,
410 newSecureAPIHandler(t,
411 []string{"tokenA", "tokenB", "tokenC"},
412 1,
413 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
414 h := &successAPIResponseHandler{t: t,
415 path: "/latest/foo",
416 method: "GET",
417 body: []byte("hello"),
418 }
419
420 time.Sleep(100 * time.Millisecond)
421 h.ServeHTTP(w, r)
422 }),
423 ))
424 },
425 ExpectContent: []byte("hello"),
426 },
427
428
429 "insecure 405": {
430 ExpectTrace: []string{
431 getTokenPath,
432 "/latest/foo",
433 "/latest/foo",
434 },
435 APICallCount: 2,
436 GetHandler: func(t *testing.T) http.Handler {
437 return newTestServeMux(t,
438 newInsecureAPIHandler(t,
439 405,
440 &successAPIResponseHandler{t: t,
441 path: "/latest/foo",
442 method: "GET",
443 body: []byte("hello"),
444 },
445 ))
446 },
447 ExpectContent: []byte("hello"),
448 },
449
450 "insecure 404": {
451 ExpectTrace: []string{
452 getTokenPath,
453 "/latest/foo",
454 "/latest/foo",
455 },
456 APICallCount: 2,
457 GetHandler: func(t *testing.T) http.Handler {
458 return newTestServeMux(t,
459 newInsecureAPIHandler(t,
460 404,
461 &successAPIResponseHandler{t: t,
462 path: "/latest/foo",
463 method: "GET",
464 body: []byte("hello"),
465 },
466 ))
467 },
468 ExpectContent: []byte("hello"),
469 },
470
471 "insecure 403": {
472 ExpectTrace: []string{
473 getTokenPath,
474 "/latest/foo",
475 "/latest/foo",
476 },
477 APICallCount: 2,
478 GetHandler: func(t *testing.T) http.Handler {
479 return newTestServeMux(t,
480 newInsecureAPIHandler(t,
481 403,
482 &successAPIResponseHandler{t: t,
483 path: "/latest/foo",
484 method: "GET",
485 body: []byte("hello"),
486 },
487 ))
488 },
489 ExpectContent: []byte("hello"),
490 },
491
492
493 "unauthorized 401 re-enable": {
494 ExpectTrace: []string{
495 getTokenPath,
496 "/latest/foo",
497 getTokenPath,
498 "/latest/foo",
499 "/latest/foo",
500 },
501 APICallCount: 2,
502 GetHandler: func(t *testing.T) http.Handler {
503 return newTestServeMux(t,
504 newUnauthorizedAPIHandler(t,
505 newSecureAPIHandler(t,
506 []string{"tokenA"},
507 5*time.Minute,
508 &successAPIResponseHandler{t: t,
509 path: "/latest/foo",
510 method: "GET",
511 body: []byte("hello"),
512 },
513 )))
514 },
515 ExpectContent: []byte("hello"),
516 },
517
518
519 "bad request 400": {
520 ExpectTrace: []string{
521 getTokenPath,
522 },
523 APICallCount: 1,
524 GetHandler: func(t *testing.T) http.Handler {
525 return newTestServeMux(t,
526 newInsecureAPIHandler(t,
527 400,
528 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
529 t.Errorf("expected no call to API handler")
530 http.Error(w, "", 400)
531 }),
532 ))
533 },
534 ExpectErr: "failed to get API token",
535 },
536
537
538 "token failure fallback enabled": {
539 ExpectTrace: []string{
540 getTokenPath,
541 getTokenPath,
542 getTokenPath,
543 "/latest/foo",
544 },
545 APICallCount: 1,
546 GetHandler: func(t *testing.T) http.Handler {
547 return newTestServeMux(t,
548 newInsecureAPIHandler(t,
549 500,
550 &successAPIResponseHandler{t: t,
551 path: "/latest/foo",
552 method: "GET",
553 body: []byte("hello"),
554 },
555 ))
556 },
557 ExpectContent: []byte("hello"),
558 },
559
560 "token failure fallback disabled": {
561 ExpectTrace: []string{
562 getTokenPath,
563 getTokenPath,
564 getTokenPath,
565 },
566 APICallCount: 1,
567 GetHandler: func(t *testing.T) http.Handler {
568 return newTestServeMux(t,
569 newInsecureAPIHandler(t,
570 500,
571 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
572 t.Errorf("expected no call to API handler")
573 http.Error(w, "", 400)
574 }),
575 ))
576 },
577 ExpectErr: "failed to get API token",
578 EnableFallback: aws.BoolTernary(false),
579 },
580 "insecure 403 fallback disabled": {
581 ExpectTrace: []string{
582 getTokenPath,
583 },
584 APICallCount: 1,
585 GetHandler: func(t *testing.T) http.Handler {
586 return newTestServeMux(t,
587 newInsecureAPIHandler(t,
588 403,
589 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
590 t.Errorf("expected no call to API handler")
591 http.Error(w, "", 400)
592 }),
593 ))
594 },
595 ExpectErr: "failed to get API token",
596 EnableFallback: aws.BoolTernary(false),
597 },
598 }
599
600 type mockRequestOutput struct {
601 Content io.ReadCloser
602 }
603
604 for name, c := range cases {
605 t.Run(name, func(t *testing.T) {
606 envs := awstesting.StashEnv()
607 defer awstesting.PopEnv(envs)
608
609 trace := newRequestTrace()
610 server := httptest.NewServer(trace.WrapHandler(c.GetHandler(t)))
611 defer server.Close()
612
613 client := New(Options{
614 Endpoint: server.URL,
615 EnableFallback: c.EnableFallback,
616 })
617
618 ctx := context.Background()
619 var result interface{}
620 var err error
621 for i := 0; i < c.APICallCount; i++ {
622 result, _, err = client.invokeOperation(ctx, "TestRequest", struct{}{}, nil,
623 func(stack *middleware.Stack, options Options) error {
624 return addAPIRequestMiddleware(stack,
625 client.options.Copy(),
626 "TestRequest",
627 func(interface{}) (string, error) {
628 return "/latest/foo", nil
629 },
630 func(resp *smithyhttp.Response) (interface{}, error) {
631 return &mockRequestOutput{
632 Content: resp.Body,
633 }, nil
634 },
635 )
636 },
637 )
638 }
639 if diff := cmpDiff(c.ExpectTrace, trace.requests); len(diff) != 0 {
640 t.Errorf("expect trace to match\n%s", diff)
641 }
642
643 if len(c.ExpectErr) != 0 {
644 if err == nil {
645 t.Fatalf("expect error, got none")
646 }
647 if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
648 t.Fatalf("expect error to contain %v, got %v", e, a)
649 }
650 return
651 }
652 if err != nil {
653 t.Fatalf("expect no error, got %v", err)
654 }
655
656 out, ok := result.(*mockRequestOutput)
657 if !ok {
658 t.Fatalf("expect output result, got %T", result)
659 }
660
661 content, err := ioutil.ReadAll(out.Content)
662 if err != nil {
663 t.Fatalf("expect to read result, got %v", err)
664 }
665
666 if e, a := c.ExpectContent, content; !bytes.Equal(e, a) {
667 t.Errorf("expect results to match\nexpect:\n%s\nactual:\n%s",
668 hex.Dump(e), hex.Dump(a))
669 }
670 })
671 }
672 }
673
674 func cmpDiff(e, a interface{}) string {
675 if !reflect.DeepEqual(e, a) {
676 return fmt.Sprintf("%v != %v", e, a)
677 }
678 return ""
679 }
680
View as plain text