1 package services
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8
9 "edge-infra.dev/pkg/edge/api/graph/model"
10 )
11
12
13 type CABundleService interface {
14 GetCABundle(ctx context.Context, bannerID string, topLevelProject string) (*model.CaBundle, error)
15 }
16
17 type caBundleService struct {
18 gcpClientService GcpClientService
19 sqlDB *sql.DB
20 }
21
22 const (
23 getCAPoolID = `SELECT ca_pool_edge_id FROM ca_pools WHERE banner_edge_id = $1`
24 getCert = "SELECT cert_ref, status FROM ca_certificates WHERE ca_pool_edge_id = $1"
25 )
26
27 func (o *caBundleService) GetCABundle(ctx context.Context, bannerID string, topLevelProject string) (*model.CaBundle, error) {
28 response := &model.CaBundle{}
29 var caCerts []string
30 var poolID string
31
32
33 err := o.sqlDB.QueryRowContext(ctx, getCAPoolID, bannerID).Scan(&poolID)
34 if err != nil && err != sql.ErrNoRows {
35 return nil, fmt.Errorf("failed to retrieve CA pool: %w", err)
36 } else if err == sql.ErrNoRows {
37 return nil, fmt.Errorf("no CA pool found for banner ID: %s", bannerID)
38 }
39
40
41 var caBundle string
42 var currentCert string
43 var status string
44
45
46
47 rows, err := o.sqlDB.QueryContext(ctx, getCert, poolID)
48 if err != nil && err != sql.ErrNoRows {
49 return nil, fmt.Errorf("failed to retrieve retired CA certificates: %w", err)
50 }
51 defer rows.Close()
52 for rows.Next() {
53 if err := rows.Scan(¤tCert, &status); err != nil {
54 return nil, fmt.Errorf("failed to scan CA certificate: %w", err)
55 }
56 certificate, err := o.retrieveCert(ctx, currentCert, topLevelProject)
57 if err != nil {
58 return nil, fmt.Errorf("failed to retrieve CA certificate: %w", err)
59 }
60
61 if status != "deleted" {
62 caBundle += certificate
63 caCerts = append(caCerts, certificate)
64 }
65 }
66
67 response.CaBundle = caBundle
68 response.CaCerts = caCerts
69 return response, nil
70 }
71
72 func (o *caBundleService) retrieveCert(ctx context.Context, certRef string, topLevelProject string) (string, error) {
73
74 lastDash := strings.LastIndex(certRef, "-")
75 if lastDash == -1 {
76 return "", fmt.Errorf("invalid format")
77 }
78
79
80 certName := certRef[:lastDash]
81 certVersion := certRef[lastDash+1:]
82
83 secretManager, err := o.gcpClientService.GetSecretClient(ctx, topLevelProject)
84 if err != nil {
85 return "", fmt.Errorf("failed to get secret manager client: %w", err)
86 }
87 smResponse, err := secretManager.GetSecretVersionValue(ctx, certName, certVersion)
88 if err != nil {
89 return "", fmt.Errorf("failed to get secret version value: %w", err)
90 }
91 return string(smResponse), nil
92 }
93
94 func NewCABundleService(sqlDB *sql.DB, gcpClientService GcpClientService) CABundleService {
95 return &caBundleService{
96 sqlDB: sqlDB,
97 gcpClientService: gcpClientService,
98 }
99 }
100
View as plain text