package dsdssandboxes

import (
	"slices"
	"testing"

	"google.golang.org/api/compute/v1"
	"gotest.tools/v3/assert"
)

var (
	testVMs = []string{"s1", "s2", "s3"}

	mockInstanceList = compute.InstanceList{
		Items: []*compute.Instance{
			{
				Name:             "s1",
				LabelFingerprint: "",
				Labels:           map[string]string{},
			},
			{ //Test with a label that belongs to someone else
				Name:             "s2",
				LabelFingerprint: "",
				Labels: map[string]string{
					"test": "example",
				},
			},
			{ //Test with existing but different schedule labels
				Name:             "s3",
				LabelFingerprint: "",
				Labels:           preexistingSchedule.ToLabelMap(),
			},
		},
	}

	expectedLabels = map[string]map[string]string{
		"s1": testSchedule.ToLabelMap(),
		"s2": testSchedule.ToLabelMap(),
		"s3": testSchedule.ToLabelMap(),
	}

	testSchedule = Schedule{
		WeekendStart: intPointer(8),
		WeekendStop:  intPointer(20),
		WeekdayStart: intPointer(6),
		WeekdayStop:  intPointer(22),
		Timezone:     "utc",
	}

	preexistingSchedule = Schedule{
		WeekendStart: intPointer(11),
		WeekendStop:  intPointer(15),
		WeekdayStart: intPointer(2),
		WeekdayStop:  intPointer(4),
		Timezone:     "eastern_standard_time",
	}
)

func TestAddVMScheduleLabels(t *testing.T) {
	labels := map[string]map[string]string{}
	progressHookCalls := []string{}
	expectedLabels["s2"]["test"] = "example"

	setLabelMock := func(instance string, request *compute.InstancesSetLabelsRequest) error {
		labels[instance] = request.Labels
		return nil
	}

	listVMMock := func() (*compute.InstanceList, error) {
		return &mockInstanceList, nil
	}

	progressHook := func(instance string) {
		progressHookCalls = append(progressHookCalls, instance)
	}

	labeller := ProjectLabeller{
		Name:         "testProject",
		setLabels:    setLabelMock,
		listVMs:      listVMMock,
		progressHook: progressHook,
	}

	err := labeller.AddVMScheduleLabels(testSchedule)

	assert.NilError(t, err)
	assert.Equal(t, len(testVMs), len(labels))
	assert.Equal(t, len(testVMs), len(progressHookCalls))

	for _, instanceName := range testVMs {
		assert.Equal(t, true, slices.Contains(progressHookCalls, instanceName))
		assert.DeepEqual(t, expectedLabels[instanceName], labels[instanceName])
	}
}

func intPointer(value int) *int {
	return &value
}