...

Source file src/github.com/mdlayher/socket/conn.go

Documentation: github.com/mdlayher/socket

     1  package socket
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"os"
     8  	"sync"
     9  	"sync/atomic"
    10  	"syscall"
    11  	"time"
    12  
    13  	"golang.org/x/sys/unix"
    14  )
    15  
    16  // Lock in an expected public interface for convenience.
    17  var _ interface {
    18  	io.ReadWriteCloser
    19  	syscall.Conn
    20  	SetDeadline(t time.Time) error
    21  	SetReadDeadline(t time.Time) error
    22  	SetWriteDeadline(t time.Time) error
    23  } = &Conn{}
    24  
    25  // A Conn is a low-level network connection which integrates with Go's runtime
    26  // network poller to provide asynchronous I/O and deadline support.
    27  //
    28  // Many of a Conn's blocking methods support net.Conn deadlines as well as
    29  // cancelation via context. Note that passing a context with a deadline set will
    30  // override any of the previous deadlines set by calls to the SetDeadline family
    31  // of methods.
    32  type Conn struct {
    33  	// Indicates whether or not Conn.Close has been called. Must be accessed
    34  	// atomically. Atomics definitions must come first in the Conn struct.
    35  	closed uint32
    36  
    37  	// A unique name for the Conn which is also associated with derived file
    38  	// descriptors such as those created by accept(2).
    39  	name string
    40  
    41  	// facts contains information we have determined about Conn to trigger
    42  	// alternate behavior in certain functions.
    43  	facts facts
    44  
    45  	// Provides access to the underlying file registered with the runtime
    46  	// network poller, and arbitrary raw I/O calls.
    47  	fd *os.File
    48  	rc syscall.RawConn
    49  }
    50  
    51  // facts contains facts about a Conn.
    52  type facts struct {
    53  	// isStream reports whether this is a streaming descriptor, as opposed to a
    54  	// packet-based descriptor like a UDP socket.
    55  	isStream bool
    56  
    57  	// zeroReadIsEOF reports Whether a zero byte read indicates EOF. This is
    58  	// false for a message based socket connection.
    59  	zeroReadIsEOF bool
    60  }
    61  
    62  // A Config contains options for a Conn.
    63  type Config struct {
    64  	// NetNS specifies the Linux network namespace the Conn will operate in.
    65  	// This option is unsupported on other operating systems.
    66  	//
    67  	// If set (non-zero), Conn will enter the specified network namespace and an
    68  	// error will occur in Socket if the operation fails.
    69  	//
    70  	// If not set (zero), a best-effort attempt will be made to enter the
    71  	// network namespace of the calling thread: this means that any changes made
    72  	// to the calling thread's network namespace will also be reflected in Conn.
    73  	// If this operation fails (due to lack of permissions or because network
    74  	// namespaces are disabled by kernel configuration), Socket will not return
    75  	// an error, and the Conn will operate in the default network namespace of
    76  	// the process. This enables non-privileged use of Conn in applications
    77  	// which do not require elevated privileges.
    78  	//
    79  	// Entering a network namespace is a privileged operation (root or
    80  	// CAP_SYS_ADMIN are required), and most applications should leave this set
    81  	// to 0.
    82  	NetNS int
    83  }
    84  
    85  // High-level methods which provide convenience over raw system calls.
    86  
    87  // Close closes the underlying file descriptor for the Conn, which also causes
    88  // all in-flight I/O operations to immediately unblock and return errors. Any
    89  // subsequent uses of Conn will result in EBADF.
    90  func (c *Conn) Close() error {
    91  	// The caller has expressed an intent to close the socket, so immediately
    92  	// increment s.closed to force further calls to result in EBADF before also
    93  	// closing the file descriptor to unblock any outstanding operations.
    94  	//
    95  	// Because other operations simply check for s.closed != 0, we will permit
    96  	// double Close, which would increment s.closed beyond 1.
    97  	if atomic.AddUint32(&c.closed, 1) != 1 {
    98  		// Multiple Close calls.
    99  		return nil
   100  	}
   101  
   102  	return os.NewSyscallError("close", c.fd.Close())
   103  }
   104  
   105  // CloseRead shuts down the reading side of the Conn. Most callers should just
   106  // use Close.
   107  func (c *Conn) CloseRead() error { return c.Shutdown(unix.SHUT_RD) }
   108  
   109  // CloseWrite shuts down the writing side of the Conn. Most callers should just
   110  // use Close.
   111  func (c *Conn) CloseWrite() error { return c.Shutdown(unix.SHUT_WR) }
   112  
   113  // Read reads directly from the underlying file descriptor.
   114  func (c *Conn) Read(b []byte) (int, error) { return c.fd.Read(b) }
   115  
   116  // ReadContext reads from the underlying file descriptor with added support for
   117  // context cancelation.
   118  func (c *Conn) ReadContext(ctx context.Context, b []byte) (int, error) {
   119  	if c.facts.isStream && len(b) > maxRW {
   120  		b = b[:maxRW]
   121  	}
   122  
   123  	n, err := readT(c, ctx, "read", func(fd int) (int, error) {
   124  		return unix.Read(fd, b)
   125  	})
   126  	if n == 0 && err == nil && c.facts.zeroReadIsEOF {
   127  		return 0, io.EOF
   128  	}
   129  
   130  	return n, os.NewSyscallError("read", err)
   131  }
   132  
   133  // Write writes directly to the underlying file descriptor.
   134  func (c *Conn) Write(b []byte) (int, error) { return c.fd.Write(b) }
   135  
   136  // WriteContext writes to the underlying file descriptor with added support for
   137  // context cancelation.
   138  func (c *Conn) WriteContext(ctx context.Context, b []byte) (int, error) {
   139  	var (
   140  		n, nn int
   141  		err   error
   142  	)
   143  
   144  	doErr := c.write(ctx, "write", func(fd int) error {
   145  		max := len(b)
   146  		if c.facts.isStream && max-nn > maxRW {
   147  			max = nn + maxRW
   148  		}
   149  
   150  		n, err = unix.Write(fd, b[nn:max])
   151  		if n > 0 {
   152  			nn += n
   153  		}
   154  		if nn == len(b) {
   155  			return err
   156  		}
   157  		if n == 0 && err == nil {
   158  			err = io.ErrUnexpectedEOF
   159  			return nil
   160  		}
   161  
   162  		return err
   163  	})
   164  	if doErr != nil {
   165  		return 0, doErr
   166  	}
   167  
   168  	return nn, os.NewSyscallError("write", err)
   169  }
   170  
   171  // SetDeadline sets both the read and write deadlines associated with the Conn.
   172  func (c *Conn) SetDeadline(t time.Time) error { return c.fd.SetDeadline(t) }
   173  
   174  // SetReadDeadline sets the read deadline associated with the Conn.
   175  func (c *Conn) SetReadDeadline(t time.Time) error { return c.fd.SetReadDeadline(t) }
   176  
   177  // SetWriteDeadline sets the write deadline associated with the Conn.
   178  func (c *Conn) SetWriteDeadline(t time.Time) error { return c.fd.SetWriteDeadline(t) }
   179  
   180  // ReadBuffer gets the size of the operating system's receive buffer associated
   181  // with the Conn.
   182  func (c *Conn) ReadBuffer() (int, error) {
   183  	return c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUF)
   184  }
   185  
   186  // WriteBuffer gets the size of the operating system's transmit buffer
   187  // associated with the Conn.
   188  func (c *Conn) WriteBuffer() (int, error) {
   189  	return c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUF)
   190  }
   191  
   192  // SetReadBuffer sets the size of the operating system's receive buffer
   193  // associated with the Conn.
   194  //
   195  // When called with elevated privileges on Linux, the SO_RCVBUFFORCE option will
   196  // be used to override operating system limits. Otherwise SO_RCVBUF is used
   197  // (which obeys operating system limits).
   198  func (c *Conn) SetReadBuffer(bytes int) error { return c.setReadBuffer(bytes) }
   199  
   200  // SetWriteBuffer sets the size of the operating system's transmit buffer
   201  // associated with the Conn.
   202  //
   203  // When called with elevated privileges on Linux, the SO_SNDBUFFORCE option will
   204  // be used to override operating system limits. Otherwise SO_SNDBUF is used
   205  // (which obeys operating system limits).
   206  func (c *Conn) SetWriteBuffer(bytes int) error { return c.setWriteBuffer(bytes) }
   207  
   208  // SyscallConn returns a raw network connection. This implements the
   209  // syscall.Conn interface.
   210  //
   211  // SyscallConn is intended for advanced use cases, such as getting and setting
   212  // arbitrary socket options using the socket's file descriptor. If possible,
   213  // those operations should be performed using methods on Conn instead.
   214  //
   215  // Once invoked, it is the caller's responsibility to ensure that operations
   216  // performed using Conn and the syscall.RawConn do not conflict with each other.
   217  func (c *Conn) SyscallConn() (syscall.RawConn, error) {
   218  	if atomic.LoadUint32(&c.closed) != 0 {
   219  		return nil, os.NewSyscallError("syscallconn", unix.EBADF)
   220  	}
   221  
   222  	// TODO(mdlayher): mutex or similar to enforce syscall.RawConn contract of
   223  	// FD remaining valid for duration of calls?
   224  	return c.rc, nil
   225  }
   226  
   227  // Socket wraps the socket(2) system call to produce a Conn. domain, typ, and
   228  // proto are passed directly to socket(2), and name should be a unique name for
   229  // the socket type such as "netlink" or "vsock".
   230  //
   231  // The cfg parameter specifies optional configuration for the Conn. If nil, no
   232  // additional configuration will be applied.
   233  //
   234  // If the operating system supports SOCK_CLOEXEC and SOCK_NONBLOCK, they are
   235  // automatically applied to typ to mirror the standard library's socket flag
   236  // behaviors.
   237  func Socket(domain, typ, proto int, name string, cfg *Config) (*Conn, error) {
   238  	if cfg == nil {
   239  		cfg = &Config{}
   240  	}
   241  
   242  	if cfg.NetNS == 0 {
   243  		// Non-Linux or no network namespace.
   244  		return socket(domain, typ, proto, name)
   245  	}
   246  
   247  	// Linux only: create Conn in the specified network namespace.
   248  	return withNetNS(cfg.NetNS, func() (*Conn, error) {
   249  		return socket(domain, typ, proto, name)
   250  	})
   251  }
   252  
   253  // socket is the internal, cross-platform entry point for socket(2).
   254  func socket(domain, typ, proto int, name string) (*Conn, error) {
   255  	var (
   256  		fd  int
   257  		err error
   258  	)
   259  
   260  	for {
   261  		fd, err = unix.Socket(domain, typ|socketFlags, proto)
   262  		switch {
   263  		case err == nil:
   264  			// Some OSes already set CLOEXEC with typ.
   265  			if !flagCLOEXEC {
   266  				unix.CloseOnExec(fd)
   267  			}
   268  
   269  			// No error, prepare the Conn.
   270  			return New(fd, name)
   271  		case !ready(err):
   272  			// System call interrupted or not ready, try again.
   273  			continue
   274  		case err == unix.EINVAL, err == unix.EPROTONOSUPPORT:
   275  			// On Linux, SOCK_NONBLOCK and SOCK_CLOEXEC were introduced in
   276  			// 2.6.27. On FreeBSD, both flags were introduced in FreeBSD 10.
   277  			// EINVAL and EPROTONOSUPPORT check for earlier versions of these
   278  			// OSes respectively.
   279  			//
   280  			// Mirror what the standard library does when creating file
   281  			// descriptors: avoid racing a fork/exec with the creation of new
   282  			// file descriptors, so that child processes do not inherit socket
   283  			// file descriptors unexpectedly.
   284  			//
   285  			// For a more thorough explanation, see similar work in the Go tree:
   286  			// func sysSocket in net/sock_cloexec.go, as well as the detailed
   287  			// comment in syscall/exec_unix.go.
   288  			syscall.ForkLock.RLock()
   289  			fd, err = unix.Socket(domain, typ, proto)
   290  			if err != nil {
   291  				syscall.ForkLock.RUnlock()
   292  				return nil, os.NewSyscallError("socket", err)
   293  			}
   294  			unix.CloseOnExec(fd)
   295  			syscall.ForkLock.RUnlock()
   296  
   297  			return New(fd, name)
   298  		default:
   299  			// Unhandled error.
   300  			return nil, os.NewSyscallError("socket", err)
   301  		}
   302  	}
   303  }
   304  
   305  // FileConn returns a copy of the network connection corresponding to the open
   306  // file. It is the caller's responsibility to close the file when finished.
   307  // Closing the Conn does not affect the File, and closing the File does not
   308  // affect the Conn.
   309  func FileConn(f *os.File, name string) (*Conn, error) {
   310  	// First we'll try to do fctnl(2) with F_DUPFD_CLOEXEC because we can dup
   311  	// the file descriptor and set the flag in one syscall.
   312  	fd, err := unix.FcntlInt(f.Fd(), unix.F_DUPFD_CLOEXEC, 0)
   313  	switch err {
   314  	case nil:
   315  		// OK, ready to set up non-blocking I/O.
   316  		return New(fd, name)
   317  	case unix.EINVAL:
   318  		// The kernel rejected our fcntl(2), fall back to separate dup(2) and
   319  		// setting close on exec.
   320  		//
   321  		// Mirror what the standard library does when creating file descriptors:
   322  		// avoid racing a fork/exec with the creation of new file descriptors,
   323  		// so that child processes do not inherit socket file descriptors
   324  		// unexpectedly.
   325  		syscall.ForkLock.RLock()
   326  		fd, err := unix.Dup(fd)
   327  		if err != nil {
   328  			syscall.ForkLock.RUnlock()
   329  			return nil, os.NewSyscallError("dup", err)
   330  		}
   331  		unix.CloseOnExec(fd)
   332  		syscall.ForkLock.RUnlock()
   333  
   334  		return New(fd, name)
   335  	default:
   336  		// Any other errors.
   337  		return nil, os.NewSyscallError("fcntl", err)
   338  	}
   339  }
   340  
   341  // New wraps an existing file descriptor to create a Conn. name should be a
   342  // unique name for the socket type such as "netlink" or "vsock".
   343  //
   344  // Most callers should use Socket or FileConn to construct a Conn. New is
   345  // intended for integrating with specific system calls which provide a file
   346  // descriptor that supports asynchronous I/O. The file descriptor is immediately
   347  // set to nonblocking mode and registered with Go's runtime network poller for
   348  // future I/O operations.
   349  //
   350  // Unlike FileConn, New does not duplicate the existing file descriptor in any
   351  // way. The returned Conn takes ownership of the underlying file descriptor.
   352  func New(fd int, name string) (*Conn, error) {
   353  	// All Conn I/O is nonblocking for integration with Go's runtime network
   354  	// poller. Depending on the OS this might already be set but it can't hurt
   355  	// to set it again.
   356  	if err := unix.SetNonblock(fd, true); err != nil {
   357  		return nil, os.NewSyscallError("setnonblock", err)
   358  	}
   359  
   360  	// os.NewFile registers the non-blocking file descriptor with the runtime
   361  	// poller, which is then used for most subsequent operations except those
   362  	// that require raw I/O via SyscallConn.
   363  	//
   364  	// See also: https://golang.org/pkg/os/#NewFile
   365  	f := os.NewFile(uintptr(fd), name)
   366  	rc, err := f.SyscallConn()
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  
   371  	c := &Conn{
   372  		name: name,
   373  		fd:   f,
   374  		rc:   rc,
   375  	}
   376  
   377  	// Probe the file descriptor for socket settings.
   378  	sotype, err := c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_TYPE)
   379  	switch {
   380  	case err == nil:
   381  		// File is a socket, check its properties.
   382  		c.facts = facts{
   383  			isStream:      sotype == unix.SOCK_STREAM,
   384  			zeroReadIsEOF: sotype != unix.SOCK_DGRAM && sotype != unix.SOCK_RAW,
   385  		}
   386  	case errors.Is(err, unix.ENOTSOCK):
   387  		// File is not a socket, treat it as a regular file.
   388  		c.facts = facts{
   389  			isStream:      true,
   390  			zeroReadIsEOF: true,
   391  		}
   392  	default:
   393  		return nil, err
   394  	}
   395  
   396  	return c, nil
   397  }
   398  
   399  // Low-level methods which provide raw system call access.
   400  
   401  // Accept wraps accept(2) or accept4(2) depending on the operating system, but
   402  // returns a Conn for the accepted connection rather than a raw file descriptor.
   403  //
   404  // If the operating system supports accept4(2) (which allows flags),
   405  // SOCK_CLOEXEC and SOCK_NONBLOCK are automatically applied to flags to mirror
   406  // the standard library's socket flag behaviors.
   407  //
   408  // If the operating system only supports accept(2) (which does not allow flags)
   409  // and flags is not zero, an error will be returned.
   410  //
   411  // Accept obeys context cancelation and uses the deadline set on the context to
   412  // cancel accepting the next connection. If a deadline is set on ctx, this
   413  // deadline will override any previous deadlines set using SetDeadline or
   414  // SetReadDeadline. Upon return, the read deadline is cleared.
   415  func (c *Conn) Accept(ctx context.Context, flags int) (*Conn, unix.Sockaddr, error) {
   416  	type ret struct {
   417  		nfd int
   418  		sa  unix.Sockaddr
   419  	}
   420  
   421  	r, err := readT(c, ctx, sysAccept, func(fd int) (ret, error) {
   422  		// Either accept(2) or accept4(2) depending on the OS.
   423  		nfd, sa, err := accept(fd, flags|socketFlags)
   424  		return ret{nfd, sa}, err
   425  	})
   426  	if err != nil {
   427  		// internal/poll, context error, or user function error.
   428  		return nil, nil, err
   429  	}
   430  
   431  	// Successfully accepted a connection, wrap it in a Conn for use by the
   432  	// caller.
   433  	ac, err := New(r.nfd, c.name)
   434  	if err != nil {
   435  		return nil, nil, err
   436  	}
   437  
   438  	return ac, r.sa, nil
   439  }
   440  
   441  // Bind wraps bind(2).
   442  func (c *Conn) Bind(sa unix.Sockaddr) error {
   443  	return c.control(context.Background(), "bind", func(fd int) error {
   444  		return unix.Bind(fd, sa)
   445  	})
   446  }
   447  
   448  // Connect wraps connect(2). In order to verify that the underlying socket is
   449  // connected to a remote peer, Connect calls getpeername(2) and returns the
   450  // unix.Sockaddr from that call.
   451  //
   452  // Connect obeys context cancelation and uses the deadline set on the context to
   453  // cancel connecting to a remote peer. If a deadline is set on ctx, this
   454  // deadline will override any previous deadlines set using SetDeadline or
   455  // SetWriteDeadline. Upon return, the write deadline is cleared.
   456  func (c *Conn) Connect(ctx context.Context, sa unix.Sockaddr) (unix.Sockaddr, error) {
   457  	const op = "connect"
   458  
   459  	// TODO(mdlayher): it would seem that trying to connect to unbound vsock
   460  	// listeners by calling Connect multiple times results in ECONNRESET for the
   461  	// first and nil error for subsequent calls. Do we need to memoize the
   462  	// error? Check what the stdlib behavior is.
   463  
   464  	var (
   465  		// Track progress between invocations of the write closure. We don't
   466  		// have an explicit WaitWrite call like internal/poll does, so we have
   467  		// to wait until the runtime calls the closure again to indicate we can
   468  		// write.
   469  		progress uint32
   470  
   471  		// Capture closure sockaddr and error.
   472  		rsa unix.Sockaddr
   473  		err error
   474  	)
   475  
   476  	doErr := c.write(ctx, op, func(fd int) error {
   477  		if atomic.AddUint32(&progress, 1) == 1 {
   478  			// First call: initiate connect.
   479  			return unix.Connect(fd, sa)
   480  		}
   481  
   482  		// Subsequent calls: the runtime network poller indicates fd is
   483  		// writable. Check for errno.
   484  		errno, gerr := c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_ERROR)
   485  		if gerr != nil {
   486  			return gerr
   487  		}
   488  		if errno != 0 {
   489  			// Connection is still not ready or failed. If errno indicates
   490  			// the socket is not ready, we will wait for the next write
   491  			// event. Otherwise we propagate this errno back to the as a
   492  			// permanent error.
   493  			uerr := unix.Errno(errno)
   494  			err = uerr
   495  			return uerr
   496  		}
   497  
   498  		// According to internal/poll, it's possible for the runtime network
   499  		// poller to spuriously wake us and return errno 0 for SO_ERROR.
   500  		// Make sure we are actually connected to a peer.
   501  		peer, err := c.Getpeername()
   502  		if err != nil {
   503  			// internal/poll unconditionally goes back to WaitWrite.
   504  			// Synthesize an error that will do the same for us.
   505  			return unix.EAGAIN
   506  		}
   507  
   508  		// Connection complete.
   509  		rsa = peer
   510  		return nil
   511  	})
   512  	if doErr != nil {
   513  		// internal/poll or context error.
   514  		return nil, doErr
   515  	}
   516  
   517  	if err == unix.EISCONN {
   518  		// TODO(mdlayher): is this block obsolete with the addition of the
   519  		// getsockopt SO_ERROR check above?
   520  		//
   521  		// EISCONN is reported if the socket is already established and should
   522  		// not be treated as an error.
   523  		//  - Darwin reports this for at least TCP sockets
   524  		//  - Linux reports this for at least AF_VSOCK sockets
   525  		return rsa, nil
   526  	}
   527  
   528  	return rsa, os.NewSyscallError(op, err)
   529  }
   530  
   531  // Getsockname wraps getsockname(2).
   532  func (c *Conn) Getsockname() (unix.Sockaddr, error) {
   533  	return controlT(c, context.Background(), "getsockname", unix.Getsockname)
   534  }
   535  
   536  // Getpeername wraps getpeername(2).
   537  func (c *Conn) Getpeername() (unix.Sockaddr, error) {
   538  	return controlT(c, context.Background(), "getpeername", unix.Getpeername)
   539  }
   540  
   541  // GetsockoptInt wraps getsockopt(2) for integer values.
   542  func (c *Conn) GetsockoptInt(level, opt int) (int, error) {
   543  	return controlT(c, context.Background(), "getsockopt", func(fd int) (int, error) {
   544  		return unix.GetsockoptInt(fd, level, opt)
   545  	})
   546  }
   547  
   548  // Listen wraps listen(2).
   549  func (c *Conn) Listen(n int) error {
   550  	return c.control(context.Background(), "listen", func(fd int) error {
   551  		return unix.Listen(fd, n)
   552  	})
   553  }
   554  
   555  // Recvmsg wraps recvmsg(2).
   556  func (c *Conn) Recvmsg(ctx context.Context, p, oob []byte, flags int) (int, int, int, unix.Sockaddr, error) {
   557  	type ret struct {
   558  		n, oobn, recvflags int
   559  		from               unix.Sockaddr
   560  	}
   561  
   562  	r, err := readT(c, ctx, "recvmsg", func(fd int) (ret, error) {
   563  		n, oobn, recvflags, from, err := unix.Recvmsg(fd, p, oob, flags)
   564  		return ret{n, oobn, recvflags, from}, err
   565  	})
   566  	if r.n == 0 && err == nil && c.facts.zeroReadIsEOF {
   567  		return 0, 0, 0, nil, io.EOF
   568  	}
   569  
   570  	return r.n, r.oobn, r.recvflags, r.from, err
   571  }
   572  
   573  // Recvfrom wraps recvfrom(2).
   574  func (c *Conn) Recvfrom(ctx context.Context, p []byte, flags int) (int, unix.Sockaddr, error) {
   575  	type ret struct {
   576  		n    int
   577  		addr unix.Sockaddr
   578  	}
   579  
   580  	out, err := readT(c, ctx, "recvfrom", func(fd int) (ret, error) {
   581  		n, addr, err := unix.Recvfrom(fd, p, flags)
   582  		return ret{n, addr}, err
   583  	})
   584  	if out.n == 0 && err == nil && c.facts.zeroReadIsEOF {
   585  		return 0, nil, io.EOF
   586  	}
   587  
   588  	return out.n, out.addr, err
   589  }
   590  
   591  // Sendmsg wraps sendmsg(2).
   592  func (c *Conn) Sendmsg(ctx context.Context, p, oob []byte, to unix.Sockaddr, flags int) (int, error) {
   593  	return writeT(c, ctx, "sendmsg", func(fd int) (int, error) {
   594  		return unix.SendmsgN(fd, p, oob, to, flags)
   595  	})
   596  }
   597  
   598  // Sendto wraps sendto(2).
   599  func (c *Conn) Sendto(ctx context.Context, p []byte, flags int, to unix.Sockaddr) error {
   600  	return c.write(ctx, "sendto", func(fd int) error {
   601  		return unix.Sendto(fd, p, flags, to)
   602  	})
   603  }
   604  
   605  // SetsockoptInt wraps setsockopt(2) for integer values.
   606  func (c *Conn) SetsockoptInt(level, opt, value int) error {
   607  	return c.control(context.Background(), "setsockopt", func(fd int) error {
   608  		return unix.SetsockoptInt(fd, level, opt, value)
   609  	})
   610  }
   611  
   612  // Shutdown wraps shutdown(2).
   613  func (c *Conn) Shutdown(how int) error {
   614  	return c.control(context.Background(), "shutdown", func(fd int) error {
   615  		return unix.Shutdown(fd, how)
   616  	})
   617  }
   618  
   619  // Conn low-level read/write/control functions. These functions mirror the
   620  // syscall.RawConn APIs but the input closures return errors rather than
   621  // booleans.
   622  
   623  // read wraps readT to execute a function and capture its error result. This is
   624  // a convenience wrapper for functions which don't return any extra values.
   625  func (c *Conn) read(ctx context.Context, op string, f func(fd int) error) error {
   626  	_, err := readT(c, ctx, op, func(fd int) (struct{}, error) {
   627  		return struct{}{}, f(fd)
   628  	})
   629  	return err
   630  }
   631  
   632  // write executes f, a write function, against the associated file descriptor.
   633  // op is used to create an *os.SyscallError if the file descriptor is closed.
   634  func (c *Conn) write(ctx context.Context, op string, f func(fd int) error) error {
   635  	_, err := writeT(c, ctx, op, func(fd int) (struct{}, error) {
   636  		return struct{}{}, f(fd)
   637  	})
   638  	return err
   639  }
   640  
   641  // readT executes c.rc.Read for op using the input function, returning a newly
   642  // allocated result T.
   643  func readT[T any](c *Conn, ctx context.Context, op string, f func(fd int) (T, error)) (T, error) {
   644  	return rwT(c, rwContext[T]{
   645  		Context: ctx,
   646  		Type:    read,
   647  		Op:      op,
   648  		Do:      f,
   649  	})
   650  }
   651  
   652  // writeT executes c.rc.Write for op using the input function, returning a newly
   653  // allocated result T.
   654  func writeT[T any](c *Conn, ctx context.Context, op string, f func(fd int) (T, error)) (T, error) {
   655  	return rwT(c, rwContext[T]{
   656  		Context: ctx,
   657  		Type:    write,
   658  		Op:      op,
   659  		Do:      f,
   660  	})
   661  }
   662  
   663  // readWrite indicates if an operation intends to read or write.
   664  type readWrite bool
   665  
   666  // Possible readWrite values.
   667  const (
   668  	read  readWrite = false
   669  	write readWrite = true
   670  )
   671  
   672  // An rwContext provides arguments to rwT.
   673  type rwContext[T any] struct {
   674  	// The caller's context passed for cancelation.
   675  	Context context.Context
   676  
   677  	// The type of an operation: read or write.
   678  	Type readWrite
   679  
   680  	// The name of the operation used in errors.
   681  	Op string
   682  
   683  	// The actual function to perform.
   684  	Do func(fd int) (T, error)
   685  }
   686  
   687  // rwT executes c.rc.Read or c.rc.Write (depending on the value of rw.Type) for
   688  // rw.Op using the input function, returning a newly allocated result T.
   689  //
   690  // It obeys context cancelation and the rw.Context must not be nil.
   691  func rwT[T any](c *Conn, rw rwContext[T]) (T, error) {
   692  	if atomic.LoadUint32(&c.closed) != 0 {
   693  		// If the file descriptor is already closed, do nothing.
   694  		return *new(T), os.NewSyscallError(rw.Op, unix.EBADF)
   695  	}
   696  
   697  	if err := rw.Context.Err(); err != nil {
   698  		// Early exit due to context cancel.
   699  		return *new(T), os.NewSyscallError(rw.Op, err)
   700  	}
   701  
   702  	var (
   703  		// The read or write function used to access the runtime network poller.
   704  		poll func(func(uintptr) bool) error
   705  
   706  		// The read or write function used to set the matching deadline.
   707  		deadline func(time.Time) error
   708  	)
   709  
   710  	if rw.Type == write {
   711  		poll = c.rc.Write
   712  		deadline = c.SetWriteDeadline
   713  	} else {
   714  		poll = c.rc.Read
   715  		deadline = c.SetReadDeadline
   716  	}
   717  
   718  	var (
   719  		// Whether or not the context carried a deadline we are actively using
   720  		// for cancelation.
   721  		setDeadline bool
   722  
   723  		// Signals for the cancelation watcher goroutine.
   724  		wg    sync.WaitGroup
   725  		doneC = make(chan struct{})
   726  
   727  		// Atomic: reports whether we have to disarm the deadline.
   728  		//
   729  		// TODO(mdlayher): switch back to atomic.Bool when we drop support for
   730  		// Go 1.18.
   731  		needDisarm int64
   732  	)
   733  
   734  	// On cancel, clean up the watcher.
   735  	defer func() {
   736  		close(doneC)
   737  		wg.Wait()
   738  	}()
   739  
   740  	if d, ok := rw.Context.Deadline(); ok {
   741  		// The context has an explicit deadline. We will use it for cancelation
   742  		// but disarm it after poll for the next call.
   743  		if err := deadline(d); err != nil {
   744  			return *new(T), err
   745  		}
   746  		setDeadline = true
   747  		atomic.AddInt64(&needDisarm, 1)
   748  	} else {
   749  		// The context does not have an explicit deadline. We have to watch for
   750  		// cancelation so we can propagate that signal to immediately unblock
   751  		// the runtime network poller.
   752  		//
   753  		// TODO(mdlayher): is it possible to detect a background context vs a
   754  		// context with possible future cancel?
   755  		wg.Add(1)
   756  		go func() {
   757  			defer wg.Done()
   758  
   759  			select {
   760  			case <-rw.Context.Done():
   761  				// Cancel the operation. Make the caller disarm after poll
   762  				// returns.
   763  				atomic.AddInt64(&needDisarm, 1)
   764  				_ = deadline(time.Unix(0, 1))
   765  			case <-doneC:
   766  				// Nothing to do.
   767  			}
   768  		}()
   769  	}
   770  
   771  	var (
   772  		t   T
   773  		err error
   774  	)
   775  
   776  	pollErr := poll(func(fd uintptr) bool {
   777  		t, err = rw.Do(int(fd))
   778  		return ready(err)
   779  	})
   780  
   781  	if atomic.LoadInt64(&needDisarm) > 0 {
   782  		_ = deadline(time.Time{})
   783  	}
   784  
   785  	if pollErr != nil {
   786  		if rw.Context.Err() != nil || (setDeadline && errors.Is(pollErr, os.ErrDeadlineExceeded)) {
   787  			// The caller canceled the operation or we set a deadline internally
   788  			// and it was reached.
   789  			//
   790  			// Unpack a plain context error. We wait for the context to be done
   791  			// to synchronize state externally. Otherwise we have noticed I/O
   792  			// timeout wakeups when we set a deadline but the context was not
   793  			// yet marked done.
   794  			<-rw.Context.Done()
   795  			return *new(T), os.NewSyscallError(rw.Op, rw.Context.Err())
   796  		}
   797  
   798  		// Error from syscall.RawConn methods. Conventionally the standard
   799  		// library does not wrap internal/poll errors in os.NewSyscallError.
   800  		return *new(T), pollErr
   801  	}
   802  
   803  	// Result from user function.
   804  	return t, os.NewSyscallError(rw.Op, err)
   805  }
   806  
   807  // control executes Conn.control for op using the input function.
   808  func (c *Conn) control(ctx context.Context, op string, f func(fd int) error) error {
   809  	_, err := controlT(c, ctx, op, func(fd int) (struct{}, error) {
   810  		return struct{}{}, f(fd)
   811  	})
   812  	return err
   813  }
   814  
   815  // controlT executes c.rc.Control for op using the input function, returning a
   816  // newly allocated result T.
   817  func controlT[T any](c *Conn, ctx context.Context, op string, f func(fd int) (T, error)) (T, error) {
   818  	if atomic.LoadUint32(&c.closed) != 0 {
   819  		// If the file descriptor is already closed, do nothing.
   820  		return *new(T), os.NewSyscallError(op, unix.EBADF)
   821  	}
   822  
   823  	var (
   824  		t   T
   825  		err error
   826  	)
   827  
   828  	doErr := c.rc.Control(func(fd uintptr) {
   829  		// Repeatedly attempt the syscall(s) invoked by f until completion is
   830  		// indicated by the return value of ready or the context is canceled.
   831  		//
   832  		// The last values for t and err are captured outside of the closure for
   833  		// use when the loop breaks.
   834  		for {
   835  			if err = ctx.Err(); err != nil {
   836  				// Early exit due to context cancel.
   837  				return
   838  			}
   839  
   840  			t, err = f(int(fd))
   841  			if ready(err) {
   842  				return
   843  			}
   844  		}
   845  	})
   846  	if doErr != nil {
   847  		// Error from syscall.RawConn methods. Conventionally the standard
   848  		// library does not wrap internal/poll errors in os.NewSyscallError.
   849  		return *new(T), doErr
   850  	}
   851  
   852  	// Result from user function.
   853  	return t, os.NewSyscallError(op, err)
   854  }
   855  
   856  // ready indicates readiness based on the value of err.
   857  func ready(err error) bool {
   858  	switch err {
   859  	case unix.EAGAIN, unix.EINPROGRESS, unix.EINTR:
   860  		// When a socket is in non-blocking mode, we might see a variety of errors:
   861  		//  - EAGAIN: most common case for a socket read not being ready
   862  		//  - EINPROGRESS: reported by some sockets when first calling connect
   863  		//  - EINTR: system call interrupted, more frequently occurs in Go 1.14+
   864  		//    because goroutines can be asynchronously preempted
   865  		//
   866  		// Return false to let the poller wait for readiness. See the source code
   867  		// for internal/poll.FD.RawRead for more details.
   868  		return false
   869  	default:
   870  		// Ready regardless of whether there was an error or no error.
   871  		return true
   872  	}
   873  }
   874  
   875  // Darwin and FreeBSD can't read or write 2GB+ files at a time,
   876  // even on 64-bit systems.
   877  // The same is true of socket implementations on many systems.
   878  // See golang.org/issue/7812 and golang.org/issue/16266.
   879  // Use 1GB instead of, say, 2GB-1, to keep subsequent reads aligned.
   880  const maxRW = 1 << 30
   881  

View as plain text