1 package stscreds_test
2
3 import (
4 "context"
5 "fmt"
6 "reflect"
7 "testing"
8 "time"
9
10 "github.com/aws/aws-sdk-go-v2/aws"
11 "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
12 "github.com/aws/aws-sdk-go-v2/internal/sdk"
13 "github.com/aws/aws-sdk-go-v2/service/sts"
14 "github.com/aws/aws-sdk-go-v2/service/sts/types"
15 )
16
17 type mockAssumeRoleWithWebIdentity func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error)
18
19 func (m mockAssumeRoleWithWebIdentity) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
20 return m(ctx, params, optFns...)
21 }
22
23 type mockErrorCode string
24
25 func (m mockErrorCode) ErrorCode() string {
26 return string(m)
27 }
28
29 func (m mockErrorCode) Error() string {
30 return "error code: " + string(m)
31 }
32
33 func TestWebIdentityProviderRetrieve(t *testing.T) {
34 restorTime := sdk.TestingUseReferenceTime(time.Time{})
35 defer restorTime()
36
37 cases := map[string]struct {
38 mockClient mockAssumeRoleWithWebIdentity
39 roleARN string
40 tokenFilepath string
41 sessionName string
42 options func(*stscreds.WebIdentityRoleOptions)
43 expectedCredValue aws.Credentials
44 }{
45 "success": {
46 roleARN: "arn01234567890123456789",
47 tokenFilepath: "testdata/token.jwt",
48 options: func(o *stscreds.WebIdentityRoleOptions) {
49 o.RoleSessionName = "foo"
50 },
51 mockClient: func(
52 ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
53 ) (
54 *sts.AssumeRoleWithWebIdentityOutput, error,
55 ) {
56 if e, a := "foo", *params.RoleSessionName; e != a {
57 return nil, fmt.Errorf("expected %v, but received %v", e, a)
58 }
59 if params.DurationSeconds != nil {
60 return nil, fmt.Errorf("expect no duration seconds, got %v",
61 *params.DurationSeconds)
62 }
63 if params.Policy != nil {
64 return nil, fmt.Errorf("expect no policy, got %v",
65 *params.Policy)
66 }
67 return &sts.AssumeRoleWithWebIdentityOutput{
68 Credentials: &types.Credentials{
69 Expiration: aws.Time(sdk.NowTime()),
70 AccessKeyId: aws.String("access-key-id"),
71 SecretAccessKey: aws.String("secret-access-key"),
72 SessionToken: aws.String("session-token"),
73 },
74 }, nil
75 },
76 expectedCredValue: aws.Credentials{
77 AccessKeyID: "access-key-id",
78 SecretAccessKey: "secret-access-key",
79 SessionToken: "session-token",
80 Source: stscreds.WebIdentityProviderName,
81 CanExpire: true,
82 Expires: sdk.NowTime(),
83 },
84 },
85 "success with duration and policy": {
86 roleARN: "arn01234567890123456789",
87 tokenFilepath: "testdata/token.jwt",
88 options: func(o *stscreds.WebIdentityRoleOptions) {
89 o.Duration = 42 * time.Second
90 o.Policy = aws.String("super secret policy")
91 o.RoleSessionName = "foo"
92 },
93 mockClient: func(
94 ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
95 ) (
96 *sts.AssumeRoleWithWebIdentityOutput, error,
97 ) {
98 if e, a := "foo", *params.RoleSessionName; e != a {
99 return nil, fmt.Errorf("expected %v, but received %v", e, a)
100 }
101 if e, a := int32(42), aws.ToInt32(params.DurationSeconds); e != a {
102 return nil, fmt.Errorf("expect %v duration seconds, got %v", e, a)
103 }
104 if e, a := "super secret policy", aws.ToString(params.Policy); e != a {
105 return nil, fmt.Errorf("expect %v policy, got %v", e, a)
106 }
107 return &sts.AssumeRoleWithWebIdentityOutput{
108 Credentials: &types.Credentials{
109 Expiration: aws.Time(sdk.NowTime()),
110 AccessKeyId: aws.String("access-key-id"),
111 SecretAccessKey: aws.String("secret-access-key"),
112 SessionToken: aws.String("session-token"),
113 },
114 }, nil
115 },
116 expectedCredValue: aws.Credentials{
117 AccessKeyID: "access-key-id",
118 SecretAccessKey: "secret-access-key",
119 SessionToken: "session-token",
120 Source: stscreds.WebIdentityProviderName,
121 CanExpire: true,
122 Expires: sdk.NowTime(),
123 },
124 },
125 "configures token retry": {
126 roleARN: "arn01234567890123456789",
127 tokenFilepath: "testdata/token.jwt",
128 options: func(o *stscreds.WebIdentityRoleOptions) {
129 o.RoleSessionName = "foo"
130 },
131 mockClient: func(
132 ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
133 ) (
134 *sts.AssumeRoleWithWebIdentityOutput, error,
135 ) {
136 o := sts.Options{}
137 for _, fn := range optFns {
138 fn(&o)
139 }
140
141 errorCode := (&types.InvalidIdentityTokenException{}).ErrorCode()
142 if o.Retryer.IsErrorRetryable(mockErrorCode(errorCode)) != true {
143 return nil, fmt.Errorf("expected %v to be retryable", errorCode)
144 }
145
146 return &sts.AssumeRoleWithWebIdentityOutput{
147 Credentials: &types.Credentials{
148 Expiration: aws.Time(sdk.NowTime()),
149 AccessKeyId: aws.String("access-key-id"),
150 SecretAccessKey: aws.String("secret-access-key"),
151 SessionToken: aws.String("session-token"),
152 },
153 }, nil
154 },
155 expectedCredValue: aws.Credentials{
156 AccessKeyID: "access-key-id",
157 SecretAccessKey: "secret-access-key",
158 SessionToken: "session-token",
159 Source: stscreds.WebIdentityProviderName,
160 CanExpire: true,
161 Expires: sdk.NowTime(),
162 },
163 },
164 }
165
166 for name, c := range cases {
167 t.Run(name, func(t *testing.T) {
168 var optFns []func(*stscreds.WebIdentityRoleOptions)
169 if c.options != nil {
170 optFns = append(optFns, c.options)
171 }
172 p := stscreds.NewWebIdentityRoleProvider(
173 c.mockClient,
174 c.roleARN,
175 stscreds.IdentityTokenFile(c.tokenFilepath),
176 optFns...,
177 )
178 credValue, err := p.Retrieve(context.Background())
179 if err != nil {
180 t.Fatalf("expect no error, got %v", err)
181 }
182
183 if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
184 t.Errorf("expected %v, but received %v", e, a)
185 }
186 })
187 }
188 }
189
View as plain text