1
2
3 package websocket_test
4
5 import (
6 "bytes"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "net/http/httptest"
13 "os"
14 "os/exec"
15 "strings"
16 "testing"
17 "time"
18
19 "nhooyr.io/websocket"
20 "nhooyr.io/websocket/internal/errd"
21 "nhooyr.io/websocket/internal/test/assert"
22 "nhooyr.io/websocket/internal/test/wstest"
23 "nhooyr.io/websocket/internal/test/xrand"
24 "nhooyr.io/websocket/internal/xsync"
25 "nhooyr.io/websocket/wsjson"
26 )
27
28 func TestConn(t *testing.T) {
29 t.Parallel()
30
31 t.Run("fuzzData", func(t *testing.T) {
32 t.Parallel()
33
34 compressionMode := func() websocket.CompressionMode {
35 return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1))
36 }
37
38 for i := 0; i < 5; i++ {
39 t.Run("", func(t *testing.T) {
40 tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
41 CompressionMode: compressionMode(),
42 CompressionThreshold: xrand.Int(9999),
43 }, &websocket.AcceptOptions{
44 CompressionMode: compressionMode(),
45 CompressionThreshold: xrand.Int(9999),
46 })
47
48 tt.goEchoLoop(c2)
49
50 c1.SetReadLimit(131072)
51
52 for i := 0; i < 5; i++ {
53 err := wstest.Echo(tt.ctx, c1, 131072)
54 assert.Success(t, err)
55 }
56
57 err := c1.Close(websocket.StatusNormalClosure, "")
58 assert.Success(t, err)
59 })
60 }
61 })
62
63 t.Run("badClose", func(t *testing.T) {
64 tt, c1, c2 := newConnTest(t, nil, nil)
65
66 c2.CloseRead(tt.ctx)
67
68 err := c1.Close(-1, "")
69 assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
70 })
71
72 t.Run("ping", func(t *testing.T) {
73 tt, c1, c2 := newConnTest(t, nil, nil)
74
75 c1.CloseRead(tt.ctx)
76 c2.CloseRead(tt.ctx)
77
78 for i := 0; i < 10; i++ {
79 err := c1.Ping(tt.ctx)
80 assert.Success(t, err)
81 }
82
83 err := c1.Close(websocket.StatusNormalClosure, "")
84 assert.Success(t, err)
85 })
86
87 t.Run("badPing", func(t *testing.T) {
88 tt, c1, c2 := newConnTest(t, nil, nil)
89
90 c2.CloseRead(tt.ctx)
91
92 ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
93 defer cancel()
94
95 err := c1.Ping(ctx)
96 assert.Contains(t, err, "failed to wait for pong")
97 })
98
99 t.Run("concurrentWrite", func(t *testing.T) {
100 tt, c1, c2 := newConnTest(t, nil, nil)
101
102 tt.goDiscardLoop(c2)
103
104 msg := xrand.Bytes(xrand.Int(9999))
105 const count = 100
106 errs := make(chan error, count)
107
108 for i := 0; i < count; i++ {
109 go func() {
110 select {
111 case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
112 case <-tt.ctx.Done():
113 return
114 }
115 }()
116 }
117
118 for i := 0; i < count; i++ {
119 select {
120 case err := <-errs:
121 assert.Success(t, err)
122 case <-tt.ctx.Done():
123 t.Fatal(tt.ctx.Err())
124 }
125 }
126
127 err := c1.Close(websocket.StatusNormalClosure, "")
128 assert.Success(t, err)
129 })
130
131 t.Run("concurrentWriteError", func(t *testing.T) {
132 tt, c1, _ := newConnTest(t, nil, nil)
133
134 _, err := c1.Writer(tt.ctx, websocket.MessageText)
135 assert.Success(t, err)
136
137 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
138 defer cancel()
139
140 err = c1.Write(ctx, websocket.MessageText, []byte("x"))
141 if !errors.Is(err, context.DeadlineExceeded) {
142 t.Fatalf("unexpected error: %#v", err)
143 }
144 })
145
146 t.Run("netConn", func(t *testing.T) {
147 tt, c1, c2 := newConnTest(t, nil, nil)
148
149 n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
150 n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
151
152
153 d, _ := tt.ctx.Deadline()
154 n1.SetDeadline(d)
155 n1.SetDeadline(time.Time{})
156
157 assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
158 assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String())
159 assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network())
160
161 errs := xsync.Go(func() error {
162 _, err := n2.Write([]byte("hello"))
163 if err != nil {
164 return err
165 }
166 return n2.Close()
167 })
168
169 b, err := io.ReadAll(n1)
170 assert.Success(t, err)
171
172 _, err = n1.Read(nil)
173 assert.Equal(t, "read error", err, io.EOF)
174
175 select {
176 case err := <-errs:
177 assert.Success(t, err)
178 case <-tt.ctx.Done():
179 t.Fatal(tt.ctx.Err())
180 }
181
182 assert.Equal(t, "read msg", []byte("hello"), b)
183 })
184
185 t.Run("netConn/BadMsg", func(t *testing.T) {
186 tt, c1, c2 := newConnTest(t, nil, nil)
187
188 n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
189 n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
190
191 c2.CloseRead(tt.ctx)
192 errs := xsync.Go(func() error {
193 _, err := n2.Write([]byte("hello"))
194 return err
195 })
196
197 _, err := io.ReadAll(n1)
198 assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
199
200 select {
201 case err := <-errs:
202 assert.Success(t, err)
203 case <-tt.ctx.Done():
204 t.Fatal(tt.ctx.Err())
205 }
206 })
207
208 t.Run("netConn/readLimit", func(t *testing.T) {
209 tt, c1, c2 := newConnTest(t, nil, nil)
210
211 n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
212 n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
213
214 s := strings.Repeat("papa", 1<<20)
215 errs := xsync.Go(func() error {
216 _, err := n2.Write([]byte(s))
217 if err != nil {
218 return err
219 }
220 return n2.Close()
221 })
222
223 b, err := io.ReadAll(n1)
224 assert.Success(t, err)
225
226 _, err = n1.Read(nil)
227 assert.Equal(t, "read error", err, io.EOF)
228
229 select {
230 case err := <-errs:
231 assert.Success(t, err)
232 case <-tt.ctx.Done():
233 t.Fatal(tt.ctx.Err())
234 }
235
236 assert.Equal(t, "read msg", s, string(b))
237 })
238
239 t.Run("netConn/pastDeadline", func(t *testing.T) {
240 tt, c1, c2 := newConnTest(t, nil, nil)
241
242 n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
243 n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
244
245 n1.SetDeadline(time.Now().Add(-time.Minute))
246 n2.SetDeadline(time.Now().Add(-time.Minute))
247
248
249 })
250
251 t.Run("wsjson", func(t *testing.T) {
252 tt, c1, c2 := newConnTest(t, nil, nil)
253
254 tt.goEchoLoop(c2)
255
256 c1.SetReadLimit(1 << 30)
257
258 exp := xrand.String(xrand.Int(131072))
259
260 werr := xsync.Go(func() error {
261 return wsjson.Write(tt.ctx, c1, exp)
262 })
263
264 var act interface{}
265 err := wsjson.Read(tt.ctx, c1, &act)
266 assert.Success(t, err)
267 assert.Equal(t, "read msg", exp, act)
268
269 select {
270 case err := <-werr:
271 assert.Success(t, err)
272 case <-tt.ctx.Done():
273 t.Fatal(tt.ctx.Err())
274 }
275
276 err = c1.Close(websocket.StatusNormalClosure, "")
277 assert.Success(t, err)
278 })
279
280 t.Run("HTTPClient.Timeout", func(t *testing.T) {
281 tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
282 HTTPClient: &http.Client{Timeout: time.Second * 5},
283 }, nil)
284
285 tt.goEchoLoop(c2)
286
287 c1.SetReadLimit(1 << 30)
288
289 exp := xrand.String(xrand.Int(131072))
290
291 werr := xsync.Go(func() error {
292 return wsjson.Write(tt.ctx, c1, exp)
293 })
294
295 var act interface{}
296 err := wsjson.Read(tt.ctx, c1, &act)
297 assert.Success(t, err)
298 assert.Equal(t, "read msg", exp, act)
299
300 select {
301 case err := <-werr:
302 assert.Success(t, err)
303 case <-tt.ctx.Done():
304 t.Fatal(tt.ctx.Err())
305 }
306
307 err = c1.Close(websocket.StatusNormalClosure, "")
308 assert.Success(t, err)
309 })
310
311 t.Run("CloseNow", func(t *testing.T) {
312 _, c1, c2 := newConnTest(t, nil, nil)
313
314 err1 := c1.CloseNow()
315 err2 := c2.CloseNow()
316 assert.Success(t, err1)
317 assert.Success(t, err2)
318 err1 = c1.CloseNow()
319 err2 = c2.CloseNow()
320 assert.ErrorIs(t, websocket.ErrClosed, err1)
321 assert.ErrorIs(t, websocket.ErrClosed, err2)
322 })
323
324 t.Run("MidReadClose", func(t *testing.T) {
325 tt, c1, c2 := newConnTest(t, nil, nil)
326
327 tt.goEchoLoop(c2)
328
329 c1.SetReadLimit(131072)
330
331 for i := 0; i < 5; i++ {
332 err := wstest.Echo(tt.ctx, c1, 131072)
333 assert.Success(t, err)
334 }
335
336 err := wsjson.Write(tt.ctx, c1, "four")
337 assert.Success(t, err)
338 _, _, err = c1.Reader(tt.ctx)
339 assert.Success(t, err)
340
341 err = c1.Close(websocket.StatusNormalClosure, "")
342 assert.Success(t, err)
343 })
344 }
345
346 func TestWasm(t *testing.T) {
347 t.Parallel()
348
349 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
350 err := echoServer(w, r, &websocket.AcceptOptions{
351 Subprotocols: []string{"echo"},
352 InsecureSkipVerify: true,
353 })
354 if err != nil {
355 t.Error(err)
356 }
357 }))
358 defer s.Close()
359
360 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
361 defer cancel()
362
363 cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".")
364 cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
365
366 b, err := cmd.CombinedOutput()
367 if err != nil {
368 t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
369 }
370 }
371
372 func assertCloseStatus(exp websocket.StatusCode, err error) error {
373 if websocket.CloseStatus(err) == -1 {
374 return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
375 }
376 if websocket.CloseStatus(err) != exp {
377 return fmt.Errorf("expected close status %v but got %v", exp, err)
378 }
379 return nil
380 }
381
382 type connTest struct {
383 t testing.TB
384 ctx context.Context
385 }
386
387 func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
388 if t, ok := t.(*testing.T); ok {
389 t.Parallel()
390 }
391 t.Helper()
392
393 ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
394 tt = &connTest{t: t, ctx: ctx}
395 t.Cleanup(cancel)
396
397 c1, c2 = wstest.Pipe(dialOpts, acceptOpts)
398 if xrand.Bool() {
399 c1, c2 = c2, c1
400 }
401 t.Cleanup(func() {
402 c2.CloseNow()
403 c1.CloseNow()
404 })
405
406 return tt, c1, c2
407 }
408
409 func (tt *connTest) goEchoLoop(c *websocket.Conn) {
410 ctx, cancel := context.WithCancel(tt.ctx)
411
412 echoLoopErr := xsync.Go(func() error {
413 err := wstest.EchoLoop(ctx, c)
414 return assertCloseStatus(websocket.StatusNormalClosure, err)
415 })
416 tt.t.Cleanup(func() {
417 cancel()
418 err := <-echoLoopErr
419 if err != nil {
420 tt.t.Errorf("echo loop error: %v", err)
421 }
422 })
423 }
424
425 func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
426 ctx, cancel := context.WithCancel(tt.ctx)
427
428 discardLoopErr := xsync.Go(func() error {
429 defer c.Close(websocket.StatusInternalError, "")
430
431 for {
432 _, _, err := c.Read(ctx)
433 if err != nil {
434 return assertCloseStatus(websocket.StatusNormalClosure, err)
435 }
436 }
437 })
438 tt.t.Cleanup(func() {
439 cancel()
440 err := <-discardLoopErr
441 if err != nil {
442 tt.t.Errorf("discard loop error: %v", err)
443 }
444 })
445 }
446
447 func BenchmarkConn(b *testing.B) {
448 var benchCases = []struct {
449 name string
450 mode websocket.CompressionMode
451 }{
452 {
453 name: "disabledCompress",
454 mode: websocket.CompressionDisabled,
455 },
456 {
457 name: "compressContextTakeover",
458 mode: websocket.CompressionContextTakeover,
459 },
460 {
461 name: "compressNoContext",
462 mode: websocket.CompressionNoContextTakeover,
463 },
464 }
465 for _, bc := range benchCases {
466 b.Run(bc.name, func(b *testing.B) {
467 bb, c1, c2 := newConnTest(b, &websocket.DialOptions{
468 CompressionMode: bc.mode,
469 }, &websocket.AcceptOptions{
470 CompressionMode: bc.mode,
471 })
472
473 bb.goEchoLoop(c2)
474
475 bytesWritten := c1.RecordBytesWritten()
476 bytesRead := c1.RecordBytesRead()
477
478 msg := []byte(strings.Repeat("1234", 128))
479 readBuf := make([]byte, len(msg))
480 writes := make(chan struct{})
481 defer close(writes)
482 werrs := make(chan error)
483
484 go func() {
485 for range writes {
486 select {
487 case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
488 case <-bb.ctx.Done():
489 return
490 }
491 }
492 }()
493 b.SetBytes(int64(len(msg)))
494 b.ReportAllocs()
495 b.ResetTimer()
496 for i := 0; i < b.N; i++ {
497 select {
498 case writes <- struct{}{}:
499 case <-bb.ctx.Done():
500 b.Fatal(bb.ctx.Err())
501 }
502
503 typ, r, err := c1.Reader(bb.ctx)
504 if err != nil {
505 b.Fatal(i, err)
506 }
507 if websocket.MessageText != typ {
508 assert.Equal(b, "data type", websocket.MessageText, typ)
509 }
510
511 _, err = io.ReadFull(r, readBuf)
512 if err != nil {
513 b.Fatal(err)
514 }
515
516 n2, err := r.Read(readBuf)
517 if err != io.EOF {
518 assert.Equal(b, "read err", io.EOF, err)
519 }
520 if n2 != 0 {
521 assert.Equal(b, "n2", 0, n2)
522 }
523
524 if !bytes.Equal(msg, readBuf) {
525 assert.Equal(b, "msg", msg, readBuf)
526 }
527
528 select {
529 case err = <-werrs:
530 case <-bb.ctx.Done():
531 b.Fatal(bb.ctx.Err())
532 }
533 if err != nil {
534 b.Fatal(err)
535 }
536 }
537 b.StopTimer()
538
539 b.ReportMetric(float64(*bytesWritten/b.N), "written/op")
540 b.ReportMetric(float64(*bytesRead/b.N), "read/op")
541
542 err := c1.Close(websocket.StatusNormalClosure, "")
543 assert.Success(b, err)
544 })
545 }
546 }
547
548 func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) {
549 defer errd.Wrap(&err, "echo server failed")
550
551 c, err := websocket.Accept(w, r, opts)
552 if err != nil {
553 return err
554 }
555 defer c.Close(websocket.StatusInternalError, "")
556
557 err = wstest.EchoLoop(r.Context(), c)
558 return assertCloseStatus(websocket.StatusNormalClosure, err)
559 }
560
561 func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) {
562 exp := xrand.String(xrand.Int(131072))
563
564 werr := xsync.Go(func() error {
565 return wsjson.Write(ctx, c, exp)
566 })
567
568 var act interface{}
569 c.SetReadLimit(1 << 30)
570 err := wsjson.Read(ctx, c, &act)
571 assert.Success(tb, err)
572 assert.Equal(tb, "read msg", exp, act)
573
574 select {
575 case err := <-werr:
576 assert.Success(tb, err)
577 case <-ctx.Done():
578 tb.Fatal(ctx.Err())
579 }
580 }
581
582 func assertClose(tb testing.TB, c *websocket.Conn) {
583 tb.Helper()
584 err := c.Close(websocket.StatusNormalClosure, "")
585 assert.Success(tb, err)
586 }
587
588 func TestConcurrentClosePing(t *testing.T) {
589 t.Parallel()
590 for i := 0; i < 64; i++ {
591 func() {
592 c1, c2 := wstest.Pipe(nil, nil)
593 defer c1.CloseNow()
594 defer c2.CloseNow()
595 c1.CloseRead(context.Background())
596 c2.CloseRead(context.Background())
597 errc := xsync.Go(func() error {
598 for range time.Tick(time.Millisecond) {
599 err := c1.Ping(context.Background())
600 if err != nil {
601 return err
602 }
603 }
604 panic("unreachable")
605 })
606
607 time.Sleep(10 * time.Millisecond)
608 assert.Success(t, c1.Close(websocket.StatusNormalClosure, ""))
609 <-errc
610 }()
611 }
612 }
613
View as plain text