package services import ( "context" "database/sql" "fmt" "strings" "edge-infra.dev/pkg/edge/api/graph/model" ) //go:generate mockgen -destination=../mocks/mock_ca_bundle_service.go -package=mocks edge-infra.dev/pkg/edge/api/services CABundleService type CABundleService interface { GetCABundle(ctx context.Context, bannerID string, topLevelProject string) (*model.CaBundle, error) } type caBundleService struct { gcpClientService GcpClientService sqlDB *sql.DB } const ( getCAPoolID = `SELECT ca_pool_edge_id FROM ca_pools WHERE banner_edge_id = $1` getCert = "SELECT cert_ref, status FROM ca_certificates WHERE ca_pool_edge_id = $1" ) func (o *caBundleService) GetCABundle(ctx context.Context, bannerID string, topLevelProject string) (*model.CaBundle, error) { response := &model.CaBundle{} var caCerts []string var poolID string //First we need to use the provided bannerID to retrieve the CA pool ID err := o.sqlDB.QueryRowContext(ctx, getCAPoolID, bannerID).Scan(&poolID) if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("failed to retrieve CA pool: %w", err) } else if err == sql.ErrNoRows { return nil, fmt.Errorf("no CA pool found for banner ID: %s", bannerID) } // Now we have the pool ID, we can proceed to get the CA bundle var caBundle string var currentCert string var status string // Retrieve the retired CA certificates // there could be many retired certs so loop through to add them rows, err := o.sqlDB.QueryContext(ctx, getCert, poolID) if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("failed to retrieve retired CA certificates: %w", err) } defer rows.Close() for rows.Next() { if err := rows.Scan(¤tCert, &status); err != nil { return nil, fmt.Errorf("failed to scan CA certificate: %w", err) } certificate, err := o.retrieveCert(ctx, currentCert, topLevelProject) if err != nil { return nil, fmt.Errorf("failed to retrieve CA certificate: %w", err) } if status != "deleted" { caBundle += certificate caCerts = append(caCerts, certificate) } } response.CaBundle = caBundle response.CaCerts = caCerts return response, nil } func (o *caBundleService) retrieveCert(ctx context.Context, certRef string, topLevelProject string) (string, error) { // Find the last occurrence of '-' lastDash := strings.LastIndex(certRef, "-") if lastDash == -1 { return "", fmt.Errorf("invalid format") } // Extract name and version certName := certRef[:lastDash] // Everything before the last '-' certVersion := certRef[lastDash+1:] // Everything after the last '-' secretManager, err := o.gcpClientService.GetSecretClient(ctx, topLevelProject) if err != nil { return "", fmt.Errorf("failed to get secret manager client: %w", err) } smResponse, err := secretManager.GetSecretVersionValue(ctx, certName, certVersion) if err != nil { return "", fmt.Errorf("failed to get secret version value: %w", err) } return string(smResponse), nil } func NewCABundleService(sqlDB *sql.DB, gcpClientService GcpClientService) CABundleService { return &caBundleService{ sqlDB: sqlDB, gcpClientService: gcpClientService, } }