1
2
3
4
5 package websocket
6
7 import (
8 "bufio"
9 "errors"
10 "io"
11 "net/http"
12 "net/url"
13 "strings"
14 "time"
15 )
16
17
18 type HandshakeError struct {
19 message string
20 }
21
22 func (e HandshakeError) Error() string { return e.message }
23
24
25
26
27
28 type Upgrader struct {
29
30 HandshakeTimeout time.Duration
31
32
33
34
35
36 ReadBufferSize, WriteBufferSize int
37
38
39
40
41
42
43
44
45
46
47 WriteBufferPool BufferPool
48
49
50
51
52
53
54
55 Subprotocols []string
56
57
58
59 Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
60
61
62
63
64
65
66
67
68 CheckOrigin func(r *http.Request) bool
69
70
71
72
73
74 EnableCompression bool
75 }
76
77 func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
78 err := HandshakeError{reason}
79 if u.Error != nil {
80 u.Error(w, r, status, err)
81 } else {
82 w.Header().Set("Sec-Websocket-Version", "13")
83 http.Error(w, http.StatusText(status), status)
84 }
85 return nil, err
86 }
87
88
89 func checkSameOrigin(r *http.Request) bool {
90 origin := r.Header["Origin"]
91 if len(origin) == 0 {
92 return true
93 }
94 u, err := url.Parse(origin[0])
95 if err != nil {
96 return false
97 }
98 return equalASCIIFold(u.Host, r.Host)
99 }
100
101 func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
102 if u.Subprotocols != nil {
103 clientProtocols := Subprotocols(r)
104 for _, serverProtocol := range u.Subprotocols {
105 for _, clientProtocol := range clientProtocols {
106 if clientProtocol == serverProtocol {
107 return clientProtocol
108 }
109 }
110 }
111 } else if responseHeader != nil {
112 return responseHeader.Get("Sec-Websocket-Protocol")
113 }
114 return ""
115 }
116
117
118
119
120
121
122
123
124
125 func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
126 const badHandshake = "websocket: the client is not using the websocket protocol: "
127
128 if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
129 return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
130 }
131
132 if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
133 return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
134 }
135
136 if r.Method != http.MethodGet {
137 return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
138 }
139
140 if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
141 return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
142 }
143
144 if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
145 return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
146 }
147
148 checkOrigin := u.CheckOrigin
149 if checkOrigin == nil {
150 checkOrigin = checkSameOrigin
151 }
152 if !checkOrigin(r) {
153 return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
154 }
155
156 challengeKey := r.Header.Get("Sec-Websocket-Key")
157 if !isValidChallengeKey(challengeKey) {
158 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
159 }
160
161 subprotocol := u.selectSubprotocol(r, responseHeader)
162
163
164 var compress bool
165 if u.EnableCompression {
166 for _, ext := range parseExtensions(r.Header) {
167 if ext[""] != "permessage-deflate" {
168 continue
169 }
170 compress = true
171 break
172 }
173 }
174
175 h, ok := w.(http.Hijacker)
176 if !ok {
177 return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
178 }
179 var brw *bufio.ReadWriter
180 netConn, brw, err := h.Hijack()
181 if err != nil {
182 return u.returnError(w, r, http.StatusInternalServerError, err.Error())
183 }
184
185 if brw.Reader.Buffered() > 0 {
186 netConn.Close()
187 return nil, errors.New("websocket: client sent data before handshake is complete")
188 }
189
190 var br *bufio.Reader
191 if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
192
193 br = brw.Reader
194 }
195
196 buf := bufioWriterBuffer(netConn, brw.Writer)
197
198 var writeBuf []byte
199 if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
200
201 writeBuf = buf
202 }
203
204 c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
205 c.subprotocol = subprotocol
206
207 if compress {
208 c.newCompressionWriter = compressNoContextTakeover
209 c.newDecompressionReader = decompressNoContextTakeover
210 }
211
212
213 p := buf
214 if len(c.writeBuf) > len(p) {
215 p = c.writeBuf
216 }
217 p = p[:0]
218
219 p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
220 p = append(p, computeAcceptKey(challengeKey)...)
221 p = append(p, "\r\n"...)
222 if c.subprotocol != "" {
223 p = append(p, "Sec-WebSocket-Protocol: "...)
224 p = append(p, c.subprotocol...)
225 p = append(p, "\r\n"...)
226 }
227 if compress {
228 p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
229 }
230 for k, vs := range responseHeader {
231 if k == "Sec-Websocket-Protocol" {
232 continue
233 }
234 for _, v := range vs {
235 p = append(p, k...)
236 p = append(p, ": "...)
237 for i := 0; i < len(v); i++ {
238 b := v[i]
239 if b <= 31 {
240
241 b = ' '
242 }
243 p = append(p, b)
244 }
245 p = append(p, "\r\n"...)
246 }
247 }
248 p = append(p, "\r\n"...)
249
250
251 netConn.SetDeadline(time.Time{})
252
253 if u.HandshakeTimeout > 0 {
254 netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
255 }
256 if _, err = netConn.Write(p); err != nil {
257 netConn.Close()
258 return nil, err
259 }
260 if u.HandshakeTimeout > 0 {
261 netConn.SetWriteDeadline(time.Time{})
262 }
263
264 return c, nil
265 }
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297 func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
298 u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
299 u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
300
301 }
302 u.CheckOrigin = func(r *http.Request) bool {
303
304 return true
305 }
306 return u.Upgrade(w, r, responseHeader)
307 }
308
309
310
311 func Subprotocols(r *http.Request) []string {
312 h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
313 if h == "" {
314 return nil
315 }
316 protocols := strings.Split(h, ",")
317 for i := range protocols {
318 protocols[i] = strings.TrimSpace(protocols[i])
319 }
320 return protocols
321 }
322
323
324
325 func IsWebSocketUpgrade(r *http.Request) bool {
326 return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
327 tokenListContainsValue(r.Header, "Upgrade", "websocket")
328 }
329
330
331 func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
332
333
334
335 br.Reset(originalReader)
336 if p, err := br.Peek(0); err == nil {
337 return cap(p)
338 }
339 return 0
340 }
341
342
343
344 type writeHook struct {
345 p []byte
346 }
347
348 func (wh *writeHook) Write(p []byte) (int, error) {
349 wh.p = p
350 return len(p), nil
351 }
352
353
354 func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
355
356
357 var wh writeHook
358 bw.Reset(&wh)
359 bw.WriteByte(0)
360 bw.Flush()
361
362 bw.Reset(originalWriter)
363
364 return wh.p[:cap(wh.p)]
365 }
366
View as plain text