1
2
3
4 package websocket
5
6 import (
7 "bytes"
8 "crypto/sha1"
9 "encoding/base64"
10 "errors"
11 "fmt"
12 "io"
13 "log"
14 "net/http"
15 "net/textproto"
16 "net/url"
17 "path/filepath"
18 "strings"
19
20 "nhooyr.io/websocket/internal/errd"
21 )
22
23
24 type AcceptOptions struct {
25
26
27
28 Subprotocols []string
29
30
31
32
33 InsecureSkipVerify bool
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 OriginPatterns []string
53
54
55
56
57
58 CompressionMode CompressionMode
59
60
61
62
63
64 CompressionThreshold int
65 }
66
67 func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
68 var o AcceptOptions
69 if opts != nil {
70 o = *opts
71 }
72 return &o
73 }
74
75
76
77
78
79
80
81
82 func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
83 return accept(w, r, opts)
84 }
85
86 func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
87 defer errd.Wrap(&err, "failed to accept WebSocket connection")
88
89 errCode, err := verifyClientRequest(w, r)
90 if err != nil {
91 http.Error(w, err.Error(), errCode)
92 return nil, err
93 }
94
95 opts = opts.cloneWithDefaults()
96 if !opts.InsecureSkipVerify {
97 err = authenticateOrigin(r, opts.OriginPatterns)
98 if err != nil {
99 if errors.Is(err, filepath.ErrBadPattern) {
100 log.Printf("websocket: %v", err)
101 err = errors.New(http.StatusText(http.StatusForbidden))
102 }
103 http.Error(w, err.Error(), http.StatusForbidden)
104 return nil, err
105 }
106 }
107
108 hj, ok := w.(http.Hijacker)
109 if !ok {
110 err = errors.New("http.ResponseWriter does not implement http.Hijacker")
111 http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
112 return nil, err
113 }
114
115 w.Header().Set("Upgrade", "websocket")
116 w.Header().Set("Connection", "Upgrade")
117
118 key := r.Header.Get("Sec-WebSocket-Key")
119 w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
120
121 subproto := selectSubprotocol(r, opts.Subprotocols)
122 if subproto != "" {
123 w.Header().Set("Sec-WebSocket-Protocol", subproto)
124 }
125
126 copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
127 if ok {
128 w.Header().Set("Sec-WebSocket-Extensions", copts.String())
129 }
130
131 w.WriteHeader(http.StatusSwitchingProtocols)
132
133 if ginWriter, ok := w.(interface {
134 WriteHeaderNow()
135 }); ok {
136 ginWriter.WriteHeaderNow()
137 }
138
139 netConn, brw, err := hj.Hijack()
140 if err != nil {
141 err = fmt.Errorf("failed to hijack connection: %w", err)
142 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
143 return nil, err
144 }
145
146
147 b, _ := brw.Reader.Peek(brw.Reader.Buffered())
148 brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
149
150 return newConn(connConfig{
151 subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
152 rwc: netConn,
153 client: false,
154 copts: copts,
155 flateThreshold: opts.CompressionThreshold,
156
157 br: brw.Reader,
158 bw: brw.Writer,
159 }), nil
160 }
161
162 func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
163 if !r.ProtoAtLeast(1, 1) {
164 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
165 }
166
167 if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
168 w.Header().Set("Connection", "Upgrade")
169 w.Header().Set("Upgrade", "websocket")
170 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
171 }
172
173 if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
174 w.Header().Set("Connection", "Upgrade")
175 w.Header().Set("Upgrade", "websocket")
176 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
177 }
178
179 if r.Method != "GET" {
180 return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
181 }
182
183 if r.Header.Get("Sec-WebSocket-Version") != "13" {
184 w.Header().Set("Sec-WebSocket-Version", "13")
185 return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
186 }
187
188 websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
189 if len(websocketSecKeys) == 0 {
190 return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
191 }
192
193 if len(websocketSecKeys) > 1 {
194 return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
195 }
196
197
198 websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
199 if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
200 return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
201 }
202
203 return 0, nil
204 }
205
206 func authenticateOrigin(r *http.Request, originHosts []string) error {
207 origin := r.Header.Get("Origin")
208 if origin == "" {
209 return nil
210 }
211
212 u, err := url.Parse(origin)
213 if err != nil {
214 return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
215 }
216
217 if strings.EqualFold(r.Host, u.Host) {
218 return nil
219 }
220
221 for _, hostPattern := range originHosts {
222 matched, err := match(hostPattern, u.Host)
223 if err != nil {
224 return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
225 }
226 if matched {
227 return nil
228 }
229 }
230 if u.Host == "" {
231 return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
232 }
233 return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
234 }
235
236 func match(pattern, s string) (bool, error) {
237 return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
238 }
239
240 func selectSubprotocol(r *http.Request, subprotocols []string) string {
241 cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
242 for _, sp := range subprotocols {
243 for _, cp := range cps {
244 if strings.EqualFold(sp, cp) {
245 return cp
246 }
247 }
248 }
249 return ""
250 }
251
252 func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
253 if mode == CompressionDisabled {
254 return nil, false
255 }
256 for _, ext := range extensions {
257 switch ext.name {
258
259
260 case "permessage-deflate":
261 copts, ok := acceptDeflate(ext, mode)
262 if ok {
263 return copts, true
264 }
265 }
266 }
267 return nil, false
268 }
269
270 func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
271 copts := mode.opts()
272 for _, p := range ext.params {
273 switch p {
274 case "client_no_context_takeover":
275 copts.clientNoContextTakeover = true
276 continue
277 case "server_no_context_takeover":
278 copts.serverNoContextTakeover = true
279 continue
280 case "client_max_window_bits",
281 "server_max_window_bits=15":
282 continue
283 }
284
285 if strings.HasPrefix(p, "client_max_window_bits=") {
286
287 continue
288 }
289 return nil, false
290 }
291 return copts, true
292 }
293
294 func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
295 for _, t := range headerTokens(h, key) {
296 if strings.EqualFold(t, token) {
297 return true
298 }
299 }
300 return false
301 }
302
303 type websocketExtension struct {
304 name string
305 params []string
306 }
307
308 func websocketExtensions(h http.Header) []websocketExtension {
309 var exts []websocketExtension
310 extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
311 for _, extStr := range extStrs {
312 if extStr == "" {
313 continue
314 }
315
316 vals := strings.Split(extStr, ";")
317 for i := range vals {
318 vals[i] = strings.TrimSpace(vals[i])
319 }
320
321 e := websocketExtension{
322 name: vals[0],
323 params: vals[1:],
324 }
325
326 exts = append(exts, e)
327 }
328 return exts
329 }
330
331 func headerTokens(h http.Header, key string) []string {
332 key = textproto.CanonicalMIMEHeaderKey(key)
333 var tokens []string
334 for _, v := range h[key] {
335 v = strings.TrimSpace(v)
336 for _, t := range strings.Split(v, ",") {
337 t = strings.TrimSpace(t)
338 tokens = append(tokens, t)
339 }
340 }
341 return tokens
342 }
343
344 var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
345
346 func secWebSocketAccept(secWebSocketKey string) string {
347 h := sha1.New()
348 h.Write([]byte(secWebSocketKey))
349 h.Write(keyGUID)
350
351 return base64.StdEncoding.EncodeToString(h.Sum(nil))
352 }
353
View as plain text