...

Source file src/github.com/gorilla/websocket/conn.go

Documentation: github.com/gorilla/websocket

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"encoding/binary"
    10  	"errors"
    11  	"io"
    12  	"io/ioutil"
    13  	"math/rand"
    14  	"net"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  	"unicode/utf8"
    20  )
    21  
    22  const (
    23  	// Frame header byte 0 bits from Section 5.2 of RFC 6455
    24  	finalBit = 1 << 7
    25  	rsv1Bit  = 1 << 6
    26  	rsv2Bit  = 1 << 5
    27  	rsv3Bit  = 1 << 4
    28  
    29  	// Frame header byte 1 bits from Section 5.2 of RFC 6455
    30  	maskBit = 1 << 7
    31  
    32  	maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
    33  	maxControlFramePayloadSize = 125
    34  
    35  	writeWait = time.Second
    36  
    37  	defaultReadBufferSize  = 4096
    38  	defaultWriteBufferSize = 4096
    39  
    40  	continuationFrame = 0
    41  	noFrame           = -1
    42  )
    43  
    44  // Close codes defined in RFC 6455, section 11.7.
    45  const (
    46  	CloseNormalClosure           = 1000
    47  	CloseGoingAway               = 1001
    48  	CloseProtocolError           = 1002
    49  	CloseUnsupportedData         = 1003
    50  	CloseNoStatusReceived        = 1005
    51  	CloseAbnormalClosure         = 1006
    52  	CloseInvalidFramePayloadData = 1007
    53  	ClosePolicyViolation         = 1008
    54  	CloseMessageTooBig           = 1009
    55  	CloseMandatoryExtension      = 1010
    56  	CloseInternalServerErr       = 1011
    57  	CloseServiceRestart          = 1012
    58  	CloseTryAgainLater           = 1013
    59  	CloseTLSHandshake            = 1015
    60  )
    61  
    62  // The message types are defined in RFC 6455, section 11.8.
    63  const (
    64  	// TextMessage denotes a text data message. The text message payload is
    65  	// interpreted as UTF-8 encoded text data.
    66  	TextMessage = 1
    67  
    68  	// BinaryMessage denotes a binary data message.
    69  	BinaryMessage = 2
    70  
    71  	// CloseMessage denotes a close control message. The optional message
    72  	// payload contains a numeric code and text. Use the FormatCloseMessage
    73  	// function to format a close message payload.
    74  	CloseMessage = 8
    75  
    76  	// PingMessage denotes a ping control message. The optional message payload
    77  	// is UTF-8 encoded text.
    78  	PingMessage = 9
    79  
    80  	// PongMessage denotes a pong control message. The optional message payload
    81  	// is UTF-8 encoded text.
    82  	PongMessage = 10
    83  )
    84  
    85  // ErrCloseSent is returned when the application writes a message to the
    86  // connection after sending a close message.
    87  var ErrCloseSent = errors.New("websocket: close sent")
    88  
    89  // ErrReadLimit is returned when reading a message that is larger than the
    90  // read limit set for the connection.
    91  var ErrReadLimit = errors.New("websocket: read limit exceeded")
    92  
    93  // netError satisfies the net Error interface.
    94  type netError struct {
    95  	msg       string
    96  	temporary bool
    97  	timeout   bool
    98  }
    99  
   100  func (e *netError) Error() string   { return e.msg }
   101  func (e *netError) Temporary() bool { return e.temporary }
   102  func (e *netError) Timeout() bool   { return e.timeout }
   103  
   104  // CloseError represents a close message.
   105  type CloseError struct {
   106  	// Code is defined in RFC 6455, section 11.7.
   107  	Code int
   108  
   109  	// Text is the optional text payload.
   110  	Text string
   111  }
   112  
   113  func (e *CloseError) Error() string {
   114  	s := []byte("websocket: close ")
   115  	s = strconv.AppendInt(s, int64(e.Code), 10)
   116  	switch e.Code {
   117  	case CloseNormalClosure:
   118  		s = append(s, " (normal)"...)
   119  	case CloseGoingAway:
   120  		s = append(s, " (going away)"...)
   121  	case CloseProtocolError:
   122  		s = append(s, " (protocol error)"...)
   123  	case CloseUnsupportedData:
   124  		s = append(s, " (unsupported data)"...)
   125  	case CloseNoStatusReceived:
   126  		s = append(s, " (no status)"...)
   127  	case CloseAbnormalClosure:
   128  		s = append(s, " (abnormal closure)"...)
   129  	case CloseInvalidFramePayloadData:
   130  		s = append(s, " (invalid payload data)"...)
   131  	case ClosePolicyViolation:
   132  		s = append(s, " (policy violation)"...)
   133  	case CloseMessageTooBig:
   134  		s = append(s, " (message too big)"...)
   135  	case CloseMandatoryExtension:
   136  		s = append(s, " (mandatory extension missing)"...)
   137  	case CloseInternalServerErr:
   138  		s = append(s, " (internal server error)"...)
   139  	case CloseTLSHandshake:
   140  		s = append(s, " (TLS handshake error)"...)
   141  	}
   142  	if e.Text != "" {
   143  		s = append(s, ": "...)
   144  		s = append(s, e.Text...)
   145  	}
   146  	return string(s)
   147  }
   148  
   149  // IsCloseError returns boolean indicating whether the error is a *CloseError
   150  // with one of the specified codes.
   151  func IsCloseError(err error, codes ...int) bool {
   152  	if e, ok := err.(*CloseError); ok {
   153  		for _, code := range codes {
   154  			if e.Code == code {
   155  				return true
   156  			}
   157  		}
   158  	}
   159  	return false
   160  }
   161  
   162  // IsUnexpectedCloseError returns boolean indicating whether the error is a
   163  // *CloseError with a code not in the list of expected codes.
   164  func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
   165  	if e, ok := err.(*CloseError); ok {
   166  		for _, code := range expectedCodes {
   167  			if e.Code == code {
   168  				return false
   169  			}
   170  		}
   171  		return true
   172  	}
   173  	return false
   174  }
   175  
   176  var (
   177  	errWriteTimeout        = &netError{msg: "websocket: write timeout", timeout: true, temporary: true}
   178  	errUnexpectedEOF       = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
   179  	errBadWriteOpCode      = errors.New("websocket: bad write message type")
   180  	errWriteClosed         = errors.New("websocket: write closed")
   181  	errInvalidControlFrame = errors.New("websocket: invalid control frame")
   182  )
   183  
   184  func newMaskKey() [4]byte {
   185  	n := rand.Uint32()
   186  	return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
   187  }
   188  
   189  func hideTempErr(err error) error {
   190  	if e, ok := err.(net.Error); ok && e.Temporary() {
   191  		err = &netError{msg: e.Error(), timeout: e.Timeout()}
   192  	}
   193  	return err
   194  }
   195  
   196  func isControl(frameType int) bool {
   197  	return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
   198  }
   199  
   200  func isData(frameType int) bool {
   201  	return frameType == TextMessage || frameType == BinaryMessage
   202  }
   203  
   204  var validReceivedCloseCodes = map[int]bool{
   205  	// see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
   206  
   207  	CloseNormalClosure:           true,
   208  	CloseGoingAway:               true,
   209  	CloseProtocolError:           true,
   210  	CloseUnsupportedData:         true,
   211  	CloseNoStatusReceived:        false,
   212  	CloseAbnormalClosure:         false,
   213  	CloseInvalidFramePayloadData: true,
   214  	ClosePolicyViolation:         true,
   215  	CloseMessageTooBig:           true,
   216  	CloseMandatoryExtension:      true,
   217  	CloseInternalServerErr:       true,
   218  	CloseServiceRestart:          true,
   219  	CloseTryAgainLater:           true,
   220  	CloseTLSHandshake:            false,
   221  }
   222  
   223  func isValidReceivedCloseCode(code int) bool {
   224  	return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
   225  }
   226  
   227  // BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
   228  // interface.  The type of the value stored in a pool is not specified.
   229  type BufferPool interface {
   230  	// Get gets a value from the pool or returns nil if the pool is empty.
   231  	Get() interface{}
   232  	// Put adds a value to the pool.
   233  	Put(interface{})
   234  }
   235  
   236  // writePoolData is the type added to the write buffer pool. This wrapper is
   237  // used to prevent applications from peeking at and depending on the values
   238  // added to the pool.
   239  type writePoolData struct{ buf []byte }
   240  
   241  // The Conn type represents a WebSocket connection.
   242  type Conn struct {
   243  	conn        net.Conn
   244  	isServer    bool
   245  	subprotocol string
   246  
   247  	// Write fields
   248  	mu            chan struct{} // used as mutex to protect write to conn
   249  	writeBuf      []byte        // frame is constructed in this buffer.
   250  	writePool     BufferPool
   251  	writeBufSize  int
   252  	writeDeadline time.Time
   253  	writer        io.WriteCloser // the current writer returned to the application
   254  	isWriting     bool           // for best-effort concurrent write detection
   255  
   256  	writeErrMu sync.Mutex
   257  	writeErr   error
   258  
   259  	enableWriteCompression bool
   260  	compressionLevel       int
   261  	newCompressionWriter   func(io.WriteCloser, int) io.WriteCloser
   262  
   263  	// Read fields
   264  	reader  io.ReadCloser // the current reader returned to the application
   265  	readErr error
   266  	br      *bufio.Reader
   267  	// bytes remaining in current frame.
   268  	// set setReadRemaining to safely update this value and prevent overflow
   269  	readRemaining int64
   270  	readFinal     bool  // true the current message has more frames.
   271  	readLength    int64 // Message size.
   272  	readLimit     int64 // Maximum message size.
   273  	readMaskPos   int
   274  	readMaskKey   [4]byte
   275  	handlePong    func(string) error
   276  	handlePing    func(string) error
   277  	handleClose   func(int, string) error
   278  	readErrCount  int
   279  	messageReader *messageReader // the current low-level reader
   280  
   281  	readDecompress         bool // whether last read frame had RSV1 set
   282  	newDecompressionReader func(io.Reader) io.ReadCloser
   283  }
   284  
   285  func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
   286  
   287  	if br == nil {
   288  		if readBufferSize == 0 {
   289  			readBufferSize = defaultReadBufferSize
   290  		} else if readBufferSize < maxControlFramePayloadSize {
   291  			// must be large enough for control frame
   292  			readBufferSize = maxControlFramePayloadSize
   293  		}
   294  		br = bufio.NewReaderSize(conn, readBufferSize)
   295  	}
   296  
   297  	if writeBufferSize <= 0 {
   298  		writeBufferSize = defaultWriteBufferSize
   299  	}
   300  	writeBufferSize += maxFrameHeaderSize
   301  
   302  	if writeBuf == nil && writeBufferPool == nil {
   303  		writeBuf = make([]byte, writeBufferSize)
   304  	}
   305  
   306  	mu := make(chan struct{}, 1)
   307  	mu <- struct{}{}
   308  	c := &Conn{
   309  		isServer:               isServer,
   310  		br:                     br,
   311  		conn:                   conn,
   312  		mu:                     mu,
   313  		readFinal:              true,
   314  		writeBuf:               writeBuf,
   315  		writePool:              writeBufferPool,
   316  		writeBufSize:           writeBufferSize,
   317  		enableWriteCompression: true,
   318  		compressionLevel:       defaultCompressionLevel,
   319  	}
   320  	c.SetCloseHandler(nil)
   321  	c.SetPingHandler(nil)
   322  	c.SetPongHandler(nil)
   323  	return c
   324  }
   325  
   326  // setReadRemaining tracks the number of bytes remaining on the connection. If n
   327  // overflows, an ErrReadLimit is returned.
   328  func (c *Conn) setReadRemaining(n int64) error {
   329  	if n < 0 {
   330  		return ErrReadLimit
   331  	}
   332  
   333  	c.readRemaining = n
   334  	return nil
   335  }
   336  
   337  // Subprotocol returns the negotiated protocol for the connection.
   338  func (c *Conn) Subprotocol() string {
   339  	return c.subprotocol
   340  }
   341  
   342  // Close closes the underlying network connection without sending or waiting
   343  // for a close message.
   344  func (c *Conn) Close() error {
   345  	return c.conn.Close()
   346  }
   347  
   348  // LocalAddr returns the local network address.
   349  func (c *Conn) LocalAddr() net.Addr {
   350  	return c.conn.LocalAddr()
   351  }
   352  
   353  // RemoteAddr returns the remote network address.
   354  func (c *Conn) RemoteAddr() net.Addr {
   355  	return c.conn.RemoteAddr()
   356  }
   357  
   358  // Write methods
   359  
   360  func (c *Conn) writeFatal(err error) error {
   361  	err = hideTempErr(err)
   362  	c.writeErrMu.Lock()
   363  	if c.writeErr == nil {
   364  		c.writeErr = err
   365  	}
   366  	c.writeErrMu.Unlock()
   367  	return err
   368  }
   369  
   370  func (c *Conn) read(n int) ([]byte, error) {
   371  	p, err := c.br.Peek(n)
   372  	if err == io.EOF {
   373  		err = errUnexpectedEOF
   374  	}
   375  	c.br.Discard(len(p))
   376  	return p, err
   377  }
   378  
   379  func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
   380  	<-c.mu
   381  	defer func() { c.mu <- struct{}{} }()
   382  
   383  	c.writeErrMu.Lock()
   384  	err := c.writeErr
   385  	c.writeErrMu.Unlock()
   386  	if err != nil {
   387  		return err
   388  	}
   389  
   390  	c.conn.SetWriteDeadline(deadline)
   391  	if len(buf1) == 0 {
   392  		_, err = c.conn.Write(buf0)
   393  	} else {
   394  		err = c.writeBufs(buf0, buf1)
   395  	}
   396  	if err != nil {
   397  		return c.writeFatal(err)
   398  	}
   399  	if frameType == CloseMessage {
   400  		c.writeFatal(ErrCloseSent)
   401  	}
   402  	return nil
   403  }
   404  
   405  func (c *Conn) writeBufs(bufs ...[]byte) error {
   406  	b := net.Buffers(bufs)
   407  	_, err := b.WriteTo(c.conn)
   408  	return err
   409  }
   410  
   411  // WriteControl writes a control message with the given deadline. The allowed
   412  // message types are CloseMessage, PingMessage and PongMessage.
   413  func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
   414  	if !isControl(messageType) {
   415  		return errBadWriteOpCode
   416  	}
   417  	if len(data) > maxControlFramePayloadSize {
   418  		return errInvalidControlFrame
   419  	}
   420  
   421  	b0 := byte(messageType) | finalBit
   422  	b1 := byte(len(data))
   423  	if !c.isServer {
   424  		b1 |= maskBit
   425  	}
   426  
   427  	buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
   428  	buf = append(buf, b0, b1)
   429  
   430  	if c.isServer {
   431  		buf = append(buf, data...)
   432  	} else {
   433  		key := newMaskKey()
   434  		buf = append(buf, key[:]...)
   435  		buf = append(buf, data...)
   436  		maskBytes(key, 0, buf[6:])
   437  	}
   438  
   439  	d := 1000 * time.Hour
   440  	if !deadline.IsZero() {
   441  		d = deadline.Sub(time.Now())
   442  		if d < 0 {
   443  			return errWriteTimeout
   444  		}
   445  	}
   446  
   447  	timer := time.NewTimer(d)
   448  	select {
   449  	case <-c.mu:
   450  		timer.Stop()
   451  	case <-timer.C:
   452  		return errWriteTimeout
   453  	}
   454  	defer func() { c.mu <- struct{}{} }()
   455  
   456  	c.writeErrMu.Lock()
   457  	err := c.writeErr
   458  	c.writeErrMu.Unlock()
   459  	if err != nil {
   460  		return err
   461  	}
   462  
   463  	c.conn.SetWriteDeadline(deadline)
   464  	_, err = c.conn.Write(buf)
   465  	if err != nil {
   466  		return c.writeFatal(err)
   467  	}
   468  	if messageType == CloseMessage {
   469  		c.writeFatal(ErrCloseSent)
   470  	}
   471  	return err
   472  }
   473  
   474  // beginMessage prepares a connection and message writer for a new message.
   475  func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
   476  	// Close previous writer if not already closed by the application. It's
   477  	// probably better to return an error in this situation, but we cannot
   478  	// change this without breaking existing applications.
   479  	if c.writer != nil {
   480  		c.writer.Close()
   481  		c.writer = nil
   482  	}
   483  
   484  	if !isControl(messageType) && !isData(messageType) {
   485  		return errBadWriteOpCode
   486  	}
   487  
   488  	c.writeErrMu.Lock()
   489  	err := c.writeErr
   490  	c.writeErrMu.Unlock()
   491  	if err != nil {
   492  		return err
   493  	}
   494  
   495  	mw.c = c
   496  	mw.frameType = messageType
   497  	mw.pos = maxFrameHeaderSize
   498  
   499  	if c.writeBuf == nil {
   500  		wpd, ok := c.writePool.Get().(writePoolData)
   501  		if ok {
   502  			c.writeBuf = wpd.buf
   503  		} else {
   504  			c.writeBuf = make([]byte, c.writeBufSize)
   505  		}
   506  	}
   507  	return nil
   508  }
   509  
   510  // NextWriter returns a writer for the next message to send. The writer's Close
   511  // method flushes the complete message to the network.
   512  //
   513  // There can be at most one open writer on a connection. NextWriter closes the
   514  // previous writer if the application has not already done so.
   515  //
   516  // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
   517  // PongMessage) are supported.
   518  func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
   519  	var mw messageWriter
   520  	if err := c.beginMessage(&mw, messageType); err != nil {
   521  		return nil, err
   522  	}
   523  	c.writer = &mw
   524  	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
   525  		w := c.newCompressionWriter(c.writer, c.compressionLevel)
   526  		mw.compress = true
   527  		c.writer = w
   528  	}
   529  	return c.writer, nil
   530  }
   531  
   532  type messageWriter struct {
   533  	c         *Conn
   534  	compress  bool // whether next call to flushFrame should set RSV1
   535  	pos       int  // end of data in writeBuf.
   536  	frameType int  // type of the current frame.
   537  	err       error
   538  }
   539  
   540  func (w *messageWriter) endMessage(err error) error {
   541  	if w.err != nil {
   542  		return err
   543  	}
   544  	c := w.c
   545  	w.err = err
   546  	c.writer = nil
   547  	if c.writePool != nil {
   548  		c.writePool.Put(writePoolData{buf: c.writeBuf})
   549  		c.writeBuf = nil
   550  	}
   551  	return err
   552  }
   553  
   554  // flushFrame writes buffered data and extra as a frame to the network. The
   555  // final argument indicates that this is the last frame in the message.
   556  func (w *messageWriter) flushFrame(final bool, extra []byte) error {
   557  	c := w.c
   558  	length := w.pos - maxFrameHeaderSize + len(extra)
   559  
   560  	// Check for invalid control frames.
   561  	if isControl(w.frameType) &&
   562  		(!final || length > maxControlFramePayloadSize) {
   563  		return w.endMessage(errInvalidControlFrame)
   564  	}
   565  
   566  	b0 := byte(w.frameType)
   567  	if final {
   568  		b0 |= finalBit
   569  	}
   570  	if w.compress {
   571  		b0 |= rsv1Bit
   572  	}
   573  	w.compress = false
   574  
   575  	b1 := byte(0)
   576  	if !c.isServer {
   577  		b1 |= maskBit
   578  	}
   579  
   580  	// Assume that the frame starts at beginning of c.writeBuf.
   581  	framePos := 0
   582  	if c.isServer {
   583  		// Adjust up if mask not included in the header.
   584  		framePos = 4
   585  	}
   586  
   587  	switch {
   588  	case length >= 65536:
   589  		c.writeBuf[framePos] = b0
   590  		c.writeBuf[framePos+1] = b1 | 127
   591  		binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
   592  	case length > 125:
   593  		framePos += 6
   594  		c.writeBuf[framePos] = b0
   595  		c.writeBuf[framePos+1] = b1 | 126
   596  		binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
   597  	default:
   598  		framePos += 8
   599  		c.writeBuf[framePos] = b0
   600  		c.writeBuf[framePos+1] = b1 | byte(length)
   601  	}
   602  
   603  	if !c.isServer {
   604  		key := newMaskKey()
   605  		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
   606  		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
   607  		if len(extra) > 0 {
   608  			return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
   609  		}
   610  	}
   611  
   612  	// Write the buffers to the connection with best-effort detection of
   613  	// concurrent writes. See the concurrency section in the package
   614  	// documentation for more info.
   615  
   616  	if c.isWriting {
   617  		panic("concurrent write to websocket connection")
   618  	}
   619  	c.isWriting = true
   620  
   621  	err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
   622  
   623  	if !c.isWriting {
   624  		panic("concurrent write to websocket connection")
   625  	}
   626  	c.isWriting = false
   627  
   628  	if err != nil {
   629  		return w.endMessage(err)
   630  	}
   631  
   632  	if final {
   633  		w.endMessage(errWriteClosed)
   634  		return nil
   635  	}
   636  
   637  	// Setup for next frame.
   638  	w.pos = maxFrameHeaderSize
   639  	w.frameType = continuationFrame
   640  	return nil
   641  }
   642  
   643  func (w *messageWriter) ncopy(max int) (int, error) {
   644  	n := len(w.c.writeBuf) - w.pos
   645  	if n <= 0 {
   646  		if err := w.flushFrame(false, nil); err != nil {
   647  			return 0, err
   648  		}
   649  		n = len(w.c.writeBuf) - w.pos
   650  	}
   651  	if n > max {
   652  		n = max
   653  	}
   654  	return n, nil
   655  }
   656  
   657  func (w *messageWriter) Write(p []byte) (int, error) {
   658  	if w.err != nil {
   659  		return 0, w.err
   660  	}
   661  
   662  	if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
   663  		// Don't buffer large messages.
   664  		err := w.flushFrame(false, p)
   665  		if err != nil {
   666  			return 0, err
   667  		}
   668  		return len(p), nil
   669  	}
   670  
   671  	nn := len(p)
   672  	for len(p) > 0 {
   673  		n, err := w.ncopy(len(p))
   674  		if err != nil {
   675  			return 0, err
   676  		}
   677  		copy(w.c.writeBuf[w.pos:], p[:n])
   678  		w.pos += n
   679  		p = p[n:]
   680  	}
   681  	return nn, nil
   682  }
   683  
   684  func (w *messageWriter) WriteString(p string) (int, error) {
   685  	if w.err != nil {
   686  		return 0, w.err
   687  	}
   688  
   689  	nn := len(p)
   690  	for len(p) > 0 {
   691  		n, err := w.ncopy(len(p))
   692  		if err != nil {
   693  			return 0, err
   694  		}
   695  		copy(w.c.writeBuf[w.pos:], p[:n])
   696  		w.pos += n
   697  		p = p[n:]
   698  	}
   699  	return nn, nil
   700  }
   701  
   702  func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
   703  	if w.err != nil {
   704  		return 0, w.err
   705  	}
   706  	for {
   707  		if w.pos == len(w.c.writeBuf) {
   708  			err = w.flushFrame(false, nil)
   709  			if err != nil {
   710  				break
   711  			}
   712  		}
   713  		var n int
   714  		n, err = r.Read(w.c.writeBuf[w.pos:])
   715  		w.pos += n
   716  		nn += int64(n)
   717  		if err != nil {
   718  			if err == io.EOF {
   719  				err = nil
   720  			}
   721  			break
   722  		}
   723  	}
   724  	return nn, err
   725  }
   726  
   727  func (w *messageWriter) Close() error {
   728  	if w.err != nil {
   729  		return w.err
   730  	}
   731  	return w.flushFrame(true, nil)
   732  }
   733  
   734  // WritePreparedMessage writes prepared message into connection.
   735  func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
   736  	frameType, frameData, err := pm.frame(prepareKey{
   737  		isServer:         c.isServer,
   738  		compress:         c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
   739  		compressionLevel: c.compressionLevel,
   740  	})
   741  	if err != nil {
   742  		return err
   743  	}
   744  	if c.isWriting {
   745  		panic("concurrent write to websocket connection")
   746  	}
   747  	c.isWriting = true
   748  	err = c.write(frameType, c.writeDeadline, frameData, nil)
   749  	if !c.isWriting {
   750  		panic("concurrent write to websocket connection")
   751  	}
   752  	c.isWriting = false
   753  	return err
   754  }
   755  
   756  // WriteMessage is a helper method for getting a writer using NextWriter,
   757  // writing the message and closing the writer.
   758  func (c *Conn) WriteMessage(messageType int, data []byte) error {
   759  
   760  	if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
   761  		// Fast path with no allocations and single frame.
   762  
   763  		var mw messageWriter
   764  		if err := c.beginMessage(&mw, messageType); err != nil {
   765  			return err
   766  		}
   767  		n := copy(c.writeBuf[mw.pos:], data)
   768  		mw.pos += n
   769  		data = data[n:]
   770  		return mw.flushFrame(true, data)
   771  	}
   772  
   773  	w, err := c.NextWriter(messageType)
   774  	if err != nil {
   775  		return err
   776  	}
   777  	if _, err = w.Write(data); err != nil {
   778  		return err
   779  	}
   780  	return w.Close()
   781  }
   782  
   783  // SetWriteDeadline sets the write deadline on the underlying network
   784  // connection. After a write has timed out, the websocket state is corrupt and
   785  // all future writes will return an error. A zero value for t means writes will
   786  // not time out.
   787  func (c *Conn) SetWriteDeadline(t time.Time) error {
   788  	c.writeDeadline = t
   789  	return nil
   790  }
   791  
   792  // Read methods
   793  
   794  func (c *Conn) advanceFrame() (int, error) {
   795  	// 1. Skip remainder of previous frame.
   796  
   797  	if c.readRemaining > 0 {
   798  		if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
   799  			return noFrame, err
   800  		}
   801  	}
   802  
   803  	// 2. Read and parse first two bytes of frame header.
   804  	// To aid debugging, collect and report all errors in the first two bytes
   805  	// of the header.
   806  
   807  	var errors []string
   808  
   809  	p, err := c.read(2)
   810  	if err != nil {
   811  		return noFrame, err
   812  	}
   813  
   814  	frameType := int(p[0] & 0xf)
   815  	final := p[0]&finalBit != 0
   816  	rsv1 := p[0]&rsv1Bit != 0
   817  	rsv2 := p[0]&rsv2Bit != 0
   818  	rsv3 := p[0]&rsv3Bit != 0
   819  	mask := p[1]&maskBit != 0
   820  	c.setReadRemaining(int64(p[1] & 0x7f))
   821  
   822  	c.readDecompress = false
   823  	if rsv1 {
   824  		if c.newDecompressionReader != nil {
   825  			c.readDecompress = true
   826  		} else {
   827  			errors = append(errors, "RSV1 set")
   828  		}
   829  	}
   830  
   831  	if rsv2 {
   832  		errors = append(errors, "RSV2 set")
   833  	}
   834  
   835  	if rsv3 {
   836  		errors = append(errors, "RSV3 set")
   837  	}
   838  
   839  	switch frameType {
   840  	case CloseMessage, PingMessage, PongMessage:
   841  		if c.readRemaining > maxControlFramePayloadSize {
   842  			errors = append(errors, "len > 125 for control")
   843  		}
   844  		if !final {
   845  			errors = append(errors, "FIN not set on control")
   846  		}
   847  	case TextMessage, BinaryMessage:
   848  		if !c.readFinal {
   849  			errors = append(errors, "data before FIN")
   850  		}
   851  		c.readFinal = final
   852  	case continuationFrame:
   853  		if c.readFinal {
   854  			errors = append(errors, "continuation after FIN")
   855  		}
   856  		c.readFinal = final
   857  	default:
   858  		errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
   859  	}
   860  
   861  	if mask != c.isServer {
   862  		errors = append(errors, "bad MASK")
   863  	}
   864  
   865  	if len(errors) > 0 {
   866  		return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
   867  	}
   868  
   869  	// 3. Read and parse frame length as per
   870  	// https://tools.ietf.org/html/rfc6455#section-5.2
   871  	//
   872  	// The length of the "Payload data", in bytes: if 0-125, that is the payload
   873  	// length.
   874  	// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
   875  	// integer are the payload length.
   876  	// - If 127, the following 8 bytes interpreted as
   877  	// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
   878  	// payload length. Multibyte length quantities are expressed in network byte
   879  	// order.
   880  
   881  	switch c.readRemaining {
   882  	case 126:
   883  		p, err := c.read(2)
   884  		if err != nil {
   885  			return noFrame, err
   886  		}
   887  
   888  		if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
   889  			return noFrame, err
   890  		}
   891  	case 127:
   892  		p, err := c.read(8)
   893  		if err != nil {
   894  			return noFrame, err
   895  		}
   896  
   897  		if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
   898  			return noFrame, err
   899  		}
   900  	}
   901  
   902  	// 4. Handle frame masking.
   903  
   904  	if mask {
   905  		c.readMaskPos = 0
   906  		p, err := c.read(len(c.readMaskKey))
   907  		if err != nil {
   908  			return noFrame, err
   909  		}
   910  		copy(c.readMaskKey[:], p)
   911  	}
   912  
   913  	// 5. For text and binary messages, enforce read limit and return.
   914  
   915  	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
   916  
   917  		c.readLength += c.readRemaining
   918  		// Don't allow readLength to overflow in the presence of a large readRemaining
   919  		// counter.
   920  		if c.readLength < 0 {
   921  			return noFrame, ErrReadLimit
   922  		}
   923  
   924  		if c.readLimit > 0 && c.readLength > c.readLimit {
   925  			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
   926  			return noFrame, ErrReadLimit
   927  		}
   928  
   929  		return frameType, nil
   930  	}
   931  
   932  	// 6. Read control frame payload.
   933  
   934  	var payload []byte
   935  	if c.readRemaining > 0 {
   936  		payload, err = c.read(int(c.readRemaining))
   937  		c.setReadRemaining(0)
   938  		if err != nil {
   939  			return noFrame, err
   940  		}
   941  		if c.isServer {
   942  			maskBytes(c.readMaskKey, 0, payload)
   943  		}
   944  	}
   945  
   946  	// 7. Process control frame payload.
   947  
   948  	switch frameType {
   949  	case PongMessage:
   950  		if err := c.handlePong(string(payload)); err != nil {
   951  			return noFrame, err
   952  		}
   953  	case PingMessage:
   954  		if err := c.handlePing(string(payload)); err != nil {
   955  			return noFrame, err
   956  		}
   957  	case CloseMessage:
   958  		closeCode := CloseNoStatusReceived
   959  		closeText := ""
   960  		if len(payload) >= 2 {
   961  			closeCode = int(binary.BigEndian.Uint16(payload))
   962  			if !isValidReceivedCloseCode(closeCode) {
   963  				return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
   964  			}
   965  			closeText = string(payload[2:])
   966  			if !utf8.ValidString(closeText) {
   967  				return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
   968  			}
   969  		}
   970  		if err := c.handleClose(closeCode, closeText); err != nil {
   971  			return noFrame, err
   972  		}
   973  		return noFrame, &CloseError{Code: closeCode, Text: closeText}
   974  	}
   975  
   976  	return frameType, nil
   977  }
   978  
   979  func (c *Conn) handleProtocolError(message string) error {
   980  	data := FormatCloseMessage(CloseProtocolError, message)
   981  	if len(data) > maxControlFramePayloadSize {
   982  		data = data[:maxControlFramePayloadSize]
   983  	}
   984  	c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
   985  	return errors.New("websocket: " + message)
   986  }
   987  
   988  // NextReader returns the next data message received from the peer. The
   989  // returned messageType is either TextMessage or BinaryMessage.
   990  //
   991  // There can be at most one open reader on a connection. NextReader discards
   992  // the previous message if the application has not already consumed it.
   993  //
   994  // Applications must break out of the application's read loop when this method
   995  // returns a non-nil error value. Errors returned from this method are
   996  // permanent. Once this method returns a non-nil error, all subsequent calls to
   997  // this method return the same error.
   998  func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
   999  	// Close previous reader, only relevant for decompression.
  1000  	if c.reader != nil {
  1001  		c.reader.Close()
  1002  		c.reader = nil
  1003  	}
  1004  
  1005  	c.messageReader = nil
  1006  	c.readLength = 0
  1007  
  1008  	for c.readErr == nil {
  1009  		frameType, err := c.advanceFrame()
  1010  		if err != nil {
  1011  			c.readErr = hideTempErr(err)
  1012  			break
  1013  		}
  1014  
  1015  		if frameType == TextMessage || frameType == BinaryMessage {
  1016  			c.messageReader = &messageReader{c}
  1017  			c.reader = c.messageReader
  1018  			if c.readDecompress {
  1019  				c.reader = c.newDecompressionReader(c.reader)
  1020  			}
  1021  			return frameType, c.reader, nil
  1022  		}
  1023  	}
  1024  
  1025  	// Applications that do handle the error returned from this method spin in
  1026  	// tight loop on connection failure. To help application developers detect
  1027  	// this error, panic on repeated reads to the failed connection.
  1028  	c.readErrCount++
  1029  	if c.readErrCount >= 1000 {
  1030  		panic("repeated read on failed websocket connection")
  1031  	}
  1032  
  1033  	return noFrame, nil, c.readErr
  1034  }
  1035  
  1036  type messageReader struct{ c *Conn }
  1037  
  1038  func (r *messageReader) Read(b []byte) (int, error) {
  1039  	c := r.c
  1040  	if c.messageReader != r {
  1041  		return 0, io.EOF
  1042  	}
  1043  
  1044  	for c.readErr == nil {
  1045  
  1046  		if c.readRemaining > 0 {
  1047  			if int64(len(b)) > c.readRemaining {
  1048  				b = b[:c.readRemaining]
  1049  			}
  1050  			n, err := c.br.Read(b)
  1051  			c.readErr = hideTempErr(err)
  1052  			if c.isServer {
  1053  				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
  1054  			}
  1055  			rem := c.readRemaining
  1056  			rem -= int64(n)
  1057  			c.setReadRemaining(rem)
  1058  			if c.readRemaining > 0 && c.readErr == io.EOF {
  1059  				c.readErr = errUnexpectedEOF
  1060  			}
  1061  			return n, c.readErr
  1062  		}
  1063  
  1064  		if c.readFinal {
  1065  			c.messageReader = nil
  1066  			return 0, io.EOF
  1067  		}
  1068  
  1069  		frameType, err := c.advanceFrame()
  1070  		switch {
  1071  		case err != nil:
  1072  			c.readErr = hideTempErr(err)
  1073  		case frameType == TextMessage || frameType == BinaryMessage:
  1074  			c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
  1075  		}
  1076  	}
  1077  
  1078  	err := c.readErr
  1079  	if err == io.EOF && c.messageReader == r {
  1080  		err = errUnexpectedEOF
  1081  	}
  1082  	return 0, err
  1083  }
  1084  
  1085  func (r *messageReader) Close() error {
  1086  	return nil
  1087  }
  1088  
  1089  // ReadMessage is a helper method for getting a reader using NextReader and
  1090  // reading from that reader to a buffer.
  1091  func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
  1092  	var r io.Reader
  1093  	messageType, r, err = c.NextReader()
  1094  	if err != nil {
  1095  		return messageType, nil, err
  1096  	}
  1097  	p, err = ioutil.ReadAll(r)
  1098  	return messageType, p, err
  1099  }
  1100  
  1101  // SetReadDeadline sets the read deadline on the underlying network connection.
  1102  // After a read has timed out, the websocket connection state is corrupt and
  1103  // all future reads will return an error. A zero value for t means reads will
  1104  // not time out.
  1105  func (c *Conn) SetReadDeadline(t time.Time) error {
  1106  	return c.conn.SetReadDeadline(t)
  1107  }
  1108  
  1109  // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
  1110  // message exceeds the limit, the connection sends a close message to the peer
  1111  // and returns ErrReadLimit to the application.
  1112  func (c *Conn) SetReadLimit(limit int64) {
  1113  	c.readLimit = limit
  1114  }
  1115  
  1116  // CloseHandler returns the current close handler
  1117  func (c *Conn) CloseHandler() func(code int, text string) error {
  1118  	return c.handleClose
  1119  }
  1120  
  1121  // SetCloseHandler sets the handler for close messages received from the peer.
  1122  // The code argument to h is the received close code or CloseNoStatusReceived
  1123  // if the close message is empty. The default close handler sends a close
  1124  // message back to the peer.
  1125  //
  1126  // The handler function is called from the NextReader, ReadMessage and message
  1127  // reader Read methods. The application must read the connection to process
  1128  // close messages as described in the section on Control Messages above.
  1129  //
  1130  // The connection read methods return a CloseError when a close message is
  1131  // received. Most applications should handle close messages as part of their
  1132  // normal error handling. Applications should only set a close handler when the
  1133  // application must perform some action before sending a close message back to
  1134  // the peer.
  1135  func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
  1136  	if h == nil {
  1137  		h = func(code int, text string) error {
  1138  			message := FormatCloseMessage(code, "")
  1139  			c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
  1140  			return nil
  1141  		}
  1142  	}
  1143  	c.handleClose = h
  1144  }
  1145  
  1146  // PingHandler returns the current ping handler
  1147  func (c *Conn) PingHandler() func(appData string) error {
  1148  	return c.handlePing
  1149  }
  1150  
  1151  // SetPingHandler sets the handler for ping messages received from the peer.
  1152  // The appData argument to h is the PING message application data. The default
  1153  // ping handler sends a pong to the peer.
  1154  //
  1155  // The handler function is called from the NextReader, ReadMessage and message
  1156  // reader Read methods. The application must read the connection to process
  1157  // ping messages as described in the section on Control Messages above.
  1158  func (c *Conn) SetPingHandler(h func(appData string) error) {
  1159  	if h == nil {
  1160  		h = func(message string) error {
  1161  			err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
  1162  			if err == ErrCloseSent {
  1163  				return nil
  1164  			} else if e, ok := err.(net.Error); ok && e.Temporary() {
  1165  				return nil
  1166  			}
  1167  			return err
  1168  		}
  1169  	}
  1170  	c.handlePing = h
  1171  }
  1172  
  1173  // PongHandler returns the current pong handler
  1174  func (c *Conn) PongHandler() func(appData string) error {
  1175  	return c.handlePong
  1176  }
  1177  
  1178  // SetPongHandler sets the handler for pong messages received from the peer.
  1179  // The appData argument to h is the PONG message application data. The default
  1180  // pong handler does nothing.
  1181  //
  1182  // The handler function is called from the NextReader, ReadMessage and message
  1183  // reader Read methods. The application must read the connection to process
  1184  // pong messages as described in the section on Control Messages above.
  1185  func (c *Conn) SetPongHandler(h func(appData string) error) {
  1186  	if h == nil {
  1187  		h = func(string) error { return nil }
  1188  	}
  1189  	c.handlePong = h
  1190  }
  1191  
  1192  // NetConn returns the underlying connection that is wrapped by c.
  1193  // Note that writing to or reading from this connection directly will corrupt the
  1194  // WebSocket connection.
  1195  func (c *Conn) NetConn() net.Conn {
  1196  	return c.conn
  1197  }
  1198  
  1199  // UnderlyingConn returns the internal net.Conn. This can be used to further
  1200  // modifications to connection specific flags.
  1201  // Deprecated: Use the NetConn method.
  1202  func (c *Conn) UnderlyingConn() net.Conn {
  1203  	return c.conn
  1204  }
  1205  
  1206  // EnableWriteCompression enables and disables write compression of
  1207  // subsequent text and binary messages. This function is a noop if
  1208  // compression was not negotiated with the peer.
  1209  func (c *Conn) EnableWriteCompression(enable bool) {
  1210  	c.enableWriteCompression = enable
  1211  }
  1212  
  1213  // SetCompressionLevel sets the flate compression level for subsequent text and
  1214  // binary messages. This function is a noop if compression was not negotiated
  1215  // with the peer. See the compress/flate package for a description of
  1216  // compression levels.
  1217  func (c *Conn) SetCompressionLevel(level int) error {
  1218  	if !isValidCompressionLevel(level) {
  1219  		return errors.New("websocket: invalid compression level")
  1220  	}
  1221  	c.compressionLevel = level
  1222  	return nil
  1223  }
  1224  
  1225  // FormatCloseMessage formats closeCode and text as a WebSocket close message.
  1226  // An empty message is returned for code CloseNoStatusReceived.
  1227  func FormatCloseMessage(closeCode int, text string) []byte {
  1228  	if closeCode == CloseNoStatusReceived {
  1229  		// Return empty message because it's illegal to send
  1230  		// CloseNoStatusReceived. Return non-nil value in case application
  1231  		// checks for nil.
  1232  		return []byte{}
  1233  	}
  1234  	buf := make([]byte, 2+len(text))
  1235  	binary.BigEndian.PutUint16(buf, uint16(closeCode))
  1236  	copy(buf[2:], text)
  1237  	return buf
  1238  }
  1239  

View as plain text