...

Source file src/edge-infra.dev/pkg/edge/api/services/channels/service.go

Documentation: edge-infra.dev/pkg/edge/api/services/channels

     1  package channels
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  
     8  	"edge-infra.dev/pkg/edge/api/graph/model"
     9  	apiServices "edge-infra.dev/pkg/edge/api/services"
    10  
    11  	"github.com/google/uuid"
    12  )
    13  
    14  var (
    15  	// ErrChannelDoesNotExist is returned when a provided channel does not exist.
    16  	//
    17  	// It is returned by the following service calls:
    18  	// - DeleteChannel
    19  	// - GetChannel
    20  	// - GetChannels
    21  	// - GetChannelByName
    22  	// - GetChannelsByName
    23  	// - GetChannelsForTeam
    24  	// - UpdateChannel
    25  	//
    26  	// TODO: check foreign key constraint violations in other service calls to see if the channel does not exist.
    27  	ErrChannelDoesNotExist = fmt.Errorf("channel does not exist")
    28  
    29  	// ErrNoBannerChannels is returned by GetBannerChannels if no channels are referenced by the banner.
    30  	ErrNoBannerChannels = fmt.Errorf("no channels exist in the banner")
    31  )
    32  
    33  type ChannelService struct {
    34  	db               *sql.DB
    35  	foremanProjectID string
    36  	ChariotService   apiServices.ChariotService
    37  }
    38  
    39  func NewChannelService(db *sql.DB, foremanProjectID string, chariotService apiServices.ChariotService) *ChannelService {
    40  	return &ChannelService{
    41  		db:               db,
    42  		foremanProjectID: foremanProjectID,
    43  		ChariotService:   chariotService,
    44  	}
    45  }
    46  
    47  type Service interface {
    48  	CreateChannel(ctx context.Context, channel Channel) (Channel, error)
    49  	DeleteChannel(ctx context.Context, channelID uuid.UUID, force bool) (Channel, error)
    50  	ExpireRotatedChannelKeyVersionNow(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error)
    51  	GetChannel(ctx context.Context, channelID uuid.UUID) (Channel, error)
    52  	GetChannelByName(ctx context.Context, name string) (Channel, error)
    53  	GetChannels(ctx context.Context, channelIDs ...uuid.UUID) ([]Channel, error)
    54  	GetChannelsByName(ctx context.Context, names ...string) ([]Channel, error)
    55  	GetChannelsForTeam(ctx context.Context, team string) ([]Channel, error)
    56  	GetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID) ([]Channel, error)
    57  	GetLatestChannelKeyVersion(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error)
    58  	RotateChannelNow(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error)
    59  	UpdateChannel(ctx context.Context, channelID uuid.UUID, request ChannelUpdateRequest) (Channel, error)
    60  	CreateChannelIAM(ctx context.Context, channelID uuid.UUID, saEmail string) (*model.ChannelIAMPolicy, error)
    61  	CreateChannelKeyVersion(ctx context.Context, ckv ChannelKeyVersion) (latest ChannelKeyVersion, err error)
    62  	GetChannelsFromHelmConfig(ctx context.Context, configYaml *string) ([]Channel, error)
    63  	CreateHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error
    64  	GetBannerChannels(ctx context.Context, bannerEdgeID uuid.UUID) ([]BannerChannel, error)
    65  	DeleteChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error)
    66  }
    67  
    68  const sqlCreateChannel = `INSERT INTO channels(name, description, team, expire_buffer_duration, rotation_interval_duration)
    69  VALUES ($1, $2, $3, $4, $5)
    70  RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at`
    71  
    72  // CreateChannel returns the created Channel with all fields populated.
    73  func (cs *ChannelService) CreateChannel(ctx context.Context, channel Channel) (Channel, error) {
    74  	// TODO: find a hack that ensures the context was passed in by the resolver for a super admin.
    75  	if err := channel.validateCreate(); err != nil {
    76  		return channel, err
    77  	}
    78  
    79  	var row = cs.db.QueryRowContext(ctx, sqlCreateChannel,
    80  		channel.Name,
    81  		channel.Description,
    82  		channel.Team,
    83  		channel.ExpireBufferDuration,
    84  		channel.RotationIntervalDuration,
    85  	)
    86  
    87  	created, err := scanChannel(row)
    88  	if err != nil {
    89  		return created, fmt.Errorf("failed to create channel: %w", err)
    90  	}
    91  	return created, nil
    92  }
    93  
    94  const sqlUpdateChannel = `UPDATE channels
    95  SET (team, description, expire_buffer_duration, rotation_interval_duration) = (
    96    COALESCE($1, team),
    97    COALESCE($2, description),
    98    COALESCE($3, expire_buffer_duration),
    99    COALESCE($4, rotation_interval_duration)
   100  )
   101  WHERE channel_id = $5
   102  RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at`
   103  
   104  // UpdateChannel should only be called by the `updateChannel` resolver available to super admins.
   105  //
   106  // TODO:
   107  // Run this in a transaction and recalculate the ChannelKeyVersion's `rotate_at` and `expire_at` times, when their fields have been updated.
   108  // For now, updating `expire_buffer_duration` and `rotation_interval_duration` only affects stuff that happens in the future.
   109  func (cs *ChannelService) UpdateChannel(ctx context.Context, channelID uuid.UUID, request ChannelUpdateRequest) (Channel, error) {
   110  	// TODO: find some hack that ensures the context was passed in by the resolver for a super admin.
   111  	if err := request.validate(); err != nil {
   112  		return Channel{}, err
   113  	}
   114  	var row = cs.db.QueryRowContext(ctx, sqlUpdateChannel,
   115  		request.Team,
   116  		request.Description,
   117  		request.ExpireBufferDuration,
   118  		request.RotationIntervalDuration,
   119  		channelID,
   120  	)
   121  
   122  	updated, err := scanChannel(row)
   123  	if err != nil {
   124  		return updated, fmt.Errorf("failed to update channel: %w", err)
   125  	}
   126  	return updated, err
   127  }
   128  
   129  const sqlDeleteChannel = `DELETE FROM channels
   130  WHERE channel_id = $1
   131  RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at`
   132  
   133  const sqlHelmWorkloadsForChannelExists = `SELECT EXISTS (
   134    SELECT 1
   135    FROM helm_workloads_channels
   136    WHERE channel_id = $1
   137  )`
   138  
   139  func (cs *ChannelService) txCanDeleteChannel(ctx context.Context, tx *sql.Tx, channelID uuid.UUID) error {
   140  	var hasHelmWorkloads bool
   141  	var row = tx.QueryRowContext(ctx, sqlHelmWorkloadsForChannelExists, channelID)
   142  	if err := row.Scan(&hasHelmWorkloads); err != nil {
   143  		return fmt.Errorf("failed to check helm workloads used by channel: %w", err)
   144  	} else if hasHelmWorkloads {
   145  		return fmt.Errorf("channel is being used by helm workloads")
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  // DeleteChannel should only be called by the `deleteChannel` resolver available to super admins.
   152  //
   153  // When `force` is false, the DeleteChannel method returns an error if the channel is mapped to any helm workloads.
   154  func (cs *ChannelService) DeleteChannel(ctx context.Context, channelID uuid.UUID, force bool) (Channel, error) {
   155  	// TODO: find some hack that ensures the context was passed in by the resolver for a super admin.
   156  	tx, err := cs.db.BeginTx(ctx, nil)
   157  	if err != nil {
   158  		return Channel{}, err
   159  	}
   160  	defer tx.Rollback() //nolint: errcheck
   161  
   162  	if !force {
   163  		if err := cs.txCanDeleteChannel(ctx, tx, channelID); err != nil {
   164  			return Channel{}, err
   165  		}
   166  	}
   167  
   168  	var row = tx.QueryRowContext(ctx, sqlDeleteChannel, channelID)
   169  	deleted, err := scanChannel(row)
   170  	if err != nil {
   171  		return Channel{}, fmt.Errorf("failed to delete channel: %w", err)
   172  	}
   173  
   174  	if err := tx.Commit(); err != nil {
   175  		return Channel{}, err
   176  	}
   177  	return deleted, nil
   178  }
   179  
   180  const sqlGetChannels = `SELECT channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at
   181  FROM channels`
   182  
   183  const sqlGetChannelsWithIDs = sqlGetChannels + `
   184  WHERE channel_id = ANY ($1)`
   185  
   186  const sqlGetChannelsForTeam = sqlGetChannels + `
   187  WHERE team = $1`
   188  
   189  const sqlGetChannelsByName = sqlGetChannels + `
   190  WHERE name = ANY($1)`
   191  
   192  func (cs *ChannelService) GetChannel(ctx context.Context, channelID uuid.UUID) (Channel, error) {
   193  	var row = cs.db.QueryRowContext(ctx, sqlGetChannelsWithIDs, []uuid.UUID{channelID})
   194  	return scanChannel(row)
   195  }
   196  
   197  func (cs *ChannelService) GetChannelByName(ctx context.Context, name string) (Channel, error) {
   198  	var row = cs.db.QueryRowContext(ctx, sqlGetChannelsByName, []string{name})
   199  	return scanChannel(row)
   200  }
   201  
   202  // GetChannelsByName returns a Channel slice indexed in same order as the provided `names`.
   203  //
   204  // An error is returned if any of the channels do not exist.
   205  func (cs *ChannelService) GetChannelsByName(ctx context.Context, names ...string) ([]Channel, error) {
   206  	if len(names) == 0 {
   207  		return nil, nil
   208  	}
   209  
   210  	rows, err := cs.db.QueryContext(ctx, sqlGetChannelsByName, names)
   211  	if err != nil {
   212  		return nil, fmt.Errorf("failed to query channels by name: %w", err)
   213  	}
   214  	defer rows.Close() //nolint: errcheck
   215  
   216  	var channels = make(map[string]Channel)
   217  	for rows.Next() {
   218  		channel, err := scanChannel(rows)
   219  		if err != nil {
   220  			return nil, err
   221  		}
   222  		channels[channel.Name] = channel
   223  	}
   224  	if err := rows.Err(); err != nil {
   225  		return nil, fmt.Errorf("failed to scan all channels by name: %w", err)
   226  	}
   227  
   228  	var ordered []Channel
   229  	for _, name := range names {
   230  		channel, found := channels[name]
   231  		if !found {
   232  			return nil, fmt.Errorf("%w: %s", ErrChannelDoesNotExist, name)
   233  		}
   234  		ordered = append(ordered, channel)
   235  	}
   236  	return ordered, nil
   237  }
   238  
   239  // GetChannelsFromHelmConfig attempts to parse the YAML helm config, then it queries the configured channels.
   240  //
   241  // If no channels are configured, an empty slice is returned without an error.
   242  //
   243  // An error is returned if:
   244  // - the configYAML is invalid.
   245  // - any of the configured channel names are invalid.
   246  // - any of the configured channels do not exist in the database (returns ErrChannelDoesNotExist).
   247  func (cs *ChannelService) GetChannelsFromHelmConfig(ctx context.Context, configYaml *string) ([]Channel, error) {
   248  	if configYaml == nil {
   249  		return make([]Channel, 0), nil
   250  	}
   251  	config, err := ParseHelmConfigChannels(*configYaml)
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  
   256  	if !config.HasChannels() {
   257  		return nil, nil
   258  	}
   259  
   260  	return cs.GetChannelsByName(ctx, config.Names()...)
   261  }
   262  
   263  // GetChannels returns a Channel slice indexed in same order as the provided `channelIDs`.
   264  // All channels are returned if the `channelIDs` argument is empty.
   265  //
   266  // When a provided channel is not found, GetChannels returns ErrChannelDoesNotExist.
   267  func (cs *ChannelService) GetChannels(ctx context.Context, channelIDs ...uuid.UUID) ([]Channel, error) {
   268  	var args []interface{}
   269  	var stmt = sqlGetChannels
   270  	if len(channelIDs) > 0 {
   271  		args = append(args, channelIDs)
   272  		stmt = sqlGetChannelsWithIDs
   273  	}
   274  
   275  	rows, err := cs.db.QueryContext(ctx, stmt, args...)
   276  	if err != nil {
   277  		return nil, fmt.Errorf("failed to query channels: %w", err)
   278  	}
   279  	defer rows.Close() //nolint: errcheck
   280  
   281  	var channels []Channel
   282  	for rows.Next() {
   283  		channel, err := scanChannel(rows)
   284  		if err != nil {
   285  			return nil, err
   286  		}
   287  		channels = append(channels, channel)
   288  	}
   289  	if err := rows.Err(); err != nil {
   290  		return nil, fmt.Errorf("failed to scan all channels: %w", err)
   291  	}
   292  
   293  	// no need to order the channels when none are provided.
   294  	if len(channelIDs) == 0 {
   295  		return channels, nil
   296  	}
   297  
   298  	// order the channels by the same index as the `channelIDs` argument.
   299  	// this is a nice feature to have in practice, and makes testing easier.
   300  	var m = make(map[uuid.UUID]Channel)
   301  	for _, channel := range channels {
   302  		m[channel.ID] = channel
   303  	}
   304  
   305  	var ordered []Channel
   306  	for _, channelID := range channelIDs {
   307  		channel, found := m[channelID]
   308  		if !found {
   309  			return nil, fmt.Errorf("%w: %s", ErrChannelDoesNotExist, channelID)
   310  		}
   311  		ordered = append(ordered, channel)
   312  	}
   313  	return ordered, nil
   314  }
   315  
   316  // TODO: check with Luc/Lee to see if `Team` should have a UNIQUE constraint like `Name`
   317  func (cs *ChannelService) GetChannelsForTeam(ctx context.Context, team string) ([]Channel, error) {
   318  	rows, err := cs.db.QueryContext(ctx, sqlGetChannelsForTeam, team)
   319  	if err != nil {
   320  		return nil, fmt.Errorf("failed to query channels for team: %w", err)
   321  	}
   322  	defer rows.Close() //nolint: errcheck
   323  
   324  	var channels []Channel
   325  	for rows.Next() {
   326  		channel, err := scanChannel(rows)
   327  		if err != nil {
   328  			return nil, err
   329  		}
   330  		channels = append(channels, channel)
   331  	}
   332  	if err := rows.Err(); err != nil {
   333  		return nil, fmt.Errorf("failed to scan all channels for team: %w", err)
   334  	}
   335  
   336  	if len(channels) == 0 {
   337  		return nil, fmt.Errorf("%w for team: %q", ErrChannelDoesNotExist, team)
   338  	}
   339  	return channels, nil
   340  }
   341  
   342  // GetBannerChannels returns all channels being referenced by this banner.
   343  // - any channel in `helm_workloads_channels` mapped to this banner's helm workloads.
   344  // - any channel in `channels_key_versions` that references this banner.
   345  //
   346  // GetBannerChannels is intended to be used by bannerctl and clusterctl within "banner-infra" clusters.
   347  // Its purpose is to retrieve all the data needed for channel reconciliation, at the banner level, from a single call to the channel service.
   348  // This should substantially simplify their reconcile function, improve its performance, and reduce their potential for errors.
   349  //
   350  // If no channels exist in the banner, then GetBannerChannels returns `ErrNoBannerChannels`
   351  func (cs *ChannelService) GetBannerChannels(ctx context.Context, bannerEdgeID uuid.UUID) ([]BannerChannel, error) {
   352  	tx, err := cs.db.BeginTx(ctx, nil)
   353  	if err != nil {
   354  		return nil, err
   355  	}
   356  	defer tx.Rollback() //nolint: errcheck
   357  
   358  	hwcm, err := cs.txGetHelmWorkloadsChannelsForBanner(ctx, tx, bannerEdgeID)
   359  	if err != nil {
   360  		return nil, err
   361  	}
   362  
   363  	ckvm, err := cs.txGetChannelKeyVersionsForBanner(ctx, tx, bannerEdgeID)
   364  	if err != nil {
   365  		return nil, err
   366  	}
   367  
   368  	// deduplicate the channels being used in this banner
   369  	var dedupMap = make(map[uuid.UUID]struct{})
   370  	for id := range hwcm {
   371  		dedupMap[id] = struct{}{}
   372  	}
   373  	for id := range ckvm {
   374  		dedupMap[id] = struct{}{}
   375  	}
   376  	var dedup []uuid.UUID
   377  	for id := range dedupMap {
   378  		dedup = append(dedup, id)
   379  	}
   380  
   381  	// return early if the banner is not using any channels
   382  	if len(dedup) == 0 {
   383  		return nil, ErrNoBannerChannels
   384  	}
   385  
   386  	// Get all channels for this banner.
   387  	rows, err := tx.QueryContext(ctx, sqlGetChannelsWithIDs, dedup)
   388  	if err != nil {
   389  		return nil, fmt.Errorf("failed to query banner channels: %w", err)
   390  	}
   391  	defer rows.Close() //nolint: errcheck
   392  
   393  	var bannerChannels []BannerChannel
   394  	for rows.Next() {
   395  		channel, err := scanChannel(rows)
   396  		if err != nil {
   397  			return nil, err
   398  		}
   399  		bannerChannels = append(bannerChannels, BannerChannel{
   400  			Channel:     channel,
   401  			KeyVersions: ckvm[channel.ID],
   402  			HelmEdgeIDs: hwcm[channel.ID],
   403  		})
   404  	}
   405  	if err := rows.Err(); err != nil {
   406  		return nil, fmt.Errorf("failed to scan all channels with IDs: %w", err)
   407  	}
   408  	return bannerChannels, nil
   409  }
   410  
   411  const sqlGetHelmWorkloadsChannelsForBanner = `SELECT channel_id, helm_edge_id
   412  FROM helm_workloads_channels
   413  WHERE helm_edge_id = ANY (
   414    SELECT helm_edge_id
   415    FROM helm_workloads
   416    WHERE banner_edge_id = $1
   417  )`
   418  
   419  // txGetHelmWorkloadsChannelsForBanner returns a map[channel_id][]helm_edge_id
   420  func (cs *ChannelService) txGetHelmWorkloadsChannelsForBanner(ctx context.Context, tx *sql.Tx, bannerEdgeID uuid.UUID) (map[uuid.UUID][]uuid.UUID, error) {
   421  	rows, err := tx.QueryContext(ctx, sqlGetHelmWorkloadsChannelsForBanner, bannerEdgeID)
   422  	if err != nil {
   423  		return nil, fmt.Errorf("failed to query helm workloads channels for banner: %w", err)
   424  	}
   425  	defer rows.Close() //nolint: errcheck
   426  
   427  	var m = make(map[uuid.UUID][]uuid.UUID)
   428  	for rows.Next() {
   429  		var channelID, helmEdgeID uuid.UUID
   430  		if err := rows.Scan(&channelID, &helmEdgeID); err != nil {
   431  			return nil, fmt.Errorf("failed to scan helm workloads channels for banner: %w", err)
   432  		}
   433  		m[channelID] = append(m[channelID], helmEdgeID)
   434  	}
   435  	if err := rows.Err(); err != nil {
   436  		return nil, fmt.Errorf("failed to scan all helm workloads channels for banner: %w", err)
   437  	}
   438  	return m, nil
   439  }
   440  
   441  const sqlGetChannelKeyVersionsForBanner = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at
   442  FROM channels_key_versions
   443  WHERE banner_edge_id = $1
   444  ORDER BY version DESC`
   445  
   446  // txGetUnexpiredChannelKeyVersionsForBanner returns a map[channel_id][]ChannelKeyVersion
   447  func (cs *ChannelService) txGetChannelKeyVersionsForBanner(ctx context.Context, tx *sql.Tx, bannerEdgeID uuid.UUID) (map[uuid.UUID][]ChannelKeyVersion, error) {
   448  	rows, err := tx.QueryContext(ctx, sqlGetChannelKeyVersionsForBanner, bannerEdgeID)
   449  	if err != nil {
   450  		return nil, fmt.Errorf("failed to query channel key versions for banner: %w", err)
   451  	}
   452  	defer rows.Close() //nolint: errcheck
   453  
   454  	var m = make(map[uuid.UUID][]ChannelKeyVersion)
   455  	for rows.Next() {
   456  		ckv, err := scanChannelKeyVersion(rows)
   457  		if err != nil {
   458  			return nil, err
   459  		}
   460  		m[ckv.ChannelID] = append(m[ckv.ChannelID], ckv)
   461  	}
   462  	if err := rows.Err(); err != nil {
   463  		return nil, fmt.Errorf("failed to scan all unexpired channels key versions for banner: %w", err)
   464  	}
   465  	return m, nil
   466  }
   467  
   468  const sqlCreateHelmWorkloadChannel = `INSERT INTO helm_workloads_channels(channel_id, helm_edge_id)
   469  VALUES ($1, $2)
   470  ON CONFLICT 
   471    ON CONSTRAINT unique_channel_id_helm_edge_id
   472    DO NOTHING`
   473  
   474  func (cs *ChannelService) CreateHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error {
   475  	if len(channelIDs) == 0 {
   476  		return fmt.Errorf("channelIDs must not be empty")
   477  	}
   478  
   479  	tx, err := cs.db.BeginTx(ctx, nil)
   480  	if err != nil {
   481  		return err
   482  	}
   483  	defer tx.Rollback() //nolint: errcheck
   484  
   485  	for _, channelID := range channelIDs {
   486  		_, err := tx.ExecContext(ctx, sqlCreateHelmWorkloadChannel, channelID, helmEdgeID)
   487  		if err != nil {
   488  			return fmt.Errorf("failed to create helm workload channel for channel %q: %w", channelID, err)
   489  		}
   490  	}
   491  
   492  	return tx.Commit()
   493  }
   494  
   495  const sqlDeleteHelmWorkloadChannels = `DELETE FROM helm_workloads_channels
   496  WHERE helm_edge_id = $1`
   497  
   498  const sqlDeleteOmittedHelmWorkloadChannels = sqlDeleteHelmWorkloadChannels + `
   499    AND NOT channel_id = ANY($2)`
   500  
   501  const sqlDeleteHelmWorkloadChannelsWithChannelIDs = sqlDeleteHelmWorkloadChannels + `
   502    AND channel_id = ANY($2)`
   503  
   504  // SetHelmWorkloadChannels sets the helm workload's channel mappings.
   505  // It creates the helm workload's mappings to the provided channels, and deletes the helm workload's mappings for omitted channels.
   506  //
   507  // If no channels are provided, this function acts like DeleteHelmWorkloadChannel, and deletes all of the helm workload's mappings.
   508  func (cs *ChannelService) SetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error {
   509  	tx, err := cs.db.BeginTx(ctx, nil)
   510  	if err != nil {
   511  		return err
   512  	}
   513  	defer tx.Rollback() //nolint: errcheck
   514  
   515  	if len(channelIDs) == 0 {
   516  		_, err = tx.ExecContext(ctx, sqlDeleteHelmWorkloadChannels, helmEdgeID)
   517  	} else {
   518  		_, err = tx.ExecContext(ctx, sqlDeleteOmittedHelmWorkloadChannels, helmEdgeID, channelIDs)
   519  	}
   520  	if err != nil {
   521  		return fmt.Errorf("failed to set deleted helm workload channels: %w", err)
   522  	}
   523  
   524  	for _, channelID := range channelIDs {
   525  		_, err := tx.ExecContext(ctx, sqlCreateHelmWorkloadChannel, channelID, helmEdgeID)
   526  		if err != nil {
   527  			return fmt.Errorf("failed to set created helm workload channels: %w", err)
   528  		}
   529  	}
   530  
   531  	return tx.Commit()
   532  }
   533  
   534  // DeleteHelmWorkloadChannels deletes the helm workload's mappings for the provided channels.
   535  // If no channels are provided, it deletes all of the helm workload's mappings.
   536  func (cs *ChannelService) DeleteHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error {
   537  	var err error
   538  	if len(channelIDs) == 0 {
   539  		_, err = cs.db.ExecContext(ctx, sqlDeleteHelmWorkloadChannels, helmEdgeID)
   540  	} else {
   541  		_, err = cs.db.ExecContext(ctx, sqlDeleteHelmWorkloadChannelsWithChannelIDs, helmEdgeID, channelIDs)
   542  	}
   543  
   544  	if err != nil {
   545  		return fmt.Errorf("failed to delete helm workload channels: %w", err)
   546  	}
   547  	return nil
   548  }
   549  
   550  const sqlGetHelmWorkloadChannels = `SELECT channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at
   551  FROM channels
   552  WHERE channel_id IN (
   553    SELECT channel_id FROM helm_workloads_channels WHERE helm_edge_id = $1
   554  )`
   555  
   556  func (cs *ChannelService) GetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID) ([]Channel, error) {
   557  	rows, err := cs.db.QueryContext(ctx, sqlGetHelmWorkloadChannels, helmEdgeID)
   558  	if err != nil {
   559  		return nil, fmt.Errorf("failed to query helm workload channels: %w", err)
   560  	}
   561  	defer rows.Close()
   562  
   563  	var channels []Channel
   564  	for rows.Next() {
   565  		channel, err := scanChannel(rows)
   566  		if err != nil {
   567  			return nil, err
   568  		}
   569  		channels = append(channels, channel)
   570  	}
   571  	if err := rows.Err(); err != nil {
   572  		return nil, fmt.Errorf("failed to scan all helm workload channels: %w", err)
   573  	}
   574  
   575  	return channels, nil
   576  }
   577  
   578  const sqlGetHelmWorkloadsForChannel = `SELECT helm_edge_id
   579  FROM helm_workloads_channels
   580  WHERE channel_id = $1
   581    AND helm_edge_id = ANY (SELECT helm_edge_id FROM helm_workloads WHERE banner_edge_id = $2)`
   582  
   583  // GetHelmWorkloadsForChannel returns all helm workloads mapped to the channel within the provided banner
   584  func (cs *ChannelService) GetHelmWorkloadsForChannel(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (helmEdgeIDs []uuid.UUID, err error) {
   585  	rows, err := cs.db.QueryContext(ctx, sqlGetHelmWorkloadsForChannel, channelID, bannerEdgeID)
   586  	if err != nil {
   587  		return nil, fmt.Errorf("failed to query helm workloads for channel: %w", err)
   588  	}
   589  	defer rows.Close() //nolint: errcheck
   590  
   591  	for rows.Next() {
   592  		var helmEdgeID uuid.UUID
   593  		if err := rows.Scan(&helmEdgeID); err != nil {
   594  			return nil, fmt.Errorf("failed to scan helm workloads for channel: %w", err)
   595  		}
   596  		helmEdgeIDs = append(helmEdgeIDs, helmEdgeID)
   597  	}
   598  	if err := rows.Err(); err != nil {
   599  		return nil, fmt.Errorf("failed to scan all helm workloads for channel: %w", err)
   600  	}
   601  	return helmEdgeIDs, nil
   602  }
   603  
   604  const sqlCreateChannelKeyVersion = `WITH convert_channel_durations_to_seconds AS (
   605    SELECT
   606      expire_buffer_duration / ` + secondStr + ` AS expire_buffer_seconds,
   607      rotation_interval_duration / ` + secondStr + ` AS rotation_interval_seconds
   608    FROM channels
   609    WHERE channel_id = $1
   610    LIMIT 1
   611  )
   612  INSERT INTO channels_key_versions (channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at)
   613  VALUES (
   614    $1,
   615    $2,
   616    $3,
   617    $4,
   618    NOW() + INTERVAL '1 SECOND' * (SELECT rotation_interval_seconds + expire_buffer_seconds FROM convert_channel_durations_to_seconds),
   619    NOW() + INTERVAL '1 SECOND' * (SELECT rotation_interval_seconds FROM convert_channel_durations_to_seconds)
   620  )
   621  RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
   622  
   623  // CreateChannelKeyVersion inserts the latest ChannelKeyVersion, which fully rotates the channel within a banner.
   624  //
   625  // The following steps are performed within a transaction:
   626  //  1. Mark the currently-latest ChannelKeyVersion as rotated, by nullifying its `rotate_at` column.
   627  //  2. Insert the new ChannelKeyVersion, which calculates its `created_at`, `expire_at`, and `rotate_at` columns.
   628  //  3. Scan the inserted row into the ChannelKeyVersion struct that CreateChannelKeyVersion returns.
   629  func (cs *ChannelService) CreateChannelKeyVersion(ctx context.Context, ckv ChannelKeyVersion) (latest ChannelKeyVersion, err error) {
   630  	if err := ckv.validateCreate(); err != nil {
   631  		return ChannelKeyVersion{}, err
   632  	}
   633  
   634  	tx, err := cs.db.BeginTx(ctx, nil)
   635  	if err != nil {
   636  		return ChannelKeyVersion{}, err
   637  	}
   638  	defer tx.Rollback() //nolint: errcheck
   639  
   640  	if err := cs.txMarkLatestChannelKeyVersionAsRotated(ctx, tx, ckv.BannerEdgeID, ckv.ChannelID); err != nil {
   641  		return ChannelKeyVersion{}, err
   642  	}
   643  
   644  	var row = tx.QueryRowContext(ctx, sqlCreateChannelKeyVersion,
   645  		ckv.ChannelID,
   646  		ckv.BannerEdgeID,
   647  		ckv.Version,
   648  		ckv.SecretManagerLink,
   649  	)
   650  
   651  	latest, err = scanChannelKeyVersion(row)
   652  	if err != nil {
   653  		return ChannelKeyVersion{}, fmt.Errorf("failed to create channel key version: %w", err)
   654  	}
   655  
   656  	if err = tx.Commit(); err != nil {
   657  		return ChannelKeyVersion{}, fmt.Errorf("failed to commit created channel key version: %w", err)
   658  	}
   659  
   660  	return latest, nil
   661  }
   662  
   663  // This query uses the `LEAST` function to prevent the `expire_at` timestamp from increasing.
   664  const sqlTxMarkLatestChannelKeyVersionAsRotated = `WITH calc_expire_buffer_seconds AS (
   665    SELECT expire_buffer_duration / ` + secondStr + ` AS seconds
   666    FROM channels
   667    WHERE channel_id = $1
   668    LIMIT 1
   669  ),
   670  calc_expire_at AS (
   671    SELECT NOW() + INTERVAL '1 SECOND' * (SELECT seconds FROM calc_expire_buffer_seconds)
   672    AS expire_at
   673  )
   674  UPDATE channels_key_versions
   675  SET 
   676    expire_at = LEAST(expire_at, (SELECT expire_at FROM calc_expire_at)),
   677    rotate_at = NULL
   678  WHERE channel_id = $1
   679    AND banner_edge_id = $2
   680    AND rotate_at IS NOT NULL`
   681  
   682  // txMarkLatestChannelKeyVersionAsRotated nullifies the `rotate_at` time for the latest ChannelKeyVersion, and recalculates the `expire_at` column.
   683  //
   684  // The `rotate_at` time only makes sense for the latest ChannelKeyVersion, so the column is set to NULL for outdated ChannelKeyVersions.
   685  func (cs *ChannelService) txMarkLatestChannelKeyVersionAsRotated(ctx context.Context, tx *sql.Tx, bannerEdgeID, channelID uuid.UUID) error {
   686  	_, err := tx.ExecContext(ctx, sqlTxMarkLatestChannelKeyVersionAsRotated, channelID, bannerEdgeID)
   687  	if err != nil {
   688  		return fmt.Errorf("failed to mark the latest channel key version as rotated: %w", err)
   689  	}
   690  	return nil
   691  }
   692  
   693  const sqlDeleteChannelKeyVersion = `DELETE FROM channels_key_versions
   694  WHERE channel_key_versions_id = $1
   695    AND expire_at < NOW()
   696  RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
   697  
   698  // DeleteChannelKeyVersion is used to finalize the deletion of an expired ChannelKeyVersion, after all its resources have been cleaned up by controllers.
   699  //
   700  // An error occurs if the ChannelKeyVersion's ExpireAt time has not elapsed, since magpie is still using it for decryption.
   701  func (cs *ChannelService) DeleteChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) {
   702  	row := cs.db.QueryRowContext(ctx, sqlDeleteChannelKeyVersion, id)
   703  	ckv, err := scanChannelKeyVersion(row)
   704  	if err != nil {
   705  		return ckv, fmt.Errorf("failed to delete channel key version: %w", err)
   706  	}
   707  	return ckv, nil
   708  }
   709  
   710  const sqlGetChannelKeyVersionByID = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at
   711  FROM channels_key_versions
   712  WHERE channel_key_versions_id = $1`
   713  
   714  func (cs *ChannelService) GetChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) {
   715  	var row = cs.db.QueryRowContext(ctx, sqlGetChannelKeyVersionByID, id)
   716  	ckv, err := scanChannelKeyVersion(row)
   717  	if err != nil {
   718  		return ckv, fmt.Errorf("failed to scan channel key version: %w", err)
   719  	}
   720  	return ckv, nil
   721  }
   722  
   723  const sqlGetChannelKeyVersions = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at
   724  FROM channels_key_versions
   725  WHERE banner_edge_id = $1
   726    AND channel_id = $2`
   727  
   728  const sqlGetChannelKeyVersionsOrdered = sqlGetChannelKeyVersions + `
   729  ORDER BY version DESC`
   730  
   731  const sqlGetLatestChannelKeyVersion = sqlGetChannelKeyVersions + `
   732  AND rotate_at IS NOT NULL`
   733  
   734  func (cs *ChannelService) GetChannelKeyVersions(ctx context.Context, bannerEdgeID, channelID uuid.UUID) ([]ChannelKeyVersion, error) {
   735  	rows, err := cs.db.QueryContext(ctx, sqlGetChannelKeyVersionsOrdered, bannerEdgeID, channelID)
   736  	if err != nil {
   737  		return nil, fmt.Errorf("failed to query channel key versions: %w", err)
   738  	}
   739  	defer rows.Close() //nolint: errcheck
   740  
   741  	var ckvs []ChannelKeyVersion
   742  	for rows.Next() {
   743  		ckv, err := scanChannelKeyVersion(rows)
   744  		if err != nil {
   745  			return nil, err
   746  		}
   747  		ckvs = append(ckvs, ckv)
   748  	}
   749  	if err := rows.Err(); err != nil {
   750  		return nil, fmt.Errorf("failed to scan all channel key versions: %w", err)
   751  	}
   752  	return ckvs, nil
   753  }
   754  
   755  // GetLatestChannelKeyVersion returns the most recently created ChannelKeyVersion for a channel in this banner.
   756  //
   757  // If no ChannelKeyVersion exists for the channel in this banner, it returns `sql.ErrNoRows`.
   758  func (cs *ChannelService) GetLatestChannelKeyVersion(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) {
   759  	var row = cs.db.QueryRowContext(ctx, sqlGetLatestChannelKeyVersion, bannerEdgeID, channelID)
   760  	return scanChannelKeyVersion(row)
   761  }
   762  
   763  /*
   764  	const sqlGetUnexpiredChannelKeyVersions = sqlGetChannelKeyVersions + `
   765  	  AND expire_at > NOW()
   766  	ORDER BY version DESC`
   767  
   768  	const sqlGetExpiredChannelKeyVersions = sqlGetChannelKeyVersions + `
   769  	  AND expire_at < NOW()
   770  	ORDER BY version ASC`
   771  
   772  	// GetUnexpiredChannelKeyVersions returns all unexpired key versions, sorted by the most recent version.
   773  	func (cs *ChannelService) GetUnexpiredChannelKeyVersions(ctx context.Context, bannerEdgeID, channelID uuid.UUID) ([]ChannelKeyVersion, error) {
   774  		rows, err := cs.db.QueryContext(ctx, sqlGetUnexpiredChannelKeyVersions, bannerEdgeID, channelID)
   775  		if err != nil {
   776  			return nil, fmt.Errorf("failed to query unexpired channel key versions: %w", err)
   777  		}
   778  		defer rows.Close()
   779  
   780  		var ckvs []ChannelKeyVersion
   781  		for rows.Next() {
   782  			ckv, err := scanChannelKeyVersion(rows)
   783  			if err != nil {
   784  				return nil, err
   785  			}
   786  			ckvs = append(ckvs, ckv)
   787  		}
   788  		if err := rows.Err(); err != nil {
   789  			return nil, fmt.Errorf("failed to scan all unexpired channel key versions: %w", err)
   790  		}
   791  		return ckvs, nil
   792  	}
   793  
   794  	// GetExpiredChannelKeyVersions is provided for clean up purposes.
   795  	//
   796  	// After cleaning up a ChannelKeyVersion's resources, the `DeleteChannelKeyVersion` method should be called.
   797  	func (cs *ChannelService) GetExpiredChannelKeyVersions(ctx context.Context, bannerEdgeID, channelID uuid.UUID) ([]ChannelKeyVersion, error) {
   798  		rows, err := cs.db.QueryContext(ctx, sqlGetExpiredChannelKeyVersions, bannerEdgeID, channelID)
   799  		if err != nil {
   800  			return nil, fmt.Errorf("failed to query expired channel key versions: %w", err)
   801  		}
   802  		defer rows.Close()
   803  
   804  		var ckvs []ChannelKeyVersion
   805  		for rows.Next() {
   806  			ckv, err := scanChannelKeyVersion(rows)
   807  			if err != nil {
   808  				return nil, err
   809  			}
   810  			ckvs = append(ckvs, ckv)
   811  		}
   812  		if err := rows.Err(); err != nil {
   813  			return nil, fmt.Errorf("failed to scan all expired channel key versions: %w", err)
   814  		}
   815  		return ckvs, nil
   816  	}
   817  */
   818  
   819  // Using the LEAST condition since the channel is already marked for rotation if `rotate_at` is less than NOW().
   820  // Using the LEAST condition also ensures `rotate_at` is always less than `expired_at`, which would break scan validation assertions.
   821  const sqlRotateChannelNow = `UPDATE channels_key_versions
   822  SET rotate_at = LEAST(rotate_at, NOW())
   823  WHERE channel_id = $1
   824    AND banner_edge_id = $2
   825    AND rotate_at IS NOT NULL
   826  RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
   827  
   828  // RotateChannelNow marks the latest ChannelKeyVersion for rotation.
   829  //
   830  // The channel is actually rotated when a new ChannelKeyVersion is created.
   831  func (cs *ChannelService) RotateChannelNow(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) {
   832  	row := cs.db.QueryRowContext(ctx, sqlRotateChannelNow, channelID, bannerEdgeID)
   833  	ckv, err := scanChannelKeyVersion(row)
   834  	if err != nil {
   835  		return ckv, fmt.Errorf("failed to rotate channel now: %w", err)
   836  	}
   837  	return ckv, nil
   838  }
   839  
   840  // Marking the latest ChannelKeyVersion as expired introduces nasty edge cases, in SQL and architecturally, so this query is limited to keys that have already been rotated.
   841  const sqlExpireRotatedChannelKeyVersionNow = `UPDATE channels_key_versions
   842  SET 
   843    expire_at = LEAST(expire_at, NOW())
   844  WHERE channel_key_versions_id = $1
   845    AND rotate_at IS NULL
   846  RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
   847  
   848  // ExpireRotatedChannelKeyVersionNow marks the ChannelKeyVersion as expired.
   849  //
   850  // An error is returned if the provided ChannelKeyVersion has not been rotated.
   851  // This prevents the latest ChannelKeyVersion's resources from being deleted while they are still being used by sparrow.
   852  //
   853  // NOTE:
   854  // Edge super admins can check if a ChannelKeyVersion has actually been rotated by calling the `getLatestChannelKeyVersion()` query.
   855  func (cs *ChannelService) ExpireRotatedChannelKeyVersionNow(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) {
   856  	row := cs.db.QueryRowContext(ctx, sqlExpireRotatedChannelKeyVersionNow, id)
   857  	ckv, err := scanChannelKeyVersion(row)
   858  	if err != nil {
   859  		return ckv, fmt.Errorf("failed to expire rotated channel key version now: %w", err)
   860  	}
   861  	return ckv, nil
   862  }
   863  

View as plain text