...

Source file src/github.com/docker/distribution/context/http_test.go

Documentation: github.com/docker/distribution/context

     1  package context
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"net/http/httputil"
     7  	"net/url"
     8  	"reflect"
     9  	"testing"
    10  	"time"
    11  )
    12  
    13  func TestWithRequest(t *testing.T) {
    14  	var req http.Request
    15  
    16  	start := time.Now()
    17  	req.Method = "GET"
    18  	req.Host = "example.com"
    19  	req.RequestURI = "/test-test"
    20  	req.Header = make(http.Header)
    21  	req.Header.Set("Referer", "foo.com/referer")
    22  	req.Header.Set("User-Agent", "test/0.1")
    23  
    24  	ctx := WithRequest(Background(), &req)
    25  	for _, testcase := range []struct {
    26  		key      string
    27  		expected interface{}
    28  	}{
    29  		{
    30  			key:      "http.request",
    31  			expected: &req,
    32  		},
    33  		{
    34  			key: "http.request.id",
    35  		},
    36  		{
    37  			key:      "http.request.method",
    38  			expected: req.Method,
    39  		},
    40  		{
    41  			key:      "http.request.host",
    42  			expected: req.Host,
    43  		},
    44  		{
    45  			key:      "http.request.uri",
    46  			expected: req.RequestURI,
    47  		},
    48  		{
    49  			key:      "http.request.referer",
    50  			expected: req.Referer(),
    51  		},
    52  		{
    53  			key:      "http.request.useragent",
    54  			expected: req.UserAgent(),
    55  		},
    56  		{
    57  			key:      "http.request.remoteaddr",
    58  			expected: req.RemoteAddr,
    59  		},
    60  		{
    61  			key: "http.request.startedat",
    62  		},
    63  	} {
    64  		v := ctx.Value(testcase.key)
    65  
    66  		if v == nil {
    67  			t.Fatalf("value not found for %q", testcase.key)
    68  		}
    69  
    70  		if testcase.expected != nil && v != testcase.expected {
    71  			t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
    72  		}
    73  
    74  		// Key specific checks!
    75  		switch testcase.key {
    76  		case "http.request.id":
    77  			if _, ok := v.(string); !ok {
    78  				t.Fatalf("request id not a string: %v", v)
    79  			}
    80  		case "http.request.startedat":
    81  			vt, ok := v.(time.Time)
    82  			if !ok {
    83  				t.Fatalf("value not a time: %v", v)
    84  			}
    85  
    86  			now := time.Now()
    87  			if vt.After(now) {
    88  				t.Fatalf("time generated too late: %v > %v", vt, now)
    89  			}
    90  
    91  			if vt.Before(start) {
    92  				t.Fatalf("time generated too early: %v < %v", vt, start)
    93  			}
    94  		}
    95  	}
    96  }
    97  
    98  type testResponseWriter struct {
    99  	flushed bool
   100  	status  int
   101  	written int64
   102  	header  http.Header
   103  }
   104  
   105  func (trw *testResponseWriter) Header() http.Header {
   106  	if trw.header == nil {
   107  		trw.header = make(http.Header)
   108  	}
   109  
   110  	return trw.header
   111  }
   112  
   113  func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
   114  	if trw.status == 0 {
   115  		trw.status = http.StatusOK
   116  	}
   117  
   118  	n = len(p)
   119  	trw.written += int64(n)
   120  	return
   121  }
   122  
   123  func (trw *testResponseWriter) WriteHeader(status int) {
   124  	trw.status = status
   125  }
   126  
   127  func (trw *testResponseWriter) Flush() {
   128  	trw.flushed = true
   129  }
   130  
   131  func TestWithResponseWriter(t *testing.T) {
   132  	trw := testResponseWriter{}
   133  	ctx, rw := WithResponseWriter(Background(), &trw)
   134  
   135  	if ctx.Value("http.response") != rw {
   136  		t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), rw)
   137  	}
   138  
   139  	grw, err := GetResponseWriter(ctx)
   140  	if err != nil {
   141  		t.Fatalf("error getting response writer: %v", err)
   142  	}
   143  
   144  	if grw != rw {
   145  		t.Fatalf("unexpected response writer returned: %#v != %#v", grw, rw)
   146  	}
   147  
   148  	if ctx.Value("http.response.status") != 0 {
   149  		t.Fatalf("response status should always be a number and should be zero here: %v != 0", ctx.Value("http.response.status"))
   150  	}
   151  
   152  	if n, err := rw.Write(make([]byte, 1024)); err != nil {
   153  		t.Fatalf("unexpected error writing: %v", err)
   154  	} else if n != 1024 {
   155  		t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
   156  	}
   157  
   158  	if ctx.Value("http.response.status") != http.StatusOK {
   159  		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
   160  	}
   161  
   162  	if ctx.Value("http.response.written") != int64(1024) {
   163  		t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
   164  	}
   165  
   166  	// Make sure flush propagates
   167  	rw.(http.Flusher).Flush()
   168  
   169  	if !trw.flushed {
   170  		t.Fatalf("response writer not flushed")
   171  	}
   172  
   173  	// Write another status and make sure context is correct. This normally
   174  	// wouldn't work except for in this contrived testcase.
   175  	rw.WriteHeader(http.StatusBadRequest)
   176  
   177  	if ctx.Value("http.response.status") != http.StatusBadRequest {
   178  		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
   179  	}
   180  }
   181  
   182  func TestWithVars(t *testing.T) {
   183  	var req http.Request
   184  	vars := map[string]string{
   185  		"foo": "asdf",
   186  		"bar": "qwer",
   187  	}
   188  
   189  	getVarsFromRequest = func(r *http.Request) map[string]string {
   190  		if r != &req {
   191  			t.Fatalf("unexpected request: %v != %v", r, req)
   192  		}
   193  
   194  		return vars
   195  	}
   196  
   197  	ctx := WithVars(Background(), &req)
   198  	for _, testcase := range []struct {
   199  		key      string
   200  		expected interface{}
   201  	}{
   202  		{
   203  			key:      "vars",
   204  			expected: vars,
   205  		},
   206  		{
   207  			key:      "vars.foo",
   208  			expected: "asdf",
   209  		},
   210  		{
   211  			key:      "vars.bar",
   212  			expected: "qwer",
   213  		},
   214  	} {
   215  		v := ctx.Value(testcase.key)
   216  
   217  		if !reflect.DeepEqual(v, testcase.expected) {
   218  			t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
   219  		}
   220  	}
   221  }
   222  
   223  // SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
   224  // RemoteAddr().  A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
   225  // at the transport layer to 127.0.0.1:<port> .  However, as the X-Forwarded-For header
   226  // just contains the IP address, it is different enough for testing.
   227  func TestRemoteAddr(t *testing.T) {
   228  	var expectedRemote string
   229  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   230  		defer r.Body.Close()
   231  
   232  		if r.RemoteAddr == expectedRemote {
   233  			t.Errorf("Unexpected matching remote addresses")
   234  		}
   235  
   236  		actualRemote := RemoteAddr(r)
   237  		if expectedRemote != actualRemote {
   238  			t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
   239  		}
   240  
   241  		w.WriteHeader(200)
   242  	}))
   243  
   244  	defer backend.Close()
   245  	backendURL, err := url.Parse(backend.URL)
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  
   250  	proxy := httputil.NewSingleHostReverseProxy(backendURL)
   251  	frontend := httptest.NewServer(proxy)
   252  	defer frontend.Close()
   253  
   254  	// X-Forwarded-For set by proxy
   255  	expectedRemote = "127.0.0.1"
   256  	proxyReq, err := http.NewRequest("GET", frontend.URL, nil)
   257  	if err != nil {
   258  		t.Fatal(err)
   259  	}
   260  
   261  	_, err = http.DefaultClient.Do(proxyReq)
   262  	if err != nil {
   263  		t.Fatal(err)
   264  	}
   265  
   266  	// RemoteAddr in X-Real-Ip
   267  	getReq, err := http.NewRequest("GET", backend.URL, nil)
   268  	if err != nil {
   269  		t.Fatal(err)
   270  	}
   271  
   272  	expectedRemote = "1.2.3.4"
   273  	getReq.Header["X-Real-ip"] = []string{expectedRemote}
   274  	_, err = http.DefaultClient.Do(getReq)
   275  	if err != nil {
   276  		t.Fatal(err)
   277  	}
   278  
   279  	// Valid X-Real-Ip and invalid X-Forwarded-For
   280  	getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
   281  	_, err = http.DefaultClient.Do(getReq)
   282  	if err != nil {
   283  		t.Fatal(err)
   284  	}
   285  }
   286  

View as plain text