...

Source file src/github.com/Microsoft/go-winio/hvsock.go

Documentation: github.com/Microsoft/go-winio

     1  //go:build windows
     2  // +build windows
     3  
     4  package winio
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"os"
    13  	"syscall"
    14  	"time"
    15  	"unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  
    19  	"github.com/Microsoft/go-winio/internal/socket"
    20  	"github.com/Microsoft/go-winio/pkg/guid"
    21  )
    22  
    23  const afHVSock = 34 // AF_HYPERV
    24  
    25  // Well known Service and VM IDs
    26  // https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
    27  
    28  // HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
    29  func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
    30  	return guid.GUID{}
    31  }
    32  
    33  // HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
    34  func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff
    35  	return guid.GUID{
    36  		Data1: 0xffffffff,
    37  		Data2: 0xffff,
    38  		Data3: 0xffff,
    39  		Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
    40  	}
    41  }
    42  
    43  // HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
    44  func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
    45  	return guid.GUID{
    46  		Data1: 0xe0e16197,
    47  		Data2: 0xdd56,
    48  		Data3: 0x4a10,
    49  		Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
    50  	}
    51  }
    52  
    53  // HvsockGUIDSiloHost is the address of a silo's host partition:
    54  //   - The silo host of a hosted silo is the utility VM.
    55  //   - The silo host of a silo on a physical host is the physical host.
    56  func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
    57  	return guid.GUID{
    58  		Data1: 0x36bd0c5c,
    59  		Data2: 0x7276,
    60  		Data3: 0x4223,
    61  		Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
    62  	}
    63  }
    64  
    65  // HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
    66  func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
    67  	return guid.GUID{
    68  		Data1: 0x90db8b89,
    69  		Data2: 0xd35,
    70  		Data3: 0x4f79,
    71  		Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
    72  	}
    73  }
    74  
    75  // HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
    76  // Listening on this VmId accepts connection from:
    77  //   - Inside silos: silo host partition.
    78  //   - Inside hosted silo: host of the VM.
    79  //   - Inside VM: VM host.
    80  //   - Physical host: Not supported.
    81  func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
    82  	return guid.GUID{
    83  		Data1: 0xa42e7cda,
    84  		Data2: 0xd03f,
    85  		Data3: 0x480c,
    86  		Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
    87  	}
    88  }
    89  
    90  // hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
    91  func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
    92  	return guid.GUID{
    93  		Data2: 0xfacb,
    94  		Data3: 0x11e6,
    95  		Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
    96  	}
    97  }
    98  
    99  // An HvsockAddr is an address for a AF_HYPERV socket.
   100  type HvsockAddr struct {
   101  	VMID      guid.GUID
   102  	ServiceID guid.GUID
   103  }
   104  
   105  type rawHvsockAddr struct {
   106  	Family    uint16
   107  	_         uint16
   108  	VMID      guid.GUID
   109  	ServiceID guid.GUID
   110  }
   111  
   112  var _ socket.RawSockaddr = &rawHvsockAddr{}
   113  
   114  // Network returns the address's network name, "hvsock".
   115  func (*HvsockAddr) Network() string {
   116  	return "hvsock"
   117  }
   118  
   119  func (addr *HvsockAddr) String() string {
   120  	return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
   121  }
   122  
   123  // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
   124  func VsockServiceID(port uint32) guid.GUID {
   125  	g := hvsockVsockServiceTemplate() // make a copy
   126  	g.Data1 = port
   127  	return g
   128  }
   129  
   130  func (addr *HvsockAddr) raw() rawHvsockAddr {
   131  	return rawHvsockAddr{
   132  		Family:    afHVSock,
   133  		VMID:      addr.VMID,
   134  		ServiceID: addr.ServiceID,
   135  	}
   136  }
   137  
   138  func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
   139  	addr.VMID = raw.VMID
   140  	addr.ServiceID = raw.ServiceID
   141  }
   142  
   143  // Sockaddr returns a pointer to and the size of this struct.
   144  //
   145  // Implements the [socket.RawSockaddr] interface, and allows use in
   146  // [socket.Bind] and [socket.ConnectEx].
   147  func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
   148  	return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
   149  }
   150  
   151  // Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
   152  func (r *rawHvsockAddr) FromBytes(b []byte) error {
   153  	n := int(unsafe.Sizeof(rawHvsockAddr{}))
   154  
   155  	if len(b) < n {
   156  		return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
   157  	}
   158  
   159  	copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
   160  	if r.Family != afHVSock {
   161  		return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
   162  	}
   163  
   164  	return nil
   165  }
   166  
   167  // HvsockListener is a socket listener for the AF_HYPERV address family.
   168  type HvsockListener struct {
   169  	sock *win32File
   170  	addr HvsockAddr
   171  }
   172  
   173  var _ net.Listener = &HvsockListener{}
   174  
   175  // HvsockConn is a connected socket of the AF_HYPERV address family.
   176  type HvsockConn struct {
   177  	sock          *win32File
   178  	local, remote HvsockAddr
   179  }
   180  
   181  var _ net.Conn = &HvsockConn{}
   182  
   183  func newHVSocket() (*win32File, error) {
   184  	fd, err := syscall.Socket(afHVSock, syscall.SOCK_STREAM, 1)
   185  	if err != nil {
   186  		return nil, os.NewSyscallError("socket", err)
   187  	}
   188  	f, err := makeWin32File(fd)
   189  	if err != nil {
   190  		syscall.Close(fd)
   191  		return nil, err
   192  	}
   193  	f.socket = true
   194  	return f, nil
   195  }
   196  
   197  // ListenHvsock listens for connections on the specified hvsock address.
   198  func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
   199  	l := &HvsockListener{addr: *addr}
   200  	sock, err := newHVSocket()
   201  	if err != nil {
   202  		return nil, l.opErr("listen", err)
   203  	}
   204  	sa := addr.raw()
   205  	err = socket.Bind(windows.Handle(sock.handle), &sa)
   206  	if err != nil {
   207  		return nil, l.opErr("listen", os.NewSyscallError("socket", err))
   208  	}
   209  	err = syscall.Listen(sock.handle, 16)
   210  	if err != nil {
   211  		return nil, l.opErr("listen", os.NewSyscallError("listen", err))
   212  	}
   213  	return &HvsockListener{sock: sock, addr: *addr}, nil
   214  }
   215  
   216  func (l *HvsockListener) opErr(op string, err error) error {
   217  	return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
   218  }
   219  
   220  // Addr returns the listener's network address.
   221  func (l *HvsockListener) Addr() net.Addr {
   222  	return &l.addr
   223  }
   224  
   225  // Accept waits for the next connection and returns it.
   226  func (l *HvsockListener) Accept() (_ net.Conn, err error) {
   227  	sock, err := newHVSocket()
   228  	if err != nil {
   229  		return nil, l.opErr("accept", err)
   230  	}
   231  	defer func() {
   232  		if sock != nil {
   233  			sock.Close()
   234  		}
   235  	}()
   236  	c, err := l.sock.prepareIO()
   237  	if err != nil {
   238  		return nil, l.opErr("accept", err)
   239  	}
   240  	defer l.sock.wg.Done()
   241  
   242  	// AcceptEx, per documentation, requires an extra 16 bytes per address.
   243  	//
   244  	// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
   245  	const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
   246  	var addrbuf [addrlen * 2]byte
   247  
   248  	var bytes uint32
   249  	err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
   250  	if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
   251  		return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
   252  	}
   253  
   254  	conn := &HvsockConn{
   255  		sock: sock,
   256  	}
   257  	// The local address returned in the AcceptEx buffer is the same as the Listener socket's
   258  	// address. However, the service GUID reported by GetSockName is different from the Listeners
   259  	// socket, and is sometimes the same as the local address of the socket that dialed the
   260  	// address, with the service GUID.Data1 incremented, but othertimes is different.
   261  	// todo: does the local address matter? is the listener's address or the actual address appropriate?
   262  	conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
   263  	conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
   264  
   265  	// initialize the accepted socket and update its properties with those of the listening socket
   266  	if err = windows.Setsockopt(windows.Handle(sock.handle),
   267  		windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
   268  		(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
   269  		return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
   270  	}
   271  
   272  	sock = nil
   273  	return conn, nil
   274  }
   275  
   276  // Close closes the listener, causing any pending Accept calls to fail.
   277  func (l *HvsockListener) Close() error {
   278  	return l.sock.Close()
   279  }
   280  
   281  // HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
   282  type HvsockDialer struct {
   283  	// Deadline is the time the Dial operation must connect before erroring.
   284  	Deadline time.Time
   285  
   286  	// Retries is the number of additional connects to try if the connection times out, is refused,
   287  	// or the host is unreachable
   288  	Retries uint
   289  
   290  	// RetryWait is the time to wait after a connection error to retry
   291  	RetryWait time.Duration
   292  
   293  	rt *time.Timer // redial wait timer
   294  }
   295  
   296  // Dial the Hyper-V socket at addr.
   297  //
   298  // See [HvsockDialer.Dial] for more information.
   299  func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
   300  	return (&HvsockDialer{}).Dial(ctx, addr)
   301  }
   302  
   303  // Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
   304  // Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
   305  // retries.
   306  //
   307  // Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
   308  func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
   309  	op := "dial"
   310  	// create the conn early to use opErr()
   311  	conn = &HvsockConn{
   312  		remote: *addr,
   313  	}
   314  
   315  	if !d.Deadline.IsZero() {
   316  		var cancel context.CancelFunc
   317  		ctx, cancel = context.WithDeadline(ctx, d.Deadline)
   318  		defer cancel()
   319  	}
   320  
   321  	// preemptive timeout/cancellation check
   322  	if err = ctx.Err(); err != nil {
   323  		return nil, conn.opErr(op, err)
   324  	}
   325  
   326  	sock, err := newHVSocket()
   327  	if err != nil {
   328  		return nil, conn.opErr(op, err)
   329  	}
   330  	defer func() {
   331  		if sock != nil {
   332  			sock.Close()
   333  		}
   334  	}()
   335  
   336  	sa := addr.raw()
   337  	err = socket.Bind(windows.Handle(sock.handle), &sa)
   338  	if err != nil {
   339  		return nil, conn.opErr(op, os.NewSyscallError("bind", err))
   340  	}
   341  
   342  	c, err := sock.prepareIO()
   343  	if err != nil {
   344  		return nil, conn.opErr(op, err)
   345  	}
   346  	defer sock.wg.Done()
   347  	var bytes uint32
   348  	for i := uint(0); i <= d.Retries; i++ {
   349  		err = socket.ConnectEx(
   350  			windows.Handle(sock.handle),
   351  			&sa,
   352  			nil, // sendBuf
   353  			0,   // sendDataLen
   354  			&bytes,
   355  			(*windows.Overlapped)(unsafe.Pointer(&c.o)))
   356  		_, err = sock.asyncIO(c, nil, bytes, err)
   357  		if i < d.Retries && canRedial(err) {
   358  			if err = d.redialWait(ctx); err == nil {
   359  				continue
   360  			}
   361  		}
   362  		break
   363  	}
   364  	if err != nil {
   365  		return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
   366  	}
   367  
   368  	// update the connection properties, so shutdown can be used
   369  	if err = windows.Setsockopt(
   370  		windows.Handle(sock.handle),
   371  		windows.SOL_SOCKET,
   372  		windows.SO_UPDATE_CONNECT_CONTEXT,
   373  		nil, // optvalue
   374  		0,   // optlen
   375  	); err != nil {
   376  		return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
   377  	}
   378  
   379  	// get the local name
   380  	var sal rawHvsockAddr
   381  	err = socket.GetSockName(windows.Handle(sock.handle), &sal)
   382  	if err != nil {
   383  		return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
   384  	}
   385  	conn.local.fromRaw(&sal)
   386  
   387  	// one last check for timeout, since asyncIO doesn't check the context
   388  	if err = ctx.Err(); err != nil {
   389  		return nil, conn.opErr(op, err)
   390  	}
   391  
   392  	conn.sock = sock
   393  	sock = nil
   394  
   395  	return conn, nil
   396  }
   397  
   398  // redialWait waits before attempting to redial, resetting the timer as appropriate.
   399  func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
   400  	if d.RetryWait == 0 {
   401  		return nil
   402  	}
   403  
   404  	if d.rt == nil {
   405  		d.rt = time.NewTimer(d.RetryWait)
   406  	} else {
   407  		// should already be stopped and drained
   408  		d.rt.Reset(d.RetryWait)
   409  	}
   410  
   411  	select {
   412  	case <-ctx.Done():
   413  	case <-d.rt.C:
   414  		return nil
   415  	}
   416  
   417  	// stop and drain the timer
   418  	if !d.rt.Stop() {
   419  		<-d.rt.C
   420  	}
   421  	return ctx.Err()
   422  }
   423  
   424  // assumes error is a plain, unwrapped syscall.Errno provided by direct syscall.
   425  func canRedial(err error) bool {
   426  	//nolint:errorlint // guaranteed to be an Errno
   427  	switch err {
   428  	case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
   429  		windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
   430  		return true
   431  	default:
   432  		return false
   433  	}
   434  }
   435  
   436  func (conn *HvsockConn) opErr(op string, err error) error {
   437  	// translate from "file closed" to "socket closed"
   438  	if errors.Is(err, ErrFileClosed) {
   439  		err = socket.ErrSocketClosed
   440  	}
   441  	return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
   442  }
   443  
   444  func (conn *HvsockConn) Read(b []byte) (int, error) {
   445  	c, err := conn.sock.prepareIO()
   446  	if err != nil {
   447  		return 0, conn.opErr("read", err)
   448  	}
   449  	defer conn.sock.wg.Done()
   450  	buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
   451  	var flags, bytes uint32
   452  	err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
   453  	n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
   454  	if err != nil {
   455  		var eno windows.Errno
   456  		if errors.As(err, &eno) {
   457  			err = os.NewSyscallError("wsarecv", eno)
   458  		}
   459  		return 0, conn.opErr("read", err)
   460  	} else if n == 0 {
   461  		err = io.EOF
   462  	}
   463  	return n, err
   464  }
   465  
   466  func (conn *HvsockConn) Write(b []byte) (int, error) {
   467  	t := 0
   468  	for len(b) != 0 {
   469  		n, err := conn.write(b)
   470  		if err != nil {
   471  			return t + n, err
   472  		}
   473  		t += n
   474  		b = b[n:]
   475  	}
   476  	return t, nil
   477  }
   478  
   479  func (conn *HvsockConn) write(b []byte) (int, error) {
   480  	c, err := conn.sock.prepareIO()
   481  	if err != nil {
   482  		return 0, conn.opErr("write", err)
   483  	}
   484  	defer conn.sock.wg.Done()
   485  	buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
   486  	var bytes uint32
   487  	err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
   488  	n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
   489  	if err != nil {
   490  		var eno windows.Errno
   491  		if errors.As(err, &eno) {
   492  			err = os.NewSyscallError("wsasend", eno)
   493  		}
   494  		return 0, conn.opErr("write", err)
   495  	}
   496  	return n, err
   497  }
   498  
   499  // Close closes the socket connection, failing any pending read or write calls.
   500  func (conn *HvsockConn) Close() error {
   501  	return conn.sock.Close()
   502  }
   503  
   504  func (conn *HvsockConn) IsClosed() bool {
   505  	return conn.sock.IsClosed()
   506  }
   507  
   508  // shutdown disables sending or receiving on a socket.
   509  func (conn *HvsockConn) shutdown(how int) error {
   510  	if conn.IsClosed() {
   511  		return socket.ErrSocketClosed
   512  	}
   513  
   514  	err := syscall.Shutdown(conn.sock.handle, how)
   515  	if err != nil {
   516  		// If the connection was closed, shutdowns fail with "not connected"
   517  		if errors.Is(err, windows.WSAENOTCONN) ||
   518  			errors.Is(err, windows.WSAESHUTDOWN) {
   519  			err = socket.ErrSocketClosed
   520  		}
   521  		return os.NewSyscallError("shutdown", err)
   522  	}
   523  	return nil
   524  }
   525  
   526  // CloseRead shuts down the read end of the socket, preventing future read operations.
   527  func (conn *HvsockConn) CloseRead() error {
   528  	err := conn.shutdown(syscall.SHUT_RD)
   529  	if err != nil {
   530  		return conn.opErr("closeread", err)
   531  	}
   532  	return nil
   533  }
   534  
   535  // CloseWrite shuts down the write end of the socket, preventing future write operations and
   536  // notifying the other endpoint that no more data will be written.
   537  func (conn *HvsockConn) CloseWrite() error {
   538  	err := conn.shutdown(syscall.SHUT_WR)
   539  	if err != nil {
   540  		return conn.opErr("closewrite", err)
   541  	}
   542  	return nil
   543  }
   544  
   545  // LocalAddr returns the local address of the connection.
   546  func (conn *HvsockConn) LocalAddr() net.Addr {
   547  	return &conn.local
   548  }
   549  
   550  // RemoteAddr returns the remote address of the connection.
   551  func (conn *HvsockConn) RemoteAddr() net.Addr {
   552  	return &conn.remote
   553  }
   554  
   555  // SetDeadline implements the net.Conn SetDeadline method.
   556  func (conn *HvsockConn) SetDeadline(t time.Time) error {
   557  	// todo: implement `SetDeadline` for `win32File`
   558  	if err := conn.SetReadDeadline(t); err != nil {
   559  		return fmt.Errorf("set read deadline: %w", err)
   560  	}
   561  	if err := conn.SetWriteDeadline(t); err != nil {
   562  		return fmt.Errorf("set write deadline: %w", err)
   563  	}
   564  	return nil
   565  }
   566  
   567  // SetReadDeadline implements the net.Conn SetReadDeadline method.
   568  func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
   569  	return conn.sock.SetReadDeadline(t)
   570  }
   571  
   572  // SetWriteDeadline implements the net.Conn SetWriteDeadline method.
   573  func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
   574  	return conn.sock.SetWriteDeadline(t)
   575  }
   576  

View as plain text