...

Source file src/github.com/gorilla/websocket/server.go

Documentation: github.com/gorilla/websocket

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"errors"
    10  	"io"
    11  	"net/http"
    12  	"net/url"
    13  	"strings"
    14  	"time"
    15  )
    16  
    17  // HandshakeError describes an error with the handshake from the peer.
    18  type HandshakeError struct {
    19  	message string
    20  }
    21  
    22  func (e HandshakeError) Error() string { return e.message }
    23  
    24  // Upgrader specifies parameters for upgrading an HTTP connection to a
    25  // WebSocket connection.
    26  //
    27  // It is safe to call Upgrader's methods concurrently.
    28  type Upgrader struct {
    29  	// HandshakeTimeout specifies the duration for the handshake to complete.
    30  	HandshakeTimeout time.Duration
    31  
    32  	// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
    33  	// size is zero, then buffers allocated by the HTTP server are used. The
    34  	// I/O buffer sizes do not limit the size of the messages that can be sent
    35  	// or received.
    36  	ReadBufferSize, WriteBufferSize int
    37  
    38  	// WriteBufferPool is a pool of buffers for write operations. If the value
    39  	// is not set, then write buffers are allocated to the connection for the
    40  	// lifetime of the connection.
    41  	//
    42  	// A pool is most useful when the application has a modest volume of writes
    43  	// across a large number of connections.
    44  	//
    45  	// Applications should use a single pool for each unique value of
    46  	// WriteBufferSize.
    47  	WriteBufferPool BufferPool
    48  
    49  	// Subprotocols specifies the server's supported protocols in order of
    50  	// preference. If this field is not nil, then the Upgrade method negotiates a
    51  	// subprotocol by selecting the first match in this list with a protocol
    52  	// requested by the client. If there's no match, then no protocol is
    53  	// negotiated (the Sec-Websocket-Protocol header is not included in the
    54  	// handshake response).
    55  	Subprotocols []string
    56  
    57  	// Error specifies the function for generating HTTP error responses. If Error
    58  	// is nil, then http.Error is used to generate the HTTP response.
    59  	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
    60  
    61  	// CheckOrigin returns true if the request Origin header is acceptable. If
    62  	// CheckOrigin is nil, then a safe default is used: return false if the
    63  	// Origin request header is present and the origin host is not equal to
    64  	// request Host header.
    65  	//
    66  	// A CheckOrigin function should carefully validate the request origin to
    67  	// prevent cross-site request forgery.
    68  	CheckOrigin func(r *http.Request) bool
    69  
    70  	// EnableCompression specify if the server should attempt to negotiate per
    71  	// message compression (RFC 7692). Setting this value to true does not
    72  	// guarantee that compression will be supported. Currently only "no context
    73  	// takeover" modes are supported.
    74  	EnableCompression bool
    75  }
    76  
    77  func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
    78  	err := HandshakeError{reason}
    79  	if u.Error != nil {
    80  		u.Error(w, r, status, err)
    81  	} else {
    82  		w.Header().Set("Sec-Websocket-Version", "13")
    83  		http.Error(w, http.StatusText(status), status)
    84  	}
    85  	return nil, err
    86  }
    87  
    88  // checkSameOrigin returns true if the origin is not set or is equal to the request host.
    89  func checkSameOrigin(r *http.Request) bool {
    90  	origin := r.Header["Origin"]
    91  	if len(origin) == 0 {
    92  		return true
    93  	}
    94  	u, err := url.Parse(origin[0])
    95  	if err != nil {
    96  		return false
    97  	}
    98  	return equalASCIIFold(u.Host, r.Host)
    99  }
   100  
   101  func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
   102  	if u.Subprotocols != nil {
   103  		clientProtocols := Subprotocols(r)
   104  		for _, serverProtocol := range u.Subprotocols {
   105  			for _, clientProtocol := range clientProtocols {
   106  				if clientProtocol == serverProtocol {
   107  					return clientProtocol
   108  				}
   109  			}
   110  		}
   111  	} else if responseHeader != nil {
   112  		return responseHeader.Get("Sec-Websocket-Protocol")
   113  	}
   114  	return ""
   115  }
   116  
   117  // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
   118  //
   119  // The responseHeader is included in the response to the client's upgrade
   120  // request. Use the responseHeader to specify cookies (Set-Cookie). To specify
   121  // subprotocols supported by the server, set Upgrader.Subprotocols directly.
   122  //
   123  // If the upgrade fails, then Upgrade replies to the client with an HTTP error
   124  // response.
   125  func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
   126  	const badHandshake = "websocket: the client is not using the websocket protocol: "
   127  
   128  	if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
   129  		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
   130  	}
   131  
   132  	if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
   133  		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
   134  	}
   135  
   136  	if r.Method != http.MethodGet {
   137  		return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
   138  	}
   139  
   140  	if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
   141  		return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
   142  	}
   143  
   144  	if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
   145  		return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
   146  	}
   147  
   148  	checkOrigin := u.CheckOrigin
   149  	if checkOrigin == nil {
   150  		checkOrigin = checkSameOrigin
   151  	}
   152  	if !checkOrigin(r) {
   153  		return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
   154  	}
   155  
   156  	challengeKey := r.Header.Get("Sec-Websocket-Key")
   157  	if !isValidChallengeKey(challengeKey) {
   158  		return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
   159  	}
   160  
   161  	subprotocol := u.selectSubprotocol(r, responseHeader)
   162  
   163  	// Negotiate PMCE
   164  	var compress bool
   165  	if u.EnableCompression {
   166  		for _, ext := range parseExtensions(r.Header) {
   167  			if ext[""] != "permessage-deflate" {
   168  				continue
   169  			}
   170  			compress = true
   171  			break
   172  		}
   173  	}
   174  
   175  	h, ok := w.(http.Hijacker)
   176  	if !ok {
   177  		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
   178  	}
   179  	var brw *bufio.ReadWriter
   180  	netConn, brw, err := h.Hijack()
   181  	if err != nil {
   182  		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
   183  	}
   184  
   185  	if brw.Reader.Buffered() > 0 {
   186  		netConn.Close()
   187  		return nil, errors.New("websocket: client sent data before handshake is complete")
   188  	}
   189  
   190  	var br *bufio.Reader
   191  	if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
   192  		// Reuse hijacked buffered reader as connection reader.
   193  		br = brw.Reader
   194  	}
   195  
   196  	buf := bufioWriterBuffer(netConn, brw.Writer)
   197  
   198  	var writeBuf []byte
   199  	if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
   200  		// Reuse hijacked write buffer as connection buffer.
   201  		writeBuf = buf
   202  	}
   203  
   204  	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
   205  	c.subprotocol = subprotocol
   206  
   207  	if compress {
   208  		c.newCompressionWriter = compressNoContextTakeover
   209  		c.newDecompressionReader = decompressNoContextTakeover
   210  	}
   211  
   212  	// Use larger of hijacked buffer and connection write buffer for header.
   213  	p := buf
   214  	if len(c.writeBuf) > len(p) {
   215  		p = c.writeBuf
   216  	}
   217  	p = p[:0]
   218  
   219  	p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
   220  	p = append(p, computeAcceptKey(challengeKey)...)
   221  	p = append(p, "\r\n"...)
   222  	if c.subprotocol != "" {
   223  		p = append(p, "Sec-WebSocket-Protocol: "...)
   224  		p = append(p, c.subprotocol...)
   225  		p = append(p, "\r\n"...)
   226  	}
   227  	if compress {
   228  		p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
   229  	}
   230  	for k, vs := range responseHeader {
   231  		if k == "Sec-Websocket-Protocol" {
   232  			continue
   233  		}
   234  		for _, v := range vs {
   235  			p = append(p, k...)
   236  			p = append(p, ": "...)
   237  			for i := 0; i < len(v); i++ {
   238  				b := v[i]
   239  				if b <= 31 {
   240  					// prevent response splitting.
   241  					b = ' '
   242  				}
   243  				p = append(p, b)
   244  			}
   245  			p = append(p, "\r\n"...)
   246  		}
   247  	}
   248  	p = append(p, "\r\n"...)
   249  
   250  	// Clear deadlines set by HTTP server.
   251  	netConn.SetDeadline(time.Time{})
   252  
   253  	if u.HandshakeTimeout > 0 {
   254  		netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
   255  	}
   256  	if _, err = netConn.Write(p); err != nil {
   257  		netConn.Close()
   258  		return nil, err
   259  	}
   260  	if u.HandshakeTimeout > 0 {
   261  		netConn.SetWriteDeadline(time.Time{})
   262  	}
   263  
   264  	return c, nil
   265  }
   266  
   267  // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
   268  //
   269  // Deprecated: Use websocket.Upgrader instead.
   270  //
   271  // Upgrade does not perform origin checking. The application is responsible for
   272  // checking the Origin header before calling Upgrade. An example implementation
   273  // of the same origin policy check is:
   274  //
   275  //	if req.Header.Get("Origin") != "http://"+req.Host {
   276  //		http.Error(w, "Origin not allowed", http.StatusForbidden)
   277  //		return
   278  //	}
   279  //
   280  // If the endpoint supports subprotocols, then the application is responsible
   281  // for negotiating the protocol used on the connection. Use the Subprotocols()
   282  // function to get the subprotocols requested by the client. Use the
   283  // Sec-Websocket-Protocol response header to specify the subprotocol selected
   284  // by the application.
   285  //
   286  // The responseHeader is included in the response to the client's upgrade
   287  // request. Use the responseHeader to specify cookies (Set-Cookie) and the
   288  // negotiated subprotocol (Sec-Websocket-Protocol).
   289  //
   290  // The connection buffers IO to the underlying network connection. The
   291  // readBufSize and writeBufSize parameters specify the size of the buffers to
   292  // use. Messages can be larger than the buffers.
   293  //
   294  // If the request is not a valid WebSocket handshake, then Upgrade returns an
   295  // error of type HandshakeError. Applications should handle this error by
   296  // replying to the client with an HTTP error response.
   297  func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
   298  	u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
   299  	u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
   300  		// don't return errors to maintain backwards compatibility
   301  	}
   302  	u.CheckOrigin = func(r *http.Request) bool {
   303  		// allow all connections by default
   304  		return true
   305  	}
   306  	return u.Upgrade(w, r, responseHeader)
   307  }
   308  
   309  // Subprotocols returns the subprotocols requested by the client in the
   310  // Sec-Websocket-Protocol header.
   311  func Subprotocols(r *http.Request) []string {
   312  	h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
   313  	if h == "" {
   314  		return nil
   315  	}
   316  	protocols := strings.Split(h, ",")
   317  	for i := range protocols {
   318  		protocols[i] = strings.TrimSpace(protocols[i])
   319  	}
   320  	return protocols
   321  }
   322  
   323  // IsWebSocketUpgrade returns true if the client requested upgrade to the
   324  // WebSocket protocol.
   325  func IsWebSocketUpgrade(r *http.Request) bool {
   326  	return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
   327  		tokenListContainsValue(r.Header, "Upgrade", "websocket")
   328  }
   329  
   330  // bufioReaderSize size returns the size of a bufio.Reader.
   331  func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
   332  	// This code assumes that peek on a reset reader returns
   333  	// bufio.Reader.buf[:0].
   334  	// TODO: Use bufio.Reader.Size() after Go 1.10
   335  	br.Reset(originalReader)
   336  	if p, err := br.Peek(0); err == nil {
   337  		return cap(p)
   338  	}
   339  	return 0
   340  }
   341  
   342  // writeHook is an io.Writer that records the last slice passed to it vio
   343  // io.Writer.Write.
   344  type writeHook struct {
   345  	p []byte
   346  }
   347  
   348  func (wh *writeHook) Write(p []byte) (int, error) {
   349  	wh.p = p
   350  	return len(p), nil
   351  }
   352  
   353  // bufioWriterBuffer grabs the buffer from a bufio.Writer.
   354  func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
   355  	// This code assumes that bufio.Writer.buf[:1] is passed to the
   356  	// bufio.Writer's underlying writer.
   357  	var wh writeHook
   358  	bw.Reset(&wh)
   359  	bw.WriteByte(0)
   360  	bw.Flush()
   361  
   362  	bw.Reset(originalWriter)
   363  
   364  	return wh.p[:cap(wh.p)]
   365  }
   366  

View as plain text