1
2
3
4
5
6
7
8
9
10
11
12
13 package chttp
14
15 import (
16 "errors"
17 "io"
18 "net/http"
19 "net/http/httptest"
20 "strings"
21 "testing"
22
23 "gitlab.com/flimzy/testy"
24 )
25
26 func TestHTTPResponse(t *testing.T) {
27 tests := []struct {
28 name string
29 trace func(t *testing.T) *ClientTrace
30 resp *http.Response
31 finalResp *http.Response
32 }{
33 {
34 name: "no hook defined",
35 trace: func(_ *testing.T) *ClientTrace { return &ClientTrace{} },
36 resp: &http.Response{StatusCode: 200},
37 finalResp: &http.Response{StatusCode: 200},
38 },
39 {
40 name: "HTTPResponseBody/cloned response",
41 trace: func(t *testing.T) *ClientTrace {
42 return &ClientTrace{
43 HTTPResponseBody: func(r *http.Response) {
44 if r.StatusCode != 200 {
45 t.Errorf("Unexpected status code: %d", r.StatusCode)
46 }
47 r.StatusCode = 0
48 defer r.Body.Close()
49 if _, err := io.ReadAll(r.Body); err != nil {
50 t.Fatal(err)
51 }
52 },
53 }
54 },
55 resp: &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("testing"))},
56 finalResp: &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("testing"))},
57 },
58 {
59 name: "HTTPResponse/cloned response",
60 trace: func(t *testing.T) *ClientTrace {
61 return &ClientTrace{
62 HTTPResponse: func(r *http.Response) {
63 if r.StatusCode != 200 {
64 t.Errorf("Unexpected status code: %d", r.StatusCode)
65 }
66 r.StatusCode = 0
67 if r.Body != nil {
68 t.Errorf("non-nil body")
69 }
70 },
71 }
72 },
73 resp: &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("testing"))},
74 finalResp: &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("testing"))},
75 },
76 }
77 for _, test := range tests {
78 t.Run(test.name, func(t *testing.T) {
79 trace := test.trace(t)
80 trace.httpResponseBody(test.resp)
81 trace.httpResponse(test.resp)
82 if d := testy.DiffHTTPResponse(test.finalResp, test.resp); d != nil {
83 t.Error(d)
84 }
85 })
86 }
87 }
88
89 func TestHTTPRequest(t *testing.T) {
90 tests := []struct {
91 name string
92 trace func(t *testing.T) *ClientTrace
93 req *http.Request
94 finalReq *http.Request
95 }{
96 {
97 name: "no hook defined",
98 trace: func(_ *testing.T) *ClientTrace { return &ClientTrace{} },
99 req: httptest.NewRequest("PUT", "/", io.NopCloser(strings.NewReader("testing"))),
100 finalReq: httptest.NewRequest("PUT", "/", io.NopCloser(strings.NewReader("testing"))),
101 },
102 {
103 name: "HTTPRequestBody/cloned response",
104 trace: func(t *testing.T) *ClientTrace {
105 return &ClientTrace{
106 HTTPRequestBody: func(r *http.Request) {
107 if r.Method != "PUT" {
108 t.Errorf("Unexpected method: %s", r.Method)
109 }
110 r.Method = "unf"
111 defer r.Body.Close()
112 if _, err := io.ReadAll(r.Body); err != nil {
113 t.Fatal(err)
114 }
115 },
116 }
117 },
118 req: httptest.NewRequest("PUT", "/", io.NopCloser(strings.NewReader("testing"))),
119 finalReq: httptest.NewRequest("PUT", "/", io.NopCloser(strings.NewReader("testing"))),
120 },
121 {
122 name: "HTTPRequest/cloned response",
123 trace: func(t *testing.T) *ClientTrace {
124 return &ClientTrace{
125 HTTPRequest: func(r *http.Request) {
126 if r.Method != "PUT" {
127 t.Errorf("Unexpected method: %s", r.Method)
128 }
129 r.Method = "unf"
130 if r.Body != nil {
131 t.Errorf("non-nil body")
132 }
133 },
134 }
135 },
136 req: httptest.NewRequest("PUT", "/", io.NopCloser(strings.NewReader("testing"))),
137 finalReq: httptest.NewRequest("PUT", "/", io.NopCloser(strings.NewReader("testing"))),
138 },
139 {
140 name: "HTTPRequestBody/no body",
141 trace: func(t *testing.T) *ClientTrace {
142 return &ClientTrace{
143 HTTPRequestBody: func(r *http.Request) {
144 if r.Method != "GET" {
145 t.Errorf("Unexpected method: %s", r.Method)
146 }
147 r.Method = "unf"
148 if r.Body != nil {
149 t.Errorf("non-nil body")
150 }
151 },
152 }
153 },
154 req: func() *http.Request {
155 req, _ := http.NewRequest("GET", "/", nil)
156 return req
157 }(),
158 finalReq: func() *http.Request {
159 req, _ := http.NewRequest("GET", "/", nil)
160 req.Header.Add("Host", "example.com")
161 return req
162 }(),
163 },
164 }
165 for _, test := range tests {
166 t.Run(test.name, func(t *testing.T) {
167 trace := test.trace(t)
168 trace.httpRequestBody(test.req)
169 trace.httpRequest(test.req)
170 if d := testy.DiffHTTPRequest(test.finalReq, test.req); d != nil {
171 t.Error(d)
172 }
173 })
174 }
175 }
176
177 func TestReplayReadCloser(t *testing.T) {
178 tests := []struct {
179 name string
180 input io.ReadCloser
181 expected string
182 readErr string
183 closeErr string
184 }{
185 {
186 name: "no errors",
187 input: io.NopCloser(strings.NewReader("testing")),
188 expected: "testing",
189 },
190 {
191 name: "read error",
192 input: io.NopCloser(&errReader{Reader: strings.NewReader("testi"), err: errors.New("read error 1")}),
193 expected: "testi",
194 readErr: "read error 1",
195 },
196 {
197 name: "close error",
198 input: &errCloser{Reader: strings.NewReader("testin"), err: errors.New("close error 1")},
199 expected: "testin",
200 closeErr: "close error 1",
201 },
202 }
203 for _, test := range tests {
204 t.Run(test.name, func(t *testing.T) {
205 content, err := io.ReadAll(test.input.(io.Reader))
206 closeErr := test.input.Close()
207 rc := newReplay(content, err, closeErr)
208
209 result, resultErr := io.ReadAll(rc.(io.Reader))
210 resultCloseErr := rc.Close()
211 if d := testy.DiffText(test.expected, result); d != nil {
212 t.Error(d)
213 }
214 if err := resultErr; !testy.ErrorMatches(test.readErr, err) {
215 t.Errorf("Unexpected error: %s", err)
216 }
217 if err := resultCloseErr; !testy.ErrorMatches(test.closeErr, err) {
218 t.Errorf("Unexpected error: %s", err)
219 }
220 })
221 }
222 }
223
View as plain text