1
2
3
4 package websocket
5
6 import (
7 "bufio"
8 "context"
9 "errors"
10 "fmt"
11 "io"
12 "net"
13 "strings"
14 "time"
15
16 "nhooyr.io/websocket/internal/errd"
17 "nhooyr.io/websocket/internal/util"
18 "nhooyr.io/websocket/internal/xsync"
19 )
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
37 return c.reader(ctx)
38 }
39
40
41
42 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
43 typ, r, err := c.Reader(ctx)
44 if err != nil {
45 return 0, nil, err
46 }
47
48 b, err := io.ReadAll(r)
49 return typ, b, err
50 }
51
52
53
54
55
56
57
58
59
60
61
62
63 func (c *Conn) CloseRead(ctx context.Context) context.Context {
64 ctx, cancel := context.WithCancel(ctx)
65
66 c.wg.Add(1)
67 go func() {
68 defer c.CloseNow()
69 defer c.wg.Done()
70 defer cancel()
71 _, _, err := c.Reader(ctx)
72 if err == nil {
73 c.Close(StatusPolicyViolation, "unexpected data message")
74 }
75 }()
76 return ctx
77 }
78
79
80
81
82
83
84
85
86
87 func (c *Conn) SetReadLimit(n int64) {
88 if n >= 0 {
89
90
91 n++
92 }
93
94 c.msgReader.limitReader.limit.Store(n)
95 }
96
97 const defaultReadLimit = 32768
98
99 func newMsgReader(c *Conn) *msgReader {
100 mr := &msgReader{
101 c: c,
102 fin: true,
103 }
104 mr.readFunc = mr.read
105
106 mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
107 return mr
108 }
109
110 func (mr *msgReader) resetFlate() {
111 if mr.flateContextTakeover() {
112 if mr.dict == nil {
113 mr.dict = &slidingWindow{}
114 }
115 mr.dict.init(32768)
116 }
117 if mr.flateBufio == nil {
118 mr.flateBufio = getBufioReader(mr.readFunc)
119 }
120
121 if mr.flateContextTakeover() {
122 mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
123 } else {
124 mr.flateReader = getFlateReader(mr.flateBufio, nil)
125 }
126 mr.limitReader.r = mr.flateReader
127 mr.flateTail.Reset(deflateMessageTail)
128 }
129
130 func (mr *msgReader) putFlateReader() {
131 if mr.flateReader != nil {
132 putFlateReader(mr.flateReader)
133 mr.flateReader = nil
134 }
135 }
136
137 func (mr *msgReader) close() {
138 mr.c.readMu.forceLock()
139 mr.putFlateReader()
140 if mr.dict != nil {
141 mr.dict.close()
142 mr.dict = nil
143 }
144 if mr.flateBufio != nil {
145 putBufioReader(mr.flateBufio)
146 }
147
148 if mr.c.client {
149 putBufioReader(mr.c.br)
150 mr.c.br = nil
151 }
152 }
153
154 func (mr *msgReader) flateContextTakeover() bool {
155 if mr.c.client {
156 return !mr.c.copts.serverNoContextTakeover
157 }
158 return !mr.c.copts.clientNoContextTakeover
159 }
160
161 func (c *Conn) readRSV1Illegal(h header) bool {
162
163 if !c.flate() {
164 return true
165 }
166
167 if h.opcode != opText && h.opcode != opBinary {
168 return true
169 }
170 return false
171 }
172
173 func (c *Conn) readLoop(ctx context.Context) (header, error) {
174 for {
175 h, err := c.readFrameHeader(ctx)
176 if err != nil {
177 return header{}, err
178 }
179
180 if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
181 err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
182 c.writeError(StatusProtocolError, err)
183 return header{}, err
184 }
185
186 if !c.client && !h.masked {
187 return header{}, errors.New("received unmasked frame from client")
188 }
189
190 switch h.opcode {
191 case opClose, opPing, opPong:
192 err = c.handleControl(ctx, h)
193 if err != nil {
194
195 if h.opcode == opClose && CloseStatus(err) != -1 {
196 return header{}, err
197 }
198 return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
199 }
200 case opContinuation, opText, opBinary:
201 return h, nil
202 default:
203 err := fmt.Errorf("received unknown opcode %v", h.opcode)
204 c.writeError(StatusProtocolError, err)
205 return header{}, err
206 }
207 }
208 }
209
210 func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
211 select {
212 case <-c.closed:
213 return header{}, net.ErrClosed
214 case c.readTimeout <- ctx:
215 }
216
217 h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
218 if err != nil {
219 select {
220 case <-c.closed:
221 return header{}, net.ErrClosed
222 case <-ctx.Done():
223 return header{}, ctx.Err()
224 default:
225 c.close(err)
226 return header{}, err
227 }
228 }
229
230 select {
231 case <-c.closed:
232 return header{}, net.ErrClosed
233 case c.readTimeout <- context.Background():
234 }
235
236 return h, nil
237 }
238
239 func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
240 select {
241 case <-c.closed:
242 return 0, net.ErrClosed
243 case c.readTimeout <- ctx:
244 }
245
246 n, err := io.ReadFull(c.br, p)
247 if err != nil {
248 select {
249 case <-c.closed:
250 return n, net.ErrClosed
251 case <-ctx.Done():
252 return n, ctx.Err()
253 default:
254 err = fmt.Errorf("failed to read frame payload: %w", err)
255 c.close(err)
256 return n, err
257 }
258 }
259
260 select {
261 case <-c.closed:
262 return n, net.ErrClosed
263 case c.readTimeout <- context.Background():
264 }
265
266 return n, err
267 }
268
269 func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
270 if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
271 err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
272 c.writeError(StatusProtocolError, err)
273 return err
274 }
275
276 if !h.fin {
277 err := errors.New("received fragmented control frame")
278 c.writeError(StatusProtocolError, err)
279 return err
280 }
281
282 ctx, cancel := context.WithTimeout(ctx, time.Second*5)
283 defer cancel()
284
285 b := c.readControlBuf[:h.payloadLength]
286 _, err = c.readFramePayload(ctx, b)
287 if err != nil {
288 return err
289 }
290
291 if h.masked {
292 mask(h.maskKey, b)
293 }
294
295 switch h.opcode {
296 case opPing:
297 return c.writeControl(ctx, opPong, b)
298 case opPong:
299 c.activePingsMu.Lock()
300 pong, ok := c.activePings[string(b)]
301 c.activePingsMu.Unlock()
302 if ok {
303 select {
304 case pong <- struct{}{}:
305 default:
306 }
307 }
308 return nil
309 }
310
311 defer func() {
312 c.readCloseFrameErr = err
313 }()
314
315 ce, err := parseClosePayload(b)
316 if err != nil {
317 err = fmt.Errorf("received invalid close payload: %w", err)
318 c.writeError(StatusProtocolError, err)
319 return err
320 }
321
322 err = fmt.Errorf("received close frame: %w", ce)
323 c.setCloseErr(err)
324 c.writeClose(ce.Code, ce.Reason)
325 c.close(err)
326 return err
327 }
328
329 func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
330 defer errd.Wrap(&err, "failed to get reader")
331
332 err = c.readMu.lock(ctx)
333 if err != nil {
334 return 0, nil, err
335 }
336 defer c.readMu.unlock()
337
338 if !c.msgReader.fin {
339 err = errors.New("previous message not read to completion")
340 c.close(fmt.Errorf("failed to get reader: %w", err))
341 return 0, nil, err
342 }
343
344 h, err := c.readLoop(ctx)
345 if err != nil {
346 return 0, nil, err
347 }
348
349 if h.opcode == opContinuation {
350 err := errors.New("received continuation frame without text or binary frame")
351 c.writeError(StatusProtocolError, err)
352 return 0, nil, err
353 }
354
355 c.msgReader.reset(ctx, h)
356
357 return MessageType(h.opcode), c.msgReader, nil
358 }
359
360 type msgReader struct {
361 c *Conn
362
363 ctx context.Context
364 flate bool
365 flateReader io.Reader
366 flateBufio *bufio.Reader
367 flateTail strings.Reader
368 limitReader *limitReader
369 dict *slidingWindow
370
371 fin bool
372 payloadLength int64
373 maskKey uint32
374
375
376 readFunc util.ReaderFunc
377 }
378
379 func (mr *msgReader) reset(ctx context.Context, h header) {
380 mr.ctx = ctx
381 mr.flate = h.rsv1
382 mr.limitReader.reset(mr.readFunc)
383
384 if mr.flate {
385 mr.resetFlate()
386 }
387
388 mr.setFrame(h)
389 }
390
391 func (mr *msgReader) setFrame(h header) {
392 mr.fin = h.fin
393 mr.payloadLength = h.payloadLength
394 mr.maskKey = h.maskKey
395 }
396
397 func (mr *msgReader) Read(p []byte) (n int, err error) {
398 err = mr.c.readMu.lock(mr.ctx)
399 if err != nil {
400 return 0, fmt.Errorf("failed to read: %w", err)
401 }
402 defer mr.c.readMu.unlock()
403
404 n, err = mr.limitReader.Read(p)
405 if mr.flate && mr.flateContextTakeover() {
406 p = p[:n]
407 mr.dict.write(p)
408 }
409 if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
410 mr.putFlateReader()
411 return n, io.EOF
412 }
413 if err != nil {
414 err = fmt.Errorf("failed to read: %w", err)
415 mr.c.close(err)
416 }
417 return n, err
418 }
419
420 func (mr *msgReader) read(p []byte) (int, error) {
421 for {
422 if mr.payloadLength == 0 {
423 if mr.fin {
424 if mr.flate {
425 return mr.flateTail.Read(p)
426 }
427 return 0, io.EOF
428 }
429
430 h, err := mr.c.readLoop(mr.ctx)
431 if err != nil {
432 return 0, err
433 }
434 if h.opcode != opContinuation {
435 err := errors.New("received new data message without finishing the previous message")
436 mr.c.writeError(StatusProtocolError, err)
437 return 0, err
438 }
439 mr.setFrame(h)
440
441 continue
442 }
443
444 if int64(len(p)) > mr.payloadLength {
445 p = p[:mr.payloadLength]
446 }
447
448 n, err := mr.c.readFramePayload(mr.ctx, p)
449 if err != nil {
450 return n, err
451 }
452
453 mr.payloadLength -= int64(n)
454
455 if !mr.c.client {
456 mr.maskKey = mask(mr.maskKey, p)
457 }
458
459 return n, nil
460 }
461 }
462
463 type limitReader struct {
464 c *Conn
465 r io.Reader
466 limit xsync.Int64
467 n int64
468 }
469
470 func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
471 lr := &limitReader{
472 c: c,
473 }
474 lr.limit.Store(limit)
475 lr.reset(r)
476 return lr
477 }
478
479 func (lr *limitReader) reset(r io.Reader) {
480 lr.n = lr.limit.Load()
481 lr.r = r
482 }
483
484 func (lr *limitReader) Read(p []byte) (int, error) {
485 if lr.n < 0 {
486 return lr.r.Read(p)
487 }
488
489 if lr.n == 0 {
490 err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
491 lr.c.writeError(StatusMessageTooBig, err)
492 return 0, err
493 }
494
495 if int64(len(p)) > lr.n {
496 p = p[:lr.n]
497 }
498 n, err := lr.r.Read(p)
499 lr.n -= int64(n)
500 if lr.n < 0 {
501 lr.n = 0
502 }
503 return n, err
504 }
505
View as plain text