1
2
3
4
5
6
7 package quic
8
9 import (
10 "context"
11 "crypto/rand"
12 "errors"
13 "net"
14 "net/netip"
15 "sync"
16 "sync/atomic"
17 "time"
18 )
19
20
21
22
23
24 type Endpoint struct {
25 listenConfig *Config
26 packetConn packetConn
27 testHooks endpointTestHooks
28 resetGen statelessResetTokenGenerator
29 retry retryState
30
31 acceptQueue queue[*Conn]
32 connsMap connsMap
33
34 connsMu sync.Mutex
35 conns map[*Conn]struct{}
36 closing bool
37 closec chan struct{}
38 }
39
40 type endpointTestHooks interface {
41 timeNow() time.Time
42 newConn(c *Conn)
43 }
44
45
46 type packetConn interface {
47 Close() error
48 LocalAddr() netip.AddrPort
49 Read(f func(*datagram))
50 Write(datagram) error
51 }
52
53
54
55
56
57 func Listen(network, address string, listenConfig *Config) (*Endpoint, error) {
58 if listenConfig != nil && listenConfig.TLSConfig == nil {
59 return nil, errors.New("TLSConfig is not set")
60 }
61 a, err := net.ResolveUDPAddr(network, address)
62 if err != nil {
63 return nil, err
64 }
65 udpConn, err := net.ListenUDP(network, a)
66 if err != nil {
67 return nil, err
68 }
69 pc, err := newNetUDPConn(udpConn)
70 if err != nil {
71 return nil, err
72 }
73 return newEndpoint(pc, listenConfig, nil)
74 }
75
76 func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
77 e := &Endpoint{
78 listenConfig: config,
79 packetConn: pc,
80 testHooks: hooks,
81 conns: make(map[*Conn]struct{}),
82 acceptQueue: newQueue[*Conn](),
83 closec: make(chan struct{}),
84 }
85 var statelessResetKey [32]byte
86 if config != nil {
87 statelessResetKey = config.StatelessResetKey
88 }
89 e.resetGen.init(statelessResetKey)
90 e.connsMap.init()
91 if config != nil && config.RequireAddressValidation {
92 if err := e.retry.init(); err != nil {
93 return nil, err
94 }
95 }
96 go e.listen()
97 return e, nil
98 }
99
100
101 func (e *Endpoint) LocalAddr() netip.AddrPort {
102 return e.packetConn.LocalAddr()
103 }
104
105
106
107
108
109
110
111
112 func (e *Endpoint) Close(ctx context.Context) error {
113 e.acceptQueue.close(errors.New("endpoint closed"))
114
115
116
117 var conns []*Conn
118 e.connsMu.Lock()
119 if !e.closing {
120 e.closing = true
121 for c := range e.conns {
122 conns = append(conns, c)
123 }
124 if len(e.conns) == 0 {
125 e.packetConn.Close()
126 }
127 }
128 e.connsMu.Unlock()
129
130 for _, c := range conns {
131 c.Abort(localTransportError{code: errNo})
132 }
133 select {
134 case <-e.closec:
135 case <-ctx.Done():
136 for _, c := range conns {
137 c.exit()
138 }
139 return ctx.Err()
140 }
141 return nil
142 }
143
144
145 func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) {
146 return e.acceptQueue.get(ctx, nil)
147 }
148
149
150
151 func (e *Endpoint) Dial(ctx context.Context, network, address string, config *Config) (*Conn, error) {
152 u, err := net.ResolveUDPAddr(network, address)
153 if err != nil {
154 return nil, err
155 }
156 addr := u.AddrPort()
157 addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
158 c, err := e.newConn(time.Now(), config, clientSide, newServerConnIDs{}, address, addr)
159 if err != nil {
160 return nil, err
161 }
162 if err := c.waitReady(ctx); err != nil {
163 c.Abort(nil)
164 return nil, err
165 }
166 return c, nil
167 }
168
169 func (e *Endpoint) newConn(now time.Time, config *Config, side connSide, cids newServerConnIDs, peerHostname string, peerAddr netip.AddrPort) (*Conn, error) {
170 e.connsMu.Lock()
171 defer e.connsMu.Unlock()
172 if e.closing {
173 return nil, errors.New("endpoint closed")
174 }
175 c, err := newConn(now, side, cids, peerHostname, peerAddr, config, e)
176 if err != nil {
177 return nil, err
178 }
179 e.conns[c] = struct{}{}
180 return c, nil
181 }
182
183
184
185 func (e *Endpoint) serverConnEstablished(c *Conn) {
186 e.acceptQueue.put(c)
187 }
188
189
190
191 func (e *Endpoint) connDrained(c *Conn) {
192 var cids [][]byte
193 for i := range c.connIDState.local {
194 cids = append(cids, c.connIDState.local[i].cid)
195 }
196 var tokens []statelessResetToken
197 for i := range c.connIDState.remote {
198 tokens = append(tokens, c.connIDState.remote[i].resetToken)
199 }
200 e.connsMap.updateConnIDs(func(conns *connsMap) {
201 for _, cid := range cids {
202 conns.retireConnID(c, cid)
203 }
204 for _, token := range tokens {
205 conns.retireResetToken(c, token)
206 }
207 })
208 e.connsMu.Lock()
209 defer e.connsMu.Unlock()
210 delete(e.conns, c)
211 if e.closing && len(e.conns) == 0 {
212 e.packetConn.Close()
213 }
214 }
215
216 func (e *Endpoint) listen() {
217 defer close(e.closec)
218 e.packetConn.Read(func(m *datagram) {
219 if e.connsMap.updateNeeded.Load() {
220 e.connsMap.applyUpdates()
221 }
222 e.handleDatagram(m)
223 })
224 }
225
226 func (e *Endpoint) handleDatagram(m *datagram) {
227 dstConnID, ok := dstConnIDForDatagram(m.b)
228 if !ok {
229 m.recycle()
230 return
231 }
232 c := e.connsMap.byConnID[string(dstConnID)]
233 if c == nil {
234
235
236 e.handleUnknownDestinationDatagram(m)
237 return
238 }
239
240
241
242 c.sendMsg(m)
243 }
244
245 func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) {
246 defer func() {
247 if m != nil {
248 m.recycle()
249 }
250 }()
251 const minimumValidPacketSize = 21
252 if len(m.b) < minimumValidPacketSize {
253 return
254 }
255 var now time.Time
256 if e.testHooks != nil {
257 now = e.testHooks.timeNow()
258 } else {
259 now = time.Now()
260 }
261
262 var token statelessResetToken
263 copy(token[:], m.b[len(m.b)-len(token):])
264 if c := e.connsMap.byResetToken[token]; c != nil {
265 c.sendMsg(func(now time.Time, c *Conn) {
266 c.handleStatelessReset(now, token)
267 })
268 return
269 }
270
271
272 if !isLongHeader(m.b[0]) {
273 e.maybeSendStatelessReset(m.b, m.peerAddr)
274 return
275 }
276 p, ok := parseGenericLongHeaderPacket(m.b)
277 if !ok || len(m.b) < paddedInitialDatagramSize {
278 return
279 }
280 switch p.version {
281 case quicVersion1:
282 case 0:
283
284 return
285 default:
286
287 e.sendVersionNegotiation(p, m.peerAddr)
288 return
289 }
290 if getPacketType(m.b) != packetTypeInitial {
291
292
293
294
295
296 return
297 }
298 if e.listenConfig == nil {
299
300 return
301 }
302 cids := newServerConnIDs{
303 srcConnID: p.srcConnID,
304 dstConnID: p.dstConnID,
305 }
306 if e.listenConfig.RequireAddressValidation {
307 var ok bool
308 cids.retrySrcConnID = p.dstConnID
309 cids.originalDstConnID, ok = e.validateInitialAddress(now, p, m.peerAddr)
310 if !ok {
311 return
312 }
313 } else {
314 cids.originalDstConnID = p.dstConnID
315 }
316 var err error
317 c, err := e.newConn(now, e.listenConfig, serverSide, cids, "", m.peerAddr)
318 if err != nil {
319
320
321
322
323 return
324 }
325 c.sendMsg(m)
326 m = nil
327 }
328
329 func (e *Endpoint) maybeSendStatelessReset(b []byte, peerAddr netip.AddrPort) {
330 if !e.resetGen.canReset {
331
332 return
333 }
334
335
336
337
338
339
340 if len(b) < 1+connIDLen+1+1+16 {
341 return
342 }
343
344 cid := b[1:][:connIDLen]
345 token := e.resetGen.tokenForConnID(cid)
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363 size := min(len(b)-1, 42)
364
365 b = b[:size]
366 rand.Read(b[:len(b)-statelessResetTokenLen])
367 b[0] &^= headerFormLong
368 b[0] |= fixedBit
369 copy(b[len(b)-statelessResetTokenLen:], token[:])
370 e.sendDatagram(datagram{
371 b: b,
372 peerAddr: peerAddr,
373 })
374 }
375
376 func (e *Endpoint) sendVersionNegotiation(p genericLongPacket, peerAddr netip.AddrPort) {
377 m := newDatagram()
378 m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
379 m.peerAddr = peerAddr
380 e.sendDatagram(*m)
381 m.recycle()
382 }
383
384 func (e *Endpoint) sendConnectionClose(in genericLongPacket, peerAddr netip.AddrPort, code transportError) {
385 keys := initialKeys(in.dstConnID, serverSide)
386 var w packetWriter
387 p := longPacket{
388 ptype: packetTypeInitial,
389 version: quicVersion1,
390 num: 0,
391 dstConnID: in.srcConnID,
392 srcConnID: in.dstConnID,
393 }
394 const pnumMaxAcked = 0
395 w.reset(paddedInitialDatagramSize)
396 w.startProtectedLongHeaderPacket(pnumMaxAcked, p)
397 w.appendConnectionCloseTransportFrame(code, 0, "")
398 w.finishProtectedLongHeaderPacket(pnumMaxAcked, keys.w, p)
399 buf := w.datagram()
400 if len(buf) == 0 {
401 return
402 }
403 e.sendDatagram(datagram{
404 b: buf,
405 peerAddr: peerAddr,
406 })
407 }
408
409 func (e *Endpoint) sendDatagram(dgram datagram) error {
410 return e.packetConn.Write(dgram)
411 }
412
413
414 type connsMap struct {
415 byConnID map[string]*Conn
416 byResetToken map[statelessResetToken]*Conn
417
418 updateMu sync.Mutex
419 updateNeeded atomic.Bool
420 updates []func(*connsMap)
421 }
422
423 func (m *connsMap) init() {
424 m.byConnID = map[string]*Conn{}
425 m.byResetToken = map[statelessResetToken]*Conn{}
426 }
427
428 func (m *connsMap) addConnID(c *Conn, cid []byte) {
429 m.byConnID[string(cid)] = c
430 }
431
432 func (m *connsMap) retireConnID(c *Conn, cid []byte) {
433 delete(m.byConnID, string(cid))
434 }
435
436 func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
437 m.byResetToken[token] = c
438 }
439
440 func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
441 delete(m.byResetToken, token)
442 }
443
444 func (m *connsMap) updateConnIDs(f func(*connsMap)) {
445 m.updateMu.Lock()
446 defer m.updateMu.Unlock()
447 m.updates = append(m.updates, f)
448 m.updateNeeded.Store(true)
449 }
450
451
452 func (m *connsMap) applyUpdates() {
453 m.updateMu.Lock()
454 defer m.updateMu.Unlock()
455 for _, f := range m.updates {
456 f(m)
457 }
458 clear(m.updates)
459 m.updates = m.updates[:0]
460 m.updateNeeded.Store(false)
461 }
462
View as plain text