...

Source file src/github.com/aws/aws-sdk-go-v2/feature/ec2/imds/request_middleware_test.go

Documentation: github.com/aws/aws-sdk-go-v2/feature/ec2/imds

     1  package imds
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/aws/aws-sdk-go-v2/aws"
    18  
    19  	"github.com/aws/aws-sdk-go-v2/internal/awstesting"
    20  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
    21  	"github.com/aws/smithy-go/middleware"
    22  	smithyhttp "github.com/aws/smithy-go/transport/http"
    23  )
    24  
    25  func TestAddRequestMiddleware(t *testing.T) {
    26  	cases := map[string]struct {
    27  		AddMiddleware     func(*middleware.Stack, Options) error
    28  		ExpectInitialize  []string
    29  		ExpectSerialize   []string
    30  		ExpectBuild       []string
    31  		ExpectFinalize    []string
    32  		ExpectDeserialize []string
    33  	}{
    34  		"api request": {
    35  			AddMiddleware: func(stack *middleware.Stack, options Options) error {
    36  				return addAPIRequestMiddleware(stack, options,
    37  					"TestRequest",
    38  					func(interface{}) (string, error) {
    39  						return "/mockPath", nil
    40  					},
    41  					func(*smithyhttp.Response) (interface{}, error) {
    42  						return struct{}{}, nil
    43  					},
    44  				)
    45  			},
    46  			ExpectInitialize: []string{
    47  				(*operationTimeout)(nil).ID(),
    48  				"SetLogger",
    49  			},
    50  			ExpectSerialize: []string{
    51  				"ResolveEndpoint",
    52  				"OperationSerializer",
    53  			},
    54  			ExpectBuild: []string{
    55  				"UserAgent",
    56  			},
    57  			ExpectFinalize: []string{
    58  				"ResolveAuthScheme",
    59  				"GetIdentity",
    60  				"ResolveEndpointV2",
    61  				"Retry",
    62  				"APITokenProvider",
    63  				"RetryMetricsHeader",
    64  				"Signing",
    65  			},
    66  			ExpectDeserialize: []string{
    67  				"APITokenProvider",
    68  				"OperationDeserializer",
    69  				"RequestResponseLogger",
    70  			},
    71  		},
    72  
    73  		"base request": {
    74  			AddMiddleware: func(stack *middleware.Stack, options Options) error {
    75  				return addRequestMiddleware(stack, options, "POST", "TestRequest",
    76  					func(interface{}) (string, error) {
    77  						return "/mockPath", nil
    78  					},
    79  					func(*smithyhttp.Response) (interface{}, error) {
    80  						return struct{}{}, nil
    81  					},
    82  				)
    83  			},
    84  			ExpectInitialize: []string{
    85  				(*operationTimeout)(nil).ID(),
    86  				"SetLogger",
    87  			},
    88  			ExpectSerialize: []string{
    89  				"ResolveEndpoint",
    90  				"OperationSerializer",
    91  			},
    92  			ExpectBuild: []string{
    93  				"UserAgent",
    94  			},
    95  			ExpectFinalize: []string{
    96  				"ResolveAuthScheme",
    97  				"GetIdentity",
    98  				"ResolveEndpointV2",
    99  				"Retry",
   100  				"RetryMetricsHeader",
   101  				"Signing",
   102  			},
   103  			ExpectDeserialize: []string{
   104  				"OperationDeserializer",
   105  				"RequestResponseLogger",
   106  			},
   107  		},
   108  	}
   109  
   110  	for name, c := range cases {
   111  		t.Run(name, func(t *testing.T) {
   112  			client := New(Options{})
   113  
   114  			stack := middleware.NewStack("mockOp", smithyhttp.NewStackRequest)
   115  
   116  			if err := c.AddMiddleware(stack, client.options); err != nil {
   117  				t.Fatalf("expect no error adding middleware, got %v", err)
   118  			}
   119  
   120  			if diff := cmpDiff(c.ExpectInitialize, stack.Initialize.List()); len(diff) != 0 {
   121  				t.Errorf("expect initialize middleware\n%s", diff)
   122  			}
   123  
   124  			if diff := cmpDiff(c.ExpectSerialize, stack.Serialize.List()); len(diff) != 0 {
   125  				t.Errorf("expect serialize middleware\n%s", diff)
   126  			}
   127  
   128  			if diff := cmpDiff(c.ExpectBuild, stack.Build.List()); len(diff) != 0 {
   129  				t.Errorf("expect build middleware\n%s", diff)
   130  			}
   131  
   132  			if diff := cmpDiff(c.ExpectFinalize, stack.Finalize.List()); len(diff) != 0 {
   133  				t.Errorf("expect finalize middleware\n%s", diff)
   134  			}
   135  
   136  			if diff := cmpDiff(c.ExpectDeserialize, stack.Deserialize.List()); len(diff) != 0 {
   137  				t.Errorf("expect deserialize middleware\n%s", diff)
   138  			}
   139  		})
   140  	}
   141  }
   142  
   143  func TestOperationTimeoutMiddleware(t *testing.T) {
   144  	m := &operationTimeout{
   145  		DefaultTimeout: time.Nanosecond,
   146  	}
   147  
   148  	_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
   149  		middleware.InitializeHandlerFunc(func(
   150  			ctx context.Context, input middleware.InitializeInput,
   151  		) (
   152  			out middleware.InitializeOutput, metadata middleware.Metadata, err error,
   153  		) {
   154  			if _, ok := ctx.Deadline(); !ok {
   155  				return out, metadata, fmt.Errorf("expect context deadline to be set")
   156  			}
   157  
   158  			if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
   159  				return out, metadata, err
   160  			}
   161  
   162  			return out, metadata, nil
   163  		}))
   164  	if err == nil {
   165  		t.Fatalf("expect error got none")
   166  	}
   167  
   168  	if e, a := "deadline exceeded", err.Error(); !strings.Contains(a, e) {
   169  		t.Errorf("expect %q error in %q", e, a)
   170  	}
   171  }
   172  
   173  func TestOperationTimeoutMiddleware_noDefaultTimeout(t *testing.T) {
   174  	m := &operationTimeout{}
   175  
   176  	_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
   177  		middleware.InitializeHandlerFunc(func(
   178  			ctx context.Context, input middleware.InitializeInput,
   179  		) (
   180  			out middleware.InitializeOutput, metadata middleware.Metadata, err error,
   181  		) {
   182  			if t, ok := ctx.Deadline(); ok {
   183  				return out, metadata, fmt.Errorf("expect no context deadline, got %v", t)
   184  			}
   185  
   186  			return out, metadata, nil
   187  		}))
   188  	if err != nil {
   189  		t.Fatalf("expect no error, got %v", err)
   190  	}
   191  }
   192  
   193  func TestOperationTimeoutMiddleware_withCustomDeadline(t *testing.T) {
   194  	m := &operationTimeout{
   195  		DefaultTimeout: time.Nanosecond,
   196  	}
   197  
   198  	expectDeadline := time.Now().Add(time.Hour)
   199  	ctx, cancelFn := context.WithDeadline(context.Background(), expectDeadline)
   200  	defer cancelFn()
   201  
   202  	_, _, err := m.HandleInitialize(ctx, middleware.InitializeInput{},
   203  		middleware.InitializeHandlerFunc(func(
   204  			ctx context.Context, input middleware.InitializeInput,
   205  		) (
   206  			out middleware.InitializeOutput, metadata middleware.Metadata, err error,
   207  		) {
   208  			t, ok := ctx.Deadline()
   209  			if !ok {
   210  				return out, metadata, fmt.Errorf("expect context deadline to be set")
   211  			}
   212  			if e, a := expectDeadline, t; !e.Equal(a) {
   213  				return out, metadata, fmt.Errorf("expect %v deadline, got %v", e, a)
   214  			}
   215  
   216  			return out, metadata, nil
   217  		}))
   218  	if err != nil {
   219  		t.Fatalf("expect no error, got %v", err)
   220  	}
   221  }
   222  
   223  func TestOperationTimeoutMiddleware_Disabled(t *testing.T) {
   224  	m := &operationTimeout{
   225  		Disabled:       true,
   226  		DefaultTimeout: time.Nanosecond,
   227  	}
   228  
   229  	_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
   230  		middleware.InitializeHandlerFunc(func(
   231  			ctx context.Context, input middleware.InitializeInput,
   232  		) (
   233  			out middleware.InitializeOutput, metadata middleware.Metadata, err error,
   234  		) {
   235  			if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
   236  				return out, metadata, err
   237  			}
   238  
   239  			return out, metadata, nil
   240  		}))
   241  	if err != nil {
   242  		t.Fatalf("expect no error, got %v", err)
   243  	}
   244  }
   245  
   246  // Ensure that the response body is read in the deserialize middleware,
   247  // ensuring that the timeoutOperation middleware won't race canceling the
   248  // context with the upstream reading the response body.
   249  //   - https://github.com/aws/aws-sdk-go-v2/issues/1253
   250  func TestDeserailizeResponse_cacheBody(t *testing.T) {
   251  	type Output struct {
   252  		Content io.ReadCloser
   253  	}
   254  	m := &deserializeResponse{
   255  		GetOutput: func(resp *smithyhttp.Response) (interface{}, error) {
   256  			return &Output{
   257  				Content: resp.Body,
   258  			}, nil
   259  		},
   260  	}
   261  
   262  	expectBody := "hello world!"
   263  	originalBody := &bytesReader{
   264  		reader: strings.NewReader(expectBody),
   265  	}
   266  	if originalBody.closed {
   267  		t.Fatalf("expect original body not to be closed yet")
   268  	}
   269  
   270  	out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{},
   271  		middleware.DeserializeHandlerFunc(func(
   272  			ctx context.Context, input middleware.DeserializeInput,
   273  		) (
   274  			out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
   275  		) {
   276  			out.RawResponse = &smithyhttp.Response{
   277  				Response: &http.Response{
   278  					StatusCode:    200,
   279  					Status:        "200 OK",
   280  					Header:        http.Header{},
   281  					ContentLength: int64(originalBody.Len()),
   282  					Body:          originalBody,
   283  				},
   284  			}
   285  			return out, metadata, nil
   286  		}))
   287  	if err != nil {
   288  		t.Fatalf("expect no error, got %v", err)
   289  	}
   290  
   291  	if !originalBody.closed {
   292  		t.Errorf("expect original body to be closed, was not")
   293  	}
   294  
   295  	result, ok := out.Result.(*Output)
   296  	if !ok {
   297  		t.Fatalf("expect result to be Output, got %T, %v", result, result)
   298  	}
   299  
   300  	actualBody, err := ioutil.ReadAll(result.Content)
   301  	if err != nil {
   302  		t.Fatalf("expect no error, got %v", err)
   303  	}
   304  	if e, a := expectBody, string(actualBody); e != a {
   305  		t.Errorf("expect %v body, got %v", e, a)
   306  	}
   307  	if err := result.Content.Close(); err != nil {
   308  		t.Fatalf("expect no error, got %v", err)
   309  	}
   310  }
   311  
   312  type bytesReader struct {
   313  	reader interface {
   314  		io.Reader
   315  		Len() int
   316  	}
   317  	closed bool
   318  }
   319  
   320  func (r *bytesReader) Len() int {
   321  	return r.reader.Len()
   322  }
   323  func (r *bytesReader) Close() error {
   324  	r.closed = true
   325  	return nil
   326  }
   327  func (r *bytesReader) Read(p []byte) (int, error) {
   328  	if r.closed {
   329  		return 0, io.EOF
   330  	}
   331  	return r.reader.Read(p)
   332  }
   333  
   334  type successAPIResponseHandler struct {
   335  	t      *testing.T
   336  	path   string
   337  	method string
   338  
   339  	// response
   340  	statusCode int
   341  	header     http.Header
   342  	body       []byte
   343  }
   344  
   345  func (h *successAPIResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   346  	if e, a := h.path, r.URL.Path; e != a {
   347  		h.t.Errorf("expect %v path, got %v", e, a)
   348  	}
   349  	if e, a := h.method, r.Method; e != a {
   350  		h.t.Errorf("expect %v method, got %v", e, a)
   351  	}
   352  
   353  	for k, vs := range h.header {
   354  		for _, v := range vs {
   355  			w.Header().Add(k, v)
   356  		}
   357  	}
   358  
   359  	if h.statusCode != 0 {
   360  		w.WriteHeader(h.statusCode)
   361  	}
   362  	w.Write(h.body)
   363  }
   364  
   365  func TestRequestGetToken(t *testing.T) {
   366  	cases := map[string]struct {
   367  		GetHandler     func(*testing.T) http.Handler
   368  		APICallCount   int
   369  		ExpectTrace    []string
   370  		ExpectContent  []byte
   371  		ExpectErr      string
   372  		EnableFallback aws.Ternary
   373  	}{
   374  		"secure": {
   375  			ExpectTrace: []string{
   376  				getTokenPath,
   377  				"/latest/foo",
   378  				"/latest/foo",
   379  			},
   380  			APICallCount: 2,
   381  			GetHandler: func(t *testing.T) http.Handler {
   382  				return newTestServeMux(t,
   383  					newSecureAPIHandler(t,
   384  						[]string{"tokenA"},
   385  						5*time.Minute,
   386  						&successAPIResponseHandler{t: t,
   387  							path:   "/latest/foo",
   388  							method: "GET",
   389  							body:   []byte("hello"),
   390  						},
   391  					))
   392  			},
   393  			ExpectContent: []byte("hello"),
   394  		},
   395  
   396  		"secure multi token": {
   397  			ExpectTrace: []string{
   398  				getTokenPath,
   399  				"/latest/foo",
   400  				getTokenPath,
   401  				"/latest/foo",
   402  				getTokenPath,
   403  				"/latest/foo",
   404  				getTokenPath,
   405  				"/latest/foo",
   406  			},
   407  			APICallCount: 4,
   408  			GetHandler: func(t *testing.T) http.Handler {
   409  				return newTestServeMux(t,
   410  					newSecureAPIHandler(t,
   411  						[]string{"tokenA", "tokenB", "tokenC"},
   412  						1,
   413  						http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   414  							h := &successAPIResponseHandler{t: t,
   415  								path:   "/latest/foo",
   416  								method: "GET",
   417  								body:   []byte("hello"),
   418  							}
   419  
   420  							time.Sleep(100 * time.Millisecond)
   421  							h.ServeHTTP(w, r)
   422  						}),
   423  					))
   424  			},
   425  			ExpectContent: []byte("hello"),
   426  		},
   427  
   428  		// disables API token, fallback to insecure API calls.
   429  		"insecure 405": {
   430  			ExpectTrace: []string{
   431  				getTokenPath,
   432  				"/latest/foo",
   433  				"/latest/foo",
   434  			},
   435  			APICallCount: 2,
   436  			GetHandler: func(t *testing.T) http.Handler {
   437  				return newTestServeMux(t,
   438  					newInsecureAPIHandler(t,
   439  						405,
   440  						&successAPIResponseHandler{t: t,
   441  							path:   "/latest/foo",
   442  							method: "GET",
   443  							body:   []byte("hello"),
   444  						},
   445  					))
   446  			},
   447  			ExpectContent: []byte("hello"),
   448  		},
   449  
   450  		"insecure 404": {
   451  			ExpectTrace: []string{
   452  				getTokenPath,
   453  				"/latest/foo",
   454  				"/latest/foo",
   455  			},
   456  			APICallCount: 2,
   457  			GetHandler: func(t *testing.T) http.Handler {
   458  				return newTestServeMux(t,
   459  					newInsecureAPIHandler(t,
   460  						404,
   461  						&successAPIResponseHandler{t: t,
   462  							path:   "/latest/foo",
   463  							method: "GET",
   464  							body:   []byte("hello"),
   465  						},
   466  					))
   467  			},
   468  			ExpectContent: []byte("hello"),
   469  		},
   470  
   471  		"insecure 403": {
   472  			ExpectTrace: []string{
   473  				getTokenPath,
   474  				"/latest/foo",
   475  				"/latest/foo",
   476  			},
   477  			APICallCount: 2,
   478  			GetHandler: func(t *testing.T) http.Handler {
   479  				return newTestServeMux(t,
   480  					newInsecureAPIHandler(t,
   481  						403,
   482  						&successAPIResponseHandler{t: t,
   483  							path:   "/latest/foo",
   484  							method: "GET",
   485  							body:   []byte("hello"),
   486  						},
   487  					))
   488  			},
   489  			ExpectContent: []byte("hello"),
   490  		},
   491  
   492  		// Token disabled and becomes re-enabled
   493  		"unauthorized 401 re-enable": {
   494  			ExpectTrace: []string{
   495  				getTokenPath,
   496  				"/latest/foo",
   497  				getTokenPath,
   498  				"/latest/foo",
   499  				"/latest/foo",
   500  			},
   501  			APICallCount: 2,
   502  			GetHandler: func(t *testing.T) http.Handler {
   503  				return newTestServeMux(t,
   504  					newUnauthorizedAPIHandler(t,
   505  						newSecureAPIHandler(t,
   506  							[]string{"tokenA"},
   507  							5*time.Minute,
   508  							&successAPIResponseHandler{t: t,
   509  								path:   "/latest/foo",
   510  								method: "GET",
   511  								body:   []byte("hello"),
   512  							},
   513  						)))
   514  			},
   515  			ExpectContent: []byte("hello"),
   516  		},
   517  
   518  		// Token and API call both fail
   519  		"bad request 400": {
   520  			ExpectTrace: []string{
   521  				getTokenPath,
   522  			},
   523  			APICallCount: 1,
   524  			GetHandler: func(t *testing.T) http.Handler {
   525  				return newTestServeMux(t,
   526  					newInsecureAPIHandler(t,
   527  						400,
   528  						http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   529  							t.Errorf("expected no call to API handler")
   530  							http.Error(w, "", 400)
   531  						}),
   532  					))
   533  			},
   534  			ExpectErr: "failed to get API token",
   535  		},
   536  
   537  		// retryable token error with fallback enabled (default)
   538  		"token failure fallback enabled": {
   539  			ExpectTrace: []string{
   540  				getTokenPath,
   541  				getTokenPath,
   542  				getTokenPath,
   543  				"/latest/foo",
   544  			},
   545  			APICallCount: 1,
   546  			GetHandler: func(t *testing.T) http.Handler {
   547  				return newTestServeMux(t,
   548  					newInsecureAPIHandler(t,
   549  						500,
   550  						&successAPIResponseHandler{t: t,
   551  							path:   "/latest/foo",
   552  							method: "GET",
   553  							body:   []byte("hello"),
   554  						},
   555  					))
   556  			},
   557  			ExpectContent: []byte("hello"),
   558  		},
   559  		// retryable token error with fallback disabled
   560  		"token failure fallback disabled": {
   561  			ExpectTrace: []string{
   562  				getTokenPath,
   563  				getTokenPath,
   564  				getTokenPath,
   565  			},
   566  			APICallCount: 1,
   567  			GetHandler: func(t *testing.T) http.Handler {
   568  				return newTestServeMux(t,
   569  					newInsecureAPIHandler(t,
   570  						500,
   571  						http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   572  							t.Errorf("expected no call to API handler")
   573  							http.Error(w, "", 400)
   574  						}),
   575  					))
   576  			},
   577  			ExpectErr:      "failed to get API token",
   578  			EnableFallback: aws.BoolTernary(false),
   579  		},
   580  		"insecure 403 fallback disabled": {
   581  			ExpectTrace: []string{
   582  				getTokenPath,
   583  			},
   584  			APICallCount: 1,
   585  			GetHandler: func(t *testing.T) http.Handler {
   586  				return newTestServeMux(t,
   587  					newInsecureAPIHandler(t,
   588  						403,
   589  						http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   590  							t.Errorf("expected no call to API handler")
   591  							http.Error(w, "", 400)
   592  						}),
   593  					))
   594  			},
   595  			ExpectErr:      "failed to get API token",
   596  			EnableFallback: aws.BoolTernary(false),
   597  		},
   598  	}
   599  
   600  	type mockRequestOutput struct {
   601  		Content io.ReadCloser
   602  	}
   603  
   604  	for name, c := range cases {
   605  		t.Run(name, func(t *testing.T) {
   606  			envs := awstesting.StashEnv()
   607  			defer awstesting.PopEnv(envs)
   608  
   609  			trace := newRequestTrace()
   610  			server := httptest.NewServer(trace.WrapHandler(c.GetHandler(t)))
   611  			defer server.Close()
   612  
   613  			client := New(Options{
   614  				Endpoint:       server.URL,
   615  				EnableFallback: c.EnableFallback,
   616  			})
   617  
   618  			ctx := context.Background()
   619  			var result interface{}
   620  			var err error
   621  			for i := 0; i < c.APICallCount; i++ {
   622  				result, _, err = client.invokeOperation(ctx, "TestRequest", struct{}{}, nil,
   623  					func(stack *middleware.Stack, options Options) error {
   624  						return addAPIRequestMiddleware(stack,
   625  							client.options.Copy(),
   626  							"TestRequest",
   627  							func(interface{}) (string, error) {
   628  								return "/latest/foo", nil
   629  							},
   630  							func(resp *smithyhttp.Response) (interface{}, error) {
   631  								return &mockRequestOutput{
   632  									Content: resp.Body,
   633  								}, nil
   634  							},
   635  						)
   636  					},
   637  				)
   638  			}
   639  			if diff := cmpDiff(c.ExpectTrace, trace.requests); len(diff) != 0 {
   640  				t.Errorf("expect trace to match\n%s", diff)
   641  			}
   642  
   643  			if len(c.ExpectErr) != 0 {
   644  				if err == nil {
   645  					t.Fatalf("expect error, got none")
   646  				}
   647  				if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
   648  					t.Fatalf("expect error to contain %v, got %v", e, a)
   649  				}
   650  				return
   651  			}
   652  			if err != nil {
   653  				t.Fatalf("expect no error, got %v", err)
   654  			}
   655  
   656  			out, ok := result.(*mockRequestOutput)
   657  			if !ok {
   658  				t.Fatalf("expect output result, got %T", result)
   659  			}
   660  
   661  			content, err := ioutil.ReadAll(out.Content)
   662  			if err != nil {
   663  				t.Fatalf("expect to read result, got %v", err)
   664  			}
   665  
   666  			if e, a := c.ExpectContent, content; !bytes.Equal(e, a) {
   667  				t.Errorf("expect results to match\nexpect:\n%s\nactual:\n%s",
   668  					hex.Dump(e), hex.Dump(a))
   669  			}
   670  		})
   671  	}
   672  }
   673  
   674  func cmpDiff(e, a interface{}) string {
   675  	if !reflect.DeepEqual(e, a) {
   676  		return fmt.Sprintf("%v != %v", e, a)
   677  	}
   678  	return ""
   679  }
   680  

View as plain text