1 package transport
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log"
10 "net"
11 "net/http"
12 "sync"
13 "time"
14
15 "github.com/gorilla/websocket"
16 "github.com/vektah/gqlparser/v2/gqlerror"
17
18 "github.com/99designs/gqlgen/graphql"
19 "github.com/99designs/gqlgen/graphql/errcode"
20 )
21
22 type (
23 Websocket struct {
24 Upgrader websocket.Upgrader
25 InitFunc WebsocketInitFunc
26 InitTimeout time.Duration
27 ErrorFunc WebsocketErrorFunc
28 CloseFunc WebsocketCloseFunc
29 KeepAlivePingInterval time.Duration
30 PongOnlyInterval time.Duration
31 PingPongInterval time.Duration
32
40 MissingPongOk bool
41
42 didInjectSubprotocols bool
43 }
44 wsConnection struct {
45 Websocket
46 ctx context.Context
47 conn *websocket.Conn
48 me messageExchanger
49 active map[string]context.CancelFunc
50 mu sync.Mutex
51 keepAliveTicker *time.Ticker
52 pongOnlyTicker *time.Ticker
53 pingPongTicker *time.Ticker
54 receivedPong bool
55 exec graphql.GraphExecutor
56 closed bool
57
58 initPayload InitPayload
59 }
60
61 WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, *InitPayload, error)
62 WebsocketErrorFunc func(ctx context.Context, err error)
63
64
65 WebsocketCloseFunc func(ctx context.Context, closeCode int)
66 )
67
68 var errReadTimeout = errors.New("read timeout")
69
70 type WebsocketError struct {
71 Err error
72
73
74 IsReadError bool
75 }
76
77 func (e WebsocketError) Error() string {
78 if e.IsReadError {
79 return fmt.Sprintf("websocket read: %v", e.Err)
80 }
81 return fmt.Sprintf("websocket write: %v", e.Err)
82 }
83
84 var (
85 _ graphql.Transport = Websocket{}
86 _ error = WebsocketError{}
87 )
88
89 func (t Websocket) Supports(r *http.Request) bool {
90 return r.Header.Get("Upgrade") != ""
91 }
92
93 func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
94 t.injectGraphQLWSSubprotocols()
95 ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
96 if err != nil {
97 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
98 SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
99 return
100 }
101
102 var me messageExchanger
103 switch ws.Subprotocol() {
104 default:
105 msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
106 ws.WriteMessage(websocket.CloseMessage, msg)
107 return
108 case graphqlwsSubprotocol, "":
109
110
111 me = graphqlwsMessageExchanger{c: ws}
112 case graphqltransportwsSubprotocol:
113 me = graphqltransportwsMessageExchanger{c: ws}
114 }
115
116 conn := wsConnection{
117 active: map[string]context.CancelFunc{},
118 conn: ws,
119 ctx: r.Context(),
120 exec: exec,
121 me: me,
122 Websocket: t,
123 }
124
125 if !conn.init() {
126 return
127 }
128
129 conn.run()
130 }
131
132 func (c *wsConnection) handlePossibleError(err error, isReadError bool) {
133 if c.ErrorFunc != nil && err != nil {
134 c.ErrorFunc(c.ctx, WebsocketError{
135 Err: err,
136 IsReadError: isReadError,
137 })
138 }
139 }
140
141 func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) {
142 messages, errs := make(chan message, 1), make(chan error, 1)
143
144 go func() {
145 if m, err := c.me.NextMessage(); err != nil {
146 errs <- err
147 } else {
148 messages <- m
149 }
150 }()
151
152 select {
153 case m := <-messages:
154 return m, nil
155 case err := <-errs:
156 return message{}, err
157 case <-time.After(timeout):
158 return message{}, errReadTimeout
159 }
160 }
161
162 func (c *wsConnection) init() bool {
163 var m message
164 var err error
165
166 if c.InitTimeout != 0 {
167 m, err = c.nextMessageWithTimeout(c.InitTimeout)
168 } else {
169 m, err = c.me.NextMessage()
170 }
171
172 if err != nil {
173 if err == errReadTimeout {
174 c.close(websocket.CloseProtocolError, "connection initialisation timeout")
175 return false
176 }
177
178 if err == errInvalidMsg {
179 c.sendConnectionError("invalid json")
180 }
181
182 c.close(websocket.CloseProtocolError, "decoding error")
183 return false
184 }
185
186 switch m.t {
187 case initMessageType:
188 if len(m.payload) > 0 {
189 c.initPayload = make(InitPayload)
190 err := json.Unmarshal(m.payload, &c.initPayload)
191 if err != nil {
192 return false
193 }
194 }
195
196 var initAckPayload *InitPayload = nil
197 if c.InitFunc != nil {
198 var ctx context.Context
199 ctx, initAckPayload, err = c.InitFunc(c.ctx, c.initPayload)
200 if err != nil {
201 c.sendConnectionError(err.Error())
202 c.close(websocket.CloseNormalClosure, "terminated")
203 return false
204 }
205 c.ctx = ctx
206 }
207
208 if initAckPayload != nil {
209 initJsonAckPayload, err := json.Marshal(*initAckPayload)
210 if err != nil {
211 panic(err)
212 }
213 c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload})
214 } else {
215 c.write(&message{t: connectionAckMessageType})
216 }
217 c.write(&message{t: keepAliveMessageType})
218 case connectionCloseMessageType:
219 c.close(websocket.CloseNormalClosure, "terminated")
220 return false
221 default:
222 c.sendConnectionError("unexpected message %s", m.t)
223 c.close(websocket.CloseProtocolError, "unexpected message")
224 return false
225 }
226
227 return true
228 }
229
230 func (c *wsConnection) write(msg *message) {
231 c.mu.Lock()
232 c.handlePossibleError(c.me.Send(msg), false)
233 c.mu.Unlock()
234 }
235
236 func (c *wsConnection) run() {
237
238
239 ctx, cancel := context.WithCancel(c.ctx)
240 defer func() {
241 cancel()
242 c.close(websocket.CloseAbnormalClosure, "unexpected closure")
243 }()
244
245
246
247 if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 {
248 c.mu.Lock()
249 c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
250 c.mu.Unlock()
251
252 go c.keepAlive(ctx)
253 }
254
255
256
257 if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PongOnlyInterval != 0 {
258 c.mu.Lock()
259 c.pongOnlyTicker = time.NewTicker(c.PongOnlyInterval)
260 c.mu.Unlock()
261
262 go c.keepAlivePongOnly(ctx)
263 }
264
265
266
267 if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 {
268 c.mu.Lock()
269 c.pingPongTicker = time.NewTicker(c.PingPongInterval)
270 c.mu.Unlock()
271
272 if !c.MissingPongOk {
273
274
275 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
276 }
277 go c.ping(ctx)
278 }
279
280
281
282 go c.closeOnCancel(ctx)
283
284 for {
285 start := graphql.Now()
286 m, err := c.me.NextMessage()
287 if err != nil {
288
289 if !errors.Is(err, net.ErrClosed) {
290 c.handlePossibleError(err, true)
291 }
292 return
293 }
294
295 switch m.t {
296 case startMessageType:
297 c.subscribe(start, &m)
298 case stopMessageType:
299 c.mu.Lock()
300 closer := c.active[m.id]
301 c.mu.Unlock()
302 if closer != nil {
303 closer()
304 }
305 case connectionCloseMessageType:
306 c.close(websocket.CloseNormalClosure, "terminated")
307 return
308 case pingMessageType:
309 c.write(&message{t: pongMessageType, payload: m.payload})
310 case pongMessageType:
311 c.mu.Lock()
312 c.receivedPong = true
313 c.mu.Unlock()
314
315 c.conn.SetReadDeadline(time.Time{})
316 default:
317 c.sendConnectionError("unexpected message %s", m.t)
318 c.close(websocket.CloseProtocolError, "unexpected message")
319 return
320 }
321 }
322 }
323
324 func (c *wsConnection) keepAlivePongOnly(ctx context.Context) {
325 for {
326 select {
327 case <-ctx.Done():
328 c.pongOnlyTicker.Stop()
329 return
330 case <-c.pongOnlyTicker.C:
331 c.write(&message{t: pongMessageType, payload: json.RawMessage{}})
332 }
333 }
334 }
335
336 func (c *wsConnection) keepAlive(ctx context.Context) {
337 for {
338 select {
339 case <-ctx.Done():
340 c.keepAliveTicker.Stop()
341 return
342 case <-c.keepAliveTicker.C:
343 c.write(&message{t: keepAliveMessageType})
344 }
345 }
346 }
347
348 func (c *wsConnection) ping(ctx context.Context) {
349 for {
350 select {
351 case <-ctx.Done():
352 c.pingPongTicker.Stop()
353 return
354 case <-c.pingPongTicker.C:
355 c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
356
357
358 c.mu.Lock()
359 if !c.MissingPongOk && c.receivedPong {
360 c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
361 }
362 c.receivedPong = false
363 c.mu.Unlock()
364 }
365 }
366 }
367
368 func (c *wsConnection) closeOnCancel(ctx context.Context) {
369 <-ctx.Done()
370
371 if r := closeReasonForContext(ctx); r != "" {
372 c.sendConnectionError(r)
373 }
374 c.close(websocket.CloseNormalClosure, "terminated")
375 }
376
377 func (c *wsConnection) subscribe(start time.Time, msg *message) {
378 ctx := graphql.StartOperationTrace(c.ctx)
379 var params *graphql.RawParams
380 if err := jsonDecode(bytes.NewReader(msg.payload), ¶ms); err != nil {
381 c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
382 c.complete(msg.id)
383 return
384 }
385
386 params.ReadTime = graphql.TraceTiming{
387 Start: start,
388 End: graphql.Now(),
389 }
390
391 rc, err := c.exec.CreateOperationContext(ctx, params)
392 if err != nil {
393 resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
394 switch errcode.GetErrorKind(err) {
395 case errcode.KindProtocol:
396 c.sendError(msg.id, resp.Errors...)
397 default:
398 c.sendResponse(msg.id, &graphql.Response{Errors: err})
399 }
400
401 c.complete(msg.id)
402 return
403 }
404
405 ctx = graphql.WithOperationContext(ctx, rc)
406
407 if c.initPayload != nil {
408 ctx = withInitPayload(ctx, c.initPayload)
409 }
410
411 ctx, cancel := context.WithCancel(ctx)
412 c.mu.Lock()
413 c.active[msg.id] = cancel
414 c.mu.Unlock()
415
416 go func() {
417 ctx = withSubscriptionErrorContext(ctx)
418 defer func() {
419 if r := recover(); r != nil {
420 err := rc.Recover(ctx, r)
421 var gqlerr *gqlerror.Error
422 if !errors.As(err, &gqlerr) {
423 gqlerr = &gqlerror.Error{}
424 if err != nil {
425 gqlerr.Message = err.Error()
426 }
427 }
428 c.sendError(msg.id, gqlerr)
429 }
430 if errs := getSubscriptionError(ctx); len(errs) != 0 {
431 c.sendError(msg.id, errs...)
432 } else {
433 c.complete(msg.id)
434 }
435 c.mu.Lock()
436 delete(c.active, msg.id)
437 c.mu.Unlock()
438 cancel()
439 }()
440
441 responses, ctx := c.exec.DispatchOperation(ctx, rc)
442 for {
443 response := responses(ctx)
444 if response == nil {
445 break
446 }
447
448 c.sendResponse(msg.id, response)
449 }
450
451
452 }()
453 }
454
455 func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
456 b, err := json.Marshal(response)
457 if err != nil {
458 panic(err)
459 }
460 c.write(&message{
461 payload: b,
462 id: id,
463 t: dataMessageType,
464 })
465 }
466
467 func (c *wsConnection) complete(id string) {
468 c.write(&message{id: id, t: completeMessageType})
469 }
470
471 func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
472 errs := make([]error, len(errors))
473 for i, err := range errors {
474 errs[i] = err
475 }
476 b, err := json.Marshal(errs)
477 if err != nil {
478 panic(err)
479 }
480 c.write(&message{t: errorMessageType, id: id, payload: b})
481 }
482
483 func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
484 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
485 if err != nil {
486 panic(err)
487 }
488
489 c.write(&message{t: connectionErrorMessageType, payload: b})
490 }
491
492 func (c *wsConnection) close(closeCode int, message string) {
493 c.mu.Lock()
494 if c.closed {
495 c.mu.Unlock()
496 return
497 }
498 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
499 for _, closer := range c.active {
500 closer()
501 }
502 c.closed = true
503 c.mu.Unlock()
504 _ = c.conn.Close()
505
506 if c.CloseFunc != nil {
507 c.CloseFunc(c.ctx, closeCode)
508 }
509 }
510
View as plain text