...

Source file src/github.com/Microsoft/hcsshim/internal/guest/storage/pmem/pmem_test.go

Documentation: github.com/Microsoft/hcsshim/internal/guest/storage/pmem

     1  //go:build linux
     2  // +build linux
     3  
     4  package pmem
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"os"
    10  	"testing"
    11  
    12  	"github.com/pkg/errors"
    13  	"golang.org/x/sys/unix"
    14  
    15  	"github.com/Microsoft/hcsshim/internal/protocol/guestresource"
    16  )
    17  
    18  func clearTestDependencies() {
    19  	osMkdirAll = nil
    20  	osRemoveAll = nil
    21  	unixMount = nil
    22  	createZeroSectorLinearTarget = nil
    23  	createVerityTarget = nil
    24  	removeDevice = nil
    25  	mountInternal = mount
    26  }
    27  
    28  func Test_Mount_Mkdir_Fails_Error(t *testing.T) {
    29  	clearTestDependencies()
    30  
    31  	expectedErr := errors.New("mkdir : no such file or directory")
    32  	osMkdirAll = func(path string, perm os.FileMode) error {
    33  		return expectedErr
    34  	}
    35  	err := Mount(context.Background(), 0, "", nil, nil)
    36  	if errors.Cause(err) != expectedErr {
    37  		t.Fatalf("expected err: %v, got: %v", expectedErr, err)
    38  	}
    39  }
    40  
    41  func Test_Mount_Mkdir_ExpectedPath(t *testing.T) {
    42  	clearTestDependencies()
    43  
    44  	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
    45  	// be called.
    46  
    47  	target := "/fake/path"
    48  	osMkdirAll = func(path string, perm os.FileMode) error {
    49  		if path != target {
    50  			t.Errorf("expected path: %v, got: %v", target, path)
    51  			return errors.New("unexpected path")
    52  		}
    53  		return nil
    54  	}
    55  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
    56  		// Fake the mount success
    57  		return nil
    58  	}
    59  	err := Mount(context.Background(), 0, target, nil, nil)
    60  	if err != nil {
    61  		t.Fatalf("expected nil error got: %v", err)
    62  	}
    63  }
    64  
    65  func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) {
    66  	clearTestDependencies()
    67  
    68  	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
    69  	// be called.
    70  
    71  	target := "/fake/path"
    72  	osMkdirAll = func(path string, perm os.FileMode) error {
    73  		if perm != os.FileMode(0700) {
    74  			t.Errorf("expected perm: %v, got: %v", os.FileMode(0700), perm)
    75  			return errors.New("unexpected perm")
    76  		}
    77  		return nil
    78  	}
    79  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
    80  		// Fake the mount success
    81  		return nil
    82  	}
    83  	err := Mount(context.Background(), 0, target, nil, nil)
    84  	if err != nil {
    85  		t.Fatalf("expected nil error got: %v", err)
    86  	}
    87  }
    88  
    89  func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) {
    90  	clearTestDependencies()
    91  
    92  	osMkdirAll = func(path string, perm os.FileMode) error {
    93  		return nil
    94  	}
    95  	target := "/fake/path"
    96  	removeAllCalled := false
    97  	osRemoveAll = func(path string) error {
    98  		removeAllCalled = true
    99  		if path != target {
   100  			t.Errorf("expected path: %v, got: %v", target, path)
   101  			return errors.New("unexpected path")
   102  		}
   103  		return nil
   104  	}
   105  	expectedErr := errors.New("unexpected mount failure")
   106  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
   107  		// Fake the mount failure to test remove is called
   108  		return expectedErr
   109  	}
   110  	err := Mount(context.Background(), 0, target, nil, nil)
   111  	if errors.Cause(err) != expectedErr {
   112  		t.Fatalf("expected err: %v, got: %v", expectedErr, err)
   113  	}
   114  	if !removeAllCalled {
   115  		t.Fatal("expected os.RemoveAll to be called on mount failure")
   116  	}
   117  }
   118  
   119  func Test_Mount_Valid_Source(t *testing.T) {
   120  	clearTestDependencies()
   121  
   122  	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
   123  	// be called.
   124  
   125  	osMkdirAll = func(path string, perm os.FileMode) error {
   126  		return nil
   127  	}
   128  	device := uint32(20)
   129  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
   130  		expected := fmt.Sprintf("/dev/pmem%d", device)
   131  		if source != expected {
   132  			t.Errorf("expected source: %s, got: %s", expected, source)
   133  			return errors.New("unexpected source")
   134  		}
   135  		return nil
   136  	}
   137  	err := Mount(context.Background(), device, "/fake/path", nil, nil)
   138  	if err != nil {
   139  		t.Fatalf("expected nil err, got: %v", err)
   140  	}
   141  }
   142  
   143  func Test_Mount_Valid_Target(t *testing.T) {
   144  	clearTestDependencies()
   145  
   146  	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
   147  	// be called.
   148  
   149  	osMkdirAll = func(path string, perm os.FileMode) error {
   150  		return nil
   151  	}
   152  	expectedTarget := "/fake/path"
   153  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
   154  		if expectedTarget != target {
   155  			t.Errorf("expected target: %s, got: %s", expectedTarget, target)
   156  			return errors.New("unexpected target")
   157  		}
   158  		return nil
   159  	}
   160  	err := Mount(context.Background(), 0, expectedTarget, nil, nil)
   161  	if err != nil {
   162  		t.Fatalf("expected nil err, got: %v", err)
   163  	}
   164  }
   165  
   166  func Test_Mount_Valid_FSType(t *testing.T) {
   167  	clearTestDependencies()
   168  
   169  	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
   170  	// be called.
   171  
   172  	osMkdirAll = func(path string, perm os.FileMode) error {
   173  		return nil
   174  	}
   175  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
   176  		expectedFSType := "ext4"
   177  		if expectedFSType != fstype {
   178  			t.Errorf("expected fstype: %s, got: %s", expectedFSType, fstype)
   179  			return errors.New("unexpected fstype")
   180  		}
   181  		return nil
   182  	}
   183  	err := Mount(context.Background(), 0, "/fake/path", nil, nil)
   184  	if err != nil {
   185  		t.Fatalf("expected nil err, got: %v", err)
   186  	}
   187  }
   188  
   189  func Test_Mount_Valid_Flags(t *testing.T) {
   190  	clearTestDependencies()
   191  
   192  	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
   193  	// be called.
   194  
   195  	osMkdirAll = func(path string, perm os.FileMode) error {
   196  		return nil
   197  	}
   198  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
   199  		expectedFlags := uintptr(unix.MS_RDONLY)
   200  		if expectedFlags != flags {
   201  			t.Errorf("expected flags: %v, got: %v", expectedFlags, flags)
   202  			return errors.New("unexpected flags")
   203  		}
   204  		return nil
   205  	}
   206  	err := Mount(context.Background(), 0, "/fake/path", nil, nil)
   207  	if err != nil {
   208  		t.Fatalf("expected nil err, got: %v", err)
   209  	}
   210  }
   211  
   212  func Test_Mount_Valid_Data(t *testing.T) {
   213  	clearTestDependencies()
   214  
   215  	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
   216  	// be called.
   217  
   218  	osMkdirAll = func(path string, perm os.FileMode) error {
   219  		return nil
   220  	}
   221  	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
   222  		expectedData := "noload"
   223  		if expectedData != data {
   224  			t.Errorf("expected data: %s, got: %s", expectedData, data)
   225  			return errors.New("unexpected data")
   226  		}
   227  		return nil
   228  	}
   229  	err := Mount(context.Background(), 0, "/fake/path", nil, nil)
   230  	if err != nil {
   231  		t.Fatalf("expected nil err, got: %v", err)
   232  	}
   233  }
   234  
   235  // device mapper tests
   236  func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
   237  	clearTestDependencies()
   238  
   239  	mappingInfo := &guestresource.LCOWVPMemMappingInfo{
   240  		DeviceOffsetInBytes: 0,
   241  		DeviceSizeInBytes:   1024,
   242  	}
   243  	expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
   244  	expectedSource := "/dev/pmem0"
   245  	expectedTarget := "/foo"
   246  	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName)
   247  	createZSLTCalled := false
   248  
   249  	osMkdirAll = func(_ string, _ os.FileMode) error {
   250  		return nil
   251  	}
   252  
   253  	mountInternal = func(_ context.Context, source, target string) error {
   254  		if source != mapperPath {
   255  			t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
   256  		}
   257  		if target != expectedTarget {
   258  			t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source)
   259  		}
   260  		return nil
   261  	}
   262  
   263  	createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
   264  		createZSLTCalled = true
   265  		if source != expectedSource {
   266  			t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source)
   267  		}
   268  		if name != expectedLinearName {
   269  			t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name)
   270  		}
   271  		return mapperPath, nil
   272  	}
   273  
   274  	if err := Mount(
   275  		context.Background(),
   276  		0,
   277  		expectedTarget,
   278  		mappingInfo,
   279  		nil,
   280  	); err != nil {
   281  		t.Fatalf("unexpected error during Mount: %s", err)
   282  	}
   283  	if !createZSLTCalled {
   284  		t.Fatalf("createZeroSectorLinearTarget not called")
   285  	}
   286  }
   287  
   288  func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
   289  	clearTestDependencies()
   290  
   291  	verityInfo := &guestresource.DeviceVerityInfo{
   292  		RootDigest: "hash",
   293  	}
   294  	expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
   295  	expectedSource := "/dev/pmem0"
   296  	expectedTarget := "/foo"
   297  	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName)
   298  	createVerityTargetCalled := false
   299  
   300  	mountInternal = func(_ context.Context, source, target string) error {
   301  		if source != mapperPath {
   302  			t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
   303  		}
   304  		if target != expectedTarget {
   305  			t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target)
   306  		}
   307  		return nil
   308  	}
   309  	createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
   310  		createVerityTargetCalled = true
   311  		if source != expectedSource {
   312  			t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source)
   313  		}
   314  		if name != expectedVerityName {
   315  			t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name)
   316  		}
   317  		return mapperPath, nil
   318  	}
   319  
   320  	if err := Mount(
   321  		context.Background(),
   322  		0,
   323  		expectedTarget,
   324  		nil,
   325  		verityInfo,
   326  	); err != nil {
   327  		t.Fatalf("unexpected Mount failure: %s", err)
   328  	}
   329  	if !createVerityTargetCalled {
   330  		t.Fatal("createVerityTarget not called")
   331  	}
   332  }
   333  
   334  func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) {
   335  	clearTestDependencies()
   336  
   337  	verityInfo := &guestresource.DeviceVerityInfo{
   338  		RootDigest: "hash",
   339  	}
   340  	mapping := &guestresource.LCOWVPMemMappingInfo{
   341  		DeviceOffsetInBytes: 0,
   342  		DeviceSizeInBytes:   1024,
   343  	}
   344  	expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
   345  	expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
   346  	expectedPMemDevice := "/dev/pmem0"
   347  	mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
   348  	mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
   349  	dmLinearCalled := false
   350  	dmVerityCalled := false
   351  	mountCalled := false
   352  
   353  	createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
   354  		dmLinearCalled = true
   355  		if source != expectedPMemDevice {
   356  			t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
   357  		}
   358  		if name != expectedLinearTarget {
   359  			t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name)
   360  		}
   361  		return mapperLinearPath, nil
   362  	}
   363  	createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
   364  		dmVerityCalled = true
   365  		if source != mapperLinearPath {
   366  			t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source)
   367  		}
   368  		if name != expectedVerityTarget {
   369  			t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name)
   370  		}
   371  		return mapperVerityPath, nil
   372  	}
   373  	mountInternal = func(_ context.Context, source, target string) error {
   374  		mountCalled = true
   375  		if source != mapperVerityPath {
   376  			t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source)
   377  		}
   378  		return nil
   379  	}
   380  
   381  	if err := Mount(
   382  		context.Background(),
   383  		0,
   384  		"/foo",
   385  		mapping,
   386  		verityInfo,
   387  	); err != nil {
   388  		t.Fatalf("unexpected error during Mount call: %s", err)
   389  	}
   390  	if !dmLinearCalled {
   391  		t.Fatal("expected createZeroSectorLinearTarget call")
   392  	}
   393  	if !dmVerityCalled {
   394  		t.Fatal("expected createVerityTarget call")
   395  	}
   396  	if !mountCalled {
   397  		t.Fatal("expected mountInternal call")
   398  	}
   399  }
   400  
   401  func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) {
   402  	clearTestDependencies()
   403  
   404  	mappingInfo := &guestresource.LCOWVPMemMappingInfo{
   405  		DeviceOffsetInBytes: 0,
   406  		DeviceSizeInBytes:   1024,
   407  	}
   408  	expectedError := errors.New("mountInternal error")
   409  	expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
   410  	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget)
   411  	removeDeviceCalled := false
   412  
   413  	createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
   414  		return mapperPath, nil
   415  	}
   416  	mountInternal = func(_ context.Context, source, target string) error {
   417  		return expectedError
   418  	}
   419  	removeDevice = func(name string) error {
   420  		removeDeviceCalled = true
   421  		if name != expectedTarget {
   422  			t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name)
   423  		}
   424  		return nil
   425  	}
   426  
   427  	if err := Mount(
   428  		context.Background(),
   429  		0,
   430  		"/foo",
   431  		mappingInfo,
   432  		nil,
   433  	); err != expectedError {
   434  		t.Fatalf("expected Mount error %s, got %s", expectedError, err)
   435  	}
   436  	if !removeDeviceCalled {
   437  		t.Fatal("expected removeDevice to be callled")
   438  	}
   439  }
   440  
   441  func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) {
   442  	clearTestDependencies()
   443  
   444  	verity := &guestresource.DeviceVerityInfo{
   445  		RootDigest: "hash",
   446  	}
   447  	expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
   448  	expectedError := errors.New("mountInternal error")
   449  	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
   450  	removeDeviceCalled := false
   451  
   452  	createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
   453  		return mapperPath, nil
   454  	}
   455  	mountInternal = func(_ context.Context, _, _ string) error {
   456  		return expectedError
   457  	}
   458  	removeDevice = func(name string) error {
   459  		removeDeviceCalled = true
   460  		if name != expectedVerityTarget {
   461  			t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name)
   462  		}
   463  		return nil
   464  	}
   465  
   466  	if err := Mount(
   467  		context.Background(),
   468  		0,
   469  		"/foo",
   470  		nil,
   471  		verity,
   472  	); err != expectedError {
   473  		t.Fatalf("expected Mount error %s, got %s", expectedError, err)
   474  	}
   475  	if !removeDeviceCalled {
   476  		t.Fatal("expected removeDevice to be called")
   477  	}
   478  }
   479  
   480  func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) {
   481  	clearTestDependencies()
   482  
   483  	mapping := &guestresource.LCOWVPMemMappingInfo{
   484  		DeviceOffsetInBytes: 0,
   485  		DeviceSizeInBytes:   1024,
   486  	}
   487  	verity := &guestresource.DeviceVerityInfo{
   488  		RootDigest: "hash",
   489  	}
   490  	expectedError := errors.New("mountInternal error")
   491  	expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
   492  	expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
   493  	expectedPMemDevice := "/dev/pmem0"
   494  	mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
   495  	mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
   496  	rmLinearCalled := false
   497  	rmVerityCalled := false
   498  
   499  	createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *guestresource.LCOWVPMemMappingInfo) (string, error) {
   500  		if source != expectedPMemDevice {
   501  			t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
   502  		}
   503  		return mapperLinearPath, nil
   504  	}
   505  	createVerityTarget = func(_ context.Context, source, name string, v *guestresource.DeviceVerityInfo) (string, error) {
   506  		if source != mapperLinearPath {
   507  			t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source)
   508  		}
   509  		if name != expectedVerityTarget {
   510  			t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name)
   511  		}
   512  		return mapperVerityPath, nil
   513  	}
   514  	removeDevice = func(name string) error {
   515  		if name != expectedLinearTarget && name != expectedVerityTarget {
   516  			t.Errorf("unexpected removeDevice target name %s", name)
   517  		}
   518  		if name == expectedLinearTarget {
   519  			rmLinearCalled = true
   520  		}
   521  		if name == expectedVerityTarget {
   522  			rmVerityCalled = true
   523  		}
   524  		return nil
   525  	}
   526  	mountInternal = func(_ context.Context, _, _ string) error {
   527  		return expectedError
   528  	}
   529  
   530  	if err := Mount(
   531  		context.Background(),
   532  		0,
   533  		"/foo",
   534  		mapping,
   535  		verity,
   536  	); err != expectedError {
   537  		t.Fatalf("expected Mount error %s, got %s", expectedError, err)
   538  	}
   539  	if !rmLinearCalled {
   540  		t.Fatal("expected removeDevice for linear target to be called")
   541  	}
   542  	if !rmVerityCalled {
   543  		t.Fatal("expected removeDevice for verity target to be called")
   544  	}
   545  }
   546  

View as plain text