...

Source file src/edge-infra.dev/pkg/edge/api/services/ca_bundle_service.go

Documentation: edge-infra.dev/pkg/edge/api/services

     1  package services
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"strings"
     8  
     9  	"edge-infra.dev/pkg/edge/api/graph/model"
    10  )
    11  
    12  //go:generate mockgen -destination=../mocks/mock_ca_bundle_service.go -package=mocks edge-infra.dev/pkg/edge/api/services CABundleService
    13  type CABundleService interface {
    14  	GetCABundle(ctx context.Context, bannerID string, topLevelProject string) (*model.CaBundle, error)
    15  }
    16  
    17  type caBundleService struct {
    18  	gcpClientService GcpClientService
    19  	sqlDB            *sql.DB
    20  }
    21  
    22  const (
    23  	getCAPoolID = `SELECT ca_pool_edge_id FROM ca_pools WHERE banner_edge_id = $1`
    24  	getCert     = "SELECT cert_ref, status FROM ca_certificates WHERE ca_pool_edge_id = $1"
    25  )
    26  
    27  func (o *caBundleService) GetCABundle(ctx context.Context, bannerID string, topLevelProject string) (*model.CaBundle, error) {
    28  	response := &model.CaBundle{}
    29  	var caCerts []string
    30  	var poolID string
    31  
    32  	//First we need to use the provided bannerID to retrieve the CA pool ID
    33  	err := o.sqlDB.QueryRowContext(ctx, getCAPoolID, bannerID).Scan(&poolID)
    34  	if err != nil && err != sql.ErrNoRows {
    35  		return nil, fmt.Errorf("failed to retrieve CA pool: %w", err)
    36  	} else if err == sql.ErrNoRows {
    37  		return nil, fmt.Errorf("no CA pool found for banner ID: %s", bannerID)
    38  	}
    39  
    40  	// Now we have the pool ID, we can proceed to get the CA bundle
    41  	var caBundle string
    42  	var currentCert string
    43  	var status string
    44  
    45  	// Retrieve the retired CA certificates
    46  	// there could be many retired certs so loop through to add them
    47  	rows, err := o.sqlDB.QueryContext(ctx, getCert, poolID)
    48  	if err != nil && err != sql.ErrNoRows {
    49  		return nil, fmt.Errorf("failed to retrieve retired CA certificates: %w", err)
    50  	}
    51  	defer rows.Close()
    52  	for rows.Next() {
    53  		if err := rows.Scan(&currentCert, &status); err != nil {
    54  			return nil, fmt.Errorf("failed to scan CA certificate: %w", err)
    55  		}
    56  		certificate, err := o.retrieveCert(ctx, currentCert, topLevelProject)
    57  		if err != nil {
    58  			return nil, fmt.Errorf("failed to retrieve CA certificate: %w", err)
    59  		}
    60  
    61  		if status != "deleted" {
    62  			caBundle += certificate
    63  			caCerts = append(caCerts, certificate)
    64  		}
    65  	}
    66  
    67  	response.CaBundle = caBundle
    68  	response.CaCerts = caCerts
    69  	return response, nil
    70  }
    71  
    72  func (o *caBundleService) retrieveCert(ctx context.Context, certRef string, topLevelProject string) (string, error) {
    73  	// Find the last occurrence of '-'
    74  	lastDash := strings.LastIndex(certRef, "-")
    75  	if lastDash == -1 {
    76  		return "", fmt.Errorf("invalid format")
    77  	}
    78  
    79  	// Extract name and version
    80  	certName := certRef[:lastDash]      // Everything before the last '-'
    81  	certVersion := certRef[lastDash+1:] // Everything after the last '-'
    82  
    83  	secretManager, err := o.gcpClientService.GetSecretClient(ctx, topLevelProject)
    84  	if err != nil {
    85  		return "", fmt.Errorf("failed to get secret manager client: %w", err)
    86  	}
    87  	smResponse, err := secretManager.GetSecretVersionValue(ctx, certName, certVersion)
    88  	if err != nil {
    89  		return "", fmt.Errorf("failed to get secret version value: %w", err)
    90  	}
    91  	return string(smResponse), nil
    92  }
    93  
    94  func NewCABundleService(sqlDB *sql.DB, gcpClientService GcpClientService) CABundleService {
    95  	return &caBundleService{
    96  		sqlDB:            sqlDB,
    97  		gcpClientService: gcpClientService,
    98  	}
    99  }
   100  

View as plain text