package channels import ( "context" "database/sql" "fmt" "edge-infra.dev/pkg/edge/api/graph/model" apiServices "edge-infra.dev/pkg/edge/api/services" "github.com/google/uuid" ) var ( // ErrChannelDoesNotExist is returned when a provided channel does not exist. // // It is returned by the following service calls: // - DeleteChannel // - GetChannel // - GetChannels // - GetChannelByName // - GetChannelsByName // - GetChannelsForTeam // - UpdateChannel // // TODO: check foreign key constraint violations in other service calls to see if the channel does not exist. ErrChannelDoesNotExist = fmt.Errorf("channel does not exist") // ErrNoBannerChannels is returned by GetBannerChannels if no channels are referenced by the banner. ErrNoBannerChannels = fmt.Errorf("no channels exist in the banner") ) type ChannelService struct { db *sql.DB foremanProjectID string ChariotService apiServices.ChariotService } func NewChannelService(db *sql.DB, foremanProjectID string, chariotService apiServices.ChariotService) *ChannelService { return &ChannelService{ db: db, foremanProjectID: foremanProjectID, ChariotService: chariotService, } } type Service interface { CreateChannel(ctx context.Context, channel Channel) (Channel, error) DeleteChannel(ctx context.Context, channelID uuid.UUID, force bool) (Channel, error) ExpireRotatedChannelKeyVersionNow(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) GetChannel(ctx context.Context, channelID uuid.UUID) (Channel, error) GetChannelByName(ctx context.Context, name string) (Channel, error) GetChannels(ctx context.Context, channelIDs ...uuid.UUID) ([]Channel, error) GetChannelsByName(ctx context.Context, names ...string) ([]Channel, error) GetChannelsForTeam(ctx context.Context, team string) ([]Channel, error) GetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID) ([]Channel, error) GetLatestChannelKeyVersion(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) RotateChannelNow(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) UpdateChannel(ctx context.Context, channelID uuid.UUID, request ChannelUpdateRequest) (Channel, error) CreateChannelIAM(ctx context.Context, channelID uuid.UUID, saEmail string) (*model.ChannelIAMPolicy, error) CreateChannelKeyVersion(ctx context.Context, ckv ChannelKeyVersion) (latest ChannelKeyVersion, err error) GetChannelsFromHelmConfig(ctx context.Context, configYaml *string) ([]Channel, error) CreateHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error GetBannerChannels(ctx context.Context, bannerEdgeID uuid.UUID) ([]BannerChannel, error) DeleteChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) } const sqlCreateChannel = `INSERT INTO channels(name, description, team, expire_buffer_duration, rotation_interval_duration) VALUES ($1, $2, $3, $4, $5) RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at` // CreateChannel returns the created Channel with all fields populated. func (cs *ChannelService) CreateChannel(ctx context.Context, channel Channel) (Channel, error) { // TODO: find a hack that ensures the context was passed in by the resolver for a super admin. if err := channel.validateCreate(); err != nil { return channel, err } var row = cs.db.QueryRowContext(ctx, sqlCreateChannel, channel.Name, channel.Description, channel.Team, channel.ExpireBufferDuration, channel.RotationIntervalDuration, ) created, err := scanChannel(row) if err != nil { return created, fmt.Errorf("failed to create channel: %w", err) } return created, nil } const sqlUpdateChannel = `UPDATE channels SET (team, description, expire_buffer_duration, rotation_interval_duration) = ( COALESCE($1, team), COALESCE($2, description), COALESCE($3, expire_buffer_duration), COALESCE($4, rotation_interval_duration) ) WHERE channel_id = $5 RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at` // UpdateChannel should only be called by the `updateChannel` resolver available to super admins. // // TODO: // Run this in a transaction and recalculate the ChannelKeyVersion's `rotate_at` and `expire_at` times, when their fields have been updated. // For now, updating `expire_buffer_duration` and `rotation_interval_duration` only affects stuff that happens in the future. func (cs *ChannelService) UpdateChannel(ctx context.Context, channelID uuid.UUID, request ChannelUpdateRequest) (Channel, error) { // TODO: find some hack that ensures the context was passed in by the resolver for a super admin. if err := request.validate(); err != nil { return Channel{}, err } var row = cs.db.QueryRowContext(ctx, sqlUpdateChannel, request.Team, request.Description, request.ExpireBufferDuration, request.RotationIntervalDuration, channelID, ) updated, err := scanChannel(row) if err != nil { return updated, fmt.Errorf("failed to update channel: %w", err) } return updated, err } const sqlDeleteChannel = `DELETE FROM channels WHERE channel_id = $1 RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at` const sqlHelmWorkloadsForChannelExists = `SELECT EXISTS ( SELECT 1 FROM helm_workloads_channels WHERE channel_id = $1 )` func (cs *ChannelService) txCanDeleteChannel(ctx context.Context, tx *sql.Tx, channelID uuid.UUID) error { var hasHelmWorkloads bool var row = tx.QueryRowContext(ctx, sqlHelmWorkloadsForChannelExists, channelID) if err := row.Scan(&hasHelmWorkloads); err != nil { return fmt.Errorf("failed to check helm workloads used by channel: %w", err) } else if hasHelmWorkloads { return fmt.Errorf("channel is being used by helm workloads") } return nil } // DeleteChannel should only be called by the `deleteChannel` resolver available to super admins. // // When `force` is false, the DeleteChannel method returns an error if the channel is mapped to any helm workloads. func (cs *ChannelService) DeleteChannel(ctx context.Context, channelID uuid.UUID, force bool) (Channel, error) { // TODO: find some hack that ensures the context was passed in by the resolver for a super admin. tx, err := cs.db.BeginTx(ctx, nil) if err != nil { return Channel{}, err } defer tx.Rollback() //nolint: errcheck if !force { if err := cs.txCanDeleteChannel(ctx, tx, channelID); err != nil { return Channel{}, err } } var row = tx.QueryRowContext(ctx, sqlDeleteChannel, channelID) deleted, err := scanChannel(row) if err != nil { return Channel{}, fmt.Errorf("failed to delete channel: %w", err) } if err := tx.Commit(); err != nil { return Channel{}, err } return deleted, nil } const sqlGetChannels = `SELECT channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at FROM channels` const sqlGetChannelsWithIDs = sqlGetChannels + ` WHERE channel_id = ANY ($1)` const sqlGetChannelsForTeam = sqlGetChannels + ` WHERE team = $1` const sqlGetChannelsByName = sqlGetChannels + ` WHERE name = ANY($1)` func (cs *ChannelService) GetChannel(ctx context.Context, channelID uuid.UUID) (Channel, error) { var row = cs.db.QueryRowContext(ctx, sqlGetChannelsWithIDs, []uuid.UUID{channelID}) return scanChannel(row) } func (cs *ChannelService) GetChannelByName(ctx context.Context, name string) (Channel, error) { var row = cs.db.QueryRowContext(ctx, sqlGetChannelsByName, []string{name}) return scanChannel(row) } // GetChannelsByName returns a Channel slice indexed in same order as the provided `names`. // // An error is returned if any of the channels do not exist. func (cs *ChannelService) GetChannelsByName(ctx context.Context, names ...string) ([]Channel, error) { if len(names) == 0 { return nil, nil } rows, err := cs.db.QueryContext(ctx, sqlGetChannelsByName, names) if err != nil { return nil, fmt.Errorf("failed to query channels by name: %w", err) } defer rows.Close() //nolint: errcheck var channels = make(map[string]Channel) for rows.Next() { channel, err := scanChannel(rows) if err != nil { return nil, err } channels[channel.Name] = channel } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all channels by name: %w", err) } var ordered []Channel for _, name := range names { channel, found := channels[name] if !found { return nil, fmt.Errorf("%w: %s", ErrChannelDoesNotExist, name) } ordered = append(ordered, channel) } return ordered, nil } // GetChannelsFromHelmConfig attempts to parse the YAML helm config, then it queries the configured channels. // // If no channels are configured, an empty slice is returned without an error. // // An error is returned if: // - the configYAML is invalid. // - any of the configured channel names are invalid. // - any of the configured channels do not exist in the database (returns ErrChannelDoesNotExist). func (cs *ChannelService) GetChannelsFromHelmConfig(ctx context.Context, configYaml *string) ([]Channel, error) { if configYaml == nil { return make([]Channel, 0), nil } config, err := ParseHelmConfigChannels(*configYaml) if err != nil { return nil, err } if !config.HasChannels() { return nil, nil } return cs.GetChannelsByName(ctx, config.Names()...) } // GetChannels returns a Channel slice indexed in same order as the provided `channelIDs`. // All channels are returned if the `channelIDs` argument is empty. // // When a provided channel is not found, GetChannels returns ErrChannelDoesNotExist. func (cs *ChannelService) GetChannels(ctx context.Context, channelIDs ...uuid.UUID) ([]Channel, error) { var args []interface{} var stmt = sqlGetChannels if len(channelIDs) > 0 { args = append(args, channelIDs) stmt = sqlGetChannelsWithIDs } rows, err := cs.db.QueryContext(ctx, stmt, args...) if err != nil { return nil, fmt.Errorf("failed to query channels: %w", err) } defer rows.Close() //nolint: errcheck var channels []Channel for rows.Next() { channel, err := scanChannel(rows) if err != nil { return nil, err } channels = append(channels, channel) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all channels: %w", err) } // no need to order the channels when none are provided. if len(channelIDs) == 0 { return channels, nil } // order the channels by the same index as the `channelIDs` argument. // this is a nice feature to have in practice, and makes testing easier. var m = make(map[uuid.UUID]Channel) for _, channel := range channels { m[channel.ID] = channel } var ordered []Channel for _, channelID := range channelIDs { channel, found := m[channelID] if !found { return nil, fmt.Errorf("%w: %s", ErrChannelDoesNotExist, channelID) } ordered = append(ordered, channel) } return ordered, nil } // TODO: check with Luc/Lee to see if `Team` should have a UNIQUE constraint like `Name` func (cs *ChannelService) GetChannelsForTeam(ctx context.Context, team string) ([]Channel, error) { rows, err := cs.db.QueryContext(ctx, sqlGetChannelsForTeam, team) if err != nil { return nil, fmt.Errorf("failed to query channels for team: %w", err) } defer rows.Close() //nolint: errcheck var channels []Channel for rows.Next() { channel, err := scanChannel(rows) if err != nil { return nil, err } channels = append(channels, channel) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all channels for team: %w", err) } if len(channels) == 0 { return nil, fmt.Errorf("%w for team: %q", ErrChannelDoesNotExist, team) } return channels, nil } // GetBannerChannels returns all channels being referenced by this banner. // - any channel in `helm_workloads_channels` mapped to this banner's helm workloads. // - any channel in `channels_key_versions` that references this banner. // // GetBannerChannels is intended to be used by bannerctl and clusterctl within "banner-infra" clusters. // Its purpose is to retrieve all the data needed for channel reconciliation, at the banner level, from a single call to the channel service. // This should substantially simplify their reconcile function, improve its performance, and reduce their potential for errors. // // If no channels exist in the banner, then GetBannerChannels returns `ErrNoBannerChannels` func (cs *ChannelService) GetBannerChannels(ctx context.Context, bannerEdgeID uuid.UUID) ([]BannerChannel, error) { tx, err := cs.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() //nolint: errcheck hwcm, err := cs.txGetHelmWorkloadsChannelsForBanner(ctx, tx, bannerEdgeID) if err != nil { return nil, err } ckvm, err := cs.txGetChannelKeyVersionsForBanner(ctx, tx, bannerEdgeID) if err != nil { return nil, err } // deduplicate the channels being used in this banner var dedupMap = make(map[uuid.UUID]struct{}) for id := range hwcm { dedupMap[id] = struct{}{} } for id := range ckvm { dedupMap[id] = struct{}{} } var dedup []uuid.UUID for id := range dedupMap { dedup = append(dedup, id) } // return early if the banner is not using any channels if len(dedup) == 0 { return nil, ErrNoBannerChannels } // Get all channels for this banner. rows, err := tx.QueryContext(ctx, sqlGetChannelsWithIDs, dedup) if err != nil { return nil, fmt.Errorf("failed to query banner channels: %w", err) } defer rows.Close() //nolint: errcheck var bannerChannels []BannerChannel for rows.Next() { channel, err := scanChannel(rows) if err != nil { return nil, err } bannerChannels = append(bannerChannels, BannerChannel{ Channel: channel, KeyVersions: ckvm[channel.ID], HelmEdgeIDs: hwcm[channel.ID], }) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all channels with IDs: %w", err) } return bannerChannels, nil } const sqlGetHelmWorkloadsChannelsForBanner = `SELECT channel_id, helm_edge_id FROM helm_workloads_channels WHERE helm_edge_id = ANY ( SELECT helm_edge_id FROM helm_workloads WHERE banner_edge_id = $1 )` // txGetHelmWorkloadsChannelsForBanner returns a map[channel_id][]helm_edge_id func (cs *ChannelService) txGetHelmWorkloadsChannelsForBanner(ctx context.Context, tx *sql.Tx, bannerEdgeID uuid.UUID) (map[uuid.UUID][]uuid.UUID, error) { rows, err := tx.QueryContext(ctx, sqlGetHelmWorkloadsChannelsForBanner, bannerEdgeID) if err != nil { return nil, fmt.Errorf("failed to query helm workloads channels for banner: %w", err) } defer rows.Close() //nolint: errcheck var m = make(map[uuid.UUID][]uuid.UUID) for rows.Next() { var channelID, helmEdgeID uuid.UUID if err := rows.Scan(&channelID, &helmEdgeID); err != nil { return nil, fmt.Errorf("failed to scan helm workloads channels for banner: %w", err) } m[channelID] = append(m[channelID], helmEdgeID) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all helm workloads channels for banner: %w", err) } return m, nil } const sqlGetChannelKeyVersionsForBanner = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at FROM channels_key_versions WHERE banner_edge_id = $1 ORDER BY version DESC` // txGetUnexpiredChannelKeyVersionsForBanner returns a map[channel_id][]ChannelKeyVersion func (cs *ChannelService) txGetChannelKeyVersionsForBanner(ctx context.Context, tx *sql.Tx, bannerEdgeID uuid.UUID) (map[uuid.UUID][]ChannelKeyVersion, error) { rows, err := tx.QueryContext(ctx, sqlGetChannelKeyVersionsForBanner, bannerEdgeID) if err != nil { return nil, fmt.Errorf("failed to query channel key versions for banner: %w", err) } defer rows.Close() //nolint: errcheck var m = make(map[uuid.UUID][]ChannelKeyVersion) for rows.Next() { ckv, err := scanChannelKeyVersion(rows) if err != nil { return nil, err } m[ckv.ChannelID] = append(m[ckv.ChannelID], ckv) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all unexpired channels key versions for banner: %w", err) } return m, nil } const sqlCreateHelmWorkloadChannel = `INSERT INTO helm_workloads_channels(channel_id, helm_edge_id) VALUES ($1, $2) ON CONFLICT ON CONSTRAINT unique_channel_id_helm_edge_id DO NOTHING` func (cs *ChannelService) CreateHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error { if len(channelIDs) == 0 { return fmt.Errorf("channelIDs must not be empty") } tx, err := cs.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() //nolint: errcheck for _, channelID := range channelIDs { _, err := tx.ExecContext(ctx, sqlCreateHelmWorkloadChannel, channelID, helmEdgeID) if err != nil { return fmt.Errorf("failed to create helm workload channel for channel %q: %w", channelID, err) } } return tx.Commit() } const sqlDeleteHelmWorkloadChannels = `DELETE FROM helm_workloads_channels WHERE helm_edge_id = $1` const sqlDeleteOmittedHelmWorkloadChannels = sqlDeleteHelmWorkloadChannels + ` AND NOT channel_id = ANY($2)` const sqlDeleteHelmWorkloadChannelsWithChannelIDs = sqlDeleteHelmWorkloadChannels + ` AND channel_id = ANY($2)` // SetHelmWorkloadChannels sets the helm workload's channel mappings. // It creates the helm workload's mappings to the provided channels, and deletes the helm workload's mappings for omitted channels. // // If no channels are provided, this function acts like DeleteHelmWorkloadChannel, and deletes all of the helm workload's mappings. func (cs *ChannelService) SetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error { tx, err := cs.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() //nolint: errcheck if len(channelIDs) == 0 { _, err = tx.ExecContext(ctx, sqlDeleteHelmWorkloadChannels, helmEdgeID) } else { _, err = tx.ExecContext(ctx, sqlDeleteOmittedHelmWorkloadChannels, helmEdgeID, channelIDs) } if err != nil { return fmt.Errorf("failed to set deleted helm workload channels: %w", err) } for _, channelID := range channelIDs { _, err := tx.ExecContext(ctx, sqlCreateHelmWorkloadChannel, channelID, helmEdgeID) if err != nil { return fmt.Errorf("failed to set created helm workload channels: %w", err) } } return tx.Commit() } // DeleteHelmWorkloadChannels deletes the helm workload's mappings for the provided channels. // If no channels are provided, it deletes all of the helm workload's mappings. func (cs *ChannelService) DeleteHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error { var err error if len(channelIDs) == 0 { _, err = cs.db.ExecContext(ctx, sqlDeleteHelmWorkloadChannels, helmEdgeID) } else { _, err = cs.db.ExecContext(ctx, sqlDeleteHelmWorkloadChannelsWithChannelIDs, helmEdgeID, channelIDs) } if err != nil { return fmt.Errorf("failed to delete helm workload channels: %w", err) } return nil } const sqlGetHelmWorkloadChannels = `SELECT channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at FROM channels WHERE channel_id IN ( SELECT channel_id FROM helm_workloads_channels WHERE helm_edge_id = $1 )` func (cs *ChannelService) GetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID) ([]Channel, error) { rows, err := cs.db.QueryContext(ctx, sqlGetHelmWorkloadChannels, helmEdgeID) if err != nil { return nil, fmt.Errorf("failed to query helm workload channels: %w", err) } defer rows.Close() var channels []Channel for rows.Next() { channel, err := scanChannel(rows) if err != nil { return nil, err } channels = append(channels, channel) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all helm workload channels: %w", err) } return channels, nil } const sqlGetHelmWorkloadsForChannel = `SELECT helm_edge_id FROM helm_workloads_channels WHERE channel_id = $1 AND helm_edge_id = ANY (SELECT helm_edge_id FROM helm_workloads WHERE banner_edge_id = $2)` // GetHelmWorkloadsForChannel returns all helm workloads mapped to the channel within the provided banner func (cs *ChannelService) GetHelmWorkloadsForChannel(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (helmEdgeIDs []uuid.UUID, err error) { rows, err := cs.db.QueryContext(ctx, sqlGetHelmWorkloadsForChannel, channelID, bannerEdgeID) if err != nil { return nil, fmt.Errorf("failed to query helm workloads for channel: %w", err) } defer rows.Close() //nolint: errcheck for rows.Next() { var helmEdgeID uuid.UUID if err := rows.Scan(&helmEdgeID); err != nil { return nil, fmt.Errorf("failed to scan helm workloads for channel: %w", err) } helmEdgeIDs = append(helmEdgeIDs, helmEdgeID) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all helm workloads for channel: %w", err) } return helmEdgeIDs, nil } const sqlCreateChannelKeyVersion = `WITH convert_channel_durations_to_seconds AS ( SELECT expire_buffer_duration / ` + secondStr + ` AS expire_buffer_seconds, rotation_interval_duration / ` + secondStr + ` AS rotation_interval_seconds FROM channels WHERE channel_id = $1 LIMIT 1 ) INSERT INTO channels_key_versions (channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at) VALUES ( $1, $2, $3, $4, NOW() + INTERVAL '1 SECOND' * (SELECT rotation_interval_seconds + expire_buffer_seconds FROM convert_channel_durations_to_seconds), NOW() + INTERVAL '1 SECOND' * (SELECT rotation_interval_seconds FROM convert_channel_durations_to_seconds) ) RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at` // CreateChannelKeyVersion inserts the latest ChannelKeyVersion, which fully rotates the channel within a banner. // // The following steps are performed within a transaction: // 1. Mark the currently-latest ChannelKeyVersion as rotated, by nullifying its `rotate_at` column. // 2. Insert the new ChannelKeyVersion, which calculates its `created_at`, `expire_at`, and `rotate_at` columns. // 3. Scan the inserted row into the ChannelKeyVersion struct that CreateChannelKeyVersion returns. func (cs *ChannelService) CreateChannelKeyVersion(ctx context.Context, ckv ChannelKeyVersion) (latest ChannelKeyVersion, err error) { if err := ckv.validateCreate(); err != nil { return ChannelKeyVersion{}, err } tx, err := cs.db.BeginTx(ctx, nil) if err != nil { return ChannelKeyVersion{}, err } defer tx.Rollback() //nolint: errcheck if err := cs.txMarkLatestChannelKeyVersionAsRotated(ctx, tx, ckv.BannerEdgeID, ckv.ChannelID); err != nil { return ChannelKeyVersion{}, err } var row = tx.QueryRowContext(ctx, sqlCreateChannelKeyVersion, ckv.ChannelID, ckv.BannerEdgeID, ckv.Version, ckv.SecretManagerLink, ) latest, err = scanChannelKeyVersion(row) if err != nil { return ChannelKeyVersion{}, fmt.Errorf("failed to create channel key version: %w", err) } if err = tx.Commit(); err != nil { return ChannelKeyVersion{}, fmt.Errorf("failed to commit created channel key version: %w", err) } return latest, nil } // This query uses the `LEAST` function to prevent the `expire_at` timestamp from increasing. const sqlTxMarkLatestChannelKeyVersionAsRotated = `WITH calc_expire_buffer_seconds AS ( SELECT expire_buffer_duration / ` + secondStr + ` AS seconds FROM channels WHERE channel_id = $1 LIMIT 1 ), calc_expire_at AS ( SELECT NOW() + INTERVAL '1 SECOND' * (SELECT seconds FROM calc_expire_buffer_seconds) AS expire_at ) UPDATE channels_key_versions SET expire_at = LEAST(expire_at, (SELECT expire_at FROM calc_expire_at)), rotate_at = NULL WHERE channel_id = $1 AND banner_edge_id = $2 AND rotate_at IS NOT NULL` // txMarkLatestChannelKeyVersionAsRotated nullifies the `rotate_at` time for the latest ChannelKeyVersion, and recalculates the `expire_at` column. // // The `rotate_at` time only makes sense for the latest ChannelKeyVersion, so the column is set to NULL for outdated ChannelKeyVersions. func (cs *ChannelService) txMarkLatestChannelKeyVersionAsRotated(ctx context.Context, tx *sql.Tx, bannerEdgeID, channelID uuid.UUID) error { _, err := tx.ExecContext(ctx, sqlTxMarkLatestChannelKeyVersionAsRotated, channelID, bannerEdgeID) if err != nil { return fmt.Errorf("failed to mark the latest channel key version as rotated: %w", err) } return nil } const sqlDeleteChannelKeyVersion = `DELETE FROM channels_key_versions WHERE channel_key_versions_id = $1 AND expire_at < NOW() RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at` // DeleteChannelKeyVersion is used to finalize the deletion of an expired ChannelKeyVersion, after all its resources have been cleaned up by controllers. // // An error occurs if the ChannelKeyVersion's ExpireAt time has not elapsed, since magpie is still using it for decryption. func (cs *ChannelService) DeleteChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) { row := cs.db.QueryRowContext(ctx, sqlDeleteChannelKeyVersion, id) ckv, err := scanChannelKeyVersion(row) if err != nil { return ckv, fmt.Errorf("failed to delete channel key version: %w", err) } return ckv, nil } const sqlGetChannelKeyVersionByID = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at FROM channels_key_versions WHERE channel_key_versions_id = $1` func (cs *ChannelService) GetChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) { var row = cs.db.QueryRowContext(ctx, sqlGetChannelKeyVersionByID, id) ckv, err := scanChannelKeyVersion(row) if err != nil { return ckv, fmt.Errorf("failed to scan channel key version: %w", err) } return ckv, nil } const sqlGetChannelKeyVersions = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at FROM channels_key_versions WHERE banner_edge_id = $1 AND channel_id = $2` const sqlGetChannelKeyVersionsOrdered = sqlGetChannelKeyVersions + ` ORDER BY version DESC` const sqlGetLatestChannelKeyVersion = sqlGetChannelKeyVersions + ` AND rotate_at IS NOT NULL` func (cs *ChannelService) GetChannelKeyVersions(ctx context.Context, bannerEdgeID, channelID uuid.UUID) ([]ChannelKeyVersion, error) { rows, err := cs.db.QueryContext(ctx, sqlGetChannelKeyVersionsOrdered, bannerEdgeID, channelID) if err != nil { return nil, fmt.Errorf("failed to query channel key versions: %w", err) } defer rows.Close() //nolint: errcheck var ckvs []ChannelKeyVersion for rows.Next() { ckv, err := scanChannelKeyVersion(rows) if err != nil { return nil, err } ckvs = append(ckvs, ckv) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all channel key versions: %w", err) } return ckvs, nil } // GetLatestChannelKeyVersion returns the most recently created ChannelKeyVersion for a channel in this banner. // // If no ChannelKeyVersion exists for the channel in this banner, it returns `sql.ErrNoRows`. func (cs *ChannelService) GetLatestChannelKeyVersion(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) { var row = cs.db.QueryRowContext(ctx, sqlGetLatestChannelKeyVersion, bannerEdgeID, channelID) return scanChannelKeyVersion(row) } /* const sqlGetUnexpiredChannelKeyVersions = sqlGetChannelKeyVersions + ` AND expire_at > NOW() ORDER BY version DESC` const sqlGetExpiredChannelKeyVersions = sqlGetChannelKeyVersions + ` AND expire_at < NOW() ORDER BY version ASC` // GetUnexpiredChannelKeyVersions returns all unexpired key versions, sorted by the most recent version. func (cs *ChannelService) GetUnexpiredChannelKeyVersions(ctx context.Context, bannerEdgeID, channelID uuid.UUID) ([]ChannelKeyVersion, error) { rows, err := cs.db.QueryContext(ctx, sqlGetUnexpiredChannelKeyVersions, bannerEdgeID, channelID) if err != nil { return nil, fmt.Errorf("failed to query unexpired channel key versions: %w", err) } defer rows.Close() var ckvs []ChannelKeyVersion for rows.Next() { ckv, err := scanChannelKeyVersion(rows) if err != nil { return nil, err } ckvs = append(ckvs, ckv) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all unexpired channel key versions: %w", err) } return ckvs, nil } // GetExpiredChannelKeyVersions is provided for clean up purposes. // // After cleaning up a ChannelKeyVersion's resources, the `DeleteChannelKeyVersion` method should be called. func (cs *ChannelService) GetExpiredChannelKeyVersions(ctx context.Context, bannerEdgeID, channelID uuid.UUID) ([]ChannelKeyVersion, error) { rows, err := cs.db.QueryContext(ctx, sqlGetExpiredChannelKeyVersions, bannerEdgeID, channelID) if err != nil { return nil, fmt.Errorf("failed to query expired channel key versions: %w", err) } defer rows.Close() var ckvs []ChannelKeyVersion for rows.Next() { ckv, err := scanChannelKeyVersion(rows) if err != nil { return nil, err } ckvs = append(ckvs, ckv) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("failed to scan all expired channel key versions: %w", err) } return ckvs, nil } */ // Using the LEAST condition since the channel is already marked for rotation if `rotate_at` is less than NOW(). // Using the LEAST condition also ensures `rotate_at` is always less than `expired_at`, which would break scan validation assertions. const sqlRotateChannelNow = `UPDATE channels_key_versions SET rotate_at = LEAST(rotate_at, NOW()) WHERE channel_id = $1 AND banner_edge_id = $2 AND rotate_at IS NOT NULL RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at` // RotateChannelNow marks the latest ChannelKeyVersion for rotation. // // The channel is actually rotated when a new ChannelKeyVersion is created. func (cs *ChannelService) RotateChannelNow(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) { row := cs.db.QueryRowContext(ctx, sqlRotateChannelNow, channelID, bannerEdgeID) ckv, err := scanChannelKeyVersion(row) if err != nil { return ckv, fmt.Errorf("failed to rotate channel now: %w", err) } return ckv, nil } // Marking the latest ChannelKeyVersion as expired introduces nasty edge cases, in SQL and architecturally, so this query is limited to keys that have already been rotated. const sqlExpireRotatedChannelKeyVersionNow = `UPDATE channels_key_versions SET expire_at = LEAST(expire_at, NOW()) WHERE channel_key_versions_id = $1 AND rotate_at IS NULL RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at` // ExpireRotatedChannelKeyVersionNow marks the ChannelKeyVersion as expired. // // An error is returned if the provided ChannelKeyVersion has not been rotated. // This prevents the latest ChannelKeyVersion's resources from being deleted while they are still being used by sparrow. // // NOTE: // Edge super admins can check if a ChannelKeyVersion has actually been rotated by calling the `getLatestChannelKeyVersion()` query. func (cs *ChannelService) ExpireRotatedChannelKeyVersionNow(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) { row := cs.db.QueryRowContext(ctx, sqlExpireRotatedChannelKeyVersionNow, id) ckv, err := scanChannelKeyVersion(row) if err != nil { return ckv, fmt.Errorf("failed to expire rotated channel key version now: %w", err) } return ckv, nil }