package services import ( "context" "database/sql" "errors" "github.com/google/uuid" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/api/utils" ) func (t *terminalService) CreateTerminalDiskEntry(ctx context.Context, terminalID string, newTerminalDisk *model.TerminalDiskCreateInput) (*model.TerminalDisk, error) { if err := t.validateTerminalDiskDevicePath(ctx, terminalID, newTerminalDisk.DevicePath); err != nil { return nil, err } transaction, err := t.SQLDB.BeginTx(ctx, nil) if err != nil { return nil, err } newDiskList := []*model.TerminalDiskCreateInput{newTerminalDisk} diskList, err := t.createTerminalDiskEntries(ctx, transaction, newDiskList, terminalID) if err != nil { if rollbackErr := transaction.Rollback(); rollbackErr != nil { return nil, rollbackErr } return nil, err } if err = transaction.Commit(); err != nil { return nil, err } if len(diskList) == 0 { return nil, errors.New("no disk to return - potential error creating disk entry") } return diskList[0], nil } func (t *terminalService) createTerminalDiskEntries(ctx context.Context, transaction *sql.Tx, newDisks []*model.TerminalDiskCreateInput, terminalID string) ([]*model.TerminalDisk, error) { terminalDisks := []*model.TerminalDisk{} for _, newDisk := range newDisks { terminalDisk := utils.CreateTerminalDiskModel(uuid.NewString(), terminalID, newDisk.IncludeDisk, newDisk.ExpectEmpty, newDisk.DevicePath, newDisk.UsePart) args := []interface{}{ terminalDisk.TerminalDiskID, terminalID, terminalDisk.IncludeDisk, terminalDisk.ExpectEmpty, terminalDisk.DevicePath, terminalDisk.UsePart, } if _, err := transaction.ExecContext(ctx, sqlquery.TerminalDiskCreateQuery, args...); err != nil { return nil, err } terminalDisks = append(terminalDisks, &terminalDisk) } return terminalDisks, nil } func (t *terminalService) DeleteTerminalDiskEntry(ctx context.Context, terminalDiskID string) (*model.Terminal, error) { terminal, err := t.GetTerminalFromDisk(ctx, terminalDiskID) if err != nil { return nil, err } if len(terminal.Disks) == 1 { return nil, errors.New("unable to delete the last disk of the terminal") } for i, disk := range terminal.Disks { if disk.TerminalDiskID == terminalDiskID { terminal.Disks[i] = terminal.Disks[len(terminal.Disks)-1] terminal.Disks = terminal.Disks[:len(terminal.Disks)-1] break } } _, err = t.SQLDB.ExecContext(ctx, sqlquery.TerminalDiskDeleteQuery, terminalDiskID) return terminal, err } func (t *terminalService) UpdateTerminalDiskEntry(ctx context.Context, terminalDiskID string, diskInput model.TerminalDiskUpdateInput) (*model.TerminalDisk, error) { transaction, err := t.SQLDB.BeginTx(ctx, nil) if err != nil { return nil, err } updateDiskInput := []*model.TerminalDiskIDInput{ { TerminalDiskID: terminalDiskID, TerminalDiskValues: &model.TerminalDiskUpdateInput{ DevicePath: diskInput.DevicePath, IncludeDisk: diskInput.IncludeDisk, ExpectEmpty: diskInput.ExpectEmpty, UsePart: diskInput.UsePart, }, }, } updatedDisks, err := t.updateTerminalDiskEntries(ctx, transaction, updateDiskInput) if err != nil { if rollbackErr := transaction.Rollback(); rollbackErr != nil { return nil, rollbackErr } return nil, err } if err = transaction.Commit(); err != nil { return nil, err } if len(updatedDisks) == 0 { return nil, errors.New("no disk to return - potential error updating disk entry") } return updatedDisks[0], nil } func (t *terminalService) updateTerminalDiskEntries(ctx context.Context, transaction *sql.Tx, updateDisks []*model.TerminalDiskIDInput) ([]*model.TerminalDisk, error) { updatedTerminalDisks := []*model.TerminalDisk{} for _, updateDisk := range updateDisks { disk, err := t.getTerminalDisk(ctx, updateDisk.TerminalDiskID) if err != nil { return nil, err } updateDiskInput := updateDisk.TerminalDiskValues if updateDiskInput.DevicePath != nil { if err = t.validateTerminalDiskDevicePath(ctx, disk.TerminalID, *updateDiskInput.DevicePath); err != nil { return nil, err } } updatedTerminalDisk, err := utils.UpdateTerminalDisk(disk, updateDiskInput) if err != nil { return nil, err } args := []interface{}{ updatedTerminalDisk.IncludeDisk, updatedTerminalDisk.ExpectEmpty, updatedTerminalDisk.DevicePath, updatedTerminalDisk.UsePart, updatedTerminalDisk.TerminalDiskID, } _, err = transaction.ExecContext(ctx, sqlquery.TerminalDiskUpdateQuery, args...) if err != nil { return nil, err } updatedTerminalDisks = append(updatedTerminalDisks, updatedTerminalDisk) } return updatedTerminalDisks, nil } func (t *terminalService) GetTerminalFromDisk(ctx context.Context, terminalDiskID string) (*model.Terminal, error) { row := t.SQLDB.QueryRowContext(ctx, sqlquery.GetTerminalIDFromDiskQuery, terminalDiskID) var terminalID string if err := row.Scan(&terminalID); err != nil { return nil, err } getLabel := true terminal, err := t.GetTerminal(ctx, terminalID, &getLabel) if err != nil { return nil, err } return terminal, nil } func (t *terminalService) getTerminalDisk(ctx context.Context, terminalDiskID string) (*model.TerminalDisk, error) { row := t.SQLDB.QueryRowContext(ctx, sqlquery.GetTerminalDiskByIDQuery, terminalDiskID) terminalDisk := model.TerminalDisk{} if err := row.Scan(&terminalDisk.TerminalDiskID, &terminalDisk.TerminalID, &terminalDisk.IncludeDisk, &terminalDisk.ExpectEmpty, &terminalDisk.DevicePath, &terminalDisk.UsePart); err != nil { return nil, err } return &terminalDisk, nil } func (t *terminalService) getTerminalDisks(ctx context.Context, terminalID *string) ([]*model.TerminalDisk, error) { rows, err := t.SQLDB.QueryContext(ctx, sqlquery.GetTerminalDiskByTerminalIDQuery, terminalID) if err != nil { return nil, err } disks, err := t.scanTerminalDiskRows(rows) if err != nil { return nil, err } return disks, nil } func (t *terminalService) validateTerminalDiskDevicePath(ctx context.Context, terminalID string, devicePath string) error { disks, err := t.getTerminalDisks(ctx, &terminalID) if err != nil { return err } for _, disk := range disks { if disk.DevicePath == devicePath { return ErrDuplicateTerminalDiskDevicePaths } } return nil } func (t *terminalService) scanTerminalDiskRows(rows *sql.Rows) ([]*model.TerminalDisk, error) { terminalDisks := []*model.TerminalDisk{} for rows.Next() { terminalDisk := model.TerminalDisk{} err := rows.Scan(&terminalDisk.TerminalDiskID, &terminalDisk.TerminalID, &terminalDisk.IncludeDisk, &terminalDisk.ExpectEmpty, &terminalDisk.DevicePath, &terminalDisk.UsePart) if err != nil { return nil, err } terminalDisks = append(terminalDisks, &terminalDisk) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return terminalDisks, nil }