1
2
3
4
5
6
7 package quic
8
9 import (
10 "encoding/binary"
11 "net"
12 "net/netip"
13 "sync"
14 "unsafe"
15
16 "golang.org/x/sys/unix"
17 )
18
19
20
21 type netUDPConn struct {
22 c *net.UDPConn
23 localAddr netip.AddrPort
24 }
25
26 func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) {
27 a, _ := uc.LocalAddr().(*net.UDPAddr)
28 localAddr := a.AddrPort()
29 if localAddr.Addr().IsUnspecified() {
30
31
32
33
34 localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port())
35 }
36
37 sc, err := uc.SyscallConn()
38 if err != nil {
39 return nil, err
40 }
41 sc.Control(func(fd uintptr) {
42
43
44
45
46
47 unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
48 unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
49 if !localAddr.IsValid() {
50 unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
51 unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
52 }
53 })
54
55 return &netUDPConn{
56 c: uc,
57 localAddr: localAddr,
58 }, nil
59 }
60
61 func (c *netUDPConn) Close() error { return c.c.Close() }
62
63 func (c *netUDPConn) LocalAddr() netip.AddrPort {
64 a, _ := c.c.LocalAddr().(*net.UDPAddr)
65 return a.AddrPort()
66 }
67
68 func (c *netUDPConn) Read(f func(*datagram)) {
69
70
71 const (
72 inPktinfoSize = 12
73 in6PktinfoSize = 20
74 ipTOSSize = 4
75 ipv6TclassSize = 4
76 )
77 control := make([]byte, 0+
78 unix.CmsgSpace(inPktinfoSize)+
79 unix.CmsgSpace(in6PktinfoSize)+
80 unix.CmsgSpace(ipTOSSize)+
81 unix.CmsgSpace(ipv6TclassSize))
82
83 for {
84 d := newDatagram()
85 n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control)
86 if err != nil {
87 return
88 }
89 if n == 0 {
90 continue
91 }
92 d.localAddr = c.localAddr
93 d.peerAddr = unmapAddrPort(peerAddr)
94 d.b = d.b[:n]
95 parseControl(d, control[:controlLen])
96 f(d)
97 }
98 }
99
100 var cmsgPool = sync.Pool{
101 New: func() any {
102 return new([]byte)
103 },
104 }
105
106 func (c *netUDPConn) Write(dgram datagram) error {
107 controlp := cmsgPool.Get().(*[]byte)
108 control := *controlp
109 defer func() {
110 *controlp = control[:0]
111 cmsgPool.Put(controlp)
112 }()
113
114 localIP := dgram.localAddr.Addr()
115 if localIP.IsValid() {
116 if localIP.Is4() {
117 control = appendCmsgIPSourceAddrV4(control, localIP)
118 } else {
119 control = appendCmsgIPSourceAddrV6(control, localIP)
120 }
121 }
122 if dgram.ecn != ecnNotECT {
123 if dgram.peerAddr.Addr().Is4() {
124 control = appendCmsgECNv4(control, dgram.ecn)
125 } else {
126 control = appendCmsgECNv6(control, dgram.ecn)
127 }
128 }
129
130 _, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr)
131 return err
132 }
133
134 func parseControl(d *datagram, control []byte) {
135 for len(control) > 0 {
136 hdr, data, remainder, err := unix.ParseOneSocketControlMessage(control)
137 if err != nil {
138 return
139 }
140 control = remainder
141 switch hdr.Level {
142 case unix.IPPROTO_IP:
143 switch hdr.Type {
144 case unix.IP_TOS, unix.IP_RECVTOS:
145
146
147 if ecn, ok := parseIPTOS(data); ok {
148 d.ecn = ecn
149 }
150 case unix.IP_PKTINFO:
151 if a, ok := parseInPktinfo(data); ok {
152 d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
153 }
154 }
155 case unix.IPPROTO_IPV6:
156 switch hdr.Type {
157 case unix.IPV6_TCLASS:
158
159
160 if ecn, ok := parseIPv6TCLASS(data); ok {
161 d.ecn = ecn
162 }
163 case unix.IPV6_PKTINFO:
164 if a, ok := parseIn6Pktinfo(data); ok {
165 d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
166 }
167 }
168 }
169 }
170 }
171
172
173
174 func parseIPv6TCLASS(b []byte) (ecnBits, bool) {
175 if len(b) != 4 {
176 return 0, false
177 }
178 return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true
179 }
180
181 func appendCmsgECNv6(b []byte, ecn ecnBits) []byte {
182 b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4)
183 binary.NativeEndian.PutUint32(data, uint32(ecn))
184 return b
185 }
186
187
188
189
190
191
192
193
194 func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) {
195 if len(b) != 12 {
196 return netip.Addr{}, false
197 }
198 return netip.AddrFrom4([4]byte(b[8:][:4])), true
199 }
200
201
202
203 func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte {
204
205
206
207
208
209 b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_PKTINFO, 12)
210 ip := src.As4()
211 copy(data[4:], ip[:])
212 return b
213 }
214
215
216
217
218
219
220
221 func parseIn6Pktinfo(b []byte) (netip.Addr, bool) {
222 if len(b) != 20 {
223 return netip.Addr{}, false
224 }
225 return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true
226 }
227
228
229
230 func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte {
231 b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20)
232 ip := src.As16()
233 copy(data[0:], ip[:])
234 return b
235 }
236
237
238
239 func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) {
240 off := len(b)
241 b = append(b, make([]byte, unix.CmsgSpace(size))...)
242 h := (*unix.Cmsghdr)(unsafe.Pointer(&b[off]))
243 h.Level = level
244 h.Type = typ
245 h.SetLen(unix.CmsgLen(size))
246 return b, b[off+unix.CmsgSpace(0):][:size]
247 }
248
View as plain text