...

Source file src/nhooyr.io/websocket/compress.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket
     5  
     6  import (
     7  	"compress/flate"
     8  	"io"
     9  	"sync"
    10  )
    11  
    12  // CompressionMode represents the modes available to the permessage-deflate extension.
    13  // See https://tools.ietf.org/html/rfc7692
    14  //
    15  // Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
    16  //
    17  // Compression is only used if the peer supports the mode selected.
    18  type CompressionMode int
    19  
    20  const (
    21  	// CompressionDisabled disables the negotiation of the permessage-deflate extension.
    22  	//
    23  	// This is the default. Do not enable compression without benchmarking for your particular use case first.
    24  	CompressionDisabled CompressionMode = iota
    25  
    26  	// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
    27  	// previous messages. i.e compression context across messages is preserved.
    28  	//
    29  	// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
    30  	//
    31  	// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
    32  	// that are used when reading and then returned.
    33  	//
    34  	// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
    35  	//
    36  	// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
    37  	CompressionContextTakeover
    38  
    39  	// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
    40  	// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
    41  	// a sync.Pool.
    42  	//
    43  	// This means less efficient compression as the sliding window from previous messages will not be used but the
    44  	// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
    45  	// Especially if the connections are long lived and seldom written to.
    46  	//
    47  	// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
    48  	//
    49  	// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
    50  	CompressionNoContextTakeover
    51  )
    52  
    53  func (m CompressionMode) opts() *compressionOptions {
    54  	return &compressionOptions{
    55  		clientNoContextTakeover: m == CompressionNoContextTakeover,
    56  		serverNoContextTakeover: m == CompressionNoContextTakeover,
    57  	}
    58  }
    59  
    60  type compressionOptions struct {
    61  	clientNoContextTakeover bool
    62  	serverNoContextTakeover bool
    63  }
    64  
    65  func (copts *compressionOptions) String() string {
    66  	s := "permessage-deflate"
    67  	if copts.clientNoContextTakeover {
    68  		s += "; client_no_context_takeover"
    69  	}
    70  	if copts.serverNoContextTakeover {
    71  		s += "; server_no_context_takeover"
    72  	}
    73  	return s
    74  }
    75  
    76  // These bytes are required to get flate.Reader to return.
    77  // They are removed when sending to avoid the overhead as
    78  // WebSocket framing tell's when the message has ended but then
    79  // we need to add them back otherwise flate.Reader keeps
    80  // trying to read more bytes.
    81  const deflateMessageTail = "\x00\x00\xff\xff"
    82  
    83  type trimLastFourBytesWriter struct {
    84  	w    io.Writer
    85  	tail []byte
    86  }
    87  
    88  func (tw *trimLastFourBytesWriter) reset() {
    89  	if tw != nil && tw.tail != nil {
    90  		tw.tail = tw.tail[:0]
    91  	}
    92  }
    93  
    94  func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
    95  	if tw.tail == nil {
    96  		tw.tail = make([]byte, 0, 4)
    97  	}
    98  
    99  	extra := len(tw.tail) + len(p) - 4
   100  
   101  	if extra <= 0 {
   102  		tw.tail = append(tw.tail, p...)
   103  		return len(p), nil
   104  	}
   105  
   106  	// Now we need to write as many extra bytes as we can from the previous tail.
   107  	if extra > len(tw.tail) {
   108  		extra = len(tw.tail)
   109  	}
   110  	if extra > 0 {
   111  		_, err := tw.w.Write(tw.tail[:extra])
   112  		if err != nil {
   113  			return 0, err
   114  		}
   115  
   116  		// Shift remaining bytes in tail over.
   117  		n := copy(tw.tail, tw.tail[extra:])
   118  		tw.tail = tw.tail[:n]
   119  	}
   120  
   121  	// If p is less than or equal to 4 bytes,
   122  	// all of it is is part of the tail.
   123  	if len(p) <= 4 {
   124  		tw.tail = append(tw.tail, p...)
   125  		return len(p), nil
   126  	}
   127  
   128  	// Otherwise, only the last 4 bytes are.
   129  	tw.tail = append(tw.tail, p[len(p)-4:]...)
   130  
   131  	p = p[:len(p)-4]
   132  	n, err := tw.w.Write(p)
   133  	return n + 4, err
   134  }
   135  
   136  var flateReaderPool sync.Pool
   137  
   138  func getFlateReader(r io.Reader, dict []byte) io.Reader {
   139  	fr, ok := flateReaderPool.Get().(io.Reader)
   140  	if !ok {
   141  		return flate.NewReaderDict(r, dict)
   142  	}
   143  	fr.(flate.Resetter).Reset(r, dict)
   144  	return fr
   145  }
   146  
   147  func putFlateReader(fr io.Reader) {
   148  	flateReaderPool.Put(fr)
   149  }
   150  
   151  var flateWriterPool sync.Pool
   152  
   153  func getFlateWriter(w io.Writer) *flate.Writer {
   154  	fw, ok := flateWriterPool.Get().(*flate.Writer)
   155  	if !ok {
   156  		fw, _ = flate.NewWriter(w, flate.BestSpeed)
   157  		return fw
   158  	}
   159  	fw.Reset(w)
   160  	return fw
   161  }
   162  
   163  func putFlateWriter(w *flate.Writer) {
   164  	flateWriterPool.Put(w)
   165  }
   166  
   167  type slidingWindow struct {
   168  	buf []byte
   169  }
   170  
   171  var swPoolMu sync.RWMutex
   172  var swPool = map[int]*sync.Pool{}
   173  
   174  func slidingWindowPool(n int) *sync.Pool {
   175  	swPoolMu.RLock()
   176  	p, ok := swPool[n]
   177  	swPoolMu.RUnlock()
   178  	if ok {
   179  		return p
   180  	}
   181  
   182  	p = &sync.Pool{}
   183  
   184  	swPoolMu.Lock()
   185  	swPool[n] = p
   186  	swPoolMu.Unlock()
   187  
   188  	return p
   189  }
   190  
   191  func (sw *slidingWindow) init(n int) {
   192  	if sw.buf != nil {
   193  		return
   194  	}
   195  
   196  	if n == 0 {
   197  		n = 32768
   198  	}
   199  
   200  	p := slidingWindowPool(n)
   201  	sw2, ok := p.Get().(*slidingWindow)
   202  	if ok {
   203  		*sw = *sw2
   204  	} else {
   205  		sw.buf = make([]byte, 0, n)
   206  	}
   207  }
   208  
   209  func (sw *slidingWindow) close() {
   210  	sw.buf = sw.buf[:0]
   211  	swPoolMu.Lock()
   212  	swPool[cap(sw.buf)].Put(sw)
   213  	swPoolMu.Unlock()
   214  }
   215  
   216  func (sw *slidingWindow) write(p []byte) {
   217  	if len(p) >= cap(sw.buf) {
   218  		sw.buf = sw.buf[:cap(sw.buf)]
   219  		p = p[len(p)-cap(sw.buf):]
   220  		copy(sw.buf, p)
   221  		return
   222  	}
   223  
   224  	left := cap(sw.buf) - len(sw.buf)
   225  	if left < len(p) {
   226  		// We need to shift spaceNeeded bytes from the end to make room for p at the end.
   227  		spaceNeeded := len(p) - left
   228  		copy(sw.buf, sw.buf[spaceNeeded:])
   229  		sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
   230  	}
   231  
   232  	sw.buf = append(sw.buf, p...)
   233  }
   234  

View as plain text