1 package context
2
3 import (
4 "net/http"
5 "net/http/httptest"
6 "net/http/httputil"
7 "net/url"
8 "reflect"
9 "testing"
10 "time"
11 )
12
13 func TestWithRequest(t *testing.T) {
14 var req http.Request
15
16 start := time.Now()
17 req.Method = "GET"
18 req.Host = "example.com"
19 req.RequestURI = "/test-test"
20 req.Header = make(http.Header)
21 req.Header.Set("Referer", "foo.com/referer")
22 req.Header.Set("User-Agent", "test/0.1")
23
24 ctx := WithRequest(Background(), &req)
25 for _, testcase := range []struct {
26 key string
27 expected interface{}
28 }{
29 {
30 key: "http.request",
31 expected: &req,
32 },
33 {
34 key: "http.request.id",
35 },
36 {
37 key: "http.request.method",
38 expected: req.Method,
39 },
40 {
41 key: "http.request.host",
42 expected: req.Host,
43 },
44 {
45 key: "http.request.uri",
46 expected: req.RequestURI,
47 },
48 {
49 key: "http.request.referer",
50 expected: req.Referer(),
51 },
52 {
53 key: "http.request.useragent",
54 expected: req.UserAgent(),
55 },
56 {
57 key: "http.request.remoteaddr",
58 expected: req.RemoteAddr,
59 },
60 {
61 key: "http.request.startedat",
62 },
63 } {
64 v := ctx.Value(testcase.key)
65
66 if v == nil {
67 t.Fatalf("value not found for %q", testcase.key)
68 }
69
70 if testcase.expected != nil && v != testcase.expected {
71 t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
72 }
73
74
75 switch testcase.key {
76 case "http.request.id":
77 if _, ok := v.(string); !ok {
78 t.Fatalf("request id not a string: %v", v)
79 }
80 case "http.request.startedat":
81 vt, ok := v.(time.Time)
82 if !ok {
83 t.Fatalf("value not a time: %v", v)
84 }
85
86 now := time.Now()
87 if vt.After(now) {
88 t.Fatalf("time generated too late: %v > %v", vt, now)
89 }
90
91 if vt.Before(start) {
92 t.Fatalf("time generated too early: %v < %v", vt, start)
93 }
94 }
95 }
96 }
97
98 type testResponseWriter struct {
99 flushed bool
100 status int
101 written int64
102 header http.Header
103 }
104
105 func (trw *testResponseWriter) Header() http.Header {
106 if trw.header == nil {
107 trw.header = make(http.Header)
108 }
109
110 return trw.header
111 }
112
113 func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
114 if trw.status == 0 {
115 trw.status = http.StatusOK
116 }
117
118 n = len(p)
119 trw.written += int64(n)
120 return
121 }
122
123 func (trw *testResponseWriter) WriteHeader(status int) {
124 trw.status = status
125 }
126
127 func (trw *testResponseWriter) Flush() {
128 trw.flushed = true
129 }
130
131 func TestWithResponseWriter(t *testing.T) {
132 trw := testResponseWriter{}
133 ctx, rw := WithResponseWriter(Background(), &trw)
134
135 if ctx.Value("http.response") != rw {
136 t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), rw)
137 }
138
139 grw, err := GetResponseWriter(ctx)
140 if err != nil {
141 t.Fatalf("error getting response writer: %v", err)
142 }
143
144 if grw != rw {
145 t.Fatalf("unexpected response writer returned: %#v != %#v", grw, rw)
146 }
147
148 if ctx.Value("http.response.status") != 0 {
149 t.Fatalf("response status should always be a number and should be zero here: %v != 0", ctx.Value("http.response.status"))
150 }
151
152 if n, err := rw.Write(make([]byte, 1024)); err != nil {
153 t.Fatalf("unexpected error writing: %v", err)
154 } else if n != 1024 {
155 t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
156 }
157
158 if ctx.Value("http.response.status") != http.StatusOK {
159 t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
160 }
161
162 if ctx.Value("http.response.written") != int64(1024) {
163 t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
164 }
165
166
167 rw.(http.Flusher).Flush()
168
169 if !trw.flushed {
170 t.Fatalf("response writer not flushed")
171 }
172
173
174
175 rw.WriteHeader(http.StatusBadRequest)
176
177 if ctx.Value("http.response.status") != http.StatusBadRequest {
178 t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
179 }
180 }
181
182 func TestWithVars(t *testing.T) {
183 var req http.Request
184 vars := map[string]string{
185 "foo": "asdf",
186 "bar": "qwer",
187 }
188
189 getVarsFromRequest = func(r *http.Request) map[string]string {
190 if r != &req {
191 t.Fatalf("unexpected request: %v != %v", r, req)
192 }
193
194 return vars
195 }
196
197 ctx := WithVars(Background(), &req)
198 for _, testcase := range []struct {
199 key string
200 expected interface{}
201 }{
202 {
203 key: "vars",
204 expected: vars,
205 },
206 {
207 key: "vars.foo",
208 expected: "asdf",
209 },
210 {
211 key: "vars.bar",
212 expected: "qwer",
213 },
214 } {
215 v := ctx.Value(testcase.key)
216
217 if !reflect.DeepEqual(v, testcase.expected) {
218 t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
219 }
220 }
221 }
222
223
224
225
226
227 func TestRemoteAddr(t *testing.T) {
228 var expectedRemote string
229 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
230 defer r.Body.Close()
231
232 if r.RemoteAddr == expectedRemote {
233 t.Errorf("Unexpected matching remote addresses")
234 }
235
236 actualRemote := RemoteAddr(r)
237 if expectedRemote != actualRemote {
238 t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
239 }
240
241 w.WriteHeader(200)
242 }))
243
244 defer backend.Close()
245 backendURL, err := url.Parse(backend.URL)
246 if err != nil {
247 t.Fatal(err)
248 }
249
250 proxy := httputil.NewSingleHostReverseProxy(backendURL)
251 frontend := httptest.NewServer(proxy)
252 defer frontend.Close()
253
254
255 expectedRemote = "127.0.0.1"
256 proxyReq, err := http.NewRequest("GET", frontend.URL, nil)
257 if err != nil {
258 t.Fatal(err)
259 }
260
261 _, err = http.DefaultClient.Do(proxyReq)
262 if err != nil {
263 t.Fatal(err)
264 }
265
266
267 getReq, err := http.NewRequest("GET", backend.URL, nil)
268 if err != nil {
269 t.Fatal(err)
270 }
271
272 expectedRemote = "1.2.3.4"
273 getReq.Header["X-Real-ip"] = []string{expectedRemote}
274 _, err = http.DefaultClient.Do(getReq)
275 if err != nil {
276 t.Fatal(err)
277 }
278
279
280 getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
281 _, err = http.DefaultClient.Do(getReq)
282 if err != nil {
283 t.Fatal(err)
284 }
285 }
286
View as plain text