...

Source file src/edge-infra.dev/pkg/edge/controllers/bannerctl/cert_management.go

Documentation: edge-infra.dev/pkg/edge/controllers/bannerctl

     1  package bannerctl
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/ecdsa"
     7  	"crypto/elliptic"
     8  	"crypto/rand"
     9  	"crypto/x509"
    10  	"crypto/x509/pkix"
    11  	"database/sql"
    12  	"encoding/base64"
    13  	"encoding/pem"
    14  	"errors"
    15  	"fmt"
    16  	"math/big"
    17  	"path/filepath"
    18  	"time"
    19  
    20  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    21  
    22  	bannerAPI "edge-infra.dev/pkg/edge/apis/banner/v1alpha1"
    23  	sequelApi "edge-infra.dev/pkg/edge/apis/sequel/k8s/v1alpha2"
    24  	"edge-infra.dev/pkg/edge/constants"
    25  	workloadApi "edge-infra.dev/pkg/edge/constants/api/workload"
    26  	"edge-infra.dev/pkg/k8s/runtime/controller/reconcile/recerr"
    27  	secretMgrApi "edge-infra.dev/pkg/lib/gcp/secretmanager"
    28  	"edge-infra.dev/pkg/lib/uuid"
    29  )
    30  
    31  const (
    32  	createCAPool   = "INSERT INTO ca_pools (banner_edge_id) VALUES ($1) RETURNING ca_pool_edge_id"
    33  	createCACert   = "INSERT INTO ca_certificates (ca_pool_edge_id, status, cert_ref, private_key_ref, expiration) VALUES ($1, $2, $3, $4, $5) RETURNING ca_cert_edge_id"
    34  	getActiveCert  = "SELECT ca_cert_edge_id, expiration FROM ca_certificates WHERE ca_pool_edge_id = $1 AND status = 'active'"
    35  	getStagedCert  = "SELECT ca_cert_edge_id, expiration FROM ca_certificates WHERE ca_pool_edge_id = $1 AND status = 'staged'"
    36  	getRetiredCert = "SELECT ca_cert_edge_id, expiration FROM ca_certificates WHERE ca_pool_edge_id = $1 AND status = 'retired'"
    37  	getCAPoolID    = `SELECT ca_pool_edge_id FROM ca_pools WHERE banner_edge_id = $1`
    38  	staged         = "staged"
    39  	active         = "active"
    40  	retired        = "retired"
    41  	deleted        = "deleted"
    42  )
    43  
    44  // reconcileCerts reconciles the CA pool and CA certs for a banner.
    45  func (r *BannerReconciler) reconcileCerts(ctx context.Context, banner *bannerAPI.Banner) recerr.Error {
    46  	log := r.Log.WithValues("banner", banner.Name)
    47  
    48  	// check or create CA pool per banner
    49  	poolID, err := r.reconcileCAPool(ctx, banner.Name)
    50  	if err != nil {
    51  		return recerr.New(err, bannerAPI.PlatformSecretsCreationFailedReason)
    52  	}
    53  
    54  	// check or create CA cert per banner
    55  	err = r.reconcileCACerts(ctx, banner, poolID)
    56  	if err != nil {
    57  		return recerr.New(err, bannerAPI.PlatformSecretsCreationFailedReason)
    58  	}
    59  
    60  	log.Info("Successfully reconciled CA pools and certs for banner", "banner", banner.Name)
    61  	return nil
    62  }
    63  
    64  // reconcileCAPool reconciles the CA pool for a banner.
    65  func (r *BannerReconciler) reconcileCAPool(ctx context.Context, bannerID string) (string, error) {
    66  	tx, err := r.EdgeDB.BeginTx(ctx, &sql.TxOptions{})
    67  	if err != nil {
    68  		return "", err
    69  	}
    70  
    71  	defer func() {
    72  		if err != nil {
    73  			err = errors.Join(err, tx.Rollback())
    74  		}
    75  	}()
    76  
    77  	log := r.Log.WithValues("banner", bannerID)
    78  
    79  	var poolID string
    80  	err = tx.QueryRowContext(ctx, getCAPoolID, bannerID).Scan(&poolID)
    81  	if err != nil && err != sql.ErrNoRows {
    82  		return "", fmt.Errorf("failed to retrieve CA pool: %w", err)
    83  	} else if err == nil && poolID != "" {
    84  		log.Info("CA pool already exists", "caPool", poolID)
    85  		return poolID, nil
    86  	}
    87  
    88  	err = tx.QueryRowContext(ctx, createCAPool, bannerID).Scan(&poolID)
    89  	if err != nil {
    90  		return "", fmt.Errorf("failed to insert CA pool into database: %w", err)
    91  	}
    92  
    93  	if err = tx.Commit(); err != nil {
    94  		return "", err
    95  	}
    96  
    97  	log.Info("Created CA pool", "caPool", poolID)
    98  	return poolID, nil
    99  }
   100  
   101  // reconcileCACerts reconciles the CA certs for a banner.
   102  // When a CA is approaching expiry (1 year left), a new CA cert is generated, with status of staged
   103  // After 3 months (tbd) of the new CA cert existing, the old CA cert is set to retired and the new CA cert is marked as active
   104  // If the CA cert is expired, it is marked as deleted
   105  func (r *BannerReconciler) reconcileCACerts(ctx context.Context, b *bannerAPI.Banner, caPool string) error {
   106  	tx, err := r.EdgeDB.BeginTx(ctx, &sql.TxOptions{})
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	defer func() {
   112  		if err != nil {
   113  			err = errors.Join(err, tx.Rollback())
   114  		}
   115  	}()
   116  
   117  	log := r.Log.WithValues("banner", b.Name)
   118  
   119  	// STEP 1 : Check if there is an active cert
   120  	// if there is no active cert, check if there is a staged cert
   121  	// if there is a staged cert make that active
   122  	// if there is no staged cert, create a new active cert
   123  
   124  	var activeCertID string
   125  	var activeCertExpiration time.Time
   126  	var stagedCertID string
   127  	var stagedCertExpiration time.Time
   128  	// This will retrieve the active cert for a given CA pool edge ID
   129  	// if there is no active cert, it will create one
   130  	err = tx.QueryRowContext(ctx, getActiveCert, caPool).Scan(&activeCertID, &activeCertExpiration)
   131  	if err != nil && err != sql.ErrNoRows {
   132  		return fmt.Errorf("failed to query active CA cert: %w", err)
   133  	} else if err == sql.ErrNoRows {
   134  		// if there is no active but there is a staged then we want to make that active
   135  		err = tx.QueryRowContext(ctx, getStagedCert, caPool).Scan(&stagedCertID, &activeCertExpiration)
   136  		if err != nil && err != sql.ErrNoRows {
   137  			return fmt.Errorf("failed to query staged CA cert: %w", err)
   138  		} else if err == sql.ErrNoRows {
   139  			// if there is no active and no staged then we want to create a new active
   140  			return r.createCA(ctx, b, caPool, active)
   141  		}
   142  		// if there is no errors then we want to make the staged cert active
   143  		// this should ideally not be used, but is a catch for if the one cycle rotation fails for any reason
   144  		err = updateCertificateStatus(ctx, tx, stagedCertID, active)
   145  		if err != nil {
   146  			return fmt.Errorf("failed to update staged certificate status: %w", err)
   147  		}
   148  	}
   149  
   150  	//STEP 2 : Check if there is a staged cert
   151  
   152  	var stagedNeeded bool
   153  	var stagedExists bool
   154  	// This will retrieve the staged cert for a given CA pool edge ID
   155  	// if there is a staged cert make stagedExists true so we dont create another one
   156  	err = tx.QueryRowContext(ctx, getStagedCert, caPool).Scan(&stagedCertID, &stagedCertExpiration)
   157  	if err != nil && err != sql.ErrNoRows {
   158  		return fmt.Errorf("failed to query staged CA cert: %w", err)
   159  	} else if err == sql.ErrNoRows {
   160  		stagedExists = false
   161  	} else {
   162  		stagedExists = true
   163  	}
   164  
   165  	// STEP 3 : Check if the active cert is expired or close to expiration
   166  	// if it is close to expiration, check if there is a staged cert
   167  	// if there is a staged cert, make the active cert retired and the staged cert active
   168  	// if there is no staged cert, create a new staged cert
   169  
   170  	// if the active cert is less than 9 months away from expiration and there is a staged cert
   171  	// we will mark the active cert as retired and the staged cert as active
   172  	// if the active cert is less than 1 year & 3 months away from expiration and there is no staged cert
   173  	// we will mark stagedNeeded as true so that we can create a new staged cert
   174  	if activeCertExpiration.Before(time.Now().AddDate(1, 3, 0)) &&
   175  		activeCertExpiration.Before(time.Now().AddDate(0, 9, 0)) && stagedExists {
   176  		//update the current active cert to retired
   177  		err = updateCertificateStatus(ctx, tx, activeCertID, retired)
   178  		if err != nil {
   179  			return fmt.Errorf("failed to update certificate status: %w", err)
   180  		}
   181  		log.Info("Updated active cert to retired", "certID", activeCertID)
   182  		//update the current staged cert to active
   183  		//this allows us to perform the rotation in one cycle
   184  		//if the above step fails for some reason then staged will become active on the next cycle
   185  		err = updateCertificateStatus(ctx, tx, stagedCertID, active)
   186  		if err != nil {
   187  			return fmt.Errorf("failed to update certificate status: %w", err)
   188  		}
   189  		log.Info("Updated staged cert to active", "certID", stagedCertID)
   190  	} else if activeCertExpiration.Before(time.Now().AddDate(1, 3, 0)) {
   191  		stagedNeeded = true
   192  	}
   193  
   194  	// if we need a staged cert and one does not exist, create one
   195  	if stagedNeeded && !stagedExists {
   196  		err = r.createCA(ctx, b, caPool, staged)
   197  		if err != nil {
   198  			return fmt.Errorf("failed to create CA cert: %w", err)
   199  		}
   200  	}
   201  
   202  	// STEP 4 : Check if the retired cert is expired
   203  	// if it is expired, mark it as deleted
   204  
   205  	var retiredCertID string
   206  	var retiredCertExpiration time.Time
   207  	// This will retrieve the retired cert for a given CA pool edge ID
   208  	// if there is a retired cert and it is expired, it will be marked as deleted
   209  	err = tx.QueryRowContext(ctx, getRetiredCert, caPool).Scan(&retiredCertID, &retiredCertExpiration)
   210  	if err != nil && err != sql.ErrNoRows {
   211  		return fmt.Errorf("failed to query retired CA cert: %w", err)
   212  	} else if err == nil {
   213  		if retiredCertExpiration.Before(time.Now()) {
   214  			err = updateCertificateStatus(ctx, tx, retiredCertID, deleted)
   215  			if err != nil {
   216  				return fmt.Errorf("failed to update certificate status: %w", err)
   217  			}
   218  			log.Info("Updated retired cert to deleted", "certID", retiredCertID)
   219  		}
   220  	}
   221  
   222  	if err = tx.Commit(); err != nil {
   223  		return err
   224  	}
   225  
   226  	log.Info("Successfully reconciled certs")
   227  	return nil
   228  }
   229  
   230  func (r *BannerReconciler) createCA(ctx context.Context, b *bannerAPI.Banner, caPoolID string, status string) error {
   231  	log := r.Log.WithValues("banner", b.Name)
   232  
   233  	// The below CA template is mimicking the CA that is created by emissary
   234  	expiration := time.Now().AddDate(5, 0, 0) // Valid for 5 years
   235  	template := &x509.Certificate{
   236  		SerialNumber: big.NewInt(1),
   237  		Subject: pkix.Name{
   238  			CommonName: "emissary-ca",
   239  		},
   240  		NotBefore:             time.Now(),
   241  		NotAfter:              expiration,
   242  		KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
   243  		ExtKeyUsage:           []x509.ExtKeyUsage{},
   244  		IsCA:                  true,
   245  		BasicConstraintsValid: true,
   246  	}
   247  
   248  	cert, privateKey, err := generateCert(template)
   249  	if err != nil {
   250  		return recerr.New(err, bannerAPI.PlatformSecretsCreationFailedReason)
   251  	}
   252  
   253  	// creating a private key ref for the cert, which will be stored in SM
   254  	privateKeySecretName := "ca-private-key-" + b.Name
   255  	version, err := r.storeInSecretManager(ctx, privateKey, privateKeySecretName)
   256  	if err != nil {
   257  		return fmt.Errorf("failed to store CA cert in Secret Manager: %w", err)
   258  	}
   259  
   260  	// creating a cert ref for the CA cert, which will be stored in SM
   261  	certSecretName := "ca-cert-" + b.Name
   262  	certVersion, err := r.storeInSecretManager(ctx, cert, certSecretName)
   263  	if err != nil {
   264  		return fmt.Errorf("failed to store CA cert in Secret Manager: %w", err)
   265  	}
   266  
   267  	log.Info("Successfully stored CA cert in Secret Manager", "certVersion", certVersion)
   268  
   269  	// generate refs based off name + version
   270  	privateKeyRef := privateKeySecretName + "-" + version
   271  	certRef := certSecretName + "-" + certVersion
   272  
   273  	// add record to database
   274  	certID, err := r.addRecordToDatabase(ctx, caPoolID, certRef, privateKeyRef, status, expiration)
   275  	if err != nil {
   276  		return fmt.Errorf("failed to add CA cert to database: %w", err)
   277  	}
   278  
   279  	log.Info("Successfully created CA cert in database", "CA Cert ID", certID)
   280  	return nil
   281  }
   282  
   283  // generates a cert and returns the cert and private key in base64 format for storage in SM
   284  func generateCert(template *x509.Certificate) ([]byte, []byte, error) {
   285  	// Create private key
   286  	privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   287  	if err != nil {
   288  		return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
   289  	}
   290  
   291  	// Generate the certificate
   292  	caBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
   293  	if err != nil {
   294  		return nil, nil, fmt.Errorf("failed to create certificate: %v", err)
   295  	}
   296  
   297  	// Encode the certificate to PEM format
   298  	caPEM := new(bytes.Buffer)
   299  	err = pem.Encode(caPEM, &pem.Block{
   300  		Type:  "CERTIFICATE",
   301  		Bytes: caBytes,
   302  	})
   303  	if err != nil {
   304  		return nil, nil, fmt.Errorf("failed to encode certificate to PEM: %v", err)
   305  	}
   306  
   307  	// Encode the private key to PEM format
   308  	caPrivKeyPEM := new(bytes.Buffer)
   309  	privBytes, err := x509.MarshalECPrivateKey(privKey)
   310  	if err != nil {
   311  		return nil, nil, fmt.Errorf("failed to marshal private key: %v", err)
   312  	}
   313  	err = pem.Encode(caPrivKeyPEM, &pem.Block{
   314  		Type:  "EC PRIVATE KEY",
   315  		Bytes: privBytes,
   316  	})
   317  	if err != nil {
   318  		return nil, nil, fmt.Errorf("failed to encode private key to PEM: %v", err)
   319  	}
   320  
   321  	// convert to B64 for secrets -> the LS0tLS1CRUdJTiBSU0EgUFJ.... format
   322  	caCertBase64 := base64.StdEncoding.EncodeToString(caPEM.Bytes())
   323  	caPrivKeyBase64 := base64.StdEncoding.EncodeToString(caPrivKeyPEM.Bytes())
   324  
   325  	return []byte(caCertBase64), []byte(caPrivKeyBase64), nil
   326  }
   327  
   328  func (r *BannerReconciler) addRecordToDatabase(ctx context.Context, caPoolID string, certRef string, privateKeyRef string, status string, expiration time.Time) (string, error) {
   329  	tx, err := r.EdgeDB.BeginTx(ctx, &sql.TxOptions{})
   330  	if err != nil {
   331  		return "", err
   332  	}
   333  
   334  	defer func() {
   335  		if err != nil {
   336  			err = errors.Join(err, tx.Rollback())
   337  		}
   338  	}()
   339  
   340  	var caCertID string
   341  	err = tx.QueryRowContext(ctx, createCACert, caPoolID, status, certRef, privateKeyRef, expiration).Scan(&caCertID)
   342  	if err != nil {
   343  		return "", fmt.Errorf("failed to insert CA cert into database: %w", err)
   344  	}
   345  
   346  	if err = tx.Commit(); err != nil {
   347  		return "", err
   348  	}
   349  
   350  	return caCertID, nil
   351  }
   352  
   353  func (r *BannerReconciler) storeInSecretManager(ctx context.Context, secretData []byte, name string) (string, error) {
   354  	// Create a new SecretManagerClient
   355  	smClient, err := r.SecretManager.NewWithOptions(ctx, r.ForemanProjectID)
   356  	if err != nil {
   357  		return "", fmt.Errorf("error creating secretmanager writer client, err: %v", err)
   358  	}
   359  	labels := map[string]string{
   360  		secretMgrApi.SecretLabel:                  string(workloadApi.Platform),
   361  		secretMgrApi.SecretTypeLabel:              "banner-ca",
   362  		secretMgrApi.SecretOwnerLabel:             "edge",
   363  		secretMgrApi.SecretNamespaceSelectorLabel: string(constants.PlatformNamespaceSelector),
   364  	}
   365  	err = smClient.AddSecret(ctx, name, secretData, labels, false, nil, "")
   366  	if err != nil {
   367  		return "", fmt.Errorf("error adding secret, secretID: %v, err: %v", name, err)
   368  	}
   369  
   370  	fullName, err := smClient.GetLatestSecretValueInfo(ctx, name)
   371  	if err != nil {
   372  		return "", fmt.Errorf("error getting secret version, secretID: %v, err: %v", name, err)
   373  	}
   374  
   375  	return filepath.Base(fullName.GetName()), nil
   376  }
   377  
   378  func updateCertificateStatus(ctx context.Context, tx *sql.Tx, certID string, status string) error {
   379  	_, err := tx.ExecContext(ctx, "UPDATE ca_certificates SET status = $1 WHERE ca_cert_edge_id = $2", status, certID)
   380  	if err != nil {
   381  		return fmt.Errorf("failed to update certificate status: %w", err)
   382  	}
   383  
   384  	return nil
   385  }
   386  
   387  func (r *BannerReconciler) createEdgeIssuerDatabaseUser(b *bannerAPI.Banner) *sequelApi.DatabaseUser {
   388  	hash := uuid.FromUUID(b.Status.ClusterInfraClusterEdgeID).Hash()
   389  	edgeIssuerSAName := fmt.Sprintf("issuer-%s", hash)
   390  	iamUsername := fmt.Sprintf("%s@%s.iam", edgeIssuerSAName, b.Spec.GCP.ProjectID)
   391  
   392  	grant := sequelApi.Grant{
   393  		Schema: "public",
   394  		TableGrant: []sequelApi.TableGrant{
   395  			{
   396  				Table: "ca_pools",
   397  				Permissions: []sequelApi.Permissions{
   398  					{
   399  						Permission: "SELECT",
   400  					},
   401  					{
   402  						Permission: "INSERT",
   403  					},
   404  					{
   405  						Permission: "UPDATE",
   406  					},
   407  				},
   408  			},
   409  			{
   410  				Table: "ca_certificates",
   411  				Permissions: []sequelApi.Permissions{
   412  					{
   413  						Permission: "SELECT",
   414  					},
   415  					{
   416  						Permission: "INSERT",
   417  					},
   418  					{
   419  						Permission: "UPDATE",
   420  					},
   421  				},
   422  			},
   423  		},
   424  	}
   425  
   426  	return &sequelApi.DatabaseUser{
   427  		TypeMeta: gvkToTypeMeta(sequelApi.UserGVK),
   428  		ObjectMeta: metav1.ObjectMeta{
   429  			Name:      edgeIssuerSAName,
   430  			Namespace: b.Name,
   431  		},
   432  		Spec: sequelApi.UserSpec{
   433  			Type: sequelApi.CloudSAUserType,
   434  			CommonOptions: sequelApi.CommonOptions{
   435  				Prune: true,
   436  				Force: true,
   437  			},
   438  			InstanceRef: sequelApi.InstanceReference{
   439  				Name:      r.DatabaseName + dbInstance,
   440  				ProjectID: r.ForemanProjectID,
   441  			},
   442  			ServiceAccount: &sequelApi.ServiceAccount{
   443  				EmailRef:    fmt.Sprintf("%s.gserviceaccount.com", iamUsername),
   444  				IAMUsername: iamUsername,
   445  			},
   446  			Grants: []sequelApi.Grant{
   447  				grant,
   448  			},
   449  		},
   450  	}
   451  }
   452  

View as plain text