1
2
3
4
5 package websocket
6
7 import (
8 "bytes"
9 "context"
10 "crypto/tls"
11 "crypto/x509"
12 "encoding/base64"
13 "encoding/binary"
14 "errors"
15 "fmt"
16 "io"
17 "io/ioutil"
18 "log"
19 "net"
20 "net/http"
21 "net/http/cookiejar"
22 "net/http/httptest"
23 "net/http/httptrace"
24 "net/url"
25 "reflect"
26 "strings"
27 "testing"
28 "time"
29 )
30
31 var cstUpgrader = Upgrader{
32 Subprotocols: []string{"p0", "p1"},
33 ReadBufferSize: 1024,
34 WriteBufferSize: 1024,
35 EnableCompression: true,
36 Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
37 http.Error(w, reason.Error(), status)
38 },
39 }
40
41 var cstDialer = Dialer{
42 Subprotocols: []string{"p1", "p2"},
43 ReadBufferSize: 1024,
44 WriteBufferSize: 1024,
45 HandshakeTimeout: 30 * time.Second,
46 }
47
48 type cstHandler struct{ *testing.T }
49
50 type cstServer struct {
51 *httptest.Server
52 URL string
53 t *testing.T
54 }
55
56 const (
57 cstPath = "/a/b"
58 cstRawQuery = "x=y"
59 cstRequestURI = cstPath + "?" + cstRawQuery
60 )
61
62 func newServer(t *testing.T) *cstServer {
63 var s cstServer
64 s.Server = httptest.NewServer(cstHandler{t})
65 s.Server.URL += cstRequestURI
66 s.URL = makeWsProto(s.Server.URL)
67 return &s
68 }
69
70 func newTLSServer(t *testing.T) *cstServer {
71 var s cstServer
72 s.Server = httptest.NewTLSServer(cstHandler{t})
73 s.Server.URL += cstRequestURI
74 s.URL = makeWsProto(s.Server.URL)
75 return &s
76 }
77
78 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
79 if r.URL.Path != cstPath {
80 t.Logf("path=%v, want %v", r.URL.Path, cstPath)
81 http.Error(w, "bad path", http.StatusBadRequest)
82 return
83 }
84 if r.URL.RawQuery != cstRawQuery {
85 t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
86 http.Error(w, "bad path", http.StatusBadRequest)
87 return
88 }
89 subprotos := Subprotocols(r)
90 if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
91 t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
92 http.Error(w, "bad protocol", http.StatusBadRequest)
93 return
94 }
95 ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
96 if err != nil {
97 t.Logf("Upgrade: %v", err)
98 return
99 }
100 defer ws.Close()
101
102 if ws.Subprotocol() != "p1" {
103 t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
104 ws.Close()
105 return
106 }
107 op, rd, err := ws.NextReader()
108 if err != nil {
109 t.Logf("NextReader: %v", err)
110 return
111 }
112 wr, err := ws.NextWriter(op)
113 if err != nil {
114 t.Logf("NextWriter: %v", err)
115 return
116 }
117 if _, err = io.Copy(wr, rd); err != nil {
118 t.Logf("NextWriter: %v", err)
119 return
120 }
121 if err := wr.Close(); err != nil {
122 t.Logf("Close: %v", err)
123 return
124 }
125 }
126
127 func makeWsProto(s string) string {
128 return "ws" + strings.TrimPrefix(s, "http")
129 }
130
131 func sendRecv(t *testing.T, ws *Conn) {
132 const message = "Hello World!"
133 if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
134 t.Fatalf("SetWriteDeadline: %v", err)
135 }
136 if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
137 t.Fatalf("WriteMessage: %v", err)
138 }
139 if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
140 t.Fatalf("SetReadDeadline: %v", err)
141 }
142 _, p, err := ws.ReadMessage()
143 if err != nil {
144 t.Fatalf("ReadMessage: %v", err)
145 }
146 if string(p) != message {
147 t.Fatalf("message=%s, want %s", p, message)
148 }
149 }
150
151 func TestProxyDial(t *testing.T) {
152
153 s := newServer(t)
154 defer s.Close()
155
156 surl, _ := url.Parse(s.Server.URL)
157
158 cstDialer := cstDialer
159 cstDialer.Proxy = http.ProxyURL(surl)
160
161 connect := false
162 origHandler := s.Server.Config.Handler
163
164
165 s.Server.Config.Handler = http.HandlerFunc(
166 func(w http.ResponseWriter, r *http.Request) {
167 if r.Method == http.MethodConnect {
168 connect = true
169 w.WriteHeader(http.StatusOK)
170 return
171 }
172
173 if !connect {
174 t.Log("connect not received")
175 http.Error(w, "connect not received", http.StatusMethodNotAllowed)
176 return
177 }
178 origHandler.ServeHTTP(w, r)
179 })
180
181 ws, _, err := cstDialer.Dial(s.URL, nil)
182 if err != nil {
183 t.Fatalf("Dial: %v", err)
184 }
185 defer ws.Close()
186 sendRecv(t, ws)
187 }
188
189 func TestProxyAuthorizationDial(t *testing.T) {
190 s := newServer(t)
191 defer s.Close()
192
193 surl, _ := url.Parse(s.Server.URL)
194 surl.User = url.UserPassword("username", "password")
195
196 cstDialer := cstDialer
197 cstDialer.Proxy = http.ProxyURL(surl)
198
199 connect := false
200 origHandler := s.Server.Config.Handler
201
202
203 s.Server.Config.Handler = http.HandlerFunc(
204 func(w http.ResponseWriter, r *http.Request) {
205 proxyAuth := r.Header.Get("Proxy-Authorization")
206 expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
207 if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth {
208 connect = true
209 w.WriteHeader(http.StatusOK)
210 return
211 }
212
213 if !connect {
214 t.Log("connect with proxy authorization not received")
215 http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
216 return
217 }
218 origHandler.ServeHTTP(w, r)
219 })
220
221 ws, _, err := cstDialer.Dial(s.URL, nil)
222 if err != nil {
223 t.Fatalf("Dial: %v", err)
224 }
225 defer ws.Close()
226 sendRecv(t, ws)
227 }
228
229 func TestDial(t *testing.T) {
230 s := newServer(t)
231 defer s.Close()
232
233 ws, _, err := cstDialer.Dial(s.URL, nil)
234 if err != nil {
235 t.Fatalf("Dial: %v", err)
236 }
237 defer ws.Close()
238 sendRecv(t, ws)
239 }
240
241 func TestDialCookieJar(t *testing.T) {
242 s := newServer(t)
243 defer s.Close()
244
245 jar, _ := cookiejar.New(nil)
246 d := cstDialer
247 d.Jar = jar
248
249 u, _ := url.Parse(s.URL)
250
251 switch u.Scheme {
252 case "ws":
253 u.Scheme = "http"
254 case "wss":
255 u.Scheme = "https"
256 }
257
258 cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
259 d.Jar.SetCookies(u, cookies)
260
261 ws, _, err := d.Dial(s.URL, nil)
262 if err != nil {
263 t.Fatalf("Dial: %v", err)
264 }
265 defer ws.Close()
266
267 var gorilla string
268 var sessionID string
269 for _, c := range d.Jar.Cookies(u) {
270 if c.Name == "gorilla" {
271 gorilla = c.Value
272 }
273
274 if c.Name == "sessionID" {
275 sessionID = c.Value
276 }
277 }
278 if gorilla != "ws" {
279 t.Error("Cookie not present in jar.")
280 }
281
282 if sessionID != "1234" {
283 t.Error("Set-Cookie not received from the server.")
284 }
285
286 sendRecv(t, ws)
287 }
288
289 func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
290 certs := x509.NewCertPool()
291 for _, c := range s.TLS.Certificates {
292 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
293 if err != nil {
294 t.Fatalf("error parsing server's root cert: %v", err)
295 }
296 for _, root := range roots {
297 certs.AddCert(root)
298 }
299 }
300 return certs
301 }
302
303 func TestDialTLS(t *testing.T) {
304 s := newTLSServer(t)
305 defer s.Close()
306
307 d := cstDialer
308 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
309 ws, _, err := d.Dial(s.URL, nil)
310 if err != nil {
311 t.Fatalf("Dial: %v", err)
312 }
313 defer ws.Close()
314 sendRecv(t, ws)
315 }
316
317 func TestDialTimeout(t *testing.T) {
318 s := newServer(t)
319 defer s.Close()
320
321 d := cstDialer
322 d.HandshakeTimeout = -1
323 ws, _, err := d.Dial(s.URL, nil)
324 if err == nil {
325 ws.Close()
326 t.Fatalf("Dial: nil")
327 }
328 }
329
330
331
332 type requireDeadlineNetConn struct {
333 t *testing.T
334 c net.Conn
335 readDeadlineIsSet bool
336 writeDeadlineIsSet bool
337 }
338
339 func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
340 c.writeDeadlineIsSet = !t.Equal(time.Time{})
341 c.readDeadlineIsSet = c.writeDeadlineIsSet
342 return c.c.SetDeadline(t)
343 }
344
345 func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
346 c.readDeadlineIsSet = !t.Equal(time.Time{})
347 return c.c.SetDeadline(t)
348 }
349
350 func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
351 c.writeDeadlineIsSet = !t.Equal(time.Time{})
352 return c.c.SetDeadline(t)
353 }
354
355 func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
356 if !c.writeDeadlineIsSet {
357 c.t.Fatalf("write with no deadline")
358 }
359 return c.c.Write(p)
360 }
361
362 func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
363 if !c.readDeadlineIsSet {
364 c.t.Fatalf("read with no deadline")
365 }
366 return c.c.Read(p)
367 }
368
369 func (c *requireDeadlineNetConn) Close() error { return c.c.Close() }
370 func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() }
371 func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
372
373 func TestHandshakeTimeout(t *testing.T) {
374 s := newServer(t)
375 defer s.Close()
376
377 d := cstDialer
378 d.NetDial = func(n, a string) (net.Conn, error) {
379 c, err := net.Dial(n, a)
380 return &requireDeadlineNetConn{c: c, t: t}, err
381 }
382 ws, _, err := d.Dial(s.URL, nil)
383 if err != nil {
384 t.Fatal("Dial:", err)
385 }
386 ws.Close()
387 }
388
389 func TestHandshakeTimeoutInContext(t *testing.T) {
390 s := newServer(t)
391 defer s.Close()
392
393 d := cstDialer
394 d.HandshakeTimeout = 0
395 d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
396 netDialer := &net.Dialer{}
397 c, err := netDialer.DialContext(ctx, n, a)
398 return &requireDeadlineNetConn{c: c, t: t}, err
399 }
400
401 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
402 defer cancel()
403 ws, _, err := d.DialContext(ctx, s.URL, nil)
404 if err != nil {
405 t.Fatal("Dial:", err)
406 }
407 ws.Close()
408 }
409
410 func TestDialBadScheme(t *testing.T) {
411 s := newServer(t)
412 defer s.Close()
413
414 ws, _, err := cstDialer.Dial(s.Server.URL, nil)
415 if err == nil {
416 ws.Close()
417 t.Fatalf("Dial: nil")
418 }
419 }
420
421 func TestDialBadOrigin(t *testing.T) {
422 s := newServer(t)
423 defer s.Close()
424
425 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
426 if err == nil {
427 ws.Close()
428 t.Fatalf("Dial: nil")
429 }
430 if resp == nil {
431 t.Fatalf("resp=nil, err=%v", err)
432 }
433 if resp.StatusCode != http.StatusForbidden {
434 t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
435 }
436 }
437
438 func TestDialBadHeader(t *testing.T) {
439 s := newServer(t)
440 defer s.Close()
441
442 for _, k := range []string{"Upgrade",
443 "Connection",
444 "Sec-Websocket-Key",
445 "Sec-Websocket-Version",
446 "Sec-Websocket-Protocol"} {
447 h := http.Header{}
448 h.Set(k, "bad")
449 ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
450 if err == nil {
451 ws.Close()
452 t.Errorf("Dial with header %s returned nil", k)
453 }
454 }
455 }
456
457 func TestBadMethod(t *testing.T) {
458 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
459 ws, err := cstUpgrader.Upgrade(w, r, nil)
460 if err == nil {
461 t.Errorf("handshake succeeded, expect fail")
462 ws.Close()
463 }
464 }))
465 defer s.Close()
466
467 req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader(""))
468 if err != nil {
469 t.Fatalf("NewRequest returned error %v", err)
470 }
471 req.Header.Set("Connection", "upgrade")
472 req.Header.Set("Upgrade", "websocket")
473 req.Header.Set("Sec-Websocket-Version", "13")
474
475 resp, err := http.DefaultClient.Do(req)
476 if err != nil {
477 t.Fatalf("Do returned error %v", err)
478 }
479 resp.Body.Close()
480 if resp.StatusCode != http.StatusMethodNotAllowed {
481 t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
482 }
483 }
484
485 func TestDialExtraTokensInRespHeaders(t *testing.T) {
486 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
487 challengeKey := r.Header.Get("Sec-Websocket-Key")
488 w.Header().Set("Upgrade", "foo, websocket")
489 w.Header().Set("Connection", "upgrade, keep-alive")
490 w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
491 w.WriteHeader(101)
492 }))
493 defer s.Close()
494
495 ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
496 if err != nil {
497 t.Fatalf("Dial: %v", err)
498 }
499 defer ws.Close()
500 }
501
502 func TestHandshake(t *testing.T) {
503 s := newServer(t)
504 defer s.Close()
505
506 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
507 if err != nil {
508 t.Fatalf("Dial: %v", err)
509 }
510 defer ws.Close()
511
512 var sessionID string
513 for _, c := range resp.Cookies() {
514 if c.Name == "sessionID" {
515 sessionID = c.Value
516 }
517 }
518 if sessionID != "1234" {
519 t.Error("Set-Cookie not received from the server.")
520 }
521
522 if ws.Subprotocol() != "p1" {
523 t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
524 }
525 sendRecv(t, ws)
526 }
527
528 func TestRespOnBadHandshake(t *testing.T) {
529 const expectedStatus = http.StatusGone
530 const expectedBody = "This is the response body."
531
532 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
533 w.WriteHeader(expectedStatus)
534 io.WriteString(w, expectedBody)
535 }))
536 defer s.Close()
537
538 ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
539 if err == nil {
540 ws.Close()
541 t.Fatalf("Dial: nil")
542 }
543
544 if resp == nil {
545 t.Fatalf("resp=nil, err=%v", err)
546 }
547
548 if resp.StatusCode != expectedStatus {
549 t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
550 }
551
552 p, err := ioutil.ReadAll(resp.Body)
553 if err != nil {
554 t.Fatalf("ReadFull(resp.Body) returned error %v", err)
555 }
556
557 if string(p) != expectedBody {
558 t.Errorf("resp.Body=%s, want %s", p, expectedBody)
559 }
560 }
561
562 type testLogWriter struct {
563 t *testing.T
564 }
565
566 func (w testLogWriter) Write(p []byte) (int, error) {
567 w.t.Logf("%s", p)
568 return len(p), nil
569 }
570
571
572 func TestHost(t *testing.T) {
573
574 upgrader := Upgrader{}
575 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
576 if IsWebSocketUpgrade(r) {
577 c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
578 if err != nil {
579 t.Fatal(err)
580 }
581 c.Close()
582 } else {
583 w.Header().Set("X-Test-Host", r.Host)
584 }
585 })
586
587 server := httptest.NewServer(handler)
588 defer server.Close()
589
590 tlsServer := httptest.NewTLSServer(handler)
591 defer tlsServer.Close()
592
593 addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
594 wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
595 httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
596
597
598 server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
599 tlsServer.Config.ErrorLog = server.Config.ErrorLog
600
601 cas := rootCAs(t, tlsServer)
602
603 tests := []struct {
604 fail bool
605 server *httptest.Server
606 url string
607 header string
608 tls string
609 wantAddr string
610 wantHeader string
611 insecureSkipVerify bool
612 }{
613 {
614 server: server,
615 url: addrs[server],
616 wantAddr: addrs[server],
617 wantHeader: addrs[server],
618 },
619 {
620 server: tlsServer,
621 url: addrs[tlsServer],
622 wantAddr: addrs[tlsServer],
623 wantHeader: addrs[tlsServer],
624 },
625
626 {
627 server: server,
628 url: addrs[server],
629 header: "badhost.com",
630 wantAddr: addrs[server],
631 wantHeader: "badhost.com",
632 },
633 {
634 server: tlsServer,
635 url: addrs[tlsServer],
636 header: "badhost.com",
637 wantAddr: addrs[tlsServer],
638 wantHeader: "badhost.com",
639 },
640
641 {
642 server: server,
643 url: "example.com",
644 header: "badhost.com",
645 wantAddr: "example.com:80",
646 wantHeader: "badhost.com",
647 },
648 {
649 server: tlsServer,
650 url: "example.com",
651 header: "badhost.com",
652 wantAddr: "example.com:443",
653 wantHeader: "badhost.com",
654 },
655
656 {
657 server: server,
658 url: "badhost.com",
659 header: "example.com",
660 wantAddr: "badhost.com:80",
661 wantHeader: "example.com",
662 },
663 {
664 fail: true,
665 server: tlsServer,
666 url: "badhost.com",
667 header: "example.com",
668 wantAddr: "badhost.com:443",
669 },
670 {
671 server: tlsServer,
672 url: "badhost.com",
673 insecureSkipVerify: true,
674 wantAddr: "badhost.com:443",
675 wantHeader: "badhost.com",
676 },
677 {
678 server: tlsServer,
679 url: "badhost.com",
680 tls: "example.com",
681 wantAddr: "badhost.com:443",
682 wantHeader: "badhost.com",
683 },
684 }
685
686 for i, tt := range tests {
687
688 tls := &tls.Config{
689 RootCAs: cas,
690 ServerName: tt.tls,
691 InsecureSkipVerify: tt.insecureSkipVerify,
692 }
693
694 var gotAddr string
695 dialer := Dialer{
696 NetDial: func(network, addr string) (net.Conn, error) {
697 gotAddr = addr
698 return net.Dial(network, addrs[tt.server])
699 },
700 TLSClientConfig: tls,
701 }
702
703
704
705 h := http.Header{}
706 if tt.header != "" {
707 h.Set("Host", tt.header)
708 }
709 c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
710 if err == nil {
711 c.Close()
712 }
713
714 check := func(protos map[*httptest.Server]string) {
715 name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
716 if gotAddr != tt.wantAddr {
717 t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
718 }
719 switch {
720 case tt.fail && err == nil:
721 t.Errorf("%s: unexpected success", name)
722 case !tt.fail && err != nil:
723 t.Errorf("%s: unexpected error %v", name, err)
724 case !tt.fail && err == nil:
725 if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
726 t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
727 }
728 }
729 }
730
731 check(wsProtos)
732
733
734
735 transport := &http.Transport{
736 Dial: dialer.NetDial,
737 TLSClientConfig: dialer.TLSClientConfig,
738 }
739 req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
740 if tt.header != "" {
741 req.Host = tt.header
742 }
743 client := &http.Client{Transport: transport}
744 resp, err = client.Do(req)
745 if err == nil {
746 resp.Body.Close()
747 }
748 transport.CloseIdleConnections()
749 check(httpProtos)
750 }
751 }
752
753 func TestDialCompression(t *testing.T) {
754 s := newServer(t)
755 defer s.Close()
756
757 dialer := cstDialer
758 dialer.EnableCompression = true
759 ws, _, err := dialer.Dial(s.URL, nil)
760 if err != nil {
761 t.Fatalf("Dial: %v", err)
762 }
763 defer ws.Close()
764 sendRecv(t, ws)
765 }
766
767 func TestSocksProxyDial(t *testing.T) {
768 s := newServer(t)
769 defer s.Close()
770
771 proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
772 if err != nil {
773 t.Fatalf("listen failed: %v", err)
774 }
775 defer proxyListener.Close()
776 go func() {
777 c1, err := proxyListener.Accept()
778 if err != nil {
779 t.Errorf("proxy accept failed: %v", err)
780 return
781 }
782 defer c1.Close()
783
784 c1.SetDeadline(time.Now().Add(30 * time.Second))
785
786 buf := make([]byte, 32)
787 if _, err := io.ReadFull(c1, buf[:3]); err != nil {
788 t.Errorf("read failed: %v", err)
789 return
790 }
791 if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
792 t.Errorf("read %x, want %x", buf[:len(want)], want)
793 }
794 if _, err := c1.Write([]byte{5, 0}); err != nil {
795 t.Errorf("write failed: %v", err)
796 return
797 }
798 if _, err := io.ReadFull(c1, buf[:10]); err != nil {
799 t.Errorf("read failed: %v", err)
800 return
801 }
802 if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
803 t.Errorf("read %x, want %x", buf[:len(want)], want)
804 return
805 }
806 buf[1] = 0
807 if _, err := c1.Write(buf[:10]); err != nil {
808 t.Errorf("write failed: %v", err)
809 return
810 }
811
812 ip := net.IP(buf[4:8])
813 port := binary.BigEndian.Uint16(buf[8:10])
814
815 c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
816 if err != nil {
817 t.Errorf("dial failed; %v", err)
818 return
819 }
820 defer c2.Close()
821 done := make(chan struct{})
822 go func() {
823 io.Copy(c1, c2)
824 close(done)
825 }()
826 io.Copy(c2, c1)
827 <-done
828 }()
829
830 purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
831 if err != nil {
832 t.Fatalf("parse failed: %v", err)
833 }
834
835 cstDialer := cstDialer
836 cstDialer.Proxy = http.ProxyURL(purl)
837
838 ws, _, err := cstDialer.Dial(s.URL, nil)
839 if err != nil {
840 t.Fatalf("Dial: %v", err)
841 }
842 defer ws.Close()
843 sendRecv(t, ws)
844 }
845
846 func TestTracingDialWithContext(t *testing.T) {
847
848 var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
849 trace := &httptrace.ClientTrace{
850 WroteHeaders: func() {
851 headersWrote = true
852 },
853 WroteRequest: func(httptrace.WroteRequestInfo) {
854 requestWrote = true
855 },
856 GetConn: func(hostPort string) {
857 getConn = true
858 },
859 GotConn: func(info httptrace.GotConnInfo) {
860 gotConn = true
861 },
862 ConnectDone: func(network, addr string, err error) {
863 connectDone = true
864 },
865 GotFirstResponseByte: func() {
866 gotFirstResponseByte = true
867 },
868 }
869 ctx := httptrace.WithClientTrace(context.Background(), trace)
870
871 s := newTLSServer(t)
872 defer s.Close()
873
874 d := cstDialer
875 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
876
877 ws, _, err := d.DialContext(ctx, s.URL, nil)
878 if err != nil {
879 t.Fatalf("Dial: %v", err)
880 }
881
882 if !headersWrote {
883 t.Fatal("Headers was not written")
884 }
885 if !requestWrote {
886 t.Fatal("Request was not written")
887 }
888 if !getConn {
889 t.Fatal("getConn was not called")
890 }
891 if !gotConn {
892 t.Fatal("gotConn was not called")
893 }
894 if !connectDone {
895 t.Fatal("connectDone was not called")
896 }
897 if !gotFirstResponseByte {
898 t.Fatal("GotFirstResponseByte was not called")
899 }
900
901 defer ws.Close()
902 sendRecv(t, ws)
903 }
904
905 func TestEmptyTracingDialWithContext(t *testing.T) {
906
907 trace := &httptrace.ClientTrace{}
908 ctx := httptrace.WithClientTrace(context.Background(), trace)
909
910 s := newTLSServer(t)
911 defer s.Close()
912
913 d := cstDialer
914 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
915
916 ws, _, err := d.DialContext(ctx, s.URL, nil)
917 if err != nil {
918 t.Fatalf("Dial: %v", err)
919 }
920
921 defer ws.Close()
922 sendRecv(t, ws)
923 }
924
925
926 func TestNetDialConnect(t *testing.T) {
927
928 upgrader := Upgrader{}
929 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
930 if IsWebSocketUpgrade(r) {
931 c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
932 if err != nil {
933 t.Fatal(err)
934 }
935 c.Close()
936 } else {
937 w.Header().Set("X-Test-Host", r.Host)
938 }
939 })
940
941 server := httptest.NewServer(handler)
942 defer server.Close()
943
944 tlsServer := httptest.NewTLSServer(handler)
945 defer tlsServer.Close()
946
947 testUrls := map[*httptest.Server]string{
948 server: "ws://" + server.Listener.Addr().String() + "/",
949 tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
950 }
951
952 cas := rootCAs(t, tlsServer)
953 tlsConfig := &tls.Config{
954 RootCAs: cas,
955 ServerName: "example.com",
956 InsecureSkipVerify: false,
957 }
958
959 tests := []struct {
960 name string
961 server *httptest.Server
962 netDial func(network, addr string) (net.Conn, error)
963 netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
964 netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
965 tlsClientConfig *tls.Config
966 }{
967
968 {
969 name: "HTTP server, all NetDial* defined, shall use NetDialContext",
970 server: server,
971 netDial: func(network, addr string) (net.Conn, error) {
972 return nil, errors.New("NetDial should not be called")
973 },
974 netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
975 return net.Dial(network, addr)
976 },
977 netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
978 return nil, errors.New("NetDialTLSContext should not be called")
979 },
980 tlsClientConfig: nil,
981 },
982 {
983 name: "HTTP server, all NetDial* undefined",
984 server: server,
985 netDial: nil,
986 netDialContext: nil,
987 netDialTLSContext: nil,
988 tlsClientConfig: nil,
989 },
990 {
991 name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
992 server: server,
993 netDial: func(network, addr string) (net.Conn, error) {
994 return net.Dial(network, addr)
995 },
996 netDialContext: nil,
997 netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
998 return nil, errors.New("NetDialTLSContext should not be called")
999 },
1000 tlsClientConfig: nil,
1001 },
1002 {
1003 name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
1004 server: tlsServer,
1005 netDial: func(network, addr string) (net.Conn, error) {
1006 return nil, errors.New("NetDial should not be called")
1007 },
1008 netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1009 return nil, errors.New("NetDialContext should not be called")
1010 },
1011 netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1012 netConn, err := net.Dial(network, addr)
1013 if err != nil {
1014 return nil, err
1015 }
1016 tlsConn := tls.Client(netConn, tlsConfig)
1017 err = tlsConn.Handshake()
1018 if err != nil {
1019 return nil, err
1020 }
1021 return tlsConn, nil
1022 },
1023 tlsClientConfig: nil,
1024 },
1025 {
1026 name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
1027 server: tlsServer,
1028 netDial: func(network, addr string) (net.Conn, error) {
1029 return nil, errors.New("NetDial should not be called")
1030 },
1031 netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1032 return net.Dial(network, addr)
1033 },
1034 netDialTLSContext: nil,
1035 tlsClientConfig: tlsConfig,
1036 },
1037 {
1038 name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
1039 server: tlsServer,
1040 netDial: func(network, addr string) (net.Conn, error) {
1041 return net.Dial(network, addr)
1042 },
1043 netDialContext: nil,
1044 netDialTLSContext: nil,
1045 tlsClientConfig: tlsConfig,
1046 },
1047 {
1048 name: "HTTPS server, all NetDial* undefined",
1049 server: tlsServer,
1050 netDial: nil,
1051 netDialContext: nil,
1052 netDialTLSContext: nil,
1053 tlsClientConfig: tlsConfig,
1054 },
1055 {
1056 name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
1057 server: tlsServer,
1058 netDial: func(network, addr string) (net.Conn, error) {
1059 return nil, errors.New("NetDial should not be called")
1060 },
1061 netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1062 return nil, errors.New("NetDialContext should not be called")
1063 },
1064 netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1065 netConn, err := net.Dial(network, addr)
1066 if err != nil {
1067 return nil, err
1068 }
1069 tlsConn := tls.Client(netConn, tlsConfig)
1070 err = tlsConn.Handshake()
1071 if err != nil {
1072 return nil, err
1073 }
1074 return tlsConn, nil
1075 },
1076 tlsClientConfig: &tls.Config{
1077 RootCAs: nil,
1078 ServerName: "badserver.com",
1079 InsecureSkipVerify: false,
1080 },
1081 },
1082 }
1083
1084 for _, tc := range tests {
1085 dialer := Dialer{
1086 NetDial: tc.netDial,
1087 NetDialContext: tc.netDialContext,
1088 NetDialTLSContext: tc.netDialTLSContext,
1089 TLSClientConfig: tc.tlsClientConfig,
1090 }
1091
1092
1093 c, _, err := dialer.Dial(testUrls[tc.server], nil)
1094 if err != nil {
1095 t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
1096 } else {
1097 c.Close()
1098 }
1099 }
1100 }
1101 func TestNextProtos(t *testing.T) {
1102 ts := httptest.NewUnstartedServer(
1103 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
1104 )
1105 ts.EnableHTTP2 = true
1106 ts.StartTLS()
1107 defer ts.Close()
1108
1109 d := Dialer{
1110 TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
1111 }
1112
1113 r, err := ts.Client().Get(ts.URL)
1114 if err != nil {
1115 t.Fatalf("Get: %v", err)
1116 }
1117 r.Body.Close()
1118
1119
1120
1121 var containsHTTP2 bool = false
1122 for _, proto := range d.TLSClientConfig.NextProtos {
1123 if proto == "h2" {
1124 containsHTTP2 = true
1125 }
1126 }
1127 if !containsHTTP2 {
1128 t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
1129 }
1130
1131 _, _, err = d.Dial(makeWsProto(ts.URL), nil)
1132 if err == nil {
1133 t.Fatalf("Dial succeeded, expect fail ")
1134 }
1135 }
1136
View as plain text