1 package transport_test
2
3 import (
4 "context"
5 "encoding/json"
6 "errors"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/gorilla/websocket"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 "github.com/vektah/gqlparser/v2"
17 "github.com/vektah/gqlparser/v2/ast"
18
19 "github.com/99designs/gqlgen/client"
20 "github.com/99designs/gqlgen/graphql"
21 "github.com/99designs/gqlgen/graphql/handler"
22 "github.com/99designs/gqlgen/graphql/handler/testserver"
23 "github.com/99designs/gqlgen/graphql/handler/transport"
24 )
25
26 type ckey string
27
28 func TestWebsocket(t *testing.T) {
29 handler := testserver.New()
30 handler.AddTransport(transport.Websocket{})
31
32 srv := httptest.NewServer(handler)
33 defer srv.Close()
34
35 t.Run("client must send valid json", func(t *testing.T) {
36 c := wsConnect(srv.URL)
37 defer c.Close()
38
39 writeRaw(c, "hello")
40
41 msg := readOp(c)
42 assert.Equal(t, "connection_error", msg.Type)
43 assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
44 })
45
46 t.Run("client can terminate before init", func(t *testing.T) {
47 c := wsConnect(srv.URL)
48 defer c.Close()
49
50 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
51
52 _, _, err := c.ReadMessage()
53 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
54 })
55
56 t.Run("client must send init first", func(t *testing.T) {
57 c := wsConnect(srv.URL)
58 defer c.Close()
59
60 require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
61
62 msg := readOp(c)
63 assert.Equal(t, connectionErrorMsg, msg.Type)
64 assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
65 })
66
67 t.Run("server acks init", func(t *testing.T) {
68 c := wsConnect(srv.URL)
69 defer c.Close()
70
71 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
72
73 assert.Equal(t, connectionAckMsg, readOp(c).Type)
74 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
75 })
76
77 t.Run("client can terminate before run", func(t *testing.T) {
78 c := wsConnect(srv.URL)
79 defer c.Close()
80
81 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
82 assert.Equal(t, connectionAckMsg, readOp(c).Type)
83 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
84
85 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
86
87 _, _, err := c.ReadMessage()
88 assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
89 })
90
91 t.Run("client gets parse errors", func(t *testing.T) {
92 c := wsConnect(srv.URL)
93 defer c.Close()
94
95 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
96 assert.Equal(t, connectionAckMsg, readOp(c).Type)
97 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
98
99 require.NoError(t, c.WriteJSON(&operationMessage{
100 Type: startMsg,
101 ID: "test_1",
102 Payload: json.RawMessage(`{"query": "!"}`),
103 }))
104
105 msg := readOp(c)
106 assert.Equal(t, errorMsg, msg.Type)
107 assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload))
108 })
109
110 t.Run("client can receive data", func(t *testing.T) {
111 c := wsConnect(srv.URL)
112 defer c.Close()
113
114 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
115 assert.Equal(t, connectionAckMsg, readOp(c).Type)
116 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
117
118 require.NoError(t, c.WriteJSON(&operationMessage{
119 Type: startMsg,
120 ID: "test_1",
121 Payload: json.RawMessage(`{"query": "subscription { name }"}`),
122 }))
123
124 handler.SendNextSubscriptionMessage()
125 msg := readOp(c)
126 require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
127 require.Equal(t, "test_1", msg.ID, string(msg.Payload))
128 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
129
130 handler.SendNextSubscriptionMessage()
131 msg = readOp(c)
132 require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
133 require.Equal(t, "test_1", msg.ID, string(msg.Payload))
134 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
135
136 require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
137
138 msg = readOp(c)
139 require.Equal(t, completeMsg, msg.Type)
140 require.Equal(t, "test_1", msg.ID)
141
142
143 c.SetReadDeadline(time.Now().UTC().Add(1 * time.Millisecond))
144
145 err := c.ReadJSON(&msg)
146 if err == nil {
147
148 require.NotEqual(t, completeMsg, msg.Type)
149 require.NotEqual(t, "test_1", msg.ID)
150 } else {
151 assert.Contains(t, err.Error(), "timeout")
152 }
153 })
154 }
155
156 func TestWebsocketWithKeepAlive(t *testing.T) {
157 h := testserver.New()
158 h.AddTransport(transport.Websocket{
159 KeepAlivePingInterval: 100 * time.Millisecond,
160 })
161
162 srv := httptest.NewServer(h)
163 defer srv.Close()
164
165 c := wsConnect(srv.URL)
166 defer c.Close()
167
168 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
169 assert.Equal(t, connectionAckMsg, readOp(c).Type)
170 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
171
172 require.NoError(t, c.WriteJSON(&operationMessage{
173 Type: startMsg,
174 ID: "test_1",
175 Payload: json.RawMessage(`{"query": "subscription { name }"}`),
176 }))
177
178
179 msg := readOp(c)
180 assert.Equal(t, connectionKeepAliveMsg, msg.Type)
181
182
183 h.SendNextSubscriptionMessage()
184 msg = readOp(c)
185 assert.Equal(t, dataMsg, msg.Type)
186
187
188 msg = readOp(c)
189 assert.Equal(t, connectionKeepAliveMsg, msg.Type)
190 }
191
192 func TestWebsocketInitFunc(t *testing.T) {
193 t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
194 h := testserver.New()
195 h.AddTransport(transport.Websocket{})
196 srv := httptest.NewServer(h)
197 defer srv.Close()
198
199 c := wsConnect(srv.URL)
200 defer c.Close()
201
202 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
203
204 assert.Equal(t, connectionAckMsg, readOp(c).Type)
205 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
206 })
207
208 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
209 h := testserver.New()
210 h.AddTransport(transport.Websocket{
211 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
212 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
213 },
214 })
215 srv := httptest.NewServer(h)
216 defer srv.Close()
217
218 c := wsConnect(srv.URL)
219 defer c.Close()
220
221 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
222
223 assert.Equal(t, connectionAckMsg, readOp(c).Type)
224 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
225 })
226
227 t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
228 h := testserver.New()
229 h.AddTransport(transport.Websocket{
230 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
231 return ctx, nil, errors.New("invalid init payload")
232 },
233 })
234 srv := httptest.NewServer(h)
235 defer srv.Close()
236
237 c := wsConnect(srv.URL)
238 defer c.Close()
239
240 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
241
242 msg := readOp(c)
243 assert.Equal(t, connectionErrorMsg, msg.Type)
244 assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
245 })
246
247 t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
248 es := &graphql.ExecutableSchemaMock{
249 ExecFunc: func(ctx context.Context) graphql.ResponseHandler {
250 assert.Equal(t, "newvalue", ctx.Value(ckey("newkey")))
251 return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)})
252 },
253 SchemaFunc: func() *ast.Schema {
254 return gqlparser.MustLoadSchema(&ast.Source{Input: `
255 schema { query: Query }
256 type Query {
257 empty: String
258 }
259 `})
260 },
261 }
262 h := handler.New(es)
263
264 h.AddTransport(transport.Websocket{
265 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
266 return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
267 },
268 })
269
270 c := client.New(h)
271
272 socket := c.Websocket("{ empty } ")
273 defer socket.Close()
274 var resp struct {
275 Empty string
276 }
277 err := socket.Next(&resp)
278 require.NoError(t, err)
279 assert.Equal(t, "ok", resp.Empty)
280 })
281
282 t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) {
283 h := testserver.New()
284 var cancel func()
285 h.AddTransport(transport.Websocket{
286 InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
287 newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
288 return
289 },
290 })
291 srv := httptest.NewServer(h)
292 defer srv.Close()
293
294 c := wsConnect(srv.URL)
295 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
296 assert.Equal(t, connectionAckMsg, readOp(c).Type)
297 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
298
299
300 defer cancel()
301
302 time.Sleep(time.Millisecond * 10)
303 m := readOp(c)
304 assert.Equal(t, m.Type, connectionErrorMsg)
305 assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
306 })
307 t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
308 h := testserver.New()
309 h.AddTransport(transport.Websocket{
310 InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
311 initResponsePayload := transport.InitPayload{"trackingId": "123-456"}
312 return context.WithValue(ctx, ckey("newkey"), "newvalue"), &initResponsePayload, nil
313 },
314 })
315 srv := httptest.NewServer(h)
316 defer srv.Close()
317
318 c := wsConnect(srv.URL)
319 defer c.Close()
320
321 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
322
323 connAck := readOp(c)
324 assert.Equal(t, connectionAckMsg, connAck.Type)
325
326 var payload map[string]interface{}
327 err := json.Unmarshal(connAck.Payload, &payload)
328 if err != nil {
329 t.Fatal("Unexpected Error", err)
330 }
331 assert.EqualValues(t, "123-456", payload["trackingId"])
332 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
333 })
334 }
335
336 func TestWebSocketInitTimeout(t *testing.T) {
337 t.Run("times out if no init message is received within the configured duration", func(t *testing.T) {
338 h := testserver.New()
339 h.AddTransport(transport.Websocket{
340 InitTimeout: 5 * time.Millisecond,
341 })
342 srv := httptest.NewServer(h)
343 defer srv.Close()
344
345 c := wsConnect(srv.URL)
346 defer c.Close()
347
348 var msg operationMessage
349 err := c.ReadJSON(&msg)
350 assert.Error(t, err)
351 assert.Contains(t, err.Error(), "timeout")
352 })
353
354 t.Run("keeps waiting for an init message if no time out is configured", func(t *testing.T) {
355 h := testserver.New()
356 h.AddTransport(transport.Websocket{})
357 srv := httptest.NewServer(h)
358 defer srv.Close()
359
360 c := wsConnect(srv.URL)
361 defer c.Close()
362
363 done := make(chan interface{}, 1)
364 go func() {
365 var msg operationMessage
366 _ = c.ReadJSON(&msg)
367 done <- 1
368 }()
369
370 select {
371 case <-done:
372 assert.Fail(t, "web socket read operation finished while it shouldn't have")
373 case <-time.After(100 * time.Millisecond):
374
375 }
376 })
377 }
378
379 func TestWebSocketErrorFunc(t *testing.T) {
380 t.Run("the error handler gets called when an error occurs", func(t *testing.T) {
381 errFuncCalled := make(chan bool, 1)
382 h := testserver.New()
383 h.AddTransport(transport.Websocket{
384 ErrorFunc: func(_ context.Context, err error) {
385 require.Error(t, err)
386 assert.Equal(t, err.Error(), "websocket read: invalid message received")
387 assert.IsType(t, transport.WebsocketError{}, err)
388 assert.True(t, err.(transport.WebsocketError).IsReadError)
389 errFuncCalled <- true
390 },
391 })
392
393 srv := httptest.NewServer(h)
394 defer srv.Close()
395
396 c := wsConnect(srv.URL)
397 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
398 assert.Equal(t, connectionAckMsg, readOp(c).Type)
399 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
400 require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte("mark my words, you will regret this")))
401
402 select {
403 case res := <-errFuncCalled:
404 assert.True(t, res)
405 case <-time.NewTimer(time.Millisecond * 20).C:
406 assert.Fail(t, "The fail handler was not called in time")
407 }
408 })
409
410 t.Run("init func errors do not call the error handler", func(t *testing.T) {
411 h := testserver.New()
412 h.AddTransport(transport.Websocket{
413 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
414 return ctx, nil, errors.New("this is not what we agreed upon")
415 },
416 ErrorFunc: func(_ context.Context, err error) {
417 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
418 },
419 })
420 srv := httptest.NewServer(h)
421 defer srv.Close()
422
423 c := wsConnect(srv.URL)
424 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
425 time.Sleep(time.Millisecond * 20)
426 })
427
428 t.Run("init func context closes do not call the error handler", func(t *testing.T) {
429 h := testserver.New()
430 h.AddTransport(transport.Websocket{
431 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
432 newCtx, cancel := context.WithCancel(ctx)
433 time.AfterFunc(time.Millisecond*5, cancel)
434 return newCtx, nil, nil
435 },
436 ErrorFunc: func(_ context.Context, err error) {
437 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
438 },
439 })
440 srv := httptest.NewServer(h)
441 defer srv.Close()
442
443 c := wsConnect(srv.URL)
444 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
445 assert.Equal(t, connectionAckMsg, readOp(c).Type)
446 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
447 time.Sleep(time.Millisecond * 20)
448 })
449
450 t.Run("init func context deadlines do not call the error handler", func(t *testing.T) {
451 h := testserver.New()
452 var cancel func()
453 h.AddTransport(transport.Websocket{
454 InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
455 newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5))
456 return newCtx, nil, nil
457 },
458 ErrorFunc: func(_ context.Context, err error) {
459 assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
460 },
461 })
462 srv := httptest.NewServer(h)
463 defer srv.Close()
464
465 c := wsConnect(srv.URL)
466 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
467 assert.Equal(t, connectionAckMsg, readOp(c).Type)
468 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
469
470
471 defer cancel()
472
473 time.Sleep(time.Millisecond * 20)
474 })
475 }
476
477 func TestWebSocketCloseFunc(t *testing.T) {
478 t.Run("the on close handler gets called when the websocket is closed", func(t *testing.T) {
479 closeFuncCalled := make(chan bool, 1)
480 h := testserver.New()
481 h.AddTransport(transport.Websocket{
482 CloseFunc: func(_ context.Context, _closeCode int) {
483 closeFuncCalled <- true
484 },
485 })
486
487 srv := httptest.NewServer(h)
488 defer srv.Close()
489
490 c := wsConnect(srv.URL)
491 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
492 assert.Equal(t, connectionAckMsg, readOp(c).Type)
493 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
494 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
495
496 select {
497 case res := <-closeFuncCalled:
498 assert.True(t, res)
499 case <-time.NewTimer(time.Millisecond * 20).C:
500 assert.Fail(t, "The close handler was not called in time")
501 }
502 })
503
504 t.Run("the on close handler gets called only once when the websocket is closed", func(t *testing.T) {
505 closeFuncCalled := make(chan bool, 1)
506 h := testserver.New()
507 h.AddTransport(transport.Websocket{
508 CloseFunc: func(_ context.Context, _closeCode int) {
509 closeFuncCalled <- true
510 },
511 })
512
513 srv := httptest.NewServer(h)
514 defer srv.Close()
515
516 c := wsConnect(srv.URL)
517 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
518 assert.Equal(t, connectionAckMsg, readOp(c).Type)
519 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
520 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
521
522 select {
523 case res := <-closeFuncCalled:
524 assert.True(t, res)
525 case <-time.NewTimer(time.Millisecond * 20).C:
526 assert.Fail(t, "The close handler was not called in time")
527 }
528
529 select {
530 case <-closeFuncCalled:
531 assert.Fail(t, "The close handler was called more than once")
532 case <-time.NewTimer(time.Millisecond * 20).C:
533
534 }
535 })
536
537 t.Run("init func errors call the close handler", func(t *testing.T) {
538 h := testserver.New()
539 closeFuncCalled := make(chan bool, 1)
540 h.AddTransport(transport.Websocket{
541 InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
542 return ctx, nil, errors.New("error during init")
543 },
544 CloseFunc: func(_ context.Context, _closeCode int) {
545 closeFuncCalled <- true
546 },
547 })
548 srv := httptest.NewServer(h)
549 defer srv.Close()
550
551 c := wsConnect(srv.URL)
552 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
553 select {
554 case res := <-closeFuncCalled:
555 assert.True(t, res)
556 case <-time.NewTimer(time.Millisecond * 20).C:
557 assert.Fail(t, "The close handler was not called in time")
558 }
559 })
560 }
561
562 func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) {
563 initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
564 h := testserver.New()
565 h.AddTransport(ws)
566 return h, httptest.NewServer(h)
567 }
568
569 t.Run("server acks init", func(t *testing.T) {
570 _, srv := initialize(transport.Websocket{})
571 defer srv.Close()
572
573 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
574 defer c.Close()
575
576 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
577 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
578 })
579
580 t.Run("client can receive data", func(t *testing.T) {
581 handler, srv := initialize(transport.Websocket{})
582 defer srv.Close()
583
584 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
585 defer c.Close()
586
587 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
588 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
589
590 require.NoError(t, c.WriteJSON(&operationMessage{
591 Type: graphqltransportwsSubscribeMsg,
592 ID: "test_1",
593 Payload: json.RawMessage(`{"query": "subscription { name }"}`),
594 }))
595
596 handler.SendNextSubscriptionMessage()
597 msg := readOp(c)
598 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
599 require.Equal(t, "test_1", msg.ID, string(msg.Payload))
600 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
601
602 handler.SendNextSubscriptionMessage()
603 msg = readOp(c)
604 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
605 require.Equal(t, "test_1", msg.ID, string(msg.Payload))
606 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
607
608 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"}))
609
610 msg = readOp(c)
611 require.Equal(t, graphqltransportwsCompleteMsg, msg.Type)
612 require.Equal(t, "test_1", msg.ID)
613 })
614
615 t.Run("receives no graphql-ws keep alive messages", func(t *testing.T) {
616 _, srv := initialize(transport.Websocket{KeepAlivePingInterval: 5 * time.Millisecond})
617 defer srv.Close()
618
619 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
620 defer c.Close()
621
622 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
623 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
624
625
626 c.SetReadDeadline(time.Now().UTC().Add(50 * time.Millisecond))
627 var msg operationMessage
628 err := c.ReadJSON(&msg)
629 require.Error(t, err)
630 assert.Contains(t, err.Error(), "timeout")
631 })
632 }
633
634 func TestWebsocketWithPingPongInterval(t *testing.T) {
635 initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
636 h := testserver.New()
637 h.AddTransport(ws)
638 return h, httptest.NewServer(h)
639 }
640
641 t.Run("client receives ping and responds with pong", func(t *testing.T) {
642 _, srv := initialize(transport.Websocket{PingPongInterval: 20 * time.Millisecond})
643 defer srv.Close()
644
645 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
646 defer c.Close()
647
648 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
649 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
650
651 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
652 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg}))
653 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
654 })
655
656 t.Run("client sends ping and expects pong", func(t *testing.T) {
657 _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
658 defer srv.Close()
659 })
660
661 t.Run("client sends ping and expects pong", func(t *testing.T) {
662 _, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
663 defer srv.Close()
664
665 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
666 defer c.Close()
667
668 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
669 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
670
671 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg}))
672 assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
673 })
674
675 t.Run("server closes with error if client does not pong and !MissingPongOk", func(t *testing.T) {
676 h := testserver.New()
677 closeFuncCalled := make(chan bool, 1)
678 h.AddTransport(transport.Websocket{
679 MissingPongOk: false,
680 PingPongInterval: 5 * time.Millisecond,
681 CloseFunc: func(_ context.Context, _closeCode int) {
682 closeFuncCalled <- true
683 },
684 })
685
686 srv := httptest.NewServer(h)
687 defer srv.Close()
688
689 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
690 defer c.Close()
691
692 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
693 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
694
695 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
696
697 select {
698 case res := <-closeFuncCalled:
699 assert.True(t, res)
700 case <-time.NewTimer(time.Millisecond * 20).C:
701
702 assert.Fail(t, "The close handler was not called in time")
703 }
704 })
705
706 t.Run("server does not close with error if client does not pong and MissingPongOk", func(t *testing.T) {
707 h := testserver.New()
708 closeFuncCalled := make(chan bool, 1)
709 h.AddTransport(transport.Websocket{
710 MissingPongOk: true,
711 PingPongInterval: 10 * time.Millisecond,
712 CloseFunc: func(_ context.Context, _closeCode int) {
713 closeFuncCalled <- true
714 },
715 })
716
717 srv := httptest.NewServer(h)
718 defer srv.Close()
719
720 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
721 defer c.Close()
722
723 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
724 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
725
726 assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
727
728 select {
729 case <-closeFuncCalled:
730 assert.Fail(t, "The close handler was called even with MissingPongOk = true")
731 case _, ok := <-time.NewTimer(time.Millisecond * 20).C:
732 assert.True(t, ok)
733 }
734 })
735
736 t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) {
737
738
739
740
741
742 _, srv := initialize(transport.Websocket{
743 PingPongInterval: 5 * time.Millisecond,
744 KeepAlivePingInterval: 10 * time.Millisecond,
745 })
746 defer srv.Close()
747
748
749 c := wsConnect(srv.URL)
750 defer c.Close()
751
752
753 require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
754 assert.Equal(t, connectionAckMsg, readOp(c).Type)
755 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
756
757
758 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
759 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
760 assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
761 })
762 t.Run("pong only messages are sent when configured with graphql-transport-ws", func(t *testing.T) {
763
764 h, srv := initialize(transport.Websocket{PongOnlyInterval: 10 * time.Millisecond})
765 defer srv.Close()
766
767 c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
768 defer c.Close()
769
770 require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
771 assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
772
773 assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
774
775 require.NoError(t, c.WriteJSON(&operationMessage{
776 Type: graphqltransportwsSubscribeMsg,
777 ID: "test_1",
778 Payload: json.RawMessage(`{"query": "subscription { name }"}`),
779 }))
780
781
782 msg := readOp(c)
783 assert.Equal(t, graphqltransportwsPongMsg, msg.Type)
784
785
786 h.SendNextSubscriptionMessage()
787 msg = readOp(c)
788 require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
789 require.Equal(t, "test_1", msg.ID, string(msg.Payload))
790 require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
791
792
793 msg = readOp(c)
794 assert.Equal(t, graphqltransportwsPongMsg, msg.Type)
795 })
796
797 }
798
799 func wsConnect(url string) *websocket.Conn {
800 return wsConnectWithSubprocotol(url, "")
801 }
802
803 func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn {
804 h := make(http.Header)
805 if subprocotol != "" {
806 h.Add("Sec-WebSocket-Protocol", subprocotol)
807 }
808
809 c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h)
810 if err != nil {
811 panic(err)
812 }
813 _ = resp.Body.Close()
814
815 return c
816 }
817
818 func writeRaw(conn *websocket.Conn, msg string) {
819 if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
820 panic(err)
821 }
822 }
823
824 func readOp(conn *websocket.Conn) operationMessage {
825 var msg operationMessage
826 if err := conn.ReadJSON(&msg); err != nil {
827 panic(err)
828 }
829 return msg
830 }
831
832
833
834 const (
835 connectionInitMsg = "connection_init"
836 connectionTerminateMsg = "connection_terminate"
837 startMsg = "start"
838 stopMsg = "stop"
839 connectionAckMsg = "connection_ack"
840 connectionErrorMsg = "connection_error"
841 dataMsg = "data"
842 errorMsg = "error"
843 completeMsg = "complete"
844 connectionKeepAliveMsg = "ka"
845 )
846
847
848
849 const (
850 graphqltransportwsSubprotocol = "graphql-transport-ws"
851
852 graphqltransportwsConnectionInitMsg = "connection_init"
853 graphqltransportwsConnectionAckMsg = "connection_ack"
854 graphqltransportwsSubscribeMsg = "subscribe"
855 graphqltransportwsNextMsg = "next"
856 graphqltransportwsCompleteMsg = "complete"
857 graphqltransportwsPingMsg = "ping"
858 graphqltransportwsPongMsg = "pong"
859 )
860
861 type operationMessage struct {
862 Payload json.RawMessage `json:"payload,omitempty"`
863 ID string `json:"id,omitempty"`
864 Type string `json:"type"`
865 }
866
View as plain text