1
16
17 package proxy
18
19 import (
20 "bufio"
21 "bytes"
22 "compress/gzip"
23 "context"
24 "crypto/tls"
25 "crypto/x509"
26 "errors"
27 "fmt"
28 "io"
29 "net"
30 "net/http"
31 "net/http/httptest"
32 "net/http/httputil"
33 "net/url"
34 "reflect"
35 "strconv"
36 "strings"
37 "testing"
38 "time"
39
40 "github.com/stretchr/testify/assert"
41 "github.com/stretchr/testify/require"
42
43 "golang.org/x/net/websocket"
44
45 "k8s.io/apimachinery/pkg/util/httpstream"
46 utilnet "k8s.io/apimachinery/pkg/util/net"
47 )
48
49 const fakeStatusCode = 567
50
51 type fakeResponder struct {
52 t *testing.T
53 called bool
54 err error
55
56 w http.ResponseWriter
57 }
58
59 func (r *fakeResponder) Error(w http.ResponseWriter, req *http.Request, err error) {
60 if r.called {
61 r.t.Errorf("Error responder called again!\nprevious error: %v\nnew error: %v", r.err, err)
62 }
63
64 w.WriteHeader(fakeStatusCode)
65 _, writeErr := w.Write([]byte(err.Error()))
66 assert.NoError(r.t, writeErr)
67
68 r.called = true
69 r.err = err
70 }
71
72 type fakeConn struct {
73 err error
74 }
75
76 func (f *fakeConn) Read([]byte) (int, error) { return 0, f.err }
77 func (f *fakeConn) Write([]byte) (int, error) { return 0, f.err }
78 func (f *fakeConn) Close() error { return nil }
79 func (fakeConn) LocalAddr() net.Addr { return nil }
80 func (fakeConn) RemoteAddr() net.Addr { return nil }
81 func (fakeConn) SetDeadline(t time.Time) error { return nil }
82 func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
83 func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
84
85 type SimpleBackendHandler struct {
86 requestURL url.URL
87 requestHost string
88 requestHeader http.Header
89 requestBody []byte
90 requestMethod string
91 responseBody string
92 responseHeader map[string]string
93 t *testing.T
94 }
95
96 func (s *SimpleBackendHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
97 s.requestURL = *req.URL
98 s.requestHost = req.Host
99 s.requestHeader = req.Header
100 s.requestMethod = req.Method
101 var err error
102 s.requestBody, err = io.ReadAll(req.Body)
103 if err != nil {
104 s.t.Errorf("Unexpected error: %v", err)
105 return
106 }
107
108 if s.responseHeader != nil {
109 for k, v := range s.responseHeader {
110 w.Header().Add(k, v)
111 }
112 }
113 w.Write([]byte(s.responseBody))
114 }
115
116 func validateParameters(t *testing.T, name string, actual url.Values, expected map[string]string) {
117 for k, v := range expected {
118 actualValue, ok := actual[k]
119 if !ok {
120 t.Errorf("%s: Expected parameter %s not received", name, k)
121 continue
122 }
123 if actualValue[0] != v {
124 t.Errorf("%s: Parameter %s values don't match. Actual: %#v, Expected: %s",
125 name, k, actualValue, v)
126 }
127 }
128 }
129
130 func validateHeaders(t *testing.T, name string, actual http.Header, expected map[string]string, notExpected []string) {
131 for k, v := range expected {
132 actualValue, ok := actual[k]
133 if !ok {
134 t.Errorf("%s: Expected header %s not received", name, k)
135 continue
136 }
137 if actualValue[0] != v {
138 t.Errorf("%s: Header %s values don't match. Actual: %s, Expected: %s",
139 name, k, actualValue, v)
140 }
141 }
142 if notExpected == nil {
143 return
144 }
145 for _, h := range notExpected {
146 if _, present := actual[h]; present {
147 t.Errorf("%s: unexpected header: %s", name, h)
148 }
149 }
150 }
151
152 func TestServeHTTP(t *testing.T) {
153 tests := []struct {
154 name string
155 method string
156 requestPath string
157 expectedPath string
158 requestBody string
159 requestParams map[string]string
160 requestHeader map[string]string
161 responseHeader map[string]string
162 expectedRespHeader map[string]string
163 notExpectedRespHeader []string
164 upgradeRequired bool
165 appendLocationPath bool
166 expectError func(err error) bool
167 useLocationHost bool
168 }{
169 {
170 name: "root path, simple get",
171 method: "GET",
172 requestPath: "/",
173 expectedPath: "/",
174 },
175 {
176 name: "no upgrade header sent",
177 method: "GET",
178 requestPath: "/",
179 upgradeRequired: true,
180 expectError: func(err error) bool {
181 return err != nil && strings.Contains(err.Error(), "Upgrade request required")
182 },
183 },
184 {
185 name: "simple path, get",
186 method: "GET",
187 requestPath: "/path/to/test",
188 expectedPath: "/path/to/test",
189 },
190 {
191 name: "request params",
192 method: "POST",
193 requestPath: "/some/path/",
194 expectedPath: "/some/path/",
195 requestParams: map[string]string{"param1": "value/1", "param2": "value%2"},
196 requestBody: "test request body",
197 },
198 {
199 name: "request headers",
200 method: "PUT",
201 requestPath: "/some/path",
202 expectedPath: "/some/path",
203 requestHeader: map[string]string{"Header1": "value1", "Header2": "value2"},
204 },
205 {
206 name: "empty path - slash should be added",
207 method: "GET",
208 requestPath: "",
209 expectedPath: "/",
210 },
211 {
212 name: "remove CORS headers",
213 method: "GET",
214 requestPath: "/some/path",
215 expectedPath: "/some/path",
216 responseHeader: map[string]string{
217 "Header1": "value1",
218 "Access-Control-Allow-Origin": "some.server",
219 "Access-Control-Allow-Methods": "GET"},
220 expectedRespHeader: map[string]string{
221 "Header1": "value1",
222 },
223 notExpectedRespHeader: []string{
224 "Access-Control-Allow-Origin",
225 "Access-Control-Allow-Methods",
226 },
227 },
228 {
229 name: "use location host",
230 method: "GET",
231 requestPath: "/some/path",
232 expectedPath: "/some/path",
233 useLocationHost: true,
234 },
235 {
236 name: "use location host - invalid upgrade",
237 method: "GET",
238 upgradeRequired: true,
239 requestHeader: map[string]string{
240 httpstream.HeaderConnection: httpstream.HeaderUpgrade,
241 },
242 expectError: func(err error) bool {
243 return err != nil && strings.Contains(err.Error(), "invalid upgrade response: status code 200")
244 },
245 requestPath: "/some/path",
246 expectedPath: "/some/path",
247 useLocationHost: true,
248 },
249 {
250 name: "append server path to request path",
251 method: "GET",
252 requestPath: "/base",
253 expectedPath: "/base/base",
254 appendLocationPath: true,
255 },
256 {
257 name: "append server path to request path with ending slash",
258 method: "GET",
259 requestPath: "/base/",
260 expectedPath: "/base/base/",
261 appendLocationPath: true,
262 },
263 {
264 name: "don't append server path to request path",
265 method: "GET",
266 requestPath: "/base",
267 expectedPath: "/base",
268 appendLocationPath: false,
269 },
270 }
271
272 for i, test := range tests {
273 func() {
274 backendResponse := "<html><head></head><body><a href=\"/test/path\">Hello</a></body></html>"
275 backendResponseHeader := test.responseHeader
276
277 if backendResponseHeader == nil && test.expectedRespHeader == nil {
278 backendResponseHeader = map[string]string{"Content-Type": "text/html"}
279 test.expectedRespHeader = map[string]string{"Content-Type": "text/html"}
280 }
281 backendHandler := &SimpleBackendHandler{
282 responseBody: backendResponse,
283 responseHeader: backendResponseHeader,
284 }
285 backendServer := httptest.NewServer(backendHandler)
286 defer backendServer.Close()
287
288 responder := &fakeResponder{t: t}
289 backendURL, _ := url.Parse(backendServer.URL)
290 backendURL.Path = test.requestPath
291 proxyHandler := NewUpgradeAwareHandler(backendURL, nil, false, test.upgradeRequired, responder)
292 proxyHandler.UseLocationHost = test.useLocationHost
293 proxyHandler.AppendLocationPath = test.appendLocationPath
294 proxyServer := httptest.NewServer(proxyHandler)
295 defer proxyServer.Close()
296 proxyURL, _ := url.Parse(proxyServer.URL)
297 proxyURL.Path = test.requestPath
298 paramValues := url.Values{}
299 for k, v := range test.requestParams {
300 paramValues[k] = []string{v}
301 }
302 proxyURL.RawQuery = paramValues.Encode()
303 var requestBody io.Reader
304 if test.requestBody != "" {
305 requestBody = bytes.NewBufferString(test.requestBody)
306 }
307 req, err := http.NewRequest(test.method, proxyURL.String(), requestBody)
308 if test.requestHeader != nil {
309 header := http.Header{}
310 for k, v := range test.requestHeader {
311 header.Add(k, v)
312 }
313 req.Header = header
314 }
315 if err != nil {
316 t.Errorf("Error creating client request: %v", err)
317 }
318 client := &http.Client{}
319 res, err := client.Do(req)
320 if err != nil {
321 t.Errorf("Error from proxy request: %v", err)
322 }
323
324
325 if test.useLocationHost && backendHandler.requestHost != backendURL.Host {
326 t.Errorf("Unexpected request host: %s", backendHandler.requestHost)
327 } else if !test.useLocationHost && backendHandler.requestHost == backendURL.Host {
328 t.Errorf("Unexpected request host: %s", backendHandler.requestHost)
329 }
330
331 if test.expectError != nil {
332 if !responder.called {
333 t.Errorf("%d: responder was not invoked", i)
334 return
335 }
336 if !test.expectError(responder.err) {
337 t.Errorf("%d: unexpected error: %v", i, responder.err)
338 }
339 return
340 }
341
342
343
344 if backendHandler.requestMethod != test.method {
345 t.Errorf("Unexpected request method: %s. Expected: %s",
346 backendHandler.requestMethod, test.method)
347 }
348
349
350 if string(backendHandler.requestBody) != test.requestBody {
351 t.Errorf("Unexpected request body: %s. Expected: %s",
352 string(backendHandler.requestBody), test.requestBody)
353 }
354
355
356 if backendHandler.requestURL.Path != test.expectedPath {
357 t.Errorf("Unexpected request path: %s", backendHandler.requestURL.Path)
358 }
359
360 validateParameters(t, test.name, backendHandler.requestURL.Query(), test.requestParams)
361
362
363 validateHeaders(t, test.name+" backend request", backendHandler.requestHeader,
364 test.requestHeader, nil)
365
366
367
368
369 validateHeaders(t, test.name+" backend headers", res.Header, test.expectedRespHeader, test.notExpectedRespHeader)
370
371
372 responseBody, err := io.ReadAll(res.Body)
373 if err != nil {
374 t.Errorf("Unexpected error reading response body: %v", err)
375 }
376 if rb := string(responseBody); rb != backendResponse {
377 t.Errorf("Did not get expected response body: %s. Expected: %s", rb, backendResponse)
378 }
379
380
381 if responder.called {
382 t.Errorf("Unexpected proxy handler error: %v", responder.err)
383 }
384 }()
385 }
386 }
387
388 type RoundTripperFunc func(req *http.Request) (*http.Response, error)
389
390 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
391 return fn(req)
392 }
393
394 func TestProxyUpgrade(t *testing.T) {
395
396 localhostPool := x509.NewCertPool()
397 if !localhostPool.AppendCertsFromPEM(localhostCert) {
398 t.Errorf("error setting up localhostCert pool")
399 }
400 var d net.Dialer
401
402 testcases := map[string]struct {
403 ServerFunc func(http.Handler) *httptest.Server
404 ProxyTransport http.RoundTripper
405 UpgradeTransport UpgradeRequestRoundTripper
406 ExpectedAuth string
407 }{
408 "http": {
409 ServerFunc: httptest.NewServer,
410 ProxyTransport: nil,
411 },
412 "both client and server support http2, but force to http/1.1 for upgrade": {
413 ServerFunc: func(h http.Handler) *httptest.Server {
414 cert, err := tls.X509KeyPair(exampleCert, exampleKey)
415 if err != nil {
416 t.Errorf("https (invalid hostname): proxy_test: %v", err)
417 }
418 ts := httptest.NewUnstartedServer(h)
419 ts.TLS = &tls.Config{
420 Certificates: []tls.Certificate{cert},
421 NextProtos: []string{"http2", "http/1.1"},
422 }
423 ts.StartTLS()
424 return ts
425 },
426 ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{TLSClientConfig: &tls.Config{
427 NextProtos: []string{"http2", "http/1.1"},
428 InsecureSkipVerify: true,
429 }}),
430 },
431 "https (invalid hostname + InsecureSkipVerify)": {
432 ServerFunc: func(h http.Handler) *httptest.Server {
433 cert, err := tls.X509KeyPair(exampleCert, exampleKey)
434 if err != nil {
435 t.Errorf("https (invalid hostname): proxy_test: %v", err)
436 }
437 ts := httptest.NewUnstartedServer(h)
438 ts.TLS = &tls.Config{
439 Certificates: []tls.Certificate{cert},
440 }
441 ts.StartTLS()
442 return ts
443 },
444 ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}),
445 },
446 "https (valid hostname + RootCAs)": {
447 ServerFunc: func(h http.Handler) *httptest.Server {
448 cert, err := tls.X509KeyPair(localhostCert, localhostKey)
449 if err != nil {
450 t.Errorf("https (valid hostname): proxy_test: %v", err)
451 }
452 ts := httptest.NewUnstartedServer(h)
453 ts.TLS = &tls.Config{
454 Certificates: []tls.Certificate{cert},
455 }
456 ts.StartTLS()
457 return ts
458 },
459 ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
460 },
461 "https (valid hostname + RootCAs + custom dialer)": {
462 ServerFunc: func(h http.Handler) *httptest.Server {
463 cert, err := tls.X509KeyPair(localhostCert, localhostKey)
464 if err != nil {
465 t.Errorf("https (valid hostname): proxy_test: %v", err)
466 }
467 ts := httptest.NewUnstartedServer(h)
468 ts.TLS = &tls.Config{
469 Certificates: []tls.Certificate{cert},
470 }
471 ts.StartTLS()
472 return ts
473 },
474 ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
475 },
476 "https (valid hostname + RootCAs + custom dialer + bearer token)": {
477 ServerFunc: func(h http.Handler) *httptest.Server {
478 cert, err := tls.X509KeyPair(localhostCert, localhostKey)
479 if err != nil {
480 t.Errorf("https (valid hostname): proxy_test: %v", err)
481 }
482 ts := httptest.NewUnstartedServer(h)
483 ts.TLS = &tls.Config{
484 Certificates: []tls.Certificate{cert},
485 }
486 ts.StartTLS()
487 return ts
488 },
489 ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
490 UpgradeTransport: NewUpgradeRequestRoundTripper(
491 utilnet.SetOldTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
492 RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
493 req = utilnet.CloneRequest(req)
494 req.Header.Set("Authorization", "Bearer 1234")
495 return MirrorRequest.RoundTrip(req)
496 }),
497 ),
498 ExpectedAuth: "Bearer 1234",
499 },
500 }
501
502 for k, tc := range testcases {
503 tcName := k
504 backendPath := "/hello"
505 func() {
506 backend := http.NewServeMux()
507 backend.Handle("/hello", websocket.Handler(func(ws *websocket.Conn) {
508 if ws.Request().Header.Get("Authorization") != tc.ExpectedAuth {
509 t.Errorf("%s: unexpected headers on request: %v", k, ws.Request().Header)
510 defer ws.Close()
511 ws.Write([]byte("you failed"))
512 return
513 }
514 defer ws.Close()
515 body := make([]byte, 5)
516 ws.Read(body)
517 ws.Write([]byte("hello " + string(body)))
518 }))
519 backend.Handle("/redirect", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
520 http.Redirect(w, r, "/hello", http.StatusFound)
521 }))
522 backendServer := tc.ServerFunc(backend)
523 defer backendServer.Close()
524
525 serverURL, _ := url.Parse(backendServer.URL)
526 serverURL.Path = backendPath
527 proxyHandler := NewUpgradeAwareHandler(serverURL, tc.ProxyTransport, false, false, &noErrorsAllowed{t: t})
528 proxyHandler.UpgradeTransport = tc.UpgradeTransport
529 proxy := httptest.NewServer(proxyHandler)
530 defer proxy.Close()
531
532 ws, err := websocket.Dial("ws://"+proxy.Listener.Addr().String()+"/some/path", "", "http://127.0.0.1/")
533 if err != nil {
534 t.Fatalf("%s: websocket dial err: %s", tcName, err)
535 }
536 defer ws.Close()
537
538 if _, err := ws.Write([]byte("world")); err != nil {
539 t.Fatalf("%s: write err: %s", tcName, err)
540 }
541
542 response := make([]byte, 20)
543 n, err := ws.Read(response)
544 if err != nil {
545 t.Fatalf("%s: read err: %s", tcName, err)
546 }
547 if e, a := "hello world", string(response[0:n]); e != a {
548 t.Fatalf("%s: expected '%#v', got '%#v'", tcName, e, a)
549 }
550 }()
551 }
552 }
553
554 type noErrorsAllowed struct {
555 t *testing.T
556 }
557
558 func (r *noErrorsAllowed) Error(w http.ResponseWriter, req *http.Request, err error) {
559 r.t.Error(err)
560 }
561
562 func TestProxyUpgradeConnectionErrorResponse(t *testing.T) {
563 var (
564 responder *fakeResponder
565 expectedErr = errors.New("EXPECTED")
566 )
567 proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
568 transport := &http.Transport{
569 Proxy: http.ProxyFromEnvironment,
570 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
571 return &fakeConn{err: expectedErr}, nil
572 },
573 MaxIdleConns: 100,
574 IdleConnTimeout: 90 * time.Second,
575 TLSHandshakeTimeout: 10 * time.Second,
576 ExpectContinueTimeout: 1 * time.Second,
577 }
578 responder = &fakeResponder{t: t, w: w}
579 proxyHandler := NewUpgradeAwareHandler(
580 &url.URL{
581 Host: "fake-backend",
582 },
583 transport,
584 false,
585 true,
586 responder,
587 )
588 proxyHandler.ServeHTTP(w, r)
589 }))
590 defer proxy.Close()
591
592
593 req, err := http.NewRequest("POST", "http://"+proxy.Listener.Addr().String()+"/some/path", nil)
594 require.NoError(t, err)
595 req.Header.Set(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
596 resp, err := http.DefaultClient.Do(req)
597 require.NoError(t, err)
598 defer resp.Body.Close()
599
600
601 assert.True(t, responder.called)
602 assert.Equal(t, fakeStatusCode, resp.StatusCode)
603 msg, err := io.ReadAll(resp.Body)
604 require.NoError(t, err)
605 assert.Contains(t, string(msg), expectedErr.Error())
606 }
607
608 func TestProxyUpgradeErrorResponseTerminates(t *testing.T) {
609 for _, code := range []int{400, 500} {
610 t.Run(fmt.Sprintf("code=%v", code), func(t *testing.T) {
611
612 backend := http.NewServeMux()
613 backend.Handle("/hello", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
614 w.WriteHeader(code)
615 w.Write([]byte(`some data`))
616 }))
617 backend.Handle("/there", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
618 t.Error("request to /there")
619 }))
620 backendServer := httptest.NewServer(backend)
621 defer backendServer.Close()
622 backendServerURL, _ := url.Parse(backendServer.URL)
623 backendServerURL.Path = "/hello"
624
625
626 proxyHandler := NewUpgradeAwareHandler(backendServerURL, nil, false, false, &noErrorsAllowed{t: t})
627 proxy := httptest.NewServer(proxyHandler)
628 defer proxy.Close()
629 proxyURL, _ := url.Parse(proxy.URL)
630
631 conn, err := net.Dial("tcp", proxyURL.Host)
632 require.NoError(t, err)
633 bufferedReader := bufio.NewReader(conn)
634
635
636 req, _ := http.NewRequest("GET", "/", nil)
637 req.Header.Set(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
638 require.NoError(t, req.Write(conn))
639
640 resp, err := http.ReadResponse(bufferedReader, nil)
641 require.NoError(t, err)
642 data, err := io.ReadAll(resp.Body)
643 require.NoError(t, err)
644 require.Equal(t, resp.StatusCode, code)
645 require.Equal(t, data, []byte(`some data`))
646 resp.Body.Close()
647
648
649 b := make([]byte, 1)
650 conn.SetReadDeadline(time.Now().Add(time.Second))
651 if _, err := conn.Read(b); err != io.EOF {
652 t.Errorf("expected EOF, got %v", err)
653 }
654
655
656 req, _ = http.NewRequest("GET", "/there", nil)
657 req.Write(conn)
658
659 time.Sleep(time.Second)
660
661
662 conn.Close()
663 })
664 }
665 }
666
667 func TestProxyUpgradeErrorResponse(t *testing.T) {
668 for _, code := range []int{200, 300, 302, 307} {
669 t.Run(fmt.Sprintf("code=%v", code), func(t *testing.T) {
670
671 backend := http.NewServeMux()
672 backend.Handle("/hello", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
673 http.Redirect(w, r, "https://example.com/there", code)
674 }))
675 backendServer := httptest.NewServer(backend)
676 defer backendServer.Close()
677 backendServerURL, _ := url.Parse(backendServer.URL)
678 backendServerURL.Path = "/hello"
679
680
681 proxyHandler := NewUpgradeAwareHandler(backendServerURL, nil, false, false, &fakeResponder{t: t})
682 proxy := httptest.NewServer(proxyHandler)
683 defer proxy.Close()
684 proxyURL, _ := url.Parse(proxy.URL)
685
686 conn, err := net.Dial("tcp", proxyURL.Host)
687 require.NoError(t, err)
688 bufferedReader := bufio.NewReader(conn)
689
690
691 req, _ := http.NewRequest("GET", "/", nil)
692 req.Header.Set(httpstream.HeaderConnection, httpstream.HeaderUpgrade)
693 require.NoError(t, req.Write(conn))
694
695 resp, err := http.ReadResponse(bufferedReader, nil)
696 require.NoError(t, err)
697 assert.Equal(t, fakeStatusCode, resp.StatusCode)
698 resp.Body.Close()
699
700
701 conn.Close()
702 })
703 }
704 }
705
706 func TestRejectForwardingRedirectsOption(t *testing.T) {
707 originalBody := []byte(`some data`)
708 testCases := []struct {
709 name string
710 rejectForwardingRedirects bool
711 serverStatusCode int
712 redirect string
713 expectStatusCode int
714 expectBody []byte
715 }{
716 {
717 name: "reject redirection enabled in proxy, backend server sending 200 response",
718 rejectForwardingRedirects: true,
719 serverStatusCode: 200,
720 expectStatusCode: 200,
721 expectBody: originalBody,
722 },
723 {
724 name: "reject redirection enabled in proxy, backend server sending 301 response",
725 rejectForwardingRedirects: true,
726 serverStatusCode: 301,
727 redirect: "/",
728 expectStatusCode: 502,
729 expectBody: []byte(`the backend attempted to redirect this request, which is not permitted`),
730 },
731 {
732 name: "reject redirection enabled in proxy, backend server sending 304 response with a location header",
733 rejectForwardingRedirects: true,
734 serverStatusCode: 304,
735 redirect: "/",
736 expectStatusCode: 502,
737 expectBody: []byte(`the backend attempted to redirect this request, which is not permitted`),
738 },
739 {
740 name: "reject redirection enabled in proxy, backend server sending 304 response with no location header",
741 rejectForwardingRedirects: true,
742 serverStatusCode: 304,
743 expectStatusCode: 304,
744 expectBody: []byte{},
745 },
746 {
747 name: "reject redirection disabled in proxy, backend server sending 200 response",
748 rejectForwardingRedirects: false,
749 serverStatusCode: 200,
750 expectStatusCode: 200,
751 expectBody: originalBody,
752 },
753 {
754 name: "reject redirection disabled in proxy, backend server sending 301 response",
755 rejectForwardingRedirects: false,
756 serverStatusCode: 301,
757 redirect: "/",
758 expectStatusCode: 301,
759 expectBody: originalBody,
760 },
761 }
762 for _, tc := range testCases {
763 t.Run(tc.name, func(t *testing.T) {
764
765 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
766 if tc.redirect != "" {
767 w.Header().Set("Location", tc.redirect)
768 }
769 w.WriteHeader(tc.serverStatusCode)
770 w.Write(originalBody)
771 }))
772 defer backendServer.Close()
773 backendServerURL, _ := url.Parse(backendServer.URL)
774
775
776 proxyHandler := NewUpgradeAwareHandler(backendServerURL, nil, false, false, &fakeResponder{t: t})
777 proxyHandler.RejectForwardingRedirects = tc.rejectForwardingRedirects
778 proxy := httptest.NewServer(proxyHandler)
779 defer proxy.Close()
780 proxyURL, _ := url.Parse(proxy.URL)
781
782 conn, err := net.Dial("tcp", proxyURL.Host)
783 require.NoError(t, err)
784 bufferedReader := bufio.NewReader(conn)
785
786 req, _ := http.NewRequest("GET", proxyURL.String(), nil)
787 require.NoError(t, req.Write(conn))
788
789 resp, err := http.ReadResponse(bufferedReader, nil)
790 require.NoError(t, err)
791 assert.Equal(t, tc.expectStatusCode, resp.StatusCode)
792 data, err := io.ReadAll(resp.Body)
793 require.NoError(t, err)
794 assert.Equal(t, tc.expectBody, data)
795 assert.Equal(t, int64(len(tc.expectBody)), resp.ContentLength)
796 resp.Body.Close()
797
798
799 conn.Close()
800 })
801 }
802 }
803
804 func TestDefaultProxyTransport(t *testing.T) {
805 tests := []struct {
806 name,
807 url,
808 location,
809 expectedScheme,
810 expectedHost,
811 expectedPathPrepend string
812 }{
813 {
814 name: "simple path",
815 url: "http://test.server:8080/a/test/location",
816 location: "http://localhost/location",
817 expectedScheme: "http",
818 expectedHost: "test.server:8080",
819 expectedPathPrepend: "/a/test",
820 },
821 {
822 name: "empty path",
823 url: "http://test.server:8080/a/test/",
824 location: "http://localhost",
825 expectedScheme: "http",
826 expectedHost: "test.server:8080",
827 expectedPathPrepend: "/a/test",
828 },
829 {
830 name: "location ending in slash",
831 url: "http://test.server:8080/a/test/",
832 location: "http://localhost/",
833 expectedScheme: "http",
834 expectedHost: "test.server:8080",
835 expectedPathPrepend: "/a/test",
836 },
837 }
838
839 for _, test := range tests {
840 locURL, _ := url.Parse(test.location)
841 URL, _ := url.Parse(test.url)
842 h := NewUpgradeAwareHandler(locURL, nil, false, false, nil)
843 result := h.defaultProxyTransport(URL, nil)
844 transport := result.(*corsRemovingTransport).RoundTripper.(*Transport)
845 if transport.Scheme != test.expectedScheme {
846 t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme)
847 }
848 if transport.Host != test.expectedHost {
849 t.Errorf("%s: unexpected host. Actual: %s, Expected: %s", test.name, transport.Host, test.expectedHost)
850 }
851 if transport.PathPrepend != test.expectedPathPrepend {
852 t.Errorf("%s: unexpected path prepend. Actual: %s, Expected: %s", test.name, transport.PathPrepend, test.expectedPathPrepend)
853 }
854 }
855 }
856
857 func TestProxyRequestContentLengthAndTransferEncoding(t *testing.T) {
858 chunk := func(data []byte) []byte {
859 out := &bytes.Buffer{}
860 chunker := httputil.NewChunkedWriter(out)
861 for _, b := range data {
862 if _, err := chunker.Write([]byte{b}); err != nil {
863 panic(err)
864 }
865 }
866 chunker.Close()
867 out.Write([]byte("\r\n"))
868 return out.Bytes()
869 }
870
871 zip := func(data []byte) []byte {
872 out := &bytes.Buffer{}
873 zipper := gzip.NewWriter(out)
874 if _, err := zipper.Write(data); err != nil {
875 panic(err)
876 }
877 zipper.Close()
878 return out.Bytes()
879 }
880
881 sampleData := []byte("abcde")
882
883 table := map[string]struct {
884 reqHeaders http.Header
885 reqBody []byte
886
887 expectedHeaders http.Header
888 expectedBody []byte
889 }{
890 "content-length": {
891 reqHeaders: http.Header{
892 "Content-Length": []string{"5"},
893 },
894 reqBody: sampleData,
895
896 expectedHeaders: http.Header{
897 "Content-Length": []string{"5"},
898 "Content-Encoding": nil,
899 "Transfer-Encoding": nil,
900 },
901 expectedBody: sampleData,
902 },
903
904 "content-length + gzip content-encoding": {
905 reqHeaders: http.Header{
906 "Content-Length": []string{strconv.Itoa(len(zip(sampleData)))},
907 "Content-Encoding": []string{"gzip"},
908 },
909 reqBody: zip(sampleData),
910
911 expectedHeaders: http.Header{
912 "Content-Length": []string{strconv.Itoa(len(zip(sampleData)))},
913 "Content-Encoding": []string{"gzip"},
914 "Transfer-Encoding": nil,
915 },
916 expectedBody: zip(sampleData),
917 },
918
919 "chunked transfer-encoding": {
920 reqHeaders: http.Header{
921 "Transfer-Encoding": []string{"chunked"},
922 },
923 reqBody: chunk(sampleData),
924
925 expectedHeaders: http.Header{
926 "Content-Length": nil,
927 "Content-Encoding": nil,
928 "Transfer-Encoding": nil,
929 },
930 expectedBody: sampleData,
931 },
932
933 "chunked transfer-encoding + gzip content-encoding": {
934 reqHeaders: http.Header{
935 "Content-Encoding": []string{"gzip"},
936 "Transfer-Encoding": []string{"chunked"},
937 },
938 reqBody: chunk(zip(sampleData)),
939
940 expectedHeaders: http.Header{
941 "Content-Length": nil,
942 "Content-Encoding": []string{"gzip"},
943 "Transfer-Encoding": nil,
944 },
945 expectedBody: zip(sampleData),
946 },
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964 }
965
966 successfulResponse := "backend passed tests"
967 for k, item := range table {
968
969 downstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
970
971 for header, v := range item.expectedHeaders {
972 if !reflect.DeepEqual(v, req.Header[header]) {
973 t.Errorf("%s: Expected headers for %s to be %v, got %v", k, header, v, req.Header[header])
974 }
975 }
976
977
978 body, err := io.ReadAll(req.Body)
979 if err != nil {
980 t.Errorf("%s: unexpected error %v", k, err)
981 }
982 req.Body.Close()
983
984
985 if req.ContentLength > 0 && req.ContentLength != int64(len(body)) {
986 t.Errorf("%s: ContentLength was %d, len(data) was %d", k, req.ContentLength, len(body))
987 }
988
989
990 if !bytes.Equal(item.expectedBody, body) {
991 t.Errorf("%s: Expected %q, got %q", k, string(item.expectedBody), string(body))
992 }
993
994
995 w.Write([]byte(successfulResponse))
996 }))
997 defer downstreamServer.Close()
998
999 responder := &fakeResponder{t: t}
1000 backendURL, _ := url.Parse(downstreamServer.URL)
1001 proxyHandler := NewUpgradeAwareHandler(backendURL, nil, false, false, responder)
1002 proxyServer := httptest.NewServer(proxyHandler)
1003 defer proxyServer.Close()
1004
1005
1006 conn, err := net.Dial(proxyServer.Listener.Addr().Network(), proxyServer.Listener.Addr().String())
1007 if err != nil {
1008 t.Errorf("unexpected error %v", err)
1009 continue
1010 }
1011 defer conn.Close()
1012
1013
1014 if item.reqHeaders == nil {
1015 item.reqHeaders = http.Header{}
1016 }
1017 item.reqHeaders.Add("Connection", "close")
1018 item.reqHeaders.Add("Host", proxyServer.Listener.Addr().String())
1019
1020
1021 if _, err := fmt.Fprint(conn, "POST / HTTP/1.1\r\n"); err != nil {
1022 t.Fatalf("%s unexpected error %v", k, err)
1023 }
1024 for header, values := range item.reqHeaders {
1025 for _, value := range values {
1026 if _, err := fmt.Fprintf(conn, "%s: %s\r\n", header, value); err != nil {
1027 t.Fatalf("%s: unexpected error %v", k, err)
1028 }
1029 }
1030 }
1031
1032 if _, err := fmt.Fprint(conn, "\r\n"); err != nil {
1033 t.Fatalf("%s: unexpected error %v", k, err)
1034 }
1035
1036 if _, err := conn.Write(item.reqBody); err != nil {
1037 t.Fatalf("%s: unexpected error %v", k, err)
1038 }
1039
1040
1041 response, err := io.ReadAll(conn)
1042 if err != nil {
1043 t.Errorf("%s: unexpected error %v", k, err)
1044 continue
1045 }
1046 if !strings.HasSuffix(string(response), successfulResponse) {
1047 t.Errorf("%s: Did not get successful response: %s", k, string(response))
1048 continue
1049 }
1050 }
1051 }
1052
1053 func TestFlushIntervalHeaders(t *testing.T) {
1054 const expected = "hi"
1055 stopCh := make(chan struct{})
1056 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1057 w.Header().Add("MyHeader", expected)
1058 w.WriteHeader(200)
1059 w.(http.Flusher).Flush()
1060 <-stopCh
1061 }))
1062 defer backend.Close()
1063 defer close(stopCh)
1064
1065 backendURL, err := url.Parse(backend.URL)
1066 if err != nil {
1067 t.Fatal(err)
1068 }
1069
1070 responder := &fakeResponder{t: t}
1071 proxyHandler := NewUpgradeAwareHandler(backendURL, nil, false, false, responder)
1072
1073 frontend := httptest.NewServer(proxyHandler)
1074 defer frontend.Close()
1075
1076 req, _ := http.NewRequest("GET", frontend.URL, nil)
1077 req.Close = true
1078
1079 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
1080 defer cancel()
1081 req = req.WithContext(ctx)
1082
1083 res, err := frontend.Client().Do(req)
1084 if err != nil {
1085 t.Fatalf("Get: %v", err)
1086 }
1087 defer res.Body.Close()
1088
1089 if res.Header.Get("MyHeader") != expected {
1090 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
1091 }
1092 }
1093
1094 type fakeRT struct {
1095 err error
1096 }
1097
1098 func (frt *fakeRT) RoundTrip(*http.Request) (*http.Response, error) {
1099 return nil, frt.err
1100 }
1101
1102
1103 func TestErrorPropagation(t *testing.T) {
1104 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1105 panic("unreachable")
1106 }))
1107 defer backend.Close()
1108
1109 backendURL, err := url.Parse(backend.URL)
1110 if err != nil {
1111 t.Fatal(err)
1112 }
1113
1114 responder := &fakeResponder{t: t}
1115 expectedErr := errors.New("nasty error")
1116 proxyHandler := NewUpgradeAwareHandler(backendURL, &fakeRT{err: expectedErr}, true, false, responder)
1117
1118 frontend := httptest.NewServer(proxyHandler)
1119 defer frontend.Close()
1120
1121 req, _ := http.NewRequest("GET", frontend.URL, nil)
1122 req.Close = true
1123
1124 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
1125 defer cancel()
1126 req = req.WithContext(ctx)
1127
1128 res, err := frontend.Client().Do(req)
1129 if err != nil {
1130 t.Fatalf("Get: %v", err)
1131 }
1132 defer res.Body.Close()
1133 if res.StatusCode != fakeStatusCode {
1134 t.Fatalf("unexpected HTTP status code returned: %v, expected: %v", res.StatusCode, fakeStatusCode)
1135 }
1136 if !strings.Contains(responder.err.Error(), expectedErr.Error()) {
1137 t.Fatalf("responder got unexpected error: %v, expected the error to contain %q", responder.err.Error(), expectedErr.Error())
1138 }
1139 }
1140
1141 func TestProxyRedirectsforRootPath(t *testing.T) {
1142
1143 tests := []struct {
1144 name string
1145 method string
1146 requestPath string
1147 expectedHeader http.Header
1148 expectedStatusCode int
1149 redirect bool
1150 }{
1151 {
1152 name: "root path, simple get",
1153 method: "GET",
1154 requestPath: "",
1155 redirect: true,
1156 expectedStatusCode: 301,
1157 expectedHeader: http.Header{
1158 "Location": []string{"/"},
1159 },
1160 },
1161 {
1162 name: "root path, simple put",
1163 method: "PUT",
1164 requestPath: "",
1165 redirect: false,
1166 expectedStatusCode: 200,
1167 },
1168 {
1169 name: "root path, simple head",
1170 method: "HEAD",
1171 requestPath: "",
1172 redirect: true,
1173 expectedStatusCode: 301,
1174 expectedHeader: http.Header{
1175 "Location": []string{"/"},
1176 },
1177 },
1178 {
1179 name: "root path, simple delete with params",
1180 method: "DELETE",
1181 requestPath: "",
1182 redirect: false,
1183 expectedStatusCode: 200,
1184 },
1185 }
1186
1187 for _, test := range tests {
1188 func() {
1189 w := httptest.NewRecorder()
1190 req, err := http.NewRequest(test.method, test.requestPath, nil)
1191 if err != nil {
1192 t.Fatal(err)
1193 }
1194
1195 redirect := proxyRedirectsforRootPath(test.requestPath, w, req)
1196 if got, want := redirect, test.redirect; got != want {
1197 t.Errorf("Expected redirect state %v; got %v", want, got)
1198 }
1199
1200 res := w.Result()
1201 if got, want := res.StatusCode, test.expectedStatusCode; got != want {
1202 t.Errorf("Expected status code %d; got %d", want, got)
1203 }
1204
1205 if res.StatusCode == 301 && !reflect.DeepEqual(res.Header, test.expectedHeader) {
1206 t.Errorf("Expected location header to be %v, got %v", test.expectedHeader, res.Header)
1207 }
1208 }()
1209 }
1210 }
1211
1212
1213
1214
1215 var exampleCert = []byte(`-----BEGIN CERTIFICATE-----
1216 MIIDADCCAeigAwIBAgIQVHG3Fn9SdWayyLOZKCW1vzANBgkqhkiG9w0BAQsFADAS
1217 MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
1218 MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
1219 MIIBCgKCAQEArTCu9fiIclNgDdWHphewM+JW55dCb5yYGlJgCBvwbOx547M9p+tn
1220 zm9QOhsdZDHDZsG9tqnWxE2Nc1HpIJyOlfYsOoonpEoG/Ep6nnK91ngj0bn/JlNy
1221 +i/bwU4r97MOukvnOIQez9/D9jAJaOX2+b8/d4lRz9BsqiwJyg+ynZ5tVVYj7aMi
1222 vXnd6HOnJmtqutOtr3beucJnkd6XbwRkLUcAYATT+ZihOWRbTuKqhCg6zGkJOoUG
1223 f8sX61JjoilxiURA//ftGVbdTCU3DrmGmardp5NNOHbumMYU8Vhmqgx1Bqxb+9he
1224 7G42uW5YWYK/GqJzgVPjjlB2dOGj9KrEWQIDAQABo1AwTjAOBgNVHQ8BAf8EBAMC
1225 AqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUwAwEB/zAWBgNVHREE
1226 DzANggtleGFtcGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAQEAig4AIi9xWs1+pLES
1227 eeGGdSDoclplFpcbXANnsYYFyLf+8pcWgVi2bOmb2gXMbHFkB07MA82wRJAUTaA+
1228 2iNXVQMhPCoA7J6ADUbww9doJX2S9HGyArhiV/MhHtE8txzMn2EKNLdhhk3N9rmV
1229 x/qRbWAY1U2z4BpdrAR87Fe81Nlj7h45csW9K+eS+NgXipiNTIfEShKgCFM8EdxL
1230 1WXg7r9AvYV3TNDPWTjLsm1rQzzZQ7Uvcf6deWiNodZd8MOT/BFLclDPTK6cF2Hr
1231 UU4dq6G4kCwMSxWE4cM3HlZ4u1dyIt47VbkP0rtvkBCXx36y+NXYA5lzntchNFZP
1232 uvEQdw==
1233 -----END CERTIFICATE-----`)
1234
1235 var exampleKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
1236 MIIEpQIBAAKCAQEArTCu9fiIclNgDdWHphewM+JW55dCb5yYGlJgCBvwbOx547M9
1237 p+tnzm9QOhsdZDHDZsG9tqnWxE2Nc1HpIJyOlfYsOoonpEoG/Ep6nnK91ngj0bn/
1238 JlNy+i/bwU4r97MOukvnOIQez9/D9jAJaOX2+b8/d4lRz9BsqiwJyg+ynZ5tVVYj
1239 7aMivXnd6HOnJmtqutOtr3beucJnkd6XbwRkLUcAYATT+ZihOWRbTuKqhCg6zGkJ
1240 OoUGf8sX61JjoilxiURA//ftGVbdTCU3DrmGmardp5NNOHbumMYU8Vhmqgx1Bqxb
1241 +9he7G42uW5YWYK/GqJzgVPjjlB2dOGj9KrEWQIDAQABAoIBAQClt4CiYaaF5ltx
1242 wVDjz6TNcJUBUs3CKE+uWAYFnF5Ii1nyU876Pxj8Aaz9fHZ6Kde0GkwiXY7gFOj1
1243 YHo2tzcELSKS/SEDZcYbYFTGCjq13g1AH74R+SV6WZLn+5m8kPvVrM1ZWap188H5
1244 bmuCkRDqVmIvShkbRW7EwhC35J9fiuW3majC/sjmsxtxyP6geWmu4f5/Ttqahcdb
1245 osPZIgIIPzqAkNtkLTi7+meHYI9wlrGhL7XZTwnJ1Oc/Y67zzmbthLYB5YFSLUew
1246 rXT58jtSjX4gbiQyheBSrWxW08QE4qYg6jJlAdffHhWv72hJW2MCXhuXp8gJs/Do
1247 XLRHGwSBAoGBAMdNtsbe4yae/QeHUPGxNW0ipa0yoTF6i+VYoxvqiRMzDM3+3L8k
1248 dgI1rr4330SivqDahMA/odWtM/9rVwJI2B2QhZLMHA0n9ytH007OO9TghgVB12nN
1249 xosRYBpKdHXyyvV/MUZl7Jux6zKIzRDWOkF95VVYPcAaxJqd1E5/jJ6JAoGBAN51
1250 QrebA1w/jfydeqQTz1sK01sbO4HYj4qGfo/JarVqGEkm1azeBBPPRnHz3jNKnCkM
1251 S4PpqRDased3NIcViXlAgoqPqivZ8mQa/Rb146l7WaTErASHsZ023OGrxsr/Ed6N
1252 P3GrmvxVJjebaFNaQ9sP80dLkpgeas0t2TY8iQNRAoGATOcnx8TpUVW3vNfx29DN
1253 FLdxxkrq9/SZVn3FMlhlXAsuva3B799ZybB9JNjaRdmmRNsMrkHfaFvU3JHGmRMS
1254 kRXa9LHdgRYSwZiNaLMbUyDvlce6HxFPswmZU4u3NGvi9KeHk+pwSgN1BaLTvdNr
1255 1ymE/FF4QlAR3LdZ3JBK6kECgYEA0wW4/CJ31ZIURoW8SNjh4iMqy0nR8SJVR7q9
1256 Y/hU2TKDRyEnoIwaohAFayNCrLUh3W5kVAXa8roB+OgDVAECH5sqOfZ+HorofD19
1257 x8II7ESujLZj1whBXDkm3ovsT7QWZ17lyBZZNvQvBKDPHgKKS8udowv1S4fPGENd
1258 wS07a4ECgYEAwLSbmMIVJme0jFjsp5d1wOGA2Qi2ZwGIAVlsbnJtygrU/hSBfnu8
1259 VfyJSCgg3fPe7kChWKlfcOebVKSb68LKRsz1Lz1KdbY0HOJFp/cT4lKmDAlRY9gq
1260 LB4rdf46lV0mUkvd2/oofIbTrzukjQSnyfLawb/2uJGV1IkTcZcn9CI=
1261 -----END RSA PRIVATE KEY-----`)
1262
View as plain text