...

Source file src/github.com/go-chi/chi/middleware/middleware_test.go

Documentation: github.com/go-chi/chi/middleware

     1  package middleware
     2  
     3  import (
     4  	"crypto/tls"
     5  	"io"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"path"
    10  	"reflect"
    11  	"runtime"
    12  	"testing"
    13  	"time"
    14  
    15  	"golang.org/x/net/http2"
    16  )
    17  
    18  // NOTE: we must import `golang.org/x/net/http2` in order to explicitly enable
    19  // http2 transports for certain tests. The runtime pkg does not have this dependency
    20  // though as the transport configuration happens under the hood on go 1.7+.
    21  
    22  var testdataDir string
    23  
    24  func init() {
    25  	_, filename, _, _ := runtime.Caller(0)
    26  	testdataDir = path.Join(path.Dir(filename), "/../testdata")
    27  }
    28  
    29  func TestWrapWriterHTTP2(t *testing.T) {
    30  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    31  		_, fl := w.(http.Flusher)
    32  		if !fl {
    33  			t.Fatal("request should have been a http.Flusher")
    34  		}
    35  		_, hj := w.(http.Hijacker)
    36  		if hj {
    37  			t.Fatal("request should not have been a http.Hijacker")
    38  		}
    39  		_, rf := w.(io.ReaderFrom)
    40  		if rf {
    41  			t.Fatal("request should not have been a io.ReaderFrom")
    42  		}
    43  		_, ps := w.(http.Pusher)
    44  		if !ps {
    45  			t.Fatal("request should have been a http.Pusher")
    46  		}
    47  
    48  		w.Write([]byte("OK"))
    49  	})
    50  
    51  	wmw := func(next http.Handler) http.Handler {
    52  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    53  			next.ServeHTTP(NewWrapResponseWriter(w, r.ProtoMajor), r)
    54  		})
    55  	}
    56  
    57  	server := http.Server{
    58  		Addr:    ":7072",
    59  		Handler: wmw(handler),
    60  	}
    61  	// By serving over TLS, we get HTTP2 requests
    62  	go server.ListenAndServeTLS(testdataDir+"/cert.pem", testdataDir+"/key.pem")
    63  	defer server.Close()
    64  	// We need the server to start before making the request
    65  	time.Sleep(100 * time.Millisecond)
    66  
    67  	client := &http.Client{
    68  		Transport: &http2.Transport{
    69  			TLSClientConfig: &tls.Config{
    70  				// The certificates we are using are self signed
    71  				InsecureSkipVerify: true,
    72  			},
    73  		},
    74  	}
    75  
    76  	resp, err := client.Get("https://localhost:7072")
    77  	if err != nil {
    78  		t.Fatalf("could not get server: %v", err)
    79  	}
    80  	if resp.StatusCode != 200 {
    81  		t.Fatalf("non 200 response: %v", resp.StatusCode)
    82  	}
    83  }
    84  
    85  func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
    86  	req, err := http.NewRequest(method, ts.URL+path, body)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  		return nil, ""
    90  	}
    91  
    92  	resp, err := http.DefaultClient.Do(req)
    93  	if err != nil {
    94  		t.Fatal(err)
    95  		return nil, ""
    96  	}
    97  
    98  	respBody, err := ioutil.ReadAll(resp.Body)
    99  	if err != nil {
   100  		t.Fatal(err)
   101  		return nil, ""
   102  	}
   103  	defer resp.Body.Close()
   104  
   105  	return resp, string(respBody)
   106  }
   107  
   108  func testRequestNoRedirect(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
   109  	req, err := http.NewRequest(method, ts.URL+path, body)
   110  	if err != nil {
   111  		t.Fatal(err)
   112  		return nil, ""
   113  	}
   114  
   115  	// http client that doesn't redirect
   116  	httpClient := &http.Client{
   117  		CheckRedirect: func(req *http.Request, via []*http.Request) error {
   118  			return http.ErrUseLastResponse
   119  		},
   120  	}
   121  
   122  	resp, err := httpClient.Do(req)
   123  	if err != nil {
   124  		t.Fatal(err)
   125  		return nil, ""
   126  	}
   127  
   128  	respBody, err := ioutil.ReadAll(resp.Body)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  		return nil, ""
   132  	}
   133  	defer resp.Body.Close()
   134  
   135  	return resp, string(respBody)
   136  }
   137  
   138  func assertNoError(t *testing.T, err error) {
   139  	t.Helper()
   140  	if err != nil {
   141  		t.Fatalf("expecting no error")
   142  	}
   143  }
   144  
   145  func assertError(t *testing.T, err error) {
   146  	t.Helper()
   147  	if err == nil {
   148  		t.Fatalf("expecting error")
   149  	}
   150  }
   151  
   152  func assertEqual(t *testing.T, a, b interface{}) {
   153  	t.Helper()
   154  	if !reflect.DeepEqual(a, b) {
   155  		t.Fatalf("expecting values to be equal but got: '%v' and '%v'", a, b)
   156  	}
   157  }
   158  

View as plain text