...

Source file src/github.com/aws/aws-sdk-go-v2/feature/ec2/imds/shared_test.go

Documentation: github.com/aws/aws-sdk-go-v2/feature/ec2/imds

     1  package imds
     2  
     3  import (
     4  	"net/http"
     5  	"strconv"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  )
    11  
    12  type APIHandlers interface {
    13  	GetAPITokenHandler() http.Handler
    14  	GetAPIHandler() http.Handler
    15  }
    16  
    17  func newTestServeMux(t *testing.T, handlers APIHandlers) *http.ServeMux {
    18  	mux := http.NewServeMux()
    19  
    20  	mux.Handle(getTokenPath, validateAPITokenRequest(t, handlers.GetAPITokenHandler()))
    21  	mux.Handle("/latest/", handlers.GetAPIHandler())
    22  
    23  	return mux
    24  }
    25  
    26  func validateAPITokenRequest(t *testing.T, handler http.Handler) http.Handler {
    27  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    28  		if e, a := "PUT", r.Method; e != a {
    29  			t.Errorf("expect %v, http method got %v", e, a)
    30  			http.Error(w, http.StatusText(400), 400)
    31  			return
    32  		}
    33  		if len(r.Header.Get(tokenTTLHeader)) == 0 {
    34  			t.Errorf("expect token TTL header to be present in the request headers, got none")
    35  			http.Error(w, http.StatusText(400), 400)
    36  			return
    37  		}
    38  
    39  		handler.ServeHTTP(w, r)
    40  	})
    41  }
    42  
    43  type secureAPIHandler struct {
    44  	t *testing.T
    45  
    46  	tokens     []string
    47  	tokenTTL   time.Duration
    48  	apiHandler http.Handler
    49  
    50  	activeToken atomic.Value
    51  }
    52  
    53  func newSecureAPIHandler(t *testing.T, tokens []string, tokenTTL time.Duration, apiHandler http.Handler) *secureAPIHandler {
    54  	return &secureAPIHandler{
    55  		t:          t,
    56  		tokens:     tokens,
    57  		tokenTTL:   tokenTTL,
    58  		apiHandler: apiHandler,
    59  	}
    60  }
    61  
    62  func (h *secureAPIHandler) GetAPITokenHandler() http.Handler {
    63  	return http.HandlerFunc(h.handleAPIToken)
    64  }
    65  
    66  func (h *secureAPIHandler) handleAPIToken(w http.ResponseWriter, r *http.Request) {
    67  	token := h.tokens[0]
    68  
    69  	// set the active token
    70  	h.storeActiveToken(token)
    71  
    72  	// rotate the token
    73  	if len(h.tokens) > 1 {
    74  		h.tokens = h.tokens[1:]
    75  	}
    76  
    77  	var tokenTTLHeaderVal string
    78  	if h.tokenTTL == 0 {
    79  		tokenTTLHeaderVal = r.Header.Get(tokenTTLHeader)
    80  	} else {
    81  		tokenTTLHeaderVal = strconv.Itoa(int(h.tokenTTL / time.Second))
    82  	}
    83  
    84  	// set the header and response body
    85  	w.Header().Set(tokenTTLHeader, tokenTTLHeaderVal)
    86  	activeToken := h.getActiveToken()
    87  
    88  	w.Write([]byte(activeToken))
    89  }
    90  
    91  func (h *secureAPIHandler) GetAPIHandler() http.Handler {
    92  	return http.HandlerFunc(h.handleAPI)
    93  }
    94  
    95  func (h *secureAPIHandler) handleAPI(w http.ResponseWriter, r *http.Request) {
    96  	token := h.getActiveToken()
    97  	if len(token) == 0 {
    98  		h.t.Errorf("expect token to have been requested, was not")
    99  		http.Error(w, http.StatusText(401), 401)
   100  		return
   101  	}
   102  
   103  	if e, a := token, r.Header.Get(tokenHeader); e != a {
   104  		h.t.Errorf("expect %v token, got %v", e, a)
   105  		http.Error(w, http.StatusText(401), 401)
   106  		return
   107  	}
   108  
   109  	// delegate to configure handler for the request
   110  	h.apiHandler.ServeHTTP(w, r)
   111  }
   112  
   113  func (h *secureAPIHandler) storeActiveToken(t string) {
   114  	h.activeToken.Store(t)
   115  }
   116  
   117  func (h *secureAPIHandler) getActiveToken() string {
   118  	activeToken := h.activeToken.Load()
   119  	v, ok := activeToken.(string)
   120  	if !ok {
   121  		h.t.Errorf("expect valid active token string, got %T, %v", v, v)
   122  	}
   123  
   124  	return v
   125  }
   126  
   127  type insecureAPIHandler struct {
   128  	t               *testing.T
   129  	apiTokenErrCode int
   130  	apiHandler      http.Handler
   131  }
   132  
   133  func newInsecureAPIHandler(t *testing.T, apiTokenErrCode int, apiHandler http.Handler) *insecureAPIHandler {
   134  	return &insecureAPIHandler{
   135  		t:               t,
   136  		apiTokenErrCode: apiTokenErrCode,
   137  		apiHandler:      apiHandler,
   138  	}
   139  }
   140  
   141  func (h *insecureAPIHandler) GetAPITokenHandler() http.Handler {
   142  	return http.HandlerFunc(h.handleAPIToken)
   143  }
   144  
   145  func (h *insecureAPIHandler) handleAPIToken(w http.ResponseWriter, r *http.Request) {
   146  	http.Error(w, http.StatusText(h.apiTokenErrCode), h.apiTokenErrCode)
   147  }
   148  
   149  func (h *insecureAPIHandler) GetAPIHandler() http.Handler {
   150  	return http.HandlerFunc(h.handleAPI)
   151  }
   152  
   153  func (h *insecureAPIHandler) handleAPI(w http.ResponseWriter, r *http.Request) {
   154  	if len(r.Header.Get(tokenHeader)) != 0 {
   155  		h.t.Errorf("request token found, expected none")
   156  		http.Error(w, http.StatusText(400), 400)
   157  		return
   158  	}
   159  
   160  	// delegate to configure handler for the request
   161  	h.apiHandler.ServeHTTP(w, r)
   162  }
   163  
   164  type unauthorizedAPIHandler struct {
   165  	t *testing.T
   166  
   167  	enabled          bool
   168  	secureAPIHandler *secureAPIHandler
   169  }
   170  
   171  func newUnauthorizedAPIHandler(t *testing.T, secureHandler *secureAPIHandler) *unauthorizedAPIHandler {
   172  	return &unauthorizedAPIHandler{
   173  		t:                t,
   174  		secureAPIHandler: secureHandler,
   175  	}
   176  }
   177  
   178  func (h *unauthorizedAPIHandler) GetAPITokenHandler() http.Handler {
   179  	return http.HandlerFunc(h.handleAPIToken)
   180  }
   181  
   182  func (h *unauthorizedAPIHandler) handleAPIToken(w http.ResponseWriter, r *http.Request) {
   183  	// Respond with 404 first, then token after 401 API handler response
   184  	if !h.enabled {
   185  		http.Error(w, http.StatusText(404), 404)
   186  		return
   187  	}
   188  
   189  	h.secureAPIHandler.GetAPITokenHandler().ServeHTTP(w, r)
   190  }
   191  
   192  func (h *unauthorizedAPIHandler) GetAPIHandler() http.Handler {
   193  	return http.HandlerFunc(h.handleAPI)
   194  }
   195  
   196  func (h *unauthorizedAPIHandler) handleAPI(w http.ResponseWriter, r *http.Request) {
   197  	// Respond with 401 first, then 200 for second. When enabled switch to
   198  	// secure flow.
   199  	if !h.enabled {
   200  		h.enabled = true
   201  		http.Error(w, http.StatusText(401), 401)
   202  		return
   203  	}
   204  
   205  	h.secureAPIHandler.GetAPIHandler().ServeHTTP(w, r)
   206  }
   207  
   208  type requestTrace struct {
   209  	requests []string
   210  	mu       sync.Mutex
   211  }
   212  
   213  func newRequestTrace() *requestTrace {
   214  	return &requestTrace{}
   215  }
   216  
   217  func (t *requestTrace) WrapHandler(handler http.Handler) http.Handler {
   218  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   219  		t.mu.Lock()
   220  		t.requests = append(t.requests, r.URL.Path)
   221  		t.mu.Unlock()
   222  
   223  		handler.ServeHTTP(w, r)
   224  	})
   225  }
   226  

View as plain text