1 package websocket
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "net"
10 "net/http"
11 "reflect"
12 "runtime"
13 "strings"
14 "sync"
15 "syscall/js"
16
17 "nhooyr.io/websocket/internal/bpool"
18 "nhooyr.io/websocket/internal/wsjs"
19 "nhooyr.io/websocket/internal/xsync"
20 )
21
22
23 type opcode int
24
25
26 const (
27 opContinuation opcode = iota
28 opText
29 opBinary
30
31 _
32 _
33 _
34 _
35 _
36 opClose
37 opPing
38 opPong
39
40 )
41
42
43 type Conn struct {
44 noCopy noCopy
45 ws wsjs.WebSocket
46
47
48 msgReadLimit xsync.Int64
49
50 wg sync.WaitGroup
51 closingMu sync.Mutex
52 isReadClosed xsync.Int64
53 closeOnce sync.Once
54 closed chan struct{}
55 closeErrOnce sync.Once
56 closeErr error
57 closeWasClean bool
58
59 releaseOnClose func()
60 releaseOnError func()
61 releaseOnMessage func()
62
63 readSignal chan struct{}
64 readBufMu sync.Mutex
65 readBuf []wsjs.MessageEvent
66 }
67
68 func (c *Conn) close(err error, wasClean bool) {
69 c.closeOnce.Do(func() {
70 runtime.SetFinalizer(c, nil)
71
72 if !wasClean {
73 err = fmt.Errorf("unclean connection close: %w", err)
74 }
75 c.setCloseErr(err)
76 c.closeWasClean = wasClean
77 close(c.closed)
78 })
79 }
80
81 func (c *Conn) init() {
82 c.closed = make(chan struct{})
83 c.readSignal = make(chan struct{}, 1)
84
85 c.msgReadLimit.Store(32768)
86
87 c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
88 err := CloseError{
89 Code: StatusCode(e.Code),
90 Reason: e.Reason,
91 }
92
93
94
95 c.close(err, e.WasClean)
96
97 c.releaseOnClose()
98 c.releaseOnError()
99 c.releaseOnMessage()
100 })
101
102 c.releaseOnError = c.ws.OnError(func(v js.Value) {
103 c.setCloseErr(errors.New(v.Get("message").String()))
104 c.closeWithInternal()
105 })
106
107 c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
108 c.readBufMu.Lock()
109 defer c.readBufMu.Unlock()
110
111 c.readBuf = append(c.readBuf, e)
112
113
114 select {
115 case c.readSignal <- struct{}{}:
116 default:
117 }
118 })
119
120 runtime.SetFinalizer(c, func(c *Conn) {
121 c.setCloseErr(errors.New("connection garbage collected"))
122 c.closeWithInternal()
123 })
124 }
125
126 func (c *Conn) closeWithInternal() {
127 c.Close(StatusInternalError, "something went wrong")
128 }
129
130
131
132 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
133 if c.isReadClosed.Load() == 1 {
134 return 0, nil, errors.New("WebSocket connection read closed")
135 }
136
137 typ, p, err := c.read(ctx)
138 if err != nil {
139 return 0, nil, fmt.Errorf("failed to read: %w", err)
140 }
141 readLimit := c.msgReadLimit.Load()
142 if readLimit >= 0 && int64(len(p)) > readLimit {
143 err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
144 c.Close(StatusMessageTooBig, err.Error())
145 return 0, nil, err
146 }
147 return typ, p, nil
148 }
149
150 func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
151 select {
152 case <-ctx.Done():
153 c.Close(StatusPolicyViolation, "read timed out")
154 return 0, nil, ctx.Err()
155 case <-c.readSignal:
156 case <-c.closed:
157 return 0, nil, net.ErrClosed
158 }
159
160 c.readBufMu.Lock()
161 defer c.readBufMu.Unlock()
162
163 me := c.readBuf[0]
164
165
166 copy(c.readBuf, c.readBuf[1:])
167 c.readBuf = c.readBuf[:len(c.readBuf)-1]
168
169 if len(c.readBuf) > 0 {
170
171 select {
172 case c.readSignal <- struct{}{}:
173 default:
174 }
175 }
176
177 switch p := me.Data.(type) {
178 case string:
179 return MessageText, []byte(p), nil
180 case []byte:
181 return MessageBinary, p, nil
182 default:
183 panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
184 }
185 }
186
187
188 func (c *Conn) Ping(ctx context.Context) error {
189 return nil
190 }
191
192
193
194 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
195 err := c.write(ctx, typ, p)
196 if err != nil {
197
198
199
200
201 err := fmt.Errorf("failed to write: %w", err)
202 c.setCloseErr(err)
203 c.closeWithInternal()
204 return err
205 }
206 return nil
207 }
208
209 func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
210 if c.isClosed() {
211 return net.ErrClosed
212 }
213 switch typ {
214 case MessageBinary:
215 return c.ws.SendBytes(p)
216 case MessageText:
217 return c.ws.SendText(string(p))
218 default:
219 return fmt.Errorf("unexpected message type: %v", typ)
220 }
221 }
222
223
224
225
226
227 func (c *Conn) Close(code StatusCode, reason string) error {
228 defer c.wg.Wait()
229 err := c.exportedClose(code, reason)
230 if err != nil {
231 return fmt.Errorf("failed to close WebSocket: %w", err)
232 }
233 return nil
234 }
235
236
237
238
239
240
241 func (c *Conn) CloseNow() error {
242 defer c.wg.Wait()
243 return c.Close(StatusGoingAway, "")
244 }
245
246 func (c *Conn) exportedClose(code StatusCode, reason string) error {
247 c.closingMu.Lock()
248 defer c.closingMu.Unlock()
249
250 if c.isClosed() {
251 return net.ErrClosed
252 }
253
254 ce := fmt.Errorf("sent close: %w", CloseError{
255 Code: code,
256 Reason: reason,
257 })
258
259 c.setCloseErr(ce)
260 err := c.ws.Close(int(code), reason)
261 if err != nil {
262 return err
263 }
264
265 <-c.closed
266 if !c.closeWasClean {
267 return c.closeErr
268 }
269 return nil
270 }
271
272
273
274 func (c *Conn) Subprotocol() string {
275 return c.ws.Subprotocol()
276 }
277
278
279 type DialOptions struct {
280
281 Subprotocols []string
282 }
283
284
285
286
287
288 func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
289 c, resp, err := dial(ctx, url, opts)
290 if err != nil {
291 return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
292 }
293 return c, resp, nil
294 }
295
296 func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
297 if opts == nil {
298 opts = &DialOptions{}
299 }
300
301 url = strings.Replace(url, "http://", "ws://", 1)
302 url = strings.Replace(url, "https://", "wss://", 1)
303
304 ws, err := wsjs.New(url, opts.Subprotocols)
305 if err != nil {
306 return nil, nil, err
307 }
308
309 c := &Conn{
310 ws: ws,
311 }
312 c.init()
313
314 opench := make(chan struct{})
315 releaseOpen := ws.OnOpen(func(e js.Value) {
316 close(opench)
317 })
318 defer releaseOpen()
319
320 select {
321 case <-ctx.Done():
322 c.Close(StatusPolicyViolation, "dial timed out")
323 return nil, nil, ctx.Err()
324 case <-opench:
325 return c, &http.Response{
326 StatusCode: http.StatusSwitchingProtocols,
327 }, nil
328 case <-c.closed:
329 return nil, nil, net.ErrClosed
330 }
331 }
332
333
334
335 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
336 typ, p, err := c.Read(ctx)
337 if err != nil {
338 return 0, nil, err
339 }
340 return typ, bytes.NewReader(p), nil
341 }
342
343
344
345
346 func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
347 return &writer{
348 c: c,
349 ctx: ctx,
350 typ: typ,
351 b: bpool.Get(),
352 }, nil
353 }
354
355 type writer struct {
356 closed bool
357
358 c *Conn
359 ctx context.Context
360 typ MessageType
361
362 b *bytes.Buffer
363 }
364
365 func (w *writer) Write(p []byte) (int, error) {
366 if w.closed {
367 return 0, errors.New("cannot write to closed writer")
368 }
369 n, err := w.b.Write(p)
370 if err != nil {
371 return n, fmt.Errorf("failed to write message: %w", err)
372 }
373 return n, nil
374 }
375
376 func (w *writer) Close() error {
377 if w.closed {
378 return errors.New("cannot close closed writer")
379 }
380 w.closed = true
381 defer bpool.Put(w.b)
382
383 err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
384 if err != nil {
385 return fmt.Errorf("failed to close writer: %w", err)
386 }
387 return nil
388 }
389
390
391 func (c *Conn) CloseRead(ctx context.Context) context.Context {
392 c.isReadClosed.Store(1)
393
394 ctx, cancel := context.WithCancel(ctx)
395 c.wg.Add(1)
396 go func() {
397 defer c.CloseNow()
398 defer c.wg.Done()
399 defer cancel()
400 _, _, err := c.read(ctx)
401 if err != nil {
402 c.Close(StatusPolicyViolation, "unexpected data message")
403 }
404 }()
405 return ctx
406 }
407
408
409 func (c *Conn) SetReadLimit(n int64) {
410 c.msgReadLimit.Store(n)
411 }
412
413 func (c *Conn) setCloseErr(err error) {
414 c.closeErrOnce.Do(func() {
415 c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
416 })
417 }
418
419 func (c *Conn) isClosed() bool {
420 select {
421 case <-c.closed:
422 return true
423 default:
424 return false
425 }
426 }
427
428
429 type AcceptOptions struct {
430 Subprotocols []string
431 InsecureSkipVerify bool
432 OriginPatterns []string
433 CompressionMode CompressionMode
434 CompressionThreshold int
435 }
436
437
438 func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
439 return nil, errors.New("unimplemented")
440 }
441
442
443
444 type StatusCode int
445
446
447
448
449
450
451
452
453 const (
454 StatusNormalClosure StatusCode = 1000
455 StatusGoingAway StatusCode = 1001
456 StatusProtocolError StatusCode = 1002
457 StatusUnsupportedData StatusCode = 1003
458
459
460 statusReserved StatusCode = 1004
461
462
463
464
465 StatusNoStatusRcvd StatusCode = 1005
466
467
468
469
470 StatusAbnormalClosure StatusCode = 1006
471
472 StatusInvalidFramePayloadData StatusCode = 1007
473 StatusPolicyViolation StatusCode = 1008
474 StatusMessageTooBig StatusCode = 1009
475 StatusMandatoryExtension StatusCode = 1010
476 StatusInternalError StatusCode = 1011
477 StatusServiceRestart StatusCode = 1012
478 StatusTryAgainLater StatusCode = 1013
479 StatusBadGateway StatusCode = 1014
480
481
482
483
484 StatusTLSHandshake StatusCode = 1015
485 )
486
487
488
489
490
491 type CloseError struct {
492 Code StatusCode
493 Reason string
494 }
495
496 func (ce CloseError) Error() string {
497 return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
498 }
499
500
501
502
503
504 func CloseStatus(err error) StatusCode {
505 var ce CloseError
506 if errors.As(err, &ce) {
507 return ce.Code
508 }
509 return -1
510 }
511
512
513
514
515 type CompressionMode int
516
517 const (
518
519
520
521
522
523
524
525
526 CompressionNoContextTakeover CompressionMode = iota
527
528
529
530
531
532
533
534
535 CompressionContextTakeover
536
537
538
539
540
541
542 CompressionDisabled
543 )
544
545
546
547 type MessageType int
548
549
550 const (
551
552 MessageText MessageType = iota + 1
553
554 MessageBinary
555 )
556
557 type mu struct {
558 c *Conn
559 ch chan struct{}
560 }
561
562 func newMu(c *Conn) *mu {
563 return &mu{
564 c: c,
565 ch: make(chan struct{}, 1),
566 }
567 }
568
569 func (m *mu) forceLock() {
570 m.ch <- struct{}{}
571 }
572
573 func (m *mu) tryLock() bool {
574 select {
575 case m.ch <- struct{}{}:
576 return true
577 default:
578 return false
579 }
580 }
581
582 func (m *mu) unlock() {
583 select {
584 case <-m.ch:
585 default:
586 }
587 }
588
589 type noCopy struct{}
590
591 func (*noCopy) Lock() {}
592
View as plain text