1
2
3 package websocket
4
5 import (
6 "bufio"
7 "encoding/binary"
8 "fmt"
9 "io"
10 "math"
11 "math/bits"
12
13 "nhooyr.io/websocket/internal/errd"
14 )
15
16
17 type opcode int
18
19
20 const (
21 opContinuation opcode = iota
22 opText
23 opBinary
24
25 _
26 _
27 _
28 _
29 _
30 opClose
31 opPing
32 opPong
33
34 )
35
36
37
38 type header struct {
39 fin bool
40 rsv1 bool
41 rsv2 bool
42 rsv3 bool
43 opcode opcode
44
45 payloadLength int64
46
47 masked bool
48 maskKey uint32
49 }
50
51
52
53 func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
54 defer errd.Wrap(&err, "failed to read frame header")
55
56 b, err := r.ReadByte()
57 if err != nil {
58 return header{}, err
59 }
60
61 h.fin = b&(1<<7) != 0
62 h.rsv1 = b&(1<<6) != 0
63 h.rsv2 = b&(1<<5) != 0
64 h.rsv3 = b&(1<<4) != 0
65
66 h.opcode = opcode(b & 0xf)
67
68 b, err = r.ReadByte()
69 if err != nil {
70 return header{}, err
71 }
72
73 h.masked = b&(1<<7) != 0
74
75 payloadLength := b &^ (1 << 7)
76 switch {
77 case payloadLength < 126:
78 h.payloadLength = int64(payloadLength)
79 case payloadLength == 126:
80 _, err = io.ReadFull(r, readBuf[:2])
81 h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
82 case payloadLength == 127:
83 _, err = io.ReadFull(r, readBuf)
84 h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
85 }
86 if err != nil {
87 return header{}, err
88 }
89
90 if h.payloadLength < 0 {
91 return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
92 }
93
94 if h.masked {
95 _, err = io.ReadFull(r, readBuf[:4])
96 if err != nil {
97 return header{}, err
98 }
99 h.maskKey = binary.LittleEndian.Uint32(readBuf)
100 }
101
102 return h, nil
103 }
104
105
106
107 const maxControlPayload = 125
108
109
110
111 func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
112 defer errd.Wrap(&err, "failed to write frame header")
113
114 var b byte
115 if h.fin {
116 b |= 1 << 7
117 }
118 if h.rsv1 {
119 b |= 1 << 6
120 }
121 if h.rsv2 {
122 b |= 1 << 5
123 }
124 if h.rsv3 {
125 b |= 1 << 4
126 }
127
128 b |= byte(h.opcode)
129
130 err = w.WriteByte(b)
131 if err != nil {
132 return err
133 }
134
135 lengthByte := byte(0)
136 if h.masked {
137 lengthByte |= 1 << 7
138 }
139
140 switch {
141 case h.payloadLength > math.MaxUint16:
142 lengthByte |= 127
143 case h.payloadLength > 125:
144 lengthByte |= 126
145 case h.payloadLength >= 0:
146 lengthByte |= byte(h.payloadLength)
147 }
148 err = w.WriteByte(lengthByte)
149 if err != nil {
150 return err
151 }
152
153 switch {
154 case h.payloadLength > math.MaxUint16:
155 binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
156 _, err = w.Write(buf)
157 case h.payloadLength > 125:
158 binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
159 _, err = w.Write(buf[:2])
160 }
161 if err != nil {
162 return err
163 }
164
165 if h.masked {
166 binary.LittleEndian.PutUint32(buf, h.maskKey)
167 _, err = w.Write(buf[:4])
168 if err != nil {
169 return err
170 }
171 }
172
173 return nil
174 }
175
176
177
178
179
180
181
182
183
184
185
186
187 func mask(key uint32, b []byte) uint32 {
188 if len(b) >= 8 {
189 key64 := uint64(key)<<32 | uint64(key)
190
191
192
193
194
195 for len(b) >= 128 {
196 v := binary.LittleEndian.Uint64(b)
197 binary.LittleEndian.PutUint64(b, v^key64)
198 v = binary.LittleEndian.Uint64(b[8:16])
199 binary.LittleEndian.PutUint64(b[8:16], v^key64)
200 v = binary.LittleEndian.Uint64(b[16:24])
201 binary.LittleEndian.PutUint64(b[16:24], v^key64)
202 v = binary.LittleEndian.Uint64(b[24:32])
203 binary.LittleEndian.PutUint64(b[24:32], v^key64)
204 v = binary.LittleEndian.Uint64(b[32:40])
205 binary.LittleEndian.PutUint64(b[32:40], v^key64)
206 v = binary.LittleEndian.Uint64(b[40:48])
207 binary.LittleEndian.PutUint64(b[40:48], v^key64)
208 v = binary.LittleEndian.Uint64(b[48:56])
209 binary.LittleEndian.PutUint64(b[48:56], v^key64)
210 v = binary.LittleEndian.Uint64(b[56:64])
211 binary.LittleEndian.PutUint64(b[56:64], v^key64)
212 v = binary.LittleEndian.Uint64(b[64:72])
213 binary.LittleEndian.PutUint64(b[64:72], v^key64)
214 v = binary.LittleEndian.Uint64(b[72:80])
215 binary.LittleEndian.PutUint64(b[72:80], v^key64)
216 v = binary.LittleEndian.Uint64(b[80:88])
217 binary.LittleEndian.PutUint64(b[80:88], v^key64)
218 v = binary.LittleEndian.Uint64(b[88:96])
219 binary.LittleEndian.PutUint64(b[88:96], v^key64)
220 v = binary.LittleEndian.Uint64(b[96:104])
221 binary.LittleEndian.PutUint64(b[96:104], v^key64)
222 v = binary.LittleEndian.Uint64(b[104:112])
223 binary.LittleEndian.PutUint64(b[104:112], v^key64)
224 v = binary.LittleEndian.Uint64(b[112:120])
225 binary.LittleEndian.PutUint64(b[112:120], v^key64)
226 v = binary.LittleEndian.Uint64(b[120:128])
227 binary.LittleEndian.PutUint64(b[120:128], v^key64)
228 b = b[128:]
229 }
230
231
232 for len(b) >= 64 {
233 v := binary.LittleEndian.Uint64(b)
234 binary.LittleEndian.PutUint64(b, v^key64)
235 v = binary.LittleEndian.Uint64(b[8:16])
236 binary.LittleEndian.PutUint64(b[8:16], v^key64)
237 v = binary.LittleEndian.Uint64(b[16:24])
238 binary.LittleEndian.PutUint64(b[16:24], v^key64)
239 v = binary.LittleEndian.Uint64(b[24:32])
240 binary.LittleEndian.PutUint64(b[24:32], v^key64)
241 v = binary.LittleEndian.Uint64(b[32:40])
242 binary.LittleEndian.PutUint64(b[32:40], v^key64)
243 v = binary.LittleEndian.Uint64(b[40:48])
244 binary.LittleEndian.PutUint64(b[40:48], v^key64)
245 v = binary.LittleEndian.Uint64(b[48:56])
246 binary.LittleEndian.PutUint64(b[48:56], v^key64)
247 v = binary.LittleEndian.Uint64(b[56:64])
248 binary.LittleEndian.PutUint64(b[56:64], v^key64)
249 b = b[64:]
250 }
251
252
253 for len(b) >= 32 {
254 v := binary.LittleEndian.Uint64(b)
255 binary.LittleEndian.PutUint64(b, v^key64)
256 v = binary.LittleEndian.Uint64(b[8:16])
257 binary.LittleEndian.PutUint64(b[8:16], v^key64)
258 v = binary.LittleEndian.Uint64(b[16:24])
259 binary.LittleEndian.PutUint64(b[16:24], v^key64)
260 v = binary.LittleEndian.Uint64(b[24:32])
261 binary.LittleEndian.PutUint64(b[24:32], v^key64)
262 b = b[32:]
263 }
264
265
266 for len(b) >= 16 {
267 v := binary.LittleEndian.Uint64(b)
268 binary.LittleEndian.PutUint64(b, v^key64)
269 v = binary.LittleEndian.Uint64(b[8:16])
270 binary.LittleEndian.PutUint64(b[8:16], v^key64)
271 b = b[16:]
272 }
273
274
275 for len(b) >= 8 {
276 v := binary.LittleEndian.Uint64(b)
277 binary.LittleEndian.PutUint64(b, v^key64)
278 b = b[8:]
279 }
280 }
281
282
283 for len(b) >= 4 {
284 v := binary.LittleEndian.Uint32(b)
285 binary.LittleEndian.PutUint32(b, v^key)
286 b = b[4:]
287 }
288
289
290 for i := range b {
291 b[i] ^= byte(key)
292 key = bits.RotateLeft32(key, -8)
293 }
294
295 return key
296 }
297
View as plain text