...

Source file src/github.com/gorilla/websocket/conn_test.go

Documentation: github.com/gorilla/websocket

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"net"
    15  	"reflect"
    16  	"sync"
    17  	"testing"
    18  	"testing/iotest"
    19  	"time"
    20  )
    21  
    22  var _ net.Error = errWriteTimeout
    23  
    24  type fakeNetConn struct {
    25  	io.Reader
    26  	io.Writer
    27  }
    28  
    29  func (c fakeNetConn) Close() error                       { return nil }
    30  func (c fakeNetConn) LocalAddr() net.Addr                { return localAddr }
    31  func (c fakeNetConn) RemoteAddr() net.Addr               { return remoteAddr }
    32  func (c fakeNetConn) SetDeadline(t time.Time) error      { return nil }
    33  func (c fakeNetConn) SetReadDeadline(t time.Time) error  { return nil }
    34  func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
    35  
    36  type fakeAddr int
    37  
    38  var (
    39  	localAddr  = fakeAddr(1)
    40  	remoteAddr = fakeAddr(2)
    41  )
    42  
    43  func (a fakeAddr) Network() string {
    44  	return "net"
    45  }
    46  
    47  func (a fakeAddr) String() string {
    48  	return "str"
    49  }
    50  
    51  // newTestConn creates a connnection backed by a fake network connection using
    52  // default values for buffering.
    53  func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
    54  	return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
    55  }
    56  
    57  func TestFraming(t *testing.T) {
    58  	frameSizes := []int{
    59  		0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
    60  		// 65536, 65537
    61  	}
    62  	var readChunkers = []struct {
    63  		name string
    64  		f    func(io.Reader) io.Reader
    65  	}{
    66  		{"half", iotest.HalfReader},
    67  		{"one", iotest.OneByteReader},
    68  		{"asis", func(r io.Reader) io.Reader { return r }},
    69  	}
    70  	writeBuf := make([]byte, 65537)
    71  	for i := range writeBuf {
    72  		writeBuf[i] = byte(i)
    73  	}
    74  	var writers = []struct {
    75  		name string
    76  		f    func(w io.Writer, n int) (int, error)
    77  	}{
    78  		{"iocopy", func(w io.Writer, n int) (int, error) {
    79  			nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
    80  			return int(nn), err
    81  		}},
    82  		{"write", func(w io.Writer, n int) (int, error) {
    83  			return w.Write(writeBuf[:n])
    84  		}},
    85  		{"string", func(w io.Writer, n int) (int, error) {
    86  			return io.WriteString(w, string(writeBuf[:n]))
    87  		}},
    88  	}
    89  
    90  	for _, compress := range []bool{false, true} {
    91  		for _, isServer := range []bool{true, false} {
    92  			for _, chunker := range readChunkers {
    93  
    94  				var connBuf bytes.Buffer
    95  				wc := newTestConn(nil, &connBuf, isServer)
    96  				rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
    97  				if compress {
    98  					wc.newCompressionWriter = compressNoContextTakeover
    99  					rc.newDecompressionReader = decompressNoContextTakeover
   100  				}
   101  				for _, n := range frameSizes {
   102  					for _, writer := range writers {
   103  						name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
   104  
   105  						w, err := wc.NextWriter(TextMessage)
   106  						if err != nil {
   107  							t.Errorf("%s: wc.NextWriter() returned %v", name, err)
   108  							continue
   109  						}
   110  						nn, err := writer.f(w, n)
   111  						if err != nil || nn != n {
   112  							t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
   113  							continue
   114  						}
   115  						err = w.Close()
   116  						if err != nil {
   117  							t.Errorf("%s: w.Close() returned %v", name, err)
   118  							continue
   119  						}
   120  
   121  						opCode, r, err := rc.NextReader()
   122  						if err != nil || opCode != TextMessage {
   123  							t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
   124  							continue
   125  						}
   126  
   127  						t.Logf("frame size: %d", n)
   128  						rbuf, err := ioutil.ReadAll(r)
   129  						if err != nil {
   130  							t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
   131  							continue
   132  						}
   133  
   134  						if len(rbuf) != n {
   135  							t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
   136  							continue
   137  						}
   138  
   139  						for i, b := range rbuf {
   140  							if byte(i) != b {
   141  								t.Errorf("%s: bad byte at offset %d", name, i)
   142  								break
   143  							}
   144  						}
   145  					}
   146  				}
   147  			}
   148  		}
   149  	}
   150  }
   151  
   152  func TestControl(t *testing.T) {
   153  	const message = "this is a ping/pong messsage"
   154  	for _, isServer := range []bool{true, false} {
   155  		for _, isWriteControl := range []bool{true, false} {
   156  			name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
   157  			var connBuf bytes.Buffer
   158  			wc := newTestConn(nil, &connBuf, isServer)
   159  			rc := newTestConn(&connBuf, nil, !isServer)
   160  			if isWriteControl {
   161  				wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
   162  			} else {
   163  				w, err := wc.NextWriter(PongMessage)
   164  				if err != nil {
   165  					t.Errorf("%s: wc.NextWriter() returned %v", name, err)
   166  					continue
   167  				}
   168  				if _, err := w.Write([]byte(message)); err != nil {
   169  					t.Errorf("%s: w.Write() returned %v", name, err)
   170  					continue
   171  				}
   172  				if err := w.Close(); err != nil {
   173  					t.Errorf("%s: w.Close() returned %v", name, err)
   174  					continue
   175  				}
   176  				var actualMessage string
   177  				rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
   178  				rc.NextReader()
   179  				if actualMessage != message {
   180  					t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
   181  					continue
   182  				}
   183  			}
   184  		}
   185  	}
   186  }
   187  
   188  // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
   189  type simpleBufferPool struct {
   190  	v interface{}
   191  }
   192  
   193  func (p *simpleBufferPool) Get() interface{} {
   194  	v := p.v
   195  	p.v = nil
   196  	return v
   197  }
   198  
   199  func (p *simpleBufferPool) Put(v interface{}) {
   200  	p.v = v
   201  }
   202  
   203  func TestWriteBufferPool(t *testing.T) {
   204  	const message = "Now is the time for all good people to come to the aid of the party."
   205  
   206  	var buf bytes.Buffer
   207  	var pool simpleBufferPool
   208  	rc := newTestConn(&buf, nil, false)
   209  
   210  	// Specify writeBufferSize smaller than message size to ensure that pooling
   211  	// works with fragmented messages.
   212  	wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
   213  
   214  	if wc.writeBuf != nil {
   215  		t.Fatal("writeBuf not nil after create")
   216  	}
   217  
   218  	// Part 1: test NextWriter/Write/Close
   219  
   220  	w, err := wc.NextWriter(TextMessage)
   221  	if err != nil {
   222  		t.Fatalf("wc.NextWriter() returned %v", err)
   223  	}
   224  
   225  	if wc.writeBuf == nil {
   226  		t.Fatal("writeBuf is nil after NextWriter")
   227  	}
   228  
   229  	writeBufAddr := &wc.writeBuf[0]
   230  
   231  	if _, err := io.WriteString(w, message); err != nil {
   232  		t.Fatalf("io.WriteString(w, message) returned %v", err)
   233  	}
   234  
   235  	if err := w.Close(); err != nil {
   236  		t.Fatalf("w.Close() returned %v", err)
   237  	}
   238  
   239  	if wc.writeBuf != nil {
   240  		t.Fatal("writeBuf not nil after w.Close()")
   241  	}
   242  
   243  	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
   244  		t.Fatal("writeBuf not returned to pool")
   245  	}
   246  
   247  	opCode, p, err := rc.ReadMessage()
   248  	if opCode != TextMessage || err != nil {
   249  		t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
   250  	}
   251  
   252  	if s := string(p); s != message {
   253  		t.Fatalf("message is %s, want %s", s, message)
   254  	}
   255  
   256  	// Part 2: Test WriteMessage.
   257  
   258  	if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
   259  		t.Fatalf("wc.WriteMessage() returned %v", err)
   260  	}
   261  
   262  	if wc.writeBuf != nil {
   263  		t.Fatal("writeBuf not nil after wc.WriteMessage()")
   264  	}
   265  
   266  	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
   267  		t.Fatal("writeBuf not returned to pool after WriteMessage")
   268  	}
   269  
   270  	opCode, p, err = rc.ReadMessage()
   271  	if opCode != TextMessage || err != nil {
   272  		t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
   273  	}
   274  
   275  	if s := string(p); s != message {
   276  		t.Fatalf("message is %s, want %s", s, message)
   277  	}
   278  }
   279  
   280  // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
   281  func TestWriteBufferPoolSync(t *testing.T) {
   282  	var buf bytes.Buffer
   283  	var pool sync.Pool
   284  	wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
   285  	rc := newTestConn(&buf, nil, false)
   286  
   287  	const message = "Hello World!"
   288  	for i := 0; i < 3; i++ {
   289  		if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
   290  			t.Fatalf("wc.WriteMessage() returned %v", err)
   291  		}
   292  		opCode, p, err := rc.ReadMessage()
   293  		if opCode != TextMessage || err != nil {
   294  			t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
   295  		}
   296  		if s := string(p); s != message {
   297  			t.Fatalf("message is %s, want %s", s, message)
   298  		}
   299  	}
   300  }
   301  
   302  // errorWriter is an io.Writer than returns an error on all writes.
   303  type errorWriter struct{}
   304  
   305  func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") }
   306  
   307  // TestWriteBufferPoolError ensures that buffer is returned to pool after error
   308  // on write.
   309  func TestWriteBufferPoolError(t *testing.T) {
   310  
   311  	// Part 1: Test NextWriter/Write/Close
   312  
   313  	var pool simpleBufferPool
   314  	wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
   315  
   316  	w, err := wc.NextWriter(TextMessage)
   317  	if err != nil {
   318  		t.Fatalf("wc.NextWriter() returned %v", err)
   319  	}
   320  
   321  	if wc.writeBuf == nil {
   322  		t.Fatal("writeBuf is nil after NextWriter")
   323  	}
   324  
   325  	writeBufAddr := &wc.writeBuf[0]
   326  
   327  	if _, err := io.WriteString(w, "Hello"); err != nil {
   328  		t.Fatalf("io.WriteString(w, message) returned %v", err)
   329  	}
   330  
   331  	if err := w.Close(); err == nil {
   332  		t.Fatalf("w.Close() did not return error")
   333  	}
   334  
   335  	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
   336  		t.Fatal("writeBuf not returned to pool")
   337  	}
   338  
   339  	// Part 2: Test WriteMessage
   340  
   341  	wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
   342  
   343  	if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
   344  		t.Fatalf("wc.WriteMessage did not return error")
   345  	}
   346  
   347  	if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
   348  		t.Fatal("writeBuf not returned to pool")
   349  	}
   350  }
   351  
   352  func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
   353  	const bufSize = 512
   354  
   355  	expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
   356  
   357  	var b1, b2 bytes.Buffer
   358  	wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
   359  	rc := newTestConn(&b1, &b2, true)
   360  
   361  	w, _ := wc.NextWriter(BinaryMessage)
   362  	w.Write(make([]byte, bufSize+bufSize/2))
   363  	wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
   364  	w.Close()
   365  
   366  	op, r, err := rc.NextReader()
   367  	if op != BinaryMessage || err != nil {
   368  		t.Fatalf("NextReader() returned %d, %v", op, err)
   369  	}
   370  	_, err = io.Copy(ioutil.Discard, r)
   371  	if !reflect.DeepEqual(err, expectedErr) {
   372  		t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
   373  	}
   374  	_, _, err = rc.NextReader()
   375  	if !reflect.DeepEqual(err, expectedErr) {
   376  		t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
   377  	}
   378  }
   379  
   380  func TestEOFWithinFrame(t *testing.T) {
   381  	const bufSize = 64
   382  
   383  	for n := 0; ; n++ {
   384  		var b bytes.Buffer
   385  		wc := newTestConn(nil, &b, false)
   386  		rc := newTestConn(&b, nil, true)
   387  
   388  		w, _ := wc.NextWriter(BinaryMessage)
   389  		w.Write(make([]byte, bufSize))
   390  		w.Close()
   391  
   392  		if n >= b.Len() {
   393  			break
   394  		}
   395  		b.Truncate(n)
   396  
   397  		op, r, err := rc.NextReader()
   398  		if err == errUnexpectedEOF {
   399  			continue
   400  		}
   401  		if op != BinaryMessage || err != nil {
   402  			t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
   403  		}
   404  		_, err = io.Copy(ioutil.Discard, r)
   405  		if err != errUnexpectedEOF {
   406  			t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
   407  		}
   408  		_, _, err = rc.NextReader()
   409  		if err != errUnexpectedEOF {
   410  			t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
   411  		}
   412  	}
   413  }
   414  
   415  func TestEOFBeforeFinalFrame(t *testing.T) {
   416  	const bufSize = 512
   417  
   418  	var b1, b2 bytes.Buffer
   419  	wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
   420  	rc := newTestConn(&b1, &b2, true)
   421  
   422  	w, _ := wc.NextWriter(BinaryMessage)
   423  	w.Write(make([]byte, bufSize+bufSize/2))
   424  
   425  	op, r, err := rc.NextReader()
   426  	if op != BinaryMessage || err != nil {
   427  		t.Fatalf("NextReader() returned %d, %v", op, err)
   428  	}
   429  	_, err = io.Copy(ioutil.Discard, r)
   430  	if err != errUnexpectedEOF {
   431  		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
   432  	}
   433  	_, _, err = rc.NextReader()
   434  	if err != errUnexpectedEOF {
   435  		t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
   436  	}
   437  }
   438  
   439  func TestWriteAfterMessageWriterClose(t *testing.T) {
   440  	wc := newTestConn(nil, &bytes.Buffer{}, false)
   441  	w, _ := wc.NextWriter(BinaryMessage)
   442  	io.WriteString(w, "hello")
   443  	if err := w.Close(); err != nil {
   444  		t.Fatalf("unxpected error closing message writer, %v", err)
   445  	}
   446  
   447  	if _, err := io.WriteString(w, "world"); err == nil {
   448  		t.Fatalf("no error writing after close")
   449  	}
   450  
   451  	w, _ = wc.NextWriter(BinaryMessage)
   452  	io.WriteString(w, "hello")
   453  
   454  	// close w by getting next writer
   455  	_, err := wc.NextWriter(BinaryMessage)
   456  	if err != nil {
   457  		t.Fatalf("unexpected error getting next writer, %v", err)
   458  	}
   459  
   460  	if _, err := io.WriteString(w, "world"); err == nil {
   461  		t.Fatalf("no error writing after close")
   462  	}
   463  }
   464  
   465  func TestReadLimit(t *testing.T) {
   466  	t.Run("Test ReadLimit is enforced", func(t *testing.T) {
   467  		const readLimit = 512
   468  		message := make([]byte, readLimit+1)
   469  
   470  		var b1, b2 bytes.Buffer
   471  		wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
   472  		rc := newTestConn(&b1, &b2, true)
   473  		rc.SetReadLimit(readLimit)
   474  
   475  		// Send message at the limit with interleaved pong.
   476  		w, _ := wc.NextWriter(BinaryMessage)
   477  		w.Write(message[:readLimit-1])
   478  		wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
   479  		w.Write(message[:1])
   480  		w.Close()
   481  
   482  		// Send message larger than the limit.
   483  		wc.WriteMessage(BinaryMessage, message[:readLimit+1])
   484  
   485  		op, _, err := rc.NextReader()
   486  		if op != BinaryMessage || err != nil {
   487  			t.Fatalf("1: NextReader() returned %d, %v", op, err)
   488  		}
   489  		op, r, err := rc.NextReader()
   490  		if op != BinaryMessage || err != nil {
   491  			t.Fatalf("2: NextReader() returned %d, %v", op, err)
   492  		}
   493  		_, err = io.Copy(ioutil.Discard, r)
   494  		if err != ErrReadLimit {
   495  			t.Fatalf("io.Copy() returned %v", err)
   496  		}
   497  	})
   498  
   499  	t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
   500  		const readLimit = 1
   501  
   502  		var b1, b2 bytes.Buffer
   503  		rc := newTestConn(&b1, &b2, true)
   504  		rc.SetReadLimit(readLimit)
   505  
   506  		// First, send a non-final binary message
   507  		b1.Write([]byte("\x02\x81"))
   508  
   509  		// Mask key
   510  		b1.Write([]byte("\x00\x00\x00\x00"))
   511  
   512  		// First payload
   513  		b1.Write([]byte("A"))
   514  
   515  		// Next, send a negative-length, non-final continuation frame
   516  		b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
   517  
   518  		// Mask key
   519  		b1.Write([]byte("\x00\x00\x00\x00"))
   520  
   521  		// Next, send a too long, final continuation frame
   522  		b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
   523  
   524  		// Mask key
   525  		b1.Write([]byte("\x00\x00\x00\x00"))
   526  
   527  		// Too-long payload
   528  		b1.Write([]byte("BCDEF"))
   529  
   530  		op, r, err := rc.NextReader()
   531  		if op != BinaryMessage || err != nil {
   532  			t.Fatalf("1: NextReader() returned %d, %v", op, err)
   533  		}
   534  
   535  		var buf [10]byte
   536  		var read int
   537  		n, err := r.Read(buf[:])
   538  		if err != nil && err != ErrReadLimit {
   539  			t.Fatalf("unexpected error testing read limit: %v", err)
   540  		}
   541  		read += n
   542  
   543  		n, err = r.Read(buf[:])
   544  		if err != nil && err != ErrReadLimit {
   545  			t.Fatalf("unexpected error testing read limit: %v", err)
   546  		}
   547  		read += n
   548  
   549  		if err == nil && read > readLimit {
   550  			t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
   551  		}
   552  	})
   553  }
   554  
   555  func TestAddrs(t *testing.T) {
   556  	c := newTestConn(nil, nil, true)
   557  	if c.LocalAddr() != localAddr {
   558  		t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
   559  	}
   560  	if c.RemoteAddr() != remoteAddr {
   561  		t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
   562  	}
   563  }
   564  
   565  func TestDeprecatedUnderlyingConn(t *testing.T) {
   566  	var b1, b2 bytes.Buffer
   567  	fc := fakeNetConn{Reader: &b1, Writer: &b2}
   568  	c := newConn(fc, true, 1024, 1024, nil, nil, nil)
   569  	ul := c.UnderlyingConn()
   570  	if ul != fc {
   571  		t.Fatalf("Underlying conn is not what it should be.")
   572  	}
   573  }
   574  
   575  func TestNetConn(t *testing.T) {
   576  	var b1, b2 bytes.Buffer
   577  	fc := fakeNetConn{Reader: &b1, Writer: &b2}
   578  	c := newConn(fc, true, 1024, 1024, nil, nil, nil)
   579  	ul := c.NetConn()
   580  	if ul != fc {
   581  		t.Fatalf("Underlying conn is not what it should be.")
   582  	}
   583  }
   584  
   585  func TestBufioReadBytes(t *testing.T) {
   586  	// Test calling bufio.ReadBytes for value longer than read buffer size.
   587  
   588  	m := make([]byte, 512)
   589  	m[len(m)-1] = '\n'
   590  
   591  	var b1, b2 bytes.Buffer
   592  	wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
   593  	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
   594  
   595  	w, _ := wc.NextWriter(BinaryMessage)
   596  	w.Write(m)
   597  	w.Close()
   598  
   599  	op, r, err := rc.NextReader()
   600  	if op != BinaryMessage || err != nil {
   601  		t.Fatalf("NextReader() returned %d, %v", op, err)
   602  	}
   603  
   604  	br := bufio.NewReader(r)
   605  	p, err := br.ReadBytes('\n')
   606  	if err != nil {
   607  		t.Fatalf("ReadBytes() returned %v", err)
   608  	}
   609  	if len(p) != len(m) {
   610  		t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
   611  	}
   612  }
   613  
   614  var closeErrorTests = []struct {
   615  	err   error
   616  	codes []int
   617  	ok    bool
   618  }{
   619  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
   620  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
   621  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
   622  	{errors.New("hello"), []int{CloseNormalClosure}, false},
   623  }
   624  
   625  func TestCloseError(t *testing.T) {
   626  	for _, tt := range closeErrorTests {
   627  		ok := IsCloseError(tt.err, tt.codes...)
   628  		if ok != tt.ok {
   629  			t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
   630  		}
   631  	}
   632  }
   633  
   634  var unexpectedCloseErrorTests = []struct {
   635  	err   error
   636  	codes []int
   637  	ok    bool
   638  }{
   639  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
   640  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
   641  	{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
   642  	{errors.New("hello"), []int{CloseNormalClosure}, false},
   643  }
   644  
   645  func TestUnexpectedCloseErrors(t *testing.T) {
   646  	for _, tt := range unexpectedCloseErrorTests {
   647  		ok := IsUnexpectedCloseError(tt.err, tt.codes...)
   648  		if ok != tt.ok {
   649  			t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
   650  		}
   651  	}
   652  }
   653  
   654  type blockingWriter struct {
   655  	c1, c2 chan struct{}
   656  }
   657  
   658  func (w blockingWriter) Write(p []byte) (int, error) {
   659  	// Allow main to continue
   660  	close(w.c1)
   661  	// Wait for panic in main
   662  	<-w.c2
   663  	return len(p), nil
   664  }
   665  
   666  func TestConcurrentWritePanic(t *testing.T) {
   667  	w := blockingWriter{make(chan struct{}), make(chan struct{})}
   668  	c := newTestConn(nil, w, false)
   669  	go func() {
   670  		c.WriteMessage(TextMessage, []byte{})
   671  	}()
   672  
   673  	// wait for goroutine to block in write.
   674  	<-w.c1
   675  
   676  	defer func() {
   677  		close(w.c2)
   678  		if v := recover(); v != nil {
   679  			return
   680  		}
   681  	}()
   682  
   683  	c.WriteMessage(TextMessage, []byte{})
   684  	t.Fatal("should not get here")
   685  }
   686  
   687  type failingReader struct{}
   688  
   689  func (r failingReader) Read(p []byte) (int, error) {
   690  	return 0, io.EOF
   691  }
   692  
   693  func TestFailedConnectionReadPanic(t *testing.T) {
   694  	c := newTestConn(failingReader{}, nil, false)
   695  
   696  	defer func() {
   697  		if v := recover(); v != nil {
   698  			return
   699  		}
   700  	}()
   701  
   702  	for i := 0; i < 20000; i++ {
   703  		c.ReadMessage()
   704  	}
   705  	t.Fatal("should not get here")
   706  }
   707  

View as plain text