1
2
3
4 package websocket
5
6 import (
7 "bufio"
8 "errors"
9 "net"
10 "net/http"
11 "net/http/httptest"
12 "strings"
13 "testing"
14
15 "nhooyr.io/websocket/internal/test/assert"
16 "nhooyr.io/websocket/internal/test/xrand"
17 )
18
19 func TestAccept(t *testing.T) {
20 t.Parallel()
21
22 t.Run("badClientHandshake", func(t *testing.T) {
23 t.Parallel()
24
25 w := httptest.NewRecorder()
26 r := httptest.NewRequest("GET", "/", nil)
27
28 _, err := Accept(w, r, nil)
29 assert.Contains(t, err, "protocol violation")
30 })
31
32 t.Run("badOrigin", func(t *testing.T) {
33 t.Parallel()
34
35 w := httptest.NewRecorder()
36 r := httptest.NewRequest("GET", "/", nil)
37 r.Header.Set("Connection", "Upgrade")
38 r.Header.Set("Upgrade", "websocket")
39 r.Header.Set("Sec-WebSocket-Version", "13")
40 r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
41 r.Header.Set("Origin", "harhar.com")
42
43 _, err := Accept(w, r, nil)
44 assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`)
45 })
46
47
48 t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) {
49 t.Parallel()
50
51 w := httptest.NewRecorder()
52 r := httptest.NewRequest("GET", "/", nil)
53 r.Header.Set("Connection", "Upgrade")
54 r.Header.Set("Upgrade", "websocket")
55 r.Header.Set("Sec-WebSocket-Version", "13")
56 r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
57 r.Header.Set("Origin", "https://harhar.com")
58
59 _, err := Accept(w, r, nil)
60 assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`)
61 })
62
63 t.Run("badCompression", func(t *testing.T) {
64 t.Parallel()
65
66 newRequest := func(extensions string) *http.Request {
67 r := httptest.NewRequest("GET", "/", nil)
68 r.Header.Set("Connection", "Upgrade")
69 r.Header.Set("Upgrade", "websocket")
70 r.Header.Set("Sec-WebSocket-Version", "13")
71 r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
72 r.Header.Set("Sec-WebSocket-Extensions", extensions)
73 return r
74 }
75 errHijack := errors.New("hijack error")
76 newResponseWriter := func() http.ResponseWriter {
77 return mockHijacker{
78 ResponseWriter: httptest.NewRecorder(),
79 hijack: func() (net.Conn, *bufio.ReadWriter, error) {
80 return nil, nil, errHijack
81 },
82 }
83 }
84
85 t.Run("withoutFallback", func(t *testing.T) {
86 t.Parallel()
87
88 w := newResponseWriter()
89 r := newRequest("permessage-deflate; harharhar")
90 _, err := Accept(w, r, &AcceptOptions{
91 CompressionMode: CompressionNoContextTakeover,
92 })
93 assert.ErrorIs(t, errHijack, err)
94 assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
95 })
96 t.Run("withFallback", func(t *testing.T) {
97 t.Parallel()
98
99 w := newResponseWriter()
100 r := newRequest("permessage-deflate; harharhar, permessage-deflate")
101 _, err := Accept(w, r, &AcceptOptions{
102 CompressionMode: CompressionNoContextTakeover,
103 })
104 assert.ErrorIs(t, errHijack, err)
105 assert.Equal(t, "extension header",
106 w.Header().Get("Sec-WebSocket-Extensions"),
107 CompressionNoContextTakeover.opts().String(),
108 )
109 })
110 })
111
112 t.Run("requireHttpHijacker", func(t *testing.T) {
113 t.Parallel()
114
115 w := httptest.NewRecorder()
116 r := httptest.NewRequest("GET", "/", nil)
117 r.Header.Set("Connection", "Upgrade")
118 r.Header.Set("Upgrade", "websocket")
119 r.Header.Set("Sec-WebSocket-Version", "13")
120 r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
121
122 _, err := Accept(w, r, nil)
123 assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
124 })
125
126 t.Run("badHijack", func(t *testing.T) {
127 t.Parallel()
128
129 w := mockHijacker{
130 ResponseWriter: httptest.NewRecorder(),
131 hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
132 return nil, nil, errors.New("haha")
133 },
134 }
135
136 r := httptest.NewRequest("GET", "/", nil)
137 r.Header.Set("Connection", "Upgrade")
138 r.Header.Set("Upgrade", "websocket")
139 r.Header.Set("Sec-WebSocket-Version", "13")
140 r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
141
142 _, err := Accept(w, r, nil)
143 assert.Contains(t, err, `failed to hijack connection`)
144 })
145 }
146
147 func Test_verifyClientHandshake(t *testing.T) {
148 t.Parallel()
149
150 testCases := []struct {
151 name string
152 method string
153 http1 bool
154 h map[string]string
155 success bool
156 }{
157 {
158 name: "badConnection",
159 h: map[string]string{
160 "Connection": "notUpgrade",
161 },
162 },
163 {
164 name: "badUpgrade",
165 h: map[string]string{
166 "Connection": "Upgrade",
167 "Upgrade": "notWebSocket",
168 },
169 },
170 {
171 name: "badMethod",
172 method: "POST",
173 h: map[string]string{
174 "Connection": "Upgrade",
175 "Upgrade": "websocket",
176 },
177 },
178 {
179 name: "badWebSocketVersion",
180 h: map[string]string{
181 "Connection": "Upgrade",
182 "Upgrade": "websocket",
183 "Sec-WebSocket-Version": "14",
184 },
185 },
186 {
187 name: "missingWebSocketKey",
188 h: map[string]string{
189 "Connection": "Upgrade",
190 "Upgrade": "websocket",
191 "Sec-WebSocket-Version": "13",
192 },
193 },
194 {
195 name: "emptyWebSocketKey",
196 h: map[string]string{
197 "Connection": "Upgrade",
198 "Upgrade": "websocket",
199 "Sec-WebSocket-Version": "13",
200 "Sec-WebSocket-Key": "",
201 },
202 },
203 {
204 name: "shortWebSocketKey",
205 h: map[string]string{
206 "Connection": "Upgrade",
207 "Upgrade": "websocket",
208 "Sec-WebSocket-Version": "13",
209 "Sec-WebSocket-Key": xrand.Base64(15),
210 },
211 },
212 {
213 name: "invalidWebSocketKey",
214 h: map[string]string{
215 "Connection": "Upgrade",
216 "Upgrade": "websocket",
217 "Sec-WebSocket-Version": "13",
218 "Sec-WebSocket-Key": "notbase64",
219 },
220 },
221 {
222 name: "extraWebSocketKey",
223 h: map[string]string{
224 "Connection": "Upgrade",
225 "Upgrade": "websocket",
226 "Sec-WebSocket-Version": "13",
227
228
229 "Sec-WebSocket-Key": xrand.Base64(16),
230 "sec-webSocket-key": xrand.Base64(16),
231 },
232 },
233 {
234 name: "badHTTPVersion",
235 h: map[string]string{
236 "Connection": "Upgrade",
237 "Upgrade": "websocket",
238 "Sec-WebSocket-Version": "13",
239 "Sec-WebSocket-Key": xrand.Base64(16),
240 },
241 http1: true,
242 },
243 {
244 name: "success",
245 h: map[string]string{
246 "Connection": "keep-alive, Upgrade",
247 "Upgrade": "websocket",
248 "Sec-WebSocket-Version": "13",
249 "Sec-WebSocket-Key": xrand.Base64(16),
250 },
251 success: true,
252 },
253 {
254 name: "successSecKeyExtraSpace",
255 h: map[string]string{
256 "Connection": "keep-alive, Upgrade",
257 "Upgrade": "websocket",
258 "Sec-WebSocket-Version": "13",
259 "Sec-WebSocket-Key": " " + xrand.Base64(16) + " ",
260 },
261 success: true,
262 },
263 }
264
265 for _, tc := range testCases {
266 tc := tc
267 t.Run(tc.name, func(t *testing.T) {
268 t.Parallel()
269
270 r := httptest.NewRequest(tc.method, "/", nil)
271
272 r.ProtoMajor = 1
273 r.ProtoMinor = 1
274 if tc.http1 {
275 r.ProtoMinor = 0
276 }
277
278 for k, v := range tc.h {
279 r.Header.Add(k, v)
280 }
281
282 _, err := verifyClientRequest(httptest.NewRecorder(), r)
283 if tc.success {
284 assert.Success(t, err)
285 } else {
286 assert.Error(t, err)
287 }
288 })
289 }
290 }
291
292 func Test_selectSubprotocol(t *testing.T) {
293 t.Parallel()
294
295 testCases := []struct {
296 name string
297 clientProtocols []string
298 serverProtocols []string
299 negotiated string
300 }{
301 {
302 name: "empty",
303 clientProtocols: nil,
304 serverProtocols: nil,
305 negotiated: "",
306 },
307 {
308 name: "basic",
309 clientProtocols: []string{"echo", "echo2"},
310 serverProtocols: []string{"echo2", "echo"},
311 negotiated: "echo2",
312 },
313 {
314 name: "none",
315 clientProtocols: []string{"echo", "echo3"},
316 serverProtocols: []string{"echo2", "echo4"},
317 negotiated: "",
318 },
319 {
320 name: "fallback",
321 clientProtocols: []string{"echo", "echo3"},
322 serverProtocols: []string{"echo2", "echo3"},
323 negotiated: "echo3",
324 },
325 {
326 name: "clientCasePresered",
327 clientProtocols: []string{"Echo1"},
328 serverProtocols: []string{"echo1"},
329 negotiated: "Echo1",
330 },
331 }
332
333 for _, tc := range testCases {
334 tc := tc
335 t.Run(tc.name, func(t *testing.T) {
336 t.Parallel()
337
338 r := httptest.NewRequest("GET", "/", nil)
339 r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ","))
340
341 negotiated := selectSubprotocol(r, tc.serverProtocols)
342 assert.Equal(t, "negotiated", tc.negotiated, negotiated)
343 })
344 }
345 }
346
347 func Test_authenticateOrigin(t *testing.T) {
348 t.Parallel()
349
350 testCases := []struct {
351 name string
352 origin string
353 host string
354 originPatterns []string
355 success bool
356 }{
357 {
358 name: "none",
359 success: true,
360 host: "example.com",
361 },
362 {
363 name: "invalid",
364 origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}",
365 host: "example.com",
366 success: false,
367 },
368 {
369 name: "unauthorized",
370 origin: "https://example.com",
371 host: "example1.com",
372 success: false,
373 },
374 {
375 name: "authorized",
376 origin: "https://example.com",
377 host: "example.com",
378 success: true,
379 },
380 {
381 name: "authorizedCaseInsensitive",
382 origin: "https://examplE.com",
383 host: "example.com",
384 success: true,
385 },
386 {
387 name: "originPatterns",
388 origin: "https://two.examplE.com",
389 host: "example.com",
390 originPatterns: []string{
391 "*.example.com",
392 "bar.com",
393 },
394 success: true,
395 },
396 {
397 name: "originPatternsUnauthorized",
398 origin: "https://two.examplE.com",
399 host: "example.com",
400 originPatterns: []string{
401 "exam3.com",
402 "bar.com",
403 },
404 success: false,
405 },
406 }
407
408 for _, tc := range testCases {
409 tc := tc
410 t.Run(tc.name, func(t *testing.T) {
411 t.Parallel()
412
413 r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
414 r.Header.Set("Origin", tc.origin)
415
416 err := authenticateOrigin(r, tc.originPatterns)
417 if tc.success {
418 assert.Success(t, err)
419 } else {
420 assert.Error(t, err)
421 }
422 })
423 }
424 }
425
426 func Test_selectDeflate(t *testing.T) {
427 t.Parallel()
428
429 testCases := []struct {
430 name string
431 mode CompressionMode
432 header string
433 expCopts *compressionOptions
434 expOK bool
435 }{
436 {
437 name: "disabled",
438 mode: CompressionDisabled,
439 expCopts: nil,
440 expOK: false,
441 },
442 {
443 name: "noClientSupport",
444 mode: CompressionNoContextTakeover,
445 expCopts: nil,
446 expOK: false,
447 },
448 {
449 name: "permessage-deflate",
450 mode: CompressionNoContextTakeover,
451 header: "permessage-deflate; client_max_window_bits",
452 expCopts: &compressionOptions{
453 clientNoContextTakeover: true,
454 serverNoContextTakeover: true,
455 },
456 expOK: true,
457 },
458 {
459 name: "permessage-deflate/unknown-parameter",
460 mode: CompressionNoContextTakeover,
461 header: "permessage-deflate; meow",
462 expOK: false,
463 },
464 {
465 name: "permessage-deflate/unknown-parameter",
466 mode: CompressionNoContextTakeover,
467 header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
468 expCopts: &compressionOptions{
469 clientNoContextTakeover: true,
470 serverNoContextTakeover: true,
471 },
472 expOK: true,
473 },
474 }
475
476 for _, tc := range testCases {
477 tc := tc
478 t.Run(tc.name, func(t *testing.T) {
479 t.Parallel()
480
481 h := http.Header{}
482 h.Set("Sec-WebSocket-Extensions", tc.header)
483 copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
484 assert.Equal(t, "selected options", tc.expOK, ok)
485 assert.Equal(t, "compression options", tc.expCopts, copts)
486 })
487 }
488 }
489
490 type mockHijacker struct {
491 http.ResponseWriter
492 hijack func() (net.Conn, *bufio.ReadWriter, error)
493 }
494
495 var _ http.Hijacker = mockHijacker{}
496
497 func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
498 return mj.hijack()
499 }
500
View as plain text