...

Source file src/github.com/go-kit/kit/transport/http/server_test.go

Documentation: github.com/go-kit/kit/transport/http

     1  package http_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/go-kit/kit/endpoint"
    14  	httptransport "github.com/go-kit/kit/transport/http"
    15  )
    16  
    17  func TestServerBadDecode(t *testing.T) {
    18  	handler := httptransport.NewServer(
    19  		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
    20  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") },
    21  		func(context.Context, http.ResponseWriter, interface{}) error { return nil },
    22  	)
    23  	server := httptest.NewServer(handler)
    24  	defer server.Close()
    25  	resp, _ := http.Get(server.URL)
    26  	if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
    27  		t.Errorf("want %d, have %d", want, have)
    28  	}
    29  }
    30  
    31  func TestServerBadEndpoint(t *testing.T) {
    32  	handler := httptransport.NewServer(
    33  		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") },
    34  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
    35  		func(context.Context, http.ResponseWriter, interface{}) error { return nil },
    36  	)
    37  	server := httptest.NewServer(handler)
    38  	defer server.Close()
    39  	resp, _ := http.Get(server.URL)
    40  	if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
    41  		t.Errorf("want %d, have %d", want, have)
    42  	}
    43  }
    44  
    45  func TestServerBadEncode(t *testing.T) {
    46  	handler := httptransport.NewServer(
    47  		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
    48  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
    49  		func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") },
    50  	)
    51  	server := httptest.NewServer(handler)
    52  	defer server.Close()
    53  	resp, _ := http.Get(server.URL)
    54  	if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
    55  		t.Errorf("want %d, have %d", want, have)
    56  	}
    57  }
    58  
    59  func TestServerErrorEncoder(t *testing.T) {
    60  	errTeapot := errors.New("teapot")
    61  	code := func(err error) int {
    62  		if errors.Is(err, errTeapot) {
    63  			return http.StatusTeapot
    64  		}
    65  		return http.StatusInternalServerError
    66  	}
    67  	handler := httptransport.NewServer(
    68  		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot },
    69  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
    70  		func(context.Context, http.ResponseWriter, interface{}) error { return nil },
    71  		httptransport.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }),
    72  	)
    73  	server := httptest.NewServer(handler)
    74  	defer server.Close()
    75  	resp, _ := http.Get(server.URL)
    76  	if want, have := http.StatusTeapot, resp.StatusCode; want != have {
    77  		t.Errorf("want %d, have %d", want, have)
    78  	}
    79  }
    80  
    81  func TestServerHappyPath(t *testing.T) {
    82  	step, response := testServer(t)
    83  	step()
    84  	resp := <-response
    85  	defer resp.Body.Close()
    86  	buf, _ := ioutil.ReadAll(resp.Body)
    87  	if want, have := http.StatusOK, resp.StatusCode; want != have {
    88  		t.Errorf("want %d, have %d (%s)", want, have, buf)
    89  	}
    90  }
    91  
    92  func TestMultipleServerBefore(t *testing.T) {
    93  	var (
    94  		headerKey    = "X-Henlo-Lizer"
    95  		headerVal    = "Helllo you stinky lizard"
    96  		statusCode   = http.StatusTeapot
    97  		responseBody = "go eat a fly ugly\n"
    98  		done         = make(chan struct{})
    99  	)
   100  	handler := httptransport.NewServer(
   101  		endpoint.Nop,
   102  		func(context.Context, *http.Request) (interface{}, error) {
   103  			return struct{}{}, nil
   104  		},
   105  		func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
   106  			w.Header().Set(headerKey, headerVal)
   107  			w.WriteHeader(statusCode)
   108  			w.Write([]byte(responseBody))
   109  			return nil
   110  		},
   111  		httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
   112  			ctx = context.WithValue(ctx, "one", 1)
   113  
   114  			return ctx
   115  		}),
   116  		httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
   117  			if _, ok := ctx.Value("one").(int); !ok {
   118  				t.Error("Value was not set properly when multiple ServerBefores are used")
   119  			}
   120  
   121  			close(done)
   122  			return ctx
   123  		}),
   124  	)
   125  
   126  	server := httptest.NewServer(handler)
   127  	defer server.Close()
   128  	go http.Get(server.URL)
   129  
   130  	select {
   131  	case <-done:
   132  	case <-time.After(time.Second):
   133  		t.Fatal("timeout waiting for finalizer")
   134  	}
   135  }
   136  
   137  func TestMultipleServerAfter(t *testing.T) {
   138  	var (
   139  		headerKey    = "X-Henlo-Lizer"
   140  		headerVal    = "Helllo you stinky lizard"
   141  		statusCode   = http.StatusTeapot
   142  		responseBody = "go eat a fly ugly\n"
   143  		done         = make(chan struct{})
   144  	)
   145  	handler := httptransport.NewServer(
   146  		endpoint.Nop,
   147  		func(context.Context, *http.Request) (interface{}, error) {
   148  			return struct{}{}, nil
   149  		},
   150  		func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
   151  			w.Header().Set(headerKey, headerVal)
   152  			w.WriteHeader(statusCode)
   153  			w.Write([]byte(responseBody))
   154  			return nil
   155  		},
   156  		httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
   157  			ctx = context.WithValue(ctx, "one", 1)
   158  
   159  			return ctx
   160  		}),
   161  		httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
   162  			if _, ok := ctx.Value("one").(int); !ok {
   163  				t.Error("Value was not set properly when multiple ServerAfters are used")
   164  			}
   165  
   166  			close(done)
   167  			return ctx
   168  		}),
   169  	)
   170  
   171  	server := httptest.NewServer(handler)
   172  	defer server.Close()
   173  	go http.Get(server.URL)
   174  
   175  	select {
   176  	case <-done:
   177  	case <-time.After(time.Second):
   178  		t.Fatal("timeout waiting for finalizer")
   179  	}
   180  }
   181  
   182  func TestServerFinalizer(t *testing.T) {
   183  	var (
   184  		headerKey    = "X-Henlo-Lizer"
   185  		headerVal    = "Helllo you stinky lizard"
   186  		statusCode   = http.StatusTeapot
   187  		responseBody = "go eat a fly ugly\n"
   188  		done         = make(chan struct{})
   189  	)
   190  	handler := httptransport.NewServer(
   191  		endpoint.Nop,
   192  		func(context.Context, *http.Request) (interface{}, error) {
   193  			return struct{}{}, nil
   194  		},
   195  		func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
   196  			w.Header().Set(headerKey, headerVal)
   197  			w.WriteHeader(statusCode)
   198  			w.Write([]byte(responseBody))
   199  			return nil
   200  		},
   201  		httptransport.ServerFinalizer(func(ctx context.Context, code int, _ *http.Request) {
   202  			if want, have := statusCode, code; want != have {
   203  				t.Errorf("StatusCode: want %d, have %d", want, have)
   204  			}
   205  
   206  			responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header)
   207  			if want, have := headerVal, responseHeader.Get(headerKey); want != have {
   208  				t.Errorf("%s: want %q, have %q", headerKey, want, have)
   209  			}
   210  
   211  			responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64)
   212  			if want, have := int64(len(responseBody)), responseSize; want != have {
   213  				t.Errorf("response size: want %d, have %d", want, have)
   214  			}
   215  
   216  			close(done)
   217  		}),
   218  	)
   219  
   220  	server := httptest.NewServer(handler)
   221  	defer server.Close()
   222  	go http.Get(server.URL)
   223  
   224  	select {
   225  	case <-done:
   226  	case <-time.After(time.Second):
   227  		t.Fatal("timeout waiting for finalizer")
   228  	}
   229  }
   230  
   231  type enhancedResponse struct {
   232  	Foo string `json:"foo"`
   233  }
   234  
   235  func (e enhancedResponse) StatusCode() int      { return http.StatusPaymentRequired }
   236  func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }
   237  
   238  func TestEncodeJSONResponse(t *testing.T) {
   239  	handler := httptransport.NewServer(
   240  		func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil },
   241  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
   242  		httptransport.EncodeJSONResponse,
   243  	)
   244  
   245  	server := httptest.NewServer(handler)
   246  	defer server.Close()
   247  
   248  	resp, err := http.Get(server.URL)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have {
   253  		t.Errorf("StatusCode: want %d, have %d", want, have)
   254  	}
   255  	if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have {
   256  		t.Errorf("X-Edward: want %q, have %q", want, have)
   257  	}
   258  	buf, _ := ioutil.ReadAll(resp.Body)
   259  	if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have {
   260  		t.Errorf("Body: want %s, have %s", want, have)
   261  	}
   262  }
   263  
   264  type multiHeaderResponse struct{}
   265  
   266  func (_ multiHeaderResponse) Headers() http.Header {
   267  	return http.Header{"Vary": []string{"Origin", "User-Agent"}}
   268  }
   269  
   270  func TestAddMultipleHeaders(t *testing.T) {
   271  	handler := httptransport.NewServer(
   272  		func(context.Context, interface{}) (interface{}, error) { return multiHeaderResponse{}, nil },
   273  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
   274  		httptransport.EncodeJSONResponse,
   275  	)
   276  
   277  	server := httptest.NewServer(handler)
   278  	defer server.Close()
   279  
   280  	resp, err := http.Get(server.URL)
   281  	if err != nil {
   282  		t.Fatal(err)
   283  	}
   284  	expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
   285  	for k, vls := range resp.Header {
   286  		for _, v := range vls {
   287  			delete((expect[k]), v)
   288  		}
   289  		if len(expect[k]) != 0 {
   290  			t.Errorf("Header: unexpected header %s: %v", k, expect[k])
   291  		}
   292  	}
   293  }
   294  
   295  type multiHeaderResponseError struct {
   296  	multiHeaderResponse
   297  	msg string
   298  }
   299  
   300  func (m multiHeaderResponseError) Error() string {
   301  	return m.msg
   302  }
   303  
   304  func TestAddMultipleHeadersErrorEncoder(t *testing.T) {
   305  	errStr := "oh no"
   306  	handler := httptransport.NewServer(
   307  		func(context.Context, interface{}) (interface{}, error) {
   308  			return nil, multiHeaderResponseError{msg: errStr}
   309  		},
   310  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
   311  		httptransport.EncodeJSONResponse,
   312  	)
   313  
   314  	server := httptest.NewServer(handler)
   315  	defer server.Close()
   316  
   317  	resp, err := http.Get(server.URL)
   318  	if err != nil {
   319  		t.Fatal(err)
   320  	}
   321  	expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
   322  	for k, vls := range resp.Header {
   323  		for _, v := range vls {
   324  			delete((expect[k]), v)
   325  		}
   326  		if len(expect[k]) != 0 {
   327  			t.Errorf("Header: unexpected header %s: %v", k, expect[k])
   328  		}
   329  	}
   330  	if b, _ := ioutil.ReadAll(resp.Body); errStr != string(b) {
   331  		t.Errorf("ErrorEncoder: got: %q, expected: %q", b, errStr)
   332  	}
   333  }
   334  
   335  type noContentResponse struct{}
   336  
   337  func (e noContentResponse) StatusCode() int { return http.StatusNoContent }
   338  
   339  func TestEncodeNoContent(t *testing.T) {
   340  	handler := httptransport.NewServer(
   341  		func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil },
   342  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
   343  		httptransport.EncodeJSONResponse,
   344  	)
   345  
   346  	server := httptest.NewServer(handler)
   347  	defer server.Close()
   348  
   349  	resp, err := http.Get(server.URL)
   350  	if err != nil {
   351  		t.Fatal(err)
   352  	}
   353  	if want, have := http.StatusNoContent, resp.StatusCode; want != have {
   354  		t.Errorf("StatusCode: want %d, have %d", want, have)
   355  	}
   356  	buf, _ := ioutil.ReadAll(resp.Body)
   357  	if want, have := 0, len(buf); want != have {
   358  		t.Errorf("Body: want no content, have %d bytes", have)
   359  	}
   360  }
   361  
   362  type enhancedError struct{}
   363  
   364  func (e enhancedError) Error() string                { return "enhanced error" }
   365  func (e enhancedError) StatusCode() int              { return http.StatusTeapot }
   366  func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil }
   367  func (e enhancedError) Headers() http.Header         { return http.Header{"X-Enhanced": []string{"1"}} }
   368  
   369  func TestEnhancedError(t *testing.T) {
   370  	handler := httptransport.NewServer(
   371  		func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} },
   372  		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
   373  		func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil },
   374  	)
   375  
   376  	server := httptest.NewServer(handler)
   377  	defer server.Close()
   378  
   379  	resp, err := http.Get(server.URL)
   380  	if err != nil {
   381  		t.Fatal(err)
   382  	}
   383  	defer resp.Body.Close()
   384  	if want, have := http.StatusTeapot, resp.StatusCode; want != have {
   385  		t.Errorf("StatusCode: want %d, have %d", want, have)
   386  	}
   387  	if want, have := "1", resp.Header.Get("X-Enhanced"); want != have {
   388  		t.Errorf("X-Enhanced: want %q, have %q", want, have)
   389  	}
   390  	buf, _ := ioutil.ReadAll(resp.Body)
   391  	if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have {
   392  		t.Errorf("Body: want %s, have %s", want, have)
   393  	}
   394  }
   395  
   396  func TestNoOpRequestDecoder(t *testing.T) {
   397  	resw := httptest.NewRecorder()
   398  	req, err := http.NewRequest(http.MethodGet, "/", nil)
   399  	if err != nil {
   400  		t.Error("Failed to create request")
   401  	}
   402  	handler := httptransport.NewServer(
   403  		func(ctx context.Context, request interface{}) (interface{}, error) {
   404  			if request != nil {
   405  				t.Error("Expected nil request in endpoint when using NopRequestDecoder")
   406  			}
   407  			return nil, nil
   408  		},
   409  		httptransport.NopRequestDecoder,
   410  		httptransport.EncodeJSONResponse,
   411  	)
   412  	handler.ServeHTTP(resw, req)
   413  	if resw.Code != http.StatusOK {
   414  		t.Errorf("Expected status code %d but got %d", http.StatusOK, resw.Code)
   415  	}
   416  }
   417  
   418  func testServer(t *testing.T) (step func(), resp <-chan *http.Response) {
   419  	var (
   420  		stepch   = make(chan bool)
   421  		endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil }
   422  		response = make(chan *http.Response)
   423  		handler  = httptransport.NewServer(
   424  			endpoint,
   425  			func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
   426  			func(context.Context, http.ResponseWriter, interface{}) error { return nil },
   427  			httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { return ctx }),
   428  			httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { return ctx }),
   429  		)
   430  	)
   431  	go func() {
   432  		server := httptest.NewServer(handler)
   433  		defer server.Close()
   434  		resp, err := http.Get(server.URL)
   435  		if err != nil {
   436  			t.Error(err)
   437  			return
   438  		}
   439  		response <- resp
   440  	}()
   441  	return func() { stepch <- true }, response
   442  }
   443  

View as plain text