...

Source file src/github.com/99designs/gqlgen/graphql/handler/transport/websocket.go

Documentation: github.com/99designs/gqlgen/graphql/handler/transport

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"log"
    10  	"net"
    11  	"net/http"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/gorilla/websocket"
    16  	"github.com/vektah/gqlparser/v2/gqlerror"
    17  
    18  	"github.com/99designs/gqlgen/graphql"
    19  	"github.com/99designs/gqlgen/graphql/errcode"
    20  )
    21  
    22  type (
    23  	Websocket struct {
    24  		Upgrader              websocket.Upgrader
    25  		InitFunc              WebsocketInitFunc
    26  		InitTimeout           time.Duration
    27  		ErrorFunc             WebsocketErrorFunc
    28  		CloseFunc             WebsocketCloseFunc
    29  		KeepAlivePingInterval time.Duration
    30  		PongOnlyInterval      time.Duration
    31  		PingPongInterval      time.Duration
    32  		/* If PingPongInterval has a non-0 duration, then when the server sends a ping
    33  		 * it sets a ReadDeadline of PingPongInterval*2 and if the client doesn't respond
    34  		 * with pong before that deadline is reached then the connection will die with a
    35  		 * 1006 error code.
    36  		 *
    37  		 * MissingPongOk if true, tells the server to not use a ReadDeadline such that a
    38  		 * missing/slow pong response from the client doesn't kill the connection.
    39  		 */
    40  		MissingPongOk bool
    41  
    42  		didInjectSubprotocols bool
    43  	}
    44  	wsConnection struct {
    45  		Websocket
    46  		ctx             context.Context
    47  		conn            *websocket.Conn
    48  		me              messageExchanger
    49  		active          map[string]context.CancelFunc
    50  		mu              sync.Mutex
    51  		keepAliveTicker *time.Ticker
    52  		pongOnlyTicker  *time.Ticker
    53  		pingPongTicker  *time.Ticker
    54  		receivedPong    bool
    55  		exec            graphql.GraphExecutor
    56  		closed          bool
    57  
    58  		initPayload InitPayload
    59  	}
    60  
    61  	WebsocketInitFunc  func(ctx context.Context, initPayload InitPayload) (context.Context, *InitPayload, error)
    62  	WebsocketErrorFunc func(ctx context.Context, err error)
    63  
    64  	// Callback called when websocket is closed.
    65  	WebsocketCloseFunc func(ctx context.Context, closeCode int)
    66  )
    67  
    68  var errReadTimeout = errors.New("read timeout")
    69  
    70  type WebsocketError struct {
    71  	Err error
    72  
    73  	// IsReadError flags whether the error occurred on read or write to the websocket
    74  	IsReadError bool
    75  }
    76  
    77  func (e WebsocketError) Error() string {
    78  	if e.IsReadError {
    79  		return fmt.Sprintf("websocket read: %v", e.Err)
    80  	}
    81  	return fmt.Sprintf("websocket write: %v", e.Err)
    82  }
    83  
    84  var (
    85  	_ graphql.Transport = Websocket{}
    86  	_ error             = WebsocketError{}
    87  )
    88  
    89  func (t Websocket) Supports(r *http.Request) bool {
    90  	return r.Header.Get("Upgrade") != ""
    91  }
    92  
    93  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    94  	t.injectGraphQLWSSubprotocols()
    95  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
    96  	if err != nil {
    97  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    98  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    99  		return
   100  	}
   101  
   102  	var me messageExchanger
   103  	switch ws.Subprotocol() {
   104  	default:
   105  		msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
   106  		ws.WriteMessage(websocket.CloseMessage, msg)
   107  		return
   108  	case graphqlwsSubprotocol, "":
   109  		// clients are required to send a subprotocol, to be backward compatible with the previous implementation we select
   110  		// "graphql-ws" by default
   111  		me = graphqlwsMessageExchanger{c: ws}
   112  	case graphqltransportwsSubprotocol:
   113  		me = graphqltransportwsMessageExchanger{c: ws}
   114  	}
   115  
   116  	conn := wsConnection{
   117  		active:    map[string]context.CancelFunc{},
   118  		conn:      ws,
   119  		ctx:       r.Context(),
   120  		exec:      exec,
   121  		me:        me,
   122  		Websocket: t,
   123  	}
   124  
   125  	if !conn.init() {
   126  		return
   127  	}
   128  
   129  	conn.run()
   130  }
   131  
   132  func (c *wsConnection) handlePossibleError(err error, isReadError bool) {
   133  	if c.ErrorFunc != nil && err != nil {
   134  		c.ErrorFunc(c.ctx, WebsocketError{
   135  			Err:         err,
   136  			IsReadError: isReadError,
   137  		})
   138  	}
   139  }
   140  
   141  func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) {
   142  	messages, errs := make(chan message, 1), make(chan error, 1)
   143  
   144  	go func() {
   145  		if m, err := c.me.NextMessage(); err != nil {
   146  			errs <- err
   147  		} else {
   148  			messages <- m
   149  		}
   150  	}()
   151  
   152  	select {
   153  	case m := <-messages:
   154  		return m, nil
   155  	case err := <-errs:
   156  		return message{}, err
   157  	case <-time.After(timeout):
   158  		return message{}, errReadTimeout
   159  	}
   160  }
   161  
   162  func (c *wsConnection) init() bool {
   163  	var m message
   164  	var err error
   165  
   166  	if c.InitTimeout != 0 {
   167  		m, err = c.nextMessageWithTimeout(c.InitTimeout)
   168  	} else {
   169  		m, err = c.me.NextMessage()
   170  	}
   171  
   172  	if err != nil {
   173  		if err == errReadTimeout {
   174  			c.close(websocket.CloseProtocolError, "connection initialisation timeout")
   175  			return false
   176  		}
   177  
   178  		if err == errInvalidMsg {
   179  			c.sendConnectionError("invalid json")
   180  		}
   181  
   182  		c.close(websocket.CloseProtocolError, "decoding error")
   183  		return false
   184  	}
   185  
   186  	switch m.t {
   187  	case initMessageType:
   188  		if len(m.payload) > 0 {
   189  			c.initPayload = make(InitPayload)
   190  			err := json.Unmarshal(m.payload, &c.initPayload)
   191  			if err != nil {
   192  				return false
   193  			}
   194  		}
   195  
   196  		var initAckPayload *InitPayload = nil
   197  		if c.InitFunc != nil {
   198  			var ctx context.Context
   199  			ctx, initAckPayload, err = c.InitFunc(c.ctx, c.initPayload)
   200  			if err != nil {
   201  				c.sendConnectionError(err.Error())
   202  				c.close(websocket.CloseNormalClosure, "terminated")
   203  				return false
   204  			}
   205  			c.ctx = ctx
   206  		}
   207  
   208  		if initAckPayload != nil {
   209  			initJsonAckPayload, err := json.Marshal(*initAckPayload)
   210  			if err != nil {
   211  				panic(err)
   212  			}
   213  			c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload})
   214  		} else {
   215  			c.write(&message{t: connectionAckMessageType})
   216  		}
   217  		c.write(&message{t: keepAliveMessageType})
   218  	case connectionCloseMessageType:
   219  		c.close(websocket.CloseNormalClosure, "terminated")
   220  		return false
   221  	default:
   222  		c.sendConnectionError("unexpected message %s", m.t)
   223  		c.close(websocket.CloseProtocolError, "unexpected message")
   224  		return false
   225  	}
   226  
   227  	return true
   228  }
   229  
   230  func (c *wsConnection) write(msg *message) {
   231  	c.mu.Lock()
   232  	c.handlePossibleError(c.me.Send(msg), false)
   233  	c.mu.Unlock()
   234  }
   235  
   236  func (c *wsConnection) run() {
   237  	// We create a cancellation that will shutdown the keep-alive when we leave
   238  	// this function.
   239  	ctx, cancel := context.WithCancel(c.ctx)
   240  	defer func() {
   241  		cancel()
   242  		c.close(websocket.CloseAbnormalClosure, "unexpected closure")
   243  	}()
   244  
   245  	// If we're running in graphql-ws mode, create a timer that will trigger a
   246  	// keep alive message every interval
   247  	if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 {
   248  		c.mu.Lock()
   249  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   250  		c.mu.Unlock()
   251  
   252  		go c.keepAlive(ctx)
   253  	}
   254  
   255  	// If we're running in graphql-transport-ws mode, create a timer that will trigger a
   256  	// just a pong message every interval
   257  	if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PongOnlyInterval != 0 {
   258  		c.mu.Lock()
   259  		c.pongOnlyTicker = time.NewTicker(c.PongOnlyInterval)
   260  		c.mu.Unlock()
   261  
   262  		go c.keepAlivePongOnly(ctx)
   263  	}
   264  
   265  	// If we're running in graphql-transport-ws mode, create a timer that will
   266  	// trigger a ping message every interval and expect a pong!
   267  	if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 {
   268  		c.mu.Lock()
   269  		c.pingPongTicker = time.NewTicker(c.PingPongInterval)
   270  		c.mu.Unlock()
   271  
   272  		if !c.MissingPongOk {
   273  			// Note: when the connection is closed by this deadline, the client
   274  			// will receive an "invalid close code"
   275  			c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   276  		}
   277  		go c.ping(ctx)
   278  	}
   279  
   280  	// Close the connection when the context is cancelled.
   281  	// Will optionally send a "close reason" that is retrieved from the context.
   282  	go c.closeOnCancel(ctx)
   283  
   284  	for {
   285  		start := graphql.Now()
   286  		m, err := c.me.NextMessage()
   287  		if err != nil {
   288  			// If the connection got closed by us, don't report the error
   289  			if !errors.Is(err, net.ErrClosed) {
   290  				c.handlePossibleError(err, true)
   291  			}
   292  			return
   293  		}
   294  
   295  		switch m.t {
   296  		case startMessageType:
   297  			c.subscribe(start, &m)
   298  		case stopMessageType:
   299  			c.mu.Lock()
   300  			closer := c.active[m.id]
   301  			c.mu.Unlock()
   302  			if closer != nil {
   303  				closer()
   304  			}
   305  		case connectionCloseMessageType:
   306  			c.close(websocket.CloseNormalClosure, "terminated")
   307  			return
   308  		case pingMessageType:
   309  			c.write(&message{t: pongMessageType, payload: m.payload})
   310  		case pongMessageType:
   311  			c.mu.Lock()
   312  			c.receivedPong = true
   313  			c.mu.Unlock()
   314  			// Clear ReadTimeout -- 0 time val clears.
   315  			c.conn.SetReadDeadline(time.Time{})
   316  		default:
   317  			c.sendConnectionError("unexpected message %s", m.t)
   318  			c.close(websocket.CloseProtocolError, "unexpected message")
   319  			return
   320  		}
   321  	}
   322  }
   323  
   324  func (c *wsConnection) keepAlivePongOnly(ctx context.Context) {
   325  	for {
   326  		select {
   327  		case <-ctx.Done():
   328  			c.pongOnlyTicker.Stop()
   329  			return
   330  		case <-c.pongOnlyTicker.C:
   331  			c.write(&message{t: pongMessageType, payload: json.RawMessage{}})
   332  		}
   333  	}
   334  }
   335  
   336  func (c *wsConnection) keepAlive(ctx context.Context) {
   337  	for {
   338  		select {
   339  		case <-ctx.Done():
   340  			c.keepAliveTicker.Stop()
   341  			return
   342  		case <-c.keepAliveTicker.C:
   343  			c.write(&message{t: keepAliveMessageType})
   344  		}
   345  	}
   346  }
   347  
   348  func (c *wsConnection) ping(ctx context.Context) {
   349  	for {
   350  		select {
   351  		case <-ctx.Done():
   352  			c.pingPongTicker.Stop()
   353  			return
   354  		case <-c.pingPongTicker.C:
   355  			c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
   356  			// The initial deadline for this method is set in run()
   357  			// if we have not yet received a pong, don't reset the deadline.
   358  			c.mu.Lock()
   359  			if !c.MissingPongOk && c.receivedPong {
   360  				c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   361  			}
   362  			c.receivedPong = false
   363  			c.mu.Unlock()
   364  		}
   365  	}
   366  }
   367  
   368  func (c *wsConnection) closeOnCancel(ctx context.Context) {
   369  	<-ctx.Done()
   370  
   371  	if r := closeReasonForContext(ctx); r != "" {
   372  		c.sendConnectionError(r)
   373  	}
   374  	c.close(websocket.CloseNormalClosure, "terminated")
   375  }
   376  
   377  func (c *wsConnection) subscribe(start time.Time, msg *message) {
   378  	ctx := graphql.StartOperationTrace(c.ctx)
   379  	var params *graphql.RawParams
   380  	if err := jsonDecode(bytes.NewReader(msg.payload), &params); err != nil {
   381  		c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
   382  		c.complete(msg.id)
   383  		return
   384  	}
   385  
   386  	params.ReadTime = graphql.TraceTiming{
   387  		Start: start,
   388  		End:   graphql.Now(),
   389  	}
   390  
   391  	rc, err := c.exec.CreateOperationContext(ctx, params)
   392  	if err != nil {
   393  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   394  		switch errcode.GetErrorKind(err) {
   395  		case errcode.KindProtocol:
   396  			c.sendError(msg.id, resp.Errors...)
   397  		default:
   398  			c.sendResponse(msg.id, &graphql.Response{Errors: err})
   399  		}
   400  
   401  		c.complete(msg.id)
   402  		return
   403  	}
   404  
   405  	ctx = graphql.WithOperationContext(ctx, rc)
   406  
   407  	if c.initPayload != nil {
   408  		ctx = withInitPayload(ctx, c.initPayload)
   409  	}
   410  
   411  	ctx, cancel := context.WithCancel(ctx)
   412  	c.mu.Lock()
   413  	c.active[msg.id] = cancel
   414  	c.mu.Unlock()
   415  
   416  	go func() {
   417  		ctx = withSubscriptionErrorContext(ctx)
   418  		defer func() {
   419  			if r := recover(); r != nil {
   420  				err := rc.Recover(ctx, r)
   421  				var gqlerr *gqlerror.Error
   422  				if !errors.As(err, &gqlerr) {
   423  					gqlerr = &gqlerror.Error{}
   424  					if err != nil {
   425  						gqlerr.Message = err.Error()
   426  					}
   427  				}
   428  				c.sendError(msg.id, gqlerr)
   429  			}
   430  			if errs := getSubscriptionError(ctx); len(errs) != 0 {
   431  				c.sendError(msg.id, errs...)
   432  			} else {
   433  				c.complete(msg.id)
   434  			}
   435  			c.mu.Lock()
   436  			delete(c.active, msg.id)
   437  			c.mu.Unlock()
   438  			cancel()
   439  		}()
   440  
   441  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   442  		for {
   443  			response := responses(ctx)
   444  			if response == nil {
   445  				break
   446  			}
   447  
   448  			c.sendResponse(msg.id, response)
   449  		}
   450  
   451  		// complete and context cancel comes from the defer
   452  	}()
   453  }
   454  
   455  func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
   456  	b, err := json.Marshal(response)
   457  	if err != nil {
   458  		panic(err)
   459  	}
   460  	c.write(&message{
   461  		payload: b,
   462  		id:      id,
   463  		t:       dataMessageType,
   464  	})
   465  }
   466  
   467  func (c *wsConnection) complete(id string) {
   468  	c.write(&message{id: id, t: completeMessageType})
   469  }
   470  
   471  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   472  	errs := make([]error, len(errors))
   473  	for i, err := range errors {
   474  		errs[i] = err
   475  	}
   476  	b, err := json.Marshal(errs)
   477  	if err != nil {
   478  		panic(err)
   479  	}
   480  	c.write(&message{t: errorMessageType, id: id, payload: b})
   481  }
   482  
   483  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   484  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   485  	if err != nil {
   486  		panic(err)
   487  	}
   488  
   489  	c.write(&message{t: connectionErrorMessageType, payload: b})
   490  }
   491  
   492  func (c *wsConnection) close(closeCode int, message string) {
   493  	c.mu.Lock()
   494  	if c.closed {
   495  		c.mu.Unlock()
   496  		return
   497  	}
   498  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   499  	for _, closer := range c.active {
   500  		closer()
   501  	}
   502  	c.closed = true
   503  	c.mu.Unlock()
   504  	_ = c.conn.Close()
   505  
   506  	if c.CloseFunc != nil {
   507  		c.CloseFunc(c.ctx, closeCode)
   508  	}
   509  }
   510  

View as plain text