...

Source file src/nhooyr.io/websocket/read.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"strings"
    14  	"time"
    15  
    16  	"nhooyr.io/websocket/internal/errd"
    17  	"nhooyr.io/websocket/internal/util"
    18  	"nhooyr.io/websocket/internal/xsync"
    19  )
    20  
    21  // Reader reads from the connection until there is a WebSocket
    22  // data message to be read. It will handle ping, pong and close frames as appropriate.
    23  //
    24  // It returns the type of the message and an io.Reader to read it.
    25  // The passed context will also bound the reader.
    26  // Ensure you read to EOF otherwise the connection will hang.
    27  //
    28  // Call CloseRead if you do not expect any data messages from the peer.
    29  //
    30  // Only one Reader may be open at a time.
    31  //
    32  // If you need a separate timeout on the Reader call and the Read itself,
    33  // use time.AfterFunc to cancel the context passed in.
    34  // See https://github.com/nhooyr/websocket/issues/87#issue-451703332
    35  // Most users should not need this.
    36  func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
    37  	return c.reader(ctx)
    38  }
    39  
    40  // Read is a convenience method around Reader to read a single message
    41  // from the connection.
    42  func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
    43  	typ, r, err := c.Reader(ctx)
    44  	if err != nil {
    45  		return 0, nil, err
    46  	}
    47  
    48  	b, err := io.ReadAll(r)
    49  	return typ, b, err
    50  }
    51  
    52  // CloseRead starts a goroutine to read from the connection until it is closed
    53  // or a data message is received.
    54  //
    55  // Once CloseRead is called you cannot read any messages from the connection.
    56  // The returned context will be cancelled when the connection is closed.
    57  //
    58  // If a data message is received, the connection will be closed with StatusPolicyViolation.
    59  //
    60  // Call CloseRead when you do not expect to read any more messages.
    61  // Since it actively reads from the connection, it will ensure that ping, pong and close
    62  // frames are responded to. This means c.Ping and c.Close will still work as expected.
    63  func (c *Conn) CloseRead(ctx context.Context) context.Context {
    64  	ctx, cancel := context.WithCancel(ctx)
    65  
    66  	c.wg.Add(1)
    67  	go func() {
    68  		defer c.CloseNow()
    69  		defer c.wg.Done()
    70  		defer cancel()
    71  		_, _, err := c.Reader(ctx)
    72  		if err == nil {
    73  			c.Close(StatusPolicyViolation, "unexpected data message")
    74  		}
    75  	}()
    76  	return ctx
    77  }
    78  
    79  // SetReadLimit sets the max number of bytes to read for a single message.
    80  // It applies to the Reader and Read methods.
    81  //
    82  // By default, the connection has a message read limit of 32768 bytes.
    83  //
    84  // When the limit is hit, the connection will be closed with StatusMessageTooBig.
    85  //
    86  // Set to -1 to disable.
    87  func (c *Conn) SetReadLimit(n int64) {
    88  	if n >= 0 {
    89  		// We read one more byte than the limit in case
    90  		// there is a fin frame that needs to be read.
    91  		n++
    92  	}
    93  
    94  	c.msgReader.limitReader.limit.Store(n)
    95  }
    96  
    97  const defaultReadLimit = 32768
    98  
    99  func newMsgReader(c *Conn) *msgReader {
   100  	mr := &msgReader{
   101  		c:   c,
   102  		fin: true,
   103  	}
   104  	mr.readFunc = mr.read
   105  
   106  	mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
   107  	return mr
   108  }
   109  
   110  func (mr *msgReader) resetFlate() {
   111  	if mr.flateContextTakeover() {
   112  		if mr.dict == nil {
   113  			mr.dict = &slidingWindow{}
   114  		}
   115  		mr.dict.init(32768)
   116  	}
   117  	if mr.flateBufio == nil {
   118  		mr.flateBufio = getBufioReader(mr.readFunc)
   119  	}
   120  
   121  	if mr.flateContextTakeover() {
   122  		mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
   123  	} else {
   124  		mr.flateReader = getFlateReader(mr.flateBufio, nil)
   125  	}
   126  	mr.limitReader.r = mr.flateReader
   127  	mr.flateTail.Reset(deflateMessageTail)
   128  }
   129  
   130  func (mr *msgReader) putFlateReader() {
   131  	if mr.flateReader != nil {
   132  		putFlateReader(mr.flateReader)
   133  		mr.flateReader = nil
   134  	}
   135  }
   136  
   137  func (mr *msgReader) close() {
   138  	mr.c.readMu.forceLock()
   139  	mr.putFlateReader()
   140  	if mr.dict != nil {
   141  		mr.dict.close()
   142  		mr.dict = nil
   143  	}
   144  	if mr.flateBufio != nil {
   145  		putBufioReader(mr.flateBufio)
   146  	}
   147  
   148  	if mr.c.client {
   149  		putBufioReader(mr.c.br)
   150  		mr.c.br = nil
   151  	}
   152  }
   153  
   154  func (mr *msgReader) flateContextTakeover() bool {
   155  	if mr.c.client {
   156  		return !mr.c.copts.serverNoContextTakeover
   157  	}
   158  	return !mr.c.copts.clientNoContextTakeover
   159  }
   160  
   161  func (c *Conn) readRSV1Illegal(h header) bool {
   162  	// If compression is disabled, rsv1 is illegal.
   163  	if !c.flate() {
   164  		return true
   165  	}
   166  	// rsv1 is only allowed on data frames beginning messages.
   167  	if h.opcode != opText && h.opcode != opBinary {
   168  		return true
   169  	}
   170  	return false
   171  }
   172  
   173  func (c *Conn) readLoop(ctx context.Context) (header, error) {
   174  	for {
   175  		h, err := c.readFrameHeader(ctx)
   176  		if err != nil {
   177  			return header{}, err
   178  		}
   179  
   180  		if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
   181  			err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
   182  			c.writeError(StatusProtocolError, err)
   183  			return header{}, err
   184  		}
   185  
   186  		if !c.client && !h.masked {
   187  			return header{}, errors.New("received unmasked frame from client")
   188  		}
   189  
   190  		switch h.opcode {
   191  		case opClose, opPing, opPong:
   192  			err = c.handleControl(ctx, h)
   193  			if err != nil {
   194  				// Pass through CloseErrors when receiving a close frame.
   195  				if h.opcode == opClose && CloseStatus(err) != -1 {
   196  					return header{}, err
   197  				}
   198  				return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
   199  			}
   200  		case opContinuation, opText, opBinary:
   201  			return h, nil
   202  		default:
   203  			err := fmt.Errorf("received unknown opcode %v", h.opcode)
   204  			c.writeError(StatusProtocolError, err)
   205  			return header{}, err
   206  		}
   207  	}
   208  }
   209  
   210  func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
   211  	select {
   212  	case <-c.closed:
   213  		return header{}, net.ErrClosed
   214  	case c.readTimeout <- ctx:
   215  	}
   216  
   217  	h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
   218  	if err != nil {
   219  		select {
   220  		case <-c.closed:
   221  			return header{}, net.ErrClosed
   222  		case <-ctx.Done():
   223  			return header{}, ctx.Err()
   224  		default:
   225  			c.close(err)
   226  			return header{}, err
   227  		}
   228  	}
   229  
   230  	select {
   231  	case <-c.closed:
   232  		return header{}, net.ErrClosed
   233  	case c.readTimeout <- context.Background():
   234  	}
   235  
   236  	return h, nil
   237  }
   238  
   239  func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
   240  	select {
   241  	case <-c.closed:
   242  		return 0, net.ErrClosed
   243  	case c.readTimeout <- ctx:
   244  	}
   245  
   246  	n, err := io.ReadFull(c.br, p)
   247  	if err != nil {
   248  		select {
   249  		case <-c.closed:
   250  			return n, net.ErrClosed
   251  		case <-ctx.Done():
   252  			return n, ctx.Err()
   253  		default:
   254  			err = fmt.Errorf("failed to read frame payload: %w", err)
   255  			c.close(err)
   256  			return n, err
   257  		}
   258  	}
   259  
   260  	select {
   261  	case <-c.closed:
   262  		return n, net.ErrClosed
   263  	case c.readTimeout <- context.Background():
   264  	}
   265  
   266  	return n, err
   267  }
   268  
   269  func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
   270  	if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
   271  		err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
   272  		c.writeError(StatusProtocolError, err)
   273  		return err
   274  	}
   275  
   276  	if !h.fin {
   277  		err := errors.New("received fragmented control frame")
   278  		c.writeError(StatusProtocolError, err)
   279  		return err
   280  	}
   281  
   282  	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
   283  	defer cancel()
   284  
   285  	b := c.readControlBuf[:h.payloadLength]
   286  	_, err = c.readFramePayload(ctx, b)
   287  	if err != nil {
   288  		return err
   289  	}
   290  
   291  	if h.masked {
   292  		mask(h.maskKey, b)
   293  	}
   294  
   295  	switch h.opcode {
   296  	case opPing:
   297  		return c.writeControl(ctx, opPong, b)
   298  	case opPong:
   299  		c.activePingsMu.Lock()
   300  		pong, ok := c.activePings[string(b)]
   301  		c.activePingsMu.Unlock()
   302  		if ok {
   303  			select {
   304  			case pong <- struct{}{}:
   305  			default:
   306  			}
   307  		}
   308  		return nil
   309  	}
   310  
   311  	defer func() {
   312  		c.readCloseFrameErr = err
   313  	}()
   314  
   315  	ce, err := parseClosePayload(b)
   316  	if err != nil {
   317  		err = fmt.Errorf("received invalid close payload: %w", err)
   318  		c.writeError(StatusProtocolError, err)
   319  		return err
   320  	}
   321  
   322  	err = fmt.Errorf("received close frame: %w", ce)
   323  	c.setCloseErr(err)
   324  	c.writeClose(ce.Code, ce.Reason)
   325  	c.close(err)
   326  	return err
   327  }
   328  
   329  func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
   330  	defer errd.Wrap(&err, "failed to get reader")
   331  
   332  	err = c.readMu.lock(ctx)
   333  	if err != nil {
   334  		return 0, nil, err
   335  	}
   336  	defer c.readMu.unlock()
   337  
   338  	if !c.msgReader.fin {
   339  		err = errors.New("previous message not read to completion")
   340  		c.close(fmt.Errorf("failed to get reader: %w", err))
   341  		return 0, nil, err
   342  	}
   343  
   344  	h, err := c.readLoop(ctx)
   345  	if err != nil {
   346  		return 0, nil, err
   347  	}
   348  
   349  	if h.opcode == opContinuation {
   350  		err := errors.New("received continuation frame without text or binary frame")
   351  		c.writeError(StatusProtocolError, err)
   352  		return 0, nil, err
   353  	}
   354  
   355  	c.msgReader.reset(ctx, h)
   356  
   357  	return MessageType(h.opcode), c.msgReader, nil
   358  }
   359  
   360  type msgReader struct {
   361  	c *Conn
   362  
   363  	ctx         context.Context
   364  	flate       bool
   365  	flateReader io.Reader
   366  	flateBufio  *bufio.Reader
   367  	flateTail   strings.Reader
   368  	limitReader *limitReader
   369  	dict        *slidingWindow
   370  
   371  	fin           bool
   372  	payloadLength int64
   373  	maskKey       uint32
   374  
   375  	// util.ReaderFunc(mr.Read) to avoid continuous allocations.
   376  	readFunc util.ReaderFunc
   377  }
   378  
   379  func (mr *msgReader) reset(ctx context.Context, h header) {
   380  	mr.ctx = ctx
   381  	mr.flate = h.rsv1
   382  	mr.limitReader.reset(mr.readFunc)
   383  
   384  	if mr.flate {
   385  		mr.resetFlate()
   386  	}
   387  
   388  	mr.setFrame(h)
   389  }
   390  
   391  func (mr *msgReader) setFrame(h header) {
   392  	mr.fin = h.fin
   393  	mr.payloadLength = h.payloadLength
   394  	mr.maskKey = h.maskKey
   395  }
   396  
   397  func (mr *msgReader) Read(p []byte) (n int, err error) {
   398  	err = mr.c.readMu.lock(mr.ctx)
   399  	if err != nil {
   400  		return 0, fmt.Errorf("failed to read: %w", err)
   401  	}
   402  	defer mr.c.readMu.unlock()
   403  
   404  	n, err = mr.limitReader.Read(p)
   405  	if mr.flate && mr.flateContextTakeover() {
   406  		p = p[:n]
   407  		mr.dict.write(p)
   408  	}
   409  	if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
   410  		mr.putFlateReader()
   411  		return n, io.EOF
   412  	}
   413  	if err != nil {
   414  		err = fmt.Errorf("failed to read: %w", err)
   415  		mr.c.close(err)
   416  	}
   417  	return n, err
   418  }
   419  
   420  func (mr *msgReader) read(p []byte) (int, error) {
   421  	for {
   422  		if mr.payloadLength == 0 {
   423  			if mr.fin {
   424  				if mr.flate {
   425  					return mr.flateTail.Read(p)
   426  				}
   427  				return 0, io.EOF
   428  			}
   429  
   430  			h, err := mr.c.readLoop(mr.ctx)
   431  			if err != nil {
   432  				return 0, err
   433  			}
   434  			if h.opcode != opContinuation {
   435  				err := errors.New("received new data message without finishing the previous message")
   436  				mr.c.writeError(StatusProtocolError, err)
   437  				return 0, err
   438  			}
   439  			mr.setFrame(h)
   440  
   441  			continue
   442  		}
   443  
   444  		if int64(len(p)) > mr.payloadLength {
   445  			p = p[:mr.payloadLength]
   446  		}
   447  
   448  		n, err := mr.c.readFramePayload(mr.ctx, p)
   449  		if err != nil {
   450  			return n, err
   451  		}
   452  
   453  		mr.payloadLength -= int64(n)
   454  
   455  		if !mr.c.client {
   456  			mr.maskKey = mask(mr.maskKey, p)
   457  		}
   458  
   459  		return n, nil
   460  	}
   461  }
   462  
   463  type limitReader struct {
   464  	c     *Conn
   465  	r     io.Reader
   466  	limit xsync.Int64
   467  	n     int64
   468  }
   469  
   470  func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
   471  	lr := &limitReader{
   472  		c: c,
   473  	}
   474  	lr.limit.Store(limit)
   475  	lr.reset(r)
   476  	return lr
   477  }
   478  
   479  func (lr *limitReader) reset(r io.Reader) {
   480  	lr.n = lr.limit.Load()
   481  	lr.r = r
   482  }
   483  
   484  func (lr *limitReader) Read(p []byte) (int, error) {
   485  	if lr.n < 0 {
   486  		return lr.r.Read(p)
   487  	}
   488  
   489  	if lr.n == 0 {
   490  		err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
   491  		lr.c.writeError(StatusMessageTooBig, err)
   492  		return 0, err
   493  	}
   494  
   495  	if int64(len(p)) > lr.n {
   496  		p = p[:lr.n]
   497  	}
   498  	n, err := lr.r.Read(p)
   499  	lr.n -= int64(n)
   500  	if lr.n < 0 {
   501  		lr.n = 0
   502  	}
   503  	return n, err
   504  }
   505  

View as plain text