...

Source file src/github.com/tmc/grpc-websocket-proxy/wsproxy/websocket_proxy.go

Documentation: github.com/tmc/grpc-websocket-proxy/wsproxy

     1  package wsproxy
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/gorilla/websocket"
    12  	"github.com/sirupsen/logrus"
    13  	"golang.org/x/net/context"
    14  )
    15  
    16  // MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method.
    17  //
    18  // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
    19  var MethodOverrideParam = "method"
    20  
    21  // TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers.
    22  //
    23  // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
    24  var TokenCookieName = "token"
    25  
    26  // RequestMutatorFunc can supply an alternate outgoing request.
    27  type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request
    28  
    29  // Proxy provides websocket transport upgrade to compatible endpoints.
    30  type Proxy struct {
    31  	h                      http.Handler
    32  	logger                 Logger
    33  	maxRespBodyBufferBytes int
    34  	methodOverrideParam    string
    35  	tokenCookieName        string
    36  	requestMutator         RequestMutatorFunc
    37  	headerForwarder        func(header string) bool
    38  	pingInterval           time.Duration
    39  	pingWait               time.Duration
    40  	pongWait               time.Duration
    41  }
    42  
    43  // Logger collects log messages.
    44  type Logger interface {
    45  	Warnln(...interface{})
    46  	Debugln(...interface{})
    47  }
    48  
    49  func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    50  	if !websocket.IsWebSocketUpgrade(r) {
    51  		p.h.ServeHTTP(w, r)
    52  		return
    53  	}
    54  	p.proxy(w, r)
    55  }
    56  
    57  // Option allows customization of the proxy.
    58  type Option func(*Proxy)
    59  
    60  // WithMaxRespBodyBufferSize allows specification of a custom size for the
    61  // buffer used while reading the response body. By default, the bufio.Scanner
    62  // used to read the response body sets the maximum token size to MaxScanTokenSize.
    63  func WithMaxRespBodyBufferSize(nBytes int) Option {
    64  	return func(p *Proxy) {
    65  		p.maxRespBodyBufferBytes = nBytes
    66  	}
    67  }
    68  
    69  // WithMethodParamOverride allows specification of the special http parameter that is used in the proxied streaming request.
    70  func WithMethodParamOverride(param string) Option {
    71  	return func(p *Proxy) {
    72  		p.methodOverrideParam = param
    73  	}
    74  }
    75  
    76  // WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header.
    77  func WithTokenCookieName(param string) Option {
    78  	return func(p *Proxy) {
    79  		p.tokenCookieName = param
    80  	}
    81  }
    82  
    83  // WithRequestMutator allows a custom RequestMutatorFunc to be supplied.
    84  func WithRequestMutator(fn RequestMutatorFunc) Option {
    85  	return func(p *Proxy) {
    86  		p.requestMutator = fn
    87  	}
    88  }
    89  
    90  // WithForwardedHeaders allows controlling which headers are forwarded.
    91  func WithForwardedHeaders(fn func(header string) bool) Option {
    92  	return func(p *Proxy) {
    93  		p.headerForwarder = fn
    94  	}
    95  }
    96  
    97  // WithLogger allows a custom FieldLogger to be supplied
    98  func WithLogger(logger Logger) Option {
    99  	return func(p *Proxy) {
   100  		p.logger = logger
   101  	}
   102  }
   103  
   104  // WithPingControl allows specification of ping pong control. The interval
   105  // parameter specifies the pingInterval between pings. The allowed wait time
   106  // for a pong response is (pingInterval * 10) / 9.
   107  func WithPingControl(interval time.Duration) Option {
   108  	return func(proxy *Proxy) {
   109  		proxy.pingInterval = interval
   110  		proxy.pongWait = (interval * 10) / 9
   111  		proxy.pingWait = proxy.pongWait / 6
   112  	}
   113  }
   114  
   115  var defaultHeadersToForward = map[string]bool{
   116  	"Origin":  true,
   117  	"origin":  true,
   118  	"Referer": true,
   119  	"referer": true,
   120  }
   121  
   122  func defaultHeaderForwarder(header string) bool {
   123  	return defaultHeadersToForward[header]
   124  }
   125  
   126  // WebsocketProxy attempts to expose the underlying handler as a bidi websocket stream with newline-delimited
   127  // JSON as the content encoding.
   128  //
   129  // The HTTP Authorization header is either populated from the Sec-Websocket-Protocol field or by a cookie.
   130  // The cookie name is specified by the TokenCookieName value.
   131  //
   132  // example:
   133  //   Sec-Websocket-Protocol: Bearer, foobar
   134  // is converted to:
   135  //   Authorization: Bearer foobar
   136  //
   137  // Method can be overwritten with the MethodOverrideParam get parameter in the requested URL
   138  func WebsocketProxy(h http.Handler, opts ...Option) http.Handler {
   139  	p := &Proxy{
   140  		h:                   h,
   141  		logger:              logrus.New(),
   142  		methodOverrideParam: MethodOverrideParam,
   143  		tokenCookieName:     TokenCookieName,
   144  		headerForwarder:     defaultHeaderForwarder,
   145  	}
   146  	for _, o := range opts {
   147  		o(p)
   148  	}
   149  	return p
   150  }
   151  
   152  // TODO(tmc): allow modification of upgrader settings?
   153  var upgrader = websocket.Upgrader{
   154  	ReadBufferSize:  1024,
   155  	WriteBufferSize: 1024,
   156  	CheckOrigin:     func(r *http.Request) bool { return true },
   157  }
   158  
   159  func isClosedConnError(err error) bool {
   160  	str := err.Error()
   161  	if strings.Contains(str, "use of closed network connection") {
   162  		return true
   163  	}
   164  	return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway)
   165  }
   166  
   167  func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
   168  	var responseHeader http.Header
   169  	// If Sec-WebSocket-Protocol starts with "Bearer", respond in kind.
   170  	// TODO(tmc): consider customizability/extension point here.
   171  	if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") {
   172  		responseHeader = http.Header{
   173  			"Sec-WebSocket-Protocol": []string{"Bearer"},
   174  		}
   175  	}
   176  	conn, err := upgrader.Upgrade(w, r, responseHeader)
   177  	if err != nil {
   178  		p.logger.Warnln("error upgrading websocket:", err)
   179  		return
   180  	}
   181  	defer conn.Close()
   182  
   183  	ctx, cancelFn := context.WithCancel(context.Background())
   184  	defer cancelFn()
   185  
   186  	requestBodyR, requestBodyW := io.Pipe()
   187  	request, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), requestBodyR)
   188  	if err != nil {
   189  		p.logger.Warnln("error preparing request:", err)
   190  		return
   191  	}
   192  	if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" {
   193  		request.Header.Set("Authorization", transformSubProtocolHeader(swsp))
   194  	}
   195  	for header := range r.Header {
   196  		if p.headerForwarder(header) {
   197  			request.Header.Set(header, r.Header.Get(header))
   198  		}
   199  	}
   200  	// If token cookie is present, populate Authorization header from the cookie instead.
   201  	if cookie, err := r.Cookie(p.tokenCookieName); err == nil {
   202  		request.Header.Set("Authorization", "Bearer "+cookie.Value)
   203  	}
   204  	if m := r.URL.Query().Get(p.methodOverrideParam); m != "" {
   205  		request.Method = m
   206  	}
   207  
   208  	if p.requestMutator != nil {
   209  		request = p.requestMutator(r, request)
   210  	}
   211  
   212  	responseBodyR, responseBodyW := io.Pipe()
   213  	response := newInMemoryResponseWriter(responseBodyW)
   214  	go func() {
   215  		<-ctx.Done()
   216  		p.logger.Debugln("closing pipes")
   217  		requestBodyW.CloseWithError(io.EOF)
   218  		responseBodyW.CloseWithError(io.EOF)
   219  		response.closed <- true
   220  	}()
   221  
   222  	go func() {
   223  		defer cancelFn()
   224  		p.h.ServeHTTP(response, request)
   225  	}()
   226  
   227  	// read loop -- take messages from websocket and write to http request
   228  	go func() {
   229  		if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
   230  			conn.SetReadDeadline(time.Now().Add(p.pongWait))
   231  			conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil })
   232  		}
   233  		defer func() {
   234  			cancelFn()
   235  		}()
   236  		for {
   237  			select {
   238  			case <-ctx.Done():
   239  				p.logger.Debugln("read loop done")
   240  				return
   241  			default:
   242  			}
   243  			p.logger.Debugln("[read] reading from socket.")
   244  			_, payload, err := conn.ReadMessage()
   245  			if err != nil {
   246  				if isClosedConnError(err) {
   247  					p.logger.Debugln("[read] websocket closed:", err)
   248  					return
   249  				}
   250  				p.logger.Warnln("error reading websocket message:", err)
   251  				return
   252  			}
   253  			p.logger.Debugln("[read] read payload:", string(payload))
   254  			p.logger.Debugln("[read] writing to requestBody:")
   255  			n, err := requestBodyW.Write(payload)
   256  			requestBodyW.Write([]byte("\n"))
   257  			p.logger.Debugln("[read] wrote to requestBody", n)
   258  			if err != nil {
   259  				p.logger.Warnln("[read] error writing message to upstream http server:", err)
   260  				return
   261  			}
   262  		}
   263  	}()
   264  	// ping write loop
   265  	if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
   266  		go func() {
   267  			ticker := time.NewTicker(p.pingInterval)
   268  			defer func() {
   269  				ticker.Stop()
   270  				conn.Close()
   271  			}()
   272  			for {
   273  				select {
   274  				case <-ctx.Done():
   275  					p.logger.Debugln("ping loop done")
   276  					return
   277  				case <-ticker.C:
   278  					conn.SetWriteDeadline(time.Now().Add(p.pingWait))
   279  					if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
   280  						return
   281  					}
   282  				}
   283  			}
   284  		}()
   285  	}
   286  	// write loop -- take messages from response and write to websocket
   287  	scanner := bufio.NewScanner(responseBodyR)
   288  
   289  	// if maxRespBodyBufferSize has been specified, use custom buffer for scanner
   290  	var scannerBuf []byte
   291  	if p.maxRespBodyBufferBytes > 0 {
   292  		scannerBuf = make([]byte, 0, 64*1024)
   293  		scanner.Buffer(scannerBuf, p.maxRespBodyBufferBytes)
   294  	}
   295  
   296  	for scanner.Scan() {
   297  		if len(scanner.Bytes()) == 0 {
   298  			p.logger.Warnln("[write] empty scan", scanner.Err())
   299  			continue
   300  		}
   301  		p.logger.Debugln("[write] scanned", scanner.Text())
   302  		if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil {
   303  			p.logger.Warnln("[write] error writing websocket message:", err)
   304  			return
   305  		}
   306  	}
   307  	if err := scanner.Err(); err != nil {
   308  		p.logger.Warnln("scanner err:", err)
   309  	}
   310  }
   311  
   312  type inMemoryResponseWriter struct {
   313  	io.Writer
   314  	header http.Header
   315  	code   int
   316  	closed chan bool
   317  }
   318  
   319  func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter {
   320  	return &inMemoryResponseWriter{
   321  		Writer: w,
   322  		header: http.Header{},
   323  		closed: make(chan bool, 1),
   324  	}
   325  }
   326  
   327  // IE and Edge do not delimit Sec-WebSocket-Protocol strings with spaces
   328  func transformSubProtocolHeader(header string) string {
   329  	tokens := strings.SplitN(header, "Bearer,", 2)
   330  
   331  	if len(tokens) < 2 {
   332  		return ""
   333  	}
   334  
   335  	return fmt.Sprintf("Bearer %v", strings.Trim(tokens[1], " "))
   336  }
   337  
   338  func (w *inMemoryResponseWriter) Write(b []byte) (int, error) {
   339  	return w.Writer.Write(b)
   340  }
   341  func (w *inMemoryResponseWriter) Header() http.Header {
   342  	return w.header
   343  }
   344  func (w *inMemoryResponseWriter) WriteHeader(code int) {
   345  	w.code = code
   346  }
   347  func (w *inMemoryResponseWriter) CloseNotify() <-chan bool {
   348  	return w.closed
   349  }
   350  func (w *inMemoryResponseWriter) Flush() {}
   351  

View as plain text