package services import ( "context" "database/sql" "fmt" "slices" "strings" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/api/utils" "edge-infra.dev/pkg/edge/capabilities" "edge-infra.dev/pkg/edge/chariot/client" "edge-infra.dev/pkg/lib/uuid" v1ien "edge-infra.dev/pkg/sds/ien/k8s/apis/v1" ) //go:generate mockgen -destination=../mocks/mock_terminal_label_service.go -package=mocks edge-infra.dev/pkg/edge/api/services TerminalLabelService type TerminalLabelService interface { CreateTerminalLabel(context.Context, string, ...string) error GetTerminalLabel(context.Context, string) (*model.TerminalLabel, error) DeleteTerminalLabels(context.Context, model.SearchTerminalLabelInput) ([]*model.TerminalLabel, error) SendUpdatedIENCRAfterDeletion(context.Context, []*model.TerminalLabel) error GetTerminalLabels(context.Context, model.SearchTerminalLabelInput) ([]*model.TerminalLabel, error) GetTerminalLabelsInfo(ctx context.Context, terminals []*model.Terminal) ([]*model.Terminal, error) } type terminalLabelService struct { SQLDB *sql.DB ChariotService ChariotService TerminalService TerminalService StoreClusterService StoreClusterService LabelService LabelService } func (n *terminalLabelService) CreateTerminalLabel(ctx context.Context, terminalID string, labelEdgeIDs ...string) error { tx, err := n.SQLDB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } customLabels, err := insertTerminalLabels(ctx, tx, terminalID, labelEdgeIDs...) if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return rollbackErr } return err } getLabel := false terminal, err := n.TerminalService.GetTerminal(ctx, terminalID, &getLabel) if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return rollbackErr } return err } cluster, err := n.StoreClusterService.GetCluster(ctx, terminal.ClusterEdgeID) if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return rollbackErr } return err } clusterNetworkServices, err := n.StoreClusterService.GetClusterNetworkServices(ctx, cluster.ClusterEdgeID) if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return rollbackErr } return err } ienNodeCRBase64, err := n.TerminalService.CreateDSDSIENodeCR(terminal, clusterNetworkServices, customLabels, cluster.FleetVersion) if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return rollbackErr } return err } msg := client.NewChariotMessage(). SetBanner(cluster.ProjectID). SetOperation(client.Create). SetCluster(cluster.ClusterEdgeID). SetOwner(ComponentOwner). AddObject(ienNodeCRBase64) if err := n.ChariotService.InvokeChariotPubsub(ctx, msg, make(map[string]string)); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return rollbackErr } return err } return tx.Commit() } func (n *terminalLabelService) GetTerminalLabel(ctx context.Context, terminalID string) (*model.TerminalLabel, error) { terminalLabel := &model.TerminalLabel{} row := n.SQLDB.QueryRowContext(ctx, sqlquery.GetTerminalLabel, terminalID) err := row.Scan(&terminalLabel.TerminalID, &terminalLabel.TerminalLabelEdgeID, &terminalLabel.LabelEdgeID) if err != nil { return nil, err } labelDetails, err := n.LabelService.GetLabel(ctx, terminalLabel.LabelEdgeID) if err != nil { // If label data not found, handle gracefully terminalLabel.Label = nil } else { terminalLabel.Label = labelDetails } return terminalLabel, err } func (n *terminalLabelService) GetTerminalLabels(ctx context.Context, terminalLabelInput model.SearchTerminalLabelInput) ([]*model.TerminalLabel, error) { terminalLabels := make([]*model.TerminalLabel, 0) if terminalLabelInput.ClusterEdgeID != nil { // nolint:nestif terminalrows, err := n.SQLDB.QueryContext(ctx, sqlquery.GetTerminalByClusterEdgeIDQuery, terminalLabelInput.ClusterEdgeID) if err != nil { return nil, err } terminals := []*model.Terminal{} for terminalrows.Next() { terminal := &model.Terminal{} err := terminalrows.Scan(&terminal.TerminalID, &terminal.Lane, &terminal.Role, &terminal.ClusterEdgeID, &terminal.ClusterName, &terminal.Class, &terminal.DiscoverDisks, &terminal.BootDisk, &terminal.PrimaryInterface, &terminal.ExistingEfiPart, &terminal.SwapEnabled, &terminal.Hostname) if err != nil { return nil, err } terminals = append(terminals, terminal) } if err := terminalrows.Err(); err != nil { return nil, sqlerr.Wrap(err) } for _, terminal := range terminals { terminalID := terminal.TerminalID if terminalLabelInput.TerminalID != nil { terminalID = *terminalLabelInput.TerminalID } terminalLabelsrow, err := n.SQLDB.QueryContext(ctx, sqlquery.GetTerminalLabels, terminalID, terminalLabelInput.LabelEdgeID) if err != nil { return nil, err } for terminalLabelsrow.Next() { terminalLabel := &model.TerminalLabel{Label: &model.Label{}} err := terminalLabelsrow.Scan(&terminalLabel.TerminalID, &terminalLabel.TerminalLabelEdgeID, &terminalLabel.LabelEdgeID, &terminalLabel.Label.Key, &terminalLabel.Label.Color, &terminalLabel.Label.Visible, &terminalLabel.Label.Editable, &terminalLabel.Label.BannerEdgeID, &terminalLabel.Label.Unique, &terminalLabel.Label.Description, &terminalLabel.Label.Type) if err != nil { return nil, err } terminalLabels = append(terminalLabels, terminalLabel) } if err := terminalLabelsrow.Err(); err != nil { return nil, sqlerr.Wrap(err) } } return terminalLabels, nil } rows, err := n.SQLDB.QueryContext(ctx, sqlquery.GetTerminalLabels, terminalLabelInput.TerminalID, terminalLabelInput.LabelEdgeID) if err != nil { return terminalLabels, err } for rows.Next() { terminalLabel := &model.TerminalLabel{Label: &model.Label{}} //label_key, color, visible, editable, banner_edge_id, label_unique, description, label_type err := rows.Scan(&terminalLabel.TerminalID, &terminalLabel.TerminalLabelEdgeID, &terminalLabel.LabelEdgeID, &terminalLabel.Label.Key, &terminalLabel.Label.Color, &terminalLabel.Label.Visible, &terminalLabel.Label.Editable, &terminalLabel.Label.BannerEdgeID, &terminalLabel.Label.Unique, &terminalLabel.Label.Description, &terminalLabel.Label.Type) if err != nil { return terminalLabels, err } terminalLabels = append(terminalLabels, terminalLabel) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return terminalLabels, nil } func (n *terminalLabelService) GetTerminalLabelsInfo(ctx context.Context, terminals []*model.Terminal) ([]*model.Terminal, error) { for _, terminal := range terminals { terminalLabels, err := n.GetTerminalLabels(ctx, model.SearchTerminalLabelInput{TerminalID: &terminal.TerminalID}) if err != nil { return terminals, err } terminal.Labels = terminalLabels } return terminals, nil } func insertTerminalLabels(ctx context.Context, tx *sql.Tx, terminalID string, labelEdgeIDs ...string) (map[string]string, error) { customLabels := make(map[string]string, 0) rows, err := tx.QueryContext(ctx, sqlquery.GetTerminalLabels, terminalID, sql.NullString{}) // in a perfect world, sql error will be exported but no rows error is not exported so using contains. if err != nil && !strings.Contains(err.Error(), "no rows in result set") { return customLabels, err } for rows.Next() { terminalLabel := &model.TerminalLabel{Label: &model.Label{}} //label_key, color, visible, editable, banner_edge_id, label_unique, description, label_type err := rows.Scan(&terminalLabel.TerminalID, &terminalLabel.TerminalLabelEdgeID, &terminalLabel.LabelEdgeID, &terminalLabel.Label.Key, &terminalLabel.Label.Color, &terminalLabel.Label.Visible, &terminalLabel.Label.Editable, &terminalLabel.Label.BannerEdgeID, &terminalLabel.Label.Unique, &terminalLabel.Label.Description, &terminalLabel.Label.Type) if err != nil { return customLabels, err } // for the terminal label description, we hash and then limit the number of chars to 20 chars. labelDescription := uuid.FromUUID(terminalLabel.Label.Description).Hash() labelKey := utils.ToK8sName(terminalLabel.Label.Key) if err = utils.ValuesValidation([]string{labelDescription, labelKey}); err != nil { return customLabels, err } // avoid prefixing labels automatically added by Edge key := formatCustomLabelKey(terminalLabel.Label) customLabels[key] = labelDescription if terminalLabel.Label.Type != "" && !slices.Contains(capabilities.EdgeAutomatedCapabilityLabelTypes, terminalLabel.Label.Type) { customLabels[fmt.Sprintf(v1ien.CustomNodeLabel, terminalLabel.Label.Type)] = labelKey } } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } for _, labelEdgeID := range labelEdgeIDs { _, err := tx.ExecContext(ctx, sqlquery.InsertTerminalLabel, terminalID, labelEdgeID) if err != nil { return customLabels, err } label := &model.Label{} row := tx.QueryRowContext(ctx, sqlquery.GetLabelQuery, labelEdgeID) if err := row.Scan(&label.LabelEdgeID, &label.Key, &label.Color, &label.Visible, &label.Editable, &label.BannerEdgeID, &label.Unique, &label.Description, &label.Type); err != nil { return customLabels, err } // for the terminal label description, we hash and then limit the number of chars to 20 chars. labelDescription := uuid.FromUUID(label.Description).Hash() labelKey := utils.ToK8sName(label.Key) if err = utils.ValuesValidation([]string{labelDescription, labelKey}); err != nil { return customLabels, err } key := formatCustomLabelKey(label) customLabels[key] = labelDescription if label.Type != "" && !slices.Contains(capabilities.EdgeAutomatedCapabilityLabelTypes, label.Type) { customLabels[fmt.Sprintf(v1ien.CustomNodeLabel, label.Type)] = labelKey } } return customLabels, nil } func (n *terminalLabelService) DeleteTerminalLabels(ctx context.Context, terminalLabelInput model.SearchTerminalLabelInput) ([]*model.TerminalLabel, error) { terminaLabels, err := n.GetTerminalLabels(ctx, terminalLabelInput) if err != nil { return nil, err } for _, terminaLabel := range terminaLabels { _, err = n.SQLDB.ExecContext(ctx, sqlquery.DeleteTerminalLabel, terminaLabel.TerminalLabelEdgeID) if err != nil { return nil, err } } return terminaLabels, nil } func (n *terminalLabelService) SendUpdatedIENCRAfterDeletion(ctx context.Context, affectedTerminalLabels []*model.TerminalLabel) error { //First set an empty map for all effected terminals so we will build an update ien cr for them updatedTerminals := make(map[string]map[string]string, 1) for _, terminalLabel := range affectedTerminalLabels { updatedTerminals[terminalLabel.TerminalID] = make(map[string]string) } getLabels := true for terminalID, customLabels := range updatedTerminals { terminal, err := n.TerminalService.GetTerminal(ctx, terminalID, &getLabels) if err != nil { return err } //add remaining labels to the custom labels to be on update ien node cr for _, terminalLabel := range terminal.Labels { key := formatCustomLabelKey(terminalLabel.Label) customLabels[key] = uuid.FromUUID(terminalLabel.Label.Description).Hash() if terminalLabel.Label.Type != "" && !slices.Contains(capabilities.EdgeAutomatedCapabilityLabelTypes, terminalLabel.Label.Type) { customLabels[fmt.Sprintf(v1ien.CustomNodeLabel, terminalLabel.Label.Key)] = utils.ToK8sName(terminalLabel.Label.Description) } } cluster, err := n.StoreClusterService.GetCluster(ctx, terminal.ClusterEdgeID) if err != nil { return err } clusterNetworkServices, err := n.StoreClusterService.GetClusterNetworkServices(ctx, cluster.ClusterEdgeID) if err != nil { return err } ienNodeCRBase64, err := n.TerminalService.CreateDSDSIENodeCR(terminal, clusterNetworkServices, customLabels, cluster.FleetVersion) if err != nil { return err } msg := client.NewChariotMessage(). SetBanner(cluster.ProjectID). SetOperation(client.Create). SetCluster(cluster.ClusterEdgeID). SetOwner(ComponentOwner). AddObject(ienNodeCRBase64) if err := n.ChariotService.InvokeChariotPubsub(ctx, msg, make(map[string]string)); err != nil { return err } } return nil } func NewTerminalLabelService(sqlDB *sql.DB, chariotSvc ChariotService, terminalSvc TerminalService, storeClusterSvc StoreClusterService, labelSvc LabelService) TerminalLabelService { //nolint return &terminalLabelService{ SQLDB: sqlDB, ChariotService: chariotSvc, TerminalService: terminalSvc, StoreClusterService: storeClusterSvc, LabelService: labelSvc, } } // formatCustomLabel returns a formatted key with custom label prefix func formatCustomLabelKey(label *model.Label) string { if !slices.Contains(capabilities.EdgeAutomatedCapabilityLabelTypes, label.Type) { return fmt.Sprintf(v1ien.CustomNodeLabel, label.Key) } return label.Key }