1
2
3
4
5 package websocket
6
7
8
9
10 import (
11 "bufio"
12 "bytes"
13 "crypto/rand"
14 "crypto/sha1"
15 "encoding/base64"
16 "encoding/binary"
17 "fmt"
18 "io"
19 "net/http"
20 "net/url"
21 "strings"
22 )
23
24 const (
25 websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
26
27 closeStatusNormal = 1000
28 closeStatusGoingAway = 1001
29 closeStatusProtocolError = 1002
30 closeStatusUnsupportedData = 1003
31 closeStatusFrameTooLarge = 1004
32 closeStatusNoStatusRcvd = 1005
33 closeStatusAbnormalClosure = 1006
34 closeStatusBadMessageData = 1007
35 closeStatusPolicyViolation = 1008
36 closeStatusTooBigData = 1009
37 closeStatusExtensionMismatch = 1010
38
39 maxControlFramePayloadLength = 125
40 )
41
42 var (
43 ErrBadMaskingKey = &ProtocolError{"bad masking key"}
44 ErrBadPongMessage = &ProtocolError{"bad pong message"}
45 ErrBadClosingStatus = &ProtocolError{"bad closing status"}
46 ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
47 ErrNotImplemented = &ProtocolError{"not implemented"}
48
49 handshakeHeader = map[string]bool{
50 "Host": true,
51 "Upgrade": true,
52 "Connection": true,
53 "Sec-Websocket-Key": true,
54 "Sec-Websocket-Origin": true,
55 "Sec-Websocket-Version": true,
56 "Sec-Websocket-Protocol": true,
57 "Sec-Websocket-Accept": true,
58 }
59 )
60
61
62 type hybiFrameHeader struct {
63 Fin bool
64 Rsv [3]bool
65 OpCode byte
66 Length int64
67 MaskingKey []byte
68
69 data *bytes.Buffer
70 }
71
72
73 type hybiFrameReader struct {
74 reader io.Reader
75
76 header hybiFrameHeader
77 pos int64
78 length int
79 }
80
81 func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
82 n, err = frame.reader.Read(msg)
83 if frame.header.MaskingKey != nil {
84 for i := 0; i < n; i++ {
85 msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
86 frame.pos++
87 }
88 }
89 return n, err
90 }
91
92 func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
93
94 func (frame *hybiFrameReader) HeaderReader() io.Reader {
95 if frame.header.data == nil {
96 return nil
97 }
98 if frame.header.data.Len() == 0 {
99 return nil
100 }
101 return frame.header.data
102 }
103
104 func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
105
106 func (frame *hybiFrameReader) Len() (n int) { return frame.length }
107
108
109 type hybiFrameReaderFactory struct {
110 *bufio.Reader
111 }
112
113
114
115
116 func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
117 hybiFrame := new(hybiFrameReader)
118 frame = hybiFrame
119 var header []byte
120 var b byte
121
122 b, err = buf.ReadByte()
123 if err != nil {
124 return
125 }
126 header = append(header, b)
127 hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
128 for i := 0; i < 3; i++ {
129 j := uint(6 - i)
130 hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
131 }
132 hybiFrame.header.OpCode = header[0] & 0x0f
133
134
135 b, err = buf.ReadByte()
136 if err != nil {
137 return
138 }
139 header = append(header, b)
140 mask := (b & 0x80) != 0
141 b &= 0x7f
142 lengthFields := 0
143 switch {
144 case b <= 125:
145 hybiFrame.header.Length = int64(b)
146 case b == 126:
147 lengthFields = 2
148 case b == 127:
149 lengthFields = 8
150 }
151 for i := 0; i < lengthFields; i++ {
152 b, err = buf.ReadByte()
153 if err != nil {
154 return
155 }
156 if lengthFields == 8 && i == 0 {
157 b &= 0x7f
158 }
159 header = append(header, b)
160 hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
161 }
162 if mask {
163
164 for i := 0; i < 4; i++ {
165 b, err = buf.ReadByte()
166 if err != nil {
167 return
168 }
169 header = append(header, b)
170 hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
171 }
172 }
173 hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
174 hybiFrame.header.data = bytes.NewBuffer(header)
175 hybiFrame.length = len(header) + int(hybiFrame.header.Length)
176 return
177 }
178
179
180 type hybiFrameWriter struct {
181 writer *bufio.Writer
182
183 header *hybiFrameHeader
184 }
185
186 func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
187 var header []byte
188 var b byte
189 if frame.header.Fin {
190 b |= 0x80
191 }
192 for i := 0; i < 3; i++ {
193 if frame.header.Rsv[i] {
194 j := uint(6 - i)
195 b |= 1 << j
196 }
197 }
198 b |= frame.header.OpCode
199 header = append(header, b)
200 if frame.header.MaskingKey != nil {
201 b = 0x80
202 } else {
203 b = 0
204 }
205 lengthFields := 0
206 length := len(msg)
207 switch {
208 case length <= 125:
209 b |= byte(length)
210 case length < 65536:
211 b |= 126
212 lengthFields = 2
213 default:
214 b |= 127
215 lengthFields = 8
216 }
217 header = append(header, b)
218 for i := 0; i < lengthFields; i++ {
219 j := uint((lengthFields - i - 1) * 8)
220 b = byte((length >> j) & 0xff)
221 header = append(header, b)
222 }
223 if frame.header.MaskingKey != nil {
224 if len(frame.header.MaskingKey) != 4 {
225 return 0, ErrBadMaskingKey
226 }
227 header = append(header, frame.header.MaskingKey...)
228 frame.writer.Write(header)
229 data := make([]byte, length)
230 for i := range data {
231 data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
232 }
233 frame.writer.Write(data)
234 err = frame.writer.Flush()
235 return length, err
236 }
237 frame.writer.Write(header)
238 frame.writer.Write(msg)
239 err = frame.writer.Flush()
240 return length, err
241 }
242
243 func (frame *hybiFrameWriter) Close() error { return nil }
244
245 type hybiFrameWriterFactory struct {
246 *bufio.Writer
247 needMaskingKey bool
248 }
249
250 func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
251 frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
252 if buf.needMaskingKey {
253 frameHeader.MaskingKey, err = generateMaskingKey()
254 if err != nil {
255 return nil, err
256 }
257 }
258 return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
259 }
260
261 type hybiFrameHandler struct {
262 conn *Conn
263 payloadType byte
264 }
265
266 func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
267 if handler.conn.IsServerConn() {
268
269 if frame.(*hybiFrameReader).header.MaskingKey == nil {
270 handler.WriteClose(closeStatusProtocolError)
271 return nil, io.EOF
272 }
273 } else {
274
275 if frame.(*hybiFrameReader).header.MaskingKey != nil {
276 handler.WriteClose(closeStatusProtocolError)
277 return nil, io.EOF
278 }
279 }
280 if header := frame.HeaderReader(); header != nil {
281 io.Copy(io.Discard, header)
282 }
283 switch frame.PayloadType() {
284 case ContinuationFrame:
285 frame.(*hybiFrameReader).header.OpCode = handler.payloadType
286 case TextFrame, BinaryFrame:
287 handler.payloadType = frame.PayloadType()
288 case CloseFrame:
289 return nil, io.EOF
290 case PingFrame, PongFrame:
291 b := make([]byte, maxControlFramePayloadLength)
292 n, err := io.ReadFull(frame, b)
293 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
294 return nil, err
295 }
296 io.Copy(io.Discard, frame)
297 if frame.PayloadType() == PingFrame {
298 if _, err := handler.WritePong(b[:n]); err != nil {
299 return nil, err
300 }
301 }
302 return nil, nil
303 }
304 return frame, nil
305 }
306
307 func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
308 handler.conn.wio.Lock()
309 defer handler.conn.wio.Unlock()
310 w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
311 if err != nil {
312 return err
313 }
314 msg := make([]byte, 2)
315 binary.BigEndian.PutUint16(msg, uint16(status))
316 _, err = w.Write(msg)
317 w.Close()
318 return err
319 }
320
321 func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
322 handler.conn.wio.Lock()
323 defer handler.conn.wio.Unlock()
324 w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
325 if err != nil {
326 return 0, err
327 }
328 n, err = w.Write(msg)
329 w.Close()
330 return n, err
331 }
332
333
334 func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
335 if buf == nil {
336 br := bufio.NewReader(rwc)
337 bw := bufio.NewWriter(rwc)
338 buf = bufio.NewReadWriter(br, bw)
339 }
340 ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
341 frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
342 frameWriterFactory: hybiFrameWriterFactory{
343 buf.Writer, request == nil},
344 PayloadType: TextFrame,
345 defaultCloseStatus: closeStatusNormal}
346 ws.frameHandler = &hybiFrameHandler{conn: ws}
347 return ws
348 }
349
350
351 func generateMaskingKey() (maskingKey []byte, err error) {
352 maskingKey = make([]byte, 4)
353 if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
354 return
355 }
356 return
357 }
358
359
360
361 func generateNonce() (nonce []byte) {
362 key := make([]byte, 16)
363 if _, err := io.ReadFull(rand.Reader, key); err != nil {
364 panic(err)
365 }
366 nonce = make([]byte, 24)
367 base64.StdEncoding.Encode(nonce, key)
368 return
369 }
370
371
372
373 func removeZone(host string) string {
374 if !strings.HasPrefix(host, "[") {
375 return host
376 }
377 i := strings.LastIndex(host, "]")
378 if i < 0 {
379 return host
380 }
381 j := strings.LastIndex(host[:i], "%")
382 if j < 0 {
383 return host
384 }
385 return host[:j] + host[i:]
386 }
387
388
389
390 func getNonceAccept(nonce []byte) (expected []byte, err error) {
391 h := sha1.New()
392 if _, err = h.Write(nonce); err != nil {
393 return
394 }
395 if _, err = h.Write([]byte(websocketGUID)); err != nil {
396 return
397 }
398 expected = make([]byte, 28)
399 base64.StdEncoding.Encode(expected, h.Sum(nil))
400 return
401 }
402
403
404 func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
405 bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
406
407
408
409
410 bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
411 bw.WriteString("Upgrade: websocket\r\n")
412 bw.WriteString("Connection: Upgrade\r\n")
413 nonce := generateNonce()
414 if config.handshakeData != nil {
415 nonce = []byte(config.handshakeData["key"])
416 }
417 bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
418 bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
419
420 if config.Version != ProtocolVersionHybi13 {
421 return ErrBadProtocolVersion
422 }
423
424 bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
425 if len(config.Protocol) > 0 {
426 bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
427 }
428
429 err = config.Header.WriteSubset(bw, handshakeHeader)
430 if err != nil {
431 return err
432 }
433
434 bw.WriteString("\r\n")
435 if err = bw.Flush(); err != nil {
436 return err
437 }
438
439 resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
440 if err != nil {
441 return err
442 }
443 if resp.StatusCode != 101 {
444 return ErrBadStatus
445 }
446 if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
447 strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
448 return ErrBadUpgrade
449 }
450 expectedAccept, err := getNonceAccept(nonce)
451 if err != nil {
452 return err
453 }
454 if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
455 return ErrChallengeResponse
456 }
457 if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
458 return ErrUnsupportedExtensions
459 }
460 offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
461 if offeredProtocol != "" {
462 protocolMatched := false
463 for i := 0; i < len(config.Protocol); i++ {
464 if config.Protocol[i] == offeredProtocol {
465 protocolMatched = true
466 break
467 }
468 }
469 if !protocolMatched {
470 return ErrBadWebSocketProtocol
471 }
472 config.Protocol = []string{offeredProtocol}
473 }
474
475 return nil
476 }
477
478
479 func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
480 return newHybiConn(config, buf, rwc, nil)
481 }
482
483
484 type hybiServerHandshaker struct {
485 *Config
486 accept []byte
487 }
488
489 func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
490 c.Version = ProtocolVersionHybi13
491 if req.Method != "GET" {
492 return http.StatusMethodNotAllowed, ErrBadRequestMethod
493 }
494
495
496 if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
497 !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
498 return http.StatusBadRequest, ErrNotWebSocket
499 }
500
501 key := req.Header.Get("Sec-Websocket-Key")
502 if key == "" {
503 return http.StatusBadRequest, ErrChallengeResponse
504 }
505 version := req.Header.Get("Sec-Websocket-Version")
506 switch version {
507 case "13":
508 c.Version = ProtocolVersionHybi13
509 default:
510 return http.StatusBadRequest, ErrBadWebSocketVersion
511 }
512 var scheme string
513 if req.TLS != nil {
514 scheme = "wss"
515 } else {
516 scheme = "ws"
517 }
518 c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
519 if err != nil {
520 return http.StatusBadRequest, err
521 }
522 protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
523 if protocol != "" {
524 protocols := strings.Split(protocol, ",")
525 for i := 0; i < len(protocols); i++ {
526 c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
527 }
528 }
529 c.accept, err = getNonceAccept([]byte(key))
530 if err != nil {
531 return http.StatusInternalServerError, err
532 }
533 return http.StatusSwitchingProtocols, nil
534 }
535
536
537
538 func Origin(config *Config, req *http.Request) (*url.URL, error) {
539 var origin string
540 switch config.Version {
541 case ProtocolVersionHybi13:
542 origin = req.Header.Get("Origin")
543 }
544 if origin == "" {
545 return nil, nil
546 }
547 return url.ParseRequestURI(origin)
548 }
549
550 func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
551 if len(c.Protocol) > 0 {
552 if len(c.Protocol) != 1 {
553
554 return ErrBadWebSocketProtocol
555 }
556 }
557 buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
558 buf.WriteString("Upgrade: websocket\r\n")
559 buf.WriteString("Connection: Upgrade\r\n")
560 buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
561 if len(c.Protocol) > 0 {
562 buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
563 }
564
565 if c.Header != nil {
566 err := c.Header.WriteSubset(buf, handshakeHeader)
567 if err != nil {
568 return err
569 }
570 }
571 buf.WriteString("\r\n")
572 return buf.Flush()
573 }
574
575 func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
576 return newHybiServerConn(c.Config, buf, rwc, request)
577 }
578
579
580 func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
581 return newHybiConn(config, buf, rwc, request)
582 }
583
View as plain text