package services import ( "context" "database/sql" "database/sql/driver" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/edge/api/graph/model" apimock "edge-infra.dev/pkg/edge/api/mocks" "edge-infra.dev/pkg/edge/api/services/artifacts" sqlquery "edge-infra.dev/pkg/edge/api/sql" ) func TestCreateTerminalDisk(t *testing.T) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() includeDisk := true expectEmpty := true usePart := false result := sqlmock.NewResult(1, 1) mockDBGetTerminalDiskByTerminalID(mock, terminalID, []model.TerminalDisk{}) mock.ExpectBegin() mockDBTerminalDiskCreate(mock, terminalID, model.TerminalDiskCreateInput{ IncludeDisk: includeDisk, ExpectEmpty: expectEmpty, DevicePath: devicePath, UsePart: usePart, }, result) mock.ExpectCommit() newDisk := model.TerminalDiskCreateInput{ IncludeDisk: includeDisk, ExpectEmpty: expectEmpty, DevicePath: devicePath, UsePart: usePart, } artifactsService := artifacts.NewArtifactsService(db, nil) labelSvc := NewLabelService(artifactsService, db) service := NewTerminalService(db, labelSvc) disk, err := service.CreateTerminalDiskEntry(context.Background(), terminalID, &newDisk) assert.NoError(t, err) assert.NotNil(t, disk) assert.Equal(t, disk.IncludeDisk, includeDisk) assert.Equal(t, disk.ExpectEmpty, expectEmpty) assert.Equal(t, devicePath, disk.DevicePath) assert.Equal(t, usePart, disk.UsePart) } func TestCreateDuplicateTerminalDisk(t *testing.T) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() terminalDisk := model.TerminalDisk{ TerminalID: terminalID, TerminalDiskID: terminalDiskID, ExpectEmpty: false, IncludeDisk: true, DevicePath: devicePath, UsePart: true, } mockDBGetTerminalDiskByTerminalID(mock, terminalID, []model.TerminalDisk{terminalDisk}) mock.ExpectBegin() mockDBTerminalDiskCreate(mock, terminalID, model.TerminalDiskCreateInput{}, sqlmock.NewResult(1, 0)) mock.ExpectRollback() mock.ExpectQuery(sqlquery.GetProjectIDByClusterEdgeID). WithArgs(terminalClusterEdgeID). WillReturnRows(mock.NewRows([]string{"project_id"}). AddRow("test-org")) getKubeResource := func() GetKubeResourceFunc { return func(_ context.Context, _projectID string, cluster *model.Cluster, _ model.LoqRequest) ([]string, error) { assert.Equal(t, projectID, _projectID) assert.Nil(t, cluster) res := versionResource return []string{res}, nil } } bqClientMock := createMockBQClient(t, getKubeResource()) artifactsService := artifacts.NewArtifactsService(db, nil) labelSvc := NewLabelService(artifactsService, db) service := NewTerminalServiceBQ(db, bqClientMock, labelSvc) diskInput := model.TerminalDiskCreateInput{ IncludeDisk: terminalDisk.IncludeDisk, ExpectEmpty: terminalDisk.ExpectEmpty, DevicePath: terminalDisk.DevicePath, UsePart: terminalDisk.UsePart, } _, err = service.CreateTerminalDiskEntry(context.Background(), terminalID, &diskInput) assert.ErrorIs(t, err, ErrDuplicateTerminalDiskDevicePaths) } func TestCreateTerminalDiskDefaultNoPart(t *testing.T) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() includeDisk := true expectEmpty := true result := sqlmock.NewResult(1, 1) mockDBGetTerminalDiskByTerminalID(mock, terminalID, []model.TerminalDisk{}) mock.ExpectBegin() mockDBTerminalDiskCreate(mock, terminalID, model.TerminalDiskCreateInput{ IncludeDisk: includeDisk, ExpectEmpty: expectEmpty, DevicePath: devicePath, }, result) mock.ExpectCommit() newDisk := model.TerminalDiskCreateInput{ IncludeDisk: includeDisk, ExpectEmpty: expectEmpty, DevicePath: devicePath, } artifactsService := artifacts.NewArtifactsService(db, nil) labelSvc := NewLabelService(artifactsService, db) service := NewTerminalService(db, labelSvc) disk, err := service.CreateTerminalDiskEntry(context.Background(), terminalID, &newDisk) assert.NoError(t, err) assert.NotNil(t, disk) assert.Equal(t, disk.IncludeDisk, includeDisk) assert.Equal(t, disk.ExpectEmpty, expectEmpty) assert.Equal(t, devicePath, disk.DevicePath) assert.False(t, disk.UsePart) // defaults to false } func TestDeleteTerminalDisk(t *testing.T) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() mockDBGetTerminalIDFromDisk(mock, terminalDiskID, &terminalID) mock.ExpectQuery(sqlquery.GetTerminalByIDQuery). WithArgs(terminalID). WillReturnRows(mock.NewRows([]string{"terminal_id", "lane", "role", "cluster_edge_id", "cluster_name", "class", "discover_disks", "boot_disk", "primary_interface", "existing_efi_part", "swap_enabled", "hostname"}). AddRow(terminalID, lane1, terminalRoleWorker, terminalClusterEdgeID, terminalClusterName, terminalClassServer, terminalDiscoverDisksAll, devicePath2, terminalPrimaryInterface, terminalExistingEfiPart, swapEnabled, terminalHostname)) mockDBGetTerminalInterfaceByTerminalIDQuery(mock, terminalID, []*model.TerminalInterface{ { TerminalInterfaceID: terminalInterfaceID, MacAddress: macAddress, Dhcp4: dhcp4False, Dhcp6: dhcp6False, Gateway4: &gateway4, Gateway6: &gateway6, TerminalID: terminalID, }, }) mockDBGetTerminalAddressByInterfaceIDQuery(mock, terminalInterfaceID, []*model.TerminalAddress{ { TerminalAddressID: terminalAddressID, IP: &ipv4, PrefixLen: prefixLen, Family: familyInet, TerminalInterfaceID: terminalInterfaceID, }, }) mockDBGetTerminalDiskByTerminalID(mock, terminalID, nil) mock.ExpectQuery(sqlquery.GetTerminalLabels). WithArgs(terminalID, sql.NullString{}). WillReturnRows(mock.NewRows([]string{"terminal_id", "terminal_label_edge_id", "label_edge_id", "labelkey", "color", "visible", "editable", "banner", "unique", "description", "label_type"}). AddRow(terminalID, "388d1144-27c5-44e2-856a-e69a3d4f859f", testLabelEdgeID, label.Key, label.Color, label.Visible, label.Editable, label.BannerEdgeID, label.Unique, label.Description, label.Type)) mock.ExpectQuery(sqlquery.GetProjectIDByClusterEdgeID). WithArgs(terminalClusterEdgeID). WillReturnRows(mock.NewRows([]string{"project_id"}). AddRow("test-org")) mockDBTerminalDiskDelete(mock, terminalDiskID, sqlmock.NewResult(1, 1)) ctrl := gomock.NewController(t) defer ctrl.Finish() bqmock := apimock.NewMockBQClient(ctrl) bqmock.EXPECT().GetKubeResource(gomock.Any(), "test-org", gomock.Any(), gomock.Any()) artifactsService := artifacts.NewArtifactsService(db, nil) labelSvc := NewLabelService(artifactsService, db) service := NewTerminalServiceBQ(db, bqmock, labelSvc) newTerminal, err := service.DeleteTerminalDiskEntry(context.Background(), terminalDiskID) assert.NoError(t, err) assert.NotNil(t, newTerminal) assert.Equal(t, terminalID, newTerminal.TerminalID) for _, disk := range newTerminal.Disks { assert.NotEqual(t, terminalDiskID, disk.TerminalDiskID) } } func TestUpdateTerminalDisk(t *testing.T) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() disks := []model.TerminalDisk{ { TerminalID: terminalID, TerminalDiskID: terminalDiskID, DevicePath: devicePath, ExpectEmpty: true, IncludeDisk: true, UsePart: false, }, { TerminalID: terminalID, TerminalDiskID: terminalDiskID2, ExpectEmpty: false, IncludeDisk: true, DevicePath: devicePath2, UsePart: false, }, } mock.ExpectBegin() mockDBGetTerminalDiskByID(mock, terminalDiskID, &disks[0]) mockDBGetTerminalDiskByTerminalID(mock, disks[0].TerminalID, disks) newDevicePath := "test-path" mockDBUpdateTerminalDisk(mock, disks[0].TerminalDiskID, model.TerminalDiskUpdateInput{ DevicePath: &newDevicePath, IncludeDisk: &disks[0].IncludeDisk, ExpectEmpty: &disks[0].ExpectEmpty, UsePart: &disks[0].UsePart, }, sqlmock.NewResult(1, 1)) mock.ExpectCommit() ctrl := gomock.NewController(t) defer ctrl.Finish() bqmock := apimock.NewMockBQClient(ctrl) artifactsService := artifacts.NewArtifactsService(db, nil) labelSvc := NewLabelService(artifactsService, db) service := NewTerminalServiceBQ(db, bqmock, labelSvc) updatedTerminalDisk, err := service.UpdateTerminalDiskEntry(context.Background(), terminalDiskID, model.TerminalDiskUpdateInput{ DevicePath: &newDevicePath, }) assert.NoError(t, err) assert.Equal(t, newDevicePath, updatedTerminalDisk.DevicePath) assert.Equal(t, disks[0].ExpectEmpty, updatedTerminalDisk.ExpectEmpty) assert.Equal(t, disks[0].IncludeDisk, updatedTerminalDisk.IncludeDisk) assert.Equal(t, disks[0].UsePart, updatedTerminalDisk.UsePart) } func mockDBGetTerminalDiskByID(mock sqlmock.Sqlmock, terminalDiskID string, returnDisk *model.TerminalDisk) { rows := sqlmock.NewRows(terminalDiskColumns) if returnDisk != nil { rows.AddRow(returnDisk.TerminalDiskID, returnDisk.TerminalID, returnDisk.IncludeDisk, returnDisk.ExpectEmpty, returnDisk.DevicePath, returnDisk.UsePart) } mock.ExpectQuery(sqlquery.GetTerminalDiskByIDQuery). WithArgs(terminalDiskID). WillReturnRows(rows) } func mockDBGetTerminalDiskByTerminalID(mock sqlmock.Sqlmock, terminalID any, returnDisks []model.TerminalDisk) { if terminalID == nil { terminalID = sqlmock.AnyArg() } rows := sqlmock.NewRows(terminalDiskColumns) if len(returnDisks) != 0 { for _, disk := range returnDisks { rows.AddRow(disk.TerminalDiskID, disk.TerminalID, disk.IncludeDisk, disk.ExpectEmpty, disk.DevicePath, disk.UsePart) } } mock.ExpectQuery(sqlquery.GetTerminalDiskByTerminalIDQuery). WithArgs(terminalID). WillReturnRows(rows) } func mockDBTerminalDiskCreate(mock sqlmock.Sqlmock, terminalID string, diskCreate model.TerminalDiskCreateInput, result driver.Result) { mock.ExpectExec(sqlquery.TerminalDiskCreateQuery). WithArgs(sqlmock.AnyArg(), terminalID, diskCreate.IncludeDisk, diskCreate.ExpectEmpty, diskCreate.DevicePath, diskCreate.UsePart). WillReturnResult(result) } func mockDBTerminalDiskDelete(mock sqlmock.Sqlmock, diskID string, result driver.Result) { mock.ExpectExec(sqlquery.TerminalDiskDeleteQuery). WithArgs(diskID). WillReturnResult(result) } func mockDBGetTerminalIDFromDisk(mock sqlmock.Sqlmock, diskID string, returnTerminalID *string) { rows := sqlmock.NewRows([]string{"terminal_id"}) if returnTerminalID != nil { rows.AddRow(returnTerminalID) } mock.ExpectQuery(sqlquery.GetTerminalIDFromDiskQuery). WithArgs(diskID). WillReturnRows(rows) } func mockDBUpdateTerminalDisk(mock sqlmock.Sqlmock, terminalDiskID string, updateTerminalDisk model.TerminalDiskUpdateInput, result driver.Result) { mock.ExpectExec(sqlquery.TerminalDiskUpdateQuery). WithArgs(updateTerminalDisk.IncludeDisk, updateTerminalDisk.ExpectEmpty, updateTerminalDisk.DevicePath, updateTerminalDisk.UsePart, terminalDiskID). WillReturnResult(result) }