...

Source file src/nhooyr.io/websocket/conn.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"runtime"
    14  	"strconv"
    15  	"sync"
    16  	"sync/atomic"
    17  )
    18  
    19  // MessageType represents the type of a WebSocket message.
    20  // See https://tools.ietf.org/html/rfc6455#section-5.6
    21  type MessageType int
    22  
    23  // MessageType constants.
    24  const (
    25  	// MessageText is for UTF-8 encoded text messages like JSON.
    26  	MessageText MessageType = iota + 1
    27  	// MessageBinary is for binary messages like protobufs.
    28  	MessageBinary
    29  )
    30  
    31  // Conn represents a WebSocket connection.
    32  // All methods may be called concurrently except for Reader and Read.
    33  //
    34  // You must always read from the connection. Otherwise control
    35  // frames will not be handled. See Reader and CloseRead.
    36  //
    37  // Be sure to call Close on the connection when you
    38  // are finished with it to release associated resources.
    39  //
    40  // On any error from any method, the connection is closed
    41  // with an appropriate reason.
    42  //
    43  // This applies to context expirations as well unfortunately.
    44  // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
    45  type Conn struct {
    46  	noCopy noCopy
    47  
    48  	subprotocol    string
    49  	rwc            io.ReadWriteCloser
    50  	client         bool
    51  	copts          *compressionOptions
    52  	flateThreshold int
    53  	br             *bufio.Reader
    54  	bw             *bufio.Writer
    55  
    56  	readTimeout  chan context.Context
    57  	writeTimeout chan context.Context
    58  
    59  	// Read state.
    60  	readMu            *mu
    61  	readHeaderBuf     [8]byte
    62  	readControlBuf    [maxControlPayload]byte
    63  	msgReader         *msgReader
    64  	readCloseFrameErr error
    65  
    66  	// Write state.
    67  	msgWriter      *msgWriter
    68  	writeFrameMu   *mu
    69  	writeBuf       []byte
    70  	writeHeaderBuf [8]byte
    71  	writeHeader    header
    72  
    73  	wg         sync.WaitGroup
    74  	closed     chan struct{}
    75  	closeMu    sync.Mutex
    76  	closeErr   error
    77  	wroteClose bool
    78  
    79  	pingCounter   int32
    80  	activePingsMu sync.Mutex
    81  	activePings   map[string]chan<- struct{}
    82  }
    83  
    84  type connConfig struct {
    85  	subprotocol    string
    86  	rwc            io.ReadWriteCloser
    87  	client         bool
    88  	copts          *compressionOptions
    89  	flateThreshold int
    90  
    91  	br *bufio.Reader
    92  	bw *bufio.Writer
    93  }
    94  
    95  func newConn(cfg connConfig) *Conn {
    96  	c := &Conn{
    97  		subprotocol:    cfg.subprotocol,
    98  		rwc:            cfg.rwc,
    99  		client:         cfg.client,
   100  		copts:          cfg.copts,
   101  		flateThreshold: cfg.flateThreshold,
   102  
   103  		br: cfg.br,
   104  		bw: cfg.bw,
   105  
   106  		readTimeout:  make(chan context.Context),
   107  		writeTimeout: make(chan context.Context),
   108  
   109  		closed:      make(chan struct{}),
   110  		activePings: make(map[string]chan<- struct{}),
   111  	}
   112  
   113  	c.readMu = newMu(c)
   114  	c.writeFrameMu = newMu(c)
   115  
   116  	c.msgReader = newMsgReader(c)
   117  
   118  	c.msgWriter = newMsgWriter(c)
   119  	if c.client {
   120  		c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
   121  	}
   122  
   123  	if c.flate() && c.flateThreshold == 0 {
   124  		c.flateThreshold = 128
   125  		if !c.msgWriter.flateContextTakeover() {
   126  			c.flateThreshold = 512
   127  		}
   128  	}
   129  
   130  	runtime.SetFinalizer(c, func(c *Conn) {
   131  		c.close(errors.New("connection garbage collected"))
   132  	})
   133  
   134  	c.wg.Add(1)
   135  	go func() {
   136  		defer c.wg.Done()
   137  		c.timeoutLoop()
   138  	}()
   139  
   140  	return c
   141  }
   142  
   143  // Subprotocol returns the negotiated subprotocol.
   144  // An empty string means the default protocol.
   145  func (c *Conn) Subprotocol() string {
   146  	return c.subprotocol
   147  }
   148  
   149  func (c *Conn) close(err error) {
   150  	c.closeMu.Lock()
   151  	defer c.closeMu.Unlock()
   152  
   153  	if c.isClosed() {
   154  		return
   155  	}
   156  	if err == nil {
   157  		err = c.rwc.Close()
   158  	}
   159  	c.setCloseErrLocked(err)
   160  
   161  	close(c.closed)
   162  	runtime.SetFinalizer(c, nil)
   163  
   164  	// Have to close after c.closed is closed to ensure any goroutine that wakes up
   165  	// from the connection being closed also sees that c.closed is closed and returns
   166  	// closeErr.
   167  	c.rwc.Close()
   168  
   169  	c.wg.Add(1)
   170  	go func() {
   171  		defer c.wg.Done()
   172  		c.msgWriter.close()
   173  		c.msgReader.close()
   174  	}()
   175  }
   176  
   177  func (c *Conn) timeoutLoop() {
   178  	readCtx := context.Background()
   179  	writeCtx := context.Background()
   180  
   181  	for {
   182  		select {
   183  		case <-c.closed:
   184  			return
   185  
   186  		case writeCtx = <-c.writeTimeout:
   187  		case readCtx = <-c.readTimeout:
   188  
   189  		case <-readCtx.Done():
   190  			c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
   191  			c.wg.Add(1)
   192  			go func() {
   193  				defer c.wg.Done()
   194  				c.writeError(StatusPolicyViolation, errors.New("read timed out"))
   195  			}()
   196  		case <-writeCtx.Done():
   197  			c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
   198  			return
   199  		}
   200  	}
   201  }
   202  
   203  func (c *Conn) flate() bool {
   204  	return c.copts != nil
   205  }
   206  
   207  // Ping sends a ping to the peer and waits for a pong.
   208  // Use this to measure latency or ensure the peer is responsive.
   209  // Ping must be called concurrently with Reader as it does
   210  // not read from the connection but instead waits for a Reader call
   211  // to read the pong.
   212  //
   213  // TCP Keepalives should suffice for most use cases.
   214  func (c *Conn) Ping(ctx context.Context) error {
   215  	p := atomic.AddInt32(&c.pingCounter, 1)
   216  
   217  	err := c.ping(ctx, strconv.Itoa(int(p)))
   218  	if err != nil {
   219  		return fmt.Errorf("failed to ping: %w", err)
   220  	}
   221  	return nil
   222  }
   223  
   224  func (c *Conn) ping(ctx context.Context, p string) error {
   225  	pong := make(chan struct{}, 1)
   226  
   227  	c.activePingsMu.Lock()
   228  	c.activePings[p] = pong
   229  	c.activePingsMu.Unlock()
   230  
   231  	defer func() {
   232  		c.activePingsMu.Lock()
   233  		delete(c.activePings, p)
   234  		c.activePingsMu.Unlock()
   235  	}()
   236  
   237  	err := c.writeControl(ctx, opPing, []byte(p))
   238  	if err != nil {
   239  		return err
   240  	}
   241  
   242  	select {
   243  	case <-c.closed:
   244  		return net.ErrClosed
   245  	case <-ctx.Done():
   246  		err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
   247  		c.close(err)
   248  		return err
   249  	case <-pong:
   250  		return nil
   251  	}
   252  }
   253  
   254  type mu struct {
   255  	c  *Conn
   256  	ch chan struct{}
   257  }
   258  
   259  func newMu(c *Conn) *mu {
   260  	return &mu{
   261  		c:  c,
   262  		ch: make(chan struct{}, 1),
   263  	}
   264  }
   265  
   266  func (m *mu) forceLock() {
   267  	m.ch <- struct{}{}
   268  }
   269  
   270  func (m *mu) tryLock() bool {
   271  	select {
   272  	case m.ch <- struct{}{}:
   273  		return true
   274  	default:
   275  		return false
   276  	}
   277  }
   278  
   279  func (m *mu) lock(ctx context.Context) error {
   280  	select {
   281  	case <-m.c.closed:
   282  		return net.ErrClosed
   283  	case <-ctx.Done():
   284  		err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
   285  		m.c.close(err)
   286  		return err
   287  	case m.ch <- struct{}{}:
   288  		// To make sure the connection is certainly alive.
   289  		// As it's possible the send on m.ch was selected
   290  		// over the receive on closed.
   291  		select {
   292  		case <-m.c.closed:
   293  			// Make sure to release.
   294  			m.unlock()
   295  			return net.ErrClosed
   296  		default:
   297  		}
   298  		return nil
   299  	}
   300  }
   301  
   302  func (m *mu) unlock() {
   303  	select {
   304  	case <-m.ch:
   305  	default:
   306  	}
   307  }
   308  
   309  type noCopy struct{}
   310  
   311  func (*noCopy) Lock() {}
   312  

View as plain text