...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/stscreds/web_identity_provider_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/stscreds

     1  package stscreds_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/aws/aws-sdk-go-v2/aws"
    11  	"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
    12  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
    13  	"github.com/aws/aws-sdk-go-v2/service/sts"
    14  	"github.com/aws/aws-sdk-go-v2/service/sts/types"
    15  )
    16  
    17  type mockAssumeRoleWithWebIdentity func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error)
    18  
    19  func (m mockAssumeRoleWithWebIdentity) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
    20  	return m(ctx, params, optFns...)
    21  }
    22  
    23  type mockErrorCode string
    24  
    25  func (m mockErrorCode) ErrorCode() string {
    26  	return string(m)
    27  }
    28  
    29  func (m mockErrorCode) Error() string {
    30  	return "error code: " + string(m)
    31  }
    32  
    33  func TestWebIdentityProviderRetrieve(t *testing.T) {
    34  	restorTime := sdk.TestingUseReferenceTime(time.Time{})
    35  	defer restorTime()
    36  
    37  	cases := map[string]struct {
    38  		mockClient        mockAssumeRoleWithWebIdentity
    39  		roleARN           string
    40  		tokenFilepath     string
    41  		sessionName       string
    42  		options           func(*stscreds.WebIdentityRoleOptions)
    43  		expectedCredValue aws.Credentials
    44  	}{
    45  		"success": {
    46  			roleARN:       "arn01234567890123456789",
    47  			tokenFilepath: "testdata/token.jwt",
    48  			options: func(o *stscreds.WebIdentityRoleOptions) {
    49  				o.RoleSessionName = "foo"
    50  			},
    51  			mockClient: func(
    52  				ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
    53  			) (
    54  				*sts.AssumeRoleWithWebIdentityOutput, error,
    55  			) {
    56  				if e, a := "foo", *params.RoleSessionName; e != a {
    57  					return nil, fmt.Errorf("expected %v, but received %v", e, a)
    58  				}
    59  				if params.DurationSeconds != nil {
    60  					return nil, fmt.Errorf("expect no duration seconds, got %v",
    61  						*params.DurationSeconds)
    62  				}
    63  				if params.Policy != nil {
    64  					return nil, fmt.Errorf("expect no policy, got %v",
    65  						*params.Policy)
    66  				}
    67  				return &sts.AssumeRoleWithWebIdentityOutput{
    68  					Credentials: &types.Credentials{
    69  						Expiration:      aws.Time(sdk.NowTime()),
    70  						AccessKeyId:     aws.String("access-key-id"),
    71  						SecretAccessKey: aws.String("secret-access-key"),
    72  						SessionToken:    aws.String("session-token"),
    73  					},
    74  				}, nil
    75  			},
    76  			expectedCredValue: aws.Credentials{
    77  				AccessKeyID:     "access-key-id",
    78  				SecretAccessKey: "secret-access-key",
    79  				SessionToken:    "session-token",
    80  				Source:          stscreds.WebIdentityProviderName,
    81  				CanExpire:       true,
    82  				Expires:         sdk.NowTime(),
    83  			},
    84  		},
    85  		"success with duration and policy": {
    86  			roleARN:       "arn01234567890123456789",
    87  			tokenFilepath: "testdata/token.jwt",
    88  			options: func(o *stscreds.WebIdentityRoleOptions) {
    89  				o.Duration = 42 * time.Second
    90  				o.Policy = aws.String("super secret policy")
    91  				o.RoleSessionName = "foo"
    92  			},
    93  			mockClient: func(
    94  				ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
    95  			) (
    96  				*sts.AssumeRoleWithWebIdentityOutput, error,
    97  			) {
    98  				if e, a := "foo", *params.RoleSessionName; e != a {
    99  					return nil, fmt.Errorf("expected %v, but received %v", e, a)
   100  				}
   101  				if e, a := int32(42), aws.ToInt32(params.DurationSeconds); e != a {
   102  					return nil, fmt.Errorf("expect %v duration seconds, got %v", e, a)
   103  				}
   104  				if e, a := "super secret policy", aws.ToString(params.Policy); e != a {
   105  					return nil, fmt.Errorf("expect %v policy, got %v", e, a)
   106  				}
   107  				return &sts.AssumeRoleWithWebIdentityOutput{
   108  					Credentials: &types.Credentials{
   109  						Expiration:      aws.Time(sdk.NowTime()),
   110  						AccessKeyId:     aws.String("access-key-id"),
   111  						SecretAccessKey: aws.String("secret-access-key"),
   112  						SessionToken:    aws.String("session-token"),
   113  					},
   114  				}, nil
   115  			},
   116  			expectedCredValue: aws.Credentials{
   117  				AccessKeyID:     "access-key-id",
   118  				SecretAccessKey: "secret-access-key",
   119  				SessionToken:    "session-token",
   120  				Source:          stscreds.WebIdentityProviderName,
   121  				CanExpire:       true,
   122  				Expires:         sdk.NowTime(),
   123  			},
   124  		},
   125  		"configures token retry": {
   126  			roleARN:       "arn01234567890123456789",
   127  			tokenFilepath: "testdata/token.jwt",
   128  			options: func(o *stscreds.WebIdentityRoleOptions) {
   129  				o.RoleSessionName = "foo"
   130  			},
   131  			mockClient: func(
   132  				ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
   133  			) (
   134  				*sts.AssumeRoleWithWebIdentityOutput, error,
   135  			) {
   136  				o := sts.Options{}
   137  				for _, fn := range optFns {
   138  					fn(&o)
   139  				}
   140  
   141  				errorCode := (&types.InvalidIdentityTokenException{}).ErrorCode()
   142  				if o.Retryer.IsErrorRetryable(mockErrorCode(errorCode)) != true {
   143  					return nil, fmt.Errorf("expected %v to be retryable", errorCode)
   144  				}
   145  
   146  				return &sts.AssumeRoleWithWebIdentityOutput{
   147  					Credentials: &types.Credentials{
   148  						Expiration:      aws.Time(sdk.NowTime()),
   149  						AccessKeyId:     aws.String("access-key-id"),
   150  						SecretAccessKey: aws.String("secret-access-key"),
   151  						SessionToken:    aws.String("session-token"),
   152  					},
   153  				}, nil
   154  			},
   155  			expectedCredValue: aws.Credentials{
   156  				AccessKeyID:     "access-key-id",
   157  				SecretAccessKey: "secret-access-key",
   158  				SessionToken:    "session-token",
   159  				Source:          stscreds.WebIdentityProviderName,
   160  				CanExpire:       true,
   161  				Expires:         sdk.NowTime(),
   162  			},
   163  		},
   164  	}
   165  
   166  	for name, c := range cases {
   167  		t.Run(name, func(t *testing.T) {
   168  			var optFns []func(*stscreds.WebIdentityRoleOptions)
   169  			if c.options != nil {
   170  				optFns = append(optFns, c.options)
   171  			}
   172  			p := stscreds.NewWebIdentityRoleProvider(
   173  				c.mockClient,
   174  				c.roleARN,
   175  				stscreds.IdentityTokenFile(c.tokenFilepath),
   176  				optFns...,
   177  			)
   178  			credValue, err := p.Retrieve(context.Background())
   179  			if err != nil {
   180  				t.Fatalf("expect no error, got %v", err)
   181  			}
   182  
   183  			if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
   184  				t.Errorf("expected %v, but received %v", e, a)
   185  			}
   186  		})
   187  	}
   188  }
   189  

View as plain text