1 package channels
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7
8 "edge-infra.dev/pkg/edge/api/graph/model"
9 apiServices "edge-infra.dev/pkg/edge/api/services"
10
11 "github.com/google/uuid"
12 )
13
14 var (
15
16
17
18
19
20
21
22
23
24
25
26
27 ErrChannelDoesNotExist = fmt.Errorf("channel does not exist")
28
29
30 ErrNoBannerChannels = fmt.Errorf("no channels exist in the banner")
31 )
32
33 type ChannelService struct {
34 db *sql.DB
35 foremanProjectID string
36 ChariotService apiServices.ChariotService
37 }
38
39 func NewChannelService(db *sql.DB, foremanProjectID string, chariotService apiServices.ChariotService) *ChannelService {
40 return &ChannelService{
41 db: db,
42 foremanProjectID: foremanProjectID,
43 ChariotService: chariotService,
44 }
45 }
46
47 type Service interface {
48 CreateChannel(ctx context.Context, channel Channel) (Channel, error)
49 DeleteChannel(ctx context.Context, channelID uuid.UUID, force bool) (Channel, error)
50 ExpireRotatedChannelKeyVersionNow(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error)
51 GetChannel(ctx context.Context, channelID uuid.UUID) (Channel, error)
52 GetChannelByName(ctx context.Context, name string) (Channel, error)
53 GetChannels(ctx context.Context, channelIDs ...uuid.UUID) ([]Channel, error)
54 GetChannelsByName(ctx context.Context, names ...string) ([]Channel, error)
55 GetChannelsForTeam(ctx context.Context, team string) ([]Channel, error)
56 GetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID) ([]Channel, error)
57 GetLatestChannelKeyVersion(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error)
58 RotateChannelNow(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error)
59 UpdateChannel(ctx context.Context, channelID uuid.UUID, request ChannelUpdateRequest) (Channel, error)
60 CreateChannelIAM(ctx context.Context, channelID uuid.UUID, saEmail string) (*model.ChannelIAMPolicy, error)
61 CreateChannelKeyVersion(ctx context.Context, ckv ChannelKeyVersion) (latest ChannelKeyVersion, err error)
62 GetChannelsFromHelmConfig(ctx context.Context, configYaml *string) ([]Channel, error)
63 CreateHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error
64 GetBannerChannels(ctx context.Context, bannerEdgeID uuid.UUID) ([]BannerChannel, error)
65 DeleteChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error)
66 }
67
68 const sqlCreateChannel = `INSERT INTO channels(name, description, team, expire_buffer_duration, rotation_interval_duration)
69 VALUES ($1, $2, $3, $4, $5)
70 RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at`
71
72
73 func (cs *ChannelService) CreateChannel(ctx context.Context, channel Channel) (Channel, error) {
74
75 if err := channel.validateCreate(); err != nil {
76 return channel, err
77 }
78
79 var row = cs.db.QueryRowContext(ctx, sqlCreateChannel,
80 channel.Name,
81 channel.Description,
82 channel.Team,
83 channel.ExpireBufferDuration,
84 channel.RotationIntervalDuration,
85 )
86
87 created, err := scanChannel(row)
88 if err != nil {
89 return created, fmt.Errorf("failed to create channel: %w", err)
90 }
91 return created, nil
92 }
93
94 const sqlUpdateChannel = `UPDATE channels
95 SET (team, description, expire_buffer_duration, rotation_interval_duration) = (
96 COALESCE($1, team),
97 COALESCE($2, description),
98 COALESCE($3, expire_buffer_duration),
99 COALESCE($4, rotation_interval_duration)
100 )
101 WHERE channel_id = $5
102 RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at`
103
104
105
106
107
108
109 func (cs *ChannelService) UpdateChannel(ctx context.Context, channelID uuid.UUID, request ChannelUpdateRequest) (Channel, error) {
110
111 if err := request.validate(); err != nil {
112 return Channel{}, err
113 }
114 var row = cs.db.QueryRowContext(ctx, sqlUpdateChannel,
115 request.Team,
116 request.Description,
117 request.ExpireBufferDuration,
118 request.RotationIntervalDuration,
119 channelID,
120 )
121
122 updated, err := scanChannel(row)
123 if err != nil {
124 return updated, fmt.Errorf("failed to update channel: %w", err)
125 }
126 return updated, err
127 }
128
129 const sqlDeleteChannel = `DELETE FROM channels
130 WHERE channel_id = $1
131 RETURNING channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at`
132
133 const sqlHelmWorkloadsForChannelExists = `SELECT EXISTS (
134 SELECT 1
135 FROM helm_workloads_channels
136 WHERE channel_id = $1
137 )`
138
139 func (cs *ChannelService) txCanDeleteChannel(ctx context.Context, tx *sql.Tx, channelID uuid.UUID) error {
140 var hasHelmWorkloads bool
141 var row = tx.QueryRowContext(ctx, sqlHelmWorkloadsForChannelExists, channelID)
142 if err := row.Scan(&hasHelmWorkloads); err != nil {
143 return fmt.Errorf("failed to check helm workloads used by channel: %w", err)
144 } else if hasHelmWorkloads {
145 return fmt.Errorf("channel is being used by helm workloads")
146 }
147
148 return nil
149 }
150
151
152
153
154 func (cs *ChannelService) DeleteChannel(ctx context.Context, channelID uuid.UUID, force bool) (Channel, error) {
155
156 tx, err := cs.db.BeginTx(ctx, nil)
157 if err != nil {
158 return Channel{}, err
159 }
160 defer tx.Rollback()
161
162 if !force {
163 if err := cs.txCanDeleteChannel(ctx, tx, channelID); err != nil {
164 return Channel{}, err
165 }
166 }
167
168 var row = tx.QueryRowContext(ctx, sqlDeleteChannel, channelID)
169 deleted, err := scanChannel(row)
170 if err != nil {
171 return Channel{}, fmt.Errorf("failed to delete channel: %w", err)
172 }
173
174 if err := tx.Commit(); err != nil {
175 return Channel{}, err
176 }
177 return deleted, nil
178 }
179
180 const sqlGetChannels = `SELECT channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at
181 FROM channels`
182
183 const sqlGetChannelsWithIDs = sqlGetChannels + `
184 WHERE channel_id = ANY ($1)`
185
186 const sqlGetChannelsForTeam = sqlGetChannels + `
187 WHERE team = $1`
188
189 const sqlGetChannelsByName = sqlGetChannels + `
190 WHERE name = ANY($1)`
191
192 func (cs *ChannelService) GetChannel(ctx context.Context, channelID uuid.UUID) (Channel, error) {
193 var row = cs.db.QueryRowContext(ctx, sqlGetChannelsWithIDs, []uuid.UUID{channelID})
194 return scanChannel(row)
195 }
196
197 func (cs *ChannelService) GetChannelByName(ctx context.Context, name string) (Channel, error) {
198 var row = cs.db.QueryRowContext(ctx, sqlGetChannelsByName, []string{name})
199 return scanChannel(row)
200 }
201
202
203
204
205 func (cs *ChannelService) GetChannelsByName(ctx context.Context, names ...string) ([]Channel, error) {
206 if len(names) == 0 {
207 return nil, nil
208 }
209
210 rows, err := cs.db.QueryContext(ctx, sqlGetChannelsByName, names)
211 if err != nil {
212 return nil, fmt.Errorf("failed to query channels by name: %w", err)
213 }
214 defer rows.Close()
215
216 var channels = make(map[string]Channel)
217 for rows.Next() {
218 channel, err := scanChannel(rows)
219 if err != nil {
220 return nil, err
221 }
222 channels[channel.Name] = channel
223 }
224 if err := rows.Err(); err != nil {
225 return nil, fmt.Errorf("failed to scan all channels by name: %w", err)
226 }
227
228 var ordered []Channel
229 for _, name := range names {
230 channel, found := channels[name]
231 if !found {
232 return nil, fmt.Errorf("%w: %s", ErrChannelDoesNotExist, name)
233 }
234 ordered = append(ordered, channel)
235 }
236 return ordered, nil
237 }
238
239
240
241
242
243
244
245
246
247 func (cs *ChannelService) GetChannelsFromHelmConfig(ctx context.Context, configYaml *string) ([]Channel, error) {
248 if configYaml == nil {
249 return make([]Channel, 0), nil
250 }
251 config, err := ParseHelmConfigChannels(*configYaml)
252 if err != nil {
253 return nil, err
254 }
255
256 if !config.HasChannels() {
257 return nil, nil
258 }
259
260 return cs.GetChannelsByName(ctx, config.Names()...)
261 }
262
263
264
265
266
267 func (cs *ChannelService) GetChannels(ctx context.Context, channelIDs ...uuid.UUID) ([]Channel, error) {
268 var args []interface{}
269 var stmt = sqlGetChannels
270 if len(channelIDs) > 0 {
271 args = append(args, channelIDs)
272 stmt = sqlGetChannelsWithIDs
273 }
274
275 rows, err := cs.db.QueryContext(ctx, stmt, args...)
276 if err != nil {
277 return nil, fmt.Errorf("failed to query channels: %w", err)
278 }
279 defer rows.Close()
280
281 var channels []Channel
282 for rows.Next() {
283 channel, err := scanChannel(rows)
284 if err != nil {
285 return nil, err
286 }
287 channels = append(channels, channel)
288 }
289 if err := rows.Err(); err != nil {
290 return nil, fmt.Errorf("failed to scan all channels: %w", err)
291 }
292
293
294 if len(channelIDs) == 0 {
295 return channels, nil
296 }
297
298
299
300 var m = make(map[uuid.UUID]Channel)
301 for _, channel := range channels {
302 m[channel.ID] = channel
303 }
304
305 var ordered []Channel
306 for _, channelID := range channelIDs {
307 channel, found := m[channelID]
308 if !found {
309 return nil, fmt.Errorf("%w: %s", ErrChannelDoesNotExist, channelID)
310 }
311 ordered = append(ordered, channel)
312 }
313 return ordered, nil
314 }
315
316
317 func (cs *ChannelService) GetChannelsForTeam(ctx context.Context, team string) ([]Channel, error) {
318 rows, err := cs.db.QueryContext(ctx, sqlGetChannelsForTeam, team)
319 if err != nil {
320 return nil, fmt.Errorf("failed to query channels for team: %w", err)
321 }
322 defer rows.Close()
323
324 var channels []Channel
325 for rows.Next() {
326 channel, err := scanChannel(rows)
327 if err != nil {
328 return nil, err
329 }
330 channels = append(channels, channel)
331 }
332 if err := rows.Err(); err != nil {
333 return nil, fmt.Errorf("failed to scan all channels for team: %w", err)
334 }
335
336 if len(channels) == 0 {
337 return nil, fmt.Errorf("%w for team: %q", ErrChannelDoesNotExist, team)
338 }
339 return channels, nil
340 }
341
342
343
344
345
346
347
348
349
350
351 func (cs *ChannelService) GetBannerChannels(ctx context.Context, bannerEdgeID uuid.UUID) ([]BannerChannel, error) {
352 tx, err := cs.db.BeginTx(ctx, nil)
353 if err != nil {
354 return nil, err
355 }
356 defer tx.Rollback()
357
358 hwcm, err := cs.txGetHelmWorkloadsChannelsForBanner(ctx, tx, bannerEdgeID)
359 if err != nil {
360 return nil, err
361 }
362
363 ckvm, err := cs.txGetChannelKeyVersionsForBanner(ctx, tx, bannerEdgeID)
364 if err != nil {
365 return nil, err
366 }
367
368
369 var dedupMap = make(map[uuid.UUID]struct{})
370 for id := range hwcm {
371 dedupMap[id] = struct{}{}
372 }
373 for id := range ckvm {
374 dedupMap[id] = struct{}{}
375 }
376 var dedup []uuid.UUID
377 for id := range dedupMap {
378 dedup = append(dedup, id)
379 }
380
381
382 if len(dedup) == 0 {
383 return nil, ErrNoBannerChannels
384 }
385
386
387 rows, err := tx.QueryContext(ctx, sqlGetChannelsWithIDs, dedup)
388 if err != nil {
389 return nil, fmt.Errorf("failed to query banner channels: %w", err)
390 }
391 defer rows.Close()
392
393 var bannerChannels []BannerChannel
394 for rows.Next() {
395 channel, err := scanChannel(rows)
396 if err != nil {
397 return nil, err
398 }
399 bannerChannels = append(bannerChannels, BannerChannel{
400 Channel: channel,
401 KeyVersions: ckvm[channel.ID],
402 HelmEdgeIDs: hwcm[channel.ID],
403 })
404 }
405 if err := rows.Err(); err != nil {
406 return nil, fmt.Errorf("failed to scan all channels with IDs: %w", err)
407 }
408 return bannerChannels, nil
409 }
410
411 const sqlGetHelmWorkloadsChannelsForBanner = `SELECT channel_id, helm_edge_id
412 FROM helm_workloads_channels
413 WHERE helm_edge_id = ANY (
414 SELECT helm_edge_id
415 FROM helm_workloads
416 WHERE banner_edge_id = $1
417 )`
418
419
420 func (cs *ChannelService) txGetHelmWorkloadsChannelsForBanner(ctx context.Context, tx *sql.Tx, bannerEdgeID uuid.UUID) (map[uuid.UUID][]uuid.UUID, error) {
421 rows, err := tx.QueryContext(ctx, sqlGetHelmWorkloadsChannelsForBanner, bannerEdgeID)
422 if err != nil {
423 return nil, fmt.Errorf("failed to query helm workloads channels for banner: %w", err)
424 }
425 defer rows.Close()
426
427 var m = make(map[uuid.UUID][]uuid.UUID)
428 for rows.Next() {
429 var channelID, helmEdgeID uuid.UUID
430 if err := rows.Scan(&channelID, &helmEdgeID); err != nil {
431 return nil, fmt.Errorf("failed to scan helm workloads channels for banner: %w", err)
432 }
433 m[channelID] = append(m[channelID], helmEdgeID)
434 }
435 if err := rows.Err(); err != nil {
436 return nil, fmt.Errorf("failed to scan all helm workloads channels for banner: %w", err)
437 }
438 return m, nil
439 }
440
441 const sqlGetChannelKeyVersionsForBanner = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at
442 FROM channels_key_versions
443 WHERE banner_edge_id = $1
444 ORDER BY version DESC`
445
446
447 func (cs *ChannelService) txGetChannelKeyVersionsForBanner(ctx context.Context, tx *sql.Tx, bannerEdgeID uuid.UUID) (map[uuid.UUID][]ChannelKeyVersion, error) {
448 rows, err := tx.QueryContext(ctx, sqlGetChannelKeyVersionsForBanner, bannerEdgeID)
449 if err != nil {
450 return nil, fmt.Errorf("failed to query channel key versions for banner: %w", err)
451 }
452 defer rows.Close()
453
454 var m = make(map[uuid.UUID][]ChannelKeyVersion)
455 for rows.Next() {
456 ckv, err := scanChannelKeyVersion(rows)
457 if err != nil {
458 return nil, err
459 }
460 m[ckv.ChannelID] = append(m[ckv.ChannelID], ckv)
461 }
462 if err := rows.Err(); err != nil {
463 return nil, fmt.Errorf("failed to scan all unexpired channels key versions for banner: %w", err)
464 }
465 return m, nil
466 }
467
468 const sqlCreateHelmWorkloadChannel = `INSERT INTO helm_workloads_channels(channel_id, helm_edge_id)
469 VALUES ($1, $2)
470 ON CONFLICT
471 ON CONSTRAINT unique_channel_id_helm_edge_id
472 DO NOTHING`
473
474 func (cs *ChannelService) CreateHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error {
475 if len(channelIDs) == 0 {
476 return fmt.Errorf("channelIDs must not be empty")
477 }
478
479 tx, err := cs.db.BeginTx(ctx, nil)
480 if err != nil {
481 return err
482 }
483 defer tx.Rollback()
484
485 for _, channelID := range channelIDs {
486 _, err := tx.ExecContext(ctx, sqlCreateHelmWorkloadChannel, channelID, helmEdgeID)
487 if err != nil {
488 return fmt.Errorf("failed to create helm workload channel for channel %q: %w", channelID, err)
489 }
490 }
491
492 return tx.Commit()
493 }
494
495 const sqlDeleteHelmWorkloadChannels = `DELETE FROM helm_workloads_channels
496 WHERE helm_edge_id = $1`
497
498 const sqlDeleteOmittedHelmWorkloadChannels = sqlDeleteHelmWorkloadChannels + `
499 AND NOT channel_id = ANY($2)`
500
501 const sqlDeleteHelmWorkloadChannelsWithChannelIDs = sqlDeleteHelmWorkloadChannels + `
502 AND channel_id = ANY($2)`
503
504
505
506
507
508 func (cs *ChannelService) SetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error {
509 tx, err := cs.db.BeginTx(ctx, nil)
510 if err != nil {
511 return err
512 }
513 defer tx.Rollback()
514
515 if len(channelIDs) == 0 {
516 _, err = tx.ExecContext(ctx, sqlDeleteHelmWorkloadChannels, helmEdgeID)
517 } else {
518 _, err = tx.ExecContext(ctx, sqlDeleteOmittedHelmWorkloadChannels, helmEdgeID, channelIDs)
519 }
520 if err != nil {
521 return fmt.Errorf("failed to set deleted helm workload channels: %w", err)
522 }
523
524 for _, channelID := range channelIDs {
525 _, err := tx.ExecContext(ctx, sqlCreateHelmWorkloadChannel, channelID, helmEdgeID)
526 if err != nil {
527 return fmt.Errorf("failed to set created helm workload channels: %w", err)
528 }
529 }
530
531 return tx.Commit()
532 }
533
534
535
536 func (cs *ChannelService) DeleteHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID, channelIDs ...uuid.UUID) error {
537 var err error
538 if len(channelIDs) == 0 {
539 _, err = cs.db.ExecContext(ctx, sqlDeleteHelmWorkloadChannels, helmEdgeID)
540 } else {
541 _, err = cs.db.ExecContext(ctx, sqlDeleteHelmWorkloadChannelsWithChannelIDs, helmEdgeID, channelIDs)
542 }
543
544 if err != nil {
545 return fmt.Errorf("failed to delete helm workload channels: %w", err)
546 }
547 return nil
548 }
549
550 const sqlGetHelmWorkloadChannels = `SELECT channel_id, name, description, team, expire_buffer_duration, rotation_interval_duration, created_at
551 FROM channels
552 WHERE channel_id IN (
553 SELECT channel_id FROM helm_workloads_channels WHERE helm_edge_id = $1
554 )`
555
556 func (cs *ChannelService) GetHelmWorkloadChannels(ctx context.Context, helmEdgeID uuid.UUID) ([]Channel, error) {
557 rows, err := cs.db.QueryContext(ctx, sqlGetHelmWorkloadChannels, helmEdgeID)
558 if err != nil {
559 return nil, fmt.Errorf("failed to query helm workload channels: %w", err)
560 }
561 defer rows.Close()
562
563 var channels []Channel
564 for rows.Next() {
565 channel, err := scanChannel(rows)
566 if err != nil {
567 return nil, err
568 }
569 channels = append(channels, channel)
570 }
571 if err := rows.Err(); err != nil {
572 return nil, fmt.Errorf("failed to scan all helm workload channels: %w", err)
573 }
574
575 return channels, nil
576 }
577
578 const sqlGetHelmWorkloadsForChannel = `SELECT helm_edge_id
579 FROM helm_workloads_channels
580 WHERE channel_id = $1
581 AND helm_edge_id = ANY (SELECT helm_edge_id FROM helm_workloads WHERE banner_edge_id = $2)`
582
583
584 func (cs *ChannelService) GetHelmWorkloadsForChannel(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (helmEdgeIDs []uuid.UUID, err error) {
585 rows, err := cs.db.QueryContext(ctx, sqlGetHelmWorkloadsForChannel, channelID, bannerEdgeID)
586 if err != nil {
587 return nil, fmt.Errorf("failed to query helm workloads for channel: %w", err)
588 }
589 defer rows.Close()
590
591 for rows.Next() {
592 var helmEdgeID uuid.UUID
593 if err := rows.Scan(&helmEdgeID); err != nil {
594 return nil, fmt.Errorf("failed to scan helm workloads for channel: %w", err)
595 }
596 helmEdgeIDs = append(helmEdgeIDs, helmEdgeID)
597 }
598 if err := rows.Err(); err != nil {
599 return nil, fmt.Errorf("failed to scan all helm workloads for channel: %w", err)
600 }
601 return helmEdgeIDs, nil
602 }
603
604 const sqlCreateChannelKeyVersion = `WITH convert_channel_durations_to_seconds AS (
605 SELECT
606 expire_buffer_duration / ` + secondStr + ` AS expire_buffer_seconds,
607 rotation_interval_duration / ` + secondStr + ` AS rotation_interval_seconds
608 FROM channels
609 WHERE channel_id = $1
610 LIMIT 1
611 )
612 INSERT INTO channels_key_versions (channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at)
613 VALUES (
614 $1,
615 $2,
616 $3,
617 $4,
618 NOW() + INTERVAL '1 SECOND' * (SELECT rotation_interval_seconds + expire_buffer_seconds FROM convert_channel_durations_to_seconds),
619 NOW() + INTERVAL '1 SECOND' * (SELECT rotation_interval_seconds FROM convert_channel_durations_to_seconds)
620 )
621 RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
622
623
624
625
626
627
628
629 func (cs *ChannelService) CreateChannelKeyVersion(ctx context.Context, ckv ChannelKeyVersion) (latest ChannelKeyVersion, err error) {
630 if err := ckv.validateCreate(); err != nil {
631 return ChannelKeyVersion{}, err
632 }
633
634 tx, err := cs.db.BeginTx(ctx, nil)
635 if err != nil {
636 return ChannelKeyVersion{}, err
637 }
638 defer tx.Rollback()
639
640 if err := cs.txMarkLatestChannelKeyVersionAsRotated(ctx, tx, ckv.BannerEdgeID, ckv.ChannelID); err != nil {
641 return ChannelKeyVersion{}, err
642 }
643
644 var row = tx.QueryRowContext(ctx, sqlCreateChannelKeyVersion,
645 ckv.ChannelID,
646 ckv.BannerEdgeID,
647 ckv.Version,
648 ckv.SecretManagerLink,
649 )
650
651 latest, err = scanChannelKeyVersion(row)
652 if err != nil {
653 return ChannelKeyVersion{}, fmt.Errorf("failed to create channel key version: %w", err)
654 }
655
656 if err = tx.Commit(); err != nil {
657 return ChannelKeyVersion{}, fmt.Errorf("failed to commit created channel key version: %w", err)
658 }
659
660 return latest, nil
661 }
662
663
664 const sqlTxMarkLatestChannelKeyVersionAsRotated = `WITH calc_expire_buffer_seconds AS (
665 SELECT expire_buffer_duration / ` + secondStr + ` AS seconds
666 FROM channels
667 WHERE channel_id = $1
668 LIMIT 1
669 ),
670 calc_expire_at AS (
671 SELECT NOW() + INTERVAL '1 SECOND' * (SELECT seconds FROM calc_expire_buffer_seconds)
672 AS expire_at
673 )
674 UPDATE channels_key_versions
675 SET
676 expire_at = LEAST(expire_at, (SELECT expire_at FROM calc_expire_at)),
677 rotate_at = NULL
678 WHERE channel_id = $1
679 AND banner_edge_id = $2
680 AND rotate_at IS NOT NULL`
681
682
683
684
685 func (cs *ChannelService) txMarkLatestChannelKeyVersionAsRotated(ctx context.Context, tx *sql.Tx, bannerEdgeID, channelID uuid.UUID) error {
686 _, err := tx.ExecContext(ctx, sqlTxMarkLatestChannelKeyVersionAsRotated, channelID, bannerEdgeID)
687 if err != nil {
688 return fmt.Errorf("failed to mark the latest channel key version as rotated: %w", err)
689 }
690 return nil
691 }
692
693 const sqlDeleteChannelKeyVersion = `DELETE FROM channels_key_versions
694 WHERE channel_key_versions_id = $1
695 AND expire_at < NOW()
696 RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
697
698
699
700
701 func (cs *ChannelService) DeleteChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) {
702 row := cs.db.QueryRowContext(ctx, sqlDeleteChannelKeyVersion, id)
703 ckv, err := scanChannelKeyVersion(row)
704 if err != nil {
705 return ckv, fmt.Errorf("failed to delete channel key version: %w", err)
706 }
707 return ckv, nil
708 }
709
710 const sqlGetChannelKeyVersionByID = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at
711 FROM channels_key_versions
712 WHERE channel_key_versions_id = $1`
713
714 func (cs *ChannelService) GetChannelKeyVersion(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) {
715 var row = cs.db.QueryRowContext(ctx, sqlGetChannelKeyVersionByID, id)
716 ckv, err := scanChannelKeyVersion(row)
717 if err != nil {
718 return ckv, fmt.Errorf("failed to scan channel key version: %w", err)
719 }
720 return ckv, nil
721 }
722
723 const sqlGetChannelKeyVersions = `SELECT channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at
724 FROM channels_key_versions
725 WHERE banner_edge_id = $1
726 AND channel_id = $2`
727
728 const sqlGetChannelKeyVersionsOrdered = sqlGetChannelKeyVersions + `
729 ORDER BY version DESC`
730
731 const sqlGetLatestChannelKeyVersion = sqlGetChannelKeyVersions + `
732 AND rotate_at IS NOT NULL`
733
734 func (cs *ChannelService) GetChannelKeyVersions(ctx context.Context, bannerEdgeID, channelID uuid.UUID) ([]ChannelKeyVersion, error) {
735 rows, err := cs.db.QueryContext(ctx, sqlGetChannelKeyVersionsOrdered, bannerEdgeID, channelID)
736 if err != nil {
737 return nil, fmt.Errorf("failed to query channel key versions: %w", err)
738 }
739 defer rows.Close()
740
741 var ckvs []ChannelKeyVersion
742 for rows.Next() {
743 ckv, err := scanChannelKeyVersion(rows)
744 if err != nil {
745 return nil, err
746 }
747 ckvs = append(ckvs, ckv)
748 }
749 if err := rows.Err(); err != nil {
750 return nil, fmt.Errorf("failed to scan all channel key versions: %w", err)
751 }
752 return ckvs, nil
753 }
754
755
756
757
758 func (cs *ChannelService) GetLatestChannelKeyVersion(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) {
759 var row = cs.db.QueryRowContext(ctx, sqlGetLatestChannelKeyVersion, bannerEdgeID, channelID)
760 return scanChannelKeyVersion(row)
761 }
762
763
818
819
820
821 const sqlRotateChannelNow = `UPDATE channels_key_versions
822 SET rotate_at = LEAST(rotate_at, NOW())
823 WHERE channel_id = $1
824 AND banner_edge_id = $2
825 AND rotate_at IS NOT NULL
826 RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
827
828
829
830
831 func (cs *ChannelService) RotateChannelNow(ctx context.Context, bannerEdgeID, channelID uuid.UUID) (ChannelKeyVersion, error) {
832 row := cs.db.QueryRowContext(ctx, sqlRotateChannelNow, channelID, bannerEdgeID)
833 ckv, err := scanChannelKeyVersion(row)
834 if err != nil {
835 return ckv, fmt.Errorf("failed to rotate channel now: %w", err)
836 }
837 return ckv, nil
838 }
839
840
841 const sqlExpireRotatedChannelKeyVersionNow = `UPDATE channels_key_versions
842 SET
843 expire_at = LEAST(expire_at, NOW())
844 WHERE channel_key_versions_id = $1
845 AND rotate_at IS NULL
846 RETURNING channel_key_versions_id, channel_id, banner_edge_id, version, sm_link, expire_at, rotate_at, created_at`
847
848
849
850
851
852
853
854
855 func (cs *ChannelService) ExpireRotatedChannelKeyVersionNow(ctx context.Context, id uuid.UUID) (ChannelKeyVersion, error) {
856 row := cs.db.QueryRowContext(ctx, sqlExpireRotatedChannelKeyVersionNow, id)
857 ckv, err := scanChannelKeyVersion(row)
858 if err != nil {
859 return ckv, fmt.Errorf("failed to expire rotated channel key version now: %w", err)
860 }
861 return ckv, nil
862 }
863
View as plain text