1 package ec2rolecreds
2
3 import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "math"
9 "path"
10 "strings"
11 "time"
12
13 "github.com/aws/aws-sdk-go-v2/aws"
14 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
15 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
16 "github.com/aws/aws-sdk-go-v2/internal/sdk"
17 "github.com/aws/smithy-go"
18 "github.com/aws/smithy-go/logging"
19 "github.com/aws/smithy-go/middleware"
20 )
21
22
23 const ProviderName = "EC2RoleProvider"
24
25
26
27 type GetMetadataAPIClient interface {
28 GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
29 }
30
31
32
33
34
35
36
37
38
39 type Provider struct {
40 options Options
41 }
42
43
44 type Options struct {
45
46
47
48
49 Client GetMetadataAPIClient
50 }
51
52
53
54 func New(optFns ...func(*Options)) *Provider {
55 options := Options{}
56
57 for _, fn := range optFns {
58 fn(&options)
59 }
60
61 if options.Client == nil {
62 options.Client = imds.New(imds.Options{})
63 }
64
65 return &Provider{
66 options: options,
67 }
68 }
69
70
71
72 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
73 credsList, err := requestCredList(ctx, p.options.Client)
74 if err != nil {
75 return aws.Credentials{Source: ProviderName}, err
76 }
77
78 if len(credsList) == 0 {
79 return aws.Credentials{Source: ProviderName},
80 fmt.Errorf("unexpected empty EC2 IMDS role list")
81 }
82 credsName := credsList[0]
83
84 roleCreds, err := requestCred(ctx, p.options.Client, credsName)
85 if err != nil {
86 return aws.Credentials{Source: ProviderName}, err
87 }
88
89 creds := aws.Credentials{
90 AccessKeyID: roleCreds.AccessKeyID,
91 SecretAccessKey: roleCreds.SecretAccessKey,
92 SessionToken: roleCreds.Token,
93 Source: ProviderName,
94
95 CanExpire: true,
96 Expires: roleCreds.Expiration,
97 }
98
99
100
101 if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
102 creds.Expires = anHour
103 }
104
105 return creds, nil
106 }
107
108
109
110
111
112
113 func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
114 aws.Credentials, error,
115 ) {
116 if !prevCreds.CanExpire {
117 return aws.Credentials{}, err
118 }
119
120 if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
121 return prevCreds, nil
122 }
123
124 newCreds := prevCreds
125 randFloat64, err := sdkrand.CryptoRandFloat64()
126 if err != nil {
127 return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
128 }
129
130
131 expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
132 newCreds.Expires = sdk.NowTime().Add(expireOffset)
133
134 logger := middleware.GetLogger(ctx)
135 logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
136
137 return newCreds, nil
138 }
139
140
141
142
143 func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
144 aws.Credentials, error,
145 ) {
146 if !creds.CanExpire {
147 return creds, nil
148 }
149 if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
150 return creds, nil
151 }
152
153 creds.Expires = creds.Expires.Add(dur)
154 return creds, nil
155 }
156
157
158
159 type ec2RoleCredRespBody struct {
160
161 Expiration time.Time
162 AccessKeyID string
163 SecretAccessKey string
164 Token string
165
166
167 Code string
168 Message string
169 }
170
171 const iamSecurityCredsPath = "/iam/security-credentials/"
172
173
174
175
176 func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
177 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
178 Path: iamSecurityCredsPath,
179 })
180 if err != nil {
181 return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
182 }
183 defer resp.Content.Close()
184
185 credsList := []string{}
186 s := bufio.NewScanner(resp.Content)
187 for s.Scan() {
188 credsList = append(credsList, s.Text())
189 }
190
191 if err := s.Err(); err != nil {
192 return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
193 }
194
195 return credsList, nil
196 }
197
198
199
200
201
202 func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
203 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
204 Path: path.Join(iamSecurityCredsPath, credsName),
205 })
206 if err != nil {
207 return ec2RoleCredRespBody{},
208 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
209 credsName, err)
210 }
211 defer resp.Content.Close()
212
213 var respCreds ec2RoleCredRespBody
214 if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
215 return ec2RoleCredRespBody{},
216 fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
217 credsName, err)
218 }
219
220 if !strings.EqualFold(respCreds.Code, "Success") {
221
222 return ec2RoleCredRespBody{},
223 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
224 credsName,
225 &smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
226 }
227
228 return respCreds, nil
229 }
230
View as plain text