...
1 package transport
2
3 import (
4 "encoding/json"
5 "fmt"
6
7 "github.com/gorilla/websocket"
8 )
9
10
11 const (
12 graphqlwsSubprotocol = "graphql-ws"
13
14 graphqlwsConnectionInitMsg = graphqlwsMessageType("connection_init")
15 graphqlwsConnectionTerminateMsg = graphqlwsMessageType("connection_terminate")
16 graphqlwsStartMsg = graphqlwsMessageType("start")
17 graphqlwsStopMsg = graphqlwsMessageType("stop")
18 graphqlwsConnectionAckMsg = graphqlwsMessageType("connection_ack")
19 graphqlwsConnectionErrorMsg = graphqlwsMessageType("connection_error")
20 graphqlwsDataMsg = graphqlwsMessageType("data")
21 graphqlwsErrorMsg = graphqlwsMessageType("error")
22 graphqlwsCompleteMsg = graphqlwsMessageType("complete")
23 graphqlwsConnectionKeepAliveMsg = graphqlwsMessageType("ka")
24 )
25
26 var allGraphqlwsMessageTypes = []graphqlwsMessageType{
27 graphqlwsConnectionInitMsg,
28 graphqlwsConnectionTerminateMsg,
29 graphqlwsStartMsg,
30 graphqlwsStopMsg,
31 graphqlwsConnectionAckMsg,
32 graphqlwsConnectionErrorMsg,
33 graphqlwsDataMsg,
34 graphqlwsErrorMsg,
35 graphqlwsCompleteMsg,
36 graphqlwsConnectionKeepAliveMsg,
37 }
38
39 type (
40 graphqlwsMessageExchanger struct {
41 c *websocket.Conn
42 }
43
44 graphqlwsMessage struct {
45 Payload json.RawMessage `json:"payload,omitempty"`
46 ID string `json:"id,omitempty"`
47 Type graphqlwsMessageType `json:"type"`
48 noOp bool
49 }
50
51 graphqlwsMessageType string
52 )
53
54 func (me graphqlwsMessageExchanger) NextMessage() (message, error) {
55 _, r, err := me.c.NextReader()
56 if err != nil {
57 return message{}, handleNextReaderError(err)
58 }
59
60 var graphqlwsMessage graphqlwsMessage
61 if err := jsonDecode(r, &graphqlwsMessage); err != nil {
62 return message{}, errInvalidMsg
63 }
64
65 return graphqlwsMessage.toMessage()
66 }
67
68 func (me graphqlwsMessageExchanger) Send(m *message) error {
69 msg := &graphqlwsMessage{}
70 if err := msg.fromMessage(m); err != nil {
71 return err
72 }
73
74 if msg.noOp {
75 return nil
76 }
77
78 return me.c.WriteJSON(msg)
79 }
80
81 func (t *graphqlwsMessageType) UnmarshalText(text []byte) (err error) {
82 var found bool
83 for _, candidate := range allGraphqlwsMessageTypes {
84 if string(candidate) == string(text) {
85 *t = candidate
86 found = true
87 break
88 }
89 }
90
91 if !found {
92 err = fmt.Errorf("invalid message type %s", string(text))
93 }
94
95 return err
96 }
97
98 func (t graphqlwsMessageType) MarshalText() ([]byte, error) {
99 return []byte(string(t)), nil
100 }
101
102 func (m graphqlwsMessage) toMessage() (message, error) {
103 var t messageType
104 var err error
105 switch m.Type {
106 default:
107 err = fmt.Errorf("invalid client->server message type %s", m.Type)
108 case graphqlwsConnectionInitMsg:
109 t = initMessageType
110 case graphqlwsConnectionTerminateMsg:
111 t = connectionCloseMessageType
112 case graphqlwsStartMsg:
113 t = startMessageType
114 case graphqlwsStopMsg:
115 t = stopMessageType
116 case graphqlwsConnectionAckMsg:
117 t = connectionAckMessageType
118 case graphqlwsConnectionErrorMsg:
119 t = connectionErrorMessageType
120 case graphqlwsDataMsg:
121 t = dataMessageType
122 case graphqlwsErrorMsg:
123 t = errorMessageType
124 case graphqlwsCompleteMsg:
125 t = completeMessageType
126 case graphqlwsConnectionKeepAliveMsg:
127 t = keepAliveMessageType
128 }
129
130 return message{
131 payload: m.Payload,
132 id: m.ID,
133 t: t,
134 }, err
135 }
136
137 func (m *graphqlwsMessage) fromMessage(msg *message) (err error) {
138 m.ID = msg.id
139 m.Payload = msg.payload
140
141 switch msg.t {
142 default:
143 err = fmt.Errorf("invalid server->client message type %s", msg.t)
144 case initMessageType:
145 m.Type = graphqlwsConnectionInitMsg
146 case connectionAckMessageType:
147 m.Type = graphqlwsConnectionAckMsg
148 case keepAliveMessageType:
149 m.Type = graphqlwsConnectionKeepAliveMsg
150 case connectionErrorMessageType:
151 m.Type = graphqlwsConnectionErrorMsg
152 case connectionCloseMessageType:
153 m.Type = graphqlwsConnectionTerminateMsg
154 case startMessageType:
155 m.Type = graphqlwsStartMsg
156 case stopMessageType:
157 m.Type = graphqlwsStopMsg
158 case dataMessageType:
159 m.Type = graphqlwsDataMsg
160 case completeMessageType:
161 m.Type = graphqlwsCompleteMsg
162 case errorMessageType:
163 m.Type = graphqlwsErrorMsg
164 case pingMessageType:
165 m.noOp = true
166 case pongMessageType:
167 m.noOp = true
168 }
169
170 return err
171 }
172
View as plain text