...

Source file src/nhooyr.io/websocket/accept_test.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket
     5  
     6  import (
     7  	"bufio"
     8  	"errors"
     9  	"net"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"strings"
    13  	"testing"
    14  
    15  	"nhooyr.io/websocket/internal/test/assert"
    16  	"nhooyr.io/websocket/internal/test/xrand"
    17  )
    18  
    19  func TestAccept(t *testing.T) {
    20  	t.Parallel()
    21  
    22  	t.Run("badClientHandshake", func(t *testing.T) {
    23  		t.Parallel()
    24  
    25  		w := httptest.NewRecorder()
    26  		r := httptest.NewRequest("GET", "/", nil)
    27  
    28  		_, err := Accept(w, r, nil)
    29  		assert.Contains(t, err, "protocol violation")
    30  	})
    31  
    32  	t.Run("badOrigin", func(t *testing.T) {
    33  		t.Parallel()
    34  
    35  		w := httptest.NewRecorder()
    36  		r := httptest.NewRequest("GET", "/", nil)
    37  		r.Header.Set("Connection", "Upgrade")
    38  		r.Header.Set("Upgrade", "websocket")
    39  		r.Header.Set("Sec-WebSocket-Version", "13")
    40  		r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
    41  		r.Header.Set("Origin", "harhar.com")
    42  
    43  		_, err := Accept(w, r, nil)
    44  		assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`)
    45  	})
    46  
    47  	// #247
    48  	t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) {
    49  		t.Parallel()
    50  
    51  		w := httptest.NewRecorder()
    52  		r := httptest.NewRequest("GET", "/", nil)
    53  		r.Header.Set("Connection", "Upgrade")
    54  		r.Header.Set("Upgrade", "websocket")
    55  		r.Header.Set("Sec-WebSocket-Version", "13")
    56  		r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
    57  		r.Header.Set("Origin", "https://harhar.com")
    58  
    59  		_, err := Accept(w, r, nil)
    60  		assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`)
    61  	})
    62  
    63  	t.Run("badCompression", func(t *testing.T) {
    64  		t.Parallel()
    65  
    66  		newRequest := func(extensions string) *http.Request {
    67  			r := httptest.NewRequest("GET", "/", nil)
    68  			r.Header.Set("Connection", "Upgrade")
    69  			r.Header.Set("Upgrade", "websocket")
    70  			r.Header.Set("Sec-WebSocket-Version", "13")
    71  			r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
    72  			r.Header.Set("Sec-WebSocket-Extensions", extensions)
    73  			return r
    74  		}
    75  		errHijack := errors.New("hijack error")
    76  		newResponseWriter := func() http.ResponseWriter {
    77  			return mockHijacker{
    78  				ResponseWriter: httptest.NewRecorder(),
    79  				hijack: func() (net.Conn, *bufio.ReadWriter, error) {
    80  					return nil, nil, errHijack
    81  				},
    82  			}
    83  		}
    84  
    85  		t.Run("withoutFallback", func(t *testing.T) {
    86  			t.Parallel()
    87  
    88  			w := newResponseWriter()
    89  			r := newRequest("permessage-deflate; harharhar")
    90  			_, err := Accept(w, r, &AcceptOptions{
    91  				CompressionMode: CompressionNoContextTakeover,
    92  			})
    93  			assert.ErrorIs(t, errHijack, err)
    94  			assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
    95  		})
    96  		t.Run("withFallback", func(t *testing.T) {
    97  			t.Parallel()
    98  
    99  			w := newResponseWriter()
   100  			r := newRequest("permessage-deflate; harharhar, permessage-deflate")
   101  			_, err := Accept(w, r, &AcceptOptions{
   102  				CompressionMode: CompressionNoContextTakeover,
   103  			})
   104  			assert.ErrorIs(t, errHijack, err)
   105  			assert.Equal(t, "extension header",
   106  				w.Header().Get("Sec-WebSocket-Extensions"),
   107  				CompressionNoContextTakeover.opts().String(),
   108  			)
   109  		})
   110  	})
   111  
   112  	t.Run("requireHttpHijacker", func(t *testing.T) {
   113  		t.Parallel()
   114  
   115  		w := httptest.NewRecorder()
   116  		r := httptest.NewRequest("GET", "/", nil)
   117  		r.Header.Set("Connection", "Upgrade")
   118  		r.Header.Set("Upgrade", "websocket")
   119  		r.Header.Set("Sec-WebSocket-Version", "13")
   120  		r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
   121  
   122  		_, err := Accept(w, r, nil)
   123  		assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
   124  	})
   125  
   126  	t.Run("badHijack", func(t *testing.T) {
   127  		t.Parallel()
   128  
   129  		w := mockHijacker{
   130  			ResponseWriter: httptest.NewRecorder(),
   131  			hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
   132  				return nil, nil, errors.New("haha")
   133  			},
   134  		}
   135  
   136  		r := httptest.NewRequest("GET", "/", nil)
   137  		r.Header.Set("Connection", "Upgrade")
   138  		r.Header.Set("Upgrade", "websocket")
   139  		r.Header.Set("Sec-WebSocket-Version", "13")
   140  		r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
   141  
   142  		_, err := Accept(w, r, nil)
   143  		assert.Contains(t, err, `failed to hijack connection`)
   144  	})
   145  }
   146  
   147  func Test_verifyClientHandshake(t *testing.T) {
   148  	t.Parallel()
   149  
   150  	testCases := []struct {
   151  		name    string
   152  		method  string
   153  		http1   bool
   154  		h       map[string]string
   155  		success bool
   156  	}{
   157  		{
   158  			name: "badConnection",
   159  			h: map[string]string{
   160  				"Connection": "notUpgrade",
   161  			},
   162  		},
   163  		{
   164  			name: "badUpgrade",
   165  			h: map[string]string{
   166  				"Connection": "Upgrade",
   167  				"Upgrade":    "notWebSocket",
   168  			},
   169  		},
   170  		{
   171  			name:   "badMethod",
   172  			method: "POST",
   173  			h: map[string]string{
   174  				"Connection": "Upgrade",
   175  				"Upgrade":    "websocket",
   176  			},
   177  		},
   178  		{
   179  			name: "badWebSocketVersion",
   180  			h: map[string]string{
   181  				"Connection":            "Upgrade",
   182  				"Upgrade":               "websocket",
   183  				"Sec-WebSocket-Version": "14",
   184  			},
   185  		},
   186  		{
   187  			name: "missingWebSocketKey",
   188  			h: map[string]string{
   189  				"Connection":            "Upgrade",
   190  				"Upgrade":               "websocket",
   191  				"Sec-WebSocket-Version": "13",
   192  			},
   193  		},
   194  		{
   195  			name: "emptyWebSocketKey",
   196  			h: map[string]string{
   197  				"Connection":            "Upgrade",
   198  				"Upgrade":               "websocket",
   199  				"Sec-WebSocket-Version": "13",
   200  				"Sec-WebSocket-Key":     "",
   201  			},
   202  		},
   203  		{
   204  			name: "shortWebSocketKey",
   205  			h: map[string]string{
   206  				"Connection":            "Upgrade",
   207  				"Upgrade":               "websocket",
   208  				"Sec-WebSocket-Version": "13",
   209  				"Sec-WebSocket-Key":     xrand.Base64(15),
   210  			},
   211  		},
   212  		{
   213  			name: "invalidWebSocketKey",
   214  			h: map[string]string{
   215  				"Connection":            "Upgrade",
   216  				"Upgrade":               "websocket",
   217  				"Sec-WebSocket-Version": "13",
   218  				"Sec-WebSocket-Key":     "notbase64",
   219  			},
   220  		},
   221  		{
   222  			name: "extraWebSocketKey",
   223  			h: map[string]string{
   224  				"Connection":            "Upgrade",
   225  				"Upgrade":               "websocket",
   226  				"Sec-WebSocket-Version": "13",
   227  				// Kinda cheeky, but http headers are case-insensitive.
   228  				// If 2 sec keys are present, this is a failure condition.
   229  				"Sec-WebSocket-Key": xrand.Base64(16),
   230  				"sec-webSocket-key": xrand.Base64(16),
   231  			},
   232  		},
   233  		{
   234  			name: "badHTTPVersion",
   235  			h: map[string]string{
   236  				"Connection":            "Upgrade",
   237  				"Upgrade":               "websocket",
   238  				"Sec-WebSocket-Version": "13",
   239  				"Sec-WebSocket-Key":     xrand.Base64(16),
   240  			},
   241  			http1: true,
   242  		},
   243  		{
   244  			name: "success",
   245  			h: map[string]string{
   246  				"Connection":            "keep-alive, Upgrade",
   247  				"Upgrade":               "websocket",
   248  				"Sec-WebSocket-Version": "13",
   249  				"Sec-WebSocket-Key":     xrand.Base64(16),
   250  			},
   251  			success: true,
   252  		},
   253  		{
   254  			name: "successSecKeyExtraSpace",
   255  			h: map[string]string{
   256  				"Connection":            "keep-alive, Upgrade",
   257  				"Upgrade":               "websocket",
   258  				"Sec-WebSocket-Version": "13",
   259  				"Sec-WebSocket-Key":     "   " + xrand.Base64(16) + "  ",
   260  			},
   261  			success: true,
   262  		},
   263  	}
   264  
   265  	for _, tc := range testCases {
   266  		tc := tc
   267  		t.Run(tc.name, func(t *testing.T) {
   268  			t.Parallel()
   269  
   270  			r := httptest.NewRequest(tc.method, "/", nil)
   271  
   272  			r.ProtoMajor = 1
   273  			r.ProtoMinor = 1
   274  			if tc.http1 {
   275  				r.ProtoMinor = 0
   276  			}
   277  
   278  			for k, v := range tc.h {
   279  				r.Header.Add(k, v)
   280  			}
   281  
   282  			_, err := verifyClientRequest(httptest.NewRecorder(), r)
   283  			if tc.success {
   284  				assert.Success(t, err)
   285  			} else {
   286  				assert.Error(t, err)
   287  			}
   288  		})
   289  	}
   290  }
   291  
   292  func Test_selectSubprotocol(t *testing.T) {
   293  	t.Parallel()
   294  
   295  	testCases := []struct {
   296  		name            string
   297  		clientProtocols []string
   298  		serverProtocols []string
   299  		negotiated      string
   300  	}{
   301  		{
   302  			name:            "empty",
   303  			clientProtocols: nil,
   304  			serverProtocols: nil,
   305  			negotiated:      "",
   306  		},
   307  		{
   308  			name:            "basic",
   309  			clientProtocols: []string{"echo", "echo2"},
   310  			serverProtocols: []string{"echo2", "echo"},
   311  			negotiated:      "echo2",
   312  		},
   313  		{
   314  			name:            "none",
   315  			clientProtocols: []string{"echo", "echo3"},
   316  			serverProtocols: []string{"echo2", "echo4"},
   317  			negotiated:      "",
   318  		},
   319  		{
   320  			name:            "fallback",
   321  			clientProtocols: []string{"echo", "echo3"},
   322  			serverProtocols: []string{"echo2", "echo3"},
   323  			negotiated:      "echo3",
   324  		},
   325  		{
   326  			name:            "clientCasePresered",
   327  			clientProtocols: []string{"Echo1"},
   328  			serverProtocols: []string{"echo1"},
   329  			negotiated:      "Echo1",
   330  		},
   331  	}
   332  
   333  	for _, tc := range testCases {
   334  		tc := tc
   335  		t.Run(tc.name, func(t *testing.T) {
   336  			t.Parallel()
   337  
   338  			r := httptest.NewRequest("GET", "/", nil)
   339  			r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ","))
   340  
   341  			negotiated := selectSubprotocol(r, tc.serverProtocols)
   342  			assert.Equal(t, "negotiated", tc.negotiated, negotiated)
   343  		})
   344  	}
   345  }
   346  
   347  func Test_authenticateOrigin(t *testing.T) {
   348  	t.Parallel()
   349  
   350  	testCases := []struct {
   351  		name           string
   352  		origin         string
   353  		host           string
   354  		originPatterns []string
   355  		success        bool
   356  	}{
   357  		{
   358  			name:    "none",
   359  			success: true,
   360  			host:    "example.com",
   361  		},
   362  		{
   363  			name:    "invalid",
   364  			origin:  "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}",
   365  			host:    "example.com",
   366  			success: false,
   367  		},
   368  		{
   369  			name:    "unauthorized",
   370  			origin:  "https://example.com",
   371  			host:    "example1.com",
   372  			success: false,
   373  		},
   374  		{
   375  			name:    "authorized",
   376  			origin:  "https://example.com",
   377  			host:    "example.com",
   378  			success: true,
   379  		},
   380  		{
   381  			name:    "authorizedCaseInsensitive",
   382  			origin:  "https://examplE.com",
   383  			host:    "example.com",
   384  			success: true,
   385  		},
   386  		{
   387  			name:   "originPatterns",
   388  			origin: "https://two.examplE.com",
   389  			host:   "example.com",
   390  			originPatterns: []string{
   391  				"*.example.com",
   392  				"bar.com",
   393  			},
   394  			success: true,
   395  		},
   396  		{
   397  			name:   "originPatternsUnauthorized",
   398  			origin: "https://two.examplE.com",
   399  			host:   "example.com",
   400  			originPatterns: []string{
   401  				"exam3.com",
   402  				"bar.com",
   403  			},
   404  			success: false,
   405  		},
   406  	}
   407  
   408  	for _, tc := range testCases {
   409  		tc := tc
   410  		t.Run(tc.name, func(t *testing.T) {
   411  			t.Parallel()
   412  
   413  			r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
   414  			r.Header.Set("Origin", tc.origin)
   415  
   416  			err := authenticateOrigin(r, tc.originPatterns)
   417  			if tc.success {
   418  				assert.Success(t, err)
   419  			} else {
   420  				assert.Error(t, err)
   421  			}
   422  		})
   423  	}
   424  }
   425  
   426  func Test_selectDeflate(t *testing.T) {
   427  	t.Parallel()
   428  
   429  	testCases := []struct {
   430  		name     string
   431  		mode     CompressionMode
   432  		header   string
   433  		expCopts *compressionOptions
   434  		expOK    bool
   435  	}{
   436  		{
   437  			name:     "disabled",
   438  			mode:     CompressionDisabled,
   439  			expCopts: nil,
   440  			expOK:    false,
   441  		},
   442  		{
   443  			name:     "noClientSupport",
   444  			mode:     CompressionNoContextTakeover,
   445  			expCopts: nil,
   446  			expOK:    false,
   447  		},
   448  		{
   449  			name:   "permessage-deflate",
   450  			mode:   CompressionNoContextTakeover,
   451  			header: "permessage-deflate; client_max_window_bits",
   452  			expCopts: &compressionOptions{
   453  				clientNoContextTakeover: true,
   454  				serverNoContextTakeover: true,
   455  			},
   456  			expOK: true,
   457  		},
   458  		{
   459  			name:   "permessage-deflate/unknown-parameter",
   460  			mode:   CompressionNoContextTakeover,
   461  			header: "permessage-deflate; meow",
   462  			expOK:  false,
   463  		},
   464  		{
   465  			name:   "permessage-deflate/unknown-parameter",
   466  			mode:   CompressionNoContextTakeover,
   467  			header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
   468  			expCopts: &compressionOptions{
   469  				clientNoContextTakeover: true,
   470  				serverNoContextTakeover: true,
   471  			},
   472  			expOK: true,
   473  		},
   474  	}
   475  
   476  	for _, tc := range testCases {
   477  		tc := tc
   478  		t.Run(tc.name, func(t *testing.T) {
   479  			t.Parallel()
   480  
   481  			h := http.Header{}
   482  			h.Set("Sec-WebSocket-Extensions", tc.header)
   483  			copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
   484  			assert.Equal(t, "selected options", tc.expOK, ok)
   485  			assert.Equal(t, "compression options", tc.expCopts, copts)
   486  		})
   487  	}
   488  }
   489  
   490  type mockHijacker struct {
   491  	http.ResponseWriter
   492  	hijack func() (net.Conn, *bufio.ReadWriter, error)
   493  }
   494  
   495  var _ http.Hijacker = mockHijacker{}
   496  
   497  func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   498  	return mj.hijack()
   499  }
   500  

View as plain text