package services import ( "context" "database/sql" "fmt" "github.com/lib/pq" "edge-infra.dev/pkg/edge/api/apierror" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" "edge-infra.dev/pkg/edge/api/services/artifacts" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/capabilities" "edge-infra.dev/pkg/edge/compatibility" clusterType "edge-infra.dev/pkg/edge/constants/api/cluster" "edge-infra.dev/pkg/edge/constants/api/fleet" ) //go:generate mockgen -destination=../mocks/mock_label_service.go -package=mocks edge-infra.dev/pkg/edge/api/services LabelService type LabelService interface { CreateLabel(ctx context.Context, label *model.LabelInput) error UpdateLabel(ctx context.Context, label *model.LabelUpdateInput) error GetLabels(ctx context.Context, bannerEdgeID *string) ([]*model.Label, error) GetLabelTypes(ctx context.Context, bannerEdgeID *string) ([]string, error) DeleteLabels(ctx context.Context, labelEdgeID string) error CreateClusterLabel(ctx context.Context, clusterEdgeID, labelEdgeID string) error CreateClusterLabels(ctx context.Context, clusterEdgeID string, labelEdgeIDs []string) error GetClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) ([]*model.ClusterLabel, error) DeleteClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) error GetEdgeClusterLabelKeys(ctx context.Context, clusterEdgeID string) ([]string, error) GetLabel(ctx context.Context, labelEdgeID string) (*model.Label, error) GetLabelTenant(ctx context.Context, labelEdgeID string) (*model.Tenant, error) } func validateLabelInput(ctx context.Context, label *model.LabelInput) error { switch { case label.BannerEdgeID == "": return apierror.New("missing label BannerEdgeID").SetOperationID(ctx) case label.Type == clusterType.LabelType || label.Type == fleet.LabelType: var msg = fmt.Sprintf("label Type is reserved: %q", label.Type) return apierror.New(msg).SetOperationID(ctx) case label.Key == "": return apierror.New("missing label Key").SetOperationID(ctx) case label.Unique && label.Type == "": return apierror.New("missing label Type for Unique label").SetOperationID(ctx) } return nil } type labelService struct { ArtifactsService artifacts.Service SQLDB *sql.DB } func (s *labelService) CreateLabel(ctx context.Context, label *model.LabelInput) error { if err := validateLabelInput(ctx, label); err != nil { return err } _, err := s.SQLDB.ExecContext(ctx, sqlquery.LabelInsertQuery, label.Key, label.Color, label.Visible, label.Editable, label.BannerEdgeID, label.Unique, label.Description, label.Type) if err != nil { return err } return nil } func (s *labelService) UpdateLabel(ctx context.Context, label *model.LabelUpdateInput) error { if err := validateLabelInput(ctx, label.LabelValues); err != nil { return err } tx, err := s.SQLDB.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() //nolint existing, err := s.txGetLabel(ctx, tx, label.LabelEdgeID) if err != nil { return err } if existing.BannerEdgeID == nil || *existing.BannerEdgeID == "" { return apierror.New("updating banner-wide labels is forbidden").SetOperationID(ctx) } _, err = tx.ExecContext(ctx, sqlquery.LabelUpdateQuery, label.LabelValues.Key, label.LabelValues.Color, label.LabelValues.Visible, label.LabelValues.Editable, label.LabelValues.BannerEdgeID, label.LabelValues.Unique, label.LabelValues.Description, label.LabelValues.Type, label.LabelEdgeID) if err != nil { return err } // Uniqueness must be checked when an update could create conflicts in cluster_labels. switch { case label.LabelValues.Unique && !existing.Unique: // the label became unique case label.LabelValues.Unique && label.LabelValues.Type != existing.Type: // the unique label's type changed default: return tx.Commit() } rows, err := tx.QueryContext(ctx, sqlquery.SelectClusterLabelsUniqueConflicts, label.LabelValues.Type) if err != nil { return err } defer rows.Close() //nolint var conflictingClusters []string for rows.Next() { var ceid string if err := rows.Scan(&ceid); err != nil { return err } conflictingClusters = append(conflictingClusters, ceid) } if err := rows.Err(); err != nil { return err } if len(conflictingClusters) > 0 { err = fmt.Errorf("uniqueness conflicts detected for label type %q in the following clusters: %v", label.LabelValues.Type, conflictingClusters) return apierror.Wrap(err).SetOperationID(ctx) } return tx.Commit() } func (s *labelService) GetLabel(ctx context.Context, labelEdgeID string) (*model.Label, error) { label := &model.Label{} row := s.SQLDB.QueryRowContext(ctx, sqlquery.GetLabelQuery, labelEdgeID) err := row.Scan(&label.LabelEdgeID, &label.Key, &label.Color, &label.Visible, &label.Editable, &label.BannerEdgeID, &label.Unique, &label.Description, &label.Type) if err != nil { return nil, err } return label, nil } func (s *labelService) txGetLabel(ctx context.Context, tx *sql.Tx, labelEdgeID string) (*model.Label, error) { label := &model.Label{} row := tx.QueryRowContext(ctx, sqlquery.GetLabelQuery, labelEdgeID) err := row.Scan(&label.LabelEdgeID, &label.Key, &label.Color, &label.Visible, &label.Editable, &label.BannerEdgeID, &label.Unique, &label.Description, &label.Type) if err != nil { return nil, err } return label, nil } func (s *labelService) txGetClusterFleetVersion(ctx context.Context, tx *sql.Tx, clusterEdgeID string) (string, error) { var fleetVersion string row := tx.QueryRowContext(ctx, sqlquery.GetClusterFleetVersion, clusterEdgeID) err := row.Scan(&fleetVersion) if err != nil { return "", sqlerr.Wrap(err) } return fleetVersion, nil } func (s *labelService) GetLabelTenant(ctx context.Context, labelEdgeID string) (*model.Tenant, error) { tenant := &model.Tenant{} row := s.SQLDB.QueryRowContext(ctx, sqlquery.GetTenantLabelQuery, labelEdgeID) err := row.Scan(&tenant.TenantEdgeID, &tenant.TenantBSLId, &tenant.OrgName) if err != nil { return nil, err } return tenant, nil } func (s *labelService) GetLabels(ctx context.Context, bannerEdgeID *string) ([]*model.Label, error) { var rows *sql.Rows var err error if bannerEdgeID == nil { rows, err = s.SQLDB.QueryContext(ctx, sqlquery.GetLabelsQuery) } else if *bannerEdgeID == "" { rows, err = s.SQLDB.QueryContext(ctx, sqlquery.GetNoBannerLabelsQuery) } else { rows, err = s.SQLDB.QueryContext(ctx, sqlquery.GetLabelsByBannerQuery, bannerEdgeID) } if err != nil { return nil, err } labels := []*model.Label{} defer rows.Close() for rows.Next() { var label model.Label if err = rows.Scan(&label.LabelEdgeID, &label.Key, &label.Color, &label.Visible, &label.Editable, &label.BannerEdgeID, &label.Unique, &label.Description, &label.Type); err != nil { return nil, err } labels = append(labels, &label) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return labels, nil } func (s *labelService) GetLabelTypes(ctx context.Context, bannerEdgeID *string) ([]string, error) { var err error var row *sql.Row if bannerEdgeID == nil { row = s.SQLDB.QueryRowContext(ctx, sqlquery.SelectLabelTypes) } else { row = s.SQLDB.QueryRowContext(ctx, sqlquery.SelectLabelTypesByBanner, bannerEdgeID) } labelTypes := []string{} if err = row.Scan(pq.Array(&labelTypes)); err != nil { return nil, err } return labelTypes, nil } func (s *labelService) GetEdgeClusterLabelKeys(ctx context.Context, clusterEdgeID string) ([]string, error) { var labelKeys []string row := s.SQLDB.QueryRowContext(ctx, sqlquery.SelectEdgeLabelKeys, clusterEdgeID) if err := row.Scan(pq.Array(&labelKeys)); err != nil { return nil, err } return labelKeys, nil } func (s *labelService) DeleteLabels(ctx context.Context, labelEdgeID string) error { _, err := s.SQLDB.ExecContext(ctx, sqlquery.LabelDeleteQuery, labelEdgeID) if err != nil { return err } return nil } func (s *labelService) CreateClusterLabel(ctx context.Context, clusterEdgeID, labelEdgeID string) error { tx, err := s.SQLDB.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() //nolint err = s.txCreateClusterLabel(ctx, tx, clusterEdgeID, labelEdgeID) if err != nil { return err } return tx.Commit() } func (s *labelService) txCreateClusterLabel(ctx context.Context, tx *sql.Tx, clusterEdgeID, labelEdgeID string) error { label, err := s.txGetLabel(ctx, tx, labelEdgeID) if err != nil { return err } if label.BannerEdgeID == nil || *label.BannerEdgeID == "" { return apierror.New("banner-wide cluster labels can only be created by the registration service").SetOperationID(ctx) } _, err = tx.ExecContext(ctx, sqlquery.ClusterLabelInsertQuery, clusterEdgeID, labelEdgeID) if err != nil { return err } if label.Unique { var count int err = tx.QueryRowContext(ctx, sqlquery.SelectClusterLabelsUniqueCount, clusterEdgeID, label.Type).Scan(&count) if err != nil { return err } else if count > 1 { return apierror.New("error creating cluster label due to uniqueness conflicts").SetOperationID(ctx) } } if label.Type == capabilities.EdgeCapabilitiesLabel { clusterFleetVersion, err := s.txGetClusterFleetVersion(ctx, tx, clusterEdgeID) if err != nil { return err } //get version edge capability label was introduced in capability := capabilities.GetCapability(label.Key) compatible, err := compatibility.Compare(compatibility.GreaterThanOrEqual, clusterFleetVersion, capability.VersionIntroduced) if err != nil { return err } if !compatible { return apierror.New(fmt.Sprintf("error creating cluster label, label can only be created on clusters on %s or higher", capability.VersionIntroduced)).SetOperationID(ctx) } err = s.addOrDeleteEdgeCapabilityClusterArtifactVersion(ctx, tx, label, clusterEdgeID, "add") if err != nil { return err } } return nil } func (s *labelService) CreateClusterLabels(ctx context.Context, clusterEdgeID string, labelEdgeIDs []string) error { tx, err := s.SQLDB.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() //nolint for _, labelEdgeID := range labelEdgeIDs { err = s.txCreateClusterLabel(ctx, tx, clusterEdgeID, labelEdgeID) if err != nil { return err } } return tx.Commit() } func (s *labelService) GetClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) ([]*model.ClusterLabel, error) { query, vals, err := BuildClusterLabelQuery(sqlquery.GetClusterLabelsQuery, clusterEdgeID, labelEdgeID) if err != nil { return nil, err } rows, err := s.SQLDB.QueryContext(ctx, query, vals...) if err != nil { return nil, err } defer rows.Close() labels := []*model.ClusterLabel{} for rows.Next() { var label model.ClusterLabel if err = rows.Scan(&label.ClusterEdgeID, &label.LabelEdgeID); err != nil { return nil, err } labels = append(labels, &label) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return labels, nil } func BuildClusterLabelQuery(queryString string, clusterEdgeID *string, labelEdgeID *string) (string, []interface{}, error) { var query string var vals []interface{} if clusterEdgeID != nil && labelEdgeID != nil { query = fmt.Sprintf(queryString, sqlquery.WhereClusterEdgeIDAndLabelID) vals = append(vals, clusterEdgeID, labelEdgeID) } else if clusterEdgeID != nil { query = fmt.Sprintf(queryString, sqlquery.WhereClusterEdgeID) vals = append(vals, clusterEdgeID) } else if labelEdgeID != nil { query = fmt.Sprintf(queryString, sqlquery.WhereLabelID) vals = append(vals, labelEdgeID) } else { return "", nil, fmt.Errorf("labelEdgeID and/or ClusterEdgeID must be set") } return query, vals, nil } func (s *labelService) DeleteClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) error { if labelEdgeID == nil || *labelEdgeID == "" { // TODO update the mutation to make labelEdgeID required. return apierror.New("labelEdgeID is required").SetOperationID(ctx) } query, vals, err := BuildClusterLabelQuery(sqlquery.ClusterLabelDeleteQuery, clusterEdgeID, labelEdgeID) if err != nil { return err } tx, err := s.SQLDB.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() //nolint rows, err := tx.QueryContext(ctx, query, vals...) if err != nil { return err } defer rows.Close() var deleted = make(map[string][]string) for rows.Next() { var ceid, leid string if err = rows.Scan(&ceid, &leid); err != nil { return err } deleted[leid] = append(deleted[leid], ceid) } if err = rows.Err(); err != nil { return err } for leid, clusters := range deleted { label, err := s.txGetLabel(ctx, tx, leid) if err != nil { return err } switch label.Type { case clusterType.LabelType, fleet.LabelType: // Prevent accidental deletion of banner-wide labels. They should only be deleted in the following cases: // 1. On cascade, when a cluster is deleted. // 2. On cascade, when a banner-wide label is deleted (through a database migration). return apierror.New("deleting banner-wide cluster labels is forbidden").SetOperationID(ctx) case capabilities.EdgeCapabilitiesLabel: for _, ceid := range clusters { err = s.addOrDeleteEdgeCapabilityClusterArtifactVersion(ctx, tx, label, ceid, "delete") if err != nil { return err } } } } return tx.Commit() } func (s *labelService) addOrDeleteEdgeCapabilityClusterArtifactVersion(ctx context.Context, tx *sql.Tx, label *model.Label, clusterEdgeID string, operation string) error { edgeCapability := capabilities.GetCapability(label.Key) if edgeCapability == nil { return fmt.Errorf("%s is not a valid edge capability", label.Key) } switch operation { case "add": return s.ArtifactsService.AddClusterArtifactVersion(ctx, tx, clusterEdgeID, edgeCapability.ArtifactName) case "delete": return s.ArtifactsService.DeleteClusterArtifactVersion(ctx, tx, clusterEdgeID, edgeCapability.ArtifactName) default: return nil } } func NewLabelService(artifactService artifacts.Service, sqlDB *sql.DB) *labelService { //nolint stupid return &labelService{ ArtifactsService: artifactService, SQLDB: sqlDB, } }