1 package middleware
2
3 import (
4 "crypto/tls"
5 "io"
6 "io/ioutil"
7 "net/http"
8 "net/http/httptest"
9 "path"
10 "reflect"
11 "runtime"
12 "testing"
13 "time"
14
15 "golang.org/x/net/http2"
16 )
17
18
19
20
21
22 var testdataDir string
23
24 func init() {
25 _, filename, _, _ := runtime.Caller(0)
26 testdataDir = path.Join(path.Dir(filename), "/../testdata")
27 }
28
29 func TestWrapWriterHTTP2(t *testing.T) {
30 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31 _, fl := w.(http.Flusher)
32 if !fl {
33 t.Fatal("request should have been a http.Flusher")
34 }
35 _, hj := w.(http.Hijacker)
36 if hj {
37 t.Fatal("request should not have been a http.Hijacker")
38 }
39 _, rf := w.(io.ReaderFrom)
40 if rf {
41 t.Fatal("request should not have been a io.ReaderFrom")
42 }
43 _, ps := w.(http.Pusher)
44 if !ps {
45 t.Fatal("request should have been a http.Pusher")
46 }
47
48 w.Write([]byte("OK"))
49 })
50
51 wmw := func(next http.Handler) http.Handler {
52 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53 next.ServeHTTP(NewWrapResponseWriter(w, r.ProtoMajor), r)
54 })
55 }
56
57 server := http.Server{
58 Addr: ":7072",
59 Handler: wmw(handler),
60 }
61
62 go server.ListenAndServeTLS(testdataDir+"/cert.pem", testdataDir+"/key.pem")
63 defer server.Close()
64
65 time.Sleep(100 * time.Millisecond)
66
67 client := &http.Client{
68 Transport: &http2.Transport{
69 TLSClientConfig: &tls.Config{
70
71 InsecureSkipVerify: true,
72 },
73 },
74 }
75
76 resp, err := client.Get("https://localhost:7072")
77 if err != nil {
78 t.Fatalf("could not get server: %v", err)
79 }
80 if resp.StatusCode != 200 {
81 t.Fatalf("non 200 response: %v", resp.StatusCode)
82 }
83 }
84
85 func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
86 req, err := http.NewRequest(method, ts.URL+path, body)
87 if err != nil {
88 t.Fatal(err)
89 return nil, ""
90 }
91
92 resp, err := http.DefaultClient.Do(req)
93 if err != nil {
94 t.Fatal(err)
95 return nil, ""
96 }
97
98 respBody, err := ioutil.ReadAll(resp.Body)
99 if err != nil {
100 t.Fatal(err)
101 return nil, ""
102 }
103 defer resp.Body.Close()
104
105 return resp, string(respBody)
106 }
107
108 func testRequestNoRedirect(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
109 req, err := http.NewRequest(method, ts.URL+path, body)
110 if err != nil {
111 t.Fatal(err)
112 return nil, ""
113 }
114
115
116 httpClient := &http.Client{
117 CheckRedirect: func(req *http.Request, via []*http.Request) error {
118 return http.ErrUseLastResponse
119 },
120 }
121
122 resp, err := httpClient.Do(req)
123 if err != nil {
124 t.Fatal(err)
125 return nil, ""
126 }
127
128 respBody, err := ioutil.ReadAll(resp.Body)
129 if err != nil {
130 t.Fatal(err)
131 return nil, ""
132 }
133 defer resp.Body.Close()
134
135 return resp, string(respBody)
136 }
137
138 func assertNoError(t *testing.T, err error) {
139 t.Helper()
140 if err != nil {
141 t.Fatalf("expecting no error")
142 }
143 }
144
145 func assertError(t *testing.T, err error) {
146 t.Helper()
147 if err == nil {
148 t.Fatalf("expecting error")
149 }
150 }
151
152 func assertEqual(t *testing.T, a, b interface{}) {
153 t.Helper()
154 if !reflect.DeepEqual(a, b) {
155 t.Fatalf("expecting values to be equal but got: '%v' and '%v'", a, b)
156 }
157 }
158
View as plain text