...

Source file src/github.com/Microsoft/go-winio/pkg/bindfilter/bind_filter.go

Documentation: github.com/Microsoft/go-winio/pkg/bindfilter

     1  //go:build windows
     2  // +build windows
     3  
     4  package bindfilter
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"syscall"
    15  	"unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  )
    19  
    20  //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go
    21  //sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter?
    22  //sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath string)  (hr error) = bindfltapi.BfRemoveMapping?
    23  //sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer *byte)  (hr error) = bindfltapi.BfGetMappings?
    24  
    25  // BfSetupFilter flags. See:
    26  // https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240
    27  //
    28  //nolint:revive // var-naming: ALL_CAPS
    29  const (
    30  	BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001
    31  	// Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces
    32  	// multiple targets.
    33  	BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040
    34  )
    35  
    36  //nolint:revive // var-naming: ALL_CAPS
    37  const (
    38  	BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001
    39  	BINDFLT_GET_MAPPINGS_FLAG_SILO   uint32 = 0x00000002
    40  	BINDFLT_GET_MAPPINGS_FLAG_USER   uint32 = 0x00000004
    41  )
    42  
    43  // ApplyFileBinding creates a global mount of the source in root, with an optional
    44  // read only flag.
    45  // The bind filter allows us to create mounts of directories and volumes. By default it allows
    46  // us to mount multiple sources inside a single root, acting as an overlay. Files from the
    47  // second source will superscede the first source that was mounted.
    48  // This function disables this behavior and sets the BINDFLT_FLAG_NO_MULTIPLE_TARGETS flag
    49  // on the mount.
    50  func ApplyFileBinding(root, source string, readOnly bool) error {
    51  	// The parent directory needs to exist for the bind to work. MkdirAll stats and
    52  	// returns nil if the directory exists internally so we should be fine to mkdirall
    53  	// every time.
    54  	if err := os.MkdirAll(filepath.Dir(root), 0); err != nil {
    55  		return err
    56  	}
    57  
    58  	if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") {
    59  		// Add trailing slash to volumes, otherwise we get an error when binding it to
    60  		// a folder.
    61  		source = source + "\\"
    62  	}
    63  
    64  	flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS
    65  	if readOnly {
    66  		flags |= BINDFLT_FLAG_READ_ONLY_MAPPING
    67  	}
    68  
    69  	// Set the job handle to 0 to create a global mount.
    70  	if err := bfSetupFilter(
    71  		0,
    72  		flags,
    73  		root,
    74  		source,
    75  		nil,
    76  		0,
    77  	); err != nil {
    78  		return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err)
    79  	}
    80  	return nil
    81  }
    82  
    83  // RemoveFileBinding removes a mount from the root path.
    84  func RemoveFileBinding(root string) error {
    85  	if err := bfRemoveMapping(0, root); err != nil {
    86  		return fmt.Errorf("removing file binding: %w", err)
    87  	}
    88  	return nil
    89  }
    90  
    91  // GetBindMappings returns a list of bind mappings that have their root on a
    92  // particular volume. The volumePath parameter can be any path that exists on
    93  // a volume. For example, if a number of mappings are created in C:\ProgramData\test,
    94  // to get a list of those mappings, the volumePath parameter would have to be set to
    95  // C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child
    96  // path that exists.
    97  func GetBindMappings(volumePath string) ([]BindMapping, error) {
    98  	rootPtr, err := windows.UTF16PtrFromString(volumePath)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	flags := BINDFLT_GET_MAPPINGS_FLAG_VOLUME
   104  	// allocate a large buffer for results
   105  	var outBuffSize uint32 = 256 * 1024
   106  	buf := make([]byte, outBuffSize)
   107  
   108  	if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, &buf[0]); err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	if outBuffSize < 12 {
   113  		return nil, fmt.Errorf("invalid buffer returned")
   114  	}
   115  
   116  	result := buf[:outBuffSize]
   117  
   118  	// The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{}
   119  	headerBuffer := result[:12]
   120  	// The alternative to using unsafe and casting it to the above defined structures, is to manually
   121  	// parse the fields. Not too terrible, but not sure it'd worth the trouble.
   122  	header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0]))
   123  
   124  	if header.MappingCount == 0 {
   125  		// no mappings
   126  		return []BindMapping{}, nil
   127  	}
   128  
   129  	mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)]
   130  	// Get a pointer to the first mapping in the slice
   131  	mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0]))
   132  	// Get slice of mappings
   133  	mappings := unsafe.Slice(mappingsPointer, header.MappingCount)
   134  
   135  	mappingEntries := make([]BindMapping, header.MappingCount)
   136  	for i := 0; i < int(header.MappingCount); i++ {
   137  		bindMapping, err := getBindMappingFromBuffer(result, mappings[i])
   138  		if err != nil {
   139  			return nil, fmt.Errorf("fetching bind mappings: %w", err)
   140  		}
   141  		mappingEntries[i] = bindMapping
   142  	}
   143  
   144  	return mappingEntries, nil
   145  }
   146  
   147  // mappingEntry holds information about where in the response buffer we can
   148  // find information about the virtual root (the mount point) and the targets (sources)
   149  // that get mounted, as well as the flags used to bind the targets to the virtual root.
   150  type mappingEntry struct {
   151  	VirtRootLength      uint32
   152  	VirtRootOffset      uint32
   153  	Flags               uint32
   154  	NumberOfTargets     uint32
   155  	TargetEntriesOffset uint32
   156  }
   157  
   158  type mappingTargetEntry struct {
   159  	TargetRootLength uint32
   160  	TargetRootOffset uint32
   161  }
   162  
   163  // getMappingsResponseHeader represents the first 12 bytes of the BfGetMappings() response.
   164  // It gives us the size of the buffer, the status of the call and the number of mappings.
   165  // A response
   166  type getMappingsResponseHeader struct {
   167  	Size         uint32
   168  	Status       uint32
   169  	MappingCount uint32
   170  }
   171  
   172  type BindMapping struct {
   173  	MountPoint string
   174  	Flags      uint32
   175  	Targets    []string
   176  }
   177  
   178  func decodeEntry(buffer []byte) (string, error) {
   179  	name := make([]uint16, len(buffer)/2)
   180  	err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name)
   181  	if err != nil {
   182  		return "", fmt.Errorf("decoding name: %w", err)
   183  	}
   184  	return windows.UTF16ToString(name), nil
   185  }
   186  
   187  func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) {
   188  	if len(buffer) < offset+count*6 {
   189  		return nil, fmt.Errorf("invalid buffer")
   190  	}
   191  
   192  	targets := make([]string, count)
   193  	for i := 0; i < count; i++ {
   194  		entryBuf := buffer[offset+i*8 : offset+i*8+8]
   195  		tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0]))
   196  		if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) {
   197  			return nil, fmt.Errorf("invalid buffer")
   198  		}
   199  		decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : tgt.TargetRootOffset+tgt.TargetRootLength])
   200  		if err != nil {
   201  			return nil, fmt.Errorf("decoding name: %w", err)
   202  		}
   203  		decoded, err = getFinalPath(decoded)
   204  		if err != nil {
   205  			return nil, fmt.Errorf("fetching final path: %w", err)
   206  		}
   207  
   208  		targets[i] = decoded
   209  	}
   210  	return targets, nil
   211  }
   212  
   213  func getFinalPath(pth string) (string, error) {
   214  	// BfGetMappings returns VOLUME_NAME_NT paths like \Device\HarddiskVolume2\ProgramData.
   215  	// These can be accessed by prepending \\.\GLOBALROOT to the path. We use this to get the
   216  	// DOS paths for these files.
   217  	if strings.HasPrefix(pth, `\Device`) {
   218  		pth = `\\.\GLOBALROOT` + pth
   219  	}
   220  
   221  	han, err := openPath(pth)
   222  	if err != nil {
   223  		return "", fmt.Errorf("fetching file handle: %w", err)
   224  	}
   225  	defer func() {
   226  		_ = windows.CloseHandle(han)
   227  	}()
   228  
   229  	buf := make([]uint16, 100)
   230  	var flags uint32 = 0x0
   231  	for {
   232  		n, err := windows.GetFinalPathNameByHandle(han, &buf[0], uint32(len(buf)), flags)
   233  		if err != nil {
   234  			// if we mounted a volume that does not also have a drive letter assigned, attempting to
   235  			// fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID.
   236  			if errors.Is(err, os.ErrNotExist) && flags != 0x1 {
   237  				flags = 0x1
   238  				continue
   239  			}
   240  			return "", fmt.Errorf("getting final path name: %w", err)
   241  		}
   242  		if n < uint32(len(buf)) {
   243  			break
   244  		}
   245  		buf = make([]uint16, n)
   246  	}
   247  	finalPath := syscall.UTF16ToString(buf)
   248  	// We got VOLUME_NAME_DOS, we need to strip away some leading slashes.
   249  	// Leave unchanged if we ended up requesting VOLUME_NAME_GUID
   250  	if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 {
   251  		finalPath = finalPath[4:]
   252  		if len(finalPath) > 3 && finalPath[:3] == `UNC` {
   253  			// return path like \\server\share\...
   254  			finalPath = `\` + finalPath[3:]
   255  		}
   256  	}
   257  
   258  	return finalPath, nil
   259  }
   260  
   261  func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) {
   262  	if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) {
   263  		return BindMapping{}, fmt.Errorf("invalid buffer")
   264  	}
   265  
   266  	src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+entry.VirtRootLength])
   267  	if err != nil {
   268  		return BindMapping{}, fmt.Errorf("decoding entry: %w", err)
   269  	}
   270  	targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets))
   271  	if err != nil {
   272  		return BindMapping{}, fmt.Errorf("fetching targets: %w", err)
   273  	}
   274  
   275  	src, err = getFinalPath(src)
   276  	if err != nil {
   277  		return BindMapping{}, fmt.Errorf("fetching final path: %w", err)
   278  	}
   279  
   280  	return BindMapping{
   281  		Flags:      entry.Flags,
   282  		Targets:    targets,
   283  		MountPoint: src,
   284  	}, nil
   285  }
   286  
   287  func openPath(path string) (windows.Handle, error) {
   288  	u16, err := windows.UTF16PtrFromString(path)
   289  	if err != nil {
   290  		return 0, err
   291  	}
   292  	h, err := windows.CreateFile(
   293  		u16,
   294  		0,
   295  		windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE,
   296  		nil,
   297  		windows.OPEN_EXISTING,
   298  		windows.FILE_FLAG_BACKUP_SEMANTICS, // Needed to open a directory handle.
   299  		0)
   300  	if err != nil {
   301  		return 0, &os.PathError{
   302  			Op:   "CreateFile",
   303  			Path: path,
   304  			Err:  err,
   305  		}
   306  	}
   307  	return h, nil
   308  }
   309  

View as plain text