//go:build linux
// +build linux

package pmem

import (
	"context"
	"fmt"
	"os"
	"testing"

	"github.com/pkg/errors"
	"golang.org/x/sys/unix"

	"github.com/Microsoft/hcsshim/internal/protocol/guestresource"
)

func clearTestDependencies() {
	osMkdirAll = nil
	osRemoveAll = nil
	unixMount = nil
	createZeroSectorLinearTarget = nil
	createVerityTarget = nil
	removeDevice = nil
	mountInternal = mount
}

func Test_Mount_Mkdir_Fails_Error(t *testing.T) {
	clearTestDependencies()

	expectedErr := errors.New("mkdir : no such file or directory")
	osMkdirAll = func(path string, perm os.FileMode) error {
		return expectedErr
	}
	err := Mount(context.Background(), 0, "", nil, nil)
	if errors.Cause(err) != expectedErr {
		t.Fatalf("expected err: %v, got: %v", expectedErr, err)
	}
}

func Test_Mount_Mkdir_ExpectedPath(t *testing.T) {
	clearTestDependencies()

	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
	// be called.

	target := "/fake/path"
	osMkdirAll = func(path string, perm os.FileMode) error {
		if path != target {
			t.Errorf("expected path: %v, got: %v", target, path)
			return errors.New("unexpected path")
		}
		return nil
	}
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		// Fake the mount success
		return nil
	}
	err := Mount(context.Background(), 0, target, nil, nil)
	if err != nil {
		t.Fatalf("expected nil error got: %v", err)
	}
}

func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) {
	clearTestDependencies()

	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
	// be called.

	target := "/fake/path"
	osMkdirAll = func(path string, perm os.FileMode) error {
		if perm != os.FileMode(0700) {
			t.Errorf("expected perm: %v, got: %v", os.FileMode(0700), perm)
			return errors.New("unexpected perm")
		}
		return nil
	}
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		// Fake the mount success
		return nil
	}
	err := Mount(context.Background(), 0, target, nil, nil)
	if err != nil {
		t.Fatalf("expected nil error got: %v", err)
	}
}

func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) {
	clearTestDependencies()

	osMkdirAll = func(path string, perm os.FileMode) error {
		return nil
	}
	target := "/fake/path"
	removeAllCalled := false
	osRemoveAll = func(path string) error {
		removeAllCalled = true
		if path != target {
			t.Errorf("expected path: %v, got: %v", target, path)
			return errors.New("unexpected path")
		}
		return nil
	}
	expectedErr := errors.New("unexpected mount failure")
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		// Fake the mount failure to test remove is called
		return expectedErr
	}
	err := Mount(context.Background(), 0, target, nil, nil)
	if errors.Cause(err) != expectedErr {
		t.Fatalf("expected err: %v, got: %v", expectedErr, err)
	}
	if !removeAllCalled {
		t.Fatal("expected os.RemoveAll to be called on mount failure")
	}
}

func Test_Mount_Valid_Source(t *testing.T) {
	clearTestDependencies()

	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
	// be called.

	osMkdirAll = func(path string, perm os.FileMode) error {
		return nil
	}
	device := uint32(20)
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		expected := fmt.Sprintf("/dev/pmem%d", device)
		if source != expected {
			t.Errorf("expected source: %s, got: %s", expected, source)
			return errors.New("unexpected source")
		}
		return nil
	}
	err := Mount(context.Background(), device, "/fake/path", nil, nil)
	if err != nil {
		t.Fatalf("expected nil err, got: %v", err)
	}
}

func Test_Mount_Valid_Target(t *testing.T) {
	clearTestDependencies()

	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
	// be called.

	osMkdirAll = func(path string, perm os.FileMode) error {
		return nil
	}
	expectedTarget := "/fake/path"
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		if expectedTarget != target {
			t.Errorf("expected target: %s, got: %s", expectedTarget, target)
			return errors.New("unexpected target")
		}
		return nil
	}
	err := Mount(context.Background(), 0, expectedTarget, nil, nil)
	if err != nil {
		t.Fatalf("expected nil err, got: %v", err)
	}
}

