1 package api
2
3 import (
4 "bytes"
5 "fmt"
6 "io"
7 "net/http"
8 "path/filepath"
9 "testing"
10 "time"
11
12 "github.com/stretchr/testify/assert"
13 )
14
15 func TestCacheResponse(t *testing.T) {
16 counter := 0
17 fakeHTTP := tripper{
18 roundTrip: func(req *http.Request) (*http.Response, error) {
19 counter += 1
20 body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
21 status := 200
22 if req.URL.Path == "/error" {
23 status = 500
24 }
25 return &http.Response{
26 StatusCode: status,
27 Body: io.NopCloser(bytes.NewBufferString(body)),
28 }, nil
29 },
30 }
31
32 cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")
33
34 httpClient, err := NewHTTPClient(
35 ClientOptions{
36 Host: "github.com",
37 AuthToken: "token",
38 Transport: fakeHTTP,
39 EnableCache: true,
40 CacheDir: cacheDir,
41 LogIgnoreEnv: true,
42 },
43 )
44 assert.NoError(t, err)
45
46 do := func(method, url string, body io.Reader) (string, error) {
47 req, err := http.NewRequest(method, url, body)
48 if err != nil {
49 return "", err
50 }
51 res, err := httpClient.Do(req)
52 if err != nil {
53 return "", err
54 }
55 defer res.Body.Close()
56 resBody, err := io.ReadAll(res.Body)
57 if err != nil {
58 err = fmt.Errorf("ReadAll: %w", err)
59 }
60 return string(resBody), err
61 }
62
63 var res string
64
65 res, err = do("GET", "http://example.com/path", nil)
66 assert.NoError(t, err)
67 assert.Equal(t, "1: GET http://example.com/path", res)
68 res, err = do("GET", "http://example.com/path", nil)
69 assert.NoError(t, err)
70 assert.Equal(t, "1: GET http://example.com/path", res)
71
72 res, err = do("GET", "http://example.com/path2", nil)
73 assert.NoError(t, err)
74 assert.Equal(t, "2: GET http://example.com/path2", res)
75
76 res, err = do("POST", "http://example.com/path2", nil)
77 assert.NoError(t, err)
78 assert.Equal(t, "3: POST http://example.com/path2", res)
79
80 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
81 assert.NoError(t, err)
82 assert.Equal(t, "4: POST http://example.com/graphql", res)
83 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
84 assert.NoError(t, err)
85 assert.Equal(t, "4: POST http://example.com/graphql", res)
86
87 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
88 assert.NoError(t, err)
89 assert.Equal(t, "5: POST http://example.com/graphql", res)
90
91 res, err = do("GET", "http://example.com/error", nil)
92 assert.NoError(t, err)
93 assert.Equal(t, "6: GET http://example.com/error", res)
94 res, err = do("GET", "http://example.com/error", nil)
95 assert.NoError(t, err)
96 assert.Equal(t, "7: GET http://example.com/error", res)
97 }
98
99 func TestCacheResponseRequestCacheOptions(t *testing.T) {
100 counter := 0
101 fakeHTTP := tripper{
102 roundTrip: func(req *http.Request) (*http.Response, error) {
103 counter += 1
104 body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
105 status := 200
106 if req.URL.Path == "/error" {
107 status = 500
108 }
109 return &http.Response{
110 StatusCode: status,
111 Body: io.NopCloser(bytes.NewBufferString(body)),
112 }, nil
113 },
114 }
115
116 cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")
117
118 httpClient, err := NewHTTPClient(
119 ClientOptions{
120 Host: "github.com",
121 AuthToken: "token",
122 Transport: fakeHTTP,
123 EnableCache: false,
124 CacheDir: cacheDir,
125 LogIgnoreEnv: true,
126 },
127 )
128 assert.NoError(t, err)
129
130 do := func(method, url string, body io.Reader) (string, error) {
131 req, err := http.NewRequest(method, url, body)
132 if err != nil {
133 return "", err
134 }
135 req.Header.Set("X-GH-CACHE-DIR", cacheDir)
136 req.Header.Set("X-GH-CACHE-TTL", "1h")
137 res, err := httpClient.Do(req)
138 if err != nil {
139 return "", err
140 }
141 defer res.Body.Close()
142 resBody, err := io.ReadAll(res.Body)
143 if err != nil {
144 err = fmt.Errorf("ReadAll: %w", err)
145 }
146 return string(resBody), err
147 }
148
149 var res string
150
151 res, err = do("GET", "http://example.com/path", nil)
152 assert.NoError(t, err)
153 assert.Equal(t, "1: GET http://example.com/path", res)
154 res, err = do("GET", "http://example.com/path", nil)
155 assert.NoError(t, err)
156 assert.Equal(t, "1: GET http://example.com/path", res)
157
158 res, err = do("GET", "http://example.com/path2", nil)
159 assert.NoError(t, err)
160 assert.Equal(t, "2: GET http://example.com/path2", res)
161
162 res, err = do("POST", "http://example.com/path2", nil)
163 assert.NoError(t, err)
164 assert.Equal(t, "3: POST http://example.com/path2", res)
165
166 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
167 assert.NoError(t, err)
168 assert.Equal(t, "4: POST http://example.com/graphql", res)
169 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
170 assert.NoError(t, err)
171 assert.Equal(t, "4: POST http://example.com/graphql", res)
172
173 res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
174 assert.NoError(t, err)
175 assert.Equal(t, "5: POST http://example.com/graphql", res)
176
177 res, err = do("GET", "http://example.com/error", nil)
178 assert.NoError(t, err)
179 assert.Equal(t, "6: GET http://example.com/error", res)
180 res, err = do("GET", "http://example.com/error", nil)
181 assert.NoError(t, err)
182 assert.Equal(t, "7: GET http://example.com/error", res)
183 }
184
185 func TestRequestCacheOptions(t *testing.T) {
186 req, err := http.NewRequest("GET", "some/url", nil)
187 assert.NoError(t, err)
188 req.Header.Set("X-GH-CACHE-DIR", "some/dir/path")
189 req.Header.Set("X-GH-CACHE-TTL", "1h")
190 dir, ttl := requestCacheOptions(req)
191 assert.Equal(t, dir, "some/dir/path")
192 assert.Equal(t, ttl, time.Hour)
193 }
194
View as plain text