1 package ochttp
2
3 import (
4 "bufio"
5 "bytes"
6 "context"
7 "crypto/tls"
8 "fmt"
9 "io"
10 "io/ioutil"
11 "net"
12 "net/http"
13 "net/http/httptest"
14 "strings"
15 "sync"
16 "testing"
17 "time"
18
19 "golang.org/x/net/http2"
20
21 "go.opencensus.io/stats/view"
22 "go.opencensus.io/trace"
23 )
24
25 func httpHandler(statusCode, respSize int) http.Handler {
26 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27 w.WriteHeader(statusCode)
28 body := make([]byte, respSize)
29 w.Write(body)
30 })
31 }
32
33 func updateMean(mean float64, sample, count int) float64 {
34 if count == 1 {
35 return float64(sample)
36 }
37 return mean + (float64(sample)-mean)/float64(count)
38 }
39
40 func TestHandlerStatsCollection(t *testing.T) {
41 if err := view.Register(DefaultServerViews...); err != nil {
42 t.Fatalf("Failed to register ochttp.DefaultServerViews error: %v", err)
43 }
44
45 views := []string{
46 "opencensus.io/http/server/request_count",
47 "opencensus.io/http/server/latency",
48 "opencensus.io/http/server/request_bytes",
49 "opencensus.io/http/server/response_bytes",
50 }
51
52
53 tests := []struct {
54 name, method, target string
55 count, statusCode, reqSize, respSize int
56 }{
57 {"get 200", "GET", "http://opencensus.io/request/one", 10, 200, 512, 512},
58 {"post 503", "POST", "http://opencensus.io/request/two", 5, 503, 1024, 16384},
59 {"no body 302", "GET", "http://opencensus.io/request/three", 2, 302, 0, 0},
60 }
61 totalCount, meanReqSize, meanRespSize := 0, 0.0, 0.0
62
63 for _, test := range tests {
64 t.Run(test.name, func(t *testing.T) {
65 body := bytes.NewBuffer(make([]byte, test.reqSize))
66 r := httptest.NewRequest(test.method, test.target, body)
67 w := httptest.NewRecorder()
68 mux := http.NewServeMux()
69 mux.Handle("/request/", httpHandler(test.statusCode, test.respSize))
70 h := &Handler{
71 Handler: mux,
72 StartOptions: trace.StartOptions{
73 Sampler: trace.NeverSample(),
74 },
75 }
76 for i := 0; i < test.count; i++ {
77 h.ServeHTTP(w, r)
78 totalCount++
79
80
81 meanReqSize = updateMean(meanReqSize, test.reqSize, totalCount)
82 meanRespSize = updateMean(meanRespSize, test.respSize, totalCount)
83 }
84 })
85 }
86
87 for _, viewName := range views {
88 v := view.Find(viewName)
89 if v == nil {
90 t.Errorf("view not found %q", viewName)
91 continue
92 }
93 rows, err := view.RetrieveData(viewName)
94 if err != nil {
95 t.Error(err)
96 continue
97 }
98 if got, want := len(rows), 1; got != want {
99 t.Errorf("len(%q) = %d; want %d", viewName, got, want)
100 continue
101 }
102 data := rows[0].Data
103
104 var count int
105 var sum float64
106 switch data := data.(type) {
107 case *view.CountData:
108 count = int(data.Value)
109 case *view.DistributionData:
110 count = int(data.Count)
111 sum = data.Sum()
112 default:
113 t.Errorf("Unknown data type: %v", data)
114 continue
115 }
116
117 if got, want := count, totalCount; got != want {
118 t.Fatalf("%s = %d; want %d", viewName, got, want)
119 }
120
121
122 switch viewName {
123 case "opencensus.io/http/server/request_bytes":
124 if got, want := sum, meanReqSize*float64(totalCount); got != want {
125 t.Fatalf("%s = %g; want %g", viewName, got, want)
126 }
127 case "opencensus.io/http/server/response_bytes":
128 if got, want := sum, meanRespSize*float64(totalCount); got != want {
129 t.Fatalf("%s = %g; want %g", viewName, got, want)
130 }
131 }
132 }
133 }
134
135 type testResponseWriterHijacker struct {
136 httptest.ResponseRecorder
137 }
138
139 func (trw *testResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
140 return nil, nil, nil
141 }
142
143 func TestUnitTestHandlerProxiesHijack(t *testing.T) {
144 tests := []struct {
145 w http.ResponseWriter
146 hasHijack bool
147 }{
148 {httptest.NewRecorder(), false},
149 {nil, false},
150 {new(testResponseWriterHijacker), true},
151 }
152
153 for i, tt := range tests {
154 tw := &trackingResponseWriter{writer: tt.w}
155 w := tw.wrappedResponseWriter()
156 _, ttHijacker := w.(http.Hijacker)
157 if want, have := tt.hasHijack, ttHijacker; want != have {
158 t.Errorf("#%d Hijack got %t, want %t", i, have, want)
159 }
160 }
161 }
162
163
164
165
166 func TestHandlerProxiesHijack_HTTP1(t *testing.T) {
167 cst := httptest.NewServer(&Handler{
168 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
169 var writeMsg func(string)
170 defer func() {
171 err := recover()
172 writeMsg(fmt.Sprintf("Proto=%s\npanic=%v", r.Proto, err != nil))
173 }()
174 conn, _, _ := w.(http.Hijacker).Hijack()
175 writeMsg = func(msg string) {
176 fmt.Fprintf(conn, "%s 200\nContentLength: %d", r.Proto, len(msg))
177 fmt.Fprintf(conn, "\r\n\r\n%s", msg)
178 conn.Close()
179 }
180 }),
181 })
182 defer cst.Close()
183
184 testCases := []struct {
185 name string
186 tr *http.Transport
187 want string
188 }{
189 {
190 name: "http1-transport",
191 tr: new(http.Transport),
192 want: "Proto=HTTP/1.1\npanic=false",
193 },
194 {
195 name: "http2-transport",
196 tr: func() *http.Transport {
197 tr := new(http.Transport)
198 http2.ConfigureTransport(tr)
199 return tr
200 }(),
201 want: "Proto=HTTP/1.1\npanic=false",
202 },
203 }
204
205 for _, tc := range testCases {
206 c := &http.Client{Transport: &Transport{Base: tc.tr}}
207 res, err := c.Get(cst.URL)
208 if err != nil {
209 t.Errorf("(%s) unexpected error %v", tc.name, err)
210 continue
211 }
212 blob, _ := ioutil.ReadAll(res.Body)
213 res.Body.Close()
214 if g, w := string(blob), tc.want; g != w {
215 t.Errorf("(%s) got = %q; want = %q", tc.name, g, w)
216 }
217 }
218 }
219
220
221
222
223
224
225 func TestHandlerProxiesHijack_HTTP2(t *testing.T) {
226 cst := httptest.NewUnstartedServer(&Handler{
227 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
228 if _, ok := w.(http.Hijacker); ok {
229 conn, _, err := w.(http.Hijacker).Hijack()
230 if conn != nil {
231 data := fmt.Sprintf("Surprisingly got the Hijacker() Proto: %s", r.Proto)
232 fmt.Fprintf(conn, "%s 200\nContent-Length:%d\r\n\r\n%s", r.Proto, len(data), data)
233 conn.Close()
234 return
235 }
236
237 switch {
238 case err == nil:
239 fmt.Fprintf(w, "Unexpectedly did not encounter an error!")
240 default:
241 fmt.Fprintf(w, "Unexpected error: %v", err)
242 case strings.Contains(err.(error).Error(), "Hijack"):
243
244 for i := 0; i < 5; i++ {
245 fmt.Fprintf(w, "%d\n", i)
246 w.(http.Flusher).Flush()
247 }
248 }
249 } else {
250
251 for i := 0; i < 5; i++ {
252 fmt.Fprintf(w, "%d\n", i)
253 w.(http.Flusher).Flush()
254 }
255 }
256 }),
257 })
258 cst.TLS = &tls.Config{NextProtos: []string{"h2"}}
259 cst.StartTLS()
260 defer cst.Close()
261
262 if wantPrefix := "https://"; !strings.HasPrefix(cst.URL, wantPrefix) {
263 t.Fatalf("URL got = %q wantPrefix = %q", cst.URL, wantPrefix)
264 }
265
266 tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
267 http2.ConfigureTransport(tr)
268 c := &http.Client{Transport: tr}
269 res, err := c.Get(cst.URL)
270 if err != nil {
271 t.Fatalf("Unexpected error %v", err)
272 }
273 blob, _ := ioutil.ReadAll(res.Body)
274 res.Body.Close()
275 if g, w := string(blob), "0\n1\n2\n3\n4\n"; g != w {
276 t.Errorf("got = %q; want = %q", g, w)
277 }
278 }
279
280 func TestEnsureTrackingResponseWriterSetsStatusCode(t *testing.T) {
281
282
283
284
285 exporter := &spanExporter{cur: make(chan *trace.SpanData, 1)}
286 trace.RegisterExporter(exporter)
287 defer trace.UnregisterExporter(exporter)
288
289 tests := []struct {
290 res *http.Response
291 want trace.Status
292 }{
293 {res: &http.Response{StatusCode: 200}, want: trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
294 {res: &http.Response{StatusCode: 500}, want: trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
295 {res: &http.Response{StatusCode: 403}, want: trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}},
296 {res: &http.Response{StatusCode: 401}, want: trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}},
297 {res: &http.Response{StatusCode: 429}, want: trace.Status{Code: trace.StatusCodeResourceExhausted, Message: `RESOURCE_EXHAUSTED`}},
298 }
299
300 for _, tt := range tests {
301 t.Run(tt.want.Message, func(t *testing.T) {
302 ctx := context.Background()
303 prc, pwc := io.Pipe()
304 go func() {
305 pwc.Write([]byte("Foo"))
306 pwc.Close()
307 }()
308 inRes := tt.res
309 inRes.Body = prc
310 tr := &traceTransport{
311 base: &testResponseTransport{res: inRes},
312 formatSpanName: spanNameFromURL,
313 startOptions: trace.StartOptions{
314 Sampler: trace.AlwaysSample(),
315 },
316 }
317 req, err := http.NewRequest("POST", "https://example.org", bytes.NewReader([]byte("testing")))
318 if err != nil {
319 t.Fatalf("NewRequest error: %v", err)
320 }
321 req = req.WithContext(ctx)
322 res, err := tr.RoundTrip(req)
323 if err != nil {
324 t.Fatalf("RoundTrip error: %v", err)
325 }
326 _, _ = ioutil.ReadAll(res.Body)
327 res.Body.Close()
328
329 cur := <-exporter.cur
330 if got, want := cur.Status, tt.want; got != want {
331 t.Fatalf("SpanData:\ngot = (%#v)\nwant = (%#v)", got, want)
332 }
333 })
334 }
335 }
336
337 type spanExporter struct {
338 sync.Mutex
339 cur chan *trace.SpanData
340 }
341
342 var _ trace.Exporter = (*spanExporter)(nil)
343
344 func (se *spanExporter) ExportSpan(sd *trace.SpanData) {
345 se.Lock()
346 se.cur <- sd
347 se.Unlock()
348 }
349
350 type testResponseTransport struct {
351 res *http.Response
352 }
353
354 var _ http.RoundTripper = (*testResponseTransport)(nil)
355
356 func (rb *testResponseTransport) RoundTrip(*http.Request) (*http.Response, error) {
357 return rb.res, nil
358 }
359
360 func TestHandlerImplementsHTTPPusher(t *testing.T) {
361 cst := setupAndStartServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
362 pusher, ok := w.(http.Pusher)
363 if !ok {
364 w.Write([]byte("false"))
365 return
366 }
367 err := pusher.Push("/static.css", &http.PushOptions{
368 Method: "GET",
369 Header: http.Header{"Accept-Encoding": r.Header["Accept-Encoding"]},
370 })
371 if err != nil && false {
372
373
374
375 http.Error(w, err.Error(), http.StatusBadRequest)
376 return
377 }
378 w.Write([]byte("true"))
379 }), asHTTP2)
380 defer cst.Close()
381
382 tests := []struct {
383 rt http.RoundTripper
384 wantBody string
385 }{
386 {
387 rt: h1Transport(),
388 wantBody: "false",
389 },
390 {
391 rt: h2Transport(),
392 wantBody: "true",
393 },
394 {
395 rt: &Transport{Base: h1Transport()},
396 wantBody: "false",
397 },
398 {
399 rt: &Transport{Base: h2Transport()},
400 wantBody: "true",
401 },
402 }
403
404 for i, tt := range tests {
405 c := &http.Client{Transport: &Transport{Base: tt.rt}}
406 res, err := c.Get(cst.URL)
407 if err != nil {
408 t.Errorf("#%d: Unexpected error %v", i, err)
409 continue
410 }
411 body, _ := ioutil.ReadAll(res.Body)
412 _ = res.Body.Close()
413 if g, w := string(body), tt.wantBody; g != w {
414 t.Errorf("#%d: got = %q; want = %q", i, g, w)
415 }
416 }
417 }
418
419 const (
420 isNil = "isNil"
421 hang = "hang"
422 ended = "ended"
423 nonNotifier = "nonNotifier"
424
425 asHTTP1 = false
426 asHTTP2 = true
427 )
428
429 func setupAndStartServer(hf func(http.ResponseWriter, *http.Request), isHTTP2 bool) *httptest.Server {
430 cst := httptest.NewUnstartedServer(&Handler{
431 Handler: http.HandlerFunc(hf),
432 })
433 if isHTTP2 {
434 http2.ConfigureServer(cst.Config, new(http2.Server))
435 cst.TLS = cst.Config.TLSConfig
436 cst.StartTLS()
437 } else {
438 cst.Start()
439 }
440
441 return cst
442 }
443
444 func insecureTLS() *tls.Config { return &tls.Config{InsecureSkipVerify: true} }
445 func h1Transport() *http.Transport { return &http.Transport{TLSClientConfig: insecureTLS()} }
446 func h2Transport() *http.Transport {
447 tr := &http.Transport{TLSClientConfig: insecureTLS()}
448 http2.ConfigureTransport(tr)
449 return tr
450 }
451
452 type concurrentBuffer struct {
453 sync.RWMutex
454 bw *bytes.Buffer
455 }
456
457 func (cw *concurrentBuffer) Write(b []byte) (int, error) {
458 cw.Lock()
459 defer cw.Unlock()
460
461 return cw.bw.Write(b)
462 }
463
464 func (cw *concurrentBuffer) String() string {
465 cw.Lock()
466 defer cw.Unlock()
467
468 return cw.bw.String()
469 }
470
471 func handleCloseNotify(outLog io.Writer) http.HandlerFunc {
472 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
473 cn, ok := w.(http.CloseNotifier)
474 if !ok {
475 fmt.Fprintln(outLog, nonNotifier)
476 return
477 }
478 ch := cn.CloseNotify()
479 if ch == nil {
480 fmt.Fprintln(outLog, isNil)
481 return
482 }
483
484 <-ch
485 fmt.Fprintln(outLog, ended)
486 })
487 }
488
489 func TestHandlerImplementsHTTPCloseNotify(t *testing.T) {
490 http1Log := &concurrentBuffer{bw: new(bytes.Buffer)}
491 http1Server := setupAndStartServer(handleCloseNotify(http1Log), asHTTP1)
492 http2Log := &concurrentBuffer{bw: new(bytes.Buffer)}
493 http2Server := setupAndStartServer(handleCloseNotify(http2Log), asHTTP2)
494
495 defer http1Server.Close()
496 defer http2Server.Close()
497
498 tests := []struct {
499 url string
500 want string
501 }{
502 {url: http1Server.URL, want: nonNotifier},
503 {url: http2Server.URL, want: ended},
504 }
505
506 transports := []struct {
507 name string
508 rt http.RoundTripper
509 }{
510 {name: "http2+ochttp", rt: &Transport{Base: h2Transport()}},
511 {name: "http1+ochttp", rt: &Transport{Base: h1Transport()}},
512 {name: "http1-ochttp", rt: h1Transport()},
513 {name: "http2-ochttp", rt: h2Transport()},
514 }
515
516
517 for _, trc := range transports {
518
519 for i, tt := range tests {
520 req, err := http.NewRequest("GET", tt.url, nil)
521 if err != nil {
522 t.Errorf("#%d: Unexpected error making request: %v", i, err)
523 continue
524 }
525
526
527
528 ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
529 defer cancel()
530 req = req.WithContext(ctx)
531
532 client := &http.Client{Transport: trc.rt}
533 res, err := client.Do(req)
534 if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
535 t.Errorf("#%d: %sClient Unexpected error %v", i, trc.name, err)
536 continue
537 }
538 if res != nil && res.Body != nil {
539 io.CopyN(ioutil.Discard, res.Body, 5)
540 _ = res.Body.Close()
541 }
542 }
543 }
544
545
546 <-time.After(200 * time.Millisecond)
547
548 wantHTTP1Log := strings.Repeat("ended\n", len(transports))
549 wantHTTP2Log := strings.Repeat("ended\n", len(transports))
550 if g, w := http1Log.String(), wantHTTP1Log; g != w {
551 t.Errorf("HTTP1Log got\n\t%q\nwant\n\t%q", g, w)
552 }
553 if g, w := http2Log.String(), wantHTTP2Log; g != w {
554 t.Errorf("HTTP2Log got\n\t%q\nwant\n\t%q", g, w)
555 }
556 }
557
558 func testHealthEndpointSkipArray(r *http.Request) bool {
559 for _, toSkip := range []string{"/health", "/metrics"} {
560 if r.URL.Path == toSkip {
561 return true
562 }
563 }
564 return false
565 }
566
567 func TestIgnoreHealthEndpoints(t *testing.T) {
568 var spans int
569
570 client := &http.Client{}
571 tests := []struct {
572 path string
573 healthEndpointFunc func(*http.Request) bool
574 }{
575 {"/healthz", nil},
576 {"/_ah/health", nil},
577 {"/healthz", testHealthEndpointSkipArray},
578 {"/_ah/health", testHealthEndpointSkipArray},
579 {"/health", testHealthEndpointSkipArray},
580 {"/metrics", testHealthEndpointSkipArray},
581 }
582 for _, tt := range tests {
583 t.Run(tt.path, func(t *testing.T) {
584 ts := httptest.NewServer(&Handler{
585 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
586 span := trace.FromContext(r.Context())
587 if span != nil {
588 spans++
589 }
590 fmt.Fprint(w, "ok")
591 }),
592 StartOptions: trace.StartOptions{
593 Sampler: trace.AlwaysSample(),
594 },
595 IsHealthEndpoint: tt.healthEndpointFunc,
596 })
597 defer ts.Close()
598
599 resp, err := client.Get(ts.URL + tt.path)
600 if err != nil {
601 t.Fatalf("Cannot GET %q: %v", tt.path, err)
602 }
603 b, err := ioutil.ReadAll(resp.Body)
604 if err != nil {
605 t.Fatalf("Cannot read body for %q: %v", tt.path, err)
606 }
607
608 if got, want := string(b), "ok"; got != want {
609 t.Fatalf("Body for %q = %q; want %q", tt.path, got, want)
610 }
611 resp.Body.Close()
612 })
613 }
614
615 if spans > 0 {
616 t.Errorf("Got %v spans; want no spans", spans)
617 }
618 }
619
View as plain text