...

Source file src/github.com/mdlayher/socket/conn_test.go

Documentation: github.com/mdlayher/socket

     1  package socket_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math"
    10  	"net"
    11  	"net/netip"
    12  	"os"
    13  	"runtime"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/google/go-cmp/cmp"
    19  	"github.com/google/go-cmp/cmp/cmpopts"
    20  	"github.com/mdlayher/socket/internal/sockettest"
    21  	"golang.org/x/net/nettest"
    22  	"golang.org/x/sync/errgroup"
    23  	"golang.org/x/sys/unix"
    24  )
    25  
    26  func TestConn(t *testing.T) {
    27  	t.Parallel()
    28  
    29  	tests := []struct {
    30  		name string
    31  		pipe nettest.MakePipe
    32  	}{
    33  		// Standard library plumbing.
    34  		{
    35  			name: "basic",
    36  			pipe: makePipe(
    37  				func() (net.Listener, error) {
    38  					return sockettest.Listen(0, nil)
    39  				},
    40  				func(addr net.Addr) (net.Conn, error) {
    41  					return sockettest.Dial(context.Background(), addr, nil)
    42  				},
    43  			),
    44  		},
    45  		// Our own implementations which have context cancelation support.
    46  		{
    47  			name: "context",
    48  			pipe: makePipe(
    49  				func() (net.Listener, error) {
    50  					l, err := sockettest.Listen(0, nil)
    51  					if err != nil {
    52  						return nil, err
    53  					}
    54  
    55  					return l.Context(context.Background()), nil
    56  				},
    57  				func(addr net.Addr) (net.Conn, error) {
    58  					ctx := context.Background()
    59  
    60  					c, err := sockettest.Dial(ctx, addr, nil)
    61  					if err != nil {
    62  						return nil, err
    63  					}
    64  
    65  					return c.Context(ctx), nil
    66  				},
    67  			),
    68  		},
    69  	}
    70  
    71  	for _, tt := range tests {
    72  		tt := tt
    73  		t.Run(tt.name, func(t *testing.T) {
    74  			t.Parallel()
    75  
    76  			nettest.TestConn(t, tt.pipe)
    77  
    78  			// Our own extensions to TestConn.
    79  			t.Run("CloseReadWrite", func(t *testing.T) { timeoutWrapper(t, tt.pipe, testCloseReadWrite) })
    80  		})
    81  	}
    82  }
    83  
    84  func TestDialTCPNoListener(t *testing.T) {
    85  	t.Parallel()
    86  
    87  	// See https://github.com/mdlayher/vsock/issues/47 and
    88  	// https://github.com/lxc/lxd/pull/9894 for context on this test.
    89  	//
    90  	//
    91  	// Given a (hopefully) non-existent listener on localhost, expect
    92  	// ECONNREFUSED.
    93  	_, err := sockettest.Dial(context.Background(), &net.TCPAddr{
    94  		IP:   net.IPv6loopback,
    95  		Port: math.MaxUint16,
    96  	}, nil)
    97  
    98  	want := os.NewSyscallError("connect", unix.ECONNREFUSED)
    99  	if diff := cmp.Diff(want, err); diff != "" {
   100  		t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
   101  	}
   102  }
   103  
   104  func TestDialTCPContextCanceledBefore(t *testing.T) {
   105  	t.Parallel()
   106  
   107  	// Context is canceled before any dialing can take place.
   108  	ctx, cancel := context.WithCancel(context.Background())
   109  	cancel()
   110  
   111  	_, err := sockettest.Dial(ctx, &net.TCPAddr{
   112  		IP:   net.IPv6loopback,
   113  		Port: math.MaxUint16,
   114  	}, nil)
   115  
   116  	if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
   117  		t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
   118  	}
   119  }
   120  
   121  var ipTests = []struct {
   122  	name string
   123  	ip   netip.Addr
   124  }{
   125  	// It appears we can dial addresses in the documentation range and
   126  	// connect will hang, which is perfect for this test case.
   127  	{
   128  		name: "IPv4",
   129  		ip:   netip.MustParseAddr("192.0.2.1"),
   130  	},
   131  	{
   132  		name: "IPv6",
   133  		ip:   netip.MustParseAddr("2001:db8::1"),
   134  	},
   135  }
   136  
   137  func TestDialTCPContextCanceledDuring(t *testing.T) {
   138  	t.Parallel()
   139  
   140  	for _, tt := range ipTests {
   141  		tt := tt
   142  		t.Run(tt.name, func(t *testing.T) {
   143  			t.Parallel()
   144  
   145  			// Context is canceled during a blocking operation but without an
   146  			// explicit deadline passed on the context.
   147  			ctx, cancel := context.WithCancel(context.Background())
   148  			defer cancel()
   149  
   150  			go func() {
   151  				time.Sleep(1 * time.Second)
   152  				cancel()
   153  			}()
   154  
   155  			_, err := sockettest.Dial(ctx, &net.TCPAddr{
   156  				IP:   tt.ip.AsSlice(),
   157  				Port: math.MaxUint16,
   158  			}, nil)
   159  			if errors.Is(err, unix.ENETUNREACH) || errors.Is(err, unix.EHOSTUNREACH) {
   160  				t.Skipf("skipping, no outbound %s connectivity: %v", tt.name, err)
   161  			}
   162  
   163  			if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
   164  				t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
   165  			}
   166  		})
   167  	}
   168  }
   169  
   170  func TestDialTCPContextDeadlineExceeded(t *testing.T) {
   171  	t.Parallel()
   172  
   173  	for _, tt := range ipTests {
   174  		tt := tt
   175  		t.Run(tt.name, func(t *testing.T) {
   176  			t.Parallel()
   177  
   178  			// Dialing is canceled after the deadline passes.
   179  			ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   180  			defer cancel()
   181  
   182  			_, err := sockettest.Dial(ctx, &net.TCPAddr{
   183  				IP:   tt.ip.AsSlice(),
   184  				Port: math.MaxUint16,
   185  			}, nil)
   186  			if errors.Is(err, unix.ENETUNREACH) || errors.Is(err, unix.EHOSTUNREACH) {
   187  				t.Skipf("skipping, no outbound %s connectivity: %v", tt.name, err)
   188  			}
   189  
   190  			if diff := cmp.Diff(context.DeadlineExceeded, err, cmpopts.EquateErrors()); diff != "" {
   191  				t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
   192  			}
   193  		})
   194  	}
   195  }
   196  
   197  func TestListenerAcceptTCPContextCanceledBefore(t *testing.T) {
   198  	t.Parallel()
   199  
   200  	l, err := sockettest.Listen(0, nil)
   201  	if err != nil {
   202  		t.Fatalf("failed to listen: %v", err)
   203  	}
   204  	defer l.Close()
   205  
   206  	// Context is canceled before accept can take place.
   207  	ctx, cancel := context.WithCancel(context.Background())
   208  	cancel()
   209  
   210  	_, err = l.Context(ctx).Accept()
   211  	if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
   212  		t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
   213  	}
   214  }
   215  
   216  func TestListenerAcceptTCPContextCanceledDuring(t *testing.T) {
   217  	t.Parallel()
   218  
   219  	l, err := sockettest.Listen(0, nil)
   220  	if err != nil {
   221  		t.Fatalf("failed to listen: %v", err)
   222  	}
   223  	defer l.Close()
   224  
   225  	// Context is canceled during a blocking operation but without an
   226  	// explicit deadline passed on the context.
   227  	ctx, cancel := context.WithCancel(context.Background())
   228  	defer cancel()
   229  
   230  	go func() {
   231  		time.Sleep(1 * time.Second)
   232  		cancel()
   233  	}()
   234  
   235  	_, err = l.Context(ctx).Accept()
   236  	if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
   237  		t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
   238  	}
   239  }
   240  
   241  func TestListenerAcceptTCPContextDeadlineExceeded(t *testing.T) {
   242  	t.Parallel()
   243  
   244  	l, err := sockettest.Listen(0, nil)
   245  	if err != nil {
   246  		t.Fatalf("failed to listen: %v", err)
   247  	}
   248  	defer l.Close()
   249  
   250  	// Accept is canceled after the deadline passes.
   251  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   252  	defer cancel()
   253  
   254  	_, err = l.Context(ctx).Accept()
   255  	if diff := cmp.Diff(context.DeadlineExceeded, err, cmpopts.EquateErrors()); diff != "" {
   256  		t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
   257  	}
   258  }
   259  
   260  func TestListenerConnTCPContextCanceled(t *testing.T) {
   261  	t.Parallel()
   262  
   263  	l, err := sockettest.Listen(0, nil)
   264  	if err != nil {
   265  		t.Fatalf("failed to open listener: %v", err)
   266  	}
   267  	defer l.Close()
   268  
   269  	// Accept a single connection.
   270  	var eg errgroup.Group
   271  	eg.Go(func() error {
   272  		c, err := l.Accept()
   273  		if err != nil {
   274  			return fmt.Errorf("failed to accept: %v", err)
   275  		}
   276  		defer c.Close()
   277  
   278  		// Context is canceled during recvfrom.
   279  		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   280  		defer cancel()
   281  
   282  		b := make([]byte, 1024)
   283  		_, _, err = c.(*sockettest.Conn).Conn.Recvfrom(ctx, b, 0)
   284  		return err
   285  	})
   286  
   287  	c, err := net.Dial(l.Addr().Network(), l.Addr().String())
   288  	if err != nil {
   289  		t.Fatalf("failed to dial listener: %v", err)
   290  	}
   291  	defer c.Close()
   292  
   293  	// Client never sends data, so we wait until ctx cancel and errgroup return.
   294  	if diff := cmp.Diff(context.DeadlineExceeded, eg.Wait(), cmpopts.EquateErrors()); diff != "" {
   295  		t.Fatalf("unexpected recvfrom error (-want +got):\n%s", diff)
   296  	}
   297  }
   298  
   299  func TestListenerConnTCPContextDeadlineExceeded(t *testing.T) {
   300  	t.Parallel()
   301  
   302  	l, err := sockettest.Listen(0, nil)
   303  	if err != nil {
   304  		t.Fatalf("failed to open listener: %v", err)
   305  	}
   306  	defer l.Close()
   307  
   308  	// Accept a single connection.
   309  	var eg errgroup.Group
   310  	eg.Go(func() error {
   311  		c, err := l.Accept()
   312  		if err != nil {
   313  			return fmt.Errorf("failed to accept: %v", err)
   314  		}
   315  		defer c.Close()
   316  
   317  		// Context is canceled before recvfrom can take place.
   318  		ctx, cancel := context.WithCancel(context.Background())
   319  		cancel()
   320  
   321  		b := make([]byte, 1024)
   322  		_, _, err = c.(*sockettest.Conn).Conn.Recvfrom(ctx, b, 0)
   323  		return err
   324  	})
   325  
   326  	c, err := net.Dial(l.Addr().Network(), l.Addr().String())
   327  	if err != nil {
   328  		t.Fatalf("failed to dial listener: %v", err)
   329  	}
   330  	defer c.Close()
   331  
   332  	// Client never sends data, so we wait until ctx cancel and errgroup return.
   333  	if diff := cmp.Diff(context.Canceled, eg.Wait(), cmpopts.EquateErrors()); diff != "" {
   334  		t.Fatalf("unexpected recvfrom error (-want +got):\n%s", diff)
   335  	}
   336  }
   337  
   338  func TestFileConn(t *testing.T) {
   339  	t.Parallel()
   340  
   341  	// Use raw system calls to set up the socket since we assume anything being
   342  	// passed into a FileConn is set up by another system, such as systemd's
   343  	// socket activation.
   344  	fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0)
   345  	if err != nil {
   346  		t.Fatalf("failed to open socket: %v", err)
   347  	}
   348  
   349  	// Bind to loopback, any available port.
   350  	sa := &unix.SockaddrInet6{Addr: [16]byte{15: 0x01}}
   351  	if err := unix.Bind(fd, sa); err != nil {
   352  		t.Fatalf("failed to bind: %v", err)
   353  	}
   354  
   355  	if err := unix.Listen(fd, unix.SOMAXCONN); err != nil {
   356  		t.Fatalf("failed to listen: %v", err)
   357  	}
   358  
   359  	// The socket should be ready, create a blocking file which is ready to be
   360  	// passed into FileConn via the FileListener helper.
   361  	f := os.NewFile(uintptr(fd), "tcpv6-listener")
   362  	defer f.Close()
   363  
   364  	l, err := sockettest.FileListener(f)
   365  	if err != nil {
   366  		t.Fatalf("failed to open file listener: %v", err)
   367  	}
   368  	defer l.Close()
   369  
   370  	// To exercise the listener, attempt to accept and then immediately close a
   371  	// single TCPv6 connection. Dial to the listener from the main goroutine and
   372  	// wait for everything to finish.
   373  	var eg errgroup.Group
   374  	eg.Go(func() error {
   375  		c, err := l.Accept()
   376  		if err != nil {
   377  			return fmt.Errorf("failed to accept: %v", err)
   378  		}
   379  
   380  		_ = c.Close()
   381  		return nil
   382  	})
   383  
   384  	c, err := net.Dial(l.Addr().Network(), l.Addr().String())
   385  	if err != nil {
   386  		t.Fatalf("failed to dial listener: %v", err)
   387  	}
   388  	_ = c.Close()
   389  
   390  	if err := eg.Wait(); err != nil {
   391  		t.Fatalf("failed to wait for listener goroutine: %v", err)
   392  	}
   393  }
   394  
   395  // Use our TCP net.Listener and net.Conn implementations backed by *socket.Conn
   396  // and run compliance tests with nettest.TestConn.
   397  //
   398  // This nettest.MakePipe function is adapted from nettest's own tests:
   399  // https://github.com/golang/net/blob/master/nettest/conntest_test.go
   400  //
   401  // Copyright 2016 The Go Authors. All rights reserved. Use of this source
   402  // code is governed by a BSD-style license that can be found in the LICENSE
   403  // file.
   404  func makePipe(
   405  	listen func() (net.Listener, error),
   406  	dial func(addr net.Addr) (net.Conn, error),
   407  ) nettest.MakePipe {
   408  	return func() (c1, c2 net.Conn, stop func(), err error) {
   409  		ln, err := listen()
   410  		if err != nil {
   411  			return nil, nil, nil, err
   412  		}
   413  
   414  		// Start a connection between two endpoints.
   415  		var err1, err2 error
   416  		done := make(chan bool)
   417  		go func() {
   418  			c2, err2 = ln.Accept()
   419  			close(done)
   420  		}()
   421  		c1, err1 = dial(ln.Addr())
   422  		<-done
   423  
   424  		stop = func() {
   425  			if err1 == nil {
   426  				c1.Close()
   427  			}
   428  			if err2 == nil {
   429  				c2.Close()
   430  			}
   431  			ln.Close()
   432  		}
   433  
   434  		switch {
   435  		case err1 != nil:
   436  			stop()
   437  			return nil, nil, nil, err1
   438  		case err2 != nil:
   439  			stop()
   440  			return nil, nil, nil, err2
   441  		default:
   442  			return c1, c2, stop, nil
   443  		}
   444  	}
   445  }
   446  
   447  // Copied from x/net/nettest, pending acceptance of:
   448  // https://go-review.googlesource.com/c/net/+/372815
   449  type connTester func(t *testing.T, c1, c2 net.Conn)
   450  
   451  func timeoutWrapper(t *testing.T, mp nettest.MakePipe, f connTester) {
   452  	t.Helper()
   453  	c1, c2, stop, err := mp()
   454  	if err != nil {
   455  		t.Fatalf("unable to make pipe: %v", err)
   456  	}
   457  	var once sync.Once
   458  	defer once.Do(func() { stop() })
   459  	timer := time.AfterFunc(time.Minute, func() {
   460  		once.Do(func() {
   461  			t.Error("test timed out; terminating pipe")
   462  			stop()
   463  		})
   464  	})
   465  	defer timer.Stop()
   466  	f(t, c1, c2)
   467  }
   468  
   469  // testCloseReadWrite tests that net.Conns which also implement the optional
   470  // CloseRead and CloseWrite methods can be half-closed correctly.
   471  func testCloseReadWrite(t *testing.T, c1, c2 net.Conn) {
   472  	// TODO(mdlayher): investigate why Mac/Windows errors are so different.
   473  	if runtime.GOOS != "linux" {
   474  		t.Skip("skipping, not supported on non-Linux platforms")
   475  	}
   476  
   477  	type closerConn interface {
   478  		net.Conn
   479  		CloseRead() error
   480  		CloseWrite() error
   481  	}
   482  
   483  	cc1, ok1 := c1.(closerConn)
   484  	cc2, ok2 := c2.(closerConn)
   485  	if !ok1 || !ok2 {
   486  		// Both c1 and c2 must implement closerConn to proceed.
   487  		return
   488  	}
   489  
   490  	var wg sync.WaitGroup
   491  	wg.Add(2)
   492  	defer wg.Wait()
   493  
   494  	go func() {
   495  		defer wg.Done()
   496  
   497  		// Writing succeeds at first but should result in a permanent "broken
   498  		// pipe" error after closing the write side of the net.Conn.
   499  		b := make([]byte, 64)
   500  		if err := chunkedCopy(cc1, bytes.NewReader(b)); err != nil {
   501  			t.Errorf("unexpected initial cc1.Write error: %v", err)
   502  		}
   503  		if err := cc1.CloseWrite(); err != nil {
   504  			t.Errorf("unexpected cc1.CloseWrite error: %v", err)
   505  		}
   506  		_, err := cc1.Write(b)
   507  		if nerr, ok := err.(net.Error); !ok || nerr.Timeout() {
   508  			t.Errorf("unexpected final cc1.Write error: %v", err)
   509  		}
   510  	}()
   511  
   512  	go func() {
   513  		defer wg.Done()
   514  
   515  		// Reading succeeds at first but should result in an EOF error after
   516  		// closing the read side of the net.Conn.
   517  		if err := chunkedCopy(io.Discard, cc2); err != nil {
   518  			t.Errorf("unexpected initial cc2.Read error: %v", err)
   519  		}
   520  		if err := cc2.CloseRead(); err != nil {
   521  			t.Errorf("unexpected cc2.CloseRead error: %v", err)
   522  		}
   523  		if _, err := cc2.Read(make([]byte, 64)); err != io.EOF {
   524  			t.Errorf("unexpected final cc2.Read error: %v", err)
   525  		}
   526  	}()
   527  }
   528  
   529  // chunkedCopy copies from r to w in fixed-width chunks to avoid
   530  // causing a Write that exceeds the maximum packet size for packet-based
   531  // connections like "unixpacket".
   532  // We assume that the maximum packet size is at least 1024.
   533  func chunkedCopy(w io.Writer, r io.Reader) error {
   534  	b := make([]byte, 1024)
   535  	_, err := io.CopyBuffer(struct{ io.Writer }{w}, struct{ io.Reader }{r}, b)
   536  	return err
   537  }
   538  

View as plain text