1
2
3
4
5 package websocket
6
7 import (
8 "bufio"
9 "bytes"
10 "errors"
11 "fmt"
12 "io"
13 "io/ioutil"
14 "net"
15 "reflect"
16 "sync"
17 "testing"
18 "testing/iotest"
19 "time"
20 )
21
22 var _ net.Error = errWriteTimeout
23
24 type fakeNetConn struct {
25 io.Reader
26 io.Writer
27 }
28
29 func (c fakeNetConn) Close() error { return nil }
30 func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
31 func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
32 func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
33 func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
34 func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
35
36 type fakeAddr int
37
38 var (
39 localAddr = fakeAddr(1)
40 remoteAddr = fakeAddr(2)
41 )
42
43 func (a fakeAddr) Network() string {
44 return "net"
45 }
46
47 func (a fakeAddr) String() string {
48 return "str"
49 }
50
51
52
53 func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
54 return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
55 }
56
57 func TestFraming(t *testing.T) {
58 frameSizes := []int{
59 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
60
61 }
62 var readChunkers = []struct {
63 name string
64 f func(io.Reader) io.Reader
65 }{
66 {"half", iotest.HalfReader},
67 {"one", iotest.OneByteReader},
68 {"asis", func(r io.Reader) io.Reader { return r }},
69 }
70 writeBuf := make([]byte, 65537)
71 for i := range writeBuf {
72 writeBuf[i] = byte(i)
73 }
74 var writers = []struct {
75 name string
76 f func(w io.Writer, n int) (int, error)
77 }{
78 {"iocopy", func(w io.Writer, n int) (int, error) {
79 nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
80 return int(nn), err
81 }},
82 {"write", func(w io.Writer, n int) (int, error) {
83 return w.Write(writeBuf[:n])
84 }},
85 {"string", func(w io.Writer, n int) (int, error) {
86 return io.WriteString(w, string(writeBuf[:n]))
87 }},
88 }
89
90 for _, compress := range []bool{false, true} {
91 for _, isServer := range []bool{true, false} {
92 for _, chunker := range readChunkers {
93
94 var connBuf bytes.Buffer
95 wc := newTestConn(nil, &connBuf, isServer)
96 rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
97 if compress {
98 wc.newCompressionWriter = compressNoContextTakeover
99 rc.newDecompressionReader = decompressNoContextTakeover
100 }
101 for _, n := range frameSizes {
102 for _, writer := range writers {
103 name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
104
105 w, err := wc.NextWriter(TextMessage)
106 if err != nil {
107 t.Errorf("%s: wc.NextWriter() returned %v", name, err)
108 continue
109 }
110 nn, err := writer.f(w, n)
111 if err != nil || nn != n {
112 t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
113 continue
114 }
115 err = w.Close()
116 if err != nil {
117 t.Errorf("%s: w.Close() returned %v", name, err)
118 continue
119 }
120
121 opCode, r, err := rc.NextReader()
122 if err != nil || opCode != TextMessage {
123 t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
124 continue
125 }
126
127 t.Logf("frame size: %d", n)
128 rbuf, err := ioutil.ReadAll(r)
129 if err != nil {
130 t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
131 continue
132 }
133
134 if len(rbuf) != n {
135 t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
136 continue
137 }
138
139 for i, b := range rbuf {
140 if byte(i) != b {
141 t.Errorf("%s: bad byte at offset %d", name, i)
142 break
143 }
144 }
145 }
146 }
147 }
148 }
149 }
150 }
151
152 func TestControl(t *testing.T) {
153 const message = "this is a ping/pong messsage"
154 for _, isServer := range []bool{true, false} {
155 for _, isWriteControl := range []bool{true, false} {
156 name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
157 var connBuf bytes.Buffer
158 wc := newTestConn(nil, &connBuf, isServer)
159 rc := newTestConn(&connBuf, nil, !isServer)
160 if isWriteControl {
161 wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
162 } else {
163 w, err := wc.NextWriter(PongMessage)
164 if err != nil {
165 t.Errorf("%s: wc.NextWriter() returned %v", name, err)
166 continue
167 }
168 if _, err := w.Write([]byte(message)); err != nil {
169 t.Errorf("%s: w.Write() returned %v", name, err)
170 continue
171 }
172 if err := w.Close(); err != nil {
173 t.Errorf("%s: w.Close() returned %v", name, err)
174 continue
175 }
176 var actualMessage string
177 rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
178 rc.NextReader()
179 if actualMessage != message {
180 t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
181 continue
182 }
183 }
184 }
185 }
186 }
187
188
189 type simpleBufferPool struct {
190 v interface{}
191 }
192
193 func (p *simpleBufferPool) Get() interface{} {
194 v := p.v
195 p.v = nil
196 return v
197 }
198
199 func (p *simpleBufferPool) Put(v interface{}) {
200 p.v = v
201 }
202
203 func TestWriteBufferPool(t *testing.T) {
204 const message = "Now is the time for all good people to come to the aid of the party."
205
206 var buf bytes.Buffer
207 var pool simpleBufferPool
208 rc := newTestConn(&buf, nil, false)
209
210
211
212 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
213
214 if wc.writeBuf != nil {
215 t.Fatal("writeBuf not nil after create")
216 }
217
218
219
220 w, err := wc.NextWriter(TextMessage)
221 if err != nil {
222 t.Fatalf("wc.NextWriter() returned %v", err)
223 }
224
225 if wc.writeBuf == nil {
226 t.Fatal("writeBuf is nil after NextWriter")
227 }
228
229 writeBufAddr := &wc.writeBuf[0]
230
231 if _, err := io.WriteString(w, message); err != nil {
232 t.Fatalf("io.WriteString(w, message) returned %v", err)
233 }
234
235 if err := w.Close(); err != nil {
236 t.Fatalf("w.Close() returned %v", err)
237 }
238
239 if wc.writeBuf != nil {
240 t.Fatal("writeBuf not nil after w.Close()")
241 }
242
243 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
244 t.Fatal("writeBuf not returned to pool")
245 }
246
247 opCode, p, err := rc.ReadMessage()
248 if opCode != TextMessage || err != nil {
249 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
250 }
251
252 if s := string(p); s != message {
253 t.Fatalf("message is %s, want %s", s, message)
254 }
255
256
257
258 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
259 t.Fatalf("wc.WriteMessage() returned %v", err)
260 }
261
262 if wc.writeBuf != nil {
263 t.Fatal("writeBuf not nil after wc.WriteMessage()")
264 }
265
266 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
267 t.Fatal("writeBuf not returned to pool after WriteMessage")
268 }
269
270 opCode, p, err = rc.ReadMessage()
271 if opCode != TextMessage || err != nil {
272 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
273 }
274
275 if s := string(p); s != message {
276 t.Fatalf("message is %s, want %s", s, message)
277 }
278 }
279
280
281 func TestWriteBufferPoolSync(t *testing.T) {
282 var buf bytes.Buffer
283 var pool sync.Pool
284 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
285 rc := newTestConn(&buf, nil, false)
286
287 const message = "Hello World!"
288 for i := 0; i < 3; i++ {
289 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
290 t.Fatalf("wc.WriteMessage() returned %v", err)
291 }
292 opCode, p, err := rc.ReadMessage()
293 if opCode != TextMessage || err != nil {
294 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
295 }
296 if s := string(p); s != message {
297 t.Fatalf("message is %s, want %s", s, message)
298 }
299 }
300 }
301
302
303 type errorWriter struct{}
304
305 func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") }
306
307
308
309 func TestWriteBufferPoolError(t *testing.T) {
310
311
312
313 var pool simpleBufferPool
314 wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
315
316 w, err := wc.NextWriter(TextMessage)
317 if err != nil {
318 t.Fatalf("wc.NextWriter() returned %v", err)
319 }
320
321 if wc.writeBuf == nil {
322 t.Fatal("writeBuf is nil after NextWriter")
323 }
324
325 writeBufAddr := &wc.writeBuf[0]
326
327 if _, err := io.WriteString(w, "Hello"); err != nil {
328 t.Fatalf("io.WriteString(w, message) returned %v", err)
329 }
330
331 if err := w.Close(); err == nil {
332 t.Fatalf("w.Close() did not return error")
333 }
334
335 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
336 t.Fatal("writeBuf not returned to pool")
337 }
338
339
340
341 wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
342
343 if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
344 t.Fatalf("wc.WriteMessage did not return error")
345 }
346
347 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
348 t.Fatal("writeBuf not returned to pool")
349 }
350 }
351
352 func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
353 const bufSize = 512
354
355 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
356
357 var b1, b2 bytes.Buffer
358 wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
359 rc := newTestConn(&b1, &b2, true)
360
361 w, _ := wc.NextWriter(BinaryMessage)
362 w.Write(make([]byte, bufSize+bufSize/2))
363 wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
364 w.Close()
365
366 op, r, err := rc.NextReader()
367 if op != BinaryMessage || err != nil {
368 t.Fatalf("NextReader() returned %d, %v", op, err)
369 }
370 _, err = io.Copy(ioutil.Discard, r)
371 if !reflect.DeepEqual(err, expectedErr) {
372 t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
373 }
374 _, _, err = rc.NextReader()
375 if !reflect.DeepEqual(err, expectedErr) {
376 t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
377 }
378 }
379
380 func TestEOFWithinFrame(t *testing.T) {
381 const bufSize = 64
382
383 for n := 0; ; n++ {
384 var b bytes.Buffer
385 wc := newTestConn(nil, &b, false)
386 rc := newTestConn(&b, nil, true)
387
388 w, _ := wc.NextWriter(BinaryMessage)
389 w.Write(make([]byte, bufSize))
390 w.Close()
391
392 if n >= b.Len() {
393 break
394 }
395 b.Truncate(n)
396
397 op, r, err := rc.NextReader()
398 if err == errUnexpectedEOF {
399 continue
400 }
401 if op != BinaryMessage || err != nil {
402 t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
403 }
404 _, err = io.Copy(ioutil.Discard, r)
405 if err != errUnexpectedEOF {
406 t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
407 }
408 _, _, err = rc.NextReader()
409 if err != errUnexpectedEOF {
410 t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
411 }
412 }
413 }
414
415 func TestEOFBeforeFinalFrame(t *testing.T) {
416 const bufSize = 512
417
418 var b1, b2 bytes.Buffer
419 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
420 rc := newTestConn(&b1, &b2, true)
421
422 w, _ := wc.NextWriter(BinaryMessage)
423 w.Write(make([]byte, bufSize+bufSize/2))
424
425 op, r, err := rc.NextReader()
426 if op != BinaryMessage || err != nil {
427 t.Fatalf("NextReader() returned %d, %v", op, err)
428 }
429 _, err = io.Copy(ioutil.Discard, r)
430 if err != errUnexpectedEOF {
431 t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
432 }
433 _, _, err = rc.NextReader()
434 if err != errUnexpectedEOF {
435 t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
436 }
437 }
438
439 func TestWriteAfterMessageWriterClose(t *testing.T) {
440 wc := newTestConn(nil, &bytes.Buffer{}, false)
441 w, _ := wc.NextWriter(BinaryMessage)
442 io.WriteString(w, "hello")
443 if err := w.Close(); err != nil {
444 t.Fatalf("unxpected error closing message writer, %v", err)
445 }
446
447 if _, err := io.WriteString(w, "world"); err == nil {
448 t.Fatalf("no error writing after close")
449 }
450
451 w, _ = wc.NextWriter(BinaryMessage)
452 io.WriteString(w, "hello")
453
454
455 _, err := wc.NextWriter(BinaryMessage)
456 if err != nil {
457 t.Fatalf("unexpected error getting next writer, %v", err)
458 }
459
460 if _, err := io.WriteString(w, "world"); err == nil {
461 t.Fatalf("no error writing after close")
462 }
463 }
464
465 func TestReadLimit(t *testing.T) {
466 t.Run("Test ReadLimit is enforced", func(t *testing.T) {
467 const readLimit = 512
468 message := make([]byte, readLimit+1)
469
470 var b1, b2 bytes.Buffer
471 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
472 rc := newTestConn(&b1, &b2, true)
473 rc.SetReadLimit(readLimit)
474
475
476 w, _ := wc.NextWriter(BinaryMessage)
477 w.Write(message[:readLimit-1])
478 wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
479 w.Write(message[:1])
480 w.Close()
481
482
483 wc.WriteMessage(BinaryMessage, message[:readLimit+1])
484
485 op, _, err := rc.NextReader()
486 if op != BinaryMessage || err != nil {
487 t.Fatalf("1: NextReader() returned %d, %v", op, err)
488 }
489 op, r, err := rc.NextReader()
490 if op != BinaryMessage || err != nil {
491 t.Fatalf("2: NextReader() returned %d, %v", op, err)
492 }
493 _, err = io.Copy(ioutil.Discard, r)
494 if err != ErrReadLimit {
495 t.Fatalf("io.Copy() returned %v", err)
496 }
497 })
498
499 t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
500 const readLimit = 1
501
502 var b1, b2 bytes.Buffer
503 rc := newTestConn(&b1, &b2, true)
504 rc.SetReadLimit(readLimit)
505
506
507 b1.Write([]byte("\x02\x81"))
508
509
510 b1.Write([]byte("\x00\x00\x00\x00"))
511
512
513 b1.Write([]byte("A"))
514
515
516 b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
517
518
519 b1.Write([]byte("\x00\x00\x00\x00"))
520
521
522 b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
523
524
525 b1.Write([]byte("\x00\x00\x00\x00"))
526
527
528 b1.Write([]byte("BCDEF"))
529
530 op, r, err := rc.NextReader()
531 if op != BinaryMessage || err != nil {
532 t.Fatalf("1: NextReader() returned %d, %v", op, err)
533 }
534
535 var buf [10]byte
536 var read int
537 n, err := r.Read(buf[:])
538 if err != nil && err != ErrReadLimit {
539 t.Fatalf("unexpected error testing read limit: %v", err)
540 }
541 read += n
542
543 n, err = r.Read(buf[:])
544 if err != nil && err != ErrReadLimit {
545 t.Fatalf("unexpected error testing read limit: %v", err)
546 }
547 read += n
548
549 if err == nil && read > readLimit {
550 t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
551 }
552 })
553 }
554
555 func TestAddrs(t *testing.T) {
556 c := newTestConn(nil, nil, true)
557 if c.LocalAddr() != localAddr {
558 t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
559 }
560 if c.RemoteAddr() != remoteAddr {
561 t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
562 }
563 }
564
565 func TestDeprecatedUnderlyingConn(t *testing.T) {
566 var b1, b2 bytes.Buffer
567 fc := fakeNetConn{Reader: &b1, Writer: &b2}
568 c := newConn(fc, true, 1024, 1024, nil, nil, nil)
569 ul := c.UnderlyingConn()
570 if ul != fc {
571 t.Fatalf("Underlying conn is not what it should be.")
572 }
573 }
574
575 func TestNetConn(t *testing.T) {
576 var b1, b2 bytes.Buffer
577 fc := fakeNetConn{Reader: &b1, Writer: &b2}
578 c := newConn(fc, true, 1024, 1024, nil, nil, nil)
579 ul := c.NetConn()
580 if ul != fc {
581 t.Fatalf("Underlying conn is not what it should be.")
582 }
583 }
584
585 func TestBufioReadBytes(t *testing.T) {
586
587
588 m := make([]byte, 512)
589 m[len(m)-1] = '\n'
590
591 var b1, b2 bytes.Buffer
592 wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
593 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
594
595 w, _ := wc.NextWriter(BinaryMessage)
596 w.Write(m)
597 w.Close()
598
599 op, r, err := rc.NextReader()
600 if op != BinaryMessage || err != nil {
601 t.Fatalf("NextReader() returned %d, %v", op, err)
602 }
603
604 br := bufio.NewReader(r)
605 p, err := br.ReadBytes('\n')
606 if err != nil {
607 t.Fatalf("ReadBytes() returned %v", err)
608 }
609 if len(p) != len(m) {
610 t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
611 }
612 }
613
614 var closeErrorTests = []struct {
615 err error
616 codes []int
617 ok bool
618 }{
619 {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
620 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
621 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
622 {errors.New("hello"), []int{CloseNormalClosure}, false},
623 }
624
625 func TestCloseError(t *testing.T) {
626 for _, tt := range closeErrorTests {
627 ok := IsCloseError(tt.err, tt.codes...)
628 if ok != tt.ok {
629 t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
630 }
631 }
632 }
633
634 var unexpectedCloseErrorTests = []struct {
635 err error
636 codes []int
637 ok bool
638 }{
639 {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
640 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
641 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
642 {errors.New("hello"), []int{CloseNormalClosure}, false},
643 }
644
645 func TestUnexpectedCloseErrors(t *testing.T) {
646 for _, tt := range unexpectedCloseErrorTests {
647 ok := IsUnexpectedCloseError(tt.err, tt.codes...)
648 if ok != tt.ok {
649 t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
650 }
651 }
652 }
653
654 type blockingWriter struct {
655 c1, c2 chan struct{}
656 }
657
658 func (w blockingWriter) Write(p []byte) (int, error) {
659
660 close(w.c1)
661
662 <-w.c2
663 return len(p), nil
664 }
665
666 func TestConcurrentWritePanic(t *testing.T) {
667 w := blockingWriter{make(chan struct{}), make(chan struct{})}
668 c := newTestConn(nil, w, false)
669 go func() {
670 c.WriteMessage(TextMessage, []byte{})
671 }()
672
673
674 <-w.c1
675
676 defer func() {
677 close(w.c2)
678 if v := recover(); v != nil {
679 return
680 }
681 }()
682
683 c.WriteMessage(TextMessage, []byte{})
684 t.Fatal("should not get here")
685 }
686
687 type failingReader struct{}
688
689 func (r failingReader) Read(p []byte) (int, error) {
690 return 0, io.EOF
691 }
692
693 func TestFailedConnectionReadPanic(t *testing.T) {
694 c := newTestConn(failingReader{}, nil, false)
695
696 defer func() {
697 if v := recover(); v != nil {
698 return
699 }
700 }()
701
702 for i := 0; i < 20000; i++ {
703 c.ReadMessage()
704 }
705 t.Fatal("should not get here")
706 }
707
View as plain text