1 package imds
2
3 import (
4 "context"
5 "net/http"
6 "net/http/httptest"
7 "strconv"
8 "strings"
9 "testing"
10 "time"
11 )
12
13 func TestGetToken(t *testing.T) {
14 cases := map[string]struct {
15 TokenTTL time.Duration
16 Header http.Header
17 Body []byte
18 ExpectToken string
19 ExpectTokenTTL time.Duration
20 ExpectTrace []string
21 ExpectErr string
22 }{
23 "success": {
24 TokenTTL: 10 * time.Second,
25 Header: http.Header{
26 tokenTTLHeader: []string{"10"},
27 },
28 Body: []byte("tokenABC"),
29 ExpectToken: "tokenABC",
30 ExpectTokenTTL: 10 * time.Second,
31 ExpectTrace: []string{
32 getTokenPath,
33 },
34 },
35 }
36
37 ctx := context.Background()
38
39 for name, c := range cases {
40 t.Run(name, func(t *testing.T) {
41 trace := newRequestTrace()
42 server := httptest.NewServer(trace.WrapHandler(
43 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
44 actualTTL := r.Header.Get(tokenTTLHeader)
45 expectTTL := strconv.Itoa(int(c.TokenTTL / time.Second))
46 if expectTTL != actualTTL {
47 t.Errorf("expect %v token TTL request header, got %v",
48 expectTTL, actualTTL)
49 http.Error(w, http.StatusText(400), 400)
50 return
51 }
52
53 (&successAPIResponseHandler{t: t,
54 path: getTokenPath,
55 method: "PUT",
56 header: c.Header,
57 body: append([]byte{}, c.Body...),
58 }).ServeHTTP(w, r)
59 })))
60 defer server.Close()
61
62
63 client := New(Options{
64 Endpoint: server.URL,
65 })
66
67 resp, err := client.getToken(ctx, &getTokenInput{
68 TokenTTL: c.TokenTTL,
69 })
70 if len(c.ExpectErr) != 0 {
71 if err == nil {
72 t.Fatalf("expect error, got none")
73 }
74 if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
75 t.Fatalf("expect error to contain %v, got %v", e, a)
76 }
77 return
78 }
79 if err != nil {
80 t.Fatalf("expect no error, got %v", err)
81 }
82
83 if resp == nil {
84 t.Fatalf("expect resp, got none")
85 }
86
87 if e, a := c.ExpectToken, resp.Token; e != a {
88 t.Errorf("expect %v token, got %v", e, a)
89 }
90 if e, a := c.ExpectTokenTTL, resp.TokenTTL; e != a {
91 t.Errorf("expect %v token TTL, got %v", e, a)
92 }
93
94 if diff := cmpDiff(c.ExpectTrace, trace.requests); len(diff) != 0 {
95 t.Errorf("expect trace to match\n%s", diff)
96 }
97 })
98 }
99 }
100
View as plain text