...

Source file src/github.com/Microsoft/go-winio/internal/socket/socket.go

Documentation: github.com/Microsoft/go-winio/internal/socket

     1  //go:build windows
     2  
     3  package socket
     4  
     5  import (
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"sync"
    10  	"syscall"
    11  	"unsafe"
    12  
    13  	"github.com/Microsoft/go-winio/pkg/guid"
    14  	"golang.org/x/sys/windows"
    15  )
    16  
    17  //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
    18  
    19  //sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
    20  //sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
    21  //sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
    22  
    23  const socketError = uintptr(^uint32(0))
    24  
    25  var (
    26  	// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
    27  
    28  	ErrBufferSize     = errors.New("buffer size")
    29  	ErrAddrFamily     = errors.New("address family")
    30  	ErrInvalidPointer = errors.New("invalid pointer")
    31  	ErrSocketClosed   = fmt.Errorf("socket closed: %w", net.ErrClosed)
    32  )
    33  
    34  // todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
    35  
    36  // GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
    37  // If rsa is not large enough, the [windows.WSAEFAULT] is returned.
    38  func GetSockName(s windows.Handle, rsa RawSockaddr) error {
    39  	ptr, l, err := rsa.Sockaddr()
    40  	if err != nil {
    41  		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
    42  	}
    43  
    44  	// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
    45  	// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
    46  	return getsockname(s, ptr, &l)
    47  }
    48  
    49  // GetPeerName returns the remote address the socket is connected to.
    50  //
    51  // See [GetSockName] for more information.
    52  func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
    53  	ptr, l, err := rsa.Sockaddr()
    54  	if err != nil {
    55  		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
    56  	}
    57  
    58  	return getpeername(s, ptr, &l)
    59  }
    60  
    61  func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
    62  	ptr, l, err := rsa.Sockaddr()
    63  	if err != nil {
    64  		return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
    65  	}
    66  
    67  	return bind(s, ptr, l)
    68  }
    69  
    70  // "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
    71  // their sockaddr interface, so they cannot be used with HvsockAddr
    72  // Replicate functionality here from
    73  // https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
    74  
    75  // The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
    76  // runtime via a WSAIoctl call:
    77  // https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
    78  
    79  type runtimeFunc struct {
    80  	id   guid.GUID
    81  	once sync.Once
    82  	addr uintptr
    83  	err  error
    84  }
    85  
    86  func (f *runtimeFunc) Load() error {
    87  	f.once.Do(func() {
    88  		var s windows.Handle
    89  		s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
    90  		if f.err != nil {
    91  			return
    92  		}
    93  		defer windows.CloseHandle(s) //nolint:errcheck
    94  
    95  		var n uint32
    96  		f.err = windows.WSAIoctl(s,
    97  			windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
    98  			(*byte)(unsafe.Pointer(&f.id)),
    99  			uint32(unsafe.Sizeof(f.id)),
   100  			(*byte)(unsafe.Pointer(&f.addr)),
   101  			uint32(unsafe.Sizeof(f.addr)),
   102  			&n,
   103  			nil, // overlapped
   104  			0,   // completionRoutine
   105  		)
   106  	})
   107  	return f.err
   108  }
   109  
   110  var (
   111  	// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
   112  	WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
   113  		Data1: 0x25a207b9,
   114  		Data2: 0xddf3,
   115  		Data3: 0x4660,
   116  		Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
   117  	}
   118  
   119  	connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
   120  )
   121  
   122  func ConnectEx(
   123  	fd windows.Handle,
   124  	rsa RawSockaddr,
   125  	sendBuf *byte,
   126  	sendDataLen uint32,
   127  	bytesSent *uint32,
   128  	overlapped *windows.Overlapped,
   129  ) error {
   130  	if err := connectExFunc.Load(); err != nil {
   131  		return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
   132  	}
   133  	ptr, n, err := rsa.Sockaddr()
   134  	if err != nil {
   135  		return err
   136  	}
   137  	return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
   138  }
   139  
   140  // BOOL LpfnConnectex(
   141  //   [in]           SOCKET s,
   142  //   [in]           const sockaddr *name,
   143  //   [in]           int namelen,
   144  //   [in, optional] PVOID lpSendBuffer,
   145  //   [in]           DWORD dwSendDataLength,
   146  //   [out]          LPDWORD lpdwBytesSent,
   147  //   [in]           LPOVERLAPPED lpOverlapped
   148  // )
   149  
   150  func connectEx(
   151  	s windows.Handle,
   152  	name unsafe.Pointer,
   153  	namelen int32,
   154  	sendBuf *byte,
   155  	sendDataLen uint32,
   156  	bytesSent *uint32,
   157  	overlapped *windows.Overlapped,
   158  ) (err error) {
   159  	// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
   160  	r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
   161  		7,
   162  		uintptr(s),
   163  		uintptr(name),
   164  		uintptr(namelen),
   165  		uintptr(unsafe.Pointer(sendBuf)),
   166  		uintptr(sendDataLen),
   167  		uintptr(unsafe.Pointer(bytesSent)),
   168  		uintptr(unsafe.Pointer(overlapped)),
   169  		0,
   170  		0)
   171  	if r1 == 0 {
   172  		if e1 != 0 {
   173  			err = error(e1)
   174  		} else {
   175  			err = syscall.EINVAL
   176  		}
   177  	}
   178  	return err
   179  }
   180  

View as plain text