...

Source file src/github.com/aws/aws-sdk-go-v2/feature/ec2/imds/token_provider.go

Documentation: github.com/aws/aws-sdk-go-v2/feature/ec2/imds

     1  package imds
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/aws/aws-sdk-go-v2/aws"
     8  	"github.com/aws/smithy-go"
     9  	"github.com/aws/smithy-go/logging"
    10  	"net/http"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/aws/smithy-go/middleware"
    16  	smithyhttp "github.com/aws/smithy-go/transport/http"
    17  )
    18  
    19  const (
    20  	// Headers for Token and TTL
    21  	tokenHeader     = "x-aws-ec2-metadata-token"
    22  	defaultTokenTTL = 5 * time.Minute
    23  )
    24  
    25  type tokenProvider struct {
    26  	client   *Client
    27  	tokenTTL time.Duration
    28  
    29  	token    *apiToken
    30  	tokenMux sync.RWMutex
    31  
    32  	disabled uint32 // Atomic updated
    33  }
    34  
    35  func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
    36  	return &tokenProvider{
    37  		client:   client,
    38  		tokenTTL: ttl,
    39  	}
    40  }
    41  
    42  // apiToken provides the API token used by all operation calls for th EC2
    43  // Instance metadata service.
    44  type apiToken struct {
    45  	token   string
    46  	expires time.Time
    47  }
    48  
    49  var timeNow = time.Now
    50  
    51  // Expired returns if the token is expired.
    52  func (t *apiToken) Expired() bool {
    53  	// Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
    54  	// time is always based on reported wall-clock time.
    55  	return timeNow().Round(0).After(t.expires)
    56  }
    57  
    58  func (t *tokenProvider) ID() string { return "APITokenProvider" }
    59  
    60  // HandleFinalize is the finalize stack middleware, that if the token provider is
    61  // enabled, will attempt to add the cached API token to the request. If the API
    62  // token is not cached, it will be retrieved in a separate API call, getToken.
    63  //
    64  // For retry attempts, handler must be added after attempt retryer.
    65  //
    66  // If request for getToken fails the token provider may be disabled from future
    67  // requests, depending on the response status code.
    68  func (t *tokenProvider) HandleFinalize(
    69  	ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
    70  ) (
    71  	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
    72  ) {
    73  	if t.fallbackEnabled() && !t.enabled() {
    74  		// short-circuits to insecure data flow if token provider is disabled.
    75  		return next.HandleFinalize(ctx, input)
    76  	}
    77  
    78  	req, ok := input.Request.(*smithyhttp.Request)
    79  	if !ok {
    80  		return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
    81  	}
    82  
    83  	tok, err := t.getToken(ctx)
    84  	if err != nil {
    85  		// If the error allows the token to downgrade to insecure flow allow that.
    86  		var bypassErr *bypassTokenRetrievalError
    87  		if errors.As(err, &bypassErr) {
    88  			return next.HandleFinalize(ctx, input)
    89  		}
    90  
    91  		return out, metadata, fmt.Errorf("failed to get API token, %w", err)
    92  	}
    93  
    94  	req.Header.Set(tokenHeader, tok.token)
    95  
    96  	return next.HandleFinalize(ctx, input)
    97  }
    98  
    99  // HandleDeserialize is the deserialize stack middleware for determining if the
   100  // operation the token provider is decorating failed because of a 401
   101  // unauthorized status code. If the operation failed for that reason the token
   102  // provider needs to be re-enabled so that it can start adding the API token to
   103  // operation calls.
   104  func (t *tokenProvider) HandleDeserialize(
   105  	ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
   106  ) (
   107  	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
   108  ) {
   109  	out, metadata, err = next.HandleDeserialize(ctx, input)
   110  	if err == nil {
   111  		return out, metadata, err
   112  	}
   113  
   114  	resp, ok := out.RawResponse.(*smithyhttp.Response)
   115  	if !ok {
   116  		return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
   117  	}
   118  
   119  	if resp.StatusCode == http.StatusUnauthorized { // unauthorized
   120  		t.enable()
   121  		err = &retryableError{Err: err, isRetryable: true}
   122  	}
   123  
   124  	return out, metadata, err
   125  }
   126  
   127  func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
   128  	if t.fallbackEnabled() && !t.enabled() {
   129  		return nil, &bypassTokenRetrievalError{
   130  			Err: fmt.Errorf("cannot get API token, provider disabled"),
   131  		}
   132  	}
   133  
   134  	t.tokenMux.RLock()
   135  	tok = t.token
   136  	t.tokenMux.RUnlock()
   137  
   138  	if tok != nil && !tok.Expired() {
   139  		return tok, nil
   140  	}
   141  
   142  	tok, err = t.updateToken(ctx)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	return tok, nil
   148  }
   149  
   150  func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
   151  	t.tokenMux.Lock()
   152  	defer t.tokenMux.Unlock()
   153  
   154  	// Prevent multiple requests to update retrieving the token.
   155  	if t.token != nil && !t.token.Expired() {
   156  		tok := t.token
   157  		return tok, nil
   158  	}
   159  
   160  	result, err := t.client.getToken(ctx, &getTokenInput{
   161  		TokenTTL: t.tokenTTL,
   162  	})
   163  	if err != nil {
   164  		var statusErr interface{ HTTPStatusCode() int }
   165  		if errors.As(err, &statusErr) {
   166  			switch statusErr.HTTPStatusCode() {
   167  			// Disable future get token if failed because of 403, 404, or 405
   168  			case http.StatusForbidden,
   169  				http.StatusNotFound,
   170  				http.StatusMethodNotAllowed:
   171  
   172  				if t.fallbackEnabled() {
   173  					logger := middleware.GetLogger(ctx)
   174  					logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
   175  					t.disable()
   176  				}
   177  
   178  			// 400 errors are terminal, and need to be upstreamed
   179  			case http.StatusBadRequest:
   180  				return nil, err
   181  			}
   182  		}
   183  
   184  		// Disable if request send failed or timed out getting response
   185  		var re *smithyhttp.RequestSendError
   186  		var ce *smithy.CanceledError
   187  		if errors.As(err, &re) || errors.As(err, &ce) {
   188  			atomic.StoreUint32(&t.disabled, 1)
   189  		}
   190  
   191  		if !t.fallbackEnabled() {
   192  			// NOTE: getToken() is an implementation detail of some outer operation
   193  			// (e.g. GetMetadata). It has its own retries that have already been exhausted.
   194  			// Mark the underlying error as a terminal error.
   195  			err = &retryableError{Err: err, isRetryable: false}
   196  			return nil, err
   197  		}
   198  
   199  		// Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
   200  		// and allow the request to proceed. Future requests _may_ re-attempt fetching a
   201  		// token if not disabled.
   202  		return nil, &bypassTokenRetrievalError{Err: err}
   203  	}
   204  
   205  	tok := &apiToken{
   206  		token:   result.Token,
   207  		expires: timeNow().Add(result.TokenTTL),
   208  	}
   209  	t.token = tok
   210  
   211  	return tok, nil
   212  }
   213  
   214  // enabled returns if the token provider is current enabled or not.
   215  func (t *tokenProvider) enabled() bool {
   216  	return atomic.LoadUint32(&t.disabled) == 0
   217  }
   218  
   219  // fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
   220  func (t *tokenProvider) fallbackEnabled() bool {
   221  	switch t.client.options.EnableFallback {
   222  	case aws.FalseTernary:
   223  		return false
   224  	default:
   225  		return true
   226  	}
   227  }
   228  
   229  // disable disables the token provider and it will no longer attempt to inject
   230  // the token, nor request updates.
   231  func (t *tokenProvider) disable() {
   232  	atomic.StoreUint32(&t.disabled, 1)
   233  }
   234  
   235  // enable enables the token provide to start refreshing tokens, and adding them
   236  // to the pending request.
   237  func (t *tokenProvider) enable() {
   238  	t.tokenMux.Lock()
   239  	t.token = nil
   240  	t.tokenMux.Unlock()
   241  	atomic.StoreUint32(&t.disabled, 0)
   242  }
   243  
   244  type bypassTokenRetrievalError struct {
   245  	Err error
   246  }
   247  
   248  func (e *bypassTokenRetrievalError) Error() string {
   249  	return fmt.Sprintf("bypass token retrieval, %v", e.Err)
   250  }
   251  
   252  func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
   253  
   254  type retryableError struct {
   255  	Err         error
   256  	isRetryable bool
   257  }
   258  
   259  func (e *retryableError) RetryableError() bool { return e.isRetryable }
   260  
   261  func (e *retryableError) Error() string { return e.Err.Error() }
   262  

View as plain text