...

Source file src/nhooyr.io/websocket/frame.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  
     3  package websocket
     4  
     5  import (
     6  	"bufio"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  	"math"
    11  	"math/bits"
    12  
    13  	"nhooyr.io/websocket/internal/errd"
    14  )
    15  
    16  // opcode represents a WebSocket opcode.
    17  type opcode int
    18  
    19  // https://tools.ietf.org/html/rfc6455#section-11.8.
    20  const (
    21  	opContinuation opcode = iota
    22  	opText
    23  	opBinary
    24  	// 3 - 7 are reserved for further non-control frames.
    25  	_
    26  	_
    27  	_
    28  	_
    29  	_
    30  	opClose
    31  	opPing
    32  	opPong
    33  	// 11-16 are reserved for further control frames.
    34  )
    35  
    36  // header represents a WebSocket frame header.
    37  // See https://tools.ietf.org/html/rfc6455#section-5.2.
    38  type header struct {
    39  	fin    bool
    40  	rsv1   bool
    41  	rsv2   bool
    42  	rsv3   bool
    43  	opcode opcode
    44  
    45  	payloadLength int64
    46  
    47  	masked  bool
    48  	maskKey uint32
    49  }
    50  
    51  // readFrameHeader reads a header from the reader.
    52  // See https://tools.ietf.org/html/rfc6455#section-5.2.
    53  func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
    54  	defer errd.Wrap(&err, "failed to read frame header")
    55  
    56  	b, err := r.ReadByte()
    57  	if err != nil {
    58  		return header{}, err
    59  	}
    60  
    61  	h.fin = b&(1<<7) != 0
    62  	h.rsv1 = b&(1<<6) != 0
    63  	h.rsv2 = b&(1<<5) != 0
    64  	h.rsv3 = b&(1<<4) != 0
    65  
    66  	h.opcode = opcode(b & 0xf)
    67  
    68  	b, err = r.ReadByte()
    69  	if err != nil {
    70  		return header{}, err
    71  	}
    72  
    73  	h.masked = b&(1<<7) != 0
    74  
    75  	payloadLength := b &^ (1 << 7)
    76  	switch {
    77  	case payloadLength < 126:
    78  		h.payloadLength = int64(payloadLength)
    79  	case payloadLength == 126:
    80  		_, err = io.ReadFull(r, readBuf[:2])
    81  		h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
    82  	case payloadLength == 127:
    83  		_, err = io.ReadFull(r, readBuf)
    84  		h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
    85  	}
    86  	if err != nil {
    87  		return header{}, err
    88  	}
    89  
    90  	if h.payloadLength < 0 {
    91  		return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
    92  	}
    93  
    94  	if h.masked {
    95  		_, err = io.ReadFull(r, readBuf[:4])
    96  		if err != nil {
    97  			return header{}, err
    98  		}
    99  		h.maskKey = binary.LittleEndian.Uint32(readBuf)
   100  	}
   101  
   102  	return h, nil
   103  }
   104  
   105  // maxControlPayload is the maximum length of a control frame payload.
   106  // See https://tools.ietf.org/html/rfc6455#section-5.5.
   107  const maxControlPayload = 125
   108  
   109  // writeFrameHeader writes the bytes of the header to w.
   110  // See https://tools.ietf.org/html/rfc6455#section-5.2
   111  func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
   112  	defer errd.Wrap(&err, "failed to write frame header")
   113  
   114  	var b byte
   115  	if h.fin {
   116  		b |= 1 << 7
   117  	}
   118  	if h.rsv1 {
   119  		b |= 1 << 6
   120  	}
   121  	if h.rsv2 {
   122  		b |= 1 << 5
   123  	}
   124  	if h.rsv3 {
   125  		b |= 1 << 4
   126  	}
   127  
   128  	b |= byte(h.opcode)
   129  
   130  	err = w.WriteByte(b)
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	lengthByte := byte(0)
   136  	if h.masked {
   137  		lengthByte |= 1 << 7
   138  	}
   139  
   140  	switch {
   141  	case h.payloadLength > math.MaxUint16:
   142  		lengthByte |= 127
   143  	case h.payloadLength > 125:
   144  		lengthByte |= 126
   145  	case h.payloadLength >= 0:
   146  		lengthByte |= byte(h.payloadLength)
   147  	}
   148  	err = w.WriteByte(lengthByte)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	switch {
   154  	case h.payloadLength > math.MaxUint16:
   155  		binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
   156  		_, err = w.Write(buf)
   157  	case h.payloadLength > 125:
   158  		binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
   159  		_, err = w.Write(buf[:2])
   160  	}
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	if h.masked {
   166  		binary.LittleEndian.PutUint32(buf, h.maskKey)
   167  		_, err = w.Write(buf[:4])
   168  		if err != nil {
   169  			return err
   170  		}
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  // mask applies the WebSocket masking algorithm to p
   177  // with the given key.
   178  // See https://tools.ietf.org/html/rfc6455#section-5.3
   179  //
   180  // The returned value is the correctly rotated key to
   181  // to continue to mask/unmask the message.
   182  //
   183  // It is optimized for LittleEndian and expects the key
   184  // to be in little endian.
   185  //
   186  // See https://github.com/golang/go/issues/31586
   187  func mask(key uint32, b []byte) uint32 {
   188  	if len(b) >= 8 {
   189  		key64 := uint64(key)<<32 | uint64(key)
   190  
   191  		// At some point in the future we can clean these unrolled loops up.
   192  		// See https://github.com/golang/go/issues/31586#issuecomment-487436401
   193  
   194  		// Then we xor until b is less than 128 bytes.
   195  		for len(b) >= 128 {
   196  			v := binary.LittleEndian.Uint64(b)
   197  			binary.LittleEndian.PutUint64(b, v^key64)
   198  			v = binary.LittleEndian.Uint64(b[8:16])
   199  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   200  			v = binary.LittleEndian.Uint64(b[16:24])
   201  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   202  			v = binary.LittleEndian.Uint64(b[24:32])
   203  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   204  			v = binary.LittleEndian.Uint64(b[32:40])
   205  			binary.LittleEndian.PutUint64(b[32:40], v^key64)
   206  			v = binary.LittleEndian.Uint64(b[40:48])
   207  			binary.LittleEndian.PutUint64(b[40:48], v^key64)
   208  			v = binary.LittleEndian.Uint64(b[48:56])
   209  			binary.LittleEndian.PutUint64(b[48:56], v^key64)
   210  			v = binary.LittleEndian.Uint64(b[56:64])
   211  			binary.LittleEndian.PutUint64(b[56:64], v^key64)
   212  			v = binary.LittleEndian.Uint64(b[64:72])
   213  			binary.LittleEndian.PutUint64(b[64:72], v^key64)
   214  			v = binary.LittleEndian.Uint64(b[72:80])
   215  			binary.LittleEndian.PutUint64(b[72:80], v^key64)
   216  			v = binary.LittleEndian.Uint64(b[80:88])
   217  			binary.LittleEndian.PutUint64(b[80:88], v^key64)
   218  			v = binary.LittleEndian.Uint64(b[88:96])
   219  			binary.LittleEndian.PutUint64(b[88:96], v^key64)
   220  			v = binary.LittleEndian.Uint64(b[96:104])
   221  			binary.LittleEndian.PutUint64(b[96:104], v^key64)
   222  			v = binary.LittleEndian.Uint64(b[104:112])
   223  			binary.LittleEndian.PutUint64(b[104:112], v^key64)
   224  			v = binary.LittleEndian.Uint64(b[112:120])
   225  			binary.LittleEndian.PutUint64(b[112:120], v^key64)
   226  			v = binary.LittleEndian.Uint64(b[120:128])
   227  			binary.LittleEndian.PutUint64(b[120:128], v^key64)
   228  			b = b[128:]
   229  		}
   230  
   231  		// Then we xor until b is less than 64 bytes.
   232  		for len(b) >= 64 {
   233  			v := binary.LittleEndian.Uint64(b)
   234  			binary.LittleEndian.PutUint64(b, v^key64)
   235  			v = binary.LittleEndian.Uint64(b[8:16])
   236  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   237  			v = binary.LittleEndian.Uint64(b[16:24])
   238  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   239  			v = binary.LittleEndian.Uint64(b[24:32])
   240  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   241  			v = binary.LittleEndian.Uint64(b[32:40])
   242  			binary.LittleEndian.PutUint64(b[32:40], v^key64)
   243  			v = binary.LittleEndian.Uint64(b[40:48])
   244  			binary.LittleEndian.PutUint64(b[40:48], v^key64)
   245  			v = binary.LittleEndian.Uint64(b[48:56])
   246  			binary.LittleEndian.PutUint64(b[48:56], v^key64)
   247  			v = binary.LittleEndian.Uint64(b[56:64])
   248  			binary.LittleEndian.PutUint64(b[56:64], v^key64)
   249  			b = b[64:]
   250  		}
   251  
   252  		// Then we xor until b is less than 32 bytes.
   253  		for len(b) >= 32 {
   254  			v := binary.LittleEndian.Uint64(b)
   255  			binary.LittleEndian.PutUint64(b, v^key64)
   256  			v = binary.LittleEndian.Uint64(b[8:16])
   257  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   258  			v = binary.LittleEndian.Uint64(b[16:24])
   259  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   260  			v = binary.LittleEndian.Uint64(b[24:32])
   261  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   262  			b = b[32:]
   263  		}
   264  
   265  		// Then we xor until b is less than 16 bytes.
   266  		for len(b) >= 16 {
   267  			v := binary.LittleEndian.Uint64(b)
   268  			binary.LittleEndian.PutUint64(b, v^key64)
   269  			v = binary.LittleEndian.Uint64(b[8:16])
   270  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   271  			b = b[16:]
   272  		}
   273  
   274  		// Then we xor until b is less than 8 bytes.
   275  		for len(b) >= 8 {
   276  			v := binary.LittleEndian.Uint64(b)
   277  			binary.LittleEndian.PutUint64(b, v^key64)
   278  			b = b[8:]
   279  		}
   280  	}
   281  
   282  	// Then we xor until b is less than 4 bytes.
   283  	for len(b) >= 4 {
   284  		v := binary.LittleEndian.Uint32(b)
   285  		binary.LittleEndian.PutUint32(b, v^key)
   286  		b = b[4:]
   287  	}
   288  
   289  	// xor remaining bytes.
   290  	for i := range b {
   291  		b[i] ^= byte(key)
   292  		key = bits.RotateLeft32(key, -8)
   293  	}
   294  
   295  	return key
   296  }
   297  

View as plain text