...

Source file src/nhooyr.io/websocket/write.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"crypto/rand"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"time"
    16  
    17  	"compress/flate"
    18  
    19  	"nhooyr.io/websocket/internal/errd"
    20  	"nhooyr.io/websocket/internal/util"
    21  )
    22  
    23  // Writer returns a writer bounded by the context that will write
    24  // a WebSocket message of type dataType to the connection.
    25  //
    26  // You must close the writer once you have written the entire message.
    27  //
    28  // Only one writer can be open at a time, multiple calls will block until the previous writer
    29  // is closed.
    30  func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
    31  	w, err := c.writer(ctx, typ)
    32  	if err != nil {
    33  		return nil, fmt.Errorf("failed to get writer: %w", err)
    34  	}
    35  	return w, nil
    36  }
    37  
    38  // Write writes a message to the connection.
    39  //
    40  // See the Writer method if you want to stream a message.
    41  //
    42  // If compression is disabled or the compression threshold is not met, then it
    43  // will write the message in a single frame.
    44  func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
    45  	_, err := c.write(ctx, typ, p)
    46  	if err != nil {
    47  		return fmt.Errorf("failed to write msg: %w", err)
    48  	}
    49  	return nil
    50  }
    51  
    52  type msgWriter struct {
    53  	c *Conn
    54  
    55  	mu      *mu
    56  	writeMu *mu
    57  	closed  bool
    58  
    59  	ctx    context.Context
    60  	opcode opcode
    61  	flate  bool
    62  
    63  	trimWriter  *trimLastFourBytesWriter
    64  	flateWriter *flate.Writer
    65  }
    66  
    67  func newMsgWriter(c *Conn) *msgWriter {
    68  	mw := &msgWriter{
    69  		c:       c,
    70  		mu:      newMu(c),
    71  		writeMu: newMu(c),
    72  	}
    73  	return mw
    74  }
    75  
    76  func (mw *msgWriter) ensureFlate() {
    77  	if mw.trimWriter == nil {
    78  		mw.trimWriter = &trimLastFourBytesWriter{
    79  			w: util.WriterFunc(mw.write),
    80  		}
    81  	}
    82  
    83  	if mw.flateWriter == nil {
    84  		mw.flateWriter = getFlateWriter(mw.trimWriter)
    85  	}
    86  	mw.flate = true
    87  }
    88  
    89  func (mw *msgWriter) flateContextTakeover() bool {
    90  	if mw.c.client {
    91  		return !mw.c.copts.clientNoContextTakeover
    92  	}
    93  	return !mw.c.copts.serverNoContextTakeover
    94  }
    95  
    96  func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
    97  	err := c.msgWriter.reset(ctx, typ)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return c.msgWriter, nil
   102  }
   103  
   104  func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
   105  	mw, err := c.writer(ctx, typ)
   106  	if err != nil {
   107  		return 0, err
   108  	}
   109  
   110  	if !c.flate() {
   111  		defer c.msgWriter.mu.unlock()
   112  		return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
   113  	}
   114  
   115  	n, err := mw.Write(p)
   116  	if err != nil {
   117  		return n, err
   118  	}
   119  
   120  	err = mw.Close()
   121  	return n, err
   122  }
   123  
   124  func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
   125  	err := mw.mu.lock(ctx)
   126  	if err != nil {
   127  		return err
   128  	}
   129  
   130  	mw.ctx = ctx
   131  	mw.opcode = opcode(typ)
   132  	mw.flate = false
   133  	mw.closed = false
   134  
   135  	mw.trimWriter.reset()
   136  
   137  	return nil
   138  }
   139  
   140  func (mw *msgWriter) putFlateWriter() {
   141  	if mw.flateWriter != nil {
   142  		putFlateWriter(mw.flateWriter)
   143  		mw.flateWriter = nil
   144  	}
   145  }
   146  
   147  // Write writes the given bytes to the WebSocket connection.
   148  func (mw *msgWriter) Write(p []byte) (_ int, err error) {
   149  	err = mw.writeMu.lock(mw.ctx)
   150  	if err != nil {
   151  		return 0, fmt.Errorf("failed to write: %w", err)
   152  	}
   153  	defer mw.writeMu.unlock()
   154  
   155  	if mw.closed {
   156  		return 0, errors.New("cannot use closed writer")
   157  	}
   158  
   159  	defer func() {
   160  		if err != nil {
   161  			err = fmt.Errorf("failed to write: %w", err)
   162  			mw.c.close(err)
   163  		}
   164  	}()
   165  
   166  	if mw.c.flate() {
   167  		// Only enables flate if the length crosses the
   168  		// threshold on the first frame
   169  		if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
   170  			mw.ensureFlate()
   171  		}
   172  	}
   173  
   174  	if mw.flate {
   175  		return mw.flateWriter.Write(p)
   176  	}
   177  
   178  	return mw.write(p)
   179  }
   180  
   181  func (mw *msgWriter) write(p []byte) (int, error) {
   182  	n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
   183  	if err != nil {
   184  		return n, fmt.Errorf("failed to write data frame: %w", err)
   185  	}
   186  	mw.opcode = opContinuation
   187  	return n, nil
   188  }
   189  
   190  // Close flushes the frame to the connection.
   191  func (mw *msgWriter) Close() (err error) {
   192  	defer errd.Wrap(&err, "failed to close writer")
   193  
   194  	err = mw.writeMu.lock(mw.ctx)
   195  	if err != nil {
   196  		return err
   197  	}
   198  	defer mw.writeMu.unlock()
   199  
   200  	if mw.closed {
   201  		return errors.New("writer already closed")
   202  	}
   203  	mw.closed = true
   204  
   205  	if mw.flate {
   206  		err = mw.flateWriter.Flush()
   207  		if err != nil {
   208  			return fmt.Errorf("failed to flush flate: %w", err)
   209  		}
   210  	}
   211  
   212  	_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
   213  	if err != nil {
   214  		return fmt.Errorf("failed to write fin frame: %w", err)
   215  	}
   216  
   217  	if mw.flate && !mw.flateContextTakeover() {
   218  		mw.putFlateWriter()
   219  	}
   220  	mw.mu.unlock()
   221  	return nil
   222  }
   223  
   224  func (mw *msgWriter) close() {
   225  	if mw.c.client {
   226  		mw.c.writeFrameMu.forceLock()
   227  		putBufioWriter(mw.c.bw)
   228  	}
   229  
   230  	mw.writeMu.forceLock()
   231  	mw.putFlateWriter()
   232  }
   233  
   234  func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
   235  	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
   236  	defer cancel()
   237  
   238  	_, err := c.writeFrame(ctx, true, false, opcode, p)
   239  	if err != nil {
   240  		return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
   241  	}
   242  	return nil
   243  }
   244  
   245  // frame handles all writes to the connection.
   246  func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
   247  	err = c.writeFrameMu.lock(ctx)
   248  	if err != nil {
   249  		return 0, err
   250  	}
   251  
   252  	// If the state says a close has already been written, we wait until
   253  	// the connection is closed and return that error.
   254  	//
   255  	// However, if the frame being written is a close, that means its the close from
   256  	// the state being set so we let it go through.
   257  	c.closeMu.Lock()
   258  	wroteClose := c.wroteClose
   259  	c.closeMu.Unlock()
   260  	if wroteClose && opcode != opClose {
   261  		c.writeFrameMu.unlock()
   262  		select {
   263  		case <-ctx.Done():
   264  			return 0, ctx.Err()
   265  		case <-c.closed:
   266  			return 0, net.ErrClosed
   267  		}
   268  	}
   269  	defer c.writeFrameMu.unlock()
   270  
   271  	select {
   272  	case <-c.closed:
   273  		return 0, net.ErrClosed
   274  	case c.writeTimeout <- ctx:
   275  	}
   276  
   277  	defer func() {
   278  		if err != nil {
   279  			select {
   280  			case <-c.closed:
   281  				err = net.ErrClosed
   282  			case <-ctx.Done():
   283  				err = ctx.Err()
   284  			default:
   285  			}
   286  			c.close(err)
   287  			err = fmt.Errorf("failed to write frame: %w", err)
   288  		}
   289  	}()
   290  
   291  	c.writeHeader.fin = fin
   292  	c.writeHeader.opcode = opcode
   293  	c.writeHeader.payloadLength = int64(len(p))
   294  
   295  	if c.client {
   296  		c.writeHeader.masked = true
   297  		_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
   298  		if err != nil {
   299  			return 0, fmt.Errorf("failed to generate masking key: %w", err)
   300  		}
   301  		c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
   302  	}
   303  
   304  	c.writeHeader.rsv1 = false
   305  	if flate && (opcode == opText || opcode == opBinary) {
   306  		c.writeHeader.rsv1 = true
   307  	}
   308  
   309  	err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
   310  	if err != nil {
   311  		return 0, err
   312  	}
   313  
   314  	n, err := c.writeFramePayload(p)
   315  	if err != nil {
   316  		return n, err
   317  	}
   318  
   319  	if c.writeHeader.fin {
   320  		err = c.bw.Flush()
   321  		if err != nil {
   322  			return n, fmt.Errorf("failed to flush: %w", err)
   323  		}
   324  	}
   325  
   326  	select {
   327  	case <-c.closed:
   328  		if opcode == opClose {
   329  			return n, nil
   330  		}
   331  		return n, net.ErrClosed
   332  	case c.writeTimeout <- context.Background():
   333  	}
   334  
   335  	return n, nil
   336  }
   337  
   338  func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
   339  	defer errd.Wrap(&err, "failed to write frame payload")
   340  
   341  	if !c.writeHeader.masked {
   342  		return c.bw.Write(p)
   343  	}
   344  
   345  	maskKey := c.writeHeader.maskKey
   346  	for len(p) > 0 {
   347  		// If the buffer is full, we need to flush.
   348  		if c.bw.Available() == 0 {
   349  			err = c.bw.Flush()
   350  			if err != nil {
   351  				return n, err
   352  			}
   353  		}
   354  
   355  		// Start of next write in the buffer.
   356  		i := c.bw.Buffered()
   357  
   358  		j := len(p)
   359  		if j > c.bw.Available() {
   360  			j = c.bw.Available()
   361  		}
   362  
   363  		_, err := c.bw.Write(p[:j])
   364  		if err != nil {
   365  			return n, err
   366  		}
   367  
   368  		maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
   369  
   370  		p = p[j:]
   371  		n += j
   372  	}
   373  
   374  	return n, nil
   375  }
   376  
   377  // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
   378  // and returns it.
   379  func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
   380  	var writeBuf []byte
   381  	bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
   382  		writeBuf = p2[:cap(p2)]
   383  		return len(p2), nil
   384  	}))
   385  
   386  	bw.WriteByte(0)
   387  	bw.Flush()
   388  
   389  	bw.Reset(w)
   390  
   391  	return writeBuf
   392  }
   393  
   394  func (c *Conn) writeError(code StatusCode, err error) {
   395  	c.setCloseErr(err)
   396  	c.writeClose(code, err.Error())
   397  	c.close(nil)
   398  }
   399  

View as plain text