...
1 package transport
2
3 import (
4 "encoding/json"
5 "fmt"
6
7 "github.com/gorilla/websocket"
8 )
9
10
11 const (
12 graphqltransportwsSubprotocol = "graphql-transport-ws"
13
14 graphqltransportwsConnectionInitMsg = graphqltransportwsMessageType("connection_init")
15 graphqltransportwsConnectionAckMsg = graphqltransportwsMessageType("connection_ack")
16 graphqltransportwsSubscribeMsg = graphqltransportwsMessageType("subscribe")
17 graphqltransportwsNextMsg = graphqltransportwsMessageType("next")
18 graphqltransportwsErrorMsg = graphqltransportwsMessageType("error")
19 graphqltransportwsCompleteMsg = graphqltransportwsMessageType("complete")
20 graphqltransportwsPingMsg = graphqltransportwsMessageType("ping")
21 graphqltransportwsPongMsg = graphqltransportwsMessageType("pong")
22 )
23
24 var allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{
25 graphqltransportwsConnectionInitMsg,
26 graphqltransportwsConnectionAckMsg,
27 graphqltransportwsSubscribeMsg,
28 graphqltransportwsNextMsg,
29 graphqltransportwsErrorMsg,
30 graphqltransportwsCompleteMsg,
31 graphqltransportwsPingMsg,
32 graphqltransportwsPongMsg,
33 }
34
35 type (
36 graphqltransportwsMessageExchanger struct {
37 c *websocket.Conn
38 }
39
40 graphqltransportwsMessage struct {
41 Payload json.RawMessage `json:"payload,omitempty"`
42 ID string `json:"id,omitempty"`
43 Type graphqltransportwsMessageType `json:"type"`
44 noOp bool
45 }
46
47 graphqltransportwsMessageType string
48 )
49
50 func (me graphqltransportwsMessageExchanger) NextMessage() (message, error) {
51 _, r, err := me.c.NextReader()
52 if err != nil {
53 return message{}, handleNextReaderError(err)
54 }
55
56 var graphqltransportwsMessage graphqltransportwsMessage
57 if err := jsonDecode(r, &graphqltransportwsMessage); err != nil {
58 return message{}, errInvalidMsg
59 }
60
61 return graphqltransportwsMessage.toMessage()
62 }
63
64 func (me graphqltransportwsMessageExchanger) Send(m *message) error {
65 msg := &graphqltransportwsMessage{}
66 if err := msg.fromMessage(m); err != nil {
67 return err
68 }
69
70 if msg.noOp {
71 return nil
72 }
73
74 return me.c.WriteJSON(msg)
75 }
76
77 func (t *graphqltransportwsMessageType) UnmarshalText(text []byte) (err error) {
78 var found bool
79 for _, candidate := range allGraphqltransportwsMessageTypes {
80 if string(candidate) == string(text) {
81 *t = candidate
82 found = true
83 break
84 }
85 }
86
87 if !found {
88 err = fmt.Errorf("invalid message type %s", string(text))
89 }
90
91 return err
92 }
93
94 func (t graphqltransportwsMessageType) MarshalText() ([]byte, error) {
95 return []byte(string(t)), nil
96 }
97
98 func (m graphqltransportwsMessage) toMessage() (message, error) {
99 var t messageType
100 var err error
101 switch m.Type {
102 default:
103 err = fmt.Errorf("invalid client->server message type %s", m.Type)
104 case graphqltransportwsConnectionInitMsg:
105 t = initMessageType
106 case graphqltransportwsSubscribeMsg:
107 t = startMessageType
108 case graphqltransportwsCompleteMsg:
109 t = stopMessageType
110 case graphqltransportwsPingMsg:
111 t = pingMessageType
112 case graphqltransportwsPongMsg:
113 t = pongMessageType
114 }
115
116 return message{
117 payload: m.Payload,
118 id: m.ID,
119 t: t,
120 }, err
121 }
122
123 func (m *graphqltransportwsMessage) fromMessage(msg *message) (err error) {
124 m.ID = msg.id
125 m.Payload = msg.payload
126
127 switch msg.t {
128 default:
129 err = fmt.Errorf("invalid server->client message type %s", msg.t)
130 case connectionAckMessageType:
131 m.Type = graphqltransportwsConnectionAckMsg
132 case keepAliveMessageType:
133 m.noOp = true
134 case connectionErrorMessageType:
135 m.noOp = true
136 case dataMessageType:
137 m.Type = graphqltransportwsNextMsg
138 case completeMessageType:
139 m.Type = graphqltransportwsCompleteMsg
140 case errorMessageType:
141 m.Type = graphqltransportwsErrorMsg
142 case pingMessageType:
143 m.Type = graphqltransportwsPingMsg
144 case pongMessageType:
145 m.Type = graphqltransportwsPongMsg
146 }
147
148 return err
149 }
150
View as plain text