...

Source file src/go.opencensus.io/plugin/ochttp/server_test.go

Documentation: go.opencensus.io/plugin/ochttp

     1  package ochttp
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/tls"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"golang.org/x/net/http2"
    20  
    21  	"go.opencensus.io/stats/view"
    22  	"go.opencensus.io/trace"
    23  )
    24  
    25  func httpHandler(statusCode, respSize int) http.Handler {
    26  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    27  		w.WriteHeader(statusCode)
    28  		body := make([]byte, respSize)
    29  		w.Write(body)
    30  	})
    31  }
    32  
    33  func updateMean(mean float64, sample, count int) float64 {
    34  	if count == 1 {
    35  		return float64(sample)
    36  	}
    37  	return mean + (float64(sample)-mean)/float64(count)
    38  }
    39  
    40  func TestHandlerStatsCollection(t *testing.T) {
    41  	if err := view.Register(DefaultServerViews...); err != nil {
    42  		t.Fatalf("Failed to register ochttp.DefaultServerViews error: %v", err)
    43  	}
    44  
    45  	views := []string{
    46  		"opencensus.io/http/server/request_count",
    47  		"opencensus.io/http/server/latency",
    48  		"opencensus.io/http/server/request_bytes",
    49  		"opencensus.io/http/server/response_bytes",
    50  	}
    51  
    52  	// TODO: test latency measurements?
    53  	tests := []struct {
    54  		name, method, target                 string
    55  		count, statusCode, reqSize, respSize int
    56  	}{
    57  		{"get 200", "GET", "http://opencensus.io/request/one", 10, 200, 512, 512},
    58  		{"post 503", "POST", "http://opencensus.io/request/two", 5, 503, 1024, 16384},
    59  		{"no body 302", "GET", "http://opencensus.io/request/three", 2, 302, 0, 0},
    60  	}
    61  	totalCount, meanReqSize, meanRespSize := 0, 0.0, 0.0
    62  
    63  	for _, test := range tests {
    64  		t.Run(test.name, func(t *testing.T) {
    65  			body := bytes.NewBuffer(make([]byte, test.reqSize))
    66  			r := httptest.NewRequest(test.method, test.target, body)
    67  			w := httptest.NewRecorder()
    68  			mux := http.NewServeMux()
    69  			mux.Handle("/request/", httpHandler(test.statusCode, test.respSize))
    70  			h := &Handler{
    71  				Handler: mux,
    72  				StartOptions: trace.StartOptions{
    73  					Sampler: trace.NeverSample(),
    74  				},
    75  			}
    76  			for i := 0; i < test.count; i++ {
    77  				h.ServeHTTP(w, r)
    78  				totalCount++
    79  				// Distributions do not track sum directly, we must
    80  				// mimic their behaviour to avoid rounding failures.
    81  				meanReqSize = updateMean(meanReqSize, test.reqSize, totalCount)
    82  				meanRespSize = updateMean(meanRespSize, test.respSize, totalCount)
    83  			}
    84  		})
    85  	}
    86  
    87  	for _, viewName := range views {
    88  		v := view.Find(viewName)
    89  		if v == nil {
    90  			t.Errorf("view not found %q", viewName)
    91  			continue
    92  		}
    93  		rows, err := view.RetrieveData(viewName)
    94  		if err != nil {
    95  			t.Error(err)
    96  			continue
    97  		}
    98  		if got, want := len(rows), 1; got != want {
    99  			t.Errorf("len(%q) = %d; want %d", viewName, got, want)
   100  			continue
   101  		}
   102  		data := rows[0].Data
   103  
   104  		var count int
   105  		var sum float64
   106  		switch data := data.(type) {
   107  		case *view.CountData:
   108  			count = int(data.Value)
   109  		case *view.DistributionData:
   110  			count = int(data.Count)
   111  			sum = data.Sum()
   112  		default:
   113  			t.Errorf("Unknown data type: %v", data)
   114  			continue
   115  		}
   116  
   117  		if got, want := count, totalCount; got != want {
   118  			t.Fatalf("%s = %d; want %d", viewName, got, want)
   119  		}
   120  
   121  		// We can only check sum for distribution views.
   122  		switch viewName {
   123  		case "opencensus.io/http/server/request_bytes":
   124  			if got, want := sum, meanReqSize*float64(totalCount); got != want {
   125  				t.Fatalf("%s = %g; want %g", viewName, got, want)
   126  			}
   127  		case "opencensus.io/http/server/response_bytes":
   128  			if got, want := sum, meanRespSize*float64(totalCount); got != want {
   129  				t.Fatalf("%s = %g; want %g", viewName, got, want)
   130  			}
   131  		}
   132  	}
   133  }
   134  
   135  type testResponseWriterHijacker struct {
   136  	httptest.ResponseRecorder
   137  }
   138  
   139  func (trw *testResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   140  	return nil, nil, nil
   141  }
   142  
   143  func TestUnitTestHandlerProxiesHijack(t *testing.T) {
   144  	tests := []struct {
   145  		w         http.ResponseWriter
   146  		hasHijack bool
   147  	}{
   148  		{httptest.NewRecorder(), false},
   149  		{nil, false},
   150  		{new(testResponseWriterHijacker), true},
   151  	}
   152  
   153  	for i, tt := range tests {
   154  		tw := &trackingResponseWriter{writer: tt.w}
   155  		w := tw.wrappedResponseWriter()
   156  		_, ttHijacker := w.(http.Hijacker)
   157  		if want, have := tt.hasHijack, ttHijacker; want != have {
   158  			t.Errorf("#%d Hijack got %t, want %t", i, have, want)
   159  		}
   160  	}
   161  }
   162  
   163  // Integration test with net/http to ensure that our Handler proxies to its
   164  // response the call to (http.Hijack).Hijacker() and that that successfully
   165  // passes with HTTP/1.1 connections. See Issue #642
   166  func TestHandlerProxiesHijack_HTTP1(t *testing.T) {
   167  	cst := httptest.NewServer(&Handler{
   168  		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   169  			var writeMsg func(string)
   170  			defer func() {
   171  				err := recover()
   172  				writeMsg(fmt.Sprintf("Proto=%s\npanic=%v", r.Proto, err != nil))
   173  			}()
   174  			conn, _, _ := w.(http.Hijacker).Hijack()
   175  			writeMsg = func(msg string) {
   176  				fmt.Fprintf(conn, "%s 200\nContentLength: %d", r.Proto, len(msg))
   177  				fmt.Fprintf(conn, "\r\n\r\n%s", msg)
   178  				conn.Close()
   179  			}
   180  		}),
   181  	})
   182  	defer cst.Close()
   183  
   184  	testCases := []struct {
   185  		name string
   186  		tr   *http.Transport
   187  		want string
   188  	}{
   189  		{
   190  			name: "http1-transport",
   191  			tr:   new(http.Transport),
   192  			want: "Proto=HTTP/1.1\npanic=false",
   193  		},
   194  		{
   195  			name: "http2-transport",
   196  			tr: func() *http.Transport {
   197  				tr := new(http.Transport)
   198  				http2.ConfigureTransport(tr)
   199  				return tr
   200  			}(),
   201  			want: "Proto=HTTP/1.1\npanic=false",
   202  		},
   203  	}
   204  
   205  	for _, tc := range testCases {
   206  		c := &http.Client{Transport: &Transport{Base: tc.tr}}
   207  		res, err := c.Get(cst.URL)
   208  		if err != nil {
   209  			t.Errorf("(%s) unexpected error %v", tc.name, err)
   210  			continue
   211  		}
   212  		blob, _ := ioutil.ReadAll(res.Body)
   213  		res.Body.Close()
   214  		if g, w := string(blob), tc.want; g != w {
   215  			t.Errorf("(%s) got = %q; want = %q", tc.name, g, w)
   216  		}
   217  	}
   218  }
   219  
   220  // Integration test with net/http, x/net/http2 to ensure that our Handler proxies
   221  // to its response the call to (http.Hijack).Hijacker() and that that crashes
   222  // since http.Hijacker and HTTP/2.0 connections are incompatible, but the
   223  // detection is only at runtime and ensure that we can stream and flush to the
   224  // connection even after invoking Hijack(). See Issue #642.
   225  func TestHandlerProxiesHijack_HTTP2(t *testing.T) {
   226  	cst := httptest.NewUnstartedServer(&Handler{
   227  		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   228  			if _, ok := w.(http.Hijacker); ok {
   229  				conn, _, err := w.(http.Hijacker).Hijack()
   230  				if conn != nil {
   231  					data := fmt.Sprintf("Surprisingly got the Hijacker() Proto: %s", r.Proto)
   232  					fmt.Fprintf(conn, "%s 200\nContent-Length:%d\r\n\r\n%s", r.Proto, len(data), data)
   233  					conn.Close()
   234  					return
   235  				}
   236  
   237  				switch {
   238  				case err == nil:
   239  					fmt.Fprintf(w, "Unexpectedly did not encounter an error!")
   240  				default:
   241  					fmt.Fprintf(w, "Unexpected error: %v", err)
   242  				case strings.Contains(err.(error).Error(), "Hijack"):
   243  					// Confirmed HTTP/2.0, let's stream to it
   244  					for i := 0; i < 5; i++ {
   245  						fmt.Fprintf(w, "%d\n", i)
   246  						w.(http.Flusher).Flush()
   247  					}
   248  				}
   249  			} else {
   250  				// Confirmed HTTP/2.0, let's stream to it
   251  				for i := 0; i < 5; i++ {
   252  					fmt.Fprintf(w, "%d\n", i)
   253  					w.(http.Flusher).Flush()
   254  				}
   255  			}
   256  		}),
   257  	})
   258  	cst.TLS = &tls.Config{NextProtos: []string{"h2"}}
   259  	cst.StartTLS()
   260  	defer cst.Close()
   261  
   262  	if wantPrefix := "https://"; !strings.HasPrefix(cst.URL, wantPrefix) {
   263  		t.Fatalf("URL got = %q wantPrefix = %q", cst.URL, wantPrefix)
   264  	}
   265  
   266  	tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
   267  	http2.ConfigureTransport(tr)
   268  	c := &http.Client{Transport: tr}
   269  	res, err := c.Get(cst.URL)
   270  	if err != nil {
   271  		t.Fatalf("Unexpected error %v", err)
   272  	}
   273  	blob, _ := ioutil.ReadAll(res.Body)
   274  	res.Body.Close()
   275  	if g, w := string(blob), "0\n1\n2\n3\n4\n"; g != w {
   276  		t.Errorf("got = %q; want = %q", g, w)
   277  	}
   278  }
   279  
   280  func TestEnsureTrackingResponseWriterSetsStatusCode(t *testing.T) {
   281  	// Ensure that the trackingResponseWriter always sets the spanStatus on ending the span.
   282  	// Because we can only examine the Status after exporting, this test roundtrips a
   283  	// couple of requests and then later examines the exported spans.
   284  	// See Issue #700.
   285  	exporter := &spanExporter{cur: make(chan *trace.SpanData, 1)}
   286  	trace.RegisterExporter(exporter)
   287  	defer trace.UnregisterExporter(exporter)
   288  
   289  	tests := []struct {
   290  		res  *http.Response
   291  		want trace.Status
   292  	}{
   293  		{res: &http.Response{StatusCode: 200}, want: trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
   294  		{res: &http.Response{StatusCode: 500}, want: trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
   295  		{res: &http.Response{StatusCode: 403}, want: trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}},
   296  		{res: &http.Response{StatusCode: 401}, want: trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}},
   297  		{res: &http.Response{StatusCode: 429}, want: trace.Status{Code: trace.StatusCodeResourceExhausted, Message: `RESOURCE_EXHAUSTED`}},
   298  	}
   299  
   300  	for _, tt := range tests {
   301  		t.Run(tt.want.Message, func(t *testing.T) {
   302  			ctx := context.Background()
   303  			prc, pwc := io.Pipe()
   304  			go func() {
   305  				pwc.Write([]byte("Foo"))
   306  				pwc.Close()
   307  			}()
   308  			inRes := tt.res
   309  			inRes.Body = prc
   310  			tr := &traceTransport{
   311  				base:           &testResponseTransport{res: inRes},
   312  				formatSpanName: spanNameFromURL,
   313  				startOptions: trace.StartOptions{
   314  					Sampler: trace.AlwaysSample(),
   315  				},
   316  			}
   317  			req, err := http.NewRequest("POST", "https://example.org", bytes.NewReader([]byte("testing")))
   318  			if err != nil {
   319  				t.Fatalf("NewRequest error: %v", err)
   320  			}
   321  			req = req.WithContext(ctx)
   322  			res, err := tr.RoundTrip(req)
   323  			if err != nil {
   324  				t.Fatalf("RoundTrip error: %v", err)
   325  			}
   326  			_, _ = ioutil.ReadAll(res.Body)
   327  			res.Body.Close()
   328  
   329  			cur := <-exporter.cur
   330  			if got, want := cur.Status, tt.want; got != want {
   331  				t.Fatalf("SpanData:\ngot =  (%#v)\nwant = (%#v)", got, want)
   332  			}
   333  		})
   334  	}
   335  }
   336  
   337  type spanExporter struct {
   338  	sync.Mutex
   339  	cur chan *trace.SpanData
   340  }
   341  
   342  var _ trace.Exporter = (*spanExporter)(nil)
   343  
   344  func (se *spanExporter) ExportSpan(sd *trace.SpanData) {
   345  	se.Lock()
   346  	se.cur <- sd
   347  	se.Unlock()
   348  }
   349  
   350  type testResponseTransport struct {
   351  	res *http.Response
   352  }
   353  
   354  var _ http.RoundTripper = (*testResponseTransport)(nil)
   355  
   356  func (rb *testResponseTransport) RoundTrip(*http.Request) (*http.Response, error) {
   357  	return rb.res, nil
   358  }
   359  
   360  func TestHandlerImplementsHTTPPusher(t *testing.T) {
   361  	cst := setupAndStartServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   362  		pusher, ok := w.(http.Pusher)
   363  		if !ok {
   364  			w.Write([]byte("false"))
   365  			return
   366  		}
   367  		err := pusher.Push("/static.css", &http.PushOptions{
   368  			Method: "GET",
   369  			Header: http.Header{"Accept-Encoding": r.Header["Accept-Encoding"]},
   370  		})
   371  		if err != nil && false {
   372  			// TODO: (@odeke-em) consult with Go stdlib for why trying
   373  			// to configure even an HTTP/2 server and HTTP/2 transport
   374  			// still return http.ErrNotSupported even without using ochttp.Handler.
   375  			http.Error(w, err.Error(), http.StatusBadRequest)
   376  			return
   377  		}
   378  		w.Write([]byte("true"))
   379  	}), asHTTP2)
   380  	defer cst.Close()
   381  
   382  	tests := []struct {
   383  		rt       http.RoundTripper
   384  		wantBody string
   385  	}{
   386  		{
   387  			rt:       h1Transport(),
   388  			wantBody: "false",
   389  		},
   390  		{
   391  			rt:       h2Transport(),
   392  			wantBody: "true",
   393  		},
   394  		{
   395  			rt:       &Transport{Base: h1Transport()},
   396  			wantBody: "false",
   397  		},
   398  		{
   399  			rt:       &Transport{Base: h2Transport()},
   400  			wantBody: "true",
   401  		},
   402  	}
   403  
   404  	for i, tt := range tests {
   405  		c := &http.Client{Transport: &Transport{Base: tt.rt}}
   406  		res, err := c.Get(cst.URL)
   407  		if err != nil {
   408  			t.Errorf("#%d: Unexpected error %v", i, err)
   409  			continue
   410  		}
   411  		body, _ := ioutil.ReadAll(res.Body)
   412  		_ = res.Body.Close()
   413  		if g, w := string(body), tt.wantBody; g != w {
   414  			t.Errorf("#%d: got = %q; want = %q", i, g, w)
   415  		}
   416  	}
   417  }
   418  
   419  const (
   420  	isNil       = "isNil"
   421  	hang        = "hang"
   422  	ended       = "ended"
   423  	nonNotifier = "nonNotifier"
   424  
   425  	asHTTP1 = false
   426  	asHTTP2 = true
   427  )
   428  
   429  func setupAndStartServer(hf func(http.ResponseWriter, *http.Request), isHTTP2 bool) *httptest.Server {
   430  	cst := httptest.NewUnstartedServer(&Handler{
   431  		Handler: http.HandlerFunc(hf),
   432  	})
   433  	if isHTTP2 {
   434  		http2.ConfigureServer(cst.Config, new(http2.Server))
   435  		cst.TLS = cst.Config.TLSConfig
   436  		cst.StartTLS()
   437  	} else {
   438  		cst.Start()
   439  	}
   440  
   441  	return cst
   442  }
   443  
   444  func insecureTLS() *tls.Config     { return &tls.Config{InsecureSkipVerify: true} }
   445  func h1Transport() *http.Transport { return &http.Transport{TLSClientConfig: insecureTLS()} }
   446  func h2Transport() *http.Transport {
   447  	tr := &http.Transport{TLSClientConfig: insecureTLS()}
   448  	http2.ConfigureTransport(tr)
   449  	return tr
   450  }
   451  
   452  type concurrentBuffer struct {
   453  	sync.RWMutex
   454  	bw *bytes.Buffer
   455  }
   456  
   457  func (cw *concurrentBuffer) Write(b []byte) (int, error) {
   458  	cw.Lock()
   459  	defer cw.Unlock()
   460  
   461  	return cw.bw.Write(b)
   462  }
   463  
   464  func (cw *concurrentBuffer) String() string {
   465  	cw.Lock()
   466  	defer cw.Unlock()
   467  
   468  	return cw.bw.String()
   469  }
   470  
   471  func handleCloseNotify(outLog io.Writer) http.HandlerFunc {
   472  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   473  		cn, ok := w.(http.CloseNotifier)
   474  		if !ok {
   475  			fmt.Fprintln(outLog, nonNotifier)
   476  			return
   477  		}
   478  		ch := cn.CloseNotify()
   479  		if ch == nil {
   480  			fmt.Fprintln(outLog, isNil)
   481  			return
   482  		}
   483  
   484  		<-ch
   485  		fmt.Fprintln(outLog, ended)
   486  	})
   487  }
   488  
   489  func TestHandlerImplementsHTTPCloseNotify(t *testing.T) {
   490  	http1Log := &concurrentBuffer{bw: new(bytes.Buffer)}
   491  	http1Server := setupAndStartServer(handleCloseNotify(http1Log), asHTTP1)
   492  	http2Log := &concurrentBuffer{bw: new(bytes.Buffer)}
   493  	http2Server := setupAndStartServer(handleCloseNotify(http2Log), asHTTP2)
   494  
   495  	defer http1Server.Close()
   496  	defer http2Server.Close()
   497  
   498  	tests := []struct {
   499  		url  string
   500  		want string
   501  	}{
   502  		{url: http1Server.URL, want: nonNotifier},
   503  		{url: http2Server.URL, want: ended},
   504  	}
   505  
   506  	transports := []struct {
   507  		name string
   508  		rt   http.RoundTripper
   509  	}{
   510  		{name: "http2+ochttp", rt: &Transport{Base: h2Transport()}},
   511  		{name: "http1+ochttp", rt: &Transport{Base: h1Transport()}},
   512  		{name: "http1-ochttp", rt: h1Transport()},
   513  		{name: "http2-ochttp", rt: h2Transport()},
   514  	}
   515  
   516  	// Each transport invokes one of two server types, either HTTP/1 or HTTP/2
   517  	for _, trc := range transports {
   518  		// Try out all the transport combinations
   519  		for i, tt := range tests {
   520  			req, err := http.NewRequest("GET", tt.url, nil)
   521  			if err != nil {
   522  				t.Errorf("#%d: Unexpected error making request: %v", i, err)
   523  				continue
   524  			}
   525  
   526  			// Using a timeout to ensure that the request is cancelled and the server
   527  			// if its handler implements CloseNotify will see this as the client leaving.
   528  			ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
   529  			defer cancel()
   530  			req = req.WithContext(ctx)
   531  
   532  			client := &http.Client{Transport: trc.rt}
   533  			res, err := client.Do(req)
   534  			if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
   535  				t.Errorf("#%d: %sClient Unexpected error %v", i, trc.name, err)
   536  				continue
   537  			}
   538  			if res != nil && res.Body != nil {
   539  				io.CopyN(ioutil.Discard, res.Body, 5)
   540  				_ = res.Body.Close()
   541  			}
   542  		}
   543  	}
   544  
   545  	// Wait for a couple of milliseconds for the GoAway frames to be properly propagated
   546  	<-time.After(200 * time.Millisecond)
   547  
   548  	wantHTTP1Log := strings.Repeat("ended\n", len(transports))
   549  	wantHTTP2Log := strings.Repeat("ended\n", len(transports))
   550  	if g, w := http1Log.String(), wantHTTP1Log; g != w {
   551  		t.Errorf("HTTP1Log got\n\t%q\nwant\n\t%q", g, w)
   552  	}
   553  	if g, w := http2Log.String(), wantHTTP2Log; g != w {
   554  		t.Errorf("HTTP2Log got\n\t%q\nwant\n\t%q", g, w)
   555  	}
   556  }
   557  
   558  func testHealthEndpointSkipArray(r *http.Request) bool {
   559  	for _, toSkip := range []string{"/health", "/metrics"} {
   560  		if r.URL.Path == toSkip {
   561  			return true
   562  		}
   563  	}
   564  	return false
   565  }
   566  
   567  func TestIgnoreHealthEndpoints(t *testing.T) {
   568  	var spans int
   569  
   570  	client := &http.Client{}
   571  	tests := []struct {
   572  		path               string
   573  		healthEndpointFunc func(*http.Request) bool
   574  	}{
   575  		{"/healthz", nil},
   576  		{"/_ah/health", nil},
   577  		{"/healthz", testHealthEndpointSkipArray},
   578  		{"/_ah/health", testHealthEndpointSkipArray},
   579  		{"/health", testHealthEndpointSkipArray},
   580  		{"/metrics", testHealthEndpointSkipArray},
   581  	}
   582  	for _, tt := range tests {
   583  		t.Run(tt.path, func(t *testing.T) {
   584  			ts := httptest.NewServer(&Handler{
   585  				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   586  					span := trace.FromContext(r.Context())
   587  					if span != nil {
   588  						spans++
   589  					}
   590  					fmt.Fprint(w, "ok")
   591  				}),
   592  				StartOptions: trace.StartOptions{
   593  					Sampler: trace.AlwaysSample(),
   594  				},
   595  				IsHealthEndpoint: tt.healthEndpointFunc,
   596  			})
   597  			defer ts.Close()
   598  
   599  			resp, err := client.Get(ts.URL + tt.path)
   600  			if err != nil {
   601  				t.Fatalf("Cannot GET %q: %v", tt.path, err)
   602  			}
   603  			b, err := ioutil.ReadAll(resp.Body)
   604  			if err != nil {
   605  				t.Fatalf("Cannot read body for %q: %v", tt.path, err)
   606  			}
   607  
   608  			if got, want := string(b), "ok"; got != want {
   609  				t.Fatalf("Body for %q = %q; want %q", tt.path, got, want)
   610  			}
   611  			resp.Body.Close()
   612  		})
   613  	}
   614  
   615  	if spans > 0 {
   616  		t.Errorf("Got %v spans; want no spans", spans)
   617  	}
   618  }
   619  

View as plain text