...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds/provider_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds

     1  package ec2rolecreds
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"reflect"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/aws/aws-sdk-go-v2/aws"
    16  	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
    17  	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
    18  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
    19  	"github.com/aws/smithy-go"
    20  	"github.com/aws/smithy-go/logging"
    21  	"github.com/aws/smithy-go/middleware"
    22  )
    23  
    24  const credsRespTmpl = `{
    25    "Code": "Success",
    26    "Type": "AWS-HMAC",
    27    "AccessKeyId" : "accessKey",
    28    "SecretAccessKey" : "secret",
    29    "Token" : "token",
    30    "Expiration" : "%s",
    31    "LastUpdated" : "2009-11-23T00:00:00Z"
    32  }`
    33  
    34  const credsFailRespTmpl = `{
    35    "Code": "ErrorCode",
    36    "Message": "ErrorMsg",
    37    "LastUpdated": "2009-11-23T00:00:00Z"
    38  }`
    39  
    40  type mockClient struct {
    41  	t          *testing.T
    42  	roleName   string
    43  	failAssume bool
    44  	expireOn   string
    45  }
    46  
    47  func (c mockClient) GetMetadata(
    48  	ctx context.Context, params *imds.GetMetadataInput, optFns ...func(*imds.Options),
    49  ) (
    50  	*imds.GetMetadataOutput, error,
    51  ) {
    52  	switch params.Path {
    53  	case iamSecurityCredsPath:
    54  		return &imds.GetMetadataOutput{
    55  			Content: ioutil.NopCloser(strings.NewReader(c.roleName)),
    56  		}, nil
    57  
    58  	case iamSecurityCredsPath + c.roleName:
    59  		var w strings.Builder
    60  		if c.failAssume {
    61  			fmt.Fprintf(&w, credsFailRespTmpl)
    62  		} else {
    63  			fmt.Fprintf(&w, credsRespTmpl, c.expireOn)
    64  		}
    65  		return &imds.GetMetadataOutput{
    66  			Content: ioutil.NopCloser(strings.NewReader(w.String())),
    67  		}, nil
    68  	default:
    69  		return nil, fmt.Errorf("unexpected path, %v", params.Path)
    70  	}
    71  }
    72  
    73  var (
    74  	_ aws.AdjustExpiresByCredentialsCacheStrategy   = (*Provider)(nil)
    75  	_ aws.HandleFailRefreshCredentialsCacheStrategy = (*Provider)(nil)
    76  )
    77  
    78  func TestProvider(t *testing.T) {
    79  	orig := sdk.NowTime
    80  	defer func() { sdk.NowTime = orig }()
    81  
    82  	p := New(func(options *Options) {
    83  		options.Client = mockClient{
    84  			roleName:   "RoleName",
    85  			failAssume: false,
    86  			expireOn:   "2014-12-16T01:51:37Z",
    87  		}
    88  	})
    89  
    90  	creds, err := p.Retrieve(context.Background())
    91  	if err != nil {
    92  		t.Fatalf("expect no error, got %v", err)
    93  	}
    94  	if e, a := "accessKey", creds.AccessKeyID; e != a {
    95  		t.Errorf("Expect access key ID to match")
    96  	}
    97  	if e, a := "secret", creds.SecretAccessKey; e != a {
    98  		t.Errorf("Expect secret access key to match")
    99  	}
   100  	if e, a := "token", creds.SessionToken; e != a {
   101  		t.Errorf("Expect session token to match")
   102  	}
   103  
   104  	sdk.NowTime = func() time.Time {
   105  		return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
   106  	}
   107  
   108  	if creds.Expired() {
   109  		t.Errorf("Expect not expired")
   110  	}
   111  }
   112  
   113  func TestProvider_FailAssume(t *testing.T) {
   114  	p := New(func(options *Options) {
   115  		options.Client = mockClient{
   116  			roleName:   "RoleName",
   117  			failAssume: true,
   118  			expireOn:   "2014-12-16T01:51:37Z",
   119  		}
   120  	})
   121  
   122  	creds, err := p.Retrieve(context.Background())
   123  	if err == nil {
   124  		t.Fatalf("expect error, got none")
   125  	}
   126  
   127  	var apiErr smithy.APIError
   128  	if !errors.As(err, &apiErr) {
   129  		t.Fatalf("expect %T error, got %v", apiErr, err)
   130  	}
   131  	if e, a := "ErrorCode", apiErr.ErrorCode(); e != a {
   132  		t.Errorf("expect %v code, got %v", e, a)
   133  	}
   134  	if e, a := "ErrorMsg", apiErr.ErrorMessage(); e != a {
   135  		t.Errorf("expect %v message, got %v", e, a)
   136  	}
   137  
   138  	nestedErr := errors.Unwrap(apiErr)
   139  	if nestedErr != nil {
   140  		t.Fatalf("expect no nested error, got %v", err)
   141  	}
   142  
   143  	if e, a := "", creds.AccessKeyID; e != a {
   144  		t.Errorf("Expect access key ID to match")
   145  	}
   146  	if e, a := "", creds.SecretAccessKey; e != a {
   147  		t.Errorf("Expect secret access key to match")
   148  	}
   149  	if e, a := "", creds.SessionToken; e != a {
   150  		t.Errorf("Expect session token to match")
   151  	}
   152  }
   153  
   154  func TestProvider_IsExpired(t *testing.T) {
   155  	orig := sdk.NowTime
   156  	defer func() { sdk.NowTime = orig }()
   157  
   158  	p := New(func(options *Options) {
   159  		options.Client = mockClient{
   160  			roleName:   "RoleName",
   161  			failAssume: false,
   162  			expireOn:   "2014-12-16T01:51:37Z",
   163  		}
   164  	})
   165  
   166  	sdk.NowTime = func() time.Time {
   167  		return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
   168  	}
   169  
   170  	creds, err := p.Retrieve(context.Background())
   171  	if err != nil {
   172  		t.Fatalf("expect no error, got %v", err)
   173  	}
   174  	if creds.Expired() {
   175  		t.Errorf("expect not to be expired")
   176  	}
   177  
   178  	sdk.NowTime = func() time.Time {
   179  		return time.Date(2014, 12, 16, 1, 55, 37, 0, time.UTC)
   180  	}
   181  
   182  	if !creds.Expired() {
   183  		t.Errorf("expect to be expired")
   184  	}
   185  }
   186  
   187  type byteReader byte
   188  
   189  func (b byteReader) Read(p []byte) (int, error) {
   190  	for i := 0; i < len(p); i++ {
   191  		p[i] = byte(b)
   192  	}
   193  	return len(p), nil
   194  }
   195  
   196  func TestProvider_HandleFailToRetrieve(t *testing.T) {
   197  	origTime := sdk.NowTime
   198  	defer func() { sdk.NowTime = origTime }()
   199  	sdk.NowTime = func() time.Time {
   200  		return time.Date(2014, 04, 04, 0, 1, 0, 0, time.UTC)
   201  	}
   202  
   203  	origRand := sdkrand.Reader
   204  	defer func() { sdkrand.Reader = origRand }()
   205  	sdkrand.Reader = byteReader(0)
   206  
   207  	cases := map[string]struct {
   208  		creds        aws.Credentials
   209  		err          error
   210  		randReader   io.Reader
   211  		expectCreds  aws.Credentials
   212  		expectErr    string
   213  		expectLogged string
   214  	}{
   215  		"expired low": {
   216  			randReader: byteReader(0),
   217  			creds: aws.Credentials{
   218  				CanExpire: true,
   219  				Expires:   sdk.NowTime().Add(-5 * time.Minute),
   220  			},
   221  			err: fmt.Errorf("some error"),
   222  			expectCreds: aws.Credentials{
   223  				CanExpire: true,
   224  				Expires:   sdk.NowTime().Add(5 * time.Minute),
   225  			},
   226  			expectLogged: fmt.Sprintf("again in 5 minutes"),
   227  		},
   228  		"expired high": {
   229  			randReader: byteReader(0xFF),
   230  			creds: aws.Credentials{
   231  				CanExpire: true,
   232  				Expires:   sdk.NowTime().Add(-5 * time.Minute),
   233  			},
   234  			err: fmt.Errorf("some error"),
   235  			expectCreds: aws.Credentials{
   236  				CanExpire: true,
   237  				Expires:   sdk.NowTime().Add(14*time.Minute + 59*time.Second),
   238  			},
   239  			expectLogged: fmt.Sprintf("again in 14 minutes"),
   240  		},
   241  		"not expired": {
   242  			randReader: byteReader(0xFF),
   243  			creds: aws.Credentials{
   244  				CanExpire: true,
   245  				Expires:   sdk.NowTime().Add(10 * time.Minute),
   246  			},
   247  			err: fmt.Errorf("some error"),
   248  			expectCreds: aws.Credentials{
   249  				CanExpire: true,
   250  				Expires:   sdk.NowTime().Add(10 * time.Minute),
   251  			},
   252  		},
   253  		"cannot expire": {
   254  			randReader: byteReader(0xFF),
   255  			creds: aws.Credentials{
   256  				CanExpire: false,
   257  			},
   258  			err:       fmt.Errorf("some error"),
   259  			expectErr: "some error",
   260  		},
   261  	}
   262  
   263  	for name, c := range cases {
   264  		t.Run(name, func(t *testing.T) {
   265  			sdkrand.Reader = c.randReader
   266  			if sdkrand.Reader == nil {
   267  				sdkrand.Reader = byteReader(0)
   268  			}
   269  
   270  			var logBuf bytes.Buffer
   271  			logger := logging.LoggerFunc(func(class logging.Classification, format string, args ...interface{}) {
   272  				fmt.Fprintf(&logBuf, string(class)+" "+format, args...)
   273  			})
   274  			ctx := middleware.SetLogger(context.Background(), logger)
   275  
   276  			p := New()
   277  			creds, err := p.HandleFailToRefresh(ctx, c.creds, c.err)
   278  			if err == nil && len(c.expectErr) != 0 {
   279  				t.Fatalf("expect error %v, got none", c.expectErr)
   280  			}
   281  			if err != nil && len(c.expectErr) == 0 {
   282  				t.Fatalf("expect no error, got %v", err)
   283  			}
   284  			if err != nil && !strings.Contains(err.Error(), c.expectErr) {
   285  				t.Fatalf("expect error to contain %v, got %v", c.expectErr, err)
   286  			}
   287  			if c.expectErr != "" {
   288  				return
   289  			}
   290  
   291  			if len(c.expectLogged) != 0 && logBuf.Len() == 0 {
   292  				t.Errorf("expect %v logged, got none", c.expectLogged)
   293  			}
   294  			if e, a := c.expectLogged, logBuf.String(); !strings.Contains(a, e) {
   295  				t.Errorf("expect %v to be logged in %v", e, a)
   296  			}
   297  
   298  			// Truncate time so it can be easily compared.
   299  			creds.Expires = creds.Expires.Truncate(time.Second)
   300  
   301  			if diff := cmpDiff(c.expectCreds, creds); diff != "" {
   302  				t.Errorf("expect creds match\n%s", diff)
   303  			}
   304  		})
   305  	}
   306  }
   307  
   308  func TestProvider_AdjustExpiresBy(t *testing.T) {
   309  	origTime := sdk.NowTime
   310  	defer func() { sdk.NowTime = origTime }()
   311  	sdk.NowTime = func() time.Time {
   312  		return time.Date(2014, 04, 04, 0, 1, 0, 0, time.UTC)
   313  	}
   314  
   315  	cases := map[string]struct {
   316  		creds       aws.Credentials
   317  		dur         time.Duration
   318  		expectCreds aws.Credentials
   319  	}{
   320  		"modify expires": {
   321  			creds: aws.Credentials{
   322  				CanExpire: true,
   323  				Expires:   sdk.NowTime().Add(1 * time.Hour),
   324  			},
   325  			dur: -5 * time.Minute,
   326  			expectCreds: aws.Credentials{
   327  				CanExpire: true,
   328  				Expires:   sdk.NowTime().Add(55 * time.Minute),
   329  			},
   330  		},
   331  		"expiry too soon": {
   332  			creds: aws.Credentials{
   333  				CanExpire: true,
   334  				Expires:   sdk.NowTime().Add(14*time.Minute + 59*time.Second),
   335  			},
   336  			dur: -5 * time.Minute,
   337  			expectCreds: aws.Credentials{
   338  				CanExpire: true,
   339  				Expires:   sdk.NowTime().Add(14*time.Minute + 59*time.Second),
   340  			},
   341  		},
   342  		"cannot expire": {
   343  			creds: aws.Credentials{
   344  				CanExpire: false,
   345  			},
   346  			dur: 10 * time.Minute,
   347  			expectCreds: aws.Credentials{
   348  				CanExpire: false,
   349  			},
   350  		},
   351  	}
   352  
   353  	for name, c := range cases {
   354  		t.Run(name, func(t *testing.T) {
   355  			p := New()
   356  			creds, err := p.AdjustExpiresBy(c.creds, c.dur)
   357  
   358  			if err != nil {
   359  				t.Fatalf("expect no error, got %v", err)
   360  			}
   361  
   362  			if diff := cmpDiff(c.expectCreds, creds); diff != "" {
   363  				t.Errorf("expect creds match\n%s", diff)
   364  			}
   365  		})
   366  	}
   367  }
   368  
   369  func cmpDiff(e, a interface{}) string {
   370  	if !reflect.DeepEqual(e, a) {
   371  		return fmt.Sprintf("%v != %v", e, a)
   372  	}
   373  	return ""
   374  }
   375  

View as plain text