...

Source file src/nhooyr.io/websocket/accept.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  // +build !js
     3  
     4  package websocket
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/sha1"
     9  	"encoding/base64"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"net/http"
    15  	"net/textproto"
    16  	"net/url"
    17  	"path/filepath"
    18  	"strings"
    19  
    20  	"nhooyr.io/websocket/internal/errd"
    21  )
    22  
    23  // AcceptOptions represents Accept's options.
    24  type AcceptOptions struct {
    25  	// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
    26  	// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
    27  	// reject it, close the connection when c.Subprotocol() == "".
    28  	Subprotocols []string
    29  
    30  	// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
    31  	//
    32  	// You probably want to use OriginPatterns instead.
    33  	InsecureSkipVerify bool
    34  
    35  	// OriginPatterns lists the host patterns for authorized origins.
    36  	// The request host is always authorized.
    37  	// Use this to enable cross origin WebSockets.
    38  	//
    39  	// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
    40  	// In such a case, example.com is the origin and chat.example.com is the request host.
    41  	// One would set this field to []string{"example.com"} to authorize example.com to connect.
    42  	//
    43  	// Each pattern is matched case insensitively against the request origin host
    44  	// with filepath.Match.
    45  	// See https://golang.org/pkg/path/filepath/#Match
    46  	//
    47  	// Please ensure you understand the ramifications of enabling this.
    48  	// If used incorrectly your WebSocket server will be open to CSRF attacks.
    49  	//
    50  	// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
    51  	// to bring attention to the danger of such a setting.
    52  	OriginPatterns []string
    53  
    54  	// CompressionMode controls the compression mode.
    55  	// Defaults to CompressionDisabled.
    56  	//
    57  	// See docs on CompressionMode for details.
    58  	CompressionMode CompressionMode
    59  
    60  	// CompressionThreshold controls the minimum size of a message before compression is applied.
    61  	//
    62  	// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
    63  	// for CompressionContextTakeover.
    64  	CompressionThreshold int
    65  }
    66  
    67  func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
    68  	var o AcceptOptions
    69  	if opts != nil {
    70  		o = *opts
    71  	}
    72  	return &o
    73  }
    74  
    75  // Accept accepts a WebSocket handshake from a client and upgrades the
    76  // the connection to a WebSocket.
    77  //
    78  // Accept will not allow cross origin requests by default.
    79  // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
    80  //
    81  // Accept will write a response to w on all errors.
    82  func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
    83  	return accept(w, r, opts)
    84  }
    85  
    86  func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
    87  	defer errd.Wrap(&err, "failed to accept WebSocket connection")
    88  
    89  	errCode, err := verifyClientRequest(w, r)
    90  	if err != nil {
    91  		http.Error(w, err.Error(), errCode)
    92  		return nil, err
    93  	}
    94  
    95  	opts = opts.cloneWithDefaults()
    96  	if !opts.InsecureSkipVerify {
    97  		err = authenticateOrigin(r, opts.OriginPatterns)
    98  		if err != nil {
    99  			if errors.Is(err, filepath.ErrBadPattern) {
   100  				log.Printf("websocket: %v", err)
   101  				err = errors.New(http.StatusText(http.StatusForbidden))
   102  			}
   103  			http.Error(w, err.Error(), http.StatusForbidden)
   104  			return nil, err
   105  		}
   106  	}
   107  
   108  	hj, ok := w.(http.Hijacker)
   109  	if !ok {
   110  		err = errors.New("http.ResponseWriter does not implement http.Hijacker")
   111  		http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
   112  		return nil, err
   113  	}
   114  
   115  	w.Header().Set("Upgrade", "websocket")
   116  	w.Header().Set("Connection", "Upgrade")
   117  
   118  	key := r.Header.Get("Sec-WebSocket-Key")
   119  	w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
   120  
   121  	subproto := selectSubprotocol(r, opts.Subprotocols)
   122  	if subproto != "" {
   123  		w.Header().Set("Sec-WebSocket-Protocol", subproto)
   124  	}
   125  
   126  	copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
   127  	if ok {
   128  		w.Header().Set("Sec-WebSocket-Extensions", copts.String())
   129  	}
   130  
   131  	w.WriteHeader(http.StatusSwitchingProtocols)
   132  	// See https://github.com/nhooyr/websocket/issues/166
   133  	if ginWriter, ok := w.(interface {
   134  		WriteHeaderNow()
   135  	}); ok {
   136  		ginWriter.WriteHeaderNow()
   137  	}
   138  
   139  	netConn, brw, err := hj.Hijack()
   140  	if err != nil {
   141  		err = fmt.Errorf("failed to hijack connection: %w", err)
   142  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   143  		return nil, err
   144  	}
   145  
   146  	// https://github.com/golang/go/issues/32314
   147  	b, _ := brw.Reader.Peek(brw.Reader.Buffered())
   148  	brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
   149  
   150  	return newConn(connConfig{
   151  		subprotocol:    w.Header().Get("Sec-WebSocket-Protocol"),
   152  		rwc:            netConn,
   153  		client:         false,
   154  		copts:          copts,
   155  		flateThreshold: opts.CompressionThreshold,
   156  
   157  		br: brw.Reader,
   158  		bw: brw.Writer,
   159  	}), nil
   160  }
   161  
   162  func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
   163  	if !r.ProtoAtLeast(1, 1) {
   164  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
   165  	}
   166  
   167  	if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
   168  		w.Header().Set("Connection", "Upgrade")
   169  		w.Header().Set("Upgrade", "websocket")
   170  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
   171  	}
   172  
   173  	if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
   174  		w.Header().Set("Connection", "Upgrade")
   175  		w.Header().Set("Upgrade", "websocket")
   176  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
   177  	}
   178  
   179  	if r.Method != "GET" {
   180  		return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
   181  	}
   182  
   183  	if r.Header.Get("Sec-WebSocket-Version") != "13" {
   184  		w.Header().Set("Sec-WebSocket-Version", "13")
   185  		return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
   186  	}
   187  
   188  	websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
   189  	if len(websocketSecKeys) == 0 {
   190  		return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
   191  	}
   192  
   193  	if len(websocketSecKeys) > 1 {
   194  		return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
   195  	}
   196  
   197  	// The RFC states to remove any leading or trailing whitespace.
   198  	websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
   199  	if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
   200  		return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
   201  	}
   202  
   203  	return 0, nil
   204  }
   205  
   206  func authenticateOrigin(r *http.Request, originHosts []string) error {
   207  	origin := r.Header.Get("Origin")
   208  	if origin == "" {
   209  		return nil
   210  	}
   211  
   212  	u, err := url.Parse(origin)
   213  	if err != nil {
   214  		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
   215  	}
   216  
   217  	if strings.EqualFold(r.Host, u.Host) {
   218  		return nil
   219  	}
   220  
   221  	for _, hostPattern := range originHosts {
   222  		matched, err := match(hostPattern, u.Host)
   223  		if err != nil {
   224  			return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
   225  		}
   226  		if matched {
   227  			return nil
   228  		}
   229  	}
   230  	if u.Host == "" {
   231  		return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
   232  	}
   233  	return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
   234  }
   235  
   236  func match(pattern, s string) (bool, error) {
   237  	return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
   238  }
   239  
   240  func selectSubprotocol(r *http.Request, subprotocols []string) string {
   241  	cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
   242  	for _, sp := range subprotocols {
   243  		for _, cp := range cps {
   244  			if strings.EqualFold(sp, cp) {
   245  				return cp
   246  			}
   247  		}
   248  	}
   249  	return ""
   250  }
   251  
   252  func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
   253  	if mode == CompressionDisabled {
   254  		return nil, false
   255  	}
   256  	for _, ext := range extensions {
   257  		switch ext.name {
   258  		// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
   259  		// See https://github.com/nhooyr/websocket/issues/218
   260  		case "permessage-deflate":
   261  			copts, ok := acceptDeflate(ext, mode)
   262  			if ok {
   263  				return copts, true
   264  			}
   265  		}
   266  	}
   267  	return nil, false
   268  }
   269  
   270  func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
   271  	copts := mode.opts()
   272  	for _, p := range ext.params {
   273  		switch p {
   274  		case "client_no_context_takeover":
   275  			copts.clientNoContextTakeover = true
   276  			continue
   277  		case "server_no_context_takeover":
   278  			copts.serverNoContextTakeover = true
   279  			continue
   280  		case "client_max_window_bits",
   281  			"server_max_window_bits=15":
   282  			continue
   283  		}
   284  
   285  		if strings.HasPrefix(p, "client_max_window_bits=") {
   286  			// We can't adjust the deflate window, but decoding with a larger window is acceptable.
   287  			continue
   288  		}
   289  		return nil, false
   290  	}
   291  	return copts, true
   292  }
   293  
   294  func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
   295  	for _, t := range headerTokens(h, key) {
   296  		if strings.EqualFold(t, token) {
   297  			return true
   298  		}
   299  	}
   300  	return false
   301  }
   302  
   303  type websocketExtension struct {
   304  	name   string
   305  	params []string
   306  }
   307  
   308  func websocketExtensions(h http.Header) []websocketExtension {
   309  	var exts []websocketExtension
   310  	extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
   311  	for _, extStr := range extStrs {
   312  		if extStr == "" {
   313  			continue
   314  		}
   315  
   316  		vals := strings.Split(extStr, ";")
   317  		for i := range vals {
   318  			vals[i] = strings.TrimSpace(vals[i])
   319  		}
   320  
   321  		e := websocketExtension{
   322  			name:   vals[0],
   323  			params: vals[1:],
   324  		}
   325  
   326  		exts = append(exts, e)
   327  	}
   328  	return exts
   329  }
   330  
   331  func headerTokens(h http.Header, key string) []string {
   332  	key = textproto.CanonicalMIMEHeaderKey(key)
   333  	var tokens []string
   334  	for _, v := range h[key] {
   335  		v = strings.TrimSpace(v)
   336  		for _, t := range strings.Split(v, ",") {
   337  			t = strings.TrimSpace(t)
   338  			tokens = append(tokens, t)
   339  		}
   340  	}
   341  	return tokens
   342  }
   343  
   344  var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
   345  
   346  func secWebSocketAccept(secWebSocketKey string) string {
   347  	h := sha1.New()
   348  	h.Write([]byte(secWebSocketKey))
   349  	h.Write(keyGUID)
   350  
   351  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
   352  }
   353  

View as plain text