...

Source file src/github.com/gorilla/websocket/client.go

Documentation: github.com/gorilla/websocket

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/tls"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"net"
    16  	"net/http"
    17  	"net/http/httptrace"
    18  	"net/url"
    19  	"strings"
    20  	"time"
    21  )
    22  
    23  // ErrBadHandshake is returned when the server response to opening handshake is
    24  // invalid.
    25  var ErrBadHandshake = errors.New("websocket: bad handshake")
    26  
    27  var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
    28  
    29  // NewClient creates a new client connection using the given net connection.
    30  // The URL u specifies the host and request URI. Use requestHeader to specify
    31  // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
    32  // (Cookie). Use the response.Header to get the selected subprotocol
    33  // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
    34  //
    35  // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
    36  // non-nil *http.Response so that callers can handle redirects, authentication,
    37  // etc.
    38  //
    39  // Deprecated: Use Dialer instead.
    40  func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
    41  	d := Dialer{
    42  		ReadBufferSize:  readBufSize,
    43  		WriteBufferSize: writeBufSize,
    44  		NetDial: func(net, addr string) (net.Conn, error) {
    45  			return netConn, nil
    46  		},
    47  	}
    48  	return d.Dial(u.String(), requestHeader)
    49  }
    50  
    51  // A Dialer contains options for connecting to WebSocket server.
    52  //
    53  // It is safe to call Dialer's methods concurrently.
    54  type Dialer struct {
    55  	// NetDial specifies the dial function for creating TCP connections. If
    56  	// NetDial is nil, net.Dial is used.
    57  	NetDial func(network, addr string) (net.Conn, error)
    58  
    59  	// NetDialContext specifies the dial function for creating TCP connections. If
    60  	// NetDialContext is nil, NetDial is used.
    61  	NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
    62  
    63  	// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
    64  	// NetDialTLSContext is nil, NetDialContext is used.
    65  	// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
    66  	// TLSClientConfig is ignored.
    67  	NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
    68  
    69  	// Proxy specifies a function to return a proxy for a given
    70  	// Request. If the function returns a non-nil error, the
    71  	// request is aborted with the provided error.
    72  	// If Proxy is nil or returns a nil *URL, no proxy is used.
    73  	Proxy func(*http.Request) (*url.URL, error)
    74  
    75  	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
    76  	// If nil, the default configuration is used.
    77  	// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
    78  	// is done there and TLSClientConfig is ignored.
    79  	TLSClientConfig *tls.Config
    80  
    81  	// HandshakeTimeout specifies the duration for the handshake to complete.
    82  	HandshakeTimeout time.Duration
    83  
    84  	// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
    85  	// size is zero, then a useful default size is used. The I/O buffer sizes
    86  	// do not limit the size of the messages that can be sent or received.
    87  	ReadBufferSize, WriteBufferSize int
    88  
    89  	// WriteBufferPool is a pool of buffers for write operations. If the value
    90  	// is not set, then write buffers are allocated to the connection for the
    91  	// lifetime of the connection.
    92  	//
    93  	// A pool is most useful when the application has a modest volume of writes
    94  	// across a large number of connections.
    95  	//
    96  	// Applications should use a single pool for each unique value of
    97  	// WriteBufferSize.
    98  	WriteBufferPool BufferPool
    99  
   100  	// Subprotocols specifies the client's requested subprotocols.
   101  	Subprotocols []string
   102  
   103  	// EnableCompression specifies if the client should attempt to negotiate
   104  	// per message compression (RFC 7692). Setting this value to true does not
   105  	// guarantee that compression will be supported. Currently only "no context
   106  	// takeover" modes are supported.
   107  	EnableCompression bool
   108  
   109  	// Jar specifies the cookie jar.
   110  	// If Jar is nil, cookies are not sent in requests and ignored
   111  	// in responses.
   112  	Jar http.CookieJar
   113  }
   114  
   115  // Dial creates a new client connection by calling DialContext with a background context.
   116  func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
   117  	return d.DialContext(context.Background(), urlStr, requestHeader)
   118  }
   119  
   120  var errMalformedURL = errors.New("malformed ws or wss URL")
   121  
   122  func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
   123  	hostPort = u.Host
   124  	hostNoPort = u.Host
   125  	if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
   126  		hostNoPort = hostNoPort[:i]
   127  	} else {
   128  		switch u.Scheme {
   129  		case "wss":
   130  			hostPort += ":443"
   131  		case "https":
   132  			hostPort += ":443"
   133  		default:
   134  			hostPort += ":80"
   135  		}
   136  	}
   137  	return hostPort, hostNoPort
   138  }
   139  
   140  // DefaultDialer is a dialer with all fields set to the default values.
   141  var DefaultDialer = &Dialer{
   142  	Proxy:            http.ProxyFromEnvironment,
   143  	HandshakeTimeout: 45 * time.Second,
   144  }
   145  
   146  // nilDialer is dialer to use when receiver is nil.
   147  var nilDialer = *DefaultDialer
   148  
   149  // DialContext creates a new client connection. Use requestHeader to specify the
   150  // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
   151  // Use the response.Header to get the selected subprotocol
   152  // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
   153  //
   154  // The context will be used in the request and in the Dialer.
   155  //
   156  // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
   157  // non-nil *http.Response so that callers can handle redirects, authentication,
   158  // etcetera. The response body may not contain the entire response and does not
   159  // need to be closed by the application.
   160  func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
   161  	if d == nil {
   162  		d = &nilDialer
   163  	}
   164  
   165  	challengeKey, err := generateChallengeKey()
   166  	if err != nil {
   167  		return nil, nil, err
   168  	}
   169  
   170  	u, err := url.Parse(urlStr)
   171  	if err != nil {
   172  		return nil, nil, err
   173  	}
   174  
   175  	switch u.Scheme {
   176  	case "ws":
   177  		u.Scheme = "http"
   178  	case "wss":
   179  		u.Scheme = "https"
   180  	default:
   181  		return nil, nil, errMalformedURL
   182  	}
   183  
   184  	if u.User != nil {
   185  		// User name and password are not allowed in websocket URIs.
   186  		return nil, nil, errMalformedURL
   187  	}
   188  
   189  	req := &http.Request{
   190  		Method:     http.MethodGet,
   191  		URL:        u,
   192  		Proto:      "HTTP/1.1",
   193  		ProtoMajor: 1,
   194  		ProtoMinor: 1,
   195  		Header:     make(http.Header),
   196  		Host:       u.Host,
   197  	}
   198  	req = req.WithContext(ctx)
   199  
   200  	// Set the cookies present in the cookie jar of the dialer
   201  	if d.Jar != nil {
   202  		for _, cookie := range d.Jar.Cookies(u) {
   203  			req.AddCookie(cookie)
   204  		}
   205  	}
   206  
   207  	// Set the request headers using the capitalization for names and values in
   208  	// RFC examples. Although the capitalization shouldn't matter, there are
   209  	// servers that depend on it. The Header.Set method is not used because the
   210  	// method canonicalizes the header names.
   211  	req.Header["Upgrade"] = []string{"websocket"}
   212  	req.Header["Connection"] = []string{"Upgrade"}
   213  	req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
   214  	req.Header["Sec-WebSocket-Version"] = []string{"13"}
   215  	if len(d.Subprotocols) > 0 {
   216  		req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
   217  	}
   218  	for k, vs := range requestHeader {
   219  		switch {
   220  		case k == "Host":
   221  			if len(vs) > 0 {
   222  				req.Host = vs[0]
   223  			}
   224  		case k == "Upgrade" ||
   225  			k == "Connection" ||
   226  			k == "Sec-Websocket-Key" ||
   227  			k == "Sec-Websocket-Version" ||
   228  			k == "Sec-Websocket-Extensions" ||
   229  			(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
   230  			return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
   231  		case k == "Sec-Websocket-Protocol":
   232  			req.Header["Sec-WebSocket-Protocol"] = vs
   233  		default:
   234  			req.Header[k] = vs
   235  		}
   236  	}
   237  
   238  	if d.EnableCompression {
   239  		req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
   240  	}
   241  
   242  	if d.HandshakeTimeout != 0 {
   243  		var cancel func()
   244  		ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
   245  		defer cancel()
   246  	}
   247  
   248  	// Get network dial function.
   249  	var netDial func(network, add string) (net.Conn, error)
   250  
   251  	switch u.Scheme {
   252  	case "http":
   253  		if d.NetDialContext != nil {
   254  			netDial = func(network, addr string) (net.Conn, error) {
   255  				return d.NetDialContext(ctx, network, addr)
   256  			}
   257  		} else if d.NetDial != nil {
   258  			netDial = d.NetDial
   259  		}
   260  	case "https":
   261  		if d.NetDialTLSContext != nil {
   262  			netDial = func(network, addr string) (net.Conn, error) {
   263  				return d.NetDialTLSContext(ctx, network, addr)
   264  			}
   265  		} else if d.NetDialContext != nil {
   266  			netDial = func(network, addr string) (net.Conn, error) {
   267  				return d.NetDialContext(ctx, network, addr)
   268  			}
   269  		} else if d.NetDial != nil {
   270  			netDial = d.NetDial
   271  		}
   272  	default:
   273  		return nil, nil, errMalformedURL
   274  	}
   275  
   276  	if netDial == nil {
   277  		netDialer := &net.Dialer{}
   278  		netDial = func(network, addr string) (net.Conn, error) {
   279  			return netDialer.DialContext(ctx, network, addr)
   280  		}
   281  	}
   282  
   283  	// If needed, wrap the dial function to set the connection deadline.
   284  	if deadline, ok := ctx.Deadline(); ok {
   285  		forwardDial := netDial
   286  		netDial = func(network, addr string) (net.Conn, error) {
   287  			c, err := forwardDial(network, addr)
   288  			if err != nil {
   289  				return nil, err
   290  			}
   291  			err = c.SetDeadline(deadline)
   292  			if err != nil {
   293  				c.Close()
   294  				return nil, err
   295  			}
   296  			return c, nil
   297  		}
   298  	}
   299  
   300  	// If needed, wrap the dial function to connect through a proxy.
   301  	if d.Proxy != nil {
   302  		proxyURL, err := d.Proxy(req)
   303  		if err != nil {
   304  			return nil, nil, err
   305  		}
   306  		if proxyURL != nil {
   307  			dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
   308  			if err != nil {
   309  				return nil, nil, err
   310  			}
   311  			netDial = dialer.Dial
   312  		}
   313  	}
   314  
   315  	hostPort, hostNoPort := hostPortNoPort(u)
   316  	trace := httptrace.ContextClientTrace(ctx)
   317  	if trace != nil && trace.GetConn != nil {
   318  		trace.GetConn(hostPort)
   319  	}
   320  
   321  	netConn, err := netDial("tcp", hostPort)
   322  	if err != nil {
   323  		return nil, nil, err
   324  	}
   325  	if trace != nil && trace.GotConn != nil {
   326  		trace.GotConn(httptrace.GotConnInfo{
   327  			Conn: netConn,
   328  		})
   329  	}
   330  
   331  	defer func() {
   332  		if netConn != nil {
   333  			netConn.Close()
   334  		}
   335  	}()
   336  
   337  	if u.Scheme == "https" && d.NetDialTLSContext == nil {
   338  		// If NetDialTLSContext is set, assume that the TLS handshake has already been done
   339  
   340  		cfg := cloneTLSConfig(d.TLSClientConfig)
   341  		if cfg.ServerName == "" {
   342  			cfg.ServerName = hostNoPort
   343  		}
   344  		tlsConn := tls.Client(netConn, cfg)
   345  		netConn = tlsConn
   346  
   347  		if trace != nil && trace.TLSHandshakeStart != nil {
   348  			trace.TLSHandshakeStart()
   349  		}
   350  		err := doHandshake(ctx, tlsConn, cfg)
   351  		if trace != nil && trace.TLSHandshakeDone != nil {
   352  			trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
   353  		}
   354  
   355  		if err != nil {
   356  			return nil, nil, err
   357  		}
   358  	}
   359  
   360  	conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
   361  
   362  	if err := req.Write(netConn); err != nil {
   363  		return nil, nil, err
   364  	}
   365  
   366  	if trace != nil && trace.GotFirstResponseByte != nil {
   367  		if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
   368  			trace.GotFirstResponseByte()
   369  		}
   370  	}
   371  
   372  	resp, err := http.ReadResponse(conn.br, req)
   373  	if err != nil {
   374  		if d.TLSClientConfig != nil {
   375  			for _, proto := range d.TLSClientConfig.NextProtos {
   376  				if proto != "http/1.1" {
   377  					return nil, nil, fmt.Errorf(
   378  						"websocket: protocol %q was given but is not supported;"+
   379  							"sharing tls.Config with net/http Transport can cause this error: %w",
   380  						proto, err,
   381  					)
   382  				}
   383  			}
   384  		}
   385  		return nil, nil, err
   386  	}
   387  
   388  	if d.Jar != nil {
   389  		if rc := resp.Cookies(); len(rc) > 0 {
   390  			d.Jar.SetCookies(u, rc)
   391  		}
   392  	}
   393  
   394  	if resp.StatusCode != 101 ||
   395  		!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
   396  		!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
   397  		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
   398  		// Before closing the network connection on return from this
   399  		// function, slurp up some of the response to aid application
   400  		// debugging.
   401  		buf := make([]byte, 1024)
   402  		n, _ := io.ReadFull(resp.Body, buf)
   403  		resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
   404  		return nil, resp, ErrBadHandshake
   405  	}
   406  
   407  	for _, ext := range parseExtensions(resp.Header) {
   408  		if ext[""] != "permessage-deflate" {
   409  			continue
   410  		}
   411  		_, snct := ext["server_no_context_takeover"]
   412  		_, cnct := ext["client_no_context_takeover"]
   413  		if !snct || !cnct {
   414  			return nil, resp, errInvalidCompression
   415  		}
   416  		conn.newCompressionWriter = compressNoContextTakeover
   417  		conn.newDecompressionReader = decompressNoContextTakeover
   418  		break
   419  	}
   420  
   421  	resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
   422  	conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
   423  
   424  	netConn.SetDeadline(time.Time{})
   425  	netConn = nil // to avoid close in defer.
   426  	return conn, resp, nil
   427  }
   428  
   429  func cloneTLSConfig(cfg *tls.Config) *tls.Config {
   430  	if cfg == nil {
   431  		return &tls.Config{}
   432  	}
   433  	return cfg.Clone()
   434  }
   435  

View as plain text