1
2
3
4
5 package websocket
6
7 import (
8 "bytes"
9 "context"
10 "crypto/tls"
11 "errors"
12 "fmt"
13 "io"
14 "io/ioutil"
15 "net"
16 "net/http"
17 "net/http/httptrace"
18 "net/url"
19 "strings"
20 "time"
21 )
22
23
24
25 var ErrBadHandshake = errors.New("websocket: bad handshake")
26
27 var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
28
29
30
31
32
33
34
35
36
37
38
39
40 func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
41 d := Dialer{
42 ReadBufferSize: readBufSize,
43 WriteBufferSize: writeBufSize,
44 NetDial: func(net, addr string) (net.Conn, error) {
45 return netConn, nil
46 },
47 }
48 return d.Dial(u.String(), requestHeader)
49 }
50
51
52
53
54 type Dialer struct {
55
56
57 NetDial func(network, addr string) (net.Conn, error)
58
59
60
61 NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
62
63
64
65
66
67 NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
68
69
70
71
72
73 Proxy func(*http.Request) (*url.URL, error)
74
75
76
77
78
79 TLSClientConfig *tls.Config
80
81
82 HandshakeTimeout time.Duration
83
84
85
86
87 ReadBufferSize, WriteBufferSize int
88
89
90
91
92
93
94
95
96
97
98 WriteBufferPool BufferPool
99
100
101 Subprotocols []string
102
103
104
105
106
107 EnableCompression bool
108
109
110
111
112 Jar http.CookieJar
113 }
114
115
116 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
117 return d.DialContext(context.Background(), urlStr, requestHeader)
118 }
119
120 var errMalformedURL = errors.New("malformed ws or wss URL")
121
122 func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
123 hostPort = u.Host
124 hostNoPort = u.Host
125 if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
126 hostNoPort = hostNoPort[:i]
127 } else {
128 switch u.Scheme {
129 case "wss":
130 hostPort += ":443"
131 case "https":
132 hostPort += ":443"
133 default:
134 hostPort += ":80"
135 }
136 }
137 return hostPort, hostNoPort
138 }
139
140
141 var DefaultDialer = &Dialer{
142 Proxy: http.ProxyFromEnvironment,
143 HandshakeTimeout: 45 * time.Second,
144 }
145
146
147 var nilDialer = *DefaultDialer
148
149
150
151
152
153
154
155
156
157
158
159
160 func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
161 if d == nil {
162 d = &nilDialer
163 }
164
165 challengeKey, err := generateChallengeKey()
166 if err != nil {
167 return nil, nil, err
168 }
169
170 u, err := url.Parse(urlStr)
171 if err != nil {
172 return nil, nil, err
173 }
174
175 switch u.Scheme {
176 case "ws":
177 u.Scheme = "http"
178 case "wss":
179 u.Scheme = "https"
180 default:
181 return nil, nil, errMalformedURL
182 }
183
184 if u.User != nil {
185
186 return nil, nil, errMalformedURL
187 }
188
189 req := &http.Request{
190 Method: http.MethodGet,
191 URL: u,
192 Proto: "HTTP/1.1",
193 ProtoMajor: 1,
194 ProtoMinor: 1,
195 Header: make(http.Header),
196 Host: u.Host,
197 }
198 req = req.WithContext(ctx)
199
200
201 if d.Jar != nil {
202 for _, cookie := range d.Jar.Cookies(u) {
203 req.AddCookie(cookie)
204 }
205 }
206
207
208
209
210
211 req.Header["Upgrade"] = []string{"websocket"}
212 req.Header["Connection"] = []string{"Upgrade"}
213 req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
214 req.Header["Sec-WebSocket-Version"] = []string{"13"}
215 if len(d.Subprotocols) > 0 {
216 req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
217 }
218 for k, vs := range requestHeader {
219 switch {
220 case k == "Host":
221 if len(vs) > 0 {
222 req.Host = vs[0]
223 }
224 case k == "Upgrade" ||
225 k == "Connection" ||
226 k == "Sec-Websocket-Key" ||
227 k == "Sec-Websocket-Version" ||
228 k == "Sec-Websocket-Extensions" ||
229 (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
230 return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
231 case k == "Sec-Websocket-Protocol":
232 req.Header["Sec-WebSocket-Protocol"] = vs
233 default:
234 req.Header[k] = vs
235 }
236 }
237
238 if d.EnableCompression {
239 req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
240 }
241
242 if d.HandshakeTimeout != 0 {
243 var cancel func()
244 ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
245 defer cancel()
246 }
247
248
249 var netDial func(network, add string) (net.Conn, error)
250
251 switch u.Scheme {
252 case "http":
253 if d.NetDialContext != nil {
254 netDial = func(network, addr string) (net.Conn, error) {
255 return d.NetDialContext(ctx, network, addr)
256 }
257 } else if d.NetDial != nil {
258 netDial = d.NetDial
259 }
260 case "https":
261 if d.NetDialTLSContext != nil {
262 netDial = func(network, addr string) (net.Conn, error) {
263 return d.NetDialTLSContext(ctx, network, addr)
264 }
265 } else if d.NetDialContext != nil {
266 netDial = func(network, addr string) (net.Conn, error) {
267 return d.NetDialContext(ctx, network, addr)
268 }
269 } else if d.NetDial != nil {
270 netDial = d.NetDial
271 }
272 default:
273 return nil, nil, errMalformedURL
274 }
275
276 if netDial == nil {
277 netDialer := &net.Dialer{}
278 netDial = func(network, addr string) (net.Conn, error) {
279 return netDialer.DialContext(ctx, network, addr)
280 }
281 }
282
283
284 if deadline, ok := ctx.Deadline(); ok {
285 forwardDial := netDial
286 netDial = func(network, addr string) (net.Conn, error) {
287 c, err := forwardDial(network, addr)
288 if err != nil {
289 return nil, err
290 }
291 err = c.SetDeadline(deadline)
292 if err != nil {
293 c.Close()
294 return nil, err
295 }
296 return c, nil
297 }
298 }
299
300
301 if d.Proxy != nil {
302 proxyURL, err := d.Proxy(req)
303 if err != nil {
304 return nil, nil, err
305 }
306 if proxyURL != nil {
307 dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
308 if err != nil {
309 return nil, nil, err
310 }
311 netDial = dialer.Dial
312 }
313 }
314
315 hostPort, hostNoPort := hostPortNoPort(u)
316 trace := httptrace.ContextClientTrace(ctx)
317 if trace != nil && trace.GetConn != nil {
318 trace.GetConn(hostPort)
319 }
320
321 netConn, err := netDial("tcp", hostPort)
322 if err != nil {
323 return nil, nil, err
324 }
325 if trace != nil && trace.GotConn != nil {
326 trace.GotConn(httptrace.GotConnInfo{
327 Conn: netConn,
328 })
329 }
330
331 defer func() {
332 if netConn != nil {
333 netConn.Close()
334 }
335 }()
336
337 if u.Scheme == "https" && d.NetDialTLSContext == nil {
338
339
340 cfg := cloneTLSConfig(d.TLSClientConfig)
341 if cfg.ServerName == "" {
342 cfg.ServerName = hostNoPort
343 }
344 tlsConn := tls.Client(netConn, cfg)
345 netConn = tlsConn
346
347 if trace != nil && trace.TLSHandshakeStart != nil {
348 trace.TLSHandshakeStart()
349 }
350 err := doHandshake(ctx, tlsConn, cfg)
351 if trace != nil && trace.TLSHandshakeDone != nil {
352 trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
353 }
354
355 if err != nil {
356 return nil, nil, err
357 }
358 }
359
360 conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
361
362 if err := req.Write(netConn); err != nil {
363 return nil, nil, err
364 }
365
366 if trace != nil && trace.GotFirstResponseByte != nil {
367 if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
368 trace.GotFirstResponseByte()
369 }
370 }
371
372 resp, err := http.ReadResponse(conn.br, req)
373 if err != nil {
374 if d.TLSClientConfig != nil {
375 for _, proto := range d.TLSClientConfig.NextProtos {
376 if proto != "http/1.1" {
377 return nil, nil, fmt.Errorf(
378 "websocket: protocol %q was given but is not supported;"+
379 "sharing tls.Config with net/http Transport can cause this error: %w",
380 proto, err,
381 )
382 }
383 }
384 }
385 return nil, nil, err
386 }
387
388 if d.Jar != nil {
389 if rc := resp.Cookies(); len(rc) > 0 {
390 d.Jar.SetCookies(u, rc)
391 }
392 }
393
394 if resp.StatusCode != 101 ||
395 !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
396 !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
397 resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
398
399
400
401 buf := make([]byte, 1024)
402 n, _ := io.ReadFull(resp.Body, buf)
403 resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
404 return nil, resp, ErrBadHandshake
405 }
406
407 for _, ext := range parseExtensions(resp.Header) {
408 if ext[""] != "permessage-deflate" {
409 continue
410 }
411 _, snct := ext["server_no_context_takeover"]
412 _, cnct := ext["client_no_context_takeover"]
413 if !snct || !cnct {
414 return nil, resp, errInvalidCompression
415 }
416 conn.newCompressionWriter = compressNoContextTakeover
417 conn.newDecompressionReader = decompressNoContextTakeover
418 break
419 }
420
421 resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
422 conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
423
424 netConn.SetDeadline(time.Time{})
425 netConn = nil
426 return conn, resp, nil
427 }
428
429 func cloneTLSConfig(cfg *tls.Config) *tls.Config {
430 if cfg == nil {
431 return &tls.Config{}
432 }
433 return cfg.Clone()
434 }
435
View as plain text