1
2
3
4
5
6
7
8
9
10
11
12 package websocket
13
14 import (
15 "bufio"
16 "crypto/tls"
17 "encoding/json"
18 "errors"
19 "io"
20 "net"
21 "net/http"
22 "net/url"
23 "sync"
24 "time"
25 )
26
27 const (
28 ProtocolVersionHybi13 = 13
29 ProtocolVersionHybi = ProtocolVersionHybi13
30 SupportedProtocolVersion = "13"
31
32 ContinuationFrame = 0
33 TextFrame = 1
34 BinaryFrame = 2
35 CloseFrame = 8
36 PingFrame = 9
37 PongFrame = 10
38 UnknownFrame = 255
39
40 DefaultMaxPayloadBytes = 32 << 20
41 )
42
43
44 type ProtocolError struct {
45 ErrorString string
46 }
47
48 func (err *ProtocolError) Error() string { return err.ErrorString }
49
50 var (
51 ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
52 ErrBadScheme = &ProtocolError{"bad scheme"}
53 ErrBadStatus = &ProtocolError{"bad status"}
54 ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
55 ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
56 ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
57 ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
58 ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
59 ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
60 ErrBadFrame = &ProtocolError{"bad frame"}
61 ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
62 ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
63 ErrBadRequestMethod = &ProtocolError{"bad method"}
64 ErrNotSupported = &ProtocolError{"not supported"}
65 )
66
67
68
69 var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
70
71
72 type Addr struct {
73 *url.URL
74 }
75
76
77 func (addr *Addr) Network() string { return "websocket" }
78
79
80 type Config struct {
81
82 Location *url.URL
83
84
85 Origin *url.URL
86
87
88 Protocol []string
89
90
91 Version int
92
93
94 TlsConfig *tls.Config
95
96
97 Header http.Header
98
99
100 Dialer *net.Dialer
101
102 handshakeData map[string]string
103 }
104
105
106 type serverHandshaker interface {
107
108
109 ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
110
111
112
113 AcceptHandshake(buf *bufio.Writer) (err error)
114
115
116 NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
117 }
118
119
120 type frameReader interface {
121
122 io.Reader
123
124
125 PayloadType() byte
126
127
128 HeaderReader() io.Reader
129
130
131
132 TrailerReader() io.Reader
133
134
135 Len() int
136 }
137
138
139 type frameReaderFactory interface {
140 NewFrameReader() (r frameReader, err error)
141 }
142
143
144 type frameWriter interface {
145
146 io.WriteCloser
147 }
148
149
150 type frameWriterFactory interface {
151 NewFrameWriter(payloadType byte) (w frameWriter, err error)
152 }
153
154 type frameHandler interface {
155 HandleFrame(frame frameReader) (r frameReader, err error)
156 WriteClose(status int) (err error)
157 }
158
159
160
161
162 type Conn struct {
163 config *Config
164 request *http.Request
165
166 buf *bufio.ReadWriter
167 rwc io.ReadWriteCloser
168
169 rio sync.Mutex
170 frameReaderFactory
171 frameReader
172
173 wio sync.Mutex
174 frameWriterFactory
175
176 frameHandler
177 PayloadType byte
178 defaultCloseStatus int
179
180
181
182 MaxPayloadBytes int
183 }
184
185
186
187
188
189
190 func (ws *Conn) Read(msg []byte) (n int, err error) {
191 ws.rio.Lock()
192 defer ws.rio.Unlock()
193 again:
194 if ws.frameReader == nil {
195 frame, err := ws.frameReaderFactory.NewFrameReader()
196 if err != nil {
197 return 0, err
198 }
199 ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
200 if err != nil {
201 return 0, err
202 }
203 if ws.frameReader == nil {
204 goto again
205 }
206 }
207 n, err = ws.frameReader.Read(msg)
208 if err == io.EOF {
209 if trailer := ws.frameReader.TrailerReader(); trailer != nil {
210 io.Copy(io.Discard, trailer)
211 }
212 ws.frameReader = nil
213 goto again
214 }
215 return n, err
216 }
217
218
219
220 func (ws *Conn) Write(msg []byte) (n int, err error) {
221 ws.wio.Lock()
222 defer ws.wio.Unlock()
223 w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
224 if err != nil {
225 return 0, err
226 }
227 n, err = w.Write(msg)
228 w.Close()
229 return n, err
230 }
231
232
233 func (ws *Conn) Close() error {
234 err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
235 err1 := ws.rwc.Close()
236 if err != nil {
237 return err
238 }
239 return err1
240 }
241
242
243 func (ws *Conn) IsClientConn() bool { return ws.request == nil }
244
245
246 func (ws *Conn) IsServerConn() bool { return ws.request != nil }
247
248
249
250 func (ws *Conn) LocalAddr() net.Addr {
251 if ws.IsClientConn() {
252 return &Addr{ws.config.Origin}
253 }
254 return &Addr{ws.config.Location}
255 }
256
257
258
259 func (ws *Conn) RemoteAddr() net.Addr {
260 if ws.IsClientConn() {
261 return &Addr{ws.config.Location}
262 }
263 return &Addr{ws.config.Origin}
264 }
265
266 var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
267
268
269 func (ws *Conn) SetDeadline(t time.Time) error {
270 if conn, ok := ws.rwc.(net.Conn); ok {
271 return conn.SetDeadline(t)
272 }
273 return errSetDeadline
274 }
275
276
277 func (ws *Conn) SetReadDeadline(t time.Time) error {
278 if conn, ok := ws.rwc.(net.Conn); ok {
279 return conn.SetReadDeadline(t)
280 }
281 return errSetDeadline
282 }
283
284
285 func (ws *Conn) SetWriteDeadline(t time.Time) error {
286 if conn, ok := ws.rwc.(net.Conn); ok {
287 return conn.SetWriteDeadline(t)
288 }
289 return errSetDeadline
290 }
291
292
293 func (ws *Conn) Config() *Config { return ws.config }
294
295
296
297 func (ws *Conn) Request() *http.Request { return ws.request }
298
299
300 type Codec struct {
301 Marshal func(v interface{}) (data []byte, payloadType byte, err error)
302 Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
303 }
304
305
306 func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
307 data, payloadType, err := cd.Marshal(v)
308 if err != nil {
309 return err
310 }
311 ws.wio.Lock()
312 defer ws.wio.Unlock()
313 w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
314 if err != nil {
315 return err
316 }
317 _, err = w.Write(data)
318 w.Close()
319 return err
320 }
321
322
323
324
325
326
327
328 func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
329 ws.rio.Lock()
330 defer ws.rio.Unlock()
331 if ws.frameReader != nil {
332 _, err = io.Copy(io.Discard, ws.frameReader)
333 if err != nil {
334 return err
335 }
336 ws.frameReader = nil
337 }
338 again:
339 frame, err := ws.frameReaderFactory.NewFrameReader()
340 if err != nil {
341 return err
342 }
343 frame, err = ws.frameHandler.HandleFrame(frame)
344 if err != nil {
345 return err
346 }
347 if frame == nil {
348 goto again
349 }
350 maxPayloadBytes := ws.MaxPayloadBytes
351 if maxPayloadBytes == 0 {
352 maxPayloadBytes = DefaultMaxPayloadBytes
353 }
354 if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
355
356
357
358
359
360 ws.frameReader = frame
361 return ErrFrameTooLarge
362 }
363 payloadType := frame.PayloadType()
364 data, err := io.ReadAll(frame)
365 if err != nil {
366 return err
367 }
368 return cd.Unmarshal(data, payloadType, v)
369 }
370
371 func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
372 switch data := v.(type) {
373 case string:
374 return []byte(data), TextFrame, nil
375 case []byte:
376 return data, BinaryFrame, nil
377 }
378 return nil, UnknownFrame, ErrNotSupported
379 }
380
381 func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
382 switch data := v.(type) {
383 case *string:
384 *data = string(msg)
385 return nil
386 case *[]byte:
387 *data = msg
388 return nil
389 }
390 return ErrNotSupported
391 }
392
393
418 var Message = Codec{marshal, unmarshal}
419
420 func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
421 msg, err = json.Marshal(v)
422 return msg, TextFrame, err
423 }
424
425 func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
426 return json.Unmarshal(msg, v)
427 }
428
429
448 var JSON = Codec{jsonMarshal, jsonUnmarshal}
449
View as plain text