...

Source file src/github.com/gorilla/websocket/client_server_test.go

Documentation: github.com/gorilla/websocket

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"encoding/base64"
    13  	"encoding/binary"
    14  	"errors"
    15  	"fmt"
    16  	"io"
    17  	"io/ioutil"
    18  	"log"
    19  	"net"
    20  	"net/http"
    21  	"net/http/cookiejar"
    22  	"net/http/httptest"
    23  	"net/http/httptrace"
    24  	"net/url"
    25  	"reflect"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  )
    30  
    31  var cstUpgrader = Upgrader{
    32  	Subprotocols:      []string{"p0", "p1"},
    33  	ReadBufferSize:    1024,
    34  	WriteBufferSize:   1024,
    35  	EnableCompression: true,
    36  	Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
    37  		http.Error(w, reason.Error(), status)
    38  	},
    39  }
    40  
    41  var cstDialer = Dialer{
    42  	Subprotocols:     []string{"p1", "p2"},
    43  	ReadBufferSize:   1024,
    44  	WriteBufferSize:  1024,
    45  	HandshakeTimeout: 30 * time.Second,
    46  }
    47  
    48  type cstHandler struct{ *testing.T }
    49  
    50  type cstServer struct {
    51  	*httptest.Server
    52  	URL string
    53  	t   *testing.T
    54  }
    55  
    56  const (
    57  	cstPath       = "/a/b"
    58  	cstRawQuery   = "x=y"
    59  	cstRequestURI = cstPath + "?" + cstRawQuery
    60  )
    61  
    62  func newServer(t *testing.T) *cstServer {
    63  	var s cstServer
    64  	s.Server = httptest.NewServer(cstHandler{t})
    65  	s.Server.URL += cstRequestURI
    66  	s.URL = makeWsProto(s.Server.URL)
    67  	return &s
    68  }
    69  
    70  func newTLSServer(t *testing.T) *cstServer {
    71  	var s cstServer
    72  	s.Server = httptest.NewTLSServer(cstHandler{t})
    73  	s.Server.URL += cstRequestURI
    74  	s.URL = makeWsProto(s.Server.URL)
    75  	return &s
    76  }
    77  
    78  func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    79  	if r.URL.Path != cstPath {
    80  		t.Logf("path=%v, want %v", r.URL.Path, cstPath)
    81  		http.Error(w, "bad path", http.StatusBadRequest)
    82  		return
    83  	}
    84  	if r.URL.RawQuery != cstRawQuery {
    85  		t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
    86  		http.Error(w, "bad path", http.StatusBadRequest)
    87  		return
    88  	}
    89  	subprotos := Subprotocols(r)
    90  	if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
    91  		t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
    92  		http.Error(w, "bad protocol", http.StatusBadRequest)
    93  		return
    94  	}
    95  	ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
    96  	if err != nil {
    97  		t.Logf("Upgrade: %v", err)
    98  		return
    99  	}
   100  	defer ws.Close()
   101  
   102  	if ws.Subprotocol() != "p1" {
   103  		t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
   104  		ws.Close()
   105  		return
   106  	}
   107  	op, rd, err := ws.NextReader()
   108  	if err != nil {
   109  		t.Logf("NextReader: %v", err)
   110  		return
   111  	}
   112  	wr, err := ws.NextWriter(op)
   113  	if err != nil {
   114  		t.Logf("NextWriter: %v", err)
   115  		return
   116  	}
   117  	if _, err = io.Copy(wr, rd); err != nil {
   118  		t.Logf("NextWriter: %v", err)
   119  		return
   120  	}
   121  	if err := wr.Close(); err != nil {
   122  		t.Logf("Close: %v", err)
   123  		return
   124  	}
   125  }
   126  
   127  func makeWsProto(s string) string {
   128  	return "ws" + strings.TrimPrefix(s, "http")
   129  }
   130  
   131  func sendRecv(t *testing.T, ws *Conn) {
   132  	const message = "Hello World!"
   133  	if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
   134  		t.Fatalf("SetWriteDeadline: %v", err)
   135  	}
   136  	if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
   137  		t.Fatalf("WriteMessage: %v", err)
   138  	}
   139  	if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
   140  		t.Fatalf("SetReadDeadline: %v", err)
   141  	}
   142  	_, p, err := ws.ReadMessage()
   143  	if err != nil {
   144  		t.Fatalf("ReadMessage: %v", err)
   145  	}
   146  	if string(p) != message {
   147  		t.Fatalf("message=%s, want %s", p, message)
   148  	}
   149  }
   150  
   151  func TestProxyDial(t *testing.T) {
   152  
   153  	s := newServer(t)
   154  	defer s.Close()
   155  
   156  	surl, _ := url.Parse(s.Server.URL)
   157  
   158  	cstDialer := cstDialer // make local copy for modification on next line.
   159  	cstDialer.Proxy = http.ProxyURL(surl)
   160  
   161  	connect := false
   162  	origHandler := s.Server.Config.Handler
   163  
   164  	// Capture the request Host header.
   165  	s.Server.Config.Handler = http.HandlerFunc(
   166  		func(w http.ResponseWriter, r *http.Request) {
   167  			if r.Method == http.MethodConnect {
   168  				connect = true
   169  				w.WriteHeader(http.StatusOK)
   170  				return
   171  			}
   172  
   173  			if !connect {
   174  				t.Log("connect not received")
   175  				http.Error(w, "connect not received", http.StatusMethodNotAllowed)
   176  				return
   177  			}
   178  			origHandler.ServeHTTP(w, r)
   179  		})
   180  
   181  	ws, _, err := cstDialer.Dial(s.URL, nil)
   182  	if err != nil {
   183  		t.Fatalf("Dial: %v", err)
   184  	}
   185  	defer ws.Close()
   186  	sendRecv(t, ws)
   187  }
   188  
   189  func TestProxyAuthorizationDial(t *testing.T) {
   190  	s := newServer(t)
   191  	defer s.Close()
   192  
   193  	surl, _ := url.Parse(s.Server.URL)
   194  	surl.User = url.UserPassword("username", "password")
   195  
   196  	cstDialer := cstDialer // make local copy for modification on next line.
   197  	cstDialer.Proxy = http.ProxyURL(surl)
   198  
   199  	connect := false
   200  	origHandler := s.Server.Config.Handler
   201  
   202  	// Capture the request Host header.
   203  	s.Server.Config.Handler = http.HandlerFunc(
   204  		func(w http.ResponseWriter, r *http.Request) {
   205  			proxyAuth := r.Header.Get("Proxy-Authorization")
   206  			expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
   207  			if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth {
   208  				connect = true
   209  				w.WriteHeader(http.StatusOK)
   210  				return
   211  			}
   212  
   213  			if !connect {
   214  				t.Log("connect with proxy authorization not received")
   215  				http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
   216  				return
   217  			}
   218  			origHandler.ServeHTTP(w, r)
   219  		})
   220  
   221  	ws, _, err := cstDialer.Dial(s.URL, nil)
   222  	if err != nil {
   223  		t.Fatalf("Dial: %v", err)
   224  	}
   225  	defer ws.Close()
   226  	sendRecv(t, ws)
   227  }
   228  
   229  func TestDial(t *testing.T) {
   230  	s := newServer(t)
   231  	defer s.Close()
   232  
   233  	ws, _, err := cstDialer.Dial(s.URL, nil)
   234  	if err != nil {
   235  		t.Fatalf("Dial: %v", err)
   236  	}
   237  	defer ws.Close()
   238  	sendRecv(t, ws)
   239  }
   240  
   241  func TestDialCookieJar(t *testing.T) {
   242  	s := newServer(t)
   243  	defer s.Close()
   244  
   245  	jar, _ := cookiejar.New(nil)
   246  	d := cstDialer
   247  	d.Jar = jar
   248  
   249  	u, _ := url.Parse(s.URL)
   250  
   251  	switch u.Scheme {
   252  	case "ws":
   253  		u.Scheme = "http"
   254  	case "wss":
   255  		u.Scheme = "https"
   256  	}
   257  
   258  	cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
   259  	d.Jar.SetCookies(u, cookies)
   260  
   261  	ws, _, err := d.Dial(s.URL, nil)
   262  	if err != nil {
   263  		t.Fatalf("Dial: %v", err)
   264  	}
   265  	defer ws.Close()
   266  
   267  	var gorilla string
   268  	var sessionID string
   269  	for _, c := range d.Jar.Cookies(u) {
   270  		if c.Name == "gorilla" {
   271  			gorilla = c.Value
   272  		}
   273  
   274  		if c.Name == "sessionID" {
   275  			sessionID = c.Value
   276  		}
   277  	}
   278  	if gorilla != "ws" {
   279  		t.Error("Cookie not present in jar.")
   280  	}
   281  
   282  	if sessionID != "1234" {
   283  		t.Error("Set-Cookie not received from the server.")
   284  	}
   285  
   286  	sendRecv(t, ws)
   287  }
   288  
   289  func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
   290  	certs := x509.NewCertPool()
   291  	for _, c := range s.TLS.Certificates {
   292  		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
   293  		if err != nil {
   294  			t.Fatalf("error parsing server's root cert: %v", err)
   295  		}
   296  		for _, root := range roots {
   297  			certs.AddCert(root)
   298  		}
   299  	}
   300  	return certs
   301  }
   302  
   303  func TestDialTLS(t *testing.T) {
   304  	s := newTLSServer(t)
   305  	defer s.Close()
   306  
   307  	d := cstDialer
   308  	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
   309  	ws, _, err := d.Dial(s.URL, nil)
   310  	if err != nil {
   311  		t.Fatalf("Dial: %v", err)
   312  	}
   313  	defer ws.Close()
   314  	sendRecv(t, ws)
   315  }
   316  
   317  func TestDialTimeout(t *testing.T) {
   318  	s := newServer(t)
   319  	defer s.Close()
   320  
   321  	d := cstDialer
   322  	d.HandshakeTimeout = -1
   323  	ws, _, err := d.Dial(s.URL, nil)
   324  	if err == nil {
   325  		ws.Close()
   326  		t.Fatalf("Dial: nil")
   327  	}
   328  }
   329  
   330  // requireDeadlineNetConn fails the current test when Read or Write are called
   331  // with no deadline.
   332  type requireDeadlineNetConn struct {
   333  	t                  *testing.T
   334  	c                  net.Conn
   335  	readDeadlineIsSet  bool
   336  	writeDeadlineIsSet bool
   337  }
   338  
   339  func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
   340  	c.writeDeadlineIsSet = !t.Equal(time.Time{})
   341  	c.readDeadlineIsSet = c.writeDeadlineIsSet
   342  	return c.c.SetDeadline(t)
   343  }
   344  
   345  func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
   346  	c.readDeadlineIsSet = !t.Equal(time.Time{})
   347  	return c.c.SetDeadline(t)
   348  }
   349  
   350  func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
   351  	c.writeDeadlineIsSet = !t.Equal(time.Time{})
   352  	return c.c.SetDeadline(t)
   353  }
   354  
   355  func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
   356  	if !c.writeDeadlineIsSet {
   357  		c.t.Fatalf("write with no deadline")
   358  	}
   359  	return c.c.Write(p)
   360  }
   361  
   362  func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
   363  	if !c.readDeadlineIsSet {
   364  		c.t.Fatalf("read with no deadline")
   365  	}
   366  	return c.c.Read(p)
   367  }
   368  
   369  func (c *requireDeadlineNetConn) Close() error         { return c.c.Close() }
   370  func (c *requireDeadlineNetConn) LocalAddr() net.Addr  { return c.c.LocalAddr() }
   371  func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
   372  
   373  func TestHandshakeTimeout(t *testing.T) {
   374  	s := newServer(t)
   375  	defer s.Close()
   376  
   377  	d := cstDialer
   378  	d.NetDial = func(n, a string) (net.Conn, error) {
   379  		c, err := net.Dial(n, a)
   380  		return &requireDeadlineNetConn{c: c, t: t}, err
   381  	}
   382  	ws, _, err := d.Dial(s.URL, nil)
   383  	if err != nil {
   384  		t.Fatal("Dial:", err)
   385  	}
   386  	ws.Close()
   387  }
   388  
   389  func TestHandshakeTimeoutInContext(t *testing.T) {
   390  	s := newServer(t)
   391  	defer s.Close()
   392  
   393  	d := cstDialer
   394  	d.HandshakeTimeout = 0
   395  	d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
   396  		netDialer := &net.Dialer{}
   397  		c, err := netDialer.DialContext(ctx, n, a)
   398  		return &requireDeadlineNetConn{c: c, t: t}, err
   399  	}
   400  
   401  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
   402  	defer cancel()
   403  	ws, _, err := d.DialContext(ctx, s.URL, nil)
   404  	if err != nil {
   405  		t.Fatal("Dial:", err)
   406  	}
   407  	ws.Close()
   408  }
   409  
   410  func TestDialBadScheme(t *testing.T) {
   411  	s := newServer(t)
   412  	defer s.Close()
   413  
   414  	ws, _, err := cstDialer.Dial(s.Server.URL, nil)
   415  	if err == nil {
   416  		ws.Close()
   417  		t.Fatalf("Dial: nil")
   418  	}
   419  }
   420  
   421  func TestDialBadOrigin(t *testing.T) {
   422  	s := newServer(t)
   423  	defer s.Close()
   424  
   425  	ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
   426  	if err == nil {
   427  		ws.Close()
   428  		t.Fatalf("Dial: nil")
   429  	}
   430  	if resp == nil {
   431  		t.Fatalf("resp=nil, err=%v", err)
   432  	}
   433  	if resp.StatusCode != http.StatusForbidden {
   434  		t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
   435  	}
   436  }
   437  
   438  func TestDialBadHeader(t *testing.T) {
   439  	s := newServer(t)
   440  	defer s.Close()
   441  
   442  	for _, k := range []string{"Upgrade",
   443  		"Connection",
   444  		"Sec-Websocket-Key",
   445  		"Sec-Websocket-Version",
   446  		"Sec-Websocket-Protocol"} {
   447  		h := http.Header{}
   448  		h.Set(k, "bad")
   449  		ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
   450  		if err == nil {
   451  			ws.Close()
   452  			t.Errorf("Dial with header %s returned nil", k)
   453  		}
   454  	}
   455  }
   456  
   457  func TestBadMethod(t *testing.T) {
   458  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   459  		ws, err := cstUpgrader.Upgrade(w, r, nil)
   460  		if err == nil {
   461  			t.Errorf("handshake succeeded, expect fail")
   462  			ws.Close()
   463  		}
   464  	}))
   465  	defer s.Close()
   466  
   467  	req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader(""))
   468  	if err != nil {
   469  		t.Fatalf("NewRequest returned error %v", err)
   470  	}
   471  	req.Header.Set("Connection", "upgrade")
   472  	req.Header.Set("Upgrade", "websocket")
   473  	req.Header.Set("Sec-Websocket-Version", "13")
   474  
   475  	resp, err := http.DefaultClient.Do(req)
   476  	if err != nil {
   477  		t.Fatalf("Do returned error %v", err)
   478  	}
   479  	resp.Body.Close()
   480  	if resp.StatusCode != http.StatusMethodNotAllowed {
   481  		t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
   482  	}
   483  }
   484  
   485  func TestDialExtraTokensInRespHeaders(t *testing.T) {
   486  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   487  		challengeKey := r.Header.Get("Sec-Websocket-Key")
   488  		w.Header().Set("Upgrade", "foo, websocket")
   489  		w.Header().Set("Connection", "upgrade, keep-alive")
   490  		w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
   491  		w.WriteHeader(101)
   492  	}))
   493  	defer s.Close()
   494  
   495  	ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
   496  	if err != nil {
   497  		t.Fatalf("Dial: %v", err)
   498  	}
   499  	defer ws.Close()
   500  }
   501  
   502  func TestHandshake(t *testing.T) {
   503  	s := newServer(t)
   504  	defer s.Close()
   505  
   506  	ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
   507  	if err != nil {
   508  		t.Fatalf("Dial: %v", err)
   509  	}
   510  	defer ws.Close()
   511  
   512  	var sessionID string
   513  	for _, c := range resp.Cookies() {
   514  		if c.Name == "sessionID" {
   515  			sessionID = c.Value
   516  		}
   517  	}
   518  	if sessionID != "1234" {
   519  		t.Error("Set-Cookie not received from the server.")
   520  	}
   521  
   522  	if ws.Subprotocol() != "p1" {
   523  		t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
   524  	}
   525  	sendRecv(t, ws)
   526  }
   527  
   528  func TestRespOnBadHandshake(t *testing.T) {
   529  	const expectedStatus = http.StatusGone
   530  	const expectedBody = "This is the response body."
   531  
   532  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   533  		w.WriteHeader(expectedStatus)
   534  		io.WriteString(w, expectedBody)
   535  	}))
   536  	defer s.Close()
   537  
   538  	ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
   539  	if err == nil {
   540  		ws.Close()
   541  		t.Fatalf("Dial: nil")
   542  	}
   543  
   544  	if resp == nil {
   545  		t.Fatalf("resp=nil, err=%v", err)
   546  	}
   547  
   548  	if resp.StatusCode != expectedStatus {
   549  		t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
   550  	}
   551  
   552  	p, err := ioutil.ReadAll(resp.Body)
   553  	if err != nil {
   554  		t.Fatalf("ReadFull(resp.Body) returned error %v", err)
   555  	}
   556  
   557  	if string(p) != expectedBody {
   558  		t.Errorf("resp.Body=%s, want %s", p, expectedBody)
   559  	}
   560  }
   561  
   562  type testLogWriter struct {
   563  	t *testing.T
   564  }
   565  
   566  func (w testLogWriter) Write(p []byte) (int, error) {
   567  	w.t.Logf("%s", p)
   568  	return len(p), nil
   569  }
   570  
   571  // TestHost tests handling of host names and confirms that it matches net/http.
   572  func TestHost(t *testing.T) {
   573  
   574  	upgrader := Upgrader{}
   575  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   576  		if IsWebSocketUpgrade(r) {
   577  			c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
   578  			if err != nil {
   579  				t.Fatal(err)
   580  			}
   581  			c.Close()
   582  		} else {
   583  			w.Header().Set("X-Test-Host", r.Host)
   584  		}
   585  	})
   586  
   587  	server := httptest.NewServer(handler)
   588  	defer server.Close()
   589  
   590  	tlsServer := httptest.NewTLSServer(handler)
   591  	defer tlsServer.Close()
   592  
   593  	addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
   594  	wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
   595  	httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
   596  
   597  	// Avoid log noise from net/http server by logging to testing.T
   598  	server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
   599  	tlsServer.Config.ErrorLog = server.Config.ErrorLog
   600  
   601  	cas := rootCAs(t, tlsServer)
   602  
   603  	tests := []struct {
   604  		fail               bool             // true if dial / get should fail
   605  		server             *httptest.Server // server to use
   606  		url                string           // host for request URI
   607  		header             string           // optional request host header
   608  		tls                string           // optional host for tls ServerName
   609  		wantAddr           string           // expected host for dial
   610  		wantHeader         string           // expected request header on server
   611  		insecureSkipVerify bool
   612  	}{
   613  		{
   614  			server:     server,
   615  			url:        addrs[server],
   616  			wantAddr:   addrs[server],
   617  			wantHeader: addrs[server],
   618  		},
   619  		{
   620  			server:     tlsServer,
   621  			url:        addrs[tlsServer],
   622  			wantAddr:   addrs[tlsServer],
   623  			wantHeader: addrs[tlsServer],
   624  		},
   625  
   626  		{
   627  			server:     server,
   628  			url:        addrs[server],
   629  			header:     "badhost.com",
   630  			wantAddr:   addrs[server],
   631  			wantHeader: "badhost.com",
   632  		},
   633  		{
   634  			server:     tlsServer,
   635  			url:        addrs[tlsServer],
   636  			header:     "badhost.com",
   637  			wantAddr:   addrs[tlsServer],
   638  			wantHeader: "badhost.com",
   639  		},
   640  
   641  		{
   642  			server:     server,
   643  			url:        "example.com",
   644  			header:     "badhost.com",
   645  			wantAddr:   "example.com:80",
   646  			wantHeader: "badhost.com",
   647  		},
   648  		{
   649  			server:     tlsServer,
   650  			url:        "example.com",
   651  			header:     "badhost.com",
   652  			wantAddr:   "example.com:443",
   653  			wantHeader: "badhost.com",
   654  		},
   655  
   656  		{
   657  			server:     server,
   658  			url:        "badhost.com",
   659  			header:     "example.com",
   660  			wantAddr:   "badhost.com:80",
   661  			wantHeader: "example.com",
   662  		},
   663  		{
   664  			fail:     true,
   665  			server:   tlsServer,
   666  			url:      "badhost.com",
   667  			header:   "example.com",
   668  			wantAddr: "badhost.com:443",
   669  		},
   670  		{
   671  			server:             tlsServer,
   672  			url:                "badhost.com",
   673  			insecureSkipVerify: true,
   674  			wantAddr:           "badhost.com:443",
   675  			wantHeader:         "badhost.com",
   676  		},
   677  		{
   678  			server:     tlsServer,
   679  			url:        "badhost.com",
   680  			tls:        "example.com",
   681  			wantAddr:   "badhost.com:443",
   682  			wantHeader: "badhost.com",
   683  		},
   684  	}
   685  
   686  	for i, tt := range tests {
   687  
   688  		tls := &tls.Config{
   689  			RootCAs:            cas,
   690  			ServerName:         tt.tls,
   691  			InsecureSkipVerify: tt.insecureSkipVerify,
   692  		}
   693  
   694  		var gotAddr string
   695  		dialer := Dialer{
   696  			NetDial: func(network, addr string) (net.Conn, error) {
   697  				gotAddr = addr
   698  				return net.Dial(network, addrs[tt.server])
   699  			},
   700  			TLSClientConfig: tls,
   701  		}
   702  
   703  		// Test websocket dial
   704  
   705  		h := http.Header{}
   706  		if tt.header != "" {
   707  			h.Set("Host", tt.header)
   708  		}
   709  		c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
   710  		if err == nil {
   711  			c.Close()
   712  		}
   713  
   714  		check := func(protos map[*httptest.Server]string) {
   715  			name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
   716  			if gotAddr != tt.wantAddr {
   717  				t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
   718  			}
   719  			switch {
   720  			case tt.fail && err == nil:
   721  				t.Errorf("%s: unexpected success", name)
   722  			case !tt.fail && err != nil:
   723  				t.Errorf("%s: unexpected error %v", name, err)
   724  			case !tt.fail && err == nil:
   725  				if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
   726  					t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
   727  				}
   728  			}
   729  		}
   730  
   731  		check(wsProtos)
   732  
   733  		// Confirm that net/http has same result
   734  
   735  		transport := &http.Transport{
   736  			Dial:            dialer.NetDial,
   737  			TLSClientConfig: dialer.TLSClientConfig,
   738  		}
   739  		req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
   740  		if tt.header != "" {
   741  			req.Host = tt.header
   742  		}
   743  		client := &http.Client{Transport: transport}
   744  		resp, err = client.Do(req)
   745  		if err == nil {
   746  			resp.Body.Close()
   747  		}
   748  		transport.CloseIdleConnections()
   749  		check(httpProtos)
   750  	}
   751  }
   752  
   753  func TestDialCompression(t *testing.T) {
   754  	s := newServer(t)
   755  	defer s.Close()
   756  
   757  	dialer := cstDialer
   758  	dialer.EnableCompression = true
   759  	ws, _, err := dialer.Dial(s.URL, nil)
   760  	if err != nil {
   761  		t.Fatalf("Dial: %v", err)
   762  	}
   763  	defer ws.Close()
   764  	sendRecv(t, ws)
   765  }
   766  
   767  func TestSocksProxyDial(t *testing.T) {
   768  	s := newServer(t)
   769  	defer s.Close()
   770  
   771  	proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
   772  	if err != nil {
   773  		t.Fatalf("listen failed: %v", err)
   774  	}
   775  	defer proxyListener.Close()
   776  	go func() {
   777  		c1, err := proxyListener.Accept()
   778  		if err != nil {
   779  			t.Errorf("proxy accept failed: %v", err)
   780  			return
   781  		}
   782  		defer c1.Close()
   783  
   784  		c1.SetDeadline(time.Now().Add(30 * time.Second))
   785  
   786  		buf := make([]byte, 32)
   787  		if _, err := io.ReadFull(c1, buf[:3]); err != nil {
   788  			t.Errorf("read failed: %v", err)
   789  			return
   790  		}
   791  		if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
   792  			t.Errorf("read %x, want %x", buf[:len(want)], want)
   793  		}
   794  		if _, err := c1.Write([]byte{5, 0}); err != nil {
   795  			t.Errorf("write failed: %v", err)
   796  			return
   797  		}
   798  		if _, err := io.ReadFull(c1, buf[:10]); err != nil {
   799  			t.Errorf("read failed: %v", err)
   800  			return
   801  		}
   802  		if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
   803  			t.Errorf("read %x, want %x", buf[:len(want)], want)
   804  			return
   805  		}
   806  		buf[1] = 0
   807  		if _, err := c1.Write(buf[:10]); err != nil {
   808  			t.Errorf("write failed: %v", err)
   809  			return
   810  		}
   811  
   812  		ip := net.IP(buf[4:8])
   813  		port := binary.BigEndian.Uint16(buf[8:10])
   814  
   815  		c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
   816  		if err != nil {
   817  			t.Errorf("dial failed; %v", err)
   818  			return
   819  		}
   820  		defer c2.Close()
   821  		done := make(chan struct{})
   822  		go func() {
   823  			io.Copy(c1, c2)
   824  			close(done)
   825  		}()
   826  		io.Copy(c2, c1)
   827  		<-done
   828  	}()
   829  
   830  	purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
   831  	if err != nil {
   832  		t.Fatalf("parse failed: %v", err)
   833  	}
   834  
   835  	cstDialer := cstDialer // make local copy for modification on next line.
   836  	cstDialer.Proxy = http.ProxyURL(purl)
   837  
   838  	ws, _, err := cstDialer.Dial(s.URL, nil)
   839  	if err != nil {
   840  		t.Fatalf("Dial: %v", err)
   841  	}
   842  	defer ws.Close()
   843  	sendRecv(t, ws)
   844  }
   845  
   846  func TestTracingDialWithContext(t *testing.T) {
   847  
   848  	var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
   849  	trace := &httptrace.ClientTrace{
   850  		WroteHeaders: func() {
   851  			headersWrote = true
   852  		},
   853  		WroteRequest: func(httptrace.WroteRequestInfo) {
   854  			requestWrote = true
   855  		},
   856  		GetConn: func(hostPort string) {
   857  			getConn = true
   858  		},
   859  		GotConn: func(info httptrace.GotConnInfo) {
   860  			gotConn = true
   861  		},
   862  		ConnectDone: func(network, addr string, err error) {
   863  			connectDone = true
   864  		},
   865  		GotFirstResponseByte: func() {
   866  			gotFirstResponseByte = true
   867  		},
   868  	}
   869  	ctx := httptrace.WithClientTrace(context.Background(), trace)
   870  
   871  	s := newTLSServer(t)
   872  	defer s.Close()
   873  
   874  	d := cstDialer
   875  	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
   876  
   877  	ws, _, err := d.DialContext(ctx, s.URL, nil)
   878  	if err != nil {
   879  		t.Fatalf("Dial: %v", err)
   880  	}
   881  
   882  	if !headersWrote {
   883  		t.Fatal("Headers was not written")
   884  	}
   885  	if !requestWrote {
   886  		t.Fatal("Request was not written")
   887  	}
   888  	if !getConn {
   889  		t.Fatal("getConn was not called")
   890  	}
   891  	if !gotConn {
   892  		t.Fatal("gotConn was not called")
   893  	}
   894  	if !connectDone {
   895  		t.Fatal("connectDone was not called")
   896  	}
   897  	if !gotFirstResponseByte {
   898  		t.Fatal("GotFirstResponseByte was not called")
   899  	}
   900  
   901  	defer ws.Close()
   902  	sendRecv(t, ws)
   903  }
   904  
   905  func TestEmptyTracingDialWithContext(t *testing.T) {
   906  
   907  	trace := &httptrace.ClientTrace{}
   908  	ctx := httptrace.WithClientTrace(context.Background(), trace)
   909  
   910  	s := newTLSServer(t)
   911  	defer s.Close()
   912  
   913  	d := cstDialer
   914  	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
   915  
   916  	ws, _, err := d.DialContext(ctx, s.URL, nil)
   917  	if err != nil {
   918  		t.Fatalf("Dial: %v", err)
   919  	}
   920  
   921  	defer ws.Close()
   922  	sendRecv(t, ws)
   923  }
   924  
   925  // TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
   926  func TestNetDialConnect(t *testing.T) {
   927  
   928  	upgrader := Upgrader{}
   929  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   930  		if IsWebSocketUpgrade(r) {
   931  			c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
   932  			if err != nil {
   933  				t.Fatal(err)
   934  			}
   935  			c.Close()
   936  		} else {
   937  			w.Header().Set("X-Test-Host", r.Host)
   938  		}
   939  	})
   940  
   941  	server := httptest.NewServer(handler)
   942  	defer server.Close()
   943  
   944  	tlsServer := httptest.NewTLSServer(handler)
   945  	defer tlsServer.Close()
   946  
   947  	testUrls := map[*httptest.Server]string{
   948  		server:    "ws://" + server.Listener.Addr().String() + "/",
   949  		tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
   950  	}
   951  
   952  	cas := rootCAs(t, tlsServer)
   953  	tlsConfig := &tls.Config{
   954  		RootCAs:            cas,
   955  		ServerName:         "example.com",
   956  		InsecureSkipVerify: false,
   957  	}
   958  
   959  	tests := []struct {
   960  		name              string
   961  		server            *httptest.Server // server to use
   962  		netDial           func(network, addr string) (net.Conn, error)
   963  		netDialContext    func(ctx context.Context, network, addr string) (net.Conn, error)
   964  		netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
   965  		tlsClientConfig   *tls.Config
   966  	}{
   967  
   968  		{
   969  			name:   "HTTP server, all NetDial* defined, shall use NetDialContext",
   970  			server: server,
   971  			netDial: func(network, addr string) (net.Conn, error) {
   972  				return nil, errors.New("NetDial should not be called")
   973  			},
   974  			netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
   975  				return net.Dial(network, addr)
   976  			},
   977  			netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
   978  				return nil, errors.New("NetDialTLSContext should not be called")
   979  			},
   980  			tlsClientConfig: nil,
   981  		},
   982  		{
   983  			name:              "HTTP server, all NetDial* undefined",
   984  			server:            server,
   985  			netDial:           nil,
   986  			netDialContext:    nil,
   987  			netDialTLSContext: nil,
   988  			tlsClientConfig:   nil,
   989  		},
   990  		{
   991  			name:   "HTTP server, NetDialContext undefined, shall fallback to NetDial",
   992  			server: server,
   993  			netDial: func(network, addr string) (net.Conn, error) {
   994  				return net.Dial(network, addr)
   995  			},
   996  			netDialContext: nil,
   997  			netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   998  				return nil, errors.New("NetDialTLSContext should not be called")
   999  			},
  1000  			tlsClientConfig: nil,
  1001  		},
  1002  		{
  1003  			name:   "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
  1004  			server: tlsServer,
  1005  			netDial: func(network, addr string) (net.Conn, error) {
  1006  				return nil, errors.New("NetDial should not be called")
  1007  			},
  1008  			netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  1009  				return nil, errors.New("NetDialContext should not be called")
  1010  			},
  1011  			netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  1012  				netConn, err := net.Dial(network, addr)
  1013  				if err != nil {
  1014  					return nil, err
  1015  				}
  1016  				tlsConn := tls.Client(netConn, tlsConfig)
  1017  				err = tlsConn.Handshake()
  1018  				if err != nil {
  1019  					return nil, err
  1020  				}
  1021  				return tlsConn, nil
  1022  			},
  1023  			tlsClientConfig: nil,
  1024  		},
  1025  		{
  1026  			name:   "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
  1027  			server: tlsServer,
  1028  			netDial: func(network, addr string) (net.Conn, error) {
  1029  				return nil, errors.New("NetDial should not be called")
  1030  			},
  1031  			netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  1032  				return net.Dial(network, addr)
  1033  			},
  1034  			netDialTLSContext: nil,
  1035  			tlsClientConfig:   tlsConfig,
  1036  		},
  1037  		{
  1038  			name:   "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
  1039  			server: tlsServer,
  1040  			netDial: func(network, addr string) (net.Conn, error) {
  1041  				return net.Dial(network, addr)
  1042  			},
  1043  			netDialContext:    nil,
  1044  			netDialTLSContext: nil,
  1045  			tlsClientConfig:   tlsConfig,
  1046  		},
  1047  		{
  1048  			name:              "HTTPS server, all NetDial* undefined",
  1049  			server:            tlsServer,
  1050  			netDial:           nil,
  1051  			netDialContext:    nil,
  1052  			netDialTLSContext: nil,
  1053  			tlsClientConfig:   tlsConfig,
  1054  		},
  1055  		{
  1056  			name:   "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
  1057  			server: tlsServer,
  1058  			netDial: func(network, addr string) (net.Conn, error) {
  1059  				return nil, errors.New("NetDial should not be called")
  1060  			},
  1061  			netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  1062  				return nil, errors.New("NetDialContext should not be called")
  1063  			},
  1064  			netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  1065  				netConn, err := net.Dial(network, addr)
  1066  				if err != nil {
  1067  					return nil, err
  1068  				}
  1069  				tlsConn := tls.Client(netConn, tlsConfig)
  1070  				err = tlsConn.Handshake()
  1071  				if err != nil {
  1072  					return nil, err
  1073  				}
  1074  				return tlsConn, nil
  1075  			},
  1076  			tlsClientConfig: &tls.Config{
  1077  				RootCAs:            nil,
  1078  				ServerName:         "badserver.com",
  1079  				InsecureSkipVerify: false,
  1080  			},
  1081  		},
  1082  	}
  1083  
  1084  	for _, tc := range tests {
  1085  		dialer := Dialer{
  1086  			NetDial:           tc.netDial,
  1087  			NetDialContext:    tc.netDialContext,
  1088  			NetDialTLSContext: tc.netDialTLSContext,
  1089  			TLSClientConfig:   tc.tlsClientConfig,
  1090  		}
  1091  
  1092  		// Test websocket dial
  1093  		c, _, err := dialer.Dial(testUrls[tc.server], nil)
  1094  		if err != nil {
  1095  			t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
  1096  		} else {
  1097  			c.Close()
  1098  		}
  1099  	}
  1100  }
  1101  func TestNextProtos(t *testing.T) {
  1102  	ts := httptest.NewUnstartedServer(
  1103  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
  1104  	)
  1105  	ts.EnableHTTP2 = true
  1106  	ts.StartTLS()
  1107  	defer ts.Close()
  1108  
  1109  	d := Dialer{
  1110  		TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
  1111  	}
  1112  
  1113  	r, err := ts.Client().Get(ts.URL)
  1114  	if err != nil {
  1115  		t.Fatalf("Get: %v", err)
  1116  	}
  1117  	r.Body.Close()
  1118  
  1119  	// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
  1120  	// after the Client.Get call from net/http above.
  1121  	var containsHTTP2 bool = false
  1122  	for _, proto := range d.TLSClientConfig.NextProtos {
  1123  		if proto == "h2" {
  1124  			containsHTTP2 = true
  1125  		}
  1126  	}
  1127  	if !containsHTTP2 {
  1128  		t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
  1129  	}
  1130  
  1131  	_, _, err = d.Dial(makeWsProto(ts.URL), nil)
  1132  	if err == nil {
  1133  		t.Fatalf("Dial succeeded, expect fail ")
  1134  	}
  1135  }
  1136  

View as plain text