...
1 package websocket
2
3 import (
4 "context"
5 "fmt"
6 "io"
7 "math"
8 "net"
9 "sync/atomic"
10 "time"
11 )
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
49 c.SetReadLimit(-1)
50
51 nc := &netConn{
52 c: c,
53 msgType: msgType,
54 readMu: newMu(c),
55 writeMu: newMu(c),
56 }
57
58 nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
59 nc.readCtx, nc.readCancel = context.WithCancel(ctx)
60
61 nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
62 if !nc.writeMu.tryLock() {
63
64
65 nc.writeCancel()
66 return
67 }
68 defer nc.writeMu.unlock()
69
70
71 atomic.StoreInt64(&nc.writeExpired, 1)
72 })
73 if !nc.writeTimer.Stop() {
74 <-nc.writeTimer.C
75 }
76
77 nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
78 if !nc.readMu.tryLock() {
79
80
81 nc.readCancel()
82 return
83 }
84 defer nc.readMu.unlock()
85
86
87 atomic.StoreInt64(&nc.readExpired, 1)
88 })
89 if !nc.readTimer.Stop() {
90 <-nc.readTimer.C
91 }
92
93 return nc
94 }
95
96 type netConn struct {
97 c *Conn
98 msgType MessageType
99
100 writeTimer *time.Timer
101 writeMu *mu
102 writeExpired int64
103 writeCtx context.Context
104 writeCancel context.CancelFunc
105
106 readTimer *time.Timer
107 readMu *mu
108 readExpired int64
109 readCtx context.Context
110 readCancel context.CancelFunc
111 readEOFed bool
112 reader io.Reader
113 }
114
115 var _ net.Conn = &netConn{}
116
117 func (nc *netConn) Close() error {
118 nc.writeTimer.Stop()
119 nc.writeCancel()
120 nc.readTimer.Stop()
121 nc.readCancel()
122 return nc.c.Close(StatusNormalClosure, "")
123 }
124
125 func (nc *netConn) Write(p []byte) (int, error) {
126 nc.writeMu.forceLock()
127 defer nc.writeMu.unlock()
128
129 if atomic.LoadInt64(&nc.writeExpired) == 1 {
130 return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
131 }
132
133 err := nc.c.Write(nc.writeCtx, nc.msgType, p)
134 if err != nil {
135 return 0, err
136 }
137 return len(p), nil
138 }
139
140 func (nc *netConn) Read(p []byte) (int, error) {
141 nc.readMu.forceLock()
142 defer nc.readMu.unlock()
143
144 for {
145 n, err := nc.read(p)
146 if err != nil {
147 return n, err
148 }
149 if n == 0 {
150 continue
151 }
152 return n, nil
153 }
154 }
155
156 func (nc *netConn) read(p []byte) (int, error) {
157 if atomic.LoadInt64(&nc.readExpired) == 1 {
158 return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
159 }
160
161 if nc.readEOFed {
162 return 0, io.EOF
163 }
164
165 if nc.reader == nil {
166 typ, r, err := nc.c.Reader(nc.readCtx)
167 if err != nil {
168 switch CloseStatus(err) {
169 case StatusNormalClosure, StatusGoingAway:
170 nc.readEOFed = true
171 return 0, io.EOF
172 }
173 return 0, err
174 }
175 if typ != nc.msgType {
176 err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
177 nc.c.Close(StatusUnsupportedData, err.Error())
178 return 0, err
179 }
180 nc.reader = r
181 }
182
183 n, err := nc.reader.Read(p)
184 if err == io.EOF {
185 nc.reader = nil
186 err = nil
187 }
188 return n, err
189 }
190
191 type websocketAddr struct {
192 }
193
194 func (a websocketAddr) Network() string {
195 return "websocket"
196 }
197
198 func (a websocketAddr) String() string {
199 return "websocket/unknown-addr"
200 }
201
202 func (nc *netConn) SetDeadline(t time.Time) error {
203 nc.SetWriteDeadline(t)
204 nc.SetReadDeadline(t)
205 return nil
206 }
207
208 func (nc *netConn) SetWriteDeadline(t time.Time) error {
209 atomic.StoreInt64(&nc.writeExpired, 0)
210 if t.IsZero() {
211 nc.writeTimer.Stop()
212 } else {
213 dur := time.Until(t)
214 if dur <= 0 {
215 dur = 1
216 }
217 nc.writeTimer.Reset(dur)
218 }
219 return nil
220 }
221
222 func (nc *netConn) SetReadDeadline(t time.Time) error {
223 atomic.StoreInt64(&nc.readExpired, 0)
224 if t.IsZero() {
225 nc.readTimer.Stop()
226 } else {
227 dur := time.Until(t)
228 if dur <= 0 {
229 dur = 1
230 }
231 nc.readTimer.Reset(dur)
232 }
233 return nil
234 }
235
View as plain text