...

Source file src/nhooyr.io/websocket/close.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket
     5  
     6  import (
     7  	"context"
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"time"
    13  
    14  	"nhooyr.io/websocket/internal/errd"
    15  )
    16  
    17  // StatusCode represents a WebSocket status code.
    18  // https://tools.ietf.org/html/rfc6455#section-7.4
    19  type StatusCode int
    20  
    21  // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
    22  //
    23  // These are only the status codes defined by the protocol.
    24  //
    25  // You can define custom codes in the 3000-4999 range.
    26  // The 3000-3999 range is reserved for use by libraries, frameworks and applications.
    27  // The 4000-4999 range is reserved for private use.
    28  const (
    29  	StatusNormalClosure   StatusCode = 1000
    30  	StatusGoingAway       StatusCode = 1001
    31  	StatusProtocolError   StatusCode = 1002
    32  	StatusUnsupportedData StatusCode = 1003
    33  
    34  	// 1004 is reserved and so unexported.
    35  	statusReserved StatusCode = 1004
    36  
    37  	// StatusNoStatusRcvd cannot be sent in a close message.
    38  	// It is reserved for when a close message is received without
    39  	// a status code.
    40  	StatusNoStatusRcvd StatusCode = 1005
    41  
    42  	// StatusAbnormalClosure is exported for use only with Wasm.
    43  	// In non Wasm Go, the returned error will indicate whether the
    44  	// connection was closed abnormally.
    45  	StatusAbnormalClosure StatusCode = 1006
    46  
    47  	StatusInvalidFramePayloadData StatusCode = 1007
    48  	StatusPolicyViolation         StatusCode = 1008
    49  	StatusMessageTooBig           StatusCode = 1009
    50  	StatusMandatoryExtension      StatusCode = 1010
    51  	StatusInternalError           StatusCode = 1011
    52  	StatusServiceRestart          StatusCode = 1012
    53  	StatusTryAgainLater           StatusCode = 1013
    54  	StatusBadGateway              StatusCode = 1014
    55  
    56  	// StatusTLSHandshake is only exported for use with Wasm.
    57  	// In non Wasm Go, the returned error will indicate whether there was
    58  	// a TLS handshake failure.
    59  	StatusTLSHandshake StatusCode = 1015
    60  )
    61  
    62  // CloseError is returned when the connection is closed with a status and reason.
    63  //
    64  // Use Go 1.13's errors.As to check for this error.
    65  // Also see the CloseStatus helper.
    66  type CloseError struct {
    67  	Code   StatusCode
    68  	Reason string
    69  }
    70  
    71  func (ce CloseError) Error() string {
    72  	return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
    73  }
    74  
    75  // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
    76  // the status code from a CloseError.
    77  //
    78  // -1 will be returned if the passed error is nil or not a CloseError.
    79  func CloseStatus(err error) StatusCode {
    80  	var ce CloseError
    81  	if errors.As(err, &ce) {
    82  		return ce.Code
    83  	}
    84  	return -1
    85  }
    86  
    87  // Close performs the WebSocket close handshake with the given status code and reason.
    88  //
    89  // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
    90  // the peer to send a close frame.
    91  // All data messages received from the peer during the close handshake will be discarded.
    92  //
    93  // The connection can only be closed once. Additional calls to Close
    94  // are no-ops.
    95  //
    96  // The maximum length of reason must be 125 bytes. Avoid
    97  // sending a dynamic reason.
    98  //
    99  // Close will unblock all goroutines interacting with the connection once
   100  // complete.
   101  func (c *Conn) Close(code StatusCode, reason string) error {
   102  	defer c.wg.Wait()
   103  	return c.closeHandshake(code, reason)
   104  }
   105  
   106  // CloseNow closes the WebSocket connection without attempting a close handshake.
   107  // Use when you do not want the overhead of the close handshake.
   108  func (c *Conn) CloseNow() (err error) {
   109  	defer c.wg.Wait()
   110  	defer errd.Wrap(&err, "failed to close WebSocket")
   111  
   112  	if c.isClosed() {
   113  		return net.ErrClosed
   114  	}
   115  
   116  	c.close(nil)
   117  	return c.closeErr
   118  }
   119  
   120  func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
   121  	defer errd.Wrap(&err, "failed to close WebSocket")
   122  
   123  	writeErr := c.writeClose(code, reason)
   124  	closeHandshakeErr := c.waitCloseHandshake()
   125  
   126  	if writeErr != nil {
   127  		return writeErr
   128  	}
   129  
   130  	if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) {
   131  		return closeHandshakeErr
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  func (c *Conn) writeClose(code StatusCode, reason string) error {
   138  	c.closeMu.Lock()
   139  	wroteClose := c.wroteClose
   140  	c.wroteClose = true
   141  	c.closeMu.Unlock()
   142  	if wroteClose {
   143  		return net.ErrClosed
   144  	}
   145  
   146  	ce := CloseError{
   147  		Code:   code,
   148  		Reason: reason,
   149  	}
   150  
   151  	var p []byte
   152  	var marshalErr error
   153  	if ce.Code != StatusNoStatusRcvd {
   154  		p, marshalErr = ce.bytes()
   155  	}
   156  
   157  	writeErr := c.writeControl(context.Background(), opClose, p)
   158  	if CloseStatus(writeErr) != -1 {
   159  		// Not a real error if it's due to a close frame being received.
   160  		writeErr = nil
   161  	}
   162  
   163  	// We do this after in case there was an error writing the close frame.
   164  	c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
   165  
   166  	if marshalErr != nil {
   167  		return marshalErr
   168  	}
   169  	return writeErr
   170  }
   171  
   172  func (c *Conn) waitCloseHandshake() error {
   173  	defer c.close(nil)
   174  
   175  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   176  	defer cancel()
   177  
   178  	err := c.readMu.lock(ctx)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	defer c.readMu.unlock()
   183  
   184  	if c.readCloseFrameErr != nil {
   185  		return c.readCloseFrameErr
   186  	}
   187  
   188  	for i := int64(0); i < c.msgReader.payloadLength; i++ {
   189  		_, err := c.br.ReadByte()
   190  		if err != nil {
   191  			return err
   192  		}
   193  	}
   194  
   195  	for {
   196  		h, err := c.readLoop(ctx)
   197  		if err != nil {
   198  			return err
   199  		}
   200  
   201  		for i := int64(0); i < h.payloadLength; i++ {
   202  			_, err := c.br.ReadByte()
   203  			if err != nil {
   204  				return err
   205  			}
   206  		}
   207  	}
   208  }
   209  
   210  func parseClosePayload(p []byte) (CloseError, error) {
   211  	if len(p) == 0 {
   212  		return CloseError{
   213  			Code: StatusNoStatusRcvd,
   214  		}, nil
   215  	}
   216  
   217  	if len(p) < 2 {
   218  		return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
   219  	}
   220  
   221  	ce := CloseError{
   222  		Code:   StatusCode(binary.BigEndian.Uint16(p)),
   223  		Reason: string(p[2:]),
   224  	}
   225  
   226  	if !validWireCloseCode(ce.Code) {
   227  		return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
   228  	}
   229  
   230  	return ce, nil
   231  }
   232  
   233  // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
   234  // and https://tools.ietf.org/html/rfc6455#section-7.4.1
   235  func validWireCloseCode(code StatusCode) bool {
   236  	switch code {
   237  	case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
   238  		return false
   239  	}
   240  
   241  	if code >= StatusNormalClosure && code <= StatusBadGateway {
   242  		return true
   243  	}
   244  	if code >= 3000 && code <= 4999 {
   245  		return true
   246  	}
   247  
   248  	return false
   249  }
   250  
   251  func (ce CloseError) bytes() ([]byte, error) {
   252  	p, err := ce.bytesErr()
   253  	if err != nil {
   254  		err = fmt.Errorf("failed to marshal close frame: %w", err)
   255  		ce = CloseError{
   256  			Code: StatusInternalError,
   257  		}
   258  		p, _ = ce.bytesErr()
   259  	}
   260  	return p, err
   261  }
   262  
   263  const maxCloseReason = maxControlPayload - 2
   264  
   265  func (ce CloseError) bytesErr() ([]byte, error) {
   266  	if len(ce.Reason) > maxCloseReason {
   267  		return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
   268  	}
   269  
   270  	if !validWireCloseCode(ce.Code) {
   271  		return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
   272  	}
   273  
   274  	buf := make([]byte, 2+len(ce.Reason))
   275  	binary.BigEndian.PutUint16(buf, uint16(ce.Code))
   276  	copy(buf[2:], ce.Reason)
   277  	return buf, nil
   278  }
   279  
   280  func (c *Conn) setCloseErr(err error) {
   281  	c.closeMu.Lock()
   282  	c.setCloseErrLocked(err)
   283  	c.closeMu.Unlock()
   284  }
   285  
   286  func (c *Conn) setCloseErrLocked(err error) {
   287  	if c.closeErr == nil && err != nil {
   288  		c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
   289  	}
   290  }
   291  
   292  func (c *Conn) isClosed() bool {
   293  	select {
   294  	case <-c.closed:
   295  		return true
   296  	default:
   297  		return false
   298  	}
   299  }
   300  

View as plain text