1 package imds
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "io/ioutil"
8 "net/url"
9 "path"
10 "time"
11
12 awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
13 "github.com/aws/aws-sdk-go-v2/aws/retry"
14 "github.com/aws/smithy-go/middleware"
15 smithyhttp "github.com/aws/smithy-go/transport/http"
16 )
17
18 func addAPIRequestMiddleware(stack *middleware.Stack,
19 options Options,
20 operation string,
21 getPath func(interface{}) (string, error),
22 getOutput func(*smithyhttp.Response) (interface{}, error),
23 ) (err error) {
24 err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
25 if err != nil {
26 return err
27 }
28
29
30 if !options.disableAPIToken {
31 err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
32 if err != nil {
33 return err
34 }
35
36 err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
37 if err != nil {
38 return err
39 }
40 }
41
42 return nil
43 }
44
45 func addRequestMiddleware(stack *middleware.Stack,
46 options Options,
47 method string,
48 operation string,
49 getPath func(interface{}) (string, error),
50 getOutput func(*smithyhttp.Response) (interface{}, error),
51 ) (err error) {
52 err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
53 if err != nil {
54 return err
55 }
56
57
58 err = stack.Initialize.Add(&operationTimeout{
59 Disabled: options.DisableDefaultTimeout,
60 DefaultTimeout: defaultOperationTimeout,
61 }, middleware.Before)
62 if err != nil {
63 return err
64 }
65
66
67 err = stack.Serialize.Add(&serializeRequest{
68 GetPath: getPath,
69 Method: method,
70 }, middleware.After)
71 if err != nil {
72 return err
73 }
74
75
76 err = stack.Serialize.Insert(&resolveEndpoint{
77 Endpoint: options.Endpoint,
78 EndpointMode: options.EndpointMode,
79 }, "OperationSerializer", middleware.Before)
80 if err != nil {
81 return err
82 }
83
84
85 err = stack.Deserialize.Add(&deserializeResponse{
86 GetOutput: getOutput,
87 }, middleware.After)
88 if err != nil {
89 return err
90 }
91
92 err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
93 LogRequest: options.ClientLogMode.IsRequest(),
94 LogRequestWithBody: options.ClientLogMode.IsRequestWithBody(),
95 LogResponse: options.ClientLogMode.IsResponse(),
96 LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
97 }, middleware.After)
98 if err != nil {
99 return err
100 }
101
102 err = addSetLoggerMiddleware(stack, options)
103 if err != nil {
104 return err
105 }
106
107 if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
108 return fmt.Errorf("add protocol finalizers: %w", err)
109 }
110
111
112 return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
113 Retryer: options.Retryer,
114 LogRetryAttempts: options.ClientLogMode.IsRetries(),
115 })
116 }
117
118 func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
119 return middleware.AddSetLoggerMiddleware(stack, o.Logger)
120 }
121
122 type serializeRequest struct {
123 GetPath func(interface{}) (string, error)
124 Method string
125 }
126
127 func (*serializeRequest) ID() string {
128 return "OperationSerializer"
129 }
130
131 func (m *serializeRequest) HandleSerialize(
132 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
133 ) (
134 out middleware.SerializeOutput, metadata middleware.Metadata, err error,
135 ) {
136 request, ok := in.Request.(*smithyhttp.Request)
137 if !ok {
138 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
139 }
140
141 reqPath, err := m.GetPath(in.Parameters)
142 if err != nil {
143 return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
144 }
145
146 request.Request.URL.Path = reqPath
147 request.Request.Method = m.Method
148
149 return next.HandleSerialize(ctx, in)
150 }
151
152 type deserializeResponse struct {
153 GetOutput func(*smithyhttp.Response) (interface{}, error)
154 }
155
156 func (*deserializeResponse) ID() string {
157 return "OperationDeserializer"
158 }
159
160 func (m *deserializeResponse) HandleDeserialize(
161 ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
162 ) (
163 out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
164 ) {
165 out, metadata, err = next.HandleDeserialize(ctx, in)
166 if err != nil {
167 return out, metadata, err
168 }
169
170 resp, ok := out.RawResponse.(*smithyhttp.Response)
171 if !ok {
172 return out, metadata, fmt.Errorf(
173 "unexpected transport response type, %T, want %T", out.RawResponse, resp)
174 }
175 defer resp.Body.Close()
176
177
178
179 body, err := ioutil.ReadAll(resp.Body)
180 if err != nil {
181 return out, metadata, fmt.Errorf("read response body failed, %w", err)
182 }
183 resp.Body = ioutil.NopCloser(bytes.NewReader(body))
184
185
186 if resp.StatusCode < 200 || resp.StatusCode >= 300 {
187 return out, metadata, &smithyhttp.ResponseError{
188 Response: resp,
189 Err: fmt.Errorf("request to EC2 IMDS failed"),
190 }
191 }
192
193 result, err := m.GetOutput(resp)
194 if err != nil {
195 return out, metadata, fmt.Errorf(
196 "unable to get deserialized result for response, %w", err,
197 )
198 }
199 out.Result = result
200
201 return out, metadata, err
202 }
203
204 type resolveEndpoint struct {
205 Endpoint string
206 EndpointMode EndpointModeState
207 }
208
209 func (*resolveEndpoint) ID() string {
210 return "ResolveEndpoint"
211 }
212
213 func (m *resolveEndpoint) HandleSerialize(
214 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
215 ) (
216 out middleware.SerializeOutput, metadata middleware.Metadata, err error,
217 ) {
218
219 req, ok := in.Request.(*smithyhttp.Request)
220 if !ok {
221 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
222 }
223
224 var endpoint string
225 if len(m.Endpoint) > 0 {
226 endpoint = m.Endpoint
227 } else {
228 switch m.EndpointMode {
229 case EndpointModeStateIPv6:
230 endpoint = defaultIPv6Endpoint
231 case EndpointModeStateIPv4:
232 fallthrough
233 case EndpointModeStateUnset:
234 endpoint = defaultIPv4Endpoint
235 default:
236 return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
237 }
238 }
239
240 req.URL, err = url.Parse(endpoint)
241 if err != nil {
242 return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
243 }
244
245 return next.HandleSerialize(ctx, in)
246 }
247
248 const (
249 defaultOperationTimeout = 5 * time.Second
250 )
251
252
253
254
255
256
257
258
259
260
261
262
263 type operationTimeout struct {
264 Disabled bool
265 DefaultTimeout time.Duration
266 }
267
268 func (*operationTimeout) ID() string { return "OperationTimeout" }
269
270 func (m *operationTimeout) HandleInitialize(
271 ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
272 ) (
273 output middleware.InitializeOutput, metadata middleware.Metadata, err error,
274 ) {
275 if m.Disabled {
276 return next.HandleInitialize(ctx, input)
277 }
278
279 if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
280 var cancelFn func()
281 ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
282 defer cancelFn()
283 }
284
285 return next.HandleInitialize(ctx, input)
286 }
287
288
289
290
291 func appendURIPath(base, add string) string {
292 reqPath := path.Join(base, add)
293 if len(add) != 0 && add[len(add)-1] == '/' {
294 reqPath += "/"
295 }
296 return reqPath
297 }
298
299 func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
300 if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
301 return fmt.Errorf("add ResolveAuthScheme: %w", err)
302 }
303 if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
304 return fmt.Errorf("add GetIdentity: %w", err)
305 }
306 if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
307 return fmt.Errorf("add ResolveEndpointV2: %w", err)
308 }
309 if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
310 return fmt.Errorf("add Signing: %w", err)
311 }
312 return nil
313 }
314
View as plain text