...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/provider_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/endpointcreds

     1  package endpointcreds_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/aws/aws-sdk-go-v2/credentials/endpointcreds"
    16  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
    17  	"github.com/aws/smithy-go"
    18  )
    19  
    20  type mockClient func(*http.Request) (*http.Response, error)
    21  
    22  func (m mockClient) Do(r *http.Request) (*http.Response, error) {
    23  	return m(r)
    24  }
    25  
    26  func TestRetrieveRefreshableCredentials(t *testing.T) {
    27  	orig := sdk.NowTime
    28  	defer func() { sdk.NowTime = orig }()
    29  
    30  	p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
    31  		o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
    32  			expTime := time.Now().UTC().Add(1 * time.Hour).Format("2006-01-02T15:04:05Z")
    33  
    34  			return &http.Response{
    35  				StatusCode: 200,
    36  				Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{
    37    "AccessKeyID": "AKID",
    38    "SecretAccessKey": "SECRET",
    39    "Token": "TOKEN",
    40    "Expiration": "%s"
    41  }`, expTime)))),
    42  			}, nil
    43  		})
    44  	})
    45  	creds, err := p.Retrieve(context.Background())
    46  
    47  	if err != nil {
    48  		t.Fatalf("expect no error, got %v", err)
    49  	}
    50  
    51  	if e, a := "AKID", creds.AccessKeyID; e != a {
    52  		t.Errorf("expect %v, got %v", e, a)
    53  	}
    54  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
    55  		t.Errorf("expect %v, got %v", e, a)
    56  	}
    57  	if e, a := "TOKEN", creds.SessionToken; e != a {
    58  		t.Errorf("expect %v, got %v", e, a)
    59  	}
    60  	if creds.Expired() {
    61  		t.Errorf("expect not expired")
    62  	}
    63  
    64  	sdk.NowTime = func() time.Time {
    65  		return time.Now().Add(2 * time.Hour)
    66  	}
    67  	if !creds.Expired() {
    68  		t.Errorf("expect to be expired")
    69  	}
    70  }
    71  
    72  func TestRetrieveStaticCredentials(t *testing.T) {
    73  	orig := sdk.NowTime
    74  	defer func() { sdk.NowTime = orig }()
    75  
    76  	p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
    77  		o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
    78  			return &http.Response{
    79  				StatusCode: 200,
    80  				Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
    81    "AccessKeyID": "AKID",
    82    "SecretAccessKey": "SECRET"
    83  }`))),
    84  			}, nil
    85  		})
    86  	})
    87  	creds, err := p.Retrieve(context.Background())
    88  
    89  	if err != nil {
    90  		t.Fatalf("expect no error, got %v", err)
    91  	}
    92  
    93  	if e, a := "AKID", creds.AccessKeyID; e != a {
    94  		t.Errorf("expect %v, got %v", e, a)
    95  	}
    96  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
    97  		t.Errorf("expect %v, got %v", e, a)
    98  	}
    99  	if v := creds.SessionToken; len(v) != 0 {
   100  		t.Errorf("expect empty, got %v", v)
   101  	}
   102  
   103  	sdk.NowTime = func() time.Time {
   104  		return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC)
   105  	}
   106  
   107  	if creds.Expired() {
   108  		t.Errorf("expect not to be expired")
   109  	}
   110  }
   111  
   112  func TestAuthTokenProvider(t *testing.T) {
   113  	cases := map[string]struct {
   114  		AuthToken         string
   115  		AuthTokenProvider endpointcreds.AuthTokenProvider
   116  		ExpectAuthToken   string
   117  		ExpectError       bool
   118  	}{
   119  		"AuthToken": {
   120  			AuthToken:       "Basic abc123",
   121  			ExpectAuthToken: "Basic abc123",
   122  		},
   123  		"AuthFileToken": {
   124  			AuthToken: "Basic abc123",
   125  			AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
   126  				return "Hello %20world", nil
   127  			}),
   128  			ExpectAuthToken: "Hello %20world",
   129  		},
   130  		"RetrieveFileTokenError": {
   131  			AuthToken: "Basic abc123",
   132  			AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
   133  				return "", fmt.Errorf("test error")
   134  			}),
   135  			ExpectAuthToken: "Hello %20world",
   136  			ExpectError:     true,
   137  		},
   138  	}
   139  
   140  	for name, c := range cases {
   141  		t.Run(name, func(t *testing.T) {
   142  			orig := sdk.NowTime
   143  			defer func() { sdk.NowTime = orig }()
   144  
   145  			var actualToken string
   146  			p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
   147  				o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
   148  					actualToken = r.Header["Authorization"][0]
   149  					return &http.Response{
   150  						StatusCode: 200,
   151  						Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
   152    "AccessKeyID": "AKID",
   153    "SecretAccessKey": "SECRET"
   154  }`))),
   155  					}, nil
   156  				})
   157  				o.AuthorizationToken = c.AuthToken
   158  				o.AuthorizationTokenProvider = c.AuthTokenProvider
   159  			})
   160  			creds, err := p.Retrieve(context.Background())
   161  
   162  			if err != nil && !c.ExpectError {
   163  				t.Errorf("expect no error, got %v", err)
   164  			} else if err == nil && c.ExpectError {
   165  				t.Errorf("expect error, got nil")
   166  			}
   167  
   168  			if c.ExpectError {
   169  				return
   170  			}
   171  
   172  			if e, a := "AKID", creds.AccessKeyID; e != a {
   173  				t.Errorf("expect %v, got %v", e, a)
   174  			}
   175  			if e, a := "SECRET", creds.SecretAccessKey; e != a {
   176  				t.Errorf("expect %v, got %v", e, a)
   177  			}
   178  			if v := creds.SessionToken; len(v) != 0 {
   179  				t.Errorf("expect empty, got %v", v)
   180  			}
   181  			if e, a := c.ExpectAuthToken, actualToken; e != a {
   182  				t.Errorf("Expect %v, got %v", e, a)
   183  			}
   184  
   185  			sdk.NowTime = func() time.Time {
   186  				return time.Date(3000, 12, 16, 1, 30, 37, 0, time.UTC)
   187  			}
   188  
   189  			if creds.Expired() {
   190  				t.Errorf("expect not to be expired")
   191  			}
   192  		})
   193  	}
   194  }
   195  
   196  func TestFailedRetrieveCredentials(t *testing.T) {
   197  	p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
   198  		o.HTTPClient = mockClient(func(r *http.Request) (*http.Response, error) {
   199  			return &http.Response{
   200  				StatusCode: 400,
   201  				Body: ioutil.NopCloser(bytes.NewReader([]byte(`{
   202    "code": "Error",
   203    "message": "Message"
   204  }`))),
   205  				Header: http.Header{
   206  					"Content-Type": {"application/json"},
   207  				},
   208  			}, nil
   209  		})
   210  	})
   211  	creds, err := p.Retrieve(context.Background())
   212  
   213  	if err == nil {
   214  		t.Fatalf("expect error, got none")
   215  	}
   216  
   217  	if e, a := "failed to load credentials", err.Error(); !strings.Contains(a, e) {
   218  		t.Errorf("expect %v, got %v", e, a)
   219  	}
   220  
   221  	var apiError smithy.APIError
   222  	if !errors.As(err, &apiError) {
   223  		t.Fatalf("expect %T error, got %v", apiError, err)
   224  	}
   225  	if e, a := "Error", apiError.ErrorCode(); e != a {
   226  		t.Errorf("expect %v, got %v", e, a)
   227  	}
   228  	if e, a := "Message", apiError.ErrorMessage(); e != a {
   229  		t.Errorf("expect %v, got %v", e, a)
   230  	}
   231  
   232  	if v := creds.AccessKeyID; len(v) != 0 {
   233  		t.Errorf("expect empty, got %v", v)
   234  	}
   235  	if v := creds.SecretAccessKey; len(v) != 0 {
   236  		t.Errorf("expect empty, got %v", v)
   237  	}
   238  	if v := creds.SessionToken; len(v) != 0 {
   239  		t.Errorf("expect empty, got %v", v)
   240  	}
   241  	if creds.Expired() {
   242  		t.Errorf("expect empty creds not to be expired")
   243  	}
   244  }
   245  
   246  type mockClientN struct {
   247  	responses []*http.Response
   248  	index     int
   249  }
   250  
   251  func (c *mockClientN) Do(r *http.Request) (*http.Response, error) {
   252  	resp := c.responses[c.index]
   253  	c.index++
   254  	return resp, nil
   255  }
   256  
   257  func TestRetryHTTPStatusCode(t *testing.T) {
   258  	expTime := time.Now().UTC().Add(1 * time.Hour).Format("2006-01-02T15:04:05Z")
   259  	credsResp := fmt.Sprintf(`{"AccessKeyID":"AKID","SecretAccessKey":"SECRET","Token":"TOKEN","Expiration":"%s"}`, expTime)
   260  
   261  	p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
   262  		o.HTTPClient = &mockClientN{
   263  			responses: []*http.Response{
   264  				{
   265  					StatusCode: 429,
   266  					Body:       io.NopCloser(strings.NewReader("You have made too many requests.")),
   267  					Header: http.Header{
   268  						"Content-Type": {"text/plain"},
   269  					},
   270  				},
   271  				{
   272  					StatusCode: 500,
   273  					Body:       io.NopCloser(strings.NewReader("Internal server error.")),
   274  					Header: http.Header{
   275  						"Content-Type": {"text/plain"},
   276  					},
   277  				},
   278  				{
   279  					StatusCode: 200,
   280  					Body:       ioutil.NopCloser(strings.NewReader(credsResp)),
   281  					Header: http.Header{
   282  						"Content-Type": {"application/json"},
   283  					},
   284  				},
   285  			},
   286  		}
   287  	})
   288  
   289  	creds, err := p.Retrieve(context.Background())
   290  	if err != nil {
   291  		t.Fatalf("expect no error, got %v", err)
   292  	}
   293  
   294  	if e, a := "AKID", creds.AccessKeyID; e != a {
   295  		t.Errorf("expect %v, got %v", e, a)
   296  	}
   297  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
   298  		t.Errorf("expect %v, got %v", e, a)
   299  	}
   300  	if e, a := "TOKEN", creds.SessionToken; e != a {
   301  		t.Errorf("expect %v, got %v", e, a)
   302  	}
   303  	if creds.Expired() {
   304  		t.Errorf("expect not expired")
   305  	}
   306  
   307  	sdk.NowTime = func() time.Time {
   308  		return time.Now().Add(2 * time.Hour)
   309  	}
   310  	if !creds.Expired() {
   311  		t.Errorf("expect to be expired")
   312  	}
   313  }
   314  

View as plain text