...

Source file src/github.com/Microsoft/go-winio/hvsock_test.go

Documentation: github.com/Microsoft/go-winio

     1  //go:build windows
     2  
     3  package winio
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"math/rand"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"golang.org/x/sys/windows"
    16  
    17  	"github.com/Microsoft/go-winio/internal/socket"
    18  	"github.com/Microsoft/go-winio/pkg/guid"
    19  )
    20  
    21  const testStr = "test"
    22  
    23  func randHvsockAddr() *HvsockAddr {
    24  	p := rand.Uint32() //nolint:gosec // used for testing
    25  	return &HvsockAddr{
    26  		VMID:      HvsockGUIDLoopback(),
    27  		ServiceID: VsockServiceID(p),
    28  	}
    29  }
    30  
    31  func serverListen(u testUtil) (l *HvsockListener, a *HvsockAddr) {
    32  	var err error
    33  	for i := 0; i < 3; i++ {
    34  		a = randHvsockAddr()
    35  		l, err = ListenHvsock(a)
    36  		if errors.Is(err, windows.WSAEADDRINUSE) {
    37  			u.T.Logf("address collision %v", a)
    38  			continue
    39  		}
    40  		break
    41  	}
    42  	u.Must(err, "could not listen")
    43  	u.T.Cleanup(func() {
    44  		if l != nil {
    45  			u.Must(l.Close(), "Hyper-V socket listener close")
    46  		}
    47  	})
    48  
    49  	return l, a
    50  }
    51  
    52  func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) {
    53  	l, addr := serverListen(u)
    54  	ch := u.Go(func() error {
    55  		conn, err := l.Accept()
    56  		if err != nil {
    57  			return fmt.Errorf("listener accept: %w", err)
    58  		}
    59  		sv = conn.(*HvsockConn)
    60  		if err := l.Close(); err != nil {
    61  			return err
    62  		}
    63  		l = nil
    64  		return nil
    65  	})
    66  
    67  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    68  	defer cancel()
    69  	cl, err := Dial(ctx, addr)
    70  	u.Must(err, "could not dial")
    71  	u.T.Cleanup(func() {
    72  		if cl != nil {
    73  			u.Must(cl.Close(), "client close")
    74  		}
    75  	})
    76  
    77  	u.WaitErr(ch, time.Second)
    78  	u.T.Cleanup(func() {
    79  		if sv != nil {
    80  			u.Must(sv.Close(), "server close")
    81  		}
    82  	})
    83  	return cl, sv, addr
    84  }
    85  
    86  func TestHvSockConstants(t *testing.T) {
    87  	tests := []struct {
    88  		name string
    89  		want string
    90  		give guid.GUID
    91  	}{
    92  		{"wildcard", "00000000-0000-0000-0000-000000000000", HvsockGUIDWildcard()},
    93  		{"broadcast", "ffffffff-ffff-ffff-ffff-ffffffffffff", HvsockGUIDBroadcast()},
    94  		{"loopback", "e0e16197-dd56-4a10-9195-5ee7a155a838", HvsockGUIDLoopback()},
    95  		{"children", "90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd", HvsockGUIDChildren()},
    96  		{"parent", "a42e7cda-d03f-480c-9cc2-a4de20abb878", HvsockGUIDParent()},
    97  		{"silohost", "36bd0c5c-7276-4223-88ba-7d03b654c568", HvsockGUIDSiloHost()},
    98  		{"vsock template", "00000000-facb-11e6-bd58-64006a7986d3", hvsockVsockServiceTemplate()},
    99  	}
   100  	for _, tt := range tests {
   101  		if tt.give.String() != tt.want {
   102  			t.Errorf("%s give: %v; want: %s", tt.name, tt.give, tt.want)
   103  		}
   104  	}
   105  }
   106  
   107  func TestHvSockListenerAddresses(t *testing.T) {
   108  	u := newUtil(t)
   109  	l, addr := serverListen(u)
   110  
   111  	la := (l.Addr()).(*HvsockAddr)
   112  	u.Assert(*la == *addr, fmt.Sprintf("give: %v; want: %v", la, addr))
   113  
   114  	ra := rawHvsockAddr{}
   115  	sa := HvsockAddr{}
   116  	u.Must(socket.GetSockName(windows.Handle(l.sock.handle), &ra))
   117  	sa.fromRaw(&ra)
   118  	u.Assert(sa == *addr, fmt.Sprintf("listener local addr give: %v; want: %v", sa, addr))
   119  }
   120  
   121  func TestHvSockAddresses(t *testing.T) {
   122  	u := newUtil(t)
   123  	cl, sv, addr := clientServer(u)
   124  
   125  	sra := (sv.RemoteAddr()).(*HvsockAddr)
   126  	sla := (sv.LocalAddr()).(*HvsockAddr)
   127  	cra := (cl.RemoteAddr()).(*HvsockAddr)
   128  	cla := (cl.LocalAddr()).(*HvsockAddr)
   129  
   130  	t.Run("Info", func(t *testing.T) {
   131  		tests := []struct {
   132  			name string
   133  			give *HvsockAddr
   134  			want HvsockAddr
   135  		}{
   136  			{"client local", cla, HvsockAddr{HvsockGUIDChildren(), sra.ServiceID}},
   137  			{"client remote", cra, *addr},
   138  			{"server local", sla, HvsockAddr{HvsockGUIDChildren(), addr.ServiceID}},
   139  			{"server remote", sra, HvsockAddr{HvsockGUIDLoopback(), cla.ServiceID}},
   140  		}
   141  		for _, tt := range tests {
   142  			if *tt.give != tt.want {
   143  				t.Errorf("%s address give: %v; want: %v", tt.name, tt.give, tt.want)
   144  			}
   145  		}
   146  	})
   147  
   148  	t.Run("OSinfo", func(t *testing.T) {
   149  		u := newUtil(t)
   150  		ra := rawHvsockAddr{}
   151  		sa := HvsockAddr{}
   152  
   153  		localTests := []struct {
   154  			name     string
   155  			giveSock *win32File
   156  			wantAddr HvsockAddr
   157  		}{
   158  			{"client", cl.sock, HvsockAddr{HvsockGUIDChildren(), cla.ServiceID}},
   159  			// The server sockets local address seems arbitrary, so skip this test
   160  			// see comment in `(*HvsockListener) Accept()` for more info
   161  			// {"server", sv.sock, _sla},
   162  		}
   163  		for _, tt := range localTests {
   164  			u.Must(socket.GetSockName(windows.Handle(tt.giveSock.handle), &ra))
   165  			sa.fromRaw(&ra)
   166  			if sa != tt.wantAddr {
   167  				t.Errorf("%s local addr give: %v; want: %v", tt.name, sa, tt.wantAddr)
   168  			}
   169  		}
   170  
   171  		remoteTests := []struct {
   172  			name     string
   173  			giveConn *HvsockConn
   174  		}{
   175  			{"client", cl},
   176  			{"server", sv},
   177  		}
   178  		for _, tt := range remoteTests {
   179  			u.Must(socket.GetPeerName(windows.Handle(tt.giveConn.sock.handle), &ra))
   180  			sa.fromRaw(&ra)
   181  			if sa != tt.giveConn.remote {
   182  				t.Errorf("%s remote addr give: %v; want: %v", tt.name, sa, tt.giveConn.remote)
   183  			}
   184  		}
   185  	})
   186  }
   187  
   188  func TestHvSockReadWrite(t *testing.T) {
   189  	u := newUtil(t)
   190  	l, addr := serverListen(u)
   191  	tests := []struct {
   192  		req, rsp string
   193  	}{
   194  		{"hello ", "world!"},
   195  		{"ping", "pong"},
   196  	}
   197  
   198  	// a sync.WaitGroup doesnt offer a channel to use in a select with a timeout
   199  	// could use an errgroup.Group, but for now dual channels work fine
   200  	svCh := u.Go(func() error {
   201  		c, err := l.Accept()
   202  		if err != nil {
   203  			return fmt.Errorf("listener accept: %w", err)
   204  		}
   205  		defer c.Close()
   206  
   207  		b := make([]byte, 64)
   208  		for _, tt := range tests {
   209  			n, err := c.Read(b)
   210  			if err != nil {
   211  				return fmt.Errorf("server rx: %w", err)
   212  			}
   213  
   214  			r := string(b[:n])
   215  			if r != tt.req {
   216  				return fmt.Errorf("server rx error: got %q; wanted %q", r, tt.req)
   217  			}
   218  			if _, err = c.Write([]byte(tt.rsp)); err != nil {
   219  				return fmt.Errorf("server tx error, could not send %q: %w", tt.rsp, err)
   220  			}
   221  		}
   222  		n, err := c.Read(b)
   223  		if n != 0 {
   224  			return errors.New("server did not get EOF")
   225  		}
   226  		if !errors.Is(err, io.EOF) {
   227  			return fmt.Errorf("server did not get EOF: %w", err)
   228  		}
   229  		return nil
   230  	})
   231  
   232  	clCh := u.Go(func() error {
   233  		cl, err := Dial(context.Background(), addr)
   234  		if err != nil {
   235  			return fmt.Errorf("client dial: %w", err)
   236  		}
   237  		defer cl.Close()
   238  
   239  		b := make([]byte, 64)
   240  		for _, tt := range tests {
   241  			_, err := cl.Write([]byte(tt.req))
   242  			if err != nil {
   243  				return fmt.Errorf("client tx error, could not send %q: %w", tt.req, err)
   244  			}
   245  
   246  			n, err := cl.Read(b)
   247  			if err != nil {
   248  				return fmt.Errorf("client tx: %w", err)
   249  			}
   250  
   251  			r := string(b[:n])
   252  			if r != tt.rsp {
   253  				return fmt.Errorf("client rx error: got %q; wanted %q", b[:n], tt.rsp)
   254  			}
   255  		}
   256  		return cl.CloseWrite()
   257  	})
   258  
   259  	u.WaitErr(svCh, 15*time.Second, "server")
   260  	u.WaitErr(clCh, 15*time.Second, "client")
   261  }
   262  
   263  func TestHvSockReadTooSmall(t *testing.T) {
   264  	u := newUtil(t)
   265  	s := "this is a really long string that hopefully takes up more than 16 bytes ..."
   266  	l, addr := serverListen(u)
   267  
   268  	svCh := u.Go(func() error {
   269  		c, err := l.Accept()
   270  		if err != nil {
   271  			return fmt.Errorf("listener accept: %w", err)
   272  		}
   273  		defer c.Close()
   274  
   275  		b := make([]byte, 16)
   276  		ss := ""
   277  		for {
   278  			n, err := c.Read(b)
   279  			if errors.Is(err, io.EOF) {
   280  				break
   281  			}
   282  			if err != nil {
   283  				return fmt.Errorf("server rx: %w", err)
   284  			}
   285  			ss += string(b[:n])
   286  		}
   287  
   288  		if ss != s {
   289  			return fmt.Errorf("got %q, wanted: %q", ss, s)
   290  		}
   291  		return nil
   292  	})
   293  
   294  	clCh := u.Go(func() error {
   295  		cl, err := Dial(context.Background(), addr)
   296  		if err != nil {
   297  			return fmt.Errorf("client dial: %w", err)
   298  		}
   299  		defer cl.Close()
   300  
   301  		if _, err = cl.Write([]byte(s)); err != nil {
   302  			return fmt.Errorf("client tx error, could not send: %w", err)
   303  		}
   304  		return nil
   305  	})
   306  
   307  	u.WaitErr(svCh, 15*time.Second, "server")
   308  	u.WaitErr(clCh, 15*time.Second, "client")
   309  }
   310  
   311  func TestHvSockCloseReadWriteListener(t *testing.T) {
   312  	u := newUtil(t)
   313  	l, addr := serverListen(u)
   314  
   315  	ch := make(chan struct{})
   316  	svCh := u.Go(func() error {
   317  		defer close(ch)
   318  		c, err := l.Accept()
   319  		if err != nil {
   320  			return fmt.Errorf("listener accept: %w", err)
   321  		}
   322  		defer c.Close()
   323  
   324  		hv := c.(*HvsockConn)
   325  		//
   326  		// test CloseWrite()
   327  		//
   328  		n, err := c.Write([]byte(testStr))
   329  		if err != nil {
   330  			return fmt.Errorf("server tx: %w", err)
   331  		}
   332  		if n != len(testStr) {
   333  			return fmt.Errorf("server wrote %d bytes, wanted %d", n, len(testStr))
   334  		}
   335  
   336  		if err := hv.CloseWrite(); err != nil {
   337  			return fmt.Errorf("server close write: %w", err)
   338  		}
   339  
   340  		if _, err = c.Write([]byte(testStr)); !errors.Is(err, windows.WSAESHUTDOWN) {
   341  			return fmt.Errorf("server did not shutdown writes: %w", err)
   342  		}
   343  		// safe to call multiple times
   344  		if err := hv.CloseWrite(); err != nil {
   345  			return fmt.Errorf("server second close write: %w", err)
   346  		}
   347  
   348  		//
   349  		// test CloseRead()
   350  		//
   351  		b := make([]byte, 256)
   352  		n, err = c.Read(b)
   353  		if err != nil {
   354  			return fmt.Errorf("server read: %w", err)
   355  		}
   356  		if n != len(testStr) {
   357  			return fmt.Errorf("server read %d bytes, wanted %d", n, len(testStr))
   358  		}
   359  		if string(b[:n]) != testStr {
   360  			return fmt.Errorf("server got %q; wanted %q", b[:n], testStr)
   361  		}
   362  		if err := hv.CloseRead(); err != nil {
   363  			return fmt.Errorf("server close read: %w", err)
   364  		}
   365  
   366  		ch <- struct{}{}
   367  
   368  		// signal the client to send more info
   369  		// if it was sent before, the read would succeed if the data was buffered prior
   370  		_, err = c.Read(b)
   371  		if !errors.Is(err, windows.WSAESHUTDOWN) {
   372  			return fmt.Errorf("server did not shutdown reads: %w", err)
   373  		}
   374  		// safe to call multiple times
   375  		if err := hv.CloseRead(); err != nil {
   376  			return fmt.Errorf("server second close read: %w", err)
   377  		}
   378  
   379  		c.Close()
   380  		if err := hv.CloseWrite(); !errors.Is(err, socket.ErrSocketClosed) {
   381  			return fmt.Errorf("server close write: %w", err)
   382  		}
   383  		if err := hv.CloseRead(); !errors.Is(err, socket.ErrSocketClosed) {
   384  			return fmt.Errorf("server close read: %w", err)
   385  		}
   386  		return nil
   387  	})
   388  
   389  	cl, err := Dial(context.Background(), addr)
   390  	u.Must(err, "could not dial")
   391  	defer cl.Close()
   392  
   393  	b := make([]byte, 256)
   394  	n, err := cl.Read(b)
   395  	u.Must(err, "client read")
   396  	u.Assert(n == len(testStr), fmt.Sprintf("client read %d bytes, wanted %d", n, len(testStr)))
   397  	u.Assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr))
   398  
   399  	n, err = cl.Read(b)
   400  	u.Assert(n == 0, "client did not get EOF")
   401  	u.Is(err, io.EOF, "client did not get EOF")
   402  
   403  	n, err = cl.Write([]byte(testStr))
   404  	u.Must(err, "client write")
   405  	u.Assert(n == len(testStr), fmt.Sprintf("client wrote %d bytes, wanted %d", n, len(testStr)))
   406  
   407  	u.Wait(ch, time.Second)
   408  
   409  	// this should succeed
   410  	_, err = cl.Write([]byte("test2"))
   411  	u.Must(err, "client write")
   412  	u.WaitErr(svCh, time.Second, "server")
   413  }
   414  
   415  func TestHvSockCloseReadWriteDial(t *testing.T) {
   416  	u := newUtil(t)
   417  	l, addr := serverListen(u)
   418  
   419  	ch := make(chan struct{})
   420  	clCh := u.Go(func() error {
   421  		defer close(ch)
   422  		c, err := l.Accept()
   423  		if err != nil {
   424  			return fmt.Errorf("listener accept: %w", err)
   425  		}
   426  		defer c.Close()
   427  
   428  		b := make([]byte, 256)
   429  		n, err := c.Read(b)
   430  		if err != nil {
   431  			return fmt.Errorf("server read: %w", err)
   432  		}
   433  		if string(b[:n]) != testStr {
   434  			return fmt.Errorf("server got %q; wanted %q", b[:n], testStr)
   435  		}
   436  
   437  		n, err = c.Read(b)
   438  		if n != 0 {
   439  			return fmt.Errorf("server did not get EOF")
   440  		}
   441  		if !errors.Is(err, io.EOF) {
   442  			return errors.New("server did not get EOF")
   443  		}
   444  
   445  		_, err = c.Write([]byte(testStr))
   446  		if err != nil {
   447  			return fmt.Errorf("server tx: %w", err)
   448  		}
   449  
   450  		ch <- struct{}{}
   451  
   452  		_, err = c.Write([]byte(testStr))
   453  		if err != nil {
   454  			return fmt.Errorf("server tx: %w", err)
   455  		}
   456  		return c.Close()
   457  	})
   458  
   459  	cl, err := Dial(context.Background(), addr)
   460  	u.Must(err, "could not dial")
   461  	defer cl.Close()
   462  
   463  	//
   464  	// test CloseWrite()
   465  	//
   466  	_, err = cl.Write([]byte(testStr))
   467  	u.Must(err, "client write")
   468  	u.Must(cl.CloseWrite(), "client close write")
   469  
   470  	_, err = cl.Write([]byte(testStr))
   471  	u.Is(err, windows.WSAESHUTDOWN, "client did not shutdown writes")
   472  
   473  	// safe to call multiple times
   474  	u.Must(cl.CloseWrite(), "client second close write")
   475  
   476  	//
   477  	// test CloseRead()
   478  	//
   479  	b := make([]byte, 256)
   480  	n, err := cl.Read(b)
   481  	u.Must(err, "client read")
   482  	u.Assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr))
   483  	u.Must(cl.CloseRead(), "client close read")
   484  
   485  	u.Wait(ch, time.Millisecond)
   486  
   487  	// signal the client to send more info
   488  	// if it was sent before, the read would succeed if the data was buffered prior
   489  	_, err = cl.Read(b)
   490  	u.Is(err, windows.WSAESHUTDOWN, "client did not shutdown reads")
   491  
   492  	// safe to call multiple times
   493  	u.Must(cl.CloseRead(), "client second close write")
   494  
   495  	l.Close()
   496  	cl.Close()
   497  
   498  	wantErr := socket.ErrSocketClosed
   499  	u.Is(cl.CloseWrite(), wantErr, "client close write")
   500  	u.Is(cl.CloseRead(), wantErr, "client close read")
   501  	u.WaitErr(clCh, time.Second, "client")
   502  }
   503  
   504  func TestHvSockDialNoTimeout(t *testing.T) {
   505  	u := newUtil(t)
   506  	ctx, cancel := context.WithCancel(context.Background())
   507  	defer cancel()
   508  	ch := u.Go(func() error {
   509  		addr := randHvsockAddr()
   510  		cl, err := Dial(ctx, addr)
   511  		if err == nil {
   512  			cl.Close()
   513  		}
   514  		if !errors.Is(err, windows.WSAECONNREFUSED) {
   515  			return err
   516  		}
   517  		return nil
   518  	})
   519  
   520  	// connections usually take about ~500µs
   521  	u.WaitErr(ch, 2*time.Millisecond, "dial did not time out")
   522  }
   523  
   524  func TestHvSockDialDeadline(t *testing.T) {
   525  	u := newUtil(t)
   526  	d := &HvsockDialer{}
   527  	d.Deadline = time.Now().Add(50 * time.Microsecond)
   528  	d.Retries = 1
   529  	// we need the wait time to be long enough for the deadline goroutine to run first and signal
   530  	// timeout
   531  	d.RetryWait = 100 * time.Millisecond
   532  	addr := randHvsockAddr()
   533  	cl, err := d.Dial(context.Background(), addr)
   534  	if err == nil {
   535  		cl.Close()
   536  		t.Fatalf("dial should not have finished")
   537  	}
   538  	u.Is(err, context.DeadlineExceeded, "dial did not exceed deadline")
   539  }
   540  
   541  func TestHvSockDialContext(t *testing.T) {
   542  	u := newUtil(t)
   543  	ctx, cancel := context.WithCancel(context.Background())
   544  	time.AfterFunc(50*time.Microsecond, cancel)
   545  
   546  	d := &HvsockDialer{}
   547  	d.Retries = 1
   548  	d.RetryWait = 100 * time.Millisecond
   549  	addr := randHvsockAddr()
   550  	cl, err := d.Dial(ctx, addr)
   551  	if err == nil {
   552  		cl.Close()
   553  		t.Fatalf("dial should not have finished")
   554  	}
   555  	u.Is(err, context.Canceled, "dial was not canceled")
   556  }
   557  
   558  func TestHvSockAcceptClose(t *testing.T) {
   559  	u := newUtil(t)
   560  	l, _ := serverListen(u)
   561  	go func() {
   562  		time.Sleep(50 * time.Millisecond)
   563  		l.Close()
   564  	}()
   565  
   566  	c, err := l.Accept()
   567  	if err == nil {
   568  		c.Close()
   569  		t.Fatal("listener should not have accepted anything")
   570  	}
   571  	u.Is(err, ErrFileClosed)
   572  }
   573  
   574  //
   575  // helpers
   576  //
   577  
   578  type testUtil struct {
   579  	T testing.TB
   580  }
   581  
   582  func newUtil(t testing.TB) testUtil {
   583  	return testUtil{
   584  		T: t,
   585  	}
   586  }
   587  
   588  // Go launches f in a go routine and returns a channel that can be monitored for the result.
   589  // ch is closed after f completes.
   590  //
   591  // Intended for use with [testUtil.WaitErr].
   592  func (*testUtil) Go(f func() error) chan error {
   593  	ch := make(chan error)
   594  	go func() {
   595  		defer close(ch)
   596  		ch <- f()
   597  	}()
   598  	return ch
   599  }
   600  
   601  func (u testUtil) Wait(ch <-chan struct{}, d time.Duration, msgs ...string) {
   602  	t := time.NewTimer(d)
   603  	defer t.Stop()
   604  	select {
   605  	case <-ch:
   606  	case <-t.C:
   607  		u.T.Helper()
   608  		u.T.Fatalf(msgJoin(msgs, "timed out after %v"), d)
   609  	}
   610  }
   611  
   612  func (u testUtil) WaitErr(ch <-chan error, d time.Duration, msgs ...string) {
   613  	t := time.NewTimer(d)
   614  	defer t.Stop()
   615  	select {
   616  	case err := <-ch:
   617  		if err != nil {
   618  			u.T.Helper()
   619  			u.T.Fatalf(msgJoin(msgs, "%v"), err)
   620  		}
   621  	case <-t.C:
   622  		u.T.Helper()
   623  		u.T.Fatalf(msgJoin(msgs, "timed out after %v"), d)
   624  	}
   625  }
   626  
   627  func (u testUtil) Assert(b bool, msgs ...string) {
   628  	if b {
   629  		return
   630  	}
   631  	u.T.Helper()
   632  	u.T.Fatalf(msgJoin(msgs, "failed assertion"))
   633  }
   634  
   635  func (u testUtil) Is(err, target error, msgs ...string) {
   636  	if errors.Is(err, target) {
   637  		return
   638  	}
   639  	u.T.Helper()
   640  	u.T.Fatalf(msgJoin(msgs, "got error %q; wanted %q"), err, target)
   641  }
   642  
   643  func (u testUtil) Must(err error, msgs ...string) {
   644  	if err == nil {
   645  		return
   646  	}
   647  	u.T.Helper()
   648  	u.T.Fatalf(msgJoin(msgs, "%v"), err)
   649  }
   650  
   651  // Check stops execution if testing failed in another go-routine.
   652  func (u testUtil) Check() {
   653  	if u.T.Failed() {
   654  		u.T.FailNow()
   655  	}
   656  }
   657  
   658  func msgJoin(pre []string, s string) string {
   659  	return strings.Join(append(pre, s), ": ")
   660  }
   661  

View as plain text