...
1
2
3
4 package websocket
5
6 import (
7 "bufio"
8 "context"
9 "errors"
10 "fmt"
11 "io"
12 "net"
13 "runtime"
14 "strconv"
15 "sync"
16 "sync/atomic"
17 )
18
19
20
21 type MessageType int
22
23
24 const (
25
26 MessageText MessageType = iota + 1
27
28 MessageBinary
29 )
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 type Conn struct {
46 noCopy noCopy
47
48 subprotocol string
49 rwc io.ReadWriteCloser
50 client bool
51 copts *compressionOptions
52 flateThreshold int
53 br *bufio.Reader
54 bw *bufio.Writer
55
56 readTimeout chan context.Context
57 writeTimeout chan context.Context
58
59
60 readMu *mu
61 readHeaderBuf [8]byte
62 readControlBuf [maxControlPayload]byte
63 msgReader *msgReader
64 readCloseFrameErr error
65
66
67 msgWriter *msgWriter
68 writeFrameMu *mu
69 writeBuf []byte
70 writeHeaderBuf [8]byte
71 writeHeader header
72
73 wg sync.WaitGroup
74 closed chan struct{}
75 closeMu sync.Mutex
76 closeErr error
77 wroteClose bool
78
79 pingCounter int32
80 activePingsMu sync.Mutex
81 activePings map[string]chan<- struct{}
82 }
83
84 type connConfig struct {
85 subprotocol string
86 rwc io.ReadWriteCloser
87 client bool
88 copts *compressionOptions
89 flateThreshold int
90
91 br *bufio.Reader
92 bw *bufio.Writer
93 }
94
95 func newConn(cfg connConfig) *Conn {
96 c := &Conn{
97 subprotocol: cfg.subprotocol,
98 rwc: cfg.rwc,
99 client: cfg.client,
100 copts: cfg.copts,
101 flateThreshold: cfg.flateThreshold,
102
103 br: cfg.br,
104 bw: cfg.bw,
105
106 readTimeout: make(chan context.Context),
107 writeTimeout: make(chan context.Context),
108
109 closed: make(chan struct{}),
110 activePings: make(map[string]chan<- struct{}),
111 }
112
113 c.readMu = newMu(c)
114 c.writeFrameMu = newMu(c)
115
116 c.msgReader = newMsgReader(c)
117
118 c.msgWriter = newMsgWriter(c)
119 if c.client {
120 c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
121 }
122
123 if c.flate() && c.flateThreshold == 0 {
124 c.flateThreshold = 128
125 if !c.msgWriter.flateContextTakeover() {
126 c.flateThreshold = 512
127 }
128 }
129
130 runtime.SetFinalizer(c, func(c *Conn) {
131 c.close(errors.New("connection garbage collected"))
132 })
133
134 c.wg.Add(1)
135 go func() {
136 defer c.wg.Done()
137 c.timeoutLoop()
138 }()
139
140 return c
141 }
142
143
144
145 func (c *Conn) Subprotocol() string {
146 return c.subprotocol
147 }
148
149 func (c *Conn) close(err error) {
150 c.closeMu.Lock()
151 defer c.closeMu.Unlock()
152
153 if c.isClosed() {
154 return
155 }
156 if err == nil {
157 err = c.rwc.Close()
158 }
159 c.setCloseErrLocked(err)
160
161 close(c.closed)
162 runtime.SetFinalizer(c, nil)
163
164
165
166
167 c.rwc.Close()
168
169 c.wg.Add(1)
170 go func() {
171 defer c.wg.Done()
172 c.msgWriter.close()
173 c.msgReader.close()
174 }()
175 }
176
177 func (c *Conn) timeoutLoop() {
178 readCtx := context.Background()
179 writeCtx := context.Background()
180
181 for {
182 select {
183 case <-c.closed:
184 return
185
186 case writeCtx = <-c.writeTimeout:
187 case readCtx = <-c.readTimeout:
188
189 case <-readCtx.Done():
190 c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
191 c.wg.Add(1)
192 go func() {
193 defer c.wg.Done()
194 c.writeError(StatusPolicyViolation, errors.New("read timed out"))
195 }()
196 case <-writeCtx.Done():
197 c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
198 return
199 }
200 }
201 }
202
203 func (c *Conn) flate() bool {
204 return c.copts != nil
205 }
206
207
208
209
210
211
212
213
214 func (c *Conn) Ping(ctx context.Context) error {
215 p := atomic.AddInt32(&c.pingCounter, 1)
216
217 err := c.ping(ctx, strconv.Itoa(int(p)))
218 if err != nil {
219 return fmt.Errorf("failed to ping: %w", err)
220 }
221 return nil
222 }
223
224 func (c *Conn) ping(ctx context.Context, p string) error {
225 pong := make(chan struct{}, 1)
226
227 c.activePingsMu.Lock()
228 c.activePings[p] = pong
229 c.activePingsMu.Unlock()
230
231 defer func() {
232 c.activePingsMu.Lock()
233 delete(c.activePings, p)
234 c.activePingsMu.Unlock()
235 }()
236
237 err := c.writeControl(ctx, opPing, []byte(p))
238 if err != nil {
239 return err
240 }
241
242 select {
243 case <-c.closed:
244 return net.ErrClosed
245 case <-ctx.Done():
246 err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
247 c.close(err)
248 return err
249 case <-pong:
250 return nil
251 }
252 }
253
254 type mu struct {
255 c *Conn
256 ch chan struct{}
257 }
258
259 func newMu(c *Conn) *mu {
260 return &mu{
261 c: c,
262 ch: make(chan struct{}, 1),
263 }
264 }
265
266 func (m *mu) forceLock() {
267 m.ch <- struct{}{}
268 }
269
270 func (m *mu) tryLock() bool {
271 select {
272 case m.ch <- struct{}{}:
273 return true
274 default:
275 return false
276 }
277 }
278
279 func (m *mu) lock(ctx context.Context) error {
280 select {
281 case <-m.c.closed:
282 return net.ErrClosed
283 case <-ctx.Done():
284 err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
285 m.c.close(err)
286 return err
287 case m.ch <- struct{}{}:
288
289
290
291 select {
292 case <-m.c.closed:
293
294 m.unlock()
295 return net.ErrClosed
296 default:
297 }
298 return nil
299 }
300 }
301
302 func (m *mu) unlock() {
303 select {
304 case <-m.ch:
305 default:
306 }
307 }
308
309 type noCopy struct{}
310
311 func (*noCopy) Lock() {}
312
View as plain text