1
2
3 package socket
4
5 import (
6 "errors"
7 "fmt"
8 "net"
9 "sync"
10 "syscall"
11 "unsafe"
12
13 "github.com/Microsoft/go-winio/pkg/guid"
14 "golang.org/x/sys/windows"
15 )
16
17
18
19
20
21
22
23 const socketError = uintptr(^uint32(0))
24
25 var (
26
27
28 ErrBufferSize = errors.New("buffer size")
29 ErrAddrFamily = errors.New("address family")
30 ErrInvalidPointer = errors.New("invalid pointer")
31 ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
32 )
33
34
35
36
37
38 func GetSockName(s windows.Handle, rsa RawSockaddr) error {
39 ptr, l, err := rsa.Sockaddr()
40 if err != nil {
41 return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
42 }
43
44
45
46 return getsockname(s, ptr, &l)
47 }
48
49
50
51
52 func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
53 ptr, l, err := rsa.Sockaddr()
54 if err != nil {
55 return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
56 }
57
58 return getpeername(s, ptr, &l)
59 }
60
61 func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
62 ptr, l, err := rsa.Sockaddr()
63 if err != nil {
64 return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
65 }
66
67 return bind(s, ptr, l)
68 }
69
70
71
72
73
74
75
76
77
78
79 type runtimeFunc struct {
80 id guid.GUID
81 once sync.Once
82 addr uintptr
83 err error
84 }
85
86 func (f *runtimeFunc) Load() error {
87 f.once.Do(func() {
88 var s windows.Handle
89 s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
90 if f.err != nil {
91 return
92 }
93 defer windows.CloseHandle(s)
94
95 var n uint32
96 f.err = windows.WSAIoctl(s,
97 windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
98 (*byte)(unsafe.Pointer(&f.id)),
99 uint32(unsafe.Sizeof(f.id)),
100 (*byte)(unsafe.Pointer(&f.addr)),
101 uint32(unsafe.Sizeof(f.addr)),
102 &n,
103 nil,
104 0,
105 )
106 })
107 return f.err
108 }
109
110 var (
111
112 WSAID_CONNECTEX = guid.GUID{
113 Data1: 0x25a207b9,
114 Data2: 0xddf3,
115 Data3: 0x4660,
116 Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
117 }
118
119 connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
120 )
121
122 func ConnectEx(
123 fd windows.Handle,
124 rsa RawSockaddr,
125 sendBuf *byte,
126 sendDataLen uint32,
127 bytesSent *uint32,
128 overlapped *windows.Overlapped,
129 ) error {
130 if err := connectExFunc.Load(); err != nil {
131 return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
132 }
133 ptr, n, err := rsa.Sockaddr()
134 if err != nil {
135 return err
136 }
137 return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
138 }
139
140
141
142
143
144
145
146
147
148
149
150 func connectEx(
151 s windows.Handle,
152 name unsafe.Pointer,
153 namelen int32,
154 sendBuf *byte,
155 sendDataLen uint32,
156 bytesSent *uint32,
157 overlapped *windows.Overlapped,
158 ) (err error) {
159
160 r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
161 7,
162 uintptr(s),
163 uintptr(name),
164 uintptr(namelen),
165 uintptr(unsafe.Pointer(sendBuf)),
166 uintptr(sendDataLen),
167 uintptr(unsafe.Pointer(bytesSent)),
168 uintptr(unsafe.Pointer(overlapped)),
169 0,
170 0)
171 if r1 == 0 {
172 if e1 != 0 {
173 err = error(e1)
174 } else {
175 err = syscall.EINVAL
176 }
177 }
178 return err
179 }
180
View as plain text