1
2
3
4 package winio
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "net"
12 "os"
13 "syscall"
14 "time"
15 "unsafe"
16
17 "golang.org/x/sys/windows"
18
19 "github.com/Microsoft/go-winio/internal/socket"
20 "github.com/Microsoft/go-winio/pkg/guid"
21 )
22
23 const afHVSock = 34
24
25
26
27
28
29 func HvsockGUIDWildcard() guid.GUID {
30 return guid.GUID{}
31 }
32
33
34 func HvsockGUIDBroadcast() guid.GUID {
35 return guid.GUID{
36 Data1: 0xffffffff,
37 Data2: 0xffff,
38 Data3: 0xffff,
39 Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
40 }
41 }
42
43
44 func HvsockGUIDLoopback() guid.GUID {
45 return guid.GUID{
46 Data1: 0xe0e16197,
47 Data2: 0xdd56,
48 Data3: 0x4a10,
49 Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
50 }
51 }
52
53
54
55
56 func HvsockGUIDSiloHost() guid.GUID {
57 return guid.GUID{
58 Data1: 0x36bd0c5c,
59 Data2: 0x7276,
60 Data3: 0x4223,
61 Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
62 }
63 }
64
65
66 func HvsockGUIDChildren() guid.GUID {
67 return guid.GUID{
68 Data1: 0x90db8b89,
69 Data2: 0xd35,
70 Data3: 0x4f79,
71 Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
72 }
73 }
74
75
76
77
78
79
80
81 func HvsockGUIDParent() guid.GUID {
82 return guid.GUID{
83 Data1: 0xa42e7cda,
84 Data2: 0xd03f,
85 Data3: 0x480c,
86 Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
87 }
88 }
89
90
91 func hvsockVsockServiceTemplate() guid.GUID {
92 return guid.GUID{
93 Data2: 0xfacb,
94 Data3: 0x11e6,
95 Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
96 }
97 }
98
99
100 type HvsockAddr struct {
101 VMID guid.GUID
102 ServiceID guid.GUID
103 }
104
105 type rawHvsockAddr struct {
106 Family uint16
107 _ uint16
108 VMID guid.GUID
109 ServiceID guid.GUID
110 }
111
112 var _ socket.RawSockaddr = &rawHvsockAddr{}
113
114
115 func (*HvsockAddr) Network() string {
116 return "hvsock"
117 }
118
119 func (addr *HvsockAddr) String() string {
120 return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
121 }
122
123
124 func VsockServiceID(port uint32) guid.GUID {
125 g := hvsockVsockServiceTemplate()
126 g.Data1 = port
127 return g
128 }
129
130 func (addr *HvsockAddr) raw() rawHvsockAddr {
131 return rawHvsockAddr{
132 Family: afHVSock,
133 VMID: addr.VMID,
134 ServiceID: addr.ServiceID,
135 }
136 }
137
138 func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
139 addr.VMID = raw.VMID
140 addr.ServiceID = raw.ServiceID
141 }
142
143
144
145
146
147 func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
148 return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
149 }
150
151
152 func (r *rawHvsockAddr) FromBytes(b []byte) error {
153 n := int(unsafe.Sizeof(rawHvsockAddr{}))
154
155 if len(b) < n {
156 return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
157 }
158
159 copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
160 if r.Family != afHVSock {
161 return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
162 }
163
164 return nil
165 }
166
167
168 type HvsockListener struct {
169 sock *win32File
170 addr HvsockAddr
171 }
172
173 var _ net.Listener = &HvsockListener{}
174
175
176 type HvsockConn struct {
177 sock *win32File
178 local, remote HvsockAddr
179 }
180
181 var _ net.Conn = &HvsockConn{}
182
183 func newHVSocket() (*win32File, error) {
184 fd, err := syscall.Socket(afHVSock, syscall.SOCK_STREAM, 1)
185 if err != nil {
186 return nil, os.NewSyscallError("socket", err)
187 }
188 f, err := makeWin32File(fd)
189 if err != nil {
190 syscall.Close(fd)
191 return nil, err
192 }
193 f.socket = true
194 return f, nil
195 }
196
197
198 func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
199 l := &HvsockListener{addr: *addr}
200 sock, err := newHVSocket()
201 if err != nil {
202 return nil, l.opErr("listen", err)
203 }
204 sa := addr.raw()
205 err = socket.Bind(windows.Handle(sock.handle), &sa)
206 if err != nil {
207 return nil, l.opErr("listen", os.NewSyscallError("socket", err))
208 }
209 err = syscall.Listen(sock.handle, 16)
210 if err != nil {
211 return nil, l.opErr("listen", os.NewSyscallError("listen", err))
212 }
213 return &HvsockListener{sock: sock, addr: *addr}, nil
214 }
215
216 func (l *HvsockListener) opErr(op string, err error) error {
217 return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
218 }
219
220
221 func (l *HvsockListener) Addr() net.Addr {
222 return &l.addr
223 }
224
225
226 func (l *HvsockListener) Accept() (_ net.Conn, err error) {
227 sock, err := newHVSocket()
228 if err != nil {
229 return nil, l.opErr("accept", err)
230 }
231 defer func() {
232 if sock != nil {
233 sock.Close()
234 }
235 }()
236 c, err := l.sock.prepareIO()
237 if err != nil {
238 return nil, l.opErr("accept", err)
239 }
240 defer l.sock.wg.Done()
241
242
243
244
245 const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
246 var addrbuf [addrlen * 2]byte
247
248 var bytes uint32
249 err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 , addrlen, addrlen, &bytes, &c.o)
250 if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
251 return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
252 }
253
254 conn := &HvsockConn{
255 sock: sock,
256 }
257
258
259
260
261
262 conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
263 conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
264
265
266 if err = windows.Setsockopt(windows.Handle(sock.handle),
267 windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
268 (*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
269 return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
270 }
271
272 sock = nil
273 return conn, nil
274 }
275
276
277 func (l *HvsockListener) Close() error {
278 return l.sock.Close()
279 }
280
281
282 type HvsockDialer struct {
283
284 Deadline time.Time
285
286
287
288 Retries uint
289
290
291 RetryWait time.Duration
292
293 rt *time.Timer
294 }
295
296
297
298
299 func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
300 return (&HvsockDialer{}).Dial(ctx, addr)
301 }
302
303
304
305
306
307
308 func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
309 op := "dial"
310
311 conn = &HvsockConn{
312 remote: *addr,
313 }
314
315 if !d.Deadline.IsZero() {
316 var cancel context.CancelFunc
317 ctx, cancel = context.WithDeadline(ctx, d.Deadline)
318 defer cancel()
319 }
320
321
322 if err = ctx.Err(); err != nil {
323 return nil, conn.opErr(op, err)
324 }
325
326 sock, err := newHVSocket()
327 if err != nil {
328 return nil, conn.opErr(op, err)
329 }
330 defer func() {
331 if sock != nil {
332 sock.Close()
333 }
334 }()
335
336 sa := addr.raw()
337 err = socket.Bind(windows.Handle(sock.handle), &sa)
338 if err != nil {
339 return nil, conn.opErr(op, os.NewSyscallError("bind", err))
340 }
341
342 c, err := sock.prepareIO()
343 if err != nil {
344 return nil, conn.opErr(op, err)
345 }
346 defer sock.wg.Done()
347 var bytes uint32
348 for i := uint(0); i <= d.Retries; i++ {
349 err = socket.ConnectEx(
350 windows.Handle(sock.handle),
351 &sa,
352 nil,
353 0,
354 &bytes,
355 (*windows.Overlapped)(unsafe.Pointer(&c.o)))
356 _, err = sock.asyncIO(c, nil, bytes, err)
357 if i < d.Retries && canRedial(err) {
358 if err = d.redialWait(ctx); err == nil {
359 continue
360 }
361 }
362 break
363 }
364 if err != nil {
365 return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
366 }
367
368
369 if err = windows.Setsockopt(
370 windows.Handle(sock.handle),
371 windows.SOL_SOCKET,
372 windows.SO_UPDATE_CONNECT_CONTEXT,
373 nil,
374 0,
375 ); err != nil {
376 return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
377 }
378
379
380 var sal rawHvsockAddr
381 err = socket.GetSockName(windows.Handle(sock.handle), &sal)
382 if err != nil {
383 return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
384 }
385 conn.local.fromRaw(&sal)
386
387
388 if err = ctx.Err(); err != nil {
389 return nil, conn.opErr(op, err)
390 }
391
392 conn.sock = sock
393 sock = nil
394
395 return conn, nil
396 }
397
398
399 func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
400 if d.RetryWait == 0 {
401 return nil
402 }
403
404 if d.rt == nil {
405 d.rt = time.NewTimer(d.RetryWait)
406 } else {
407
408 d.rt.Reset(d.RetryWait)
409 }
410
411 select {
412 case <-ctx.Done():
413 case <-d.rt.C:
414 return nil
415 }
416
417
418 if !d.rt.Stop() {
419 <-d.rt.C
420 }
421 return ctx.Err()
422 }
423
424
425 func canRedial(err error) bool {
426
427 switch err {
428 case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
429 windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
430 return true
431 default:
432 return false
433 }
434 }
435
436 func (conn *HvsockConn) opErr(op string, err error) error {
437
438 if errors.Is(err, ErrFileClosed) {
439 err = socket.ErrSocketClosed
440 }
441 return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
442 }
443
444 func (conn *HvsockConn) Read(b []byte) (int, error) {
445 c, err := conn.sock.prepareIO()
446 if err != nil {
447 return 0, conn.opErr("read", err)
448 }
449 defer conn.sock.wg.Done()
450 buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
451 var flags, bytes uint32
452 err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
453 n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
454 if err != nil {
455 var eno windows.Errno
456 if errors.As(err, &eno) {
457 err = os.NewSyscallError("wsarecv", eno)
458 }
459 return 0, conn.opErr("read", err)
460 } else if n == 0 {
461 err = io.EOF
462 }
463 return n, err
464 }
465
466 func (conn *HvsockConn) Write(b []byte) (int, error) {
467 t := 0
468 for len(b) != 0 {
469 n, err := conn.write(b)
470 if err != nil {
471 return t + n, err
472 }
473 t += n
474 b = b[n:]
475 }
476 return t, nil
477 }
478
479 func (conn *HvsockConn) write(b []byte) (int, error) {
480 c, err := conn.sock.prepareIO()
481 if err != nil {
482 return 0, conn.opErr("write", err)
483 }
484 defer conn.sock.wg.Done()
485 buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
486 var bytes uint32
487 err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
488 n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
489 if err != nil {
490 var eno windows.Errno
491 if errors.As(err, &eno) {
492 err = os.NewSyscallError("wsasend", eno)
493 }
494 return 0, conn.opErr("write", err)
495 }
496 return n, err
497 }
498
499
500 func (conn *HvsockConn) Close() error {
501 return conn.sock.Close()
502 }
503
504 func (conn *HvsockConn) IsClosed() bool {
505 return conn.sock.IsClosed()
506 }
507
508
509 func (conn *HvsockConn) shutdown(how int) error {
510 if conn.IsClosed() {
511 return socket.ErrSocketClosed
512 }
513
514 err := syscall.Shutdown(conn.sock.handle, how)
515 if err != nil {
516
517 if errors.Is(err, windows.WSAENOTCONN) ||
518 errors.Is(err, windows.WSAESHUTDOWN) {
519 err = socket.ErrSocketClosed
520 }
521 return os.NewSyscallError("shutdown", err)
522 }
523 return nil
524 }
525
526
527 func (conn *HvsockConn) CloseRead() error {
528 err := conn.shutdown(syscall.SHUT_RD)
529 if err != nil {
530 return conn.opErr("closeread", err)
531 }
532 return nil
533 }
534
535
536
537 func (conn *HvsockConn) CloseWrite() error {
538 err := conn.shutdown(syscall.SHUT_WR)
539 if err != nil {
540 return conn.opErr("closewrite", err)
541 }
542 return nil
543 }
544
545
546 func (conn *HvsockConn) LocalAddr() net.Addr {
547 return &conn.local
548 }
549
550
551 func (conn *HvsockConn) RemoteAddr() net.Addr {
552 return &conn.remote
553 }
554
555
556 func (conn *HvsockConn) SetDeadline(t time.Time) error {
557
558 if err := conn.SetReadDeadline(t); err != nil {
559 return fmt.Errorf("set read deadline: %w", err)
560 }
561 if err := conn.SetWriteDeadline(t); err != nil {
562 return fmt.Errorf("set write deadline: %w", err)
563 }
564 return nil
565 }
566
567
568 func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
569 return conn.sock.SetReadDeadline(t)
570 }
571
572
573 func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
574 return conn.sock.SetWriteDeadline(t)
575 }
576
View as plain text