func Test_Mount_Valid_FSType(t *testing.T) {
	clearTestDependencies()

	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
	// be called.

	osMkdirAll = func(path string, perm os.FileMode) error {
		return nil
	}
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		expectedFSType := "ext4"
		if expectedFSType != fstype {
			t.Errorf("expected fstype: %s, got: %s", expectedFSType, fstype)
			return errors.New("unexpected fstype")
		}
		return nil
	}
	err := Mount(context.Background(), 0, "/fake/path", nil, nil)
	if err != nil {
		t.Fatalf("expected nil err, got: %v", err)
	}
}

func Test_Mount_Valid_Flags(t *testing.T) {
	clearTestDependencies()

	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
	// be called.

	osMkdirAll = func(path string, perm os.FileMode) error {
		return nil
	}
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		expectedFlags := uintptr(unix.MS_RDONLY)
		if expectedFlags != flags {
			t.Errorf("expected flags: %v, got: %v", expectedFlags, flags)
			return errors.New("unexpected flags")
		}
		return nil
	}
	err := Mount(context.Background(), 0, "/fake/path", nil, nil)
	if err != nil {
		t.Fatalf("expected nil err, got: %v", err)
	}
}

func Test_Mount_Valid_Data(t *testing.T) {
	clearTestDependencies()

	// NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to
	// be called.

	osMkdirAll = func(path string, perm os.FileMode) error {
		return nil
	}
	unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
		expectedData := "noload"
		if expectedData != data {
			t.Errorf("expected data: %s, got: %s", expectedData, data)
			return errors.New("unexpected data")
		}
		return nil
	}
	err := Mount(context.Background(), 0, "/fake/path", nil, nil)
	if err != nil {
		t.Fatalf("expected nil err, got: %v", err)
	}
}

// device mapper tests
func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
	clearTestDependencies()

	mappingInfo := &guestresource.LCOWVPMemMappingInfo{
		DeviceOffsetInBytes: 0,
		DeviceSizeInBytes:   1024,
	}
	expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
	expectedSource := "/dev/pmem0"
	expectedTarget := "/foo"
	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName)
	createZSLTCalled := false

	osMkdirAll = func(_ string, _ os.FileMode) error {
		return nil
	}

	mountInternal = func(_ context.Context, source, target string) error {
		if source != mapperPath {
			t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
		}
		if target != expectedTarget {
			t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source)
		}
		return nil
	}

	createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
		createZSLTCalled = true
		if source != expectedSource {
			t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source)
		}
		if name != expectedLinearName {
			t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name)
		}
		return mapperPath, nil
	}

	if err := Mount(
		context.Background(),
		0,
		expectedTarget,
		mappingInfo,
		nil,
	); err != nil {
		t.Fatalf("unexpected error during Mount: %s", err)
	}
	if !createZSLTCalled {
		t.Fatalf("createZeroSectorLinearTarget not called")
	}
}

func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
	clearTestDependencies()

	verityInfo := &guestresource.DeviceVerityInfo{
		RootDigest: "hash",
	}
	expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
	expectedSource := "/dev/pmem0"
	expectedTarget := "/foo"
	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName)
	createVerityTargetCalled := false

	mountInternal = func(_ context.Context, source, target string) error {
		if source != mapperPath {
			t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
		}
		if target != expectedTarget {
			t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target)
		}
		return nil
	}
	createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
		createVerityTargetCalled = true
		if source != expectedSource {
			t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source)
		}
		if name != expectedVerityName {
			t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name)
		}
		return mapperPath, nil
	}

	if err := Mount(
		context.Background(),
		0,
		expectedTarget,
		nil,
		verityInfo,
	); err != nil {
		t.Fatalf("unexpected Mount failure: %s", err)
	}
	if !createVerityTargetCalled {
		t.Fatal("createVerityTarget not called")
	}
}

