1
2
3 package winio
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "io"
10 "math/rand"
11 "strings"
12 "testing"
13 "time"
14
15 "golang.org/x/sys/windows"
16
17 "github.com/Microsoft/go-winio/internal/socket"
18 "github.com/Microsoft/go-winio/pkg/guid"
19 )
20
21 const testStr = "test"
22
23 func randHvsockAddr() *HvsockAddr {
24 p := rand.Uint32()
25 return &HvsockAddr{
26 VMID: HvsockGUIDLoopback(),
27 ServiceID: VsockServiceID(p),
28 }
29 }
30
31 func serverListen(u testUtil) (l *HvsockListener, a *HvsockAddr) {
32 var err error
33 for i := 0; i < 3; i++ {
34 a = randHvsockAddr()
35 l, err = ListenHvsock(a)
36 if errors.Is(err, windows.WSAEADDRINUSE) {
37 u.T.Logf("address collision %v", a)
38 continue
39 }
40 break
41 }
42 u.Must(err, "could not listen")
43 u.T.Cleanup(func() {
44 if l != nil {
45 u.Must(l.Close(), "Hyper-V socket listener close")
46 }
47 })
48
49 return l, a
50 }
51
52 func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) {
53 l, addr := serverListen(u)
54 ch := u.Go(func() error {
55 conn, err := l.Accept()
56 if err != nil {
57 return fmt.Errorf("listener accept: %w", err)
58 }
59 sv = conn.(*HvsockConn)
60 if err := l.Close(); err != nil {
61 return err
62 }
63 l = nil
64 return nil
65 })
66
67 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
68 defer cancel()
69 cl, err := Dial(ctx, addr)
70 u.Must(err, "could not dial")
71 u.T.Cleanup(func() {
72 if cl != nil {
73 u.Must(cl.Close(), "client close")
74 }
75 })
76
77 u.WaitErr(ch, time.Second)
78 u.T.Cleanup(func() {
79 if sv != nil {
80 u.Must(sv.Close(), "server close")
81 }
82 })
83 return cl, sv, addr
84 }
85
86 func TestHvSockConstants(t *testing.T) {
87 tests := []struct {
88 name string
89 want string
90 give guid.GUID
91 }{
92 {"wildcard", "00000000-0000-0000-0000-000000000000", HvsockGUIDWildcard()},
93 {"broadcast", "ffffffff-ffff-ffff-ffff-ffffffffffff", HvsockGUIDBroadcast()},
94 {"loopback", "e0e16197-dd56-4a10-9195-5ee7a155a838", HvsockGUIDLoopback()},
95 {"children", "90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd", HvsockGUIDChildren()},
96 {"parent", "a42e7cda-d03f-480c-9cc2-a4de20abb878", HvsockGUIDParent()},
97 {"silohost", "36bd0c5c-7276-4223-88ba-7d03b654c568", HvsockGUIDSiloHost()},
98 {"vsock template", "00000000-facb-11e6-bd58-64006a7986d3", hvsockVsockServiceTemplate()},
99 }
100 for _, tt := range tests {
101 if tt.give.String() != tt.want {
102 t.Errorf("%s give: %v; want: %s", tt.name, tt.give, tt.want)
103 }
104 }
105 }
106
107 func TestHvSockListenerAddresses(t *testing.T) {
108 u := newUtil(t)
109 l, addr := serverListen(u)
110
111 la := (l.Addr()).(*HvsockAddr)
112 u.Assert(*la == *addr, fmt.Sprintf("give: %v; want: %v", la, addr))
113
114 ra := rawHvsockAddr{}
115 sa := HvsockAddr{}
116 u.Must(socket.GetSockName(windows.Handle(l.sock.handle), &ra))
117 sa.fromRaw(&ra)
118 u.Assert(sa == *addr, fmt.Sprintf("listener local addr give: %v; want: %v", sa, addr))
119 }
120
121 func TestHvSockAddresses(t *testing.T) {
122 u := newUtil(t)
123 cl, sv, addr := clientServer(u)
124
125 sra := (sv.RemoteAddr()).(*HvsockAddr)
126 sla := (sv.LocalAddr()).(*HvsockAddr)
127 cra := (cl.RemoteAddr()).(*HvsockAddr)
128 cla := (cl.LocalAddr()).(*HvsockAddr)
129
130 t.Run("Info", func(t *testing.T) {
131 tests := []struct {
132 name string
133 give *HvsockAddr
134 want HvsockAddr
135 }{
136 {"client local", cla, HvsockAddr{HvsockGUIDChildren(), sra.ServiceID}},
137 {"client remote", cra, *addr},
138 {"server local", sla, HvsockAddr{HvsockGUIDChildren(), addr.ServiceID}},
139 {"server remote", sra, HvsockAddr{HvsockGUIDLoopback(), cla.ServiceID}},
140 }
141 for _, tt := range tests {
142 if *tt.give != tt.want {
143 t.Errorf("%s address give: %v; want: %v", tt.name, tt.give, tt.want)
144 }
145 }
146 })
147
148 t.Run("OSinfo", func(t *testing.T) {
149 u := newUtil(t)
150 ra := rawHvsockAddr{}
151 sa := HvsockAddr{}
152
153 localTests := []struct {
154 name string
155 giveSock *win32File
156 wantAddr HvsockAddr
157 }{
158 {"client", cl.sock, HvsockAddr{HvsockGUIDChildren(), cla.ServiceID}},
159
160
161
162 }
163 for _, tt := range localTests {
164 u.Must(socket.GetSockName(windows.Handle(tt.giveSock.handle), &ra))
165 sa.fromRaw(&ra)
166 if sa != tt.wantAddr {
167 t.Errorf("%s local addr give: %v; want: %v", tt.name, sa, tt.wantAddr)
168 }
169 }
170
171 remoteTests := []struct {
172 name string
173 giveConn *HvsockConn
174 }{
175 {"client", cl},
176 {"server", sv},
177 }
178 for _, tt := range remoteTests {
179 u.Must(socket.GetPeerName(windows.Handle(tt.giveConn.sock.handle), &ra))
180 sa.fromRaw(&ra)
181 if sa != tt.giveConn.remote {
182 t.Errorf("%s remote addr give: %v; want: %v", tt.name, sa, tt.giveConn.remote)
183 }
184 }
185 })
186 }
187
188 func TestHvSockReadWrite(t *testing.T) {
189 u := newUtil(t)
190 l, addr := serverListen(u)
191 tests := []struct {
192 req, rsp string
193 }{
194 {"hello ", "world!"},
195 {"ping", "pong"},
196 }
197
198
199
200 svCh := u.Go(func() error {
201 c, err := l.Accept()
202 if err != nil {
203 return fmt.Errorf("listener accept: %w", err)
204 }
205 defer c.Close()
206
207 b := make([]byte, 64)
208 for _, tt := range tests {
209 n, err := c.Read(b)
210 if err != nil {
211 return fmt.Errorf("server rx: %w", err)
212 }
213
214 r := string(b[:n])
215 if r != tt.req {
216 return fmt.Errorf("server rx error: got %q; wanted %q", r, tt.req)
217 }
218 if _, err = c.Write([]byte(tt.rsp)); err != nil {
219 return fmt.Errorf("server tx error, could not send %q: %w", tt.rsp, err)
220 }
221 }
222 n, err := c.Read(b)
223 if n != 0 {
224 return errors.New("server did not get EOF")
225 }
226 if !errors.Is(err, io.EOF) {
227 return fmt.Errorf("server did not get EOF: %w", err)
228 }
229 return nil
230 })
231
232 clCh := u.Go(func() error {
233 cl, err := Dial(context.Background(), addr)
234 if err != nil {
235 return fmt.Errorf("client dial: %w", err)
236 }
237 defer cl.Close()
238
239 b := make([]byte, 64)
240 for _, tt := range tests {
241 _, err := cl.Write([]byte(tt.req))
242 if err != nil {
243 return fmt.Errorf("client tx error, could not send %q: %w", tt.req, err)
244 }
245
246 n, err := cl.Read(b)
247 if err != nil {
248 return fmt.Errorf("client tx: %w", err)
249 }
250
251 r := string(b[:n])
252 if r != tt.rsp {
253 return fmt.Errorf("client rx error: got %q; wanted %q", b[:n], tt.rsp)
254 }
255 }
256 return cl.CloseWrite()
257 })
258
259 u.WaitErr(svCh, 15*time.Second, "server")
260 u.WaitErr(clCh, 15*time.Second, "client")
261 }
262
263 func TestHvSockReadTooSmall(t *testing.T) {
264 u := newUtil(t)
265 s := "this is a really long string that hopefully takes up more than 16 bytes ..."
266 l, addr := serverListen(u)
267
268 svCh := u.Go(func() error {
269 c, err := l.Accept()
270 if err != nil {
271 return fmt.Errorf("listener accept: %w", err)
272 }
273 defer c.Close()
274
275 b := make([]byte, 16)
276 ss := ""
277 for {
278 n, err := c.Read(b)
279 if errors.Is(err, io.EOF) {
280 break
281 }
282 if err != nil {
283 return fmt.Errorf("server rx: %w", err)
284 }
285 ss += string(b[:n])
286 }
287
288 if ss != s {
289 return fmt.Errorf("got %q, wanted: %q", ss, s)
290 }
291 return nil
292 })
293
294 clCh := u.Go(func() error {
295 cl, err := Dial(context.Background(), addr)
296 if err != nil {
297 return fmt.Errorf("client dial: %w", err)
298 }
299 defer cl.Close()
300
301 if _, err = cl.Write([]byte(s)); err != nil {
302 return fmt.Errorf("client tx error, could not send: %w", err)
303 }
304 return nil
305 })
306
307 u.WaitErr(svCh, 15*time.Second, "server")
308 u.WaitErr(clCh, 15*time.Second, "client")
309 }
310
311 func TestHvSockCloseReadWriteListener(t *testing.T) {
312 u := newUtil(t)
313 l, addr := serverListen(u)
314
315 ch := make(chan struct{})
316 svCh := u.Go(func() error {
317 defer close(ch)
318 c, err := l.Accept()
319 if err != nil {
320 return fmt.Errorf("listener accept: %w", err)
321 }
322 defer c.Close()
323
324 hv := c.(*HvsockConn)
325
326
327
328 n, err := c.Write([]byte(testStr))
329 if err != nil {
330 return fmt.Errorf("server tx: %w", err)
331 }
332 if n != len(testStr) {
333 return fmt.Errorf("server wrote %d bytes, wanted %d", n, len(testStr))
334 }
335
336 if err := hv.CloseWrite(); err != nil {
337 return fmt.Errorf("server close write: %w", err)
338 }
339
340 if _, err = c.Write([]byte(testStr)); !errors.Is(err, windows.WSAESHUTDOWN) {
341 return fmt.Errorf("server did not shutdown writes: %w", err)
342 }
343
344 if err := hv.CloseWrite(); err != nil {
345 return fmt.Errorf("server second close write: %w", err)
346 }
347
348
349
350
351 b := make([]byte, 256)
352 n, err = c.Read(b)
353 if err != nil {
354 return fmt.Errorf("server read: %w", err)
355 }
356 if n != len(testStr) {
357 return fmt.Errorf("server read %d bytes, wanted %d", n, len(testStr))
358 }
359 if string(b[:n]) != testStr {
360 return fmt.Errorf("server got %q; wanted %q", b[:n], testStr)
361 }
362 if err := hv.CloseRead(); err != nil {
363 return fmt.Errorf("server close read: %w", err)
364 }
365
366 ch <- struct{}{}
367
368
369
370 _, err = c.Read(b)
371 if !errors.Is(err, windows.WSAESHUTDOWN) {
372 return fmt.Errorf("server did not shutdown reads: %w", err)
373 }
374
375 if err := hv.CloseRead(); err != nil {
376 return fmt.Errorf("server second close read: %w", err)
377 }
378
379 c.Close()
380 if err := hv.CloseWrite(); !errors.Is(err, socket.ErrSocketClosed) {
381 return fmt.Errorf("server close write: %w", err)
382 }
383 if err := hv.CloseRead(); !errors.Is(err, socket.ErrSocketClosed) {
384 return fmt.Errorf("server close read: %w", err)
385 }
386 return nil
387 })
388
389 cl, err := Dial(context.Background(), addr)
390 u.Must(err, "could not dial")
391 defer cl.Close()
392
393 b := make([]byte, 256)
394 n, err := cl.Read(b)
395 u.Must(err, "client read")
396 u.Assert(n == len(testStr), fmt.Sprintf("client read %d bytes, wanted %d", n, len(testStr)))
397 u.Assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr))
398
399 n, err = cl.Read(b)
400 u.Assert(n == 0, "client did not get EOF")
401 u.Is(err, io.EOF, "client did not get EOF")
402
403 n, err = cl.Write([]byte(testStr))
404 u.Must(err, "client write")
405 u.Assert(n == len(testStr), fmt.Sprintf("client wrote %d bytes, wanted %d", n, len(testStr)))
406
407 u.Wait(ch, time.Second)
408
409
410 _, err = cl.Write([]byte("test2"))
411 u.Must(err, "client write")
412 u.WaitErr(svCh, time.Second, "server")
413 }
414
415 func TestHvSockCloseReadWriteDial(t *testing.T) {
416 u := newUtil(t)
417 l, addr := serverListen(u)
418
419 ch := make(chan struct{})
420 clCh := u.Go(func() error {
421 defer close(ch)
422 c, err := l.Accept()
423 if err != nil {
424 return fmt.Errorf("listener accept: %w", err)
425 }
426 defer c.Close()
427
428 b := make([]byte, 256)
429 n, err := c.Read(b)
430 if err != nil {
431 return fmt.Errorf("server read: %w", err)
432 }
433 if string(b[:n]) != testStr {
434 return fmt.Errorf("server got %q; wanted %q", b[:n], testStr)
435 }
436
437 n, err = c.Read(b)
438 if n != 0 {
439 return fmt.Errorf("server did not get EOF")
440 }
441 if !errors.Is(err, io.EOF) {
442 return errors.New("server did not get EOF")
443 }
444
445 _, err = c.Write([]byte(testStr))
446 if err != nil {
447 return fmt.Errorf("server tx: %w", err)
448 }
449
450 ch <- struct{}{}
451
452 _, err = c.Write([]byte(testStr))
453 if err != nil {
454 return fmt.Errorf("server tx: %w", err)
455 }
456 return c.Close()
457 })
458
459 cl, err := Dial(context.Background(), addr)
460 u.Must(err, "could not dial")
461 defer cl.Close()
462
463
464
465
466 _, err = cl.Write([]byte(testStr))
467 u.Must(err, "client write")
468 u.Must(cl.CloseWrite(), "client close write")
469
470 _, err = cl.Write([]byte(testStr))
471 u.Is(err, windows.WSAESHUTDOWN, "client did not shutdown writes")
472
473
474 u.Must(cl.CloseWrite(), "client second close write")
475
476
477
478
479 b := make([]byte, 256)
480 n, err := cl.Read(b)
481 u.Must(err, "client read")
482 u.Assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr))
483 u.Must(cl.CloseRead(), "client close read")
484
485 u.Wait(ch, time.Millisecond)
486
487
488
489 _, err = cl.Read(b)
490 u.Is(err, windows.WSAESHUTDOWN, "client did not shutdown reads")
491
492
493 u.Must(cl.CloseRead(), "client second close write")
494
495 l.Close()
496 cl.Close()
497
498 wantErr := socket.ErrSocketClosed
499 u.Is(cl.CloseWrite(), wantErr, "client close write")
500 u.Is(cl.CloseRead(), wantErr, "client close read")
501 u.WaitErr(clCh, time.Second, "client")
502 }
503
504 func TestHvSockDialNoTimeout(t *testing.T) {
505 u := newUtil(t)
506 ctx, cancel := context.WithCancel(context.Background())
507 defer cancel()
508 ch := u.Go(func() error {
509 addr := randHvsockAddr()
510 cl, err := Dial(ctx, addr)
511 if err == nil {
512 cl.Close()
513 }
514 if !errors.Is(err, windows.WSAECONNREFUSED) {
515 return err
516 }
517 return nil
518 })
519
520
521 u.WaitErr(ch, 2*time.Millisecond, "dial did not time out")
522 }
523
524 func TestHvSockDialDeadline(t *testing.T) {
525 u := newUtil(t)
526 d := &HvsockDialer{}
527 d.Deadline = time.Now().Add(50 * time.Microsecond)
528 d.Retries = 1
529
530
531 d.RetryWait = 100 * time.Millisecond
532 addr := randHvsockAddr()
533 cl, err := d.Dial(context.Background(), addr)
534 if err == nil {
535 cl.Close()
536 t.Fatalf("dial should not have finished")
537 }
538 u.Is(err, context.DeadlineExceeded, "dial did not exceed deadline")
539 }
540
541 func TestHvSockDialContext(t *testing.T) {
542 u := newUtil(t)
543 ctx, cancel := context.WithCancel(context.Background())
544 time.AfterFunc(50*time.Microsecond, cancel)
545
546 d := &HvsockDialer{}
547 d.Retries = 1
548 d.RetryWait = 100 * time.Millisecond
549 addr := randHvsockAddr()
550 cl, err := d.Dial(ctx, addr)
551 if err == nil {
552 cl.Close()
553 t.Fatalf("dial should not have finished")
554 }
555 u.Is(err, context.Canceled, "dial was not canceled")
556 }
557
558 func TestHvSockAcceptClose(t *testing.T) {
559 u := newUtil(t)
560 l, _ := serverListen(u)
561 go func() {
562 time.Sleep(50 * time.Millisecond)
563 l.Close()
564 }()
565
566 c, err := l.Accept()
567 if err == nil {
568 c.Close()
569 t.Fatal("listener should not have accepted anything")
570 }
571 u.Is(err, ErrFileClosed)
572 }
573
574
575
576
577
578 type testUtil struct {
579 T testing.TB
580 }
581
582 func newUtil(t testing.TB) testUtil {
583 return testUtil{
584 T: t,
585 }
586 }
587
588
589
590
591
592 func (*testUtil) Go(f func() error) chan error {
593 ch := make(chan error)
594 go func() {
595 defer close(ch)
596 ch <- f()
597 }()
598 return ch
599 }
600
601 func (u testUtil) Wait(ch <-chan struct{}, d time.Duration, msgs ...string) {
602 t := time.NewTimer(d)
603 defer t.Stop()
604 select {
605 case <-ch:
606 case <-t.C:
607 u.T.Helper()
608 u.T.Fatalf(msgJoin(msgs, "timed out after %v"), d)
609 }
610 }
611
612 func (u testUtil) WaitErr(ch <-chan error, d time.Duration, msgs ...string) {
613 t := time.NewTimer(d)
614 defer t.Stop()
615 select {
616 case err := <-ch:
617 if err != nil {
618 u.T.Helper()
619 u.T.Fatalf(msgJoin(msgs, "%v"), err)
620 }
621 case <-t.C:
622 u.T.Helper()
623 u.T.Fatalf(msgJoin(msgs, "timed out after %v"), d)
624 }
625 }
626
627 func (u testUtil) Assert(b bool, msgs ...string) {
628 if b {
629 return
630 }
631 u.T.Helper()
632 u.T.Fatalf(msgJoin(msgs, "failed assertion"))
633 }
634
635 func (u testUtil) Is(err, target error, msgs ...string) {
636 if errors.Is(err, target) {
637 return
638 }
639 u.T.Helper()
640 u.T.Fatalf(msgJoin(msgs, "got error %q; wanted %q"), err, target)
641 }
642
643 func (u testUtil) Must(err error, msgs ...string) {
644 if err == nil {
645 return
646 }
647 u.T.Helper()
648 u.T.Fatalf(msgJoin(msgs, "%v"), err)
649 }
650
651
652 func (u testUtil) Check() {
653 if u.T.Failed() {
654 u.T.FailNow()
655 }
656 }
657
658 func msgJoin(pre []string, s string) string {
659 return strings.Join(append(pre, s), ": ")
660 }
661
View as plain text