...

Source file src/github.com/go-kit/kit/transport/http/client_test.go

Documentation: github.com/go-kit/kit/transport/http

     1  package http_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"net/url"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	httptransport "github.com/go-kit/kit/transport/http"
    17  )
    18  
    19  type TestResponse struct {
    20  	Body   io.ReadCloser
    21  	String string
    22  }
    23  
    24  func TestHTTPClient(t *testing.T) {
    25  	var (
    26  		testbody = "testbody"
    27  		encode   = func(context.Context, *http.Request, interface{}) error { return nil }
    28  		decode   = func(_ context.Context, r *http.Response) (interface{}, error) {
    29  			buffer := make([]byte, len(testbody))
    30  			r.Body.Read(buffer)
    31  			return TestResponse{r.Body, string(buffer)}, nil
    32  		}
    33  		headers        = make(chan string, 1)
    34  		headerKey      = "X-Foo"
    35  		headerVal      = "abcde"
    36  		afterHeaderKey = "X-The-Dude"
    37  		afterHeaderVal = "Abides"
    38  		afterVal       = ""
    39  		afterFunc      = func(ctx context.Context, r *http.Response) context.Context {
    40  			afterVal = r.Header.Get(afterHeaderKey)
    41  			return ctx
    42  		}
    43  	)
    44  
    45  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    46  		headers <- r.Header.Get(headerKey)
    47  		w.Header().Set(afterHeaderKey, afterHeaderVal)
    48  		w.WriteHeader(http.StatusOK)
    49  		w.Write([]byte(testbody))
    50  	}))
    51  
    52  	client := httptransport.NewClient(
    53  		"GET",
    54  		mustParse(server.URL),
    55  		encode,
    56  		decode,
    57  		httptransport.ClientBefore(httptransport.SetRequestHeader(headerKey, headerVal)),
    58  		httptransport.ClientAfter(afterFunc),
    59  	)
    60  
    61  	res, err := client.Endpoint()(context.Background(), struct{}{})
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  
    66  	var have string
    67  	select {
    68  	case have = <-headers:
    69  	case <-time.After(time.Millisecond):
    70  		t.Fatalf("timeout waiting for %s", headerKey)
    71  	}
    72  	// Check that Request Header was successfully received
    73  	if want := headerVal; want != have {
    74  		t.Errorf("want %q, have %q", want, have)
    75  	}
    76  
    77  	// Check that Response header set from server was received in SetClientAfter
    78  	if want, have := afterVal, afterHeaderVal; want != have {
    79  		t.Errorf("want %q, have %q", want, have)
    80  	}
    81  
    82  	// Check that the response was successfully decoded
    83  	response, ok := res.(TestResponse)
    84  	if !ok {
    85  		t.Fatal("response should be TestResponse")
    86  	}
    87  	if want, have := testbody, response.String; want != have {
    88  		t.Errorf("want %q, have %q", want, have)
    89  	}
    90  
    91  	// Check that response body was closed
    92  	b := make([]byte, 1)
    93  	_, err = response.Body.Read(b)
    94  	if err == nil {
    95  		t.Fatal("wanted error, got none")
    96  	}
    97  	if doNotWant, have := io.EOF, err; doNotWant == have {
    98  		t.Errorf("do not want %q, have %q", doNotWant, have)
    99  	}
   100  }
   101  
   102  func TestHTTPClientBufferedStream(t *testing.T) {
   103  	// bodysize has a size big enought to make the resopnse.Body not an instant read
   104  	// so if the response is cancelled it wount be all readed and the test would fail
   105  	// The 6000 has not a particular meaning, it big enough to fulfill the usecase.
   106  	const bodysize = 6000
   107  	var (
   108  		testbody = string(make([]byte, bodysize))
   109  		encode   = func(context.Context, *http.Request, interface{}) error { return nil }
   110  		decode   = func(_ context.Context, r *http.Response) (interface{}, error) {
   111  			return TestResponse{r.Body, ""}, nil
   112  		}
   113  	)
   114  
   115  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   116  		w.WriteHeader(http.StatusOK)
   117  		w.Write([]byte(testbody))
   118  	}))
   119  
   120  	client := httptransport.NewClient(
   121  		"GET",
   122  		mustParse(server.URL),
   123  		encode,
   124  		decode,
   125  		httptransport.BufferedStream(true),
   126  	)
   127  
   128  	res, err := client.Endpoint()(context.Background(), struct{}{})
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  
   133  	// Check that the response was successfully decoded
   134  	response, ok := res.(TestResponse)
   135  	if !ok {
   136  		t.Fatal("response should be TestResponse")
   137  	}
   138  	defer response.Body.Close()
   139  	// Faking work
   140  	time.Sleep(time.Second * 1)
   141  
   142  	// Check that response body was NOT closed
   143  	b := make([]byte, len(testbody))
   144  	_, err = response.Body.Read(b)
   145  	if want, have := io.EOF, err; have != want {
   146  		t.Fatalf("want %q, have %q", want, have)
   147  	}
   148  	if want, have := testbody, string(b); want != have {
   149  		t.Errorf("want %q, have %q", want, have)
   150  	}
   151  }
   152  
   153  func TestClientFinalizer(t *testing.T) {
   154  	var (
   155  		headerKey    = "X-Henlo-Lizer"
   156  		headerVal    = "Helllo you stinky lizard"
   157  		responseBody = "go eat a fly ugly\n"
   158  		done         = make(chan struct{})
   159  		encode       = func(context.Context, *http.Request, interface{}) error { return nil }
   160  		decode       = func(_ context.Context, r *http.Response) (interface{}, error) {
   161  			return TestResponse{r.Body, ""}, nil
   162  		}
   163  	)
   164  
   165  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   166  		w.Header().Set(headerKey, headerVal)
   167  		w.Write([]byte(responseBody))
   168  	}))
   169  	defer server.Close()
   170  
   171  	client := httptransport.NewClient(
   172  		"GET",
   173  		mustParse(server.URL),
   174  		encode,
   175  		decode,
   176  		httptransport.ClientFinalizer(func(ctx context.Context, err error) {
   177  			responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header)
   178  			if want, have := headerVal, responseHeader.Get(headerKey); want != have {
   179  				t.Errorf("%s: want %q, have %q", headerKey, want, have)
   180  			}
   181  
   182  			responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64)
   183  			if want, have := int64(len(responseBody)), responseSize; want != have {
   184  				t.Errorf("response size: want %d, have %d", want, have)
   185  			}
   186  
   187  			close(done)
   188  		}),
   189  	)
   190  
   191  	_, err := client.Endpoint()(context.Background(), struct{}{})
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	select {
   197  	case <-done:
   198  	case <-time.After(time.Second):
   199  		t.Fatal("timeout waiting for finalizer")
   200  	}
   201  }
   202  
   203  func TestEncodeJSONRequest(t *testing.T) {
   204  	var header http.Header
   205  	var body string
   206  
   207  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   208  		b, err := ioutil.ReadAll(r.Body)
   209  		if err != nil && err != io.EOF {
   210  			t.Fatal(err)
   211  		}
   212  		header = r.Header
   213  		body = string(b)
   214  	}))
   215  
   216  	defer server.Close()
   217  
   218  	serverURL, err := url.Parse(server.URL)
   219  
   220  	if err != nil {
   221  		t.Fatal(err)
   222  	}
   223  
   224  	client := httptransport.NewClient(
   225  		"POST",
   226  		serverURL,
   227  		httptransport.EncodeJSONRequest,
   228  		func(context.Context, *http.Response) (interface{}, error) { return nil, nil },
   229  	).Endpoint()
   230  
   231  	for _, test := range []struct {
   232  		value interface{}
   233  		body  string
   234  	}{
   235  		{nil, "null\n"},
   236  		{12, "12\n"},
   237  		{1.2, "1.2\n"},
   238  		{true, "true\n"},
   239  		{"test", "\"test\"\n"},
   240  		{enhancedRequest{Foo: "foo"}, "{\"foo\":\"foo\"}\n"},
   241  	} {
   242  		if _, err := client(context.Background(), test.value); err != nil {
   243  			t.Error(err)
   244  			continue
   245  		}
   246  
   247  		if body != test.body {
   248  			t.Errorf("%v: actual %#v, expected %#v", test.value, body, test.body)
   249  		}
   250  	}
   251  
   252  	if _, err := client(context.Background(), enhancedRequest{Foo: "foo"}); err != nil {
   253  		t.Fatal(err)
   254  	}
   255  
   256  	if _, ok := header["X-Edward"]; !ok {
   257  		t.Fatalf("X-Edward value: actual %v, expected %v", nil, []string{"Snowden"})
   258  	}
   259  
   260  	if v := header.Get("X-Edward"); v != "Snowden" {
   261  		t.Errorf("X-Edward string: actual %v, expected %v", v, "Snowden")
   262  	}
   263  }
   264  
   265  func TestSetClient(t *testing.T) {
   266  	var (
   267  		encode = func(context.Context, *http.Request, interface{}) error { return nil }
   268  		decode = func(_ context.Context, r *http.Response) (interface{}, error) {
   269  			t, err := ioutil.ReadAll(r.Body)
   270  			if err != nil {
   271  				return nil, err
   272  			}
   273  			return string(t), nil
   274  		}
   275  	)
   276  
   277  	testHttpClient := httpClientFunc(func(req *http.Request) (*http.Response, error) {
   278  		return &http.Response{
   279  			StatusCode: http.StatusOK,
   280  			Request:    req,
   281  			Body:       ioutil.NopCloser(bytes.NewBufferString("hello, world!")),
   282  		}, nil
   283  	})
   284  
   285  	client := httptransport.NewClient(
   286  		"GET",
   287  		&url.URL{},
   288  		encode,
   289  		decode,
   290  		httptransport.SetClient(testHttpClient),
   291  	).Endpoint()
   292  
   293  	resp, err := client(context.Background(), nil)
   294  	if err != nil {
   295  		t.Fatal(err)
   296  	}
   297  	if r, ok := resp.(string); !ok || r != "hello, world!" {
   298  		t.Fatal("Expected response to be 'hello, world!' string")
   299  	}
   300  }
   301  
   302  func TestNewExplicitClient(t *testing.T) {
   303  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   304  		fmt.Fprintf(w, "%d", r.ContentLength)
   305  	}))
   306  	defer srv.Close()
   307  
   308  	req := func(ctx context.Context, request interface{}) (*http.Request, error) {
   309  		req, _ := http.NewRequest("POST", srv.URL, strings.NewReader(request.(string)))
   310  		return req, nil
   311  	}
   312  
   313  	dec := func(_ context.Context, resp *http.Response) (response interface{}, err error) {
   314  		buf, err := ioutil.ReadAll(resp.Body)
   315  		resp.Body.Close()
   316  		return string(buf), err
   317  	}
   318  
   319  	client := httptransport.NewExplicitClient(req, dec)
   320  
   321  	request := "hello world"
   322  	response, err := client.Endpoint()(context.Background(), request)
   323  	if err != nil {
   324  		t.Fatal(err)
   325  	}
   326  
   327  	if want, have := "11", response.(string); want != have {
   328  		t.Fatalf("want %q, have %q", want, have)
   329  	}
   330  }
   331  
   332  func mustParse(s string) *url.URL {
   333  	u, err := url.Parse(s)
   334  	if err != nil {
   335  		panic(err)
   336  	}
   337  	return u
   338  }
   339  
   340  type enhancedRequest struct {
   341  	Foo string `json:"foo"`
   342  }
   343  
   344  func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }
   345  
   346  type httpClientFunc func(req *http.Request) (*http.Response, error)
   347  
   348  func (f httpClientFunc) Do(req *http.Request) (*http.Response, error) {
   349  	return f(req)
   350  }
   351  

View as plain text