func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) {
	clearTestDependencies()

	verityInfo := &guestresource.DeviceVerityInfo{
		RootDigest: "hash",
	}
	mapping := &guestresource.LCOWVPMemMappingInfo{
		DeviceOffsetInBytes: 0,
		DeviceSizeInBytes:   1024,
	}
	expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
	expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
	expectedPMemDevice := "/dev/pmem0"
	mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
	mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
	dmLinearCalled := false
	dmVerityCalled := false
	mountCalled := false

	createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
		dmLinearCalled = true
		if source != expectedPMemDevice {
			t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
		}
		if name != expectedLinearTarget {
			t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name)
		}
		return mapperLinearPath, nil
	}
	createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
		dmVerityCalled = true
		if source != mapperLinearPath {
			t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source)
		}
		if name != expectedVerityTarget {
			t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name)
		}
		return mapperVerityPath, nil
	}
	mountInternal = func(_ context.Context, source, target string) error {
		mountCalled = true
		if source != mapperVerityPath {
			t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source)
		}
		return nil
	}

	if err := Mount(
		context.Background(),
		0,
		"/foo",
		mapping,
		verityInfo,
	); err != nil {
		t.Fatalf("unexpected error during Mount call: %s", err)
	}
	if !dmLinearCalled {
		t.Fatal("expected createZeroSectorLinearTarget call")
	}
	if !dmVerityCalled {
		t.Fatal("expected createVerityTarget call")
	}
	if !mountCalled {
		t.Fatal("expected mountInternal call")
	}
}

func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) {
	clearTestDependencies()

	mappingInfo := &guestresource.LCOWVPMemMappingInfo{
		DeviceOffsetInBytes: 0,
		DeviceSizeInBytes:   1024,
	}
	expectedError := errors.New("mountInternal error")
	expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget)
	removeDeviceCalled := false

	createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
		return mapperPath, nil
	}
	mountInternal = func(_ context.Context, source, target string) error {
		return expectedError
	}
	removeDevice = func(name string) error {
		removeDeviceCalled = true
		if name != expectedTarget {
			t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name)
		}
		return nil
	}

	if err := Mount(
		context.Background(),
		0,
		"/foo",
		mappingInfo,
		nil,
	); err != expectedError {
		t.Fatalf("expected Mount error %s, got %s", expectedError, err)
	}
	if !removeDeviceCalled {
		t.Fatal("expected removeDevice to be callled")
	}
}

func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) {
	clearTestDependencies()

	verity := &guestresource.DeviceVerityInfo{
		RootDigest: "hash",
	}
	expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
	expectedError := errors.New("mountInternal error")
	mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
	removeDeviceCalled := false

	createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
		return mapperPath, nil
	}
	mountInternal = func(_ context.Context, _, _ string) error {
		return expectedError
	}
	removeDevice = func(name string) error {
		removeDeviceCalled = true
		if name != expectedVerityTarget {
			t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name)
		}
		return nil
	}

	if err := Mount(
		context.Background(),
		0,
		"/foo",
		nil,
		verity,
	); err != expectedError {
		t.Fatalf("expected Mount error %s, got %s", expectedError, err)
	}
	if !removeDeviceCalled {
		t.Fatal("expected removeDevice to be called")
	}
}

func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) {
	clearTestDependencies()

	mapping := &guestresource.LCOWVPMemMappingInfo{
		DeviceOffsetInBytes: 0,
		DeviceSizeInBytes:   1024,
	}
	verity := &guestresource.DeviceVerityInfo{
		RootDigest: "hash",
	}
	expectedError := errors.New("mountInternal error")
	expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
	expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
	expectedPMemDevice := "/dev/pmem0"
	mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
	mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
	rmLinearCalled := false
	rmVerityCalled := false

	createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *guestresource.LCOWVPMemMappingInfo) (string, error) {
		if source != expectedPMemDevice {
			t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
		}
		return mapperLinearPath, nil
	}
	createVerityTarget = func(_ context.Context, source, name string, v *guestresource.DeviceVerityInfo) (string, error) {
		if source != mapperLinearPath {
			t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source)
		}
		if name != expectedVerityTarget {
			t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name)
		}
		return mapperVerityPath, nil
	}
	removeDevice = func(name string) error {
		if name != expectedLinearTarget && name != expectedVerityTarget {
			t.Errorf("unexpected removeDevice target name %s", name)
		}
		if name == expectedLinearTarget {
			rmLinearCalled = true
		}
		if name == expectedVerityTarget {
			rmVerityCalled = true
		}
		return nil
	}
	mountInternal = func(_ context.Context, _, _ string) error {
		return expectedError
	}

	if err := Mount(
		context.Background(),
		0,
		"/foo",
		mapping,
		verity,
	); err != expectedError {
		t.Fatalf("expected Mount error %s, got %s", expectedError, err)
	}
	if !rmLinearCalled {
		t.Fatal("expected removeDevice for linear target to be called")
	}
	if !rmVerityCalled {
		t.Fatal("expected removeDevice for verity target to be called")
	}
}