...

Source file src/github.com/99designs/gqlgen/graphql/handler/transport/websocket_test.go

Documentation: github.com/99designs/gqlgen/graphql/handler/transport

     1  package transport_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  	"github.com/vektah/gqlparser/v2"
    17  	"github.com/vektah/gqlparser/v2/ast"
    18  
    19  	"github.com/99designs/gqlgen/client"
    20  	"github.com/99designs/gqlgen/graphql"
    21  	"github.com/99designs/gqlgen/graphql/handler"
    22  	"github.com/99designs/gqlgen/graphql/handler/testserver"
    23  	"github.com/99designs/gqlgen/graphql/handler/transport"
    24  )
    25  
    26  type ckey string
    27  
    28  func TestWebsocket(t *testing.T) {
    29  	handler := testserver.New()
    30  	handler.AddTransport(transport.Websocket{})
    31  
    32  	srv := httptest.NewServer(handler)
    33  	defer srv.Close()
    34  
    35  	t.Run("client must send valid json", func(t *testing.T) {
    36  		c := wsConnect(srv.URL)
    37  		defer c.Close()
    38  
    39  		writeRaw(c, "hello")
    40  
    41  		msg := readOp(c)
    42  		assert.Equal(t, "connection_error", msg.Type)
    43  		assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
    44  	})
    45  
    46  	t.Run("client can terminate before init", func(t *testing.T) {
    47  		c := wsConnect(srv.URL)
    48  		defer c.Close()
    49  
    50  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    51  
    52  		_, _, err := c.ReadMessage()
    53  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    54  	})
    55  
    56  	t.Run("client must send init first", func(t *testing.T) {
    57  		c := wsConnect(srv.URL)
    58  		defer c.Close()
    59  
    60  		require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
    61  
    62  		msg := readOp(c)
    63  		assert.Equal(t, connectionErrorMsg, msg.Type)
    64  		assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
    65  	})
    66  
    67  	t.Run("server acks init", func(t *testing.T) {
    68  		c := wsConnect(srv.URL)
    69  		defer c.Close()
    70  
    71  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    72  
    73  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    74  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    75  	})
    76  
    77  	t.Run("client can terminate before run", func(t *testing.T) {
    78  		c := wsConnect(srv.URL)
    79  		defer c.Close()
    80  
    81  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    82  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    83  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    84  
    85  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    86  
    87  		_, _, err := c.ReadMessage()
    88  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    89  	})
    90  
    91  	t.Run("client gets parse errors", func(t *testing.T) {
    92  		c := wsConnect(srv.URL)
    93  		defer c.Close()
    94  
    95  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    96  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    97  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    98  
    99  		require.NoError(t, c.WriteJSON(&operationMessage{
   100  			Type:    startMsg,
   101  			ID:      "test_1",
   102  			Payload: json.RawMessage(`{"query": "!"}`),
   103  		}))
   104  
   105  		msg := readOp(c)
   106  		assert.Equal(t, errorMsg, msg.Type)
   107  		assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload))
   108  	})
   109  
   110  	t.Run("client can receive data", func(t *testing.T) {
   111  		c := wsConnect(srv.URL)
   112  		defer c.Close()
   113  
   114  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   115  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   116  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   117  
   118  		require.NoError(t, c.WriteJSON(&operationMessage{
   119  			Type:    startMsg,
   120  			ID:      "test_1",
   121  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   122  		}))
   123  
   124  		handler.SendNextSubscriptionMessage()
   125  		msg := readOp(c)
   126  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   127  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   128  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   129  
   130  		handler.SendNextSubscriptionMessage()
   131  		msg = readOp(c)
   132  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   133  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   134  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   135  
   136  		require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
   137  
   138  		msg = readOp(c)
   139  		require.Equal(t, completeMsg, msg.Type)
   140  		require.Equal(t, "test_1", msg.ID)
   141  
   142  		// At this point we should be done and should not receive another message.
   143  		c.SetReadDeadline(time.Now().UTC().Add(1 * time.Millisecond))
   144  
   145  		err := c.ReadJSON(&msg)
   146  		if err == nil {
   147  			// This should not send a second close message for the same id.
   148  			require.NotEqual(t, completeMsg, msg.Type)
   149  			require.NotEqual(t, "test_1", msg.ID)
   150  		} else {
   151  			assert.Contains(t, err.Error(), "timeout")
   152  		}
   153  	})
   154  }
   155  
   156  func TestWebsocketWithKeepAlive(t *testing.T) {
   157  	h := testserver.New()
   158  	h.AddTransport(transport.Websocket{
   159  		KeepAlivePingInterval: 100 * time.Millisecond,
   160  	})
   161  
   162  	srv := httptest.NewServer(h)
   163  	defer srv.Close()
   164  
   165  	c := wsConnect(srv.URL)
   166  	defer c.Close()
   167  
   168  	require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   169  	assert.Equal(t, connectionAckMsg, readOp(c).Type)
   170  	assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   171  
   172  	require.NoError(t, c.WriteJSON(&operationMessage{
   173  		Type:    startMsg,
   174  		ID:      "test_1",
   175  		Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   176  	}))
   177  
   178  	// keepalive
   179  	msg := readOp(c)
   180  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   181  
   182  	// server message
   183  	h.SendNextSubscriptionMessage()
   184  	msg = readOp(c)
   185  	assert.Equal(t, dataMsg, msg.Type)
   186  
   187  	// keepalive
   188  	msg = readOp(c)
   189  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   190  }
   191  
   192  func TestWebsocketInitFunc(t *testing.T) {
   193  	t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
   194  		h := testserver.New()
   195  		h.AddTransport(transport.Websocket{})
   196  		srv := httptest.NewServer(h)
   197  		defer srv.Close()
   198  
   199  		c := wsConnect(srv.URL)
   200  		defer c.Close()
   201  
   202  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   203  
   204  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   205  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   206  	})
   207  
   208  	t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   209  		h := testserver.New()
   210  		h.AddTransport(transport.Websocket{
   211  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
   212  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
   213  			},
   214  		})
   215  		srv := httptest.NewServer(h)
   216  		defer srv.Close()
   217  
   218  		c := wsConnect(srv.URL)
   219  		defer c.Close()
   220  
   221  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   222  
   223  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   224  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   225  	})
   226  
   227  	t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   228  		h := testserver.New()
   229  		h.AddTransport(transport.Websocket{
   230  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
   231  				return ctx, nil, errors.New("invalid init payload")
   232  			},
   233  		})
   234  		srv := httptest.NewServer(h)
   235  		defer srv.Close()
   236  
   237  		c := wsConnect(srv.URL)
   238  		defer c.Close()
   239  
   240  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   241  
   242  		msg := readOp(c)
   243  		assert.Equal(t, connectionErrorMsg, msg.Type)
   244  		assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
   245  	})
   246  
   247  	t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
   248  		es := &graphql.ExecutableSchemaMock{
   249  			ExecFunc: func(ctx context.Context) graphql.ResponseHandler {
   250  				assert.Equal(t, "newvalue", ctx.Value(ckey("newkey")))
   251  				return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)})
   252  			},
   253  			SchemaFunc: func() *ast.Schema {
   254  				return gqlparser.MustLoadSchema(&ast.Source{Input: `
   255  				schema { query: Query }
   256  				type Query {
   257  					empty: String
   258  				}
   259  			`})
   260  			},
   261  		}
   262  		h := handler.New(es)
   263  
   264  		h.AddTransport(transport.Websocket{
   265  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
   266  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
   267  			},
   268  		})
   269  
   270  		c := client.New(h)
   271  
   272  		socket := c.Websocket("{ empty } ")
   273  		defer socket.Close()
   274  		var resp struct {
   275  			Empty string
   276  		}
   277  		err := socket.Next(&resp)
   278  		require.NoError(t, err)
   279  		assert.Equal(t, "ok", resp.Empty)
   280  	})
   281  
   282  	t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) {
   283  		h := testserver.New()
   284  		var cancel func()
   285  		h.AddTransport(transport.Websocket{
   286  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
   287  				newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
   288  				return
   289  			},
   290  		})
   291  		srv := httptest.NewServer(h)
   292  		defer srv.Close()
   293  
   294  		c := wsConnect(srv.URL)
   295  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   296  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   297  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   298  
   299  		// Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy)
   300  		defer cancel()
   301  
   302  		time.Sleep(time.Millisecond * 10)
   303  		m := readOp(c)
   304  		assert.Equal(t, m.Type, connectionErrorMsg)
   305  		assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
   306  	})
   307  	t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   308  		h := testserver.New()
   309  		h.AddTransport(transport.Websocket{
   310  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
   311  				initResponsePayload := transport.InitPayload{"trackingId": "123-456"}
   312  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), &initResponsePayload, nil
   313  			},
   314  		})
   315  		srv := httptest.NewServer(h)
   316  		defer srv.Close()
   317  
   318  		c := wsConnect(srv.URL)
   319  		defer c.Close()
   320  
   321  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   322  
   323  		connAck := readOp(c)
   324  		assert.Equal(t, connectionAckMsg, connAck.Type)
   325  
   326  		var payload map[string]interface{}
   327  		err := json.Unmarshal(connAck.Payload, &payload)
   328  		if err != nil {
   329  			t.Fatal("Unexpected Error", err)
   330  		}
   331  		assert.EqualValues(t, "123-456", payload["trackingId"])
   332  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   333  	})
   334  }
   335  
   336  func TestWebSocketInitTimeout(t *testing.T) {
   337  	t.Run("times out if no init message is received within the configured duration", func(t *testing.T) {
   338  		h := testserver.New()
   339  		h.AddTransport(transport.Websocket{
   340  			InitTimeout: 5 * time.Millisecond,
   341  		})
   342  		srv := httptest.NewServer(h)
   343  		defer srv.Close()
   344  
   345  		c := wsConnect(srv.URL)
   346  		defer c.Close()
   347  
   348  		var msg operationMessage
   349  		err := c.ReadJSON(&msg)
   350  		assert.Error(t, err)
   351  		assert.Contains(t, err.Error(), "timeout")
   352  	})
   353  
   354  	t.Run("keeps waiting for an init message if no time out is configured", func(t *testing.T) {
   355  		h := testserver.New()
   356  		h.AddTransport(transport.Websocket{})
   357  		srv := httptest.NewServer(h)
   358  		defer srv.Close()
   359  
   360  		c := wsConnect(srv.URL)
   361  		defer c.Close()
   362  
   363  		done := make(chan interface{}, 1)
   364  		go func() {
   365  			var msg operationMessage
   366  			_ = c.ReadJSON(&msg)
   367  			done <- 1
   368  		}()
   369  
   370  		select {
   371  		case <-done:
   372  			assert.Fail(t, "web socket read operation finished while it shouldn't have")
   373  		case <-time.After(100 * time.Millisecond):
   374  			// Success! I guess? Can't really wait forever to see if the read waits forever...
   375  		}
   376  	})
   377  }
   378  
   379  func TestWebSocketErrorFunc(t *testing.T) {
   380  	t.Run("the error handler gets called when an error occurs", func(t *testing.T) {
   381  		errFuncCalled := make(chan bool, 1)
   382  		h := testserver.New()
   383  		h.AddTransport(transport.Websocket{
   384  			ErrorFunc: func(_ context.Context, err error) {
   385  				require.Error(t, err)
   386  				assert.Equal(t, err.Error(), "websocket read: invalid message received")
   387  				assert.IsType(t, transport.WebsocketError{}, err)
   388  				assert.True(t, err.(transport.WebsocketError).IsReadError)
   389  				errFuncCalled <- true
   390  			},
   391  		})
   392  
   393  		srv := httptest.NewServer(h)
   394  		defer srv.Close()
   395  
   396  		c := wsConnect(srv.URL)
   397  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   398  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   399  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   400  		require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte("mark my words, you will regret this")))
   401  
   402  		select {
   403  		case res := <-errFuncCalled:
   404  			assert.True(t, res)
   405  		case <-time.NewTimer(time.Millisecond * 20).C:
   406  			assert.Fail(t, "The fail handler was not called in time")
   407  		}
   408  	})
   409  
   410  	t.Run("init func errors do not call the error handler", func(t *testing.T) {
   411  		h := testserver.New()
   412  		h.AddTransport(transport.Websocket{
   413  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
   414  				return ctx, nil, errors.New("this is not what we agreed upon")
   415  			},
   416  			ErrorFunc: func(_ context.Context, err error) {
   417  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   418  			},
   419  		})
   420  		srv := httptest.NewServer(h)
   421  		defer srv.Close()
   422  
   423  		c := wsConnect(srv.URL)
   424  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   425  		time.Sleep(time.Millisecond * 20)
   426  	})
   427  
   428  	t.Run("init func context closes do not call the error handler", func(t *testing.T) {
   429  		h := testserver.New()
   430  		h.AddTransport(transport.Websocket{
   431  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
   432  				newCtx, cancel := context.WithCancel(ctx)
   433  				time.AfterFunc(time.Millisecond*5, cancel)
   434  				return newCtx, nil, nil
   435  			},
   436  			ErrorFunc: func(_ context.Context, err error) {
   437  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   438  			},
   439  		})
   440  		srv := httptest.NewServer(h)
   441  		defer srv.Close()
   442  
   443  		c := wsConnect(srv.URL)
   444  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   445  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   446  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   447  		time.Sleep(time.Millisecond * 20)
   448  	})
   449  
   450  	t.Run("init func context deadlines do not call the error handler", func(t *testing.T) {
   451  		h := testserver.New()
   452  		var cancel func()
   453  		h.AddTransport(transport.Websocket{
   454  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
   455  				newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5))
   456  				return newCtx, nil, nil
   457  			},
   458  			ErrorFunc: func(_ context.Context, err error) {
   459  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   460  			},
   461  		})
   462  		srv := httptest.NewServer(h)
   463  		defer srv.Close()
   464  
   465  		c := wsConnect(srv.URL)
   466  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   467  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   468  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   469  
   470  		// Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy)
   471  		defer cancel()
   472  
   473  		time.Sleep(time.Millisecond * 20)
   474  	})
   475  }
   476  
   477  func TestWebSocketCloseFunc(t *testing.T) {
   478  	t.Run("the on close handler gets called when the websocket is closed", func(t *testing.T) {
   479  		closeFuncCalled := make(chan bool, 1)
   480  		h := testserver.New()
   481  		h.AddTransport(transport.Websocket{
   482  			CloseFunc: func(_ context.Context, _closeCode int) {
   483  				closeFuncCalled <- true
   484  			},
   485  		})
   486  
   487  		srv := httptest.NewServer(h)
   488  		defer srv.Close()
   489  
   490  		c := wsConnect(srv.URL)
   491  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   492  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   493  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   494  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
   495  
   496  		select {
   497  		case res := <-closeFuncCalled:
   498  			assert.True(t, res)
   499  		case <-time.NewTimer(time.Millisecond * 20).C:
   500  			assert.Fail(t, "The close handler was not called in time")
   501  		}
   502  	})
   503  
   504  	t.Run("the on close handler gets called only once when the websocket is closed", func(t *testing.T) {
   505  		closeFuncCalled := make(chan bool, 1)
   506  		h := testserver.New()
   507  		h.AddTransport(transport.Websocket{
   508  			CloseFunc: func(_ context.Context, _closeCode int) {
   509  				closeFuncCalled <- true
   510  			},
   511  		})
   512  
   513  		srv := httptest.NewServer(h)
   514  		defer srv.Close()
   515  
   516  		c := wsConnect(srv.URL)
   517  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   518  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   519  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   520  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
   521  
   522  		select {
   523  		case res := <-closeFuncCalled:
   524  			assert.True(t, res)
   525  		case <-time.NewTimer(time.Millisecond * 20).C:
   526  			assert.Fail(t, "The close handler was not called in time")
   527  		}
   528  
   529  		select {
   530  		case <-closeFuncCalled:
   531  			assert.Fail(t, "The close handler was called more than once")
   532  		case <-time.NewTimer(time.Millisecond * 20).C:
   533  			// ok
   534  		}
   535  	})
   536  
   537  	t.Run("init func errors call the close handler", func(t *testing.T) {
   538  		h := testserver.New()
   539  		closeFuncCalled := make(chan bool, 1)
   540  		h.AddTransport(transport.Websocket{
   541  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
   542  				return ctx, nil, errors.New("error during init")
   543  			},
   544  			CloseFunc: func(_ context.Context, _closeCode int) {
   545  				closeFuncCalled <- true
   546  			},
   547  		})
   548  		srv := httptest.NewServer(h)
   549  		defer srv.Close()
   550  
   551  		c := wsConnect(srv.URL)
   552  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   553  		select {
   554  		case res := <-closeFuncCalled:
   555  			assert.True(t, res)
   556  		case <-time.NewTimer(time.Millisecond * 20).C:
   557  			assert.Fail(t, "The close handler was not called in time")
   558  		}
   559  	})
   560  }
   561  
   562  func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) {
   563  	initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
   564  		h := testserver.New()
   565  		h.AddTransport(ws)
   566  		return h, httptest.NewServer(h)
   567  	}
   568  
   569  	t.Run("server acks init", func(t *testing.T) {
   570  		_, srv := initialize(transport.Websocket{})
   571  		defer srv.Close()
   572  
   573  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   574  		defer c.Close()
   575  
   576  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   577  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   578  	})
   579  
   580  	t.Run("client can receive data", func(t *testing.T) {
   581  		handler, srv := initialize(transport.Websocket{})
   582  		defer srv.Close()
   583  
   584  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   585  		defer c.Close()
   586  
   587  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   588  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   589  
   590  		require.NoError(t, c.WriteJSON(&operationMessage{
   591  			Type:    graphqltransportwsSubscribeMsg,
   592  			ID:      "test_1",
   593  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   594  		}))
   595  
   596  		handler.SendNextSubscriptionMessage()
   597  		msg := readOp(c)
   598  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   599  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   600  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   601  
   602  		handler.SendNextSubscriptionMessage()
   603  		msg = readOp(c)
   604  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   605  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   606  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   607  
   608  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"}))
   609  
   610  		msg = readOp(c)
   611  		require.Equal(t, graphqltransportwsCompleteMsg, msg.Type)
   612  		require.Equal(t, "test_1", msg.ID)
   613  	})
   614  
   615  	t.Run("receives no graphql-ws keep alive messages", func(t *testing.T) {
   616  		_, srv := initialize(transport.Websocket{KeepAlivePingInterval: 5 * time.Millisecond})
   617  		defer srv.Close()
   618  
   619  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   620  		defer c.Close()
   621  
   622  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   623  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   624  
   625  		// If the keep-alives are sent, this deadline will not be used, and no timeout error will be found
   626  		c.SetReadDeadline(time.Now().UTC().Add(50 * time.Millisecond))
   627  		var msg operationMessage
   628  		err := c.ReadJSON(&msg)
   629  		require.Error(t, err)
   630  		assert.Contains(t, err.Error(), "timeout")
   631  	})
   632  }
   633  
   634  func TestWebsocketWithPingPongInterval(t *testing.T) {
   635  	initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
   636  		h := testserver.New()
   637  		h.AddTransport(ws)
   638  		return h, httptest.NewServer(h)
   639  	}
   640  
   641  	t.Run("client receives ping and responds with pong", func(t *testing.T) {
   642  		_, srv := initialize(transport.Websocket{PingPongInterval: 20 * time.Millisecond})
   643  		defer srv.Close()
   644  
   645  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   646  		defer c.Close()
   647  
   648  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   649  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   650  
   651  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   652  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg}))
   653  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   654  	})
   655  
   656  	t.Run("client sends ping and expects pong", func(t *testing.T) {
   657  		_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
   658  		defer srv.Close()
   659  	})
   660  
   661  	t.Run("client sends ping and expects pong", func(t *testing.T) {
   662  		_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
   663  		defer srv.Close()
   664  
   665  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   666  		defer c.Close()
   667  
   668  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   669  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   670  
   671  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg}))
   672  		assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
   673  	})
   674  
   675  	t.Run("server closes with error if client does not pong and !MissingPongOk", func(t *testing.T) {
   676  		h := testserver.New()
   677  		closeFuncCalled := make(chan bool, 1)
   678  		h.AddTransport(transport.Websocket{
   679  			MissingPongOk:    false, // default value but beign explicit for test clarity.
   680  			PingPongInterval: 5 * time.Millisecond,
   681  			CloseFunc: func(_ context.Context, _closeCode int) {
   682  				closeFuncCalled <- true
   683  			},
   684  		})
   685  
   686  		srv := httptest.NewServer(h)
   687  		defer srv.Close()
   688  
   689  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   690  		defer c.Close()
   691  
   692  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   693  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   694  
   695  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   696  
   697  		select {
   698  		case res := <-closeFuncCalled:
   699  			assert.True(t, res)
   700  		case <-time.NewTimer(time.Millisecond * 20).C:
   701  			// with a 5ms interval 10ms should be the timeout, double that to make the test less likely to flake under load
   702  			assert.Fail(t, "The close handler was not called in time")
   703  		}
   704  	})
   705  
   706  	t.Run("server does not close with error if client does not pong and MissingPongOk", func(t *testing.T) {
   707  		h := testserver.New()
   708  		closeFuncCalled := make(chan bool, 1)
   709  		h.AddTransport(transport.Websocket{
   710  			MissingPongOk:    true,
   711  			PingPongInterval: 10 * time.Millisecond,
   712  			CloseFunc: func(_ context.Context, _closeCode int) {
   713  				closeFuncCalled <- true
   714  			},
   715  		})
   716  
   717  		srv := httptest.NewServer(h)
   718  		defer srv.Close()
   719  
   720  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   721  		defer c.Close()
   722  
   723  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   724  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   725  
   726  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   727  
   728  		select {
   729  		case <-closeFuncCalled:
   730  			assert.Fail(t, "The close handler was called even with MissingPongOk = true")
   731  		case _, ok := <-time.NewTimer(time.Millisecond * 20).C:
   732  			assert.True(t, ok)
   733  		}
   734  	})
   735  
   736  	t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) {
   737  		// Regression test
   738  		// ---
   739  		// Before the refactor, the code would try to convert a ping message to a graphql-ws message type
   740  		// But since this message type does not exist in the graphql-ws sub protocol, it would fail
   741  
   742  		_, srv := initialize(transport.Websocket{
   743  			PingPongInterval:      5 * time.Millisecond,
   744  			KeepAlivePingInterval: 10 * time.Millisecond,
   745  		})
   746  		defer srv.Close()
   747  
   748  		// Create connection
   749  		c := wsConnect(srv.URL)
   750  		defer c.Close()
   751  
   752  		// Initialize connection
   753  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   754  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   755  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   756  
   757  		// Wait for a few more keep alives to be sure nothing goes wrong
   758  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   759  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   760  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   761  	})
   762  	t.Run("pong only messages are sent when configured with graphql-transport-ws", func(t *testing.T) {
   763  
   764  		h, srv := initialize(transport.Websocket{PongOnlyInterval: 10 * time.Millisecond})
   765  		defer srv.Close()
   766  
   767  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   768  		defer c.Close()
   769  
   770  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   771  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   772  
   773  		assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
   774  
   775  		require.NoError(t, c.WriteJSON(&operationMessage{
   776  			Type:    graphqltransportwsSubscribeMsg,
   777  			ID:      "test_1",
   778  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   779  		}))
   780  
   781  		// pong
   782  		msg := readOp(c)
   783  		assert.Equal(t, graphqltransportwsPongMsg, msg.Type)
   784  
   785  		// server message
   786  		h.SendNextSubscriptionMessage()
   787  		msg = readOp(c)
   788  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   789  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   790  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   791  
   792  		// keepalive
   793  		msg = readOp(c)
   794  		assert.Equal(t, graphqltransportwsPongMsg, msg.Type)
   795  	})
   796  
   797  }
   798  
   799  func wsConnect(url string) *websocket.Conn {
   800  	return wsConnectWithSubprocotol(url, "")
   801  }
   802  
   803  func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn {
   804  	h := make(http.Header)
   805  	if subprocotol != "" {
   806  		h.Add("Sec-WebSocket-Protocol", subprocotol)
   807  	}
   808  
   809  	c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h)
   810  	if err != nil {
   811  		panic(err)
   812  	}
   813  	_ = resp.Body.Close()
   814  
   815  	return c
   816  }
   817  
   818  func writeRaw(conn *websocket.Conn, msg string) {
   819  	if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
   820  		panic(err)
   821  	}
   822  }
   823  
   824  func readOp(conn *websocket.Conn) operationMessage {
   825  	var msg operationMessage
   826  	if err := conn.ReadJSON(&msg); err != nil {
   827  		panic(err)
   828  	}
   829  	return msg
   830  }
   831  
   832  // copied out from websocket_graphqlws.go to keep these private
   833  
   834  const (
   835  	connectionInitMsg      = "connection_init"      // Client -> Server
   836  	connectionTerminateMsg = "connection_terminate" // Client -> Server
   837  	startMsg               = "start"                // Client -> Server
   838  	stopMsg                = "stop"                 // Client -> Server
   839  	connectionAckMsg       = "connection_ack"       // Server -> Client
   840  	connectionErrorMsg     = "connection_error"     // Server -> Client
   841  	dataMsg                = "data"                 // Server -> Client
   842  	errorMsg               = "error"                // Server -> Client
   843  	completeMsg            = "complete"             // Server -> Client
   844  	connectionKeepAliveMsg = "ka"                   // Server -> Client
   845  )
   846  
   847  // copied out from websocket_graphql_transport_ws.go to keep these private
   848  
   849  const (
   850  	graphqltransportwsSubprotocol = "graphql-transport-ws"
   851  
   852  	graphqltransportwsConnectionInitMsg = "connection_init"
   853  	graphqltransportwsConnectionAckMsg  = "connection_ack"
   854  	graphqltransportwsSubscribeMsg      = "subscribe"
   855  	graphqltransportwsNextMsg           = "next"
   856  	graphqltransportwsCompleteMsg       = "complete"
   857  	graphqltransportwsPingMsg           = "ping"
   858  	graphqltransportwsPongMsg           = "pong"
   859  )
   860  
   861  type operationMessage struct {
   862  	Payload json.RawMessage `json:"payload,omitempty"`
   863  	ID      string          `json:"id,omitempty"`
   864  	Type    string          `json:"type"`
   865  }
   866  

View as plain text