...

Source file src/edge-infra.dev/pkg/edge/api/services/label_service.go

Documentation: edge-infra.dev/pkg/edge/api/services

     1  package services
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  
     8  	"github.com/lib/pq"
     9  
    10  	"edge-infra.dev/pkg/edge/api/apierror"
    11  	sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql"
    12  	"edge-infra.dev/pkg/edge/api/graph/model"
    13  	"edge-infra.dev/pkg/edge/api/services/artifacts"
    14  	sqlquery "edge-infra.dev/pkg/edge/api/sql"
    15  	"edge-infra.dev/pkg/edge/capabilities"
    16  	"edge-infra.dev/pkg/edge/compatibility"
    17  	clusterType "edge-infra.dev/pkg/edge/constants/api/cluster"
    18  	"edge-infra.dev/pkg/edge/constants/api/fleet"
    19  )
    20  
    21  //go:generate mockgen -destination=../mocks/mock_label_service.go -package=mocks edge-infra.dev/pkg/edge/api/services LabelService
    22  type LabelService interface {
    23  	CreateLabel(ctx context.Context, label *model.LabelInput) error
    24  	UpdateLabel(ctx context.Context, label *model.LabelUpdateInput) error
    25  	GetLabels(ctx context.Context, bannerEdgeID *string) ([]*model.Label, error)
    26  	GetLabelTypes(ctx context.Context, bannerEdgeID *string) ([]string, error)
    27  	DeleteLabels(ctx context.Context, labelEdgeID string) error
    28  	CreateClusterLabel(ctx context.Context, clusterEdgeID, labelEdgeID string) error
    29  	CreateClusterLabels(ctx context.Context, clusterEdgeID string, labelEdgeIDs []string) error
    30  	GetClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) ([]*model.ClusterLabel, error)
    31  	DeleteClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) error
    32  	GetEdgeClusterLabelKeys(ctx context.Context, clusterEdgeID string) ([]string, error)
    33  	GetLabel(ctx context.Context, labelEdgeID string) (*model.Label, error)
    34  	GetLabelTenant(ctx context.Context, labelEdgeID string) (*model.Tenant, error)
    35  }
    36  
    37  func validateLabelInput(ctx context.Context, label *model.LabelInput) error {
    38  	switch {
    39  	case label.BannerEdgeID == "":
    40  		return apierror.New("missing label BannerEdgeID").SetOperationID(ctx)
    41  	case label.Type == clusterType.LabelType || label.Type == fleet.LabelType:
    42  		var msg = fmt.Sprintf("label Type is reserved: %q", label.Type)
    43  		return apierror.New(msg).SetOperationID(ctx)
    44  	case label.Key == "":
    45  		return apierror.New("missing label Key").SetOperationID(ctx)
    46  	case label.Unique && label.Type == "":
    47  		return apierror.New("missing label Type for Unique label").SetOperationID(ctx)
    48  	}
    49  	return nil
    50  }
    51  
    52  type labelService struct {
    53  	ArtifactsService artifacts.Service
    54  	SQLDB            *sql.DB
    55  }
    56  
    57  func (s *labelService) CreateLabel(ctx context.Context, label *model.LabelInput) error {
    58  	if err := validateLabelInput(ctx, label); err != nil {
    59  		return err
    60  	}
    61  	_, err := s.SQLDB.ExecContext(ctx, sqlquery.LabelInsertQuery, label.Key, label.Color, label.Visible, label.Editable, label.BannerEdgeID, label.Unique, label.Description, label.Type)
    62  	if err != nil {
    63  		return err
    64  	}
    65  	return nil
    66  }
    67  
    68  func (s *labelService) UpdateLabel(ctx context.Context, label *model.LabelUpdateInput) error {
    69  	if err := validateLabelInput(ctx, label.LabelValues); err != nil {
    70  		return err
    71  	}
    72  
    73  	tx, err := s.SQLDB.BeginTx(ctx, nil)
    74  	if err != nil {
    75  		return err
    76  	}
    77  	defer tx.Rollback() //nolint
    78  
    79  	existing, err := s.txGetLabel(ctx, tx, label.LabelEdgeID)
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	if existing.BannerEdgeID == nil || *existing.BannerEdgeID == "" {
    85  		return apierror.New("updating banner-wide labels is forbidden").SetOperationID(ctx)
    86  	}
    87  
    88  	_, err = tx.ExecContext(ctx, sqlquery.LabelUpdateQuery, label.LabelValues.Key, label.LabelValues.Color,
    89  		label.LabelValues.Visible, label.LabelValues.Editable, label.LabelValues.BannerEdgeID, label.LabelValues.Unique,
    90  		label.LabelValues.Description, label.LabelValues.Type, label.LabelEdgeID)
    91  	if err != nil {
    92  		return err
    93  	}
    94  
    95  	// Uniqueness must be checked when an update could create conflicts in cluster_labels.
    96  	switch {
    97  	case label.LabelValues.Unique && !existing.Unique: // the label became unique
    98  	case label.LabelValues.Unique && label.LabelValues.Type != existing.Type: // the unique label's type changed
    99  	default:
   100  		return tx.Commit()
   101  	}
   102  
   103  	rows, err := tx.QueryContext(ctx, sqlquery.SelectClusterLabelsUniqueConflicts, label.LabelValues.Type)
   104  	if err != nil {
   105  		return err
   106  	}
   107  	defer rows.Close() //nolint
   108  
   109  	var conflictingClusters []string
   110  	for rows.Next() {
   111  		var ceid string
   112  		if err := rows.Scan(&ceid); err != nil {
   113  			return err
   114  		}
   115  		conflictingClusters = append(conflictingClusters, ceid)
   116  	}
   117  	if err := rows.Err(); err != nil {
   118  		return err
   119  	}
   120  
   121  	if len(conflictingClusters) > 0 {
   122  		err = fmt.Errorf("uniqueness conflicts detected for label type %q in the following clusters: %v", label.LabelValues.Type, conflictingClusters)
   123  		return apierror.Wrap(err).SetOperationID(ctx)
   124  	}
   125  
   126  	return tx.Commit()
   127  }
   128  
   129  func (s *labelService) GetLabel(ctx context.Context, labelEdgeID string) (*model.Label, error) {
   130  	label := &model.Label{}
   131  	row := s.SQLDB.QueryRowContext(ctx, sqlquery.GetLabelQuery, labelEdgeID)
   132  	err := row.Scan(&label.LabelEdgeID, &label.Key, &label.Color, &label.Visible, &label.Editable, &label.BannerEdgeID, &label.Unique, &label.Description, &label.Type)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	return label, nil
   137  }
   138  
   139  func (s *labelService) txGetLabel(ctx context.Context, tx *sql.Tx, labelEdgeID string) (*model.Label, error) {
   140  	label := &model.Label{}
   141  	row := tx.QueryRowContext(ctx, sqlquery.GetLabelQuery, labelEdgeID)
   142  	err := row.Scan(&label.LabelEdgeID, &label.Key, &label.Color, &label.Visible, &label.Editable, &label.BannerEdgeID, &label.Unique, &label.Description, &label.Type)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	return label, nil
   147  }
   148  
   149  func (s *labelService) txGetClusterFleetVersion(ctx context.Context, tx *sql.Tx, clusterEdgeID string) (string, error) {
   150  	var fleetVersion string
   151  	row := tx.QueryRowContext(ctx, sqlquery.GetClusterFleetVersion, clusterEdgeID)
   152  	err := row.Scan(&fleetVersion)
   153  	if err != nil {
   154  		return "", sqlerr.Wrap(err)
   155  	}
   156  	return fleetVersion, nil
   157  }
   158  
   159  func (s *labelService) GetLabelTenant(ctx context.Context, labelEdgeID string) (*model.Tenant, error) {
   160  	tenant := &model.Tenant{}
   161  	row := s.SQLDB.QueryRowContext(ctx, sqlquery.GetTenantLabelQuery, labelEdgeID)
   162  	err := row.Scan(&tenant.TenantEdgeID, &tenant.TenantBSLId, &tenant.OrgName)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	return tenant, nil
   167  }
   168  
   169  func (s *labelService) GetLabels(ctx context.Context, bannerEdgeID *string) ([]*model.Label, error) {
   170  	var rows *sql.Rows
   171  	var err error
   172  	if bannerEdgeID == nil {
   173  		rows, err = s.SQLDB.QueryContext(ctx, sqlquery.GetLabelsQuery)
   174  	} else if *bannerEdgeID == "" {
   175  		rows, err = s.SQLDB.QueryContext(ctx, sqlquery.GetNoBannerLabelsQuery)
   176  	} else {
   177  		rows, err = s.SQLDB.QueryContext(ctx, sqlquery.GetLabelsByBannerQuery, bannerEdgeID)
   178  	}
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	labels := []*model.Label{}
   183  	defer rows.Close()
   184  	for rows.Next() {
   185  		var label model.Label
   186  		if err = rows.Scan(&label.LabelEdgeID, &label.Key, &label.Color, &label.Visible, &label.Editable, &label.BannerEdgeID, &label.Unique, &label.Description, &label.Type); err != nil {
   187  			return nil, err
   188  		}
   189  		labels = append(labels, &label)
   190  	}
   191  	if err := rows.Err(); err != nil {
   192  		return nil, sqlerr.Wrap(err)
   193  	}
   194  	return labels, nil
   195  }
   196  
   197  func (s *labelService) GetLabelTypes(ctx context.Context, bannerEdgeID *string) ([]string, error) {
   198  	var err error
   199  	var row *sql.Row
   200  	if bannerEdgeID == nil {
   201  		row = s.SQLDB.QueryRowContext(ctx, sqlquery.SelectLabelTypes)
   202  	} else {
   203  		row = s.SQLDB.QueryRowContext(ctx, sqlquery.SelectLabelTypesByBanner, bannerEdgeID)
   204  	}
   205  	labelTypes := []string{}
   206  	if err = row.Scan(pq.Array(&labelTypes)); err != nil {
   207  		return nil, err
   208  	}
   209  	return labelTypes, nil
   210  }
   211  
   212  func (s *labelService) GetEdgeClusterLabelKeys(ctx context.Context, clusterEdgeID string) ([]string, error) {
   213  	var labelKeys []string
   214  	row := s.SQLDB.QueryRowContext(ctx, sqlquery.SelectEdgeLabelKeys, clusterEdgeID)
   215  	if err := row.Scan(pq.Array(&labelKeys)); err != nil {
   216  		return nil, err
   217  	}
   218  	return labelKeys, nil
   219  }
   220  
   221  func (s *labelService) DeleteLabels(ctx context.Context, labelEdgeID string) error {
   222  	_, err := s.SQLDB.ExecContext(ctx, sqlquery.LabelDeleteQuery, labelEdgeID)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	return nil
   227  }
   228  
   229  func (s *labelService) CreateClusterLabel(ctx context.Context, clusterEdgeID, labelEdgeID string) error {
   230  	tx, err := s.SQLDB.BeginTx(ctx, nil)
   231  	if err != nil {
   232  		return err
   233  	}
   234  	defer tx.Rollback() //nolint
   235  	err = s.txCreateClusterLabel(ctx, tx, clusterEdgeID, labelEdgeID)
   236  	if err != nil {
   237  		return err
   238  	}
   239  	return tx.Commit()
   240  }
   241  
   242  func (s *labelService) txCreateClusterLabel(ctx context.Context, tx *sql.Tx, clusterEdgeID, labelEdgeID string) error {
   243  	label, err := s.txGetLabel(ctx, tx, labelEdgeID)
   244  	if err != nil {
   245  		return err
   246  	}
   247  
   248  	if label.BannerEdgeID == nil || *label.BannerEdgeID == "" {
   249  		return apierror.New("banner-wide cluster labels can only be created by the registration service").SetOperationID(ctx)
   250  	}
   251  
   252  	_, err = tx.ExecContext(ctx, sqlquery.ClusterLabelInsertQuery, clusterEdgeID, labelEdgeID)
   253  	if err != nil {
   254  		return err
   255  	}
   256  
   257  	if label.Unique {
   258  		var count int
   259  		err = tx.QueryRowContext(ctx, sqlquery.SelectClusterLabelsUniqueCount, clusterEdgeID, label.Type).Scan(&count)
   260  		if err != nil {
   261  			return err
   262  		} else if count > 1 {
   263  			return apierror.New("error creating cluster label due to uniqueness conflicts").SetOperationID(ctx)
   264  		}
   265  	}
   266  
   267  	if label.Type == capabilities.EdgeCapabilitiesLabel {
   268  		clusterFleetVersion, err := s.txGetClusterFleetVersion(ctx, tx, clusterEdgeID)
   269  		if err != nil {
   270  			return err
   271  		}
   272  
   273  		//get version edge capability label was introduced in
   274  		capability := capabilities.GetCapability(label.Key)
   275  		compatible, err := compatibility.Compare(compatibility.GreaterThanOrEqual, clusterFleetVersion, capability.VersionIntroduced)
   276  		if err != nil {
   277  			return err
   278  		}
   279  		if !compatible {
   280  			return apierror.New(fmt.Sprintf("error creating cluster label, label can only be created on clusters on %s or higher", capability.VersionIntroduced)).SetOperationID(ctx)
   281  		}
   282  
   283  		err = s.addOrDeleteEdgeCapabilityClusterArtifactVersion(ctx, tx, label, clusterEdgeID, "add")
   284  		if err != nil {
   285  			return err
   286  		}
   287  	}
   288  	return nil
   289  }
   290  
   291  func (s *labelService) CreateClusterLabels(ctx context.Context, clusterEdgeID string, labelEdgeIDs []string) error {
   292  	tx, err := s.SQLDB.BeginTx(ctx, nil)
   293  	if err != nil {
   294  		return err
   295  	}
   296  	defer tx.Rollback() //nolint
   297  	for _, labelEdgeID := range labelEdgeIDs {
   298  		err = s.txCreateClusterLabel(ctx, tx, clusterEdgeID, labelEdgeID)
   299  		if err != nil {
   300  			return err
   301  		}
   302  	}
   303  	return tx.Commit()
   304  }
   305  
   306  func (s *labelService) GetClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) ([]*model.ClusterLabel, error) {
   307  	query, vals, err := BuildClusterLabelQuery(sqlquery.GetClusterLabelsQuery, clusterEdgeID, labelEdgeID)
   308  	if err != nil {
   309  		return nil, err
   310  	}
   311  	rows, err := s.SQLDB.QueryContext(ctx, query, vals...)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  	defer rows.Close()
   316  
   317  	labels := []*model.ClusterLabel{}
   318  	for rows.Next() {
   319  		var label model.ClusterLabel
   320  		if err = rows.Scan(&label.ClusterEdgeID, &label.LabelEdgeID); err != nil {
   321  			return nil, err
   322  		}
   323  		labels = append(labels, &label)
   324  	}
   325  	if err := rows.Err(); err != nil {
   326  		return nil, sqlerr.Wrap(err)
   327  	}
   328  	return labels, nil
   329  }
   330  
   331  func BuildClusterLabelQuery(queryString string, clusterEdgeID *string, labelEdgeID *string) (string, []interface{}, error) {
   332  	var query string
   333  	var vals []interface{}
   334  	if clusterEdgeID != nil && labelEdgeID != nil {
   335  		query = fmt.Sprintf(queryString, sqlquery.WhereClusterEdgeIDAndLabelID)
   336  		vals = append(vals, clusterEdgeID, labelEdgeID)
   337  	} else if clusterEdgeID != nil {
   338  		query = fmt.Sprintf(queryString, sqlquery.WhereClusterEdgeID)
   339  		vals = append(vals, clusterEdgeID)
   340  	} else if labelEdgeID != nil {
   341  		query = fmt.Sprintf(queryString, sqlquery.WhereLabelID)
   342  		vals = append(vals, labelEdgeID)
   343  	} else {
   344  		return "", nil, fmt.Errorf("labelEdgeID and/or ClusterEdgeID must be set")
   345  	}
   346  	return query, vals, nil
   347  }
   348  
   349  func (s *labelService) DeleteClusterLabels(ctx context.Context, clusterEdgeID, labelEdgeID *string) error {
   350  	if labelEdgeID == nil || *labelEdgeID == "" {
   351  		// TODO update the mutation to make labelEdgeID required.
   352  		return apierror.New("labelEdgeID is required").SetOperationID(ctx)
   353  	}
   354  
   355  	query, vals, err := BuildClusterLabelQuery(sqlquery.ClusterLabelDeleteQuery, clusterEdgeID, labelEdgeID)
   356  	if err != nil {
   357  		return err
   358  	}
   359  	tx, err := s.SQLDB.BeginTx(ctx, nil)
   360  	if err != nil {
   361  		return err
   362  	}
   363  	defer tx.Rollback() //nolint
   364  
   365  	rows, err := tx.QueryContext(ctx, query, vals...)
   366  	if err != nil {
   367  		return err
   368  	}
   369  	defer rows.Close()
   370  
   371  	var deleted = make(map[string][]string)
   372  	for rows.Next() {
   373  		var ceid, leid string
   374  		if err = rows.Scan(&ceid, &leid); err != nil {
   375  			return err
   376  		}
   377  		deleted[leid] = append(deleted[leid], ceid)
   378  	}
   379  	if err = rows.Err(); err != nil {
   380  		return err
   381  	}
   382  
   383  	for leid, clusters := range deleted {
   384  		label, err := s.txGetLabel(ctx, tx, leid)
   385  		if err != nil {
   386  			return err
   387  		}
   388  
   389  		switch label.Type {
   390  		case clusterType.LabelType, fleet.LabelType:
   391  			// Prevent accidental deletion of banner-wide labels. They should only be deleted in the following cases:
   392  			// 1. On cascade, when a cluster is deleted.
   393  			// 2. On cascade, when a banner-wide label is deleted (through a database migration).
   394  			return apierror.New("deleting banner-wide cluster labels is forbidden").SetOperationID(ctx)
   395  		case capabilities.EdgeCapabilitiesLabel:
   396  			for _, ceid := range clusters {
   397  				err = s.addOrDeleteEdgeCapabilityClusterArtifactVersion(ctx, tx, label, ceid, "delete")
   398  				if err != nil {
   399  					return err
   400  				}
   401  			}
   402  		}
   403  	}
   404  	return tx.Commit()
   405  }
   406  
   407  func (s *labelService) addOrDeleteEdgeCapabilityClusterArtifactVersion(ctx context.Context, tx *sql.Tx, label *model.Label, clusterEdgeID string, operation string) error {
   408  	edgeCapability := capabilities.GetCapability(label.Key)
   409  	if edgeCapability == nil {
   410  		return fmt.Errorf("%s is not a valid edge capability", label.Key)
   411  	}
   412  	switch operation {
   413  	case "add":
   414  		return s.ArtifactsService.AddClusterArtifactVersion(ctx, tx, clusterEdgeID, edgeCapability.ArtifactName)
   415  	case "delete":
   416  		return s.ArtifactsService.DeleteClusterArtifactVersion(ctx, tx, clusterEdgeID, edgeCapability.ArtifactName)
   417  	default:
   418  		return nil
   419  	}
   420  }
   421  
   422  func NewLabelService(artifactService artifacts.Service, sqlDB *sql.DB) *labelService { //nolint stupid
   423  	return &labelService{
   424  		ArtifactsService: artifactService,
   425  		SQLDB:            sqlDB,
   426  	}
   427  }
   428  

View as plain text