1
2
3
4
5 package websocket
6
7 import (
8 "bufio"
9 "encoding/binary"
10 "errors"
11 "io"
12 "io/ioutil"
13 "math/rand"
14 "net"
15 "strconv"
16 "strings"
17 "sync"
18 "time"
19 "unicode/utf8"
20 )
21
22 const (
23
24 finalBit = 1 << 7
25 rsv1Bit = 1 << 6
26 rsv2Bit = 1 << 5
27 rsv3Bit = 1 << 4
28
29
30 maskBit = 1 << 7
31
32 maxFrameHeaderSize = 2 + 8 + 4
33 maxControlFramePayloadSize = 125
34
35 writeWait = time.Second
36
37 defaultReadBufferSize = 4096
38 defaultWriteBufferSize = 4096
39
40 continuationFrame = 0
41 noFrame = -1
42 )
43
44
45 const (
46 CloseNormalClosure = 1000
47 CloseGoingAway = 1001
48 CloseProtocolError = 1002
49 CloseUnsupportedData = 1003
50 CloseNoStatusReceived = 1005
51 CloseAbnormalClosure = 1006
52 CloseInvalidFramePayloadData = 1007
53 ClosePolicyViolation = 1008
54 CloseMessageTooBig = 1009
55 CloseMandatoryExtension = 1010
56 CloseInternalServerErr = 1011
57 CloseServiceRestart = 1012
58 CloseTryAgainLater = 1013
59 CloseTLSHandshake = 1015
60 )
61
62
63 const (
64
65
66 TextMessage = 1
67
68
69 BinaryMessage = 2
70
71
72
73
74 CloseMessage = 8
75
76
77
78 PingMessage = 9
79
80
81
82 PongMessage = 10
83 )
84
85
86
87 var ErrCloseSent = errors.New("websocket: close sent")
88
89
90
91 var ErrReadLimit = errors.New("websocket: read limit exceeded")
92
93
94 type netError struct {
95 msg string
96 temporary bool
97 timeout bool
98 }
99
100 func (e *netError) Error() string { return e.msg }
101 func (e *netError) Temporary() bool { return e.temporary }
102 func (e *netError) Timeout() bool { return e.timeout }
103
104
105 type CloseError struct {
106
107 Code int
108
109
110 Text string
111 }
112
113 func (e *CloseError) Error() string {
114 s := []byte("websocket: close ")
115 s = strconv.AppendInt(s, int64(e.Code), 10)
116 switch e.Code {
117 case CloseNormalClosure:
118 s = append(s, " (normal)"...)
119 case CloseGoingAway:
120 s = append(s, " (going away)"...)
121 case CloseProtocolError:
122 s = append(s, " (protocol error)"...)
123 case CloseUnsupportedData:
124 s = append(s, " (unsupported data)"...)
125 case CloseNoStatusReceived:
126 s = append(s, " (no status)"...)
127 case CloseAbnormalClosure:
128 s = append(s, " (abnormal closure)"...)
129 case CloseInvalidFramePayloadData:
130 s = append(s, " (invalid payload data)"...)
131 case ClosePolicyViolation:
132 s = append(s, " (policy violation)"...)
133 case CloseMessageTooBig:
134 s = append(s, " (message too big)"...)
135 case CloseMandatoryExtension:
136 s = append(s, " (mandatory extension missing)"...)
137 case CloseInternalServerErr:
138 s = append(s, " (internal server error)"...)
139 case CloseTLSHandshake:
140 s = append(s, " (TLS handshake error)"...)
141 }
142 if e.Text != "" {
143 s = append(s, ": "...)
144 s = append(s, e.Text...)
145 }
146 return string(s)
147 }
148
149
150
151 func IsCloseError(err error, codes ...int) bool {
152 if e, ok := err.(*CloseError); ok {
153 for _, code := range codes {
154 if e.Code == code {
155 return true
156 }
157 }
158 }
159 return false
160 }
161
162
163
164 func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
165 if e, ok := err.(*CloseError); ok {
166 for _, code := range expectedCodes {
167 if e.Code == code {
168 return false
169 }
170 }
171 return true
172 }
173 return false
174 }
175
176 var (
177 errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true}
178 errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
179 errBadWriteOpCode = errors.New("websocket: bad write message type")
180 errWriteClosed = errors.New("websocket: write closed")
181 errInvalidControlFrame = errors.New("websocket: invalid control frame")
182 )
183
184 func newMaskKey() [4]byte {
185 n := rand.Uint32()
186 return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
187 }
188
189 func hideTempErr(err error) error {
190 if e, ok := err.(net.Error); ok && e.Temporary() {
191 err = &netError{msg: e.Error(), timeout: e.Timeout()}
192 }
193 return err
194 }
195
196 func isControl(frameType int) bool {
197 return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
198 }
199
200 func isData(frameType int) bool {
201 return frameType == TextMessage || frameType == BinaryMessage
202 }
203
204 var validReceivedCloseCodes = map[int]bool{
205
206
207 CloseNormalClosure: true,
208 CloseGoingAway: true,
209 CloseProtocolError: true,
210 CloseUnsupportedData: true,
211 CloseNoStatusReceived: false,
212 CloseAbnormalClosure: false,
213 CloseInvalidFramePayloadData: true,
214 ClosePolicyViolation: true,
215 CloseMessageTooBig: true,
216 CloseMandatoryExtension: true,
217 CloseInternalServerErr: true,
218 CloseServiceRestart: true,
219 CloseTryAgainLater: true,
220 CloseTLSHandshake: false,
221 }
222
223 func isValidReceivedCloseCode(code int) bool {
224 return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
225 }
226
227
228
229 type BufferPool interface {
230
231 Get() interface{}
232
233 Put(interface{})
234 }
235
236
237
238
239 type writePoolData struct{ buf []byte }
240
241
242 type Conn struct {
243 conn net.Conn
244 isServer bool
245 subprotocol string
246
247
248 mu chan struct{}
249 writeBuf []byte
250 writePool BufferPool
251 writeBufSize int
252 writeDeadline time.Time
253 writer io.WriteCloser
254 isWriting bool
255
256 writeErrMu sync.Mutex
257 writeErr error
258
259 enableWriteCompression bool
260 compressionLevel int
261 newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
262
263
264 reader io.ReadCloser
265 readErr error
266 br *bufio.Reader
267
268
269 readRemaining int64
270 readFinal bool
271 readLength int64
272 readLimit int64
273 readMaskPos int
274 readMaskKey [4]byte
275 handlePong func(string) error
276 handlePing func(string) error
277 handleClose func(int, string) error
278 readErrCount int
279 messageReader *messageReader
280
281 readDecompress bool
282 newDecompressionReader func(io.Reader) io.ReadCloser
283 }
284
285 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
286
287 if br == nil {
288 if readBufferSize == 0 {
289 readBufferSize = defaultReadBufferSize
290 } else if readBufferSize < maxControlFramePayloadSize {
291
292 readBufferSize = maxControlFramePayloadSize
293 }
294 br = bufio.NewReaderSize(conn, readBufferSize)
295 }
296
297 if writeBufferSize <= 0 {
298 writeBufferSize = defaultWriteBufferSize
299 }
300 writeBufferSize += maxFrameHeaderSize
301
302 if writeBuf == nil && writeBufferPool == nil {
303 writeBuf = make([]byte, writeBufferSize)
304 }
305
306 mu := make(chan struct{}, 1)
307 mu <- struct{}{}
308 c := &Conn{
309 isServer: isServer,
310 br: br,
311 conn: conn,
312 mu: mu,
313 readFinal: true,
314 writeBuf: writeBuf,
315 writePool: writeBufferPool,
316 writeBufSize: writeBufferSize,
317 enableWriteCompression: true,
318 compressionLevel: defaultCompressionLevel,
319 }
320 c.SetCloseHandler(nil)
321 c.SetPingHandler(nil)
322 c.SetPongHandler(nil)
323 return c
324 }
325
326
327
328 func (c *Conn) setReadRemaining(n int64) error {
329 if n < 0 {
330 return ErrReadLimit
331 }
332
333 c.readRemaining = n
334 return nil
335 }
336
337
338 func (c *Conn) Subprotocol() string {
339 return c.subprotocol
340 }
341
342
343
344 func (c *Conn) Close() error {
345 return c.conn.Close()
346 }
347
348
349 func (c *Conn) LocalAddr() net.Addr {
350 return c.conn.LocalAddr()
351 }
352
353
354 func (c *Conn) RemoteAddr() net.Addr {
355 return c.conn.RemoteAddr()
356 }
357
358
359
360 func (c *Conn) writeFatal(err error) error {
361 err = hideTempErr(err)
362 c.writeErrMu.Lock()
363 if c.writeErr == nil {
364 c.writeErr = err
365 }
366 c.writeErrMu.Unlock()
367 return err
368 }
369
370 func (c *Conn) read(n int) ([]byte, error) {
371 p, err := c.br.Peek(n)
372 if err == io.EOF {
373 err = errUnexpectedEOF
374 }
375 c.br.Discard(len(p))
376 return p, err
377 }
378
379 func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
380 <-c.mu
381 defer func() { c.mu <- struct{}{} }()
382
383 c.writeErrMu.Lock()
384 err := c.writeErr
385 c.writeErrMu.Unlock()
386 if err != nil {
387 return err
388 }
389
390 c.conn.SetWriteDeadline(deadline)
391 if len(buf1) == 0 {
392 _, err = c.conn.Write(buf0)
393 } else {
394 err = c.writeBufs(buf0, buf1)
395 }
396 if err != nil {
397 return c.writeFatal(err)
398 }
399 if frameType == CloseMessage {
400 c.writeFatal(ErrCloseSent)
401 }
402 return nil
403 }
404
405 func (c *Conn) writeBufs(bufs ...[]byte) error {
406 b := net.Buffers(bufs)
407 _, err := b.WriteTo(c.conn)
408 return err
409 }
410
411
412
413 func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
414 if !isControl(messageType) {
415 return errBadWriteOpCode
416 }
417 if len(data) > maxControlFramePayloadSize {
418 return errInvalidControlFrame
419 }
420
421 b0 := byte(messageType) | finalBit
422 b1 := byte(len(data))
423 if !c.isServer {
424 b1 |= maskBit
425 }
426
427 buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
428 buf = append(buf, b0, b1)
429
430 if c.isServer {
431 buf = append(buf, data...)
432 } else {
433 key := newMaskKey()
434 buf = append(buf, key[:]...)
435 buf = append(buf, data...)
436 maskBytes(key, 0, buf[6:])
437 }
438
439 d := 1000 * time.Hour
440 if !deadline.IsZero() {
441 d = deadline.Sub(time.Now())
442 if d < 0 {
443 return errWriteTimeout
444 }
445 }
446
447 timer := time.NewTimer(d)
448 select {
449 case <-c.mu:
450 timer.Stop()
451 case <-timer.C:
452 return errWriteTimeout
453 }
454 defer func() { c.mu <- struct{}{} }()
455
456 c.writeErrMu.Lock()
457 err := c.writeErr
458 c.writeErrMu.Unlock()
459 if err != nil {
460 return err
461 }
462
463 c.conn.SetWriteDeadline(deadline)
464 _, err = c.conn.Write(buf)
465 if err != nil {
466 return c.writeFatal(err)
467 }
468 if messageType == CloseMessage {
469 c.writeFatal(ErrCloseSent)
470 }
471 return err
472 }
473
474
475 func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
476
477
478
479 if c.writer != nil {
480 c.writer.Close()
481 c.writer = nil
482 }
483
484 if !isControl(messageType) && !isData(messageType) {
485 return errBadWriteOpCode
486 }
487
488 c.writeErrMu.Lock()
489 err := c.writeErr
490 c.writeErrMu.Unlock()
491 if err != nil {
492 return err
493 }
494
495 mw.c = c
496 mw.frameType = messageType
497 mw.pos = maxFrameHeaderSize
498
499 if c.writeBuf == nil {
500 wpd, ok := c.writePool.Get().(writePoolData)
501 if ok {
502 c.writeBuf = wpd.buf
503 } else {
504 c.writeBuf = make([]byte, c.writeBufSize)
505 }
506 }
507 return nil
508 }
509
510
511
512
513
514
515
516
517
518 func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
519 var mw messageWriter
520 if err := c.beginMessage(&mw, messageType); err != nil {
521 return nil, err
522 }
523 c.writer = &mw
524 if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
525 w := c.newCompressionWriter(c.writer, c.compressionLevel)
526 mw.compress = true
527 c.writer = w
528 }
529 return c.writer, nil
530 }
531
532 type messageWriter struct {
533 c *Conn
534 compress bool
535 pos int
536 frameType int
537 err error
538 }
539
540 func (w *messageWriter) endMessage(err error) error {
541 if w.err != nil {
542 return err
543 }
544 c := w.c
545 w.err = err
546 c.writer = nil
547 if c.writePool != nil {
548 c.writePool.Put(writePoolData{buf: c.writeBuf})
549 c.writeBuf = nil
550 }
551 return err
552 }
553
554
555
556 func (w *messageWriter) flushFrame(final bool, extra []byte) error {
557 c := w.c
558 length := w.pos - maxFrameHeaderSize + len(extra)
559
560
561 if isControl(w.frameType) &&
562 (!final || length > maxControlFramePayloadSize) {
563 return w.endMessage(errInvalidControlFrame)
564 }
565
566 b0 := byte(w.frameType)
567 if final {
568 b0 |= finalBit
569 }
570 if w.compress {
571 b0 |= rsv1Bit
572 }
573 w.compress = false
574
575 b1 := byte(0)
576 if !c.isServer {
577 b1 |= maskBit
578 }
579
580
581 framePos := 0
582 if c.isServer {
583
584 framePos = 4
585 }
586
587 switch {
588 case length >= 65536:
589 c.writeBuf[framePos] = b0
590 c.writeBuf[framePos+1] = b1 | 127
591 binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
592 case length > 125:
593 framePos += 6
594 c.writeBuf[framePos] = b0
595 c.writeBuf[framePos+1] = b1 | 126
596 binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
597 default:
598 framePos += 8
599 c.writeBuf[framePos] = b0
600 c.writeBuf[framePos+1] = b1 | byte(length)
601 }
602
603 if !c.isServer {
604 key := newMaskKey()
605 copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
606 maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
607 if len(extra) > 0 {
608 return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
609 }
610 }
611
612
613
614
615
616 if c.isWriting {
617 panic("concurrent write to websocket connection")
618 }
619 c.isWriting = true
620
621 err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
622
623 if !c.isWriting {
624 panic("concurrent write to websocket connection")
625 }
626 c.isWriting = false
627
628 if err != nil {
629 return w.endMessage(err)
630 }
631
632 if final {
633 w.endMessage(errWriteClosed)
634 return nil
635 }
636
637
638 w.pos = maxFrameHeaderSize
639 w.frameType = continuationFrame
640 return nil
641 }
642
643 func (w *messageWriter) ncopy(max int) (int, error) {
644 n := len(w.c.writeBuf) - w.pos
645 if n <= 0 {
646 if err := w.flushFrame(false, nil); err != nil {
647 return 0, err
648 }
649 n = len(w.c.writeBuf) - w.pos
650 }
651 if n > max {
652 n = max
653 }
654 return n, nil
655 }
656
657 func (w *messageWriter) Write(p []byte) (int, error) {
658 if w.err != nil {
659 return 0, w.err
660 }
661
662 if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
663
664 err := w.flushFrame(false, p)
665 if err != nil {
666 return 0, err
667 }
668 return len(p), nil
669 }
670
671 nn := len(p)
672 for len(p) > 0 {
673 n, err := w.ncopy(len(p))
674 if err != nil {
675 return 0, err
676 }
677 copy(w.c.writeBuf[w.pos:], p[:n])
678 w.pos += n
679 p = p[n:]
680 }
681 return nn, nil
682 }
683
684 func (w *messageWriter) WriteString(p string) (int, error) {
685 if w.err != nil {
686 return 0, w.err
687 }
688
689 nn := len(p)
690 for len(p) > 0 {
691 n, err := w.ncopy(len(p))
692 if err != nil {
693 return 0, err
694 }
695 copy(w.c.writeBuf[w.pos:], p[:n])
696 w.pos += n
697 p = p[n:]
698 }
699 return nn, nil
700 }
701
702 func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
703 if w.err != nil {
704 return 0, w.err
705 }
706 for {
707 if w.pos == len(w.c.writeBuf) {
708 err = w.flushFrame(false, nil)
709 if err != nil {
710 break
711 }
712 }
713 var n int
714 n, err = r.Read(w.c.writeBuf[w.pos:])
715 w.pos += n
716 nn += int64(n)
717 if err != nil {
718 if err == io.EOF {
719 err = nil
720 }
721 break
722 }
723 }
724 return nn, err
725 }
726
727 func (w *messageWriter) Close() error {
728 if w.err != nil {
729 return w.err
730 }
731 return w.flushFrame(true, nil)
732 }
733
734
735 func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
736 frameType, frameData, err := pm.frame(prepareKey{
737 isServer: c.isServer,
738 compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
739 compressionLevel: c.compressionLevel,
740 })
741 if err != nil {
742 return err
743 }
744 if c.isWriting {
745 panic("concurrent write to websocket connection")
746 }
747 c.isWriting = true
748 err = c.write(frameType, c.writeDeadline, frameData, nil)
749 if !c.isWriting {
750 panic("concurrent write to websocket connection")
751 }
752 c.isWriting = false
753 return err
754 }
755
756
757
758 func (c *Conn) WriteMessage(messageType int, data []byte) error {
759
760 if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
761
762
763 var mw messageWriter
764 if err := c.beginMessage(&mw, messageType); err != nil {
765 return err
766 }
767 n := copy(c.writeBuf[mw.pos:], data)
768 mw.pos += n
769 data = data[n:]
770 return mw.flushFrame(true, data)
771 }
772
773 w, err := c.NextWriter(messageType)
774 if err != nil {
775 return err
776 }
777 if _, err = w.Write(data); err != nil {
778 return err
779 }
780 return w.Close()
781 }
782
783
784
785
786
787 func (c *Conn) SetWriteDeadline(t time.Time) error {
788 c.writeDeadline = t
789 return nil
790 }
791
792
793
794 func (c *Conn) advanceFrame() (int, error) {
795
796
797 if c.readRemaining > 0 {
798 if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
799 return noFrame, err
800 }
801 }
802
803
804
805
806
807 var errors []string
808
809 p, err := c.read(2)
810 if err != nil {
811 return noFrame, err
812 }
813
814 frameType := int(p[0] & 0xf)
815 final := p[0]&finalBit != 0
816 rsv1 := p[0]&rsv1Bit != 0
817 rsv2 := p[0]&rsv2Bit != 0
818 rsv3 := p[0]&rsv3Bit != 0
819 mask := p[1]&maskBit != 0
820 c.setReadRemaining(int64(p[1] & 0x7f))
821
822 c.readDecompress = false
823 if rsv1 {
824 if c.newDecompressionReader != nil {
825 c.readDecompress = true
826 } else {
827 errors = append(errors, "RSV1 set")
828 }
829 }
830
831 if rsv2 {
832 errors = append(errors, "RSV2 set")
833 }
834
835 if rsv3 {
836 errors = append(errors, "RSV3 set")
837 }
838
839 switch frameType {
840 case CloseMessage, PingMessage, PongMessage:
841 if c.readRemaining > maxControlFramePayloadSize {
842 errors = append(errors, "len > 125 for control")
843 }
844 if !final {
845 errors = append(errors, "FIN not set on control")
846 }
847 case TextMessage, BinaryMessage:
848 if !c.readFinal {
849 errors = append(errors, "data before FIN")
850 }
851 c.readFinal = final
852 case continuationFrame:
853 if c.readFinal {
854 errors = append(errors, "continuation after FIN")
855 }
856 c.readFinal = final
857 default:
858 errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
859 }
860
861 if mask != c.isServer {
862 errors = append(errors, "bad MASK")
863 }
864
865 if len(errors) > 0 {
866 return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
867 }
868
869
870
871
872
873
874
875
876
877
878
879
880
881 switch c.readRemaining {
882 case 126:
883 p, err := c.read(2)
884 if err != nil {
885 return noFrame, err
886 }
887
888 if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
889 return noFrame, err
890 }
891 case 127:
892 p, err := c.read(8)
893 if err != nil {
894 return noFrame, err
895 }
896
897 if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
898 return noFrame, err
899 }
900 }
901
902
903
904 if mask {
905 c.readMaskPos = 0
906 p, err := c.read(len(c.readMaskKey))
907 if err != nil {
908 return noFrame, err
909 }
910 copy(c.readMaskKey[:], p)
911 }
912
913
914
915 if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
916
917 c.readLength += c.readRemaining
918
919
920 if c.readLength < 0 {
921 return noFrame, ErrReadLimit
922 }
923
924 if c.readLimit > 0 && c.readLength > c.readLimit {
925 c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
926 return noFrame, ErrReadLimit
927 }
928
929 return frameType, nil
930 }
931
932
933
934 var payload []byte
935 if c.readRemaining > 0 {
936 payload, err = c.read(int(c.readRemaining))
937 c.setReadRemaining(0)
938 if err != nil {
939 return noFrame, err
940 }
941 if c.isServer {
942 maskBytes(c.readMaskKey, 0, payload)
943 }
944 }
945
946
947
948 switch frameType {
949 case PongMessage:
950 if err := c.handlePong(string(payload)); err != nil {
951 return noFrame, err
952 }
953 case PingMessage:
954 if err := c.handlePing(string(payload)); err != nil {
955 return noFrame, err
956 }
957 case CloseMessage:
958 closeCode := CloseNoStatusReceived
959 closeText := ""
960 if len(payload) >= 2 {
961 closeCode = int(binary.BigEndian.Uint16(payload))
962 if !isValidReceivedCloseCode(closeCode) {
963 return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
964 }
965 closeText = string(payload[2:])
966 if !utf8.ValidString(closeText) {
967 return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
968 }
969 }
970 if err := c.handleClose(closeCode, closeText); err != nil {
971 return noFrame, err
972 }
973 return noFrame, &CloseError{Code: closeCode, Text: closeText}
974 }
975
976 return frameType, nil
977 }
978
979 func (c *Conn) handleProtocolError(message string) error {
980 data := FormatCloseMessage(CloseProtocolError, message)
981 if len(data) > maxControlFramePayloadSize {
982 data = data[:maxControlFramePayloadSize]
983 }
984 c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
985 return errors.New("websocket: " + message)
986 }
987
988
989
990
991
992
993
994
995
996
997
998 func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
999
1000 if c.reader != nil {
1001 c.reader.Close()
1002 c.reader = nil
1003 }
1004
1005 c.messageReader = nil
1006 c.readLength = 0
1007
1008 for c.readErr == nil {
1009 frameType, err := c.advanceFrame()
1010 if err != nil {
1011 c.readErr = hideTempErr(err)
1012 break
1013 }
1014
1015 if frameType == TextMessage || frameType == BinaryMessage {
1016 c.messageReader = &messageReader{c}
1017 c.reader = c.messageReader
1018 if c.readDecompress {
1019 c.reader = c.newDecompressionReader(c.reader)
1020 }
1021 return frameType, c.reader, nil
1022 }
1023 }
1024
1025
1026
1027
1028 c.readErrCount++
1029 if c.readErrCount >= 1000 {
1030 panic("repeated read on failed websocket connection")
1031 }
1032
1033 return noFrame, nil, c.readErr
1034 }
1035
1036 type messageReader struct{ c *Conn }
1037
1038 func (r *messageReader) Read(b []byte) (int, error) {
1039 c := r.c
1040 if c.messageReader != r {
1041 return 0, io.EOF
1042 }
1043
1044 for c.readErr == nil {
1045
1046 if c.readRemaining > 0 {
1047 if int64(len(b)) > c.readRemaining {
1048 b = b[:c.readRemaining]
1049 }
1050 n, err := c.br.Read(b)
1051 c.readErr = hideTempErr(err)
1052 if c.isServer {
1053 c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
1054 }
1055 rem := c.readRemaining
1056 rem -= int64(n)
1057 c.setReadRemaining(rem)
1058 if c.readRemaining > 0 && c.readErr == io.EOF {
1059 c.readErr = errUnexpectedEOF
1060 }
1061 return n, c.readErr
1062 }
1063
1064 if c.readFinal {
1065 c.messageReader = nil
1066 return 0, io.EOF
1067 }
1068
1069 frameType, err := c.advanceFrame()
1070 switch {
1071 case err != nil:
1072 c.readErr = hideTempErr(err)
1073 case frameType == TextMessage || frameType == BinaryMessage:
1074 c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
1075 }
1076 }
1077
1078 err := c.readErr
1079 if err == io.EOF && c.messageReader == r {
1080 err = errUnexpectedEOF
1081 }
1082 return 0, err
1083 }
1084
1085 func (r *messageReader) Close() error {
1086 return nil
1087 }
1088
1089
1090
1091 func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
1092 var r io.Reader
1093 messageType, r, err = c.NextReader()
1094 if err != nil {
1095 return messageType, nil, err
1096 }
1097 p, err = ioutil.ReadAll(r)
1098 return messageType, p, err
1099 }
1100
1101
1102
1103
1104
1105 func (c *Conn) SetReadDeadline(t time.Time) error {
1106 return c.conn.SetReadDeadline(t)
1107 }
1108
1109
1110
1111
1112 func (c *Conn) SetReadLimit(limit int64) {
1113 c.readLimit = limit
1114 }
1115
1116
1117 func (c *Conn) CloseHandler() func(code int, text string) error {
1118 return c.handleClose
1119 }
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135 func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
1136 if h == nil {
1137 h = func(code int, text string) error {
1138 message := FormatCloseMessage(code, "")
1139 c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
1140 return nil
1141 }
1142 }
1143 c.handleClose = h
1144 }
1145
1146
1147 func (c *Conn) PingHandler() func(appData string) error {
1148 return c.handlePing
1149 }
1150
1151
1152
1153
1154
1155
1156
1157
1158 func (c *Conn) SetPingHandler(h func(appData string) error) {
1159 if h == nil {
1160 h = func(message string) error {
1161 err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
1162 if err == ErrCloseSent {
1163 return nil
1164 } else if e, ok := err.(net.Error); ok && e.Temporary() {
1165 return nil
1166 }
1167 return err
1168 }
1169 }
1170 c.handlePing = h
1171 }
1172
1173
1174 func (c *Conn) PongHandler() func(appData string) error {
1175 return c.handlePong
1176 }
1177
1178
1179
1180
1181
1182
1183
1184
1185 func (c *Conn) SetPongHandler(h func(appData string) error) {
1186 if h == nil {
1187 h = func(string) error { return nil }
1188 }
1189 c.handlePong = h
1190 }
1191
1192
1193
1194
1195 func (c *Conn) NetConn() net.Conn {
1196 return c.conn
1197 }
1198
1199
1200
1201
1202 func (c *Conn) UnderlyingConn() net.Conn {
1203 return c.conn
1204 }
1205
1206
1207
1208
1209 func (c *Conn) EnableWriteCompression(enable bool) {
1210 c.enableWriteCompression = enable
1211 }
1212
1213
1214
1215
1216
1217 func (c *Conn) SetCompressionLevel(level int) error {
1218 if !isValidCompressionLevel(level) {
1219 return errors.New("websocket: invalid compression level")
1220 }
1221 c.compressionLevel = level
1222 return nil
1223 }
1224
1225
1226
1227 func FormatCloseMessage(closeCode int, text string) []byte {
1228 if closeCode == CloseNoStatusReceived {
1229
1230
1231
1232 return []byte{}
1233 }
1234 buf := make([]byte, 2+len(text))
1235 binary.BigEndian.PutUint16(buf, uint16(closeCode))
1236 copy(buf[2:], text)
1237 return buf
1238 }
1239
View as plain text