1 package imds
2
3 import (
4 "bytes"
5 "context"
6 "encoding/hex"
7 "io/ioutil"
8 "net/http/httptest"
9 "strings"
10 "testing"
11 "time"
12 )
13
14 func TestGetUserData(t *testing.T) {
15 cases := map[string]struct {
16 RespStatusCode int
17 ExpectContent []byte
18 ExpectTrace []string
19 ExpectErr string
20 }{
21 "get data": {
22 ExpectContent: []byte("success"),
23 ExpectTrace: []string{
24 getTokenPath,
25 getUserDataPath,
26 },
27 },
28 "get data error": {
29 RespStatusCode: 400,
30 ExpectTrace: []string{
31 getTokenPath,
32 getUserDataPath,
33 },
34 ExpectErr: "EC2 IMDS failed",
35 },
36 }
37
38 ctx := context.Background()
39
40 for name, c := range cases {
41 t.Run(name, func(t *testing.T) {
42 trace := newRequestTrace()
43 server := httptest.NewServer(trace.WrapHandler(
44 newTestServeMux(t,
45 newSecureAPIHandler(t,
46 []string{"tokenA"},
47 5*time.Minute,
48 &successAPIResponseHandler{t: t,
49 path: getUserDataPath,
50 method: "GET",
51 statusCode: c.RespStatusCode,
52 body: append([]byte{}, c.ExpectContent...),
53 },
54 ))))
55 defer server.Close()
56
57
58 client := New(Options{
59 Endpoint: server.URL,
60 })
61
62 resp, err := client.GetUserData(ctx, nil)
63 if len(c.ExpectErr) != 0 {
64 if err == nil {
65 t.Fatalf("expect error, got none")
66 }
67 if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
68 t.Fatalf("expect error to contain %v, got %v", e, a)
69 }
70 return
71 }
72 if err != nil {
73 t.Fatalf("expect no error, got %v", err)
74 }
75 if resp == nil {
76 t.Fatalf("expect resp, got none")
77 }
78
79 actualContent, err := ioutil.ReadAll(resp.Content)
80 if err != nil {
81 t.Fatalf("expect to read content, got %v", err)
82 }
83
84 if e, a := c.ExpectContent, actualContent; !bytes.Equal(e, a) {
85 t.Errorf("expect content to be equal\nexpect:\n%s\nactual:\n%s",
86 hex.Dump(e), hex.Dump(a))
87 }
88
89 if diff := cmpDiff(c.ExpectTrace, trace.requests); len(diff) != 0 {
90 t.Errorf("expect trace to match\n%s", diff)
91 }
92 })
93 }
94 }
95
View as plain text