1 package imds
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "github.com/aws/aws-sdk-go-v2/aws"
8 "github.com/aws/smithy-go"
9 "github.com/aws/smithy-go/logging"
10 "net/http"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "github.com/aws/smithy-go/middleware"
16 smithyhttp "github.com/aws/smithy-go/transport/http"
17 )
18
19 const (
20
21 tokenHeader = "x-aws-ec2-metadata-token"
22 defaultTokenTTL = 5 * time.Minute
23 )
24
25 type tokenProvider struct {
26 client *Client
27 tokenTTL time.Duration
28
29 token *apiToken
30 tokenMux sync.RWMutex
31
32 disabled uint32
33 }
34
35 func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
36 return &tokenProvider{
37 client: client,
38 tokenTTL: ttl,
39 }
40 }
41
42
43
44 type apiToken struct {
45 token string
46 expires time.Time
47 }
48
49 var timeNow = time.Now
50
51
52 func (t *apiToken) Expired() bool {
53
54
55 return timeNow().Round(0).After(t.expires)
56 }
57
58 func (t *tokenProvider) ID() string { return "APITokenProvider" }
59
60
61
62
63
64
65
66
67
68 func (t *tokenProvider) HandleFinalize(
69 ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
70 ) (
71 out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
72 ) {
73 if t.fallbackEnabled() && !t.enabled() {
74
75 return next.HandleFinalize(ctx, input)
76 }
77
78 req, ok := input.Request.(*smithyhttp.Request)
79 if !ok {
80 return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
81 }
82
83 tok, err := t.getToken(ctx)
84 if err != nil {
85
86 var bypassErr *bypassTokenRetrievalError
87 if errors.As(err, &bypassErr) {
88 return next.HandleFinalize(ctx, input)
89 }
90
91 return out, metadata, fmt.Errorf("failed to get API token, %w", err)
92 }
93
94 req.Header.Set(tokenHeader, tok.token)
95
96 return next.HandleFinalize(ctx, input)
97 }
98
99
100
101
102
103
104 func (t *tokenProvider) HandleDeserialize(
105 ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
106 ) (
107 out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
108 ) {
109 out, metadata, err = next.HandleDeserialize(ctx, input)
110 if err == nil {
111 return out, metadata, err
112 }
113
114 resp, ok := out.RawResponse.(*smithyhttp.Response)
115 if !ok {
116 return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
117 }
118
119 if resp.StatusCode == http.StatusUnauthorized {
120 t.enable()
121 err = &retryableError{Err: err, isRetryable: true}
122 }
123
124 return out, metadata, err
125 }
126
127 func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
128 if t.fallbackEnabled() && !t.enabled() {
129 return nil, &bypassTokenRetrievalError{
130 Err: fmt.Errorf("cannot get API token, provider disabled"),
131 }
132 }
133
134 t.tokenMux.RLock()
135 tok = t.token
136 t.tokenMux.RUnlock()
137
138 if tok != nil && !tok.Expired() {
139 return tok, nil
140 }
141
142 tok, err = t.updateToken(ctx)
143 if err != nil {
144 return nil, err
145 }
146
147 return tok, nil
148 }
149
150 func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
151 t.tokenMux.Lock()
152 defer t.tokenMux.Unlock()
153
154
155 if t.token != nil && !t.token.Expired() {
156 tok := t.token
157 return tok, nil
158 }
159
160 result, err := t.client.getToken(ctx, &getTokenInput{
161 TokenTTL: t.tokenTTL,
162 })
163 if err != nil {
164 var statusErr interface{ HTTPStatusCode() int }
165 if errors.As(err, &statusErr) {
166 switch statusErr.HTTPStatusCode() {
167
168 case http.StatusForbidden,
169 http.StatusNotFound,
170 http.StatusMethodNotAllowed:
171
172 if t.fallbackEnabled() {
173 logger := middleware.GetLogger(ctx)
174 logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
175 t.disable()
176 }
177
178
179 case http.StatusBadRequest:
180 return nil, err
181 }
182 }
183
184
185 var re *smithyhttp.RequestSendError
186 var ce *smithy.CanceledError
187 if errors.As(err, &re) || errors.As(err, &ce) {
188 atomic.StoreUint32(&t.disabled, 1)
189 }
190
191 if !t.fallbackEnabled() {
192
193
194
195 err = &retryableError{Err: err, isRetryable: false}
196 return nil, err
197 }
198
199
200
201
202 return nil, &bypassTokenRetrievalError{Err: err}
203 }
204
205 tok := &apiToken{
206 token: result.Token,
207 expires: timeNow().Add(result.TokenTTL),
208 }
209 t.token = tok
210
211 return tok, nil
212 }
213
214
215 func (t *tokenProvider) enabled() bool {
216 return atomic.LoadUint32(&t.disabled) == 0
217 }
218
219
220 func (t *tokenProvider) fallbackEnabled() bool {
221 switch t.client.options.EnableFallback {
222 case aws.FalseTernary:
223 return false
224 default:
225 return true
226 }
227 }
228
229
230
231 func (t *tokenProvider) disable() {
232 atomic.StoreUint32(&t.disabled, 1)
233 }
234
235
236
237 func (t *tokenProvider) enable() {
238 t.tokenMux.Lock()
239 t.token = nil
240 t.tokenMux.Unlock()
241 atomic.StoreUint32(&t.disabled, 0)
242 }
243
244 type bypassTokenRetrievalError struct {
245 Err error
246 }
247
248 func (e *bypassTokenRetrievalError) Error() string {
249 return fmt.Sprintf("bypass token retrieval, %v", e.Err)
250 }
251
252 func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
253
254 type retryableError struct {
255 Err error
256 isRetryable bool
257 }
258
259 func (e *retryableError) RetryableError() bool { return e.isRetryable }
260
261 func (e *retryableError) Error() string { return e.Err.Error() }
262
View as plain text