1
2
3
4 package websocket
5
6 import (
7 "bufio"
8 "context"
9 "crypto/rand"
10 "encoding/binary"
11 "errors"
12 "fmt"
13 "io"
14 "net"
15 "time"
16
17 "compress/flate"
18
19 "nhooyr.io/websocket/internal/errd"
20 "nhooyr.io/websocket/internal/util"
21 )
22
23
24
25
26
27
28
29
30 func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
31 w, err := c.writer(ctx, typ)
32 if err != nil {
33 return nil, fmt.Errorf("failed to get writer: %w", err)
34 }
35 return w, nil
36 }
37
38
39
40
41
42
43
44 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
45 _, err := c.write(ctx, typ, p)
46 if err != nil {
47 return fmt.Errorf("failed to write msg: %w", err)
48 }
49 return nil
50 }
51
52 type msgWriter struct {
53 c *Conn
54
55 mu *mu
56 writeMu *mu
57 closed bool
58
59 ctx context.Context
60 opcode opcode
61 flate bool
62
63 trimWriter *trimLastFourBytesWriter
64 flateWriter *flate.Writer
65 }
66
67 func newMsgWriter(c *Conn) *msgWriter {
68 mw := &msgWriter{
69 c: c,
70 mu: newMu(c),
71 writeMu: newMu(c),
72 }
73 return mw
74 }
75
76 func (mw *msgWriter) ensureFlate() {
77 if mw.trimWriter == nil {
78 mw.trimWriter = &trimLastFourBytesWriter{
79 w: util.WriterFunc(mw.write),
80 }
81 }
82
83 if mw.flateWriter == nil {
84 mw.flateWriter = getFlateWriter(mw.trimWriter)
85 }
86 mw.flate = true
87 }
88
89 func (mw *msgWriter) flateContextTakeover() bool {
90 if mw.c.client {
91 return !mw.c.copts.clientNoContextTakeover
92 }
93 return !mw.c.copts.serverNoContextTakeover
94 }
95
96 func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
97 err := c.msgWriter.reset(ctx, typ)
98 if err != nil {
99 return nil, err
100 }
101 return c.msgWriter, nil
102 }
103
104 func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
105 mw, err := c.writer(ctx, typ)
106 if err != nil {
107 return 0, err
108 }
109
110 if !c.flate() {
111 defer c.msgWriter.mu.unlock()
112 return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
113 }
114
115 n, err := mw.Write(p)
116 if err != nil {
117 return n, err
118 }
119
120 err = mw.Close()
121 return n, err
122 }
123
124 func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
125 err := mw.mu.lock(ctx)
126 if err != nil {
127 return err
128 }
129
130 mw.ctx = ctx
131 mw.opcode = opcode(typ)
132 mw.flate = false
133 mw.closed = false
134
135 mw.trimWriter.reset()
136
137 return nil
138 }
139
140 func (mw *msgWriter) putFlateWriter() {
141 if mw.flateWriter != nil {
142 putFlateWriter(mw.flateWriter)
143 mw.flateWriter = nil
144 }
145 }
146
147
148 func (mw *msgWriter) Write(p []byte) (_ int, err error) {
149 err = mw.writeMu.lock(mw.ctx)
150 if err != nil {
151 return 0, fmt.Errorf("failed to write: %w", err)
152 }
153 defer mw.writeMu.unlock()
154
155 if mw.closed {
156 return 0, errors.New("cannot use closed writer")
157 }
158
159 defer func() {
160 if err != nil {
161 err = fmt.Errorf("failed to write: %w", err)
162 mw.c.close(err)
163 }
164 }()
165
166 if mw.c.flate() {
167
168
169 if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
170 mw.ensureFlate()
171 }
172 }
173
174 if mw.flate {
175 return mw.flateWriter.Write(p)
176 }
177
178 return mw.write(p)
179 }
180
181 func (mw *msgWriter) write(p []byte) (int, error) {
182 n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
183 if err != nil {
184 return n, fmt.Errorf("failed to write data frame: %w", err)
185 }
186 mw.opcode = opContinuation
187 return n, nil
188 }
189
190
191 func (mw *msgWriter) Close() (err error) {
192 defer errd.Wrap(&err, "failed to close writer")
193
194 err = mw.writeMu.lock(mw.ctx)
195 if err != nil {
196 return err
197 }
198 defer mw.writeMu.unlock()
199
200 if mw.closed {
201 return errors.New("writer already closed")
202 }
203 mw.closed = true
204
205 if mw.flate {
206 err = mw.flateWriter.Flush()
207 if err != nil {
208 return fmt.Errorf("failed to flush flate: %w", err)
209 }
210 }
211
212 _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
213 if err != nil {
214 return fmt.Errorf("failed to write fin frame: %w", err)
215 }
216
217 if mw.flate && !mw.flateContextTakeover() {
218 mw.putFlateWriter()
219 }
220 mw.mu.unlock()
221 return nil
222 }
223
224 func (mw *msgWriter) close() {
225 if mw.c.client {
226 mw.c.writeFrameMu.forceLock()
227 putBufioWriter(mw.c.bw)
228 }
229
230 mw.writeMu.forceLock()
231 mw.putFlateWriter()
232 }
233
234 func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
235 ctx, cancel := context.WithTimeout(ctx, time.Second*5)
236 defer cancel()
237
238 _, err := c.writeFrame(ctx, true, false, opcode, p)
239 if err != nil {
240 return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
241 }
242 return nil
243 }
244
245
246 func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
247 err = c.writeFrameMu.lock(ctx)
248 if err != nil {
249 return 0, err
250 }
251
252
253
254
255
256
257 c.closeMu.Lock()
258 wroteClose := c.wroteClose
259 c.closeMu.Unlock()
260 if wroteClose && opcode != opClose {
261 c.writeFrameMu.unlock()
262 select {
263 case <-ctx.Done():
264 return 0, ctx.Err()
265 case <-c.closed:
266 return 0, net.ErrClosed
267 }
268 }
269 defer c.writeFrameMu.unlock()
270
271 select {
272 case <-c.closed:
273 return 0, net.ErrClosed
274 case c.writeTimeout <- ctx:
275 }
276
277 defer func() {
278 if err != nil {
279 select {
280 case <-c.closed:
281 err = net.ErrClosed
282 case <-ctx.Done():
283 err = ctx.Err()
284 default:
285 }
286 c.close(err)
287 err = fmt.Errorf("failed to write frame: %w", err)
288 }
289 }()
290
291 c.writeHeader.fin = fin
292 c.writeHeader.opcode = opcode
293 c.writeHeader.payloadLength = int64(len(p))
294
295 if c.client {
296 c.writeHeader.masked = true
297 _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
298 if err != nil {
299 return 0, fmt.Errorf("failed to generate masking key: %w", err)
300 }
301 c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
302 }
303
304 c.writeHeader.rsv1 = false
305 if flate && (opcode == opText || opcode == opBinary) {
306 c.writeHeader.rsv1 = true
307 }
308
309 err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
310 if err != nil {
311 return 0, err
312 }
313
314 n, err := c.writeFramePayload(p)
315 if err != nil {
316 return n, err
317 }
318
319 if c.writeHeader.fin {
320 err = c.bw.Flush()
321 if err != nil {
322 return n, fmt.Errorf("failed to flush: %w", err)
323 }
324 }
325
326 select {
327 case <-c.closed:
328 if opcode == opClose {
329 return n, nil
330 }
331 return n, net.ErrClosed
332 case c.writeTimeout <- context.Background():
333 }
334
335 return n, nil
336 }
337
338 func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
339 defer errd.Wrap(&err, "failed to write frame payload")
340
341 if !c.writeHeader.masked {
342 return c.bw.Write(p)
343 }
344
345 maskKey := c.writeHeader.maskKey
346 for len(p) > 0 {
347
348 if c.bw.Available() == 0 {
349 err = c.bw.Flush()
350 if err != nil {
351 return n, err
352 }
353 }
354
355
356 i := c.bw.Buffered()
357
358 j := len(p)
359 if j > c.bw.Available() {
360 j = c.bw.Available()
361 }
362
363 _, err := c.bw.Write(p[:j])
364 if err != nil {
365 return n, err
366 }
367
368 maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
369
370 p = p[j:]
371 n += j
372 }
373
374 return n, nil
375 }
376
377
378
379 func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
380 var writeBuf []byte
381 bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
382 writeBuf = p2[:cap(p2)]
383 return len(p2), nil
384 }))
385
386 bw.WriteByte(0)
387 bw.Flush()
388
389 bw.Reset(w)
390
391 return writeBuf
392 }
393
394 func (c *Conn) writeError(code StatusCode, err error) {
395 c.setCloseErr(err)
396 c.writeClose(code, err.Error())
397 c.close(nil)
398 }
399
View as plain text