package services import ( "context" "database/sql" "errors" "github.com/google/uuid" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/api/utils" ) //go:generate mockgen -destination=../mocks/mock_artifact_registry_service.go -package=mocks edge-infra.dev/pkg/edge/api/services ArtifactRegistryService type ArtifactRegistryService interface { // Retrieve the artifact registry entry with a given ID. GetArtifactRegistry(ctx context.Context, registryEdgeID string) (*model.ArtifactRegistry, error) // Retrieve all artifact registries entries for a given banner. GetArtifactRegistriesForBanner(ctx context.Context, bannerEdgeID string) ([]*model.ArtifactRegistry, error) // Retrieve all artifact registries entries for a given cluster. GetArtifactRegistriesForCluster(ctx context.Context, clusterEdgeID string) ([]*model.ArtifactRegistry, error) // Create a new artifact registry entry. CreateArtifactRegistryEntry(ctx context.Context, createArtifactRegistry *model.ArtifactRegistryCreateInput) (*model.ArtifactRegistry, error) // Update the artifact registry entry with a given ID. UpdateArtifactRegistryEntry(ctx context.Context, registryEdgeID string, updateArtifactRegistry *model.ArtifactRegistryUpdateInput) (*model.ArtifactRegistry, error) // Delete the artifact registry entry with a given ID. DeleteArtifactRegistryEntry(ctx context.Context, registryEdgeID string) (bool, error) // Retrieve the cluster artifact registry entry with a given ID. GetClusterArtifactRegistry(ctx context.Context, clusterRegistryEdgeID string) (*model.ClusterArtifactRegistry, error) // Retrieve all cluster artifact registries for a given cluster. GetClusterArtifactRegistries(ctx context.Context, clusterEdgeID string) ([]*model.ClusterArtifactRegistry, error) // Create a new cluster artifact registry entry. CreateClusterArtifactRegistryEntry(ctx context.Context, createClusterArtifactRegistry *model.ClusterArtifactRegistryCreateInput) (*model.ClusterArtifactRegistry, error) // Delete the cluster artifact registry entry with a given ID. DeleteClusterArtifactRegistryEntry(ctx context.Context, clusterRegistryEdgeID string) (bool, error) } func NewArtifactRegistryService(sqlDB *sql.DB) ArtifactRegistryService { return &artifactRegistryService{ SQLDB: sqlDB, } } type artifactRegistryService struct { SQLDB *sql.DB } func (s *artifactRegistryService) GetArtifactRegistry(ctx context.Context, registryEdgeID string) (*model.ArtifactRegistry, error) { row := s.SQLDB.QueryRowContext(ctx, sqlquery.GetArtifactRegistryByIDQuery, registryEdgeID) return s.scanArtifactRegistryRow(row) } func (s *artifactRegistryService) scanArtifactRegistryRow(row *sql.Row) (*model.ArtifactRegistry, error) { artifactRegistry := &model.ArtifactRegistry{} if err := row.Scan(&artifactRegistry.RegistryEdgeID, &artifactRegistry.BannerEdgeID, &artifactRegistry.Description, &artifactRegistry.URL); err != nil { return nil, err } return artifactRegistry, nil } func (s *artifactRegistryService) GetArtifactRegistriesForBanner(ctx context.Context, bannerEdgeID string) ([]*model.ArtifactRegistry, error) { rows, err := s.SQLDB.QueryContext(ctx, sqlquery.GetArtifactRegistriesForBannerQuery, bannerEdgeID) if err != nil { return nil, err } return s.scanArtifactRegistryRows(rows) } func (s *artifactRegistryService) GetArtifactRegistriesForCluster(ctx context.Context, clusterEdgeID string) ([]*model.ArtifactRegistry, error) { rows, err := s.SQLDB.QueryContext(ctx, sqlquery.GetArtifactRegistriesForClusterQuery, clusterEdgeID) if err != nil { return nil, err } return s.scanArtifactRegistryRows(rows) } func (s *artifactRegistryService) scanArtifactRegistryRows(rows *sql.Rows) ([]*model.ArtifactRegistry, error) { artifactRegistries := []*model.ArtifactRegistry{} for rows.Next() { artifactRegistry := &model.ArtifactRegistry{} if err := rows.Scan(&artifactRegistry.RegistryEdgeID, &artifactRegistry.BannerEdgeID, &artifactRegistry.Description, &artifactRegistry.URL); err != nil { return nil, err } artifactRegistries = append(artifactRegistries, artifactRegistry) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return artifactRegistries, nil } func (s *artifactRegistryService) CreateArtifactRegistryEntry(ctx context.Context, createArtifactRegistry *model.ArtifactRegistryCreateInput) (artifactRegistry *model.ArtifactRegistry, err error) { if err := utils.ValidateArtifactRegistryCreateInput(createArtifactRegistry); err != nil { return nil, err } transaction, err := s.SQLDB.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { err = rollbackOnError(transaction, err) }() artifactRegistry = utils.CreateArtifactRegistryModel( uuid.NewString(), createArtifactRegistry.BannerEdgeID, createArtifactRegistry.URL, createArtifactRegistry.Description, ) args := []any{ artifactRegistry.RegistryEdgeID, artifactRegistry.BannerEdgeID, artifactRegistry.Description, artifactRegistry.URL, } if _, err = transaction.ExecContext(ctx, sqlquery.ArtifactRegistryCreateQuery, args...); err != nil { return } if err = transaction.Commit(); err != nil { return } return artifactRegistry, nil } func (s *artifactRegistryService) UpdateArtifactRegistryEntry(ctx context.Context, registryEdgeID string, updateArtifactRegistry *model.ArtifactRegistryUpdateInput) (artifactRegistry *model.ArtifactRegistry, err error) { if err := utils.ValidateArtifactRegistryUpdateInput(updateArtifactRegistry); err != nil { return nil, err } currentArtifactRegistry, err := s.GetArtifactRegistry(ctx, registryEdgeID) if err != nil { return nil, err } // return current if there are no changes to be made if !utils.ArtifactRegistryNeedsUpdating(currentArtifactRegistry, updateArtifactRegistry) { return currentArtifactRegistry, nil } transaction, err := s.SQLDB.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { err = rollbackOnError(transaction, err) }() artifactRegistry = utils.UpdateArtifactRegistry(currentArtifactRegistry, updateArtifactRegistry) args := []any{ artifactRegistry.Description, artifactRegistry.URL, registryEdgeID, } if _, err = transaction.ExecContext(ctx, sqlquery.ArtifactRegistryUpdateQuery, args...); err != nil { return } if err = transaction.Commit(); err != nil { return } return artifactRegistry, nil } func (s *artifactRegistryService) DeleteArtifactRegistryEntry(ctx context.Context, registryEdgeID string) (bool, error) { // do nothing if entry does not exist if exists, err := s.artifactRegistryEntryExists(ctx, registryEdgeID); err != nil { return false, err } else if !exists { return false, nil } if _, err := s.SQLDB.ExecContext(ctx, sqlquery.ArtifactRegistryDeleteQuery, registryEdgeID); err != nil { return false, err } return true, nil } func (s *artifactRegistryService) artifactRegistryEntryExists(ctx context.Context, registryEdgeID string) (bool, error) { exists := false row := s.SQLDB.QueryRowContext(ctx, sqlquery.ArtifactRegistryExistsQuery, registryEdgeID) if err := row.Scan(&exists); err != nil { return false, err } return exists, nil } func (s *artifactRegistryService) GetClusterArtifactRegistry(ctx context.Context, clusterRegistryEdgeID string) (*model.ClusterArtifactRegistry, error) { row := s.SQLDB.QueryRowContext(ctx, sqlquery.GetClusterArtifactRegistryByIDQuery, clusterRegistryEdgeID) return s.scanClusterArtifactRegistryRow(row) } func (s *artifactRegistryService) scanClusterArtifactRegistryRow(row *sql.Row) (*model.ClusterArtifactRegistry, error) { clusterArtifactRegistry := &model.ClusterArtifactRegistry{} if err := row.Scan(&clusterArtifactRegistry.ClusterRegistryEdgeID, &clusterArtifactRegistry.ClusterEdgeID, &clusterArtifactRegistry.RegistryEdgeID); err != nil { return nil, err } return clusterArtifactRegistry, nil } func (s *artifactRegistryService) GetClusterArtifactRegistries(ctx context.Context, clusterEdgeID string) ([]*model.ClusterArtifactRegistry, error) { rows, err := s.SQLDB.QueryContext(ctx, sqlquery.GetClusterArtifactRegistriesForClusterQuery, clusterEdgeID) if err != nil { return nil, err } return s.scanClusterArtifactRegistryRows(rows) } func (s *artifactRegistryService) scanClusterArtifactRegistryRows(rows *sql.Rows) ([]*model.ClusterArtifactRegistry, error) { clusterArtifactRegistries := []*model.ClusterArtifactRegistry{} for rows.Next() { clusterArtifactRegistry := &model.ClusterArtifactRegistry{} if err := rows.Scan(&clusterArtifactRegistry.ClusterRegistryEdgeID, &clusterArtifactRegistry.ClusterEdgeID, &clusterArtifactRegistry.RegistryEdgeID); err != nil { return nil, err } clusterArtifactRegistries = append(clusterArtifactRegistries, clusterArtifactRegistry) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return clusterArtifactRegistries, nil } func (s *artifactRegistryService) CreateClusterArtifactRegistryEntry(ctx context.Context, createClusterArtifactRegistry *model.ClusterArtifactRegistryCreateInput) (clusterArtifactRegistry *model.ClusterArtifactRegistry, err error) { transaction, err := s.SQLDB.BeginTx(ctx, nil) if err != nil { return nil, err } defer func() { err = rollbackOnError(transaction, err) }() clusterArtifactRegistry = utils.CreateClusterArtifactRegistryModel( uuid.NewString(), createClusterArtifactRegistry.ClusterEdgeID, createClusterArtifactRegistry.RegistryEdgeID, ) args := []any{ clusterArtifactRegistry.ClusterRegistryEdgeID, clusterArtifactRegistry.ClusterEdgeID, clusterArtifactRegistry.RegistryEdgeID, } if _, err = transaction.ExecContext(ctx, sqlquery.ClusterArtifactRegistryCreateQuery, args...); err != nil { return } if err = transaction.Commit(); err != nil { return } return clusterArtifactRegistry, nil } func (s *artifactRegistryService) DeleteClusterArtifactRegistryEntry(ctx context.Context, clusterRegistryEdgeID string) (bool, error) { // do nothing if entry does not exist if exists, err := s.clusterArtifactRegistryEntryExists(ctx, clusterRegistryEdgeID); err != nil { return false, err } else if !exists { return false, nil } if _, err := s.SQLDB.ExecContext(ctx, sqlquery.ClusterArtifactRegistryDeleteQuery, clusterRegistryEdgeID); err != nil { return false, err } return true, nil } func (s *artifactRegistryService) clusterArtifactRegistryEntryExists(ctx context.Context, clusterRegistryEdgeID string) (bool, error) { exists := false row := s.SQLDB.QueryRowContext(ctx, sqlquery.ClusterArtifactRegistryExistsQuery, clusterRegistryEdgeID) if err := row.Scan(&exists); err != nil { return false, err } return exists, nil } func rollbackOnError(transaction *sql.Tx, err error) error { if err != nil { return errors.Join(err, transaction.Rollback()) } return nil }