1
2
3
4 package websocket
5
6 import (
7 "context"
8 "encoding/binary"
9 "errors"
10 "fmt"
11 "net"
12 "time"
13
14 "nhooyr.io/websocket/internal/errd"
15 )
16
17
18
19 type StatusCode int
20
21
22
23
24
25
26
27
28 const (
29 StatusNormalClosure StatusCode = 1000
30 StatusGoingAway StatusCode = 1001
31 StatusProtocolError StatusCode = 1002
32 StatusUnsupportedData StatusCode = 1003
33
34
35 statusReserved StatusCode = 1004
36
37
38
39
40 StatusNoStatusRcvd StatusCode = 1005
41
42
43
44
45 StatusAbnormalClosure StatusCode = 1006
46
47 StatusInvalidFramePayloadData StatusCode = 1007
48 StatusPolicyViolation StatusCode = 1008
49 StatusMessageTooBig StatusCode = 1009
50 StatusMandatoryExtension StatusCode = 1010
51 StatusInternalError StatusCode = 1011
52 StatusServiceRestart StatusCode = 1012
53 StatusTryAgainLater StatusCode = 1013
54 StatusBadGateway StatusCode = 1014
55
56
57
58
59 StatusTLSHandshake StatusCode = 1015
60 )
61
62
63
64
65
66 type CloseError struct {
67 Code StatusCode
68 Reason string
69 }
70
71 func (ce CloseError) Error() string {
72 return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
73 }
74
75
76
77
78
79 func CloseStatus(err error) StatusCode {
80 var ce CloseError
81 if errors.As(err, &ce) {
82 return ce.Code
83 }
84 return -1
85 }
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101 func (c *Conn) Close(code StatusCode, reason string) error {
102 defer c.wg.Wait()
103 return c.closeHandshake(code, reason)
104 }
105
106
107
108 func (c *Conn) CloseNow() (err error) {
109 defer c.wg.Wait()
110 defer errd.Wrap(&err, "failed to close WebSocket")
111
112 if c.isClosed() {
113 return net.ErrClosed
114 }
115
116 c.close(nil)
117 return c.closeErr
118 }
119
120 func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
121 defer errd.Wrap(&err, "failed to close WebSocket")
122
123 writeErr := c.writeClose(code, reason)
124 closeHandshakeErr := c.waitCloseHandshake()
125
126 if writeErr != nil {
127 return writeErr
128 }
129
130 if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) {
131 return closeHandshakeErr
132 }
133
134 return nil
135 }
136
137 func (c *Conn) writeClose(code StatusCode, reason string) error {
138 c.closeMu.Lock()
139 wroteClose := c.wroteClose
140 c.wroteClose = true
141 c.closeMu.Unlock()
142 if wroteClose {
143 return net.ErrClosed
144 }
145
146 ce := CloseError{
147 Code: code,
148 Reason: reason,
149 }
150
151 var p []byte
152 var marshalErr error
153 if ce.Code != StatusNoStatusRcvd {
154 p, marshalErr = ce.bytes()
155 }
156
157 writeErr := c.writeControl(context.Background(), opClose, p)
158 if CloseStatus(writeErr) != -1 {
159
160 writeErr = nil
161 }
162
163
164 c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
165
166 if marshalErr != nil {
167 return marshalErr
168 }
169 return writeErr
170 }
171
172 func (c *Conn) waitCloseHandshake() error {
173 defer c.close(nil)
174
175 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
176 defer cancel()
177
178 err := c.readMu.lock(ctx)
179 if err != nil {
180 return err
181 }
182 defer c.readMu.unlock()
183
184 if c.readCloseFrameErr != nil {
185 return c.readCloseFrameErr
186 }
187
188 for i := int64(0); i < c.msgReader.payloadLength; i++ {
189 _, err := c.br.ReadByte()
190 if err != nil {
191 return err
192 }
193 }
194
195 for {
196 h, err := c.readLoop(ctx)
197 if err != nil {
198 return err
199 }
200
201 for i := int64(0); i < h.payloadLength; i++ {
202 _, err := c.br.ReadByte()
203 if err != nil {
204 return err
205 }
206 }
207 }
208 }
209
210 func parseClosePayload(p []byte) (CloseError, error) {
211 if len(p) == 0 {
212 return CloseError{
213 Code: StatusNoStatusRcvd,
214 }, nil
215 }
216
217 if len(p) < 2 {
218 return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
219 }
220
221 ce := CloseError{
222 Code: StatusCode(binary.BigEndian.Uint16(p)),
223 Reason: string(p[2:]),
224 }
225
226 if !validWireCloseCode(ce.Code) {
227 return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
228 }
229
230 return ce, nil
231 }
232
233
234
235 func validWireCloseCode(code StatusCode) bool {
236 switch code {
237 case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
238 return false
239 }
240
241 if code >= StatusNormalClosure && code <= StatusBadGateway {
242 return true
243 }
244 if code >= 3000 && code <= 4999 {
245 return true
246 }
247
248 return false
249 }
250
251 func (ce CloseError) bytes() ([]byte, error) {
252 p, err := ce.bytesErr()
253 if err != nil {
254 err = fmt.Errorf("failed to marshal close frame: %w", err)
255 ce = CloseError{
256 Code: StatusInternalError,
257 }
258 p, _ = ce.bytesErr()
259 }
260 return p, err
261 }
262
263 const maxCloseReason = maxControlPayload - 2
264
265 func (ce CloseError) bytesErr() ([]byte, error) {
266 if len(ce.Reason) > maxCloseReason {
267 return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
268 }
269
270 if !validWireCloseCode(ce.Code) {
271 return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
272 }
273
274 buf := make([]byte, 2+len(ce.Reason))
275 binary.BigEndian.PutUint16(buf, uint16(ce.Code))
276 copy(buf[2:], ce.Reason)
277 return buf, nil
278 }
279
280 func (c *Conn) setCloseErr(err error) {
281 c.closeMu.Lock()
282 c.setCloseErrLocked(err)
283 c.closeMu.Unlock()
284 }
285
286 func (c *Conn) setCloseErrLocked(err error) {
287 if c.closeErr == nil && err != nil {
288 c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
289 }
290 }
291
292 func (c *Conn) isClosed() bool {
293 select {
294 case <-c.closed:
295 return true
296 default:
297 return false
298 }
299 }
300
View as plain text