...

Source file src/nhooyr.io/websocket/dial_test.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket_test
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/rand"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/url"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  
    18  	"nhooyr.io/websocket"
    19  	"nhooyr.io/websocket/internal/test/assert"
    20  	"nhooyr.io/websocket/internal/util"
    21  	"nhooyr.io/websocket/internal/xsync"
    22  )
    23  
    24  func TestBadDials(t *testing.T) {
    25  	t.Parallel()
    26  
    27  	t.Run("badReq", func(t *testing.T) {
    28  		t.Parallel()
    29  
    30  		testCases := []struct {
    31  			name   string
    32  			url    string
    33  			opts   *websocket.DialOptions
    34  			rand   util.ReaderFunc
    35  			nilCtx bool
    36  		}{
    37  			{
    38  				name: "badURL",
    39  				url:  "://noscheme",
    40  			},
    41  			{
    42  				name: "badURLScheme",
    43  				url:  "ftp://nhooyr.io",
    44  			},
    45  			{
    46  				name: "badTLS",
    47  				url:  "wss://totallyfake.nhooyr.io",
    48  			},
    49  			{
    50  				name: "badReader",
    51  				rand: func(p []byte) (int, error) {
    52  					return 0, io.EOF
    53  				},
    54  			},
    55  			{
    56  				name:   "nilContext",
    57  				url:    "http://localhost",
    58  				nilCtx: true,
    59  			},
    60  		}
    61  
    62  		for _, tc := range testCases {
    63  			tc := tc
    64  			t.Run(tc.name, func(t *testing.T) {
    65  				t.Parallel()
    66  
    67  				var ctx context.Context
    68  				var cancel func()
    69  				if !tc.nilCtx {
    70  					ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
    71  					defer cancel()
    72  				}
    73  
    74  				if tc.rand == nil {
    75  					tc.rand = rand.Reader.Read
    76  				}
    77  
    78  				_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
    79  				assert.Error(t, err)
    80  			})
    81  		}
    82  	})
    83  
    84  	t.Run("badResponse", func(t *testing.T) {
    85  		t.Parallel()
    86  
    87  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
    88  		defer cancel()
    89  
    90  		_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
    91  			HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
    92  				return &http.Response{
    93  					Body: io.NopCloser(strings.NewReader("hi")),
    94  				}, nil
    95  			}),
    96  		})
    97  		assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
    98  	})
    99  
   100  	t.Run("badBody", func(t *testing.T) {
   101  		t.Parallel()
   102  
   103  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   104  		defer cancel()
   105  
   106  		rt := func(r *http.Request) (*http.Response, error) {
   107  			h := http.Header{}
   108  			h.Set("Connection", "Upgrade")
   109  			h.Set("Upgrade", "websocket")
   110  			h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
   111  
   112  			return &http.Response{
   113  				StatusCode: http.StatusSwitchingProtocols,
   114  				Header:     h,
   115  				Body:       io.NopCloser(strings.NewReader("hi")),
   116  			}, nil
   117  		}
   118  
   119  		_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
   120  			HTTPClient: mockHTTPClient(rt),
   121  		})
   122  		assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
   123  	})
   124  }
   125  
   126  func Test_verifyHostOverride(t *testing.T) {
   127  	testCases := []struct {
   128  		name string
   129  		host string
   130  		exp  string
   131  	}{
   132  		{
   133  			name: "noOverride",
   134  			host: "",
   135  			exp:  "example.com",
   136  		},
   137  		{
   138  			name: "hostOverride",
   139  			host: "example.net",
   140  			exp:  "example.net",
   141  		},
   142  	}
   143  
   144  	for _, tc := range testCases {
   145  		tc := tc
   146  		t.Run(tc.name, func(t *testing.T) {
   147  			t.Parallel()
   148  
   149  			ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   150  			defer cancel()
   151  
   152  			rt := func(r *http.Request) (*http.Response, error) {
   153  				assert.Equal(t, "Host", tc.exp, r.Host)
   154  
   155  				h := http.Header{}
   156  				h.Set("Connection", "Upgrade")
   157  				h.Set("Upgrade", "websocket")
   158  				h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
   159  
   160  				return &http.Response{
   161  					StatusCode: http.StatusSwitchingProtocols,
   162  					Header:     h,
   163  					Body:       mockBody{bytes.NewBufferString("hi")},
   164  				}, nil
   165  			}
   166  
   167  			c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
   168  				HTTPClient: mockHTTPClient(rt),
   169  				Host:       tc.host,
   170  			})
   171  			assert.Success(t, err)
   172  			c.CloseNow()
   173  		})
   174  	}
   175  
   176  }
   177  
   178  type mockBody struct {
   179  	*bytes.Buffer
   180  }
   181  
   182  func (mb mockBody) Close() error {
   183  	return nil
   184  }
   185  
   186  func Test_verifyServerHandshake(t *testing.T) {
   187  	t.Parallel()
   188  
   189  	testCases := []struct {
   190  		name     string
   191  		response func(w http.ResponseWriter)
   192  		success  bool
   193  	}{
   194  		{
   195  			name: "badStatus",
   196  			response: func(w http.ResponseWriter) {
   197  				w.WriteHeader(http.StatusOK)
   198  			},
   199  			success: false,
   200  		},
   201  		{
   202  			name: "badConnection",
   203  			response: func(w http.ResponseWriter) {
   204  				w.Header().Set("Connection", "???")
   205  				w.WriteHeader(http.StatusSwitchingProtocols)
   206  			},
   207  			success: false,
   208  		},
   209  		{
   210  			name: "badUpgrade",
   211  			response: func(w http.ResponseWriter) {
   212  				w.Header().Set("Connection", "Upgrade")
   213  				w.Header().Set("Upgrade", "???")
   214  				w.WriteHeader(http.StatusSwitchingProtocols)
   215  			},
   216  			success: false,
   217  		},
   218  		{
   219  			name: "badSecWebSocketAccept",
   220  			response: func(w http.ResponseWriter) {
   221  				w.Header().Set("Connection", "Upgrade")
   222  				w.Header().Set("Upgrade", "websocket")
   223  				w.Header().Set("Sec-WebSocket-Accept", "xd")
   224  				w.WriteHeader(http.StatusSwitchingProtocols)
   225  			},
   226  			success: false,
   227  		},
   228  		{
   229  			name: "badSecWebSocketProtocol",
   230  			response: func(w http.ResponseWriter) {
   231  				w.Header().Set("Connection", "Upgrade")
   232  				w.Header().Set("Upgrade", "websocket")
   233  				w.Header().Set("Sec-WebSocket-Protocol", "xd")
   234  				w.WriteHeader(http.StatusSwitchingProtocols)
   235  			},
   236  			success: false,
   237  		},
   238  		{
   239  			name: "unsupportedExtension",
   240  			response: func(w http.ResponseWriter) {
   241  				w.Header().Set("Connection", "Upgrade")
   242  				w.Header().Set("Upgrade", "websocket")
   243  				w.Header().Set("Sec-WebSocket-Extensions", "meow")
   244  				w.WriteHeader(http.StatusSwitchingProtocols)
   245  			},
   246  			success: false,
   247  		},
   248  		{
   249  			name: "unsupportedDeflateParam",
   250  			response: func(w http.ResponseWriter) {
   251  				w.Header().Set("Connection", "Upgrade")
   252  				w.Header().Set("Upgrade", "websocket")
   253  				w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
   254  				w.WriteHeader(http.StatusSwitchingProtocols)
   255  			},
   256  			success: false,
   257  		},
   258  		{
   259  			name: "success",
   260  			response: func(w http.ResponseWriter) {
   261  				w.Header().Set("Connection", "Upgrade")
   262  				w.Header().Set("Upgrade", "websocket")
   263  				w.WriteHeader(http.StatusSwitchingProtocols)
   264  			},
   265  			success: true,
   266  		},
   267  	}
   268  
   269  	for _, tc := range testCases {
   270  		tc := tc
   271  		t.Run(tc.name, func(t *testing.T) {
   272  			t.Parallel()
   273  
   274  			w := httptest.NewRecorder()
   275  			tc.response(w)
   276  			resp := w.Result()
   277  
   278  			r := httptest.NewRequest("GET", "/", nil)
   279  			key, err := websocket.SecWebSocketKey(rand.Reader)
   280  			assert.Success(t, err)
   281  			r.Header.Set("Sec-WebSocket-Key", key)
   282  
   283  			if resp.Header.Get("Sec-WebSocket-Accept") == "" {
   284  				resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
   285  			}
   286  
   287  			opts := &websocket.DialOptions{
   288  				Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
   289  			}
   290  			_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
   291  			if tc.success {
   292  				assert.Success(t, err)
   293  			} else {
   294  				assert.Error(t, err)
   295  			}
   296  		})
   297  	}
   298  }
   299  
   300  func mockHTTPClient(fn roundTripperFunc) *http.Client {
   301  	return &http.Client{
   302  		Transport: fn,
   303  	}
   304  }
   305  
   306  type roundTripperFunc func(*http.Request) (*http.Response, error)
   307  
   308  func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
   309  	return f(r)
   310  }
   311  
   312  func TestDialRedirect(t *testing.T) {
   313  	t.Parallel()
   314  
   315  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   316  	defer cancel()
   317  
   318  	_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
   319  		HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
   320  			resp := &http.Response{
   321  				Header: http.Header{},
   322  			}
   323  			if r.URL.Scheme != "https" {
   324  				resp.Header.Set("Location", "wss://example.com")
   325  				resp.StatusCode = http.StatusFound
   326  				return resp, nil
   327  			}
   328  			resp.Header.Set("Connection", "Upgrade")
   329  			resp.Header.Set("Upgrade", "meow")
   330  			resp.StatusCode = http.StatusSwitchingProtocols
   331  			return resp, nil
   332  		}),
   333  	})
   334  	assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
   335  }
   336  
   337  type forwardProxy struct {
   338  	hc *http.Client
   339  }
   340  
   341  func newForwardProxy() *forwardProxy {
   342  	return &forwardProxy{
   343  		hc: &http.Client{},
   344  	}
   345  }
   346  
   347  func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   348  	ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
   349  	defer cancel()
   350  
   351  	r = r.WithContext(ctx)
   352  	r.RequestURI = ""
   353  	resp, err := fc.hc.Do(r)
   354  	if err != nil {
   355  		http.Error(w, err.Error(), http.StatusBadRequest)
   356  		return
   357  	}
   358  	defer resp.Body.Close()
   359  
   360  	for k, v := range resp.Header {
   361  		w.Header()[k] = v
   362  	}
   363  	w.Header().Set("PROXIED", "true")
   364  	w.WriteHeader(resp.StatusCode)
   365  	if resprw, ok := resp.Body.(io.ReadWriter); ok {
   366  		c, brw, err := w.(http.Hijacker).Hijack()
   367  		if err != nil {
   368  			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   369  			return
   370  		}
   371  		brw.Flush()
   372  
   373  		errc1 := xsync.Go(func() error {
   374  			_, err := io.Copy(c, resprw)
   375  			return err
   376  		})
   377  		errc2 := xsync.Go(func() error {
   378  			_, err := io.Copy(resprw, c)
   379  			return err
   380  		})
   381  		select {
   382  		case <-errc1:
   383  		case <-errc2:
   384  		case <-r.Context().Done():
   385  		}
   386  	} else {
   387  		io.Copy(w, resp.Body)
   388  	}
   389  }
   390  
   391  func TestDialViaProxy(t *testing.T) {
   392  	t.Parallel()
   393  
   394  	ps := httptest.NewServer(newForwardProxy())
   395  	defer ps.Close()
   396  
   397  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   398  		err := echoServer(w, r, nil)
   399  		assert.Success(t, err)
   400  	}))
   401  	defer s.Close()
   402  
   403  	psu, err := url.Parse(ps.URL)
   404  	assert.Success(t, err)
   405  	proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
   406  	proxyTransport.Proxy = http.ProxyURL(psu)
   407  
   408  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   409  	defer cancel()
   410  	c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
   411  		HTTPClient: &http.Client{
   412  			Transport: proxyTransport,
   413  		},
   414  	})
   415  	assert.Success(t, err)
   416  	assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
   417  
   418  	assertEcho(t, ctx, c)
   419  	assertClose(t, c)
   420  }
   421  

View as plain text