1 package socket_test
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "math"
10 "net"
11 "net/netip"
12 "os"
13 "runtime"
14 "sync"
15 "testing"
16 "time"
17
18 "github.com/google/go-cmp/cmp"
19 "github.com/google/go-cmp/cmp/cmpopts"
20 "github.com/mdlayher/socket/internal/sockettest"
21 "golang.org/x/net/nettest"
22 "golang.org/x/sync/errgroup"
23 "golang.org/x/sys/unix"
24 )
25
26 func TestConn(t *testing.T) {
27 t.Parallel()
28
29 tests := []struct {
30 name string
31 pipe nettest.MakePipe
32 }{
33
34 {
35 name: "basic",
36 pipe: makePipe(
37 func() (net.Listener, error) {
38 return sockettest.Listen(0, nil)
39 },
40 func(addr net.Addr) (net.Conn, error) {
41 return sockettest.Dial(context.Background(), addr, nil)
42 },
43 ),
44 },
45
46 {
47 name: "context",
48 pipe: makePipe(
49 func() (net.Listener, error) {
50 l, err := sockettest.Listen(0, nil)
51 if err != nil {
52 return nil, err
53 }
54
55 return l.Context(context.Background()), nil
56 },
57 func(addr net.Addr) (net.Conn, error) {
58 ctx := context.Background()
59
60 c, err := sockettest.Dial(ctx, addr, nil)
61 if err != nil {
62 return nil, err
63 }
64
65 return c.Context(ctx), nil
66 },
67 ),
68 },
69 }
70
71 for _, tt := range tests {
72 tt := tt
73 t.Run(tt.name, func(t *testing.T) {
74 t.Parallel()
75
76 nettest.TestConn(t, tt.pipe)
77
78
79 t.Run("CloseReadWrite", func(t *testing.T) { timeoutWrapper(t, tt.pipe, testCloseReadWrite) })
80 })
81 }
82 }
83
84 func TestDialTCPNoListener(t *testing.T) {
85 t.Parallel()
86
87
88
89
90
91
92
93 _, err := sockettest.Dial(context.Background(), &net.TCPAddr{
94 IP: net.IPv6loopback,
95 Port: math.MaxUint16,
96 }, nil)
97
98 want := os.NewSyscallError("connect", unix.ECONNREFUSED)
99 if diff := cmp.Diff(want, err); diff != "" {
100 t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
101 }
102 }
103
104 func TestDialTCPContextCanceledBefore(t *testing.T) {
105 t.Parallel()
106
107
108 ctx, cancel := context.WithCancel(context.Background())
109 cancel()
110
111 _, err := sockettest.Dial(ctx, &net.TCPAddr{
112 IP: net.IPv6loopback,
113 Port: math.MaxUint16,
114 }, nil)
115
116 if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
117 t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
118 }
119 }
120
121 var ipTests = []struct {
122 name string
123 ip netip.Addr
124 }{
125
126
127 {
128 name: "IPv4",
129 ip: netip.MustParseAddr("192.0.2.1"),
130 },
131 {
132 name: "IPv6",
133 ip: netip.MustParseAddr("2001:db8::1"),
134 },
135 }
136
137 func TestDialTCPContextCanceledDuring(t *testing.T) {
138 t.Parallel()
139
140 for _, tt := range ipTests {
141 tt := tt
142 t.Run(tt.name, func(t *testing.T) {
143 t.Parallel()
144
145
146
147 ctx, cancel := context.WithCancel(context.Background())
148 defer cancel()
149
150 go func() {
151 time.Sleep(1 * time.Second)
152 cancel()
153 }()
154
155 _, err := sockettest.Dial(ctx, &net.TCPAddr{
156 IP: tt.ip.AsSlice(),
157 Port: math.MaxUint16,
158 }, nil)
159 if errors.Is(err, unix.ENETUNREACH) || errors.Is(err, unix.EHOSTUNREACH) {
160 t.Skipf("skipping, no outbound %s connectivity: %v", tt.name, err)
161 }
162
163 if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
164 t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
165 }
166 })
167 }
168 }
169
170 func TestDialTCPContextDeadlineExceeded(t *testing.T) {
171 t.Parallel()
172
173 for _, tt := range ipTests {
174 tt := tt
175 t.Run(tt.name, func(t *testing.T) {
176 t.Parallel()
177
178
179 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
180 defer cancel()
181
182 _, err := sockettest.Dial(ctx, &net.TCPAddr{
183 IP: tt.ip.AsSlice(),
184 Port: math.MaxUint16,
185 }, nil)
186 if errors.Is(err, unix.ENETUNREACH) || errors.Is(err, unix.EHOSTUNREACH) {
187 t.Skipf("skipping, no outbound %s connectivity: %v", tt.name, err)
188 }
189
190 if diff := cmp.Diff(context.DeadlineExceeded, err, cmpopts.EquateErrors()); diff != "" {
191 t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
192 }
193 })
194 }
195 }
196
197 func TestListenerAcceptTCPContextCanceledBefore(t *testing.T) {
198 t.Parallel()
199
200 l, err := sockettest.Listen(0, nil)
201 if err != nil {
202 t.Fatalf("failed to listen: %v", err)
203 }
204 defer l.Close()
205
206
207 ctx, cancel := context.WithCancel(context.Background())
208 cancel()
209
210 _, err = l.Context(ctx).Accept()
211 if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
212 t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
213 }
214 }
215
216 func TestListenerAcceptTCPContextCanceledDuring(t *testing.T) {
217 t.Parallel()
218
219 l, err := sockettest.Listen(0, nil)
220 if err != nil {
221 t.Fatalf("failed to listen: %v", err)
222 }
223 defer l.Close()
224
225
226
227 ctx, cancel := context.WithCancel(context.Background())
228 defer cancel()
229
230 go func() {
231 time.Sleep(1 * time.Second)
232 cancel()
233 }()
234
235 _, err = l.Context(ctx).Accept()
236 if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
237 t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
238 }
239 }
240
241 func TestListenerAcceptTCPContextDeadlineExceeded(t *testing.T) {
242 t.Parallel()
243
244 l, err := sockettest.Listen(0, nil)
245 if err != nil {
246 t.Fatalf("failed to listen: %v", err)
247 }
248 defer l.Close()
249
250
251 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
252 defer cancel()
253
254 _, err = l.Context(ctx).Accept()
255 if diff := cmp.Diff(context.DeadlineExceeded, err, cmpopts.EquateErrors()); diff != "" {
256 t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
257 }
258 }
259
260 func TestListenerConnTCPContextCanceled(t *testing.T) {
261 t.Parallel()
262
263 l, err := sockettest.Listen(0, nil)
264 if err != nil {
265 t.Fatalf("failed to open listener: %v", err)
266 }
267 defer l.Close()
268
269
270 var eg errgroup.Group
271 eg.Go(func() error {
272 c, err := l.Accept()
273 if err != nil {
274 return fmt.Errorf("failed to accept: %v", err)
275 }
276 defer c.Close()
277
278
279 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
280 defer cancel()
281
282 b := make([]byte, 1024)
283 _, _, err = c.(*sockettest.Conn).Conn.Recvfrom(ctx, b, 0)
284 return err
285 })
286
287 c, err := net.Dial(l.Addr().Network(), l.Addr().String())
288 if err != nil {
289 t.Fatalf("failed to dial listener: %v", err)
290 }
291 defer c.Close()
292
293
294 if diff := cmp.Diff(context.DeadlineExceeded, eg.Wait(), cmpopts.EquateErrors()); diff != "" {
295 t.Fatalf("unexpected recvfrom error (-want +got):\n%s", diff)
296 }
297 }
298
299 func TestListenerConnTCPContextDeadlineExceeded(t *testing.T) {
300 t.Parallel()
301
302 l, err := sockettest.Listen(0, nil)
303 if err != nil {
304 t.Fatalf("failed to open listener: %v", err)
305 }
306 defer l.Close()
307
308
309 var eg errgroup.Group
310 eg.Go(func() error {
311 c, err := l.Accept()
312 if err != nil {
313 return fmt.Errorf("failed to accept: %v", err)
314 }
315 defer c.Close()
316
317
318 ctx, cancel := context.WithCancel(context.Background())
319 cancel()
320
321 b := make([]byte, 1024)
322 _, _, err = c.(*sockettest.Conn).Conn.Recvfrom(ctx, b, 0)
323 return err
324 })
325
326 c, err := net.Dial(l.Addr().Network(), l.Addr().String())
327 if err != nil {
328 t.Fatalf("failed to dial listener: %v", err)
329 }
330 defer c.Close()
331
332
333 if diff := cmp.Diff(context.Canceled, eg.Wait(), cmpopts.EquateErrors()); diff != "" {
334 t.Fatalf("unexpected recvfrom error (-want +got):\n%s", diff)
335 }
336 }
337
338 func TestFileConn(t *testing.T) {
339 t.Parallel()
340
341
342
343
344 fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0)
345 if err != nil {
346 t.Fatalf("failed to open socket: %v", err)
347 }
348
349
350 sa := &unix.SockaddrInet6{Addr: [16]byte{15: 0x01}}
351 if err := unix.Bind(fd, sa); err != nil {
352 t.Fatalf("failed to bind: %v", err)
353 }
354
355 if err := unix.Listen(fd, unix.SOMAXCONN); err != nil {
356 t.Fatalf("failed to listen: %v", err)
357 }
358
359
360
361 f := os.NewFile(uintptr(fd), "tcpv6-listener")
362 defer f.Close()
363
364 l, err := sockettest.FileListener(f)
365 if err != nil {
366 t.Fatalf("failed to open file listener: %v", err)
367 }
368 defer l.Close()
369
370
371
372
373 var eg errgroup.Group
374 eg.Go(func() error {
375 c, err := l.Accept()
376 if err != nil {
377 return fmt.Errorf("failed to accept: %v", err)
378 }
379
380 _ = c.Close()
381 return nil
382 })
383
384 c, err := net.Dial(l.Addr().Network(), l.Addr().String())
385 if err != nil {
386 t.Fatalf("failed to dial listener: %v", err)
387 }
388 _ = c.Close()
389
390 if err := eg.Wait(); err != nil {
391 t.Fatalf("failed to wait for listener goroutine: %v", err)
392 }
393 }
394
395
396
397
398
399
400
401
402
403
404 func makePipe(
405 listen func() (net.Listener, error),
406 dial func(addr net.Addr) (net.Conn, error),
407 ) nettest.MakePipe {
408 return func() (c1, c2 net.Conn, stop func(), err error) {
409 ln, err := listen()
410 if err != nil {
411 return nil, nil, nil, err
412 }
413
414
415 var err1, err2 error
416 done := make(chan bool)
417 go func() {
418 c2, err2 = ln.Accept()
419 close(done)
420 }()
421 c1, err1 = dial(ln.Addr())
422 <-done
423
424 stop = func() {
425 if err1 == nil {
426 c1.Close()
427 }
428 if err2 == nil {
429 c2.Close()
430 }
431 ln.Close()
432 }
433
434 switch {
435 case err1 != nil:
436 stop()
437 return nil, nil, nil, err1
438 case err2 != nil:
439 stop()
440 return nil, nil, nil, err2
441 default:
442 return c1, c2, stop, nil
443 }
444 }
445 }
446
447
448
449 type connTester func(t *testing.T, c1, c2 net.Conn)
450
451 func timeoutWrapper(t *testing.T, mp nettest.MakePipe, f connTester) {
452 t.Helper()
453 c1, c2, stop, err := mp()
454 if err != nil {
455 t.Fatalf("unable to make pipe: %v", err)
456 }
457 var once sync.Once
458 defer once.Do(func() { stop() })
459 timer := time.AfterFunc(time.Minute, func() {
460 once.Do(func() {
461 t.Error("test timed out; terminating pipe")
462 stop()
463 })
464 })
465 defer timer.Stop()
466 f(t, c1, c2)
467 }
468
469
470
471 func testCloseReadWrite(t *testing.T, c1, c2 net.Conn) {
472
473 if runtime.GOOS != "linux" {
474 t.Skip("skipping, not supported on non-Linux platforms")
475 }
476
477 type closerConn interface {
478 net.Conn
479 CloseRead() error
480 CloseWrite() error
481 }
482
483 cc1, ok1 := c1.(closerConn)
484 cc2, ok2 := c2.(closerConn)
485 if !ok1 || !ok2 {
486
487 return
488 }
489
490 var wg sync.WaitGroup
491 wg.Add(2)
492 defer wg.Wait()
493
494 go func() {
495 defer wg.Done()
496
497
498
499 b := make([]byte, 64)
500 if err := chunkedCopy(cc1, bytes.NewReader(b)); err != nil {
501 t.Errorf("unexpected initial cc1.Write error: %v", err)
502 }
503 if err := cc1.CloseWrite(); err != nil {
504 t.Errorf("unexpected cc1.CloseWrite error: %v", err)
505 }
506 _, err := cc1.Write(b)
507 if nerr, ok := err.(net.Error); !ok || nerr.Timeout() {
508 t.Errorf("unexpected final cc1.Write error: %v", err)
509 }
510 }()
511
512 go func() {
513 defer wg.Done()
514
515
516
517 if err := chunkedCopy(io.Discard, cc2); err != nil {
518 t.Errorf("unexpected initial cc2.Read error: %v", err)
519 }
520 if err := cc2.CloseRead(); err != nil {
521 t.Errorf("unexpected cc2.CloseRead error: %v", err)
522 }
523 if _, err := cc2.Read(make([]byte, 64)); err != io.EOF {
524 t.Errorf("unexpected final cc2.Read error: %v", err)
525 }
526 }()
527 }
528
529
530
531
532
533 func chunkedCopy(w io.Writer, r io.Reader) error {
534 b := make([]byte, 1024)
535 _, err := io.CopyBuffer(struct{ io.Writer }{w}, struct{ io.Reader }{r}, b)
536 return err
537 }
538
View as plain text