1
2
3 package sockettest
4
5 import (
6 "context"
7 "fmt"
8 "io"
9 "net"
10 "os"
11 "time"
12
13 "github.com/mdlayher/socket"
14 "golang.org/x/sys/unix"
15 )
16
17
18 type Listener struct {
19 addr *net.TCPAddr
20 c *socket.Conn
21 ctx context.Context
22 }
23
24 func (l *Listener) Context(ctx context.Context) *Listener {
25 l.ctx = ctx
26 return l
27 }
28
29
30
31
32 func Listen(port int, cfg *socket.Config) (*Listener, error) {
33 c, err := socket.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0, "tcpv6-server", cfg)
34 if err != nil {
35 return nil, fmt.Errorf("failed to open socket: %v", err)
36 }
37
38
39
40
41 if err := c.Bind(&unix.SockaddrInet6{Port: port}); err != nil {
42 _ = c.Close()
43 return nil, fmt.Errorf("failed to bind: %v", err)
44 }
45
46 if err := c.Listen(unix.SOMAXCONN); err != nil {
47 _ = c.Close()
48 return nil, fmt.Errorf("failed to listen: %v", err)
49 }
50
51 sa, err := c.Getsockname()
52 if err != nil {
53 _ = c.Close()
54 return nil, fmt.Errorf("failed to getsockname: %v", err)
55 }
56
57 return &Listener{
58 addr: newTCPAddr(sa),
59 c: c,
60 }, nil
61 }
62
63
64
65 func FileListener(f *os.File) (*Listener, error) {
66 c, err := socket.FileConn(f, "tcpv6-server")
67 if err != nil {
68 return nil, fmt.Errorf("failed to open file conn: %v", err)
69 }
70
71 sa, err := c.Getsockname()
72 if err != nil {
73 _ = c.Close()
74 return nil, fmt.Errorf("failed to getsockname: %v", err)
75 }
76
77 return &Listener{
78 addr: newTCPAddr(sa),
79 c: c,
80 }, nil
81 }
82
83 func (l *Listener) Addr() net.Addr { return l.addr }
84 func (l *Listener) Close() error { return l.c.Close() }
85 func (l *Listener) Accept() (net.Conn, error) {
86 ctx := context.Background()
87 if l.ctx != nil {
88 ctx = l.ctx
89 }
90
91
92 conn, rsa, err := l.c.Accept(ctx, 0)
93 if err != nil {
94 return nil, err
95 }
96
97 lsa, err := conn.Getsockname()
98 if err != nil {
99
100 _ = conn.Close()
101 return nil, err
102 }
103
104 c := &Conn{
105 Conn: conn,
106 local: newTCPAddr(lsa),
107 remote: newTCPAddr(rsa),
108 }
109
110 if l.ctx != nil {
111 return c.Context(l.ctx), nil
112 }
113
114 return c, nil
115 }
116
117
118 type Conn struct {
119 Conn *socket.Conn
120 local, remote *net.TCPAddr
121 ctx context.Context
122 }
123
124 func (c *Conn) Context(ctx context.Context) *Conn {
125 c.ctx = ctx
126 return c
127 }
128
129
130
131 func Dial(ctx context.Context, addr net.Addr, cfg *socket.Config) (*Conn, error) {
132 ta, ok := addr.(*net.TCPAddr)
133 if !ok {
134 return nil, fmt.Errorf("expected *net.TCPAddr, but got: %T", addr)
135 }
136
137 var (
138 family int
139 name string
140 sa unix.Sockaddr
141 )
142
143 if ta.IP.To16() != nil && ta.IP.To4() == nil {
144
145 family = unix.AF_INET6
146 name = "tcpv6-client"
147
148 var sa6 unix.SockaddrInet6
149 copy(sa6.Addr[:], ta.IP)
150 sa6.Port = ta.Port
151
152 sa = &sa6
153 } else {
154
155 family = unix.AF_INET
156 name = "tcpv4-client"
157
158 var sa4 unix.SockaddrInet4
159 copy(sa4.Addr[:], ta.IP.To4())
160 sa4.Port = ta.Port
161
162 sa = &sa4
163 }
164
165 c, err := socket.Socket(family, unix.SOCK_STREAM, 0, name, cfg)
166 if err != nil {
167 return nil, fmt.Errorf("failed to open socket: %v", err)
168 }
169
170
171
172
173 rsa, err := c.Connect(ctx, sa)
174 if err != nil {
175 _ = c.Close()
176
177 return nil, err
178 }
179
180 lsa, err := c.Getsockname()
181 if err != nil {
182 _ = c.Close()
183 return nil, err
184 }
185
186 return &Conn{
187 Conn: c,
188 local: newTCPAddr(lsa),
189 remote: newTCPAddr(rsa),
190 }, nil
191 }
192
193 func (c *Conn) Close() error { return c.Conn.Close() }
194 func (c *Conn) CloseRead() error { return c.Conn.CloseRead() }
195 func (c *Conn) CloseWrite() error { return c.Conn.CloseWrite() }
196 func (c *Conn) LocalAddr() net.Addr { return c.local }
197 func (c *Conn) RemoteAddr() net.Addr { return c.remote }
198 func (c *Conn) SetDeadline(t time.Time) error { return c.Conn.SetDeadline(t) }
199 func (c *Conn) SetReadDeadline(t time.Time) error { return c.Conn.SetReadDeadline(t) }
200 func (c *Conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) }
201
202 func (c *Conn) Read(b []byte) (int, error) {
203 var (
204 n int
205 err error
206 )
207
208 if c.ctx != nil {
209 n, err = c.Conn.ReadContext(c.ctx, b)
210 } else {
211 n, err = c.Conn.Read(b)
212 }
213
214 return n, opError("read", err)
215 }
216
217 func (c *Conn) Write(b []byte) (int, error) {
218 var (
219 n int
220 err error
221 )
222
223 if c.ctx != nil {
224 n, err = c.Conn.WriteContext(c.ctx, b)
225 } else {
226 n, err = c.Conn.Write(b)
227 }
228
229 return n, opError("write", err)
230 }
231
232 func opError(op string, err error) error {
233
234 switch err {
235 case nil:
236 return nil
237 case io.EOF:
238 return io.EOF
239 default:
240 return &net.OpError{Op: op, Err: err}
241 }
242 }
243
244 func newTCPAddr(sa unix.Sockaddr) *net.TCPAddr {
245 switch sa := sa.(type) {
246 case *unix.SockaddrInet4:
247 return &net.TCPAddr{
248 IP: sa.Addr[:],
249 Port: sa.Port,
250 }
251 case *unix.SockaddrInet6:
252 return &net.TCPAddr{
253 IP: sa.Addr[:],
254 Port: sa.Port,
255 }
256 }
257
258 panic("unknown address family")
259 }
260
View as plain text