...

Source file src/github.com/aws/aws-sdk-go-v2/feature/ec2/imds/request_middleware.go

Documentation: github.com/aws/aws-sdk-go-v2/feature/ec2/imds

     1  package imds
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/url"
     9  	"path"
    10  	"time"
    11  
    12  	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
    13  	"github.com/aws/aws-sdk-go-v2/aws/retry"
    14  	"github.com/aws/smithy-go/middleware"
    15  	smithyhttp "github.com/aws/smithy-go/transport/http"
    16  )
    17  
    18  func addAPIRequestMiddleware(stack *middleware.Stack,
    19  	options Options,
    20  	operation string,
    21  	getPath func(interface{}) (string, error),
    22  	getOutput func(*smithyhttp.Response) (interface{}, error),
    23  ) (err error) {
    24  	err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
    25  	if err != nil {
    26  		return err
    27  	}
    28  
    29  	// Token Serializer build and state management.
    30  	if !options.disableAPIToken {
    31  		err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
    32  		if err != nil {
    33  			return err
    34  		}
    35  
    36  		err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
    37  		if err != nil {
    38  			return err
    39  		}
    40  	}
    41  
    42  	return nil
    43  }
    44  
    45  func addRequestMiddleware(stack *middleware.Stack,
    46  	options Options,
    47  	method string,
    48  	operation string,
    49  	getPath func(interface{}) (string, error),
    50  	getOutput func(*smithyhttp.Response) (interface{}, error),
    51  ) (err error) {
    52  	err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	// Operation timeout
    58  	err = stack.Initialize.Add(&operationTimeout{
    59  		Disabled:       options.DisableDefaultTimeout,
    60  		DefaultTimeout: defaultOperationTimeout,
    61  	}, middleware.Before)
    62  	if err != nil {
    63  		return err
    64  	}
    65  
    66  	// Operation Serializer
    67  	err = stack.Serialize.Add(&serializeRequest{
    68  		GetPath: getPath,
    69  		Method:  method,
    70  	}, middleware.After)
    71  	if err != nil {
    72  		return err
    73  	}
    74  
    75  	// Operation endpoint resolver
    76  	err = stack.Serialize.Insert(&resolveEndpoint{
    77  		Endpoint:     options.Endpoint,
    78  		EndpointMode: options.EndpointMode,
    79  	}, "OperationSerializer", middleware.Before)
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	// Operation Deserializer
    85  	err = stack.Deserialize.Add(&deserializeResponse{
    86  		GetOutput: getOutput,
    87  	}, middleware.After)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
    93  		LogRequest:          options.ClientLogMode.IsRequest(),
    94  		LogRequestWithBody:  options.ClientLogMode.IsRequestWithBody(),
    95  		LogResponse:         options.ClientLogMode.IsResponse(),
    96  		LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
    97  	}, middleware.After)
    98  	if err != nil {
    99  		return err
   100  	}
   101  
   102  	err = addSetLoggerMiddleware(stack, options)
   103  	if err != nil {
   104  		return err
   105  	}
   106  
   107  	if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
   108  		return fmt.Errorf("add protocol finalizers: %w", err)
   109  	}
   110  
   111  	// Retry support
   112  	return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
   113  		Retryer:          options.Retryer,
   114  		LogRetryAttempts: options.ClientLogMode.IsRetries(),
   115  	})
   116  }
   117  
   118  func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
   119  	return middleware.AddSetLoggerMiddleware(stack, o.Logger)
   120  }
   121  
   122  type serializeRequest struct {
   123  	GetPath func(interface{}) (string, error)
   124  	Method  string
   125  }
   126  
   127  func (*serializeRequest) ID() string {
   128  	return "OperationSerializer"
   129  }
   130  
   131  func (m *serializeRequest) HandleSerialize(
   132  	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
   133  ) (
   134  	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
   135  ) {
   136  	request, ok := in.Request.(*smithyhttp.Request)
   137  	if !ok {
   138  		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
   139  	}
   140  
   141  	reqPath, err := m.GetPath(in.Parameters)
   142  	if err != nil {
   143  		return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
   144  	}
   145  
   146  	request.Request.URL.Path = reqPath
   147  	request.Request.Method = m.Method
   148  
   149  	return next.HandleSerialize(ctx, in)
   150  }
   151  
   152  type deserializeResponse struct {
   153  	GetOutput func(*smithyhttp.Response) (interface{}, error)
   154  }
   155  
   156  func (*deserializeResponse) ID() string {
   157  	return "OperationDeserializer"
   158  }
   159  
   160  func (m *deserializeResponse) HandleDeserialize(
   161  	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
   162  ) (
   163  	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
   164  ) {
   165  	out, metadata, err = next.HandleDeserialize(ctx, in)
   166  	if err != nil {
   167  		return out, metadata, err
   168  	}
   169  
   170  	resp, ok := out.RawResponse.(*smithyhttp.Response)
   171  	if !ok {
   172  		return out, metadata, fmt.Errorf(
   173  			"unexpected transport response type, %T, want %T", out.RawResponse, resp)
   174  	}
   175  	defer resp.Body.Close()
   176  
   177  	// read the full body so that any operation timeouts cleanup will not race
   178  	// the body being read.
   179  	body, err := ioutil.ReadAll(resp.Body)
   180  	if err != nil {
   181  		return out, metadata, fmt.Errorf("read response body failed, %w", err)
   182  	}
   183  	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
   184  
   185  	// Anything that's not 200 |< 300 is error
   186  	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
   187  		return out, metadata, &smithyhttp.ResponseError{
   188  			Response: resp,
   189  			Err:      fmt.Errorf("request to EC2 IMDS failed"),
   190  		}
   191  	}
   192  
   193  	result, err := m.GetOutput(resp)
   194  	if err != nil {
   195  		return out, metadata, fmt.Errorf(
   196  			"unable to get deserialized result for response, %w", err,
   197  		)
   198  	}
   199  	out.Result = result
   200  
   201  	return out, metadata, err
   202  }
   203  
   204  type resolveEndpoint struct {
   205  	Endpoint     string
   206  	EndpointMode EndpointModeState
   207  }
   208  
   209  func (*resolveEndpoint) ID() string {
   210  	return "ResolveEndpoint"
   211  }
   212  
   213  func (m *resolveEndpoint) HandleSerialize(
   214  	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
   215  ) (
   216  	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
   217  ) {
   218  
   219  	req, ok := in.Request.(*smithyhttp.Request)
   220  	if !ok {
   221  		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
   222  	}
   223  
   224  	var endpoint string
   225  	if len(m.Endpoint) > 0 {
   226  		endpoint = m.Endpoint
   227  	} else {
   228  		switch m.EndpointMode {
   229  		case EndpointModeStateIPv6:
   230  			endpoint = defaultIPv6Endpoint
   231  		case EndpointModeStateIPv4:
   232  			fallthrough
   233  		case EndpointModeStateUnset:
   234  			endpoint = defaultIPv4Endpoint
   235  		default:
   236  			return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
   237  		}
   238  	}
   239  
   240  	req.URL, err = url.Parse(endpoint)
   241  	if err != nil {
   242  		return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
   243  	}
   244  
   245  	return next.HandleSerialize(ctx, in)
   246  }
   247  
   248  const (
   249  	defaultOperationTimeout = 5 * time.Second
   250  )
   251  
   252  // operationTimeout adds a timeout on the middleware stack if the Context the
   253  // stack was called with does not have a deadline. The next middleware must
   254  // complete before the timeout, or the context will be canceled.
   255  //
   256  // If DefaultTimeout is zero, no default timeout will be used if the Context
   257  // does not have a timeout.
   258  //
   259  // The next middleware must also ensure that any resources that are also
   260  // canceled by the stack's context are completely consumed before returning.
   261  // Otherwise the timeout cleanup will race the resource being consumed
   262  // upstream.
   263  type operationTimeout struct {
   264  	Disabled       bool
   265  	DefaultTimeout time.Duration
   266  }
   267  
   268  func (*operationTimeout) ID() string { return "OperationTimeout" }
   269  
   270  func (m *operationTimeout) HandleInitialize(
   271  	ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
   272  ) (
   273  	output middleware.InitializeOutput, metadata middleware.Metadata, err error,
   274  ) {
   275  	if m.Disabled {
   276  		return next.HandleInitialize(ctx, input)
   277  	}
   278  
   279  	if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
   280  		var cancelFn func()
   281  		ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
   282  		defer cancelFn()
   283  	}
   284  
   285  	return next.HandleInitialize(ctx, input)
   286  }
   287  
   288  // appendURIPath joins a URI path component to the existing path with `/`
   289  // separators between the path components. If the path being added ends with a
   290  // trailing `/` that slash will be maintained.
   291  func appendURIPath(base, add string) string {
   292  	reqPath := path.Join(base, add)
   293  	if len(add) != 0 && add[len(add)-1] == '/' {
   294  		reqPath += "/"
   295  	}
   296  	return reqPath
   297  }
   298  
   299  func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
   300  	if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
   301  		return fmt.Errorf("add ResolveAuthScheme: %w", err)
   302  	}
   303  	if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
   304  		return fmt.Errorf("add GetIdentity: %w", err)
   305  	}
   306  	if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
   307  		return fmt.Errorf("add ResolveEndpointV2: %w", err)
   308  	}
   309  	if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
   310  		return fmt.Errorf("add Signing: %w", err)
   311  	}
   312  	return nil
   313  }
   314  

View as plain text