1 package http_test
2
3 import (
4 "context"
5 "errors"
6 "io/ioutil"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/go-kit/kit/endpoint"
14 httptransport "github.com/go-kit/kit/transport/http"
15 )
16
17 func TestServerBadDecode(t *testing.T) {
18 handler := httptransport.NewServer(
19 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
20 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") },
21 func(context.Context, http.ResponseWriter, interface{}) error { return nil },
22 )
23 server := httptest.NewServer(handler)
24 defer server.Close()
25 resp, _ := http.Get(server.URL)
26 if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
27 t.Errorf("want %d, have %d", want, have)
28 }
29 }
30
31 func TestServerBadEndpoint(t *testing.T) {
32 handler := httptransport.NewServer(
33 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") },
34 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
35 func(context.Context, http.ResponseWriter, interface{}) error { return nil },
36 )
37 server := httptest.NewServer(handler)
38 defer server.Close()
39 resp, _ := http.Get(server.URL)
40 if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
41 t.Errorf("want %d, have %d", want, have)
42 }
43 }
44
45 func TestServerBadEncode(t *testing.T) {
46 handler := httptransport.NewServer(
47 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
48 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
49 func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") },
50 )
51 server := httptest.NewServer(handler)
52 defer server.Close()
53 resp, _ := http.Get(server.URL)
54 if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
55 t.Errorf("want %d, have %d", want, have)
56 }
57 }
58
59 func TestServerErrorEncoder(t *testing.T) {
60 errTeapot := errors.New("teapot")
61 code := func(err error) int {
62 if errors.Is(err, errTeapot) {
63 return http.StatusTeapot
64 }
65 return http.StatusInternalServerError
66 }
67 handler := httptransport.NewServer(
68 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot },
69 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
70 func(context.Context, http.ResponseWriter, interface{}) error { return nil },
71 httptransport.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }),
72 )
73 server := httptest.NewServer(handler)
74 defer server.Close()
75 resp, _ := http.Get(server.URL)
76 if want, have := http.StatusTeapot, resp.StatusCode; want != have {
77 t.Errorf("want %d, have %d", want, have)
78 }
79 }
80
81 func TestServerHappyPath(t *testing.T) {
82 step, response := testServer(t)
83 step()
84 resp := <-response
85 defer resp.Body.Close()
86 buf, _ := ioutil.ReadAll(resp.Body)
87 if want, have := http.StatusOK, resp.StatusCode; want != have {
88 t.Errorf("want %d, have %d (%s)", want, have, buf)
89 }
90 }
91
92 func TestMultipleServerBefore(t *testing.T) {
93 var (
94 headerKey = "X-Henlo-Lizer"
95 headerVal = "Helllo you stinky lizard"
96 statusCode = http.StatusTeapot
97 responseBody = "go eat a fly ugly\n"
98 done = make(chan struct{})
99 )
100 handler := httptransport.NewServer(
101 endpoint.Nop,
102 func(context.Context, *http.Request) (interface{}, error) {
103 return struct{}{}, nil
104 },
105 func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
106 w.Header().Set(headerKey, headerVal)
107 w.WriteHeader(statusCode)
108 w.Write([]byte(responseBody))
109 return nil
110 },
111 httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
112 ctx = context.WithValue(ctx, "one", 1)
113
114 return ctx
115 }),
116 httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
117 if _, ok := ctx.Value("one").(int); !ok {
118 t.Error("Value was not set properly when multiple ServerBefores are used")
119 }
120
121 close(done)
122 return ctx
123 }),
124 )
125
126 server := httptest.NewServer(handler)
127 defer server.Close()
128 go http.Get(server.URL)
129
130 select {
131 case <-done:
132 case <-time.After(time.Second):
133 t.Fatal("timeout waiting for finalizer")
134 }
135 }
136
137 func TestMultipleServerAfter(t *testing.T) {
138 var (
139 headerKey = "X-Henlo-Lizer"
140 headerVal = "Helllo you stinky lizard"
141 statusCode = http.StatusTeapot
142 responseBody = "go eat a fly ugly\n"
143 done = make(chan struct{})
144 )
145 handler := httptransport.NewServer(
146 endpoint.Nop,
147 func(context.Context, *http.Request) (interface{}, error) {
148 return struct{}{}, nil
149 },
150 func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
151 w.Header().Set(headerKey, headerVal)
152 w.WriteHeader(statusCode)
153 w.Write([]byte(responseBody))
154 return nil
155 },
156 httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
157 ctx = context.WithValue(ctx, "one", 1)
158
159 return ctx
160 }),
161 httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
162 if _, ok := ctx.Value("one").(int); !ok {
163 t.Error("Value was not set properly when multiple ServerAfters are used")
164 }
165
166 close(done)
167 return ctx
168 }),
169 )
170
171 server := httptest.NewServer(handler)
172 defer server.Close()
173 go http.Get(server.URL)
174
175 select {
176 case <-done:
177 case <-time.After(time.Second):
178 t.Fatal("timeout waiting for finalizer")
179 }
180 }
181
182 func TestServerFinalizer(t *testing.T) {
183 var (
184 headerKey = "X-Henlo-Lizer"
185 headerVal = "Helllo you stinky lizard"
186 statusCode = http.StatusTeapot
187 responseBody = "go eat a fly ugly\n"
188 done = make(chan struct{})
189 )
190 handler := httptransport.NewServer(
191 endpoint.Nop,
192 func(context.Context, *http.Request) (interface{}, error) {
193 return struct{}{}, nil
194 },
195 func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
196 w.Header().Set(headerKey, headerVal)
197 w.WriteHeader(statusCode)
198 w.Write([]byte(responseBody))
199 return nil
200 },
201 httptransport.ServerFinalizer(func(ctx context.Context, code int, _ *http.Request) {
202 if want, have := statusCode, code; want != have {
203 t.Errorf("StatusCode: want %d, have %d", want, have)
204 }
205
206 responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header)
207 if want, have := headerVal, responseHeader.Get(headerKey); want != have {
208 t.Errorf("%s: want %q, have %q", headerKey, want, have)
209 }
210
211 responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64)
212 if want, have := int64(len(responseBody)), responseSize; want != have {
213 t.Errorf("response size: want %d, have %d", want, have)
214 }
215
216 close(done)
217 }),
218 )
219
220 server := httptest.NewServer(handler)
221 defer server.Close()
222 go http.Get(server.URL)
223
224 select {
225 case <-done:
226 case <-time.After(time.Second):
227 t.Fatal("timeout waiting for finalizer")
228 }
229 }
230
231 type enhancedResponse struct {
232 Foo string `json:"foo"`
233 }
234
235 func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired }
236 func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }
237
238 func TestEncodeJSONResponse(t *testing.T) {
239 handler := httptransport.NewServer(
240 func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil },
241 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
242 httptransport.EncodeJSONResponse,
243 )
244
245 server := httptest.NewServer(handler)
246 defer server.Close()
247
248 resp, err := http.Get(server.URL)
249 if err != nil {
250 t.Fatal(err)
251 }
252 if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have {
253 t.Errorf("StatusCode: want %d, have %d", want, have)
254 }
255 if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have {
256 t.Errorf("X-Edward: want %q, have %q", want, have)
257 }
258 buf, _ := ioutil.ReadAll(resp.Body)
259 if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have {
260 t.Errorf("Body: want %s, have %s", want, have)
261 }
262 }
263
264 type multiHeaderResponse struct{}
265
266 func (_ multiHeaderResponse) Headers() http.Header {
267 return http.Header{"Vary": []string{"Origin", "User-Agent"}}
268 }
269
270 func TestAddMultipleHeaders(t *testing.T) {
271 handler := httptransport.NewServer(
272 func(context.Context, interface{}) (interface{}, error) { return multiHeaderResponse{}, nil },
273 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
274 httptransport.EncodeJSONResponse,
275 )
276
277 server := httptest.NewServer(handler)
278 defer server.Close()
279
280 resp, err := http.Get(server.URL)
281 if err != nil {
282 t.Fatal(err)
283 }
284 expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
285 for k, vls := range resp.Header {
286 for _, v := range vls {
287 delete((expect[k]), v)
288 }
289 if len(expect[k]) != 0 {
290 t.Errorf("Header: unexpected header %s: %v", k, expect[k])
291 }
292 }
293 }
294
295 type multiHeaderResponseError struct {
296 multiHeaderResponse
297 msg string
298 }
299
300 func (m multiHeaderResponseError) Error() string {
301 return m.msg
302 }
303
304 func TestAddMultipleHeadersErrorEncoder(t *testing.T) {
305 errStr := "oh no"
306 handler := httptransport.NewServer(
307 func(context.Context, interface{}) (interface{}, error) {
308 return nil, multiHeaderResponseError{msg: errStr}
309 },
310 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
311 httptransport.EncodeJSONResponse,
312 )
313
314 server := httptest.NewServer(handler)
315 defer server.Close()
316
317 resp, err := http.Get(server.URL)
318 if err != nil {
319 t.Fatal(err)
320 }
321 expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
322 for k, vls := range resp.Header {
323 for _, v := range vls {
324 delete((expect[k]), v)
325 }
326 if len(expect[k]) != 0 {
327 t.Errorf("Header: unexpected header %s: %v", k, expect[k])
328 }
329 }
330 if b, _ := ioutil.ReadAll(resp.Body); errStr != string(b) {
331 t.Errorf("ErrorEncoder: got: %q, expected: %q", b, errStr)
332 }
333 }
334
335 type noContentResponse struct{}
336
337 func (e noContentResponse) StatusCode() int { return http.StatusNoContent }
338
339 func TestEncodeNoContent(t *testing.T) {
340 handler := httptransport.NewServer(
341 func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil },
342 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
343 httptransport.EncodeJSONResponse,
344 )
345
346 server := httptest.NewServer(handler)
347 defer server.Close()
348
349 resp, err := http.Get(server.URL)
350 if err != nil {
351 t.Fatal(err)
352 }
353 if want, have := http.StatusNoContent, resp.StatusCode; want != have {
354 t.Errorf("StatusCode: want %d, have %d", want, have)
355 }
356 buf, _ := ioutil.ReadAll(resp.Body)
357 if want, have := 0, len(buf); want != have {
358 t.Errorf("Body: want no content, have %d bytes", have)
359 }
360 }
361
362 type enhancedError struct{}
363
364 func (e enhancedError) Error() string { return "enhanced error" }
365 func (e enhancedError) StatusCode() int { return http.StatusTeapot }
366 func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil }
367 func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} }
368
369 func TestEnhancedError(t *testing.T) {
370 handler := httptransport.NewServer(
371 func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} },
372 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
373 func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil },
374 )
375
376 server := httptest.NewServer(handler)
377 defer server.Close()
378
379 resp, err := http.Get(server.URL)
380 if err != nil {
381 t.Fatal(err)
382 }
383 defer resp.Body.Close()
384 if want, have := http.StatusTeapot, resp.StatusCode; want != have {
385 t.Errorf("StatusCode: want %d, have %d", want, have)
386 }
387 if want, have := "1", resp.Header.Get("X-Enhanced"); want != have {
388 t.Errorf("X-Enhanced: want %q, have %q", want, have)
389 }
390 buf, _ := ioutil.ReadAll(resp.Body)
391 if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have {
392 t.Errorf("Body: want %s, have %s", want, have)
393 }
394 }
395
396 func TestNoOpRequestDecoder(t *testing.T) {
397 resw := httptest.NewRecorder()
398 req, err := http.NewRequest(http.MethodGet, "/", nil)
399 if err != nil {
400 t.Error("Failed to create request")
401 }
402 handler := httptransport.NewServer(
403 func(ctx context.Context, request interface{}) (interface{}, error) {
404 if request != nil {
405 t.Error("Expected nil request in endpoint when using NopRequestDecoder")
406 }
407 return nil, nil
408 },
409 httptransport.NopRequestDecoder,
410 httptransport.EncodeJSONResponse,
411 )
412 handler.ServeHTTP(resw, req)
413 if resw.Code != http.StatusOK {
414 t.Errorf("Expected status code %d but got %d", http.StatusOK, resw.Code)
415 }
416 }
417
418 func testServer(t *testing.T) (step func(), resp <-chan *http.Response) {
419 var (
420 stepch = make(chan bool)
421 endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil }
422 response = make(chan *http.Response)
423 handler = httptransport.NewServer(
424 endpoint,
425 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
426 func(context.Context, http.ResponseWriter, interface{}) error { return nil },
427 httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { return ctx }),
428 httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { return ctx }),
429 )
430 )
431 go func() {
432 server := httptest.NewServer(handler)
433 defer server.Close()
434 resp, err := http.Get(server.URL)
435 if err != nil {
436 t.Error(err)
437 return
438 }
439 response <- resp
440 }()
441 return func() { stepch <- true }, response
442 }
443
View as plain text