...

Source file src/github.com/mdlayher/socket/internal/sockettest/sockettest.go

Documentation: github.com/mdlayher/socket/internal/sockettest

     1  // Package sockettest implements net.Listener and net.Conn types based on
     2  // *socket.Conn for use in the package's tests.
     3  package sockettest
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"time"
    12  
    13  	"github.com/mdlayher/socket"
    14  	"golang.org/x/sys/unix"
    15  )
    16  
    17  // A Listener is a net.Listener which can be extended with context support.
    18  type Listener struct {
    19  	addr *net.TCPAddr
    20  	c    *socket.Conn
    21  	ctx  context.Context
    22  }
    23  
    24  func (l *Listener) Context(ctx context.Context) *Listener {
    25  	l.ctx = ctx
    26  	return l
    27  }
    28  
    29  // Listen creates an IPv6 TCP net.Listener backed by a *socket.Conn on the
    30  // specified port with optional configuration. Context ctx will be passed
    31  // to accept and accepted connections.
    32  func Listen(port int, cfg *socket.Config) (*Listener, error) {
    33  	c, err := socket.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0, "tcpv6-server", cfg)
    34  	if err != nil {
    35  		return nil, fmt.Errorf("failed to open socket: %v", err)
    36  	}
    37  
    38  	// Be sure to close the Conn if any of the system calls fail before we
    39  	// return the Conn to the caller.
    40  
    41  	if err := c.Bind(&unix.SockaddrInet6{Port: port}); err != nil {
    42  		_ = c.Close()
    43  		return nil, fmt.Errorf("failed to bind: %v", err)
    44  	}
    45  
    46  	if err := c.Listen(unix.SOMAXCONN); err != nil {
    47  		_ = c.Close()
    48  		return nil, fmt.Errorf("failed to listen: %v", err)
    49  	}
    50  
    51  	sa, err := c.Getsockname()
    52  	if err != nil {
    53  		_ = c.Close()
    54  		return nil, fmt.Errorf("failed to getsockname: %v", err)
    55  	}
    56  
    57  	return &Listener{
    58  		addr: newTCPAddr(sa),
    59  		c:    c,
    60  	}, nil
    61  }
    62  
    63  // FileListener creates an IPv6 TCP net.Listener backed by a *socket.Conn from
    64  // the input file.
    65  func FileListener(f *os.File) (*Listener, error) {
    66  	c, err := socket.FileConn(f, "tcpv6-server")
    67  	if err != nil {
    68  		return nil, fmt.Errorf("failed to open file conn: %v", err)
    69  	}
    70  
    71  	sa, err := c.Getsockname()
    72  	if err != nil {
    73  		_ = c.Close()
    74  		return nil, fmt.Errorf("failed to getsockname: %v", err)
    75  	}
    76  
    77  	return &Listener{
    78  		addr: newTCPAddr(sa),
    79  		c:    c,
    80  	}, nil
    81  }
    82  
    83  func (l *Listener) Addr() net.Addr { return l.addr }
    84  func (l *Listener) Close() error   { return l.c.Close() }
    85  func (l *Listener) Accept() (net.Conn, error) {
    86  	ctx := context.Background()
    87  	if l.ctx != nil {
    88  		ctx = l.ctx
    89  	}
    90  
    91  	// SOCK_CLOEXEC and SOCK_NONBLOCK set automatically by Accept when possible.
    92  	conn, rsa, err := l.c.Accept(ctx, 0)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	lsa, err := conn.Getsockname()
    98  	if err != nil {
    99  		// Don't leak the Conn if the system call fails.
   100  		_ = conn.Close()
   101  		return nil, err
   102  	}
   103  
   104  	c := &Conn{
   105  		Conn:   conn,
   106  		local:  newTCPAddr(lsa),
   107  		remote: newTCPAddr(rsa),
   108  	}
   109  
   110  	if l.ctx != nil {
   111  		return c.Context(l.ctx), nil
   112  	}
   113  
   114  	return c, nil
   115  }
   116  
   117  // A Conn is a net.Conn which can be extended with context support.
   118  type Conn struct {
   119  	Conn          *socket.Conn
   120  	local, remote *net.TCPAddr
   121  	ctx           context.Context
   122  }
   123  
   124  func (c *Conn) Context(ctx context.Context) *Conn {
   125  	c.ctx = ctx
   126  	return c
   127  }
   128  
   129  // Dial creates an IPv4 or IPv6 TCP net.Conn backed by a *socket.Conn with
   130  // optional configuration.
   131  func Dial(ctx context.Context, addr net.Addr, cfg *socket.Config) (*Conn, error) {
   132  	ta, ok := addr.(*net.TCPAddr)
   133  	if !ok {
   134  		return nil, fmt.Errorf("expected *net.TCPAddr, but got: %T", addr)
   135  	}
   136  
   137  	var (
   138  		family int
   139  		name   string
   140  		sa     unix.Sockaddr
   141  	)
   142  
   143  	if ta.IP.To16() != nil && ta.IP.To4() == nil {
   144  		// IPv6.
   145  		family = unix.AF_INET6
   146  		name = "tcpv6-client"
   147  
   148  		var sa6 unix.SockaddrInet6
   149  		copy(sa6.Addr[:], ta.IP)
   150  		sa6.Port = ta.Port
   151  
   152  		sa = &sa6
   153  	} else {
   154  		// IPv4.
   155  		family = unix.AF_INET
   156  		name = "tcpv4-client"
   157  
   158  		var sa4 unix.SockaddrInet4
   159  		copy(sa4.Addr[:], ta.IP.To4())
   160  		sa4.Port = ta.Port
   161  
   162  		sa = &sa4
   163  	}
   164  
   165  	c, err := socket.Socket(family, unix.SOCK_STREAM, 0, name, cfg)
   166  	if err != nil {
   167  		return nil, fmt.Errorf("failed to open socket: %v", err)
   168  	}
   169  
   170  	// Be sure to close the Conn if any of the system calls fail before we
   171  	// return the Conn to the caller.
   172  
   173  	rsa, err := c.Connect(ctx, sa)
   174  	if err != nil {
   175  		_ = c.Close()
   176  		// Don't wrap, we want the raw error for tests.
   177  		return nil, err
   178  	}
   179  
   180  	lsa, err := c.Getsockname()
   181  	if err != nil {
   182  		_ = c.Close()
   183  		return nil, err
   184  	}
   185  
   186  	return &Conn{
   187  		Conn:   c,
   188  		local:  newTCPAddr(lsa),
   189  		remote: newTCPAddr(rsa),
   190  	}, nil
   191  }
   192  
   193  func (c *Conn) Close() error                       { return c.Conn.Close() }
   194  func (c *Conn) CloseRead() error                   { return c.Conn.CloseRead() }
   195  func (c *Conn) CloseWrite() error                  { return c.Conn.CloseWrite() }
   196  func (c *Conn) LocalAddr() net.Addr                { return c.local }
   197  func (c *Conn) RemoteAddr() net.Addr               { return c.remote }
   198  func (c *Conn) SetDeadline(t time.Time) error      { return c.Conn.SetDeadline(t) }
   199  func (c *Conn) SetReadDeadline(t time.Time) error  { return c.Conn.SetReadDeadline(t) }
   200  func (c *Conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) }
   201  
   202  func (c *Conn) Read(b []byte) (int, error) {
   203  	var (
   204  		n   int
   205  		err error
   206  	)
   207  
   208  	if c.ctx != nil {
   209  		n, err = c.Conn.ReadContext(c.ctx, b)
   210  	} else {
   211  		n, err = c.Conn.Read(b)
   212  	}
   213  
   214  	return n, opError("read", err)
   215  }
   216  
   217  func (c *Conn) Write(b []byte) (int, error) {
   218  	var (
   219  		n   int
   220  		err error
   221  	)
   222  
   223  	if c.ctx != nil {
   224  		n, err = c.Conn.WriteContext(c.ctx, b)
   225  	} else {
   226  		n, err = c.Conn.Write(b)
   227  	}
   228  
   229  	return n, opError("write", err)
   230  }
   231  
   232  func opError(op string, err error) error {
   233  	// This is still a bit simplistic but sufficient for nettest.TestConn.
   234  	switch err {
   235  	case nil:
   236  		return nil
   237  	case io.EOF:
   238  		return io.EOF
   239  	default:
   240  		return &net.OpError{Op: op, Err: err}
   241  	}
   242  }
   243  
   244  func newTCPAddr(sa unix.Sockaddr) *net.TCPAddr {
   245  	switch sa := sa.(type) {
   246  	case *unix.SockaddrInet4:
   247  		return &net.TCPAddr{
   248  			IP:   sa.Addr[:],
   249  			Port: sa.Port,
   250  		}
   251  	case *unix.SockaddrInet6:
   252  		return &net.TCPAddr{
   253  			IP:   sa.Addr[:],
   254  			Port: sa.Port,
   255  		}
   256  	}
   257  
   258  	panic("unknown address family")
   259  }
   260  

View as plain text