1
2
3
4 package bindfilter
5
6 import (
7 "bytes"
8 "encoding/binary"
9 "errors"
10 "fmt"
11 "os"
12 "path/filepath"
13 "strings"
14 "syscall"
15 "unsafe"
16
17 "golang.org/x/sys/windows"
18 )
19
20
21
22
23
24
25
26
27
28
29 const (
30 BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001
31
32
33 BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040
34 )
35
36
37 const (
38 BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001
39 BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002
40 BINDFLT_GET_MAPPINGS_FLAG_USER uint32 = 0x00000004
41 )
42
43
44
45
46
47
48
49
50 func ApplyFileBinding(root, source string, readOnly bool) error {
51
52
53
54 if err := os.MkdirAll(filepath.Dir(root), 0); err != nil {
55 return err
56 }
57
58 if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") {
59
60
61 source = source + "\\"
62 }
63
64 flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS
65 if readOnly {
66 flags |= BINDFLT_FLAG_READ_ONLY_MAPPING
67 }
68
69
70 if err := bfSetupFilter(
71 0,
72 flags,
73 root,
74 source,
75 nil,
76 0,
77 ); err != nil {
78 return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err)
79 }
80 return nil
81 }
82
83
84 func RemoveFileBinding(root string) error {
85 if err := bfRemoveMapping(0, root); err != nil {
86 return fmt.Errorf("removing file binding: %w", err)
87 }
88 return nil
89 }
90
91
92
93
94
95
96
97 func GetBindMappings(volumePath string) ([]BindMapping, error) {
98 rootPtr, err := windows.UTF16PtrFromString(volumePath)
99 if err != nil {
100 return nil, err
101 }
102
103 flags := BINDFLT_GET_MAPPINGS_FLAG_VOLUME
104
105 var outBuffSize uint32 = 256 * 1024
106 buf := make([]byte, outBuffSize)
107
108 if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, &buf[0]); err != nil {
109 return nil, err
110 }
111
112 if outBuffSize < 12 {
113 return nil, fmt.Errorf("invalid buffer returned")
114 }
115
116 result := buf[:outBuffSize]
117
118
119 headerBuffer := result[:12]
120
121
122 header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0]))
123
124 if header.MappingCount == 0 {
125
126 return []BindMapping{}, nil
127 }
128
129 mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)]
130
131 mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0]))
132
133 mappings := unsafe.Slice(mappingsPointer, header.MappingCount)
134
135 mappingEntries := make([]BindMapping, header.MappingCount)
136 for i := 0; i < int(header.MappingCount); i++ {
137 bindMapping, err := getBindMappingFromBuffer(result, mappings[i])
138 if err != nil {
139 return nil, fmt.Errorf("fetching bind mappings: %w", err)
140 }
141 mappingEntries[i] = bindMapping
142 }
143
144 return mappingEntries, nil
145 }
146
147
148
149
150 type mappingEntry struct {
151 VirtRootLength uint32
152 VirtRootOffset uint32
153 Flags uint32
154 NumberOfTargets uint32
155 TargetEntriesOffset uint32
156 }
157
158 type mappingTargetEntry struct {
159 TargetRootLength uint32
160 TargetRootOffset uint32
161 }
162
163
164
165
166 type getMappingsResponseHeader struct {
167 Size uint32
168 Status uint32
169 MappingCount uint32
170 }
171
172 type BindMapping struct {
173 MountPoint string
174 Flags uint32
175 Targets []string
176 }
177
178 func decodeEntry(buffer []byte) (string, error) {
179 name := make([]uint16, len(buffer)/2)
180 err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name)
181 if err != nil {
182 return "", fmt.Errorf("decoding name: %w", err)
183 }
184 return windows.UTF16ToString(name), nil
185 }
186
187 func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) {
188 if len(buffer) < offset+count*6 {
189 return nil, fmt.Errorf("invalid buffer")
190 }
191
192 targets := make([]string, count)
193 for i := 0; i < count; i++ {
194 entryBuf := buffer[offset+i*8 : offset+i*8+8]
195 tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0]))
196 if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) {
197 return nil, fmt.Errorf("invalid buffer")
198 }
199 decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : tgt.TargetRootOffset+tgt.TargetRootLength])
200 if err != nil {
201 return nil, fmt.Errorf("decoding name: %w", err)
202 }
203 decoded, err = getFinalPath(decoded)
204 if err != nil {
205 return nil, fmt.Errorf("fetching final path: %w", err)
206 }
207
208 targets[i] = decoded
209 }
210 return targets, nil
211 }
212
213 func getFinalPath(pth string) (string, error) {
214
215
216
217 if strings.HasPrefix(pth, `\Device`) {
218 pth = `\\.\GLOBALROOT` + pth
219 }
220
221 han, err := openPath(pth)
222 if err != nil {
223 return "", fmt.Errorf("fetching file handle: %w", err)
224 }
225 defer func() {
226 _ = windows.CloseHandle(han)
227 }()
228
229 buf := make([]uint16, 100)
230 var flags uint32 = 0x0
231 for {
232 n, err := windows.GetFinalPathNameByHandle(han, &buf[0], uint32(len(buf)), flags)
233 if err != nil {
234
235
236 if errors.Is(err, os.ErrNotExist) && flags != 0x1 {
237 flags = 0x1
238 continue
239 }
240 return "", fmt.Errorf("getting final path name: %w", err)
241 }
242 if n < uint32(len(buf)) {
243 break
244 }
245 buf = make([]uint16, n)
246 }
247 finalPath := syscall.UTF16ToString(buf)
248
249
250 if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 {
251 finalPath = finalPath[4:]
252 if len(finalPath) > 3 && finalPath[:3] == `UNC` {
253
254 finalPath = `\` + finalPath[3:]
255 }
256 }
257
258 return finalPath, nil
259 }
260
261 func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) {
262 if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) {
263 return BindMapping{}, fmt.Errorf("invalid buffer")
264 }
265
266 src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+entry.VirtRootLength])
267 if err != nil {
268 return BindMapping{}, fmt.Errorf("decoding entry: %w", err)
269 }
270 targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets))
271 if err != nil {
272 return BindMapping{}, fmt.Errorf("fetching targets: %w", err)
273 }
274
275 src, err = getFinalPath(src)
276 if err != nil {
277 return BindMapping{}, fmt.Errorf("fetching final path: %w", err)
278 }
279
280 return BindMapping{
281 Flags: entry.Flags,
282 Targets: targets,
283 MountPoint: src,
284 }, nil
285 }
286
287 func openPath(path string) (windows.Handle, error) {
288 u16, err := windows.UTF16PtrFromString(path)
289 if err != nil {
290 return 0, err
291 }
292 h, err := windows.CreateFile(
293 u16,
294 0,
295 windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE,
296 nil,
297 windows.OPEN_EXISTING,
298 windows.FILE_FLAG_BACKUP_SEMANTICS,
299 0)
300 if err != nil {
301 return 0, &os.PathError{
302 Op: "CreateFile",
303 Path: path,
304 Err: err,
305 }
306 }
307 return h, nil
308 }
309
View as plain text