package services import ( "context" "database/sql" "errors" "fmt" "github.com/google/uuid" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/api/utils" ) func (v *virtualMachineService) CreateVirtualMachineDiskEntries(ctx context.Context, virtualMachineID string, createDisks []*model.VirtualMachineDiskCreateInput) (createdDisks []*model.VirtualMachineDisk, err error) { transaction, err := v.SQLDB.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { if err != nil { err = errors.Join(err, transaction.Rollback()) } }() createdDisks, err = v.createVirtualMachineDiskEntries(ctx, transaction, createDisks, virtualMachineID) if err != nil { return nil, err } if len(createdDisks) == 0 { return nil, fmt.Errorf("no vm disks to return - error creating disk entries") } allDisks, err := v.getVirtualMachineDisks(ctx, transaction, virtualMachineID) if err != nil { return nil, err } if utils.HasDuplicateDiskBootOrders(allDisks) { return nil, fmt.Errorf("cannot create vm disks that have duplicate boot orders") } if err = transaction.Commit(); err != nil { return nil, err } return createdDisks, nil } func (v *virtualMachineService) createVirtualMachineDiskEntries(ctx context.Context, transaction *sql.Tx, createDisks []*model.VirtualMachineDiskCreateInput, virtualMachineID string) ([]*model.VirtualMachineDisk, error) { createdDisks := []*model.VirtualMachineDisk{} for _, createDisk := range createDisks { createdDisk := utils.CreateVirtualMachineDiskModel(uuid.NewString(), virtualMachineID, *createDisk.Type, *createDisk.Bus, createDisk.BootOrder, *createDisk.Size, *createDisk.ContainerImageURL) if err := utils.ValidateVirtualMachineDisk(&createdDisk); err != nil { return nil, err } args := []interface{}{ createdDisk.DiskID, virtualMachineID, createdDisk.Type, createdDisk.Bus, createdDisk.BootOrder, createdDisk.Size, createdDisk.ContainerImageURL, } if _, err := transaction.ExecContext(ctx, sqlquery.VirtualMachineDiskCreateQuery, args...); err != nil { return nil, err } createdDisks = append(createdDisks, &createdDisk) } return createdDisks, nil } func (v *virtualMachineService) DeleteVirtualMachineDiskEntry(ctx context.Context, diskID string) (*model.VirtualMachineDisk, error) { virtualMachine, err := v.GetVirtualMachineFromDisk(ctx, diskID) if err != nil { return nil, err } if virtualMachine == nil { return nil, fmt.Errorf("associated virtual machine does not exist so cannot delete disk") } disks, err := v.GetVirtualMachineDisks(ctx, virtualMachine.VirtualMachineID) if err != nil { return nil, err } if disks == nil { return nil, fmt.Errorf("cannot find any disks for the linked virtual machine") } if len(disks) == 1 { return nil, fmt.Errorf("unable to delete the last disk of the virtual machine") } var deletedDisk *model.VirtualMachineDisk for _, disk := range disks { if disk.DiskID == diskID { // find disk to delete deletedDisk = disk break } } _, err = v.SQLDB.ExecContext(ctx, sqlquery.VirtualMachineDiskDeleteQuery, diskID) return deletedDisk, err } func (v *virtualMachineService) UpdateVirtualMachineDiskEntries(ctx context.Context, updateDisks []*model.VirtualMachineDiskIDInput) (updatedDisks []*model.VirtualMachineDisk, err error) { if len(updateDisks) == 0 { return nil, fmt.Errorf("no vm disks provided to update") } transaction, err := v.SQLDB.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { if err != nil { err = errors.Join(err, transaction.Rollback()) } }() updatedDisks, err = v.updateVirtualMachineDiskEntries(ctx, transaction, updateDisks) if err != nil { return nil, err } allDisks, err := v.getVirtualMachineDisks(ctx, transaction, updatedDisks[0].VirtualMachineID) if err != nil { return nil, err } if utils.HasDuplicateDiskBootOrders(allDisks) { return nil, fmt.Errorf("cannot update vm disks to cause duplicate boot orders") } if err = transaction.Commit(); err != nil { return nil, err } return updatedDisks, nil } func (v *virtualMachineService) updateVirtualMachineDiskEntries(ctx context.Context, transaction *sql.Tx, updateDisks []*model.VirtualMachineDiskIDInput) ([]*model.VirtualMachineDisk, error) { updatedDisks := []*model.VirtualMachineDisk{} for _, updateDisk := range updateDisks { updatedDisk, err := v.updateVirtualMachineDiskEntry(ctx, transaction, *updateDisk) if err != nil { return nil, err } updatedDisks = append(updatedDisks, updatedDisk) } return updatedDisks, nil } func (v *virtualMachineService) updateVirtualMachineDiskEntry(ctx context.Context, transaction *sql.Tx, updateDisk model.VirtualMachineDiskIDInput) (*model.VirtualMachineDisk, error) { currentDisk, err := v.GetVirtualMachineDisk(ctx, updateDisk.DiskID) if err != nil { return nil, err } if currentDisk == nil { return nil, fmt.Errorf("cannot update a disk that does not exist") } updatedDisk := utils.UpdateVirtualMachineDisk(currentDisk, updateDisk.VirtualMachineDiskValues) if err = utils.ValidateVirtualMachineDisk(updatedDisk); err != nil { return nil, err } args := []interface{}{ updatedDisk.Type, updatedDisk.Bus, updatedDisk.BootOrder, updatedDisk.Size, updatedDisk.ContainerImageURL, updatedDisk.DiskID, } _, err = transaction.ExecContext(ctx, sqlquery.VirtualMachineDiskUpdateQuery, args...) if err != nil { return nil, err } return updatedDisk, nil } func (v *virtualMachineService) GetVirtualMachineDisk(ctx context.Context, diskID string) (*model.VirtualMachineDisk, error) { row := v.SQLDB.QueryRowContext(ctx, sqlquery.GetVirtualMachineDiskByIDQuery, diskID) return v.scanVirtualMachineDiskRow(row) } // get VM disks as in DB func (v *virtualMachineService) GetVirtualMachineDisks(ctx context.Context, virtualMachineID string) ([]*model.VirtualMachineDisk, error) { virtualMachine, err := v.GetVirtualMachine(ctx, virtualMachineID) if err != nil { return nil, err } if virtualMachine == nil { return nil, fmt.Errorf("cannot get disks for virtual machine that does not exist") } rows, err := v.SQLDB.QueryContext(ctx, sqlquery.GetVirtualMachineDiskByVirtualMachineIDQuery, virtualMachineID) if err != nil { return nil, err } return v.scanVirtualMachineDiskRows(rows) } // get VM disks given a transaction func (v *virtualMachineService) getVirtualMachineDisks(ctx context.Context, transaction *sql.Tx, virtualMachineID string) ([]*model.VirtualMachineDisk, error) { rows, err := transaction.QueryContext(ctx, sqlquery.GetVirtualMachineDiskByVirtualMachineIDQuery, virtualMachineID) if err != nil { return nil, err } return v.scanVirtualMachineDiskRows(rows) } func (v *virtualMachineService) GetVirtualMachineFromDisk(ctx context.Context, diskID string) (*model.VirtualMachine, error) { row := v.SQLDB.QueryRowContext(ctx, sqlquery.GetVirtualMachineIDFromDiskQuery, diskID) var virtualMachineID string if err := row.Scan(&virtualMachineID); err != nil { return nil, err } virtualMachine, err := v.GetVirtualMachine(ctx, virtualMachineID) if err != nil { return nil, err } return virtualMachine, nil } func (v *virtualMachineService) scanVirtualMachineDiskRow(row *sql.Row) (*model.VirtualMachineDisk, error) { disk := model.VirtualMachineDisk{} err := row.Scan(&disk.DiskID, &disk.VirtualMachineID, &disk.Type, &disk.Bus, &disk.BootOrder, &disk.Size, &disk.ContainerImageURL) if err != nil { return nil, err } return &disk, nil } func (v *virtualMachineService) scanVirtualMachineDiskRows(rows *sql.Rows) ([]*model.VirtualMachineDisk, error) { disks := []*model.VirtualMachineDisk{} for rows.Next() { disk := model.VirtualMachineDisk{} err := rows.Scan(&disk.DiskID, &disk.VirtualMachineID, &disk.Type, &disk.Bus, &disk.BootOrder, &disk.Size, &disk.ContainerImageURL) if err != nil { return nil, err } disks = append(disks, &disk) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return disks, nil }