...
1 package stscreds
2
3 import (
4 "context"
5 "fmt"
6 "io/ioutil"
7 "strconv"
8 "time"
9
10 "github.com/aws/aws-sdk-go-v2/aws"
11 "github.com/aws/aws-sdk-go-v2/aws/retry"
12 "github.com/aws/aws-sdk-go-v2/internal/sdk"
13 "github.com/aws/aws-sdk-go-v2/service/sts"
14 "github.com/aws/aws-sdk-go-v2/service/sts/types"
15 )
16
17 var invalidIdentityTokenExceptionCode = (&types.InvalidIdentityTokenException{}).ErrorCode()
18
19 const (
20
21 WebIdentityProviderName = "WebIdentityCredentials"
22 )
23
24
25 type AssumeRoleWithWebIdentityAPIClient interface {
26 AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error)
27 }
28
29
30
31 type WebIdentityRoleProvider struct {
32 options WebIdentityRoleOptions
33 }
34
35
36 type WebIdentityRoleOptions struct {
37
38 Client AssumeRoleWithWebIdentityAPIClient
39
40
41 TokenRetriever IdentityTokenRetriever
42
43
44 RoleARN string
45
46
47 RoleSessionName string
48
49
50
51
52
53
54
55
56
57 Duration time.Duration
58
59
60 Policy *string
61
62
63
64
65 PolicyARNs []types.PolicyDescriptorType
66 }
67
68
69 type IdentityTokenRetriever interface {
70 GetIdentityToken() ([]byte, error)
71 }
72
73
74 type IdentityTokenFile string
75
76
77 func (j IdentityTokenFile) GetIdentityToken() ([]byte, error) {
78 b, err := ioutil.ReadFile(string(j))
79 if err != nil {
80 return nil, fmt.Errorf("unable to read file at %s: %v", string(j), err)
81 }
82
83 return b, nil
84 }
85
86
87
88 func NewWebIdentityRoleProvider(client AssumeRoleWithWebIdentityAPIClient, roleARN string, tokenRetriever IdentityTokenRetriever, optFns ...func(*WebIdentityRoleOptions)) *WebIdentityRoleProvider {
89 o := WebIdentityRoleOptions{
90 Client: client,
91 RoleARN: roleARN,
92 TokenRetriever: tokenRetriever,
93 }
94
95 for _, fn := range optFns {
96 fn(&o)
97 }
98
99 return &WebIdentityRoleProvider{options: o}
100 }
101
102
103
104
105 func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
106 b, err := p.options.TokenRetriever.GetIdentityToken()
107 if err != nil {
108 return aws.Credentials{}, fmt.Errorf("failed to retrieve jwt from provide source, %w", err)
109 }
110
111 sessionName := p.options.RoleSessionName
112 if len(sessionName) == 0 {
113
114
115 sessionName = strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
116 }
117 input := &sts.AssumeRoleWithWebIdentityInput{
118 PolicyArns: p.options.PolicyARNs,
119 RoleArn: &p.options.RoleARN,
120 RoleSessionName: &sessionName,
121 WebIdentityToken: aws.String(string(b)),
122 }
123 if p.options.Duration != 0 {
124
125 input.DurationSeconds = aws.Int32(int32(p.options.Duration / time.Second))
126 }
127 if p.options.Policy != nil {
128 input.Policy = p.options.Policy
129 }
130
131 resp, err := p.options.Client.AssumeRoleWithWebIdentity(ctx, input, func(options *sts.Options) {
132 options.Retryer = retry.AddWithErrorCodes(options.Retryer, invalidIdentityTokenExceptionCode)
133 })
134 if err != nil {
135 return aws.Credentials{}, fmt.Errorf("failed to retrieve credentials, %w", err)
136 }
137
138
139
140
141 value := aws.Credentials{
142 AccessKeyID: aws.ToString(resp.Credentials.AccessKeyId),
143 SecretAccessKey: aws.ToString(resp.Credentials.SecretAccessKey),
144 SessionToken: aws.ToString(resp.Credentials.SessionToken),
145 Source: WebIdentityProviderName,
146 CanExpire: true,
147 Expires: *resp.Credentials.Expiration,
148 }
149 return value, nil
150 }
151
View as plain text