...

Source file src/edge-infra.dev/pkg/edge/api/services/clustersecrets/secret.go

Documentation: edge-infra.dev/pkg/edge/api/services/clustersecrets

     1  package clustersecrets
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/google/uuid"
    14  
    15  	sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql"
    16  	"edge-infra.dev/pkg/edge/api/graph/model"
    17  	"edge-infra.dev/pkg/sds/clustersecrets"
    18  	cc "edge-infra.dev/pkg/sds/clustersecrets/common"
    19  )
    20  
    21  var (
    22  	ErrClusterSecretNotAdded   = errors.New("cluster secret not added")
    23  	ErrClusterSecretNotDeleted = errors.New("cluster secret version not deleted")
    24  	ErrClusterSecretNotUpdated = errors.New("cluster secret was not updated")
    25  	ErrClusterSecretNotExpired = errors.New("cluster secret was not expired")
    26  )
    27  
    28  // AddClusterSecret creates an entry in the db for a cluster secret
    29  func (s *clusterSecretService) AddClusterSecret(ctx context.Context, clusterSecret cc.ClusterSecret) error {
    30  	currentTime := time.Now().UTC()
    31  	expirationTime, err := s.getExpirationTime(currentTime)
    32  	if err != nil {
    33  		return err
    34  	}
    35  	result, err := s.SQLDB.ExecContext(ctx, AddClusterSecretQuery, clusterSecret.SecretEdgeID, clusterSecret.LeaseEdgeID, clusterSecret.Name, clusterSecret.Version, expirationTime, currentTime.Format(time.RFC3339), currentTime.Format(time.RFC3339), clusterSecret.Type.String())
    36  	if err != nil {
    37  		return err
    38  	}
    39  	rows, err := result.RowsAffected()
    40  	if err != nil {
    41  		return err
    42  	}
    43  	if rows != 1 {
    44  		return ErrClusterSecretNotAdded
    45  	}
    46  	return nil
    47  }
    48  
    49  // UpdateClusterSecret updates the cluster secret expiry time and version
    50  func (s *clusterSecretService) UpdateClusterSecret(ctx context.Context, clusterSecretEdgeID string, clusterSecretType model.ClusterSecretType, version string) error {
    51  	currentTime := time.Now().UTC()
    52  	expirationTime, err := s.getExpirationTime(currentTime)
    53  	if err != nil {
    54  		return err
    55  	}
    56  	result, err := s.SQLDB.ExecContext(ctx, UpdateClusterSecretQuery, expirationTime, version, clusterSecretEdgeID, clusterSecretType)
    57  	if err != nil {
    58  		return err
    59  	}
    60  	rows, err := result.RowsAffected()
    61  	if err != nil {
    62  		return err
    63  	}
    64  	if rows != 1 {
    65  		return ErrClusterSecretNotUpdated
    66  	}
    67  	return nil
    68  }
    69  
    70  // FetchClusterSecret retrieves the cluster secret from the db
    71  func (s *clusterSecretService) FetchClusterSecret(ctx context.Context, clusterEdgeID string, secretType model.ClusterSecretType) (cc.ClusterSecret, error) {
    72  	clusterSecret := cc.ClusterSecret{}
    73  	clusterSecret.Type = secretType
    74  
    75  	row := s.SQLDB.QueryRowContext(ctx, FetchClusterSecretQuery, clusterEdgeID, secretType.String())
    76  	if err := row.Scan(
    77  		&clusterSecret.SecretEdgeID,
    78  		&clusterSecret.LeaseEdgeID,
    79  		&clusterSecret.Name,
    80  		&clusterSecret.Version,
    81  		&clusterSecret.ExpirationTime,
    82  	); err != nil {
    83  		return clusterSecret, err
    84  	}
    85  	return clusterSecret, nil
    86  }
    87  
    88  // ExpireClusterSecret sets expiry time to now so clusterctl triggers new secret credentials
    89  func (s *clusterSecretService) ExpireClusterSecrets(ctx context.Context, clusterSecretLeaseEdgeID string) error {
    90  	expirationTime := time.Now().UTC().Format(time.RFC3339)
    91  	result, err := s.SQLDB.ExecContext(ctx, ExpireClusterSecretsQuery, expirationTime, clusterSecretLeaseEdgeID)
    92  	if err != nil {
    93  		return err
    94  	}
    95  	rows, err := result.RowsAffected()
    96  	if err != nil {
    97  		return err
    98  	}
    99  	if rows != 2 {
   100  		return ErrClusterSecretNotExpired
   101  	}
   102  	return nil
   103  }
   104  
   105  // FetchClusterSecretVersions retrieves the latest version and the expiry time of the cluster secret
   106  func (s *clusterSecretService) FetchClusterSecretVersions(ctx context.Context, clusterEdgeID string, secretType model.ClusterSecretType) ([]*model.ClusterSecretVersionInfo, error) {
   107  	versionInfos := []*model.ClusterSecretVersionInfo{}
   108  	notLatestTerminalVersions := map[string][]string{}
   109  	row := s.SQLDB.QueryRowContext(ctx, FetchSecretVersionInfoQuery, secretType.String(), clusterEdgeID)
   110  	var expiration, currentVersion, secretName string
   111  	if err := row.Scan(&expiration, &currentVersion); err != nil {
   112  		return nil, err
   113  	}
   114  	versionInfos = append(versionInfos, &model.ClusterSecretVersionInfo{
   115  		Version:   currentVersion,
   116  		ExpiresAt: expiration,
   117  	})
   118  	terminalSecrets, err := s.FetchLatestTerminalClusterSecrets(ctx, clusterEdgeID)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	secretName = clustersecrets.NameFromType(secretType)
   123  	// get terminal secret versions that aren't latest, and group them according to version
   124  	for _, terminalSec := range terminalSecrets {
   125  		if terminalSec.SecretType == secretName && terminalSec.Version != currentVersion {
   126  			notLatestTerminalVersions[terminalSec.Version] = append(notLatestTerminalVersions[terminalSec.Version], terminalSec.TerminalEdgeID)
   127  		}
   128  	}
   129  	for version, terminalIDs := range notLatestTerminalVersions {
   130  		versionInfos = append(versionInfos, &model.ClusterSecretVersionInfo{
   131  			Version:   version,
   132  			ExpiresAt: fmt.Sprintf("expires once latest version syncs to terminal(s) %s", strings.Join(terminalIDs, ", ")),
   133  		})
   134  	}
   135  	return versionInfos, nil
   136  }
   137  
   138  // VerifyClusterSecretExists checks the cluster secret exists and creates it if not
   139  func (s *clusterSecretService) VerifyClusterSecretExists(ctx context.Context, clusterEdgeID string, secret cc.Secret, leaseID string) error {
   140  	_, err := s.FetchClusterSecret(ctx, clusterEdgeID, secret.Type())
   141  	if !errors.Is(err, sql.ErrNoRows) {
   142  		return err
   143  	}
   144  	return s.createClusterSecret(ctx, clusterEdgeID, secret, leaseID)
   145  }
   146  
   147  func (s *clusterSecretService) CheckSecretIsExpired(ctx context.Context, clusterEdgeID string, clusterSecretType model.ClusterSecretType) (bool, error) {
   148  	clusterSecret, err := s.FetchClusterSecret(ctx, clusterEdgeID, clusterSecretType)
   149  	if err != nil {
   150  		return false, err
   151  	}
   152  	currentSecretExpirationTime, err := time.Parse(time.RFC3339, clusterSecret.ExpirationTime)
   153  	if err != nil {
   154  		return false, err
   155  	}
   156  	if currentSecretExpirationTime.Before(time.Now()) {
   157  		return true, nil
   158  	}
   159  	return false, nil
   160  }
   161  
   162  // createClusterSecret gets the values required for a cluster secret and requests for it to be created in the db
   163  func (s *clusterSecretService) createClusterSecret(ctx context.Context, clusterEdgeID string, secret cc.Secret, leaseID string) error {
   164  	var err error
   165  	if len(leaseID) == 0 {
   166  		leaseID, err = s.FetchLeaseID(ctx, clusterEdgeID)
   167  		if err != nil {
   168  			return err
   169  		}
   170  	}
   171  	clusterSecret := cc.ClusterSecret{
   172  		SecretEdgeID:   uuid.NewString(),
   173  		LeaseEdgeID:    leaseID,
   174  		Name:           secret.Names(clusterEdgeID)[1],
   175  		Version:        secret.Version(),
   176  		Type:           secret.Type(),
   177  		ExpirationTime: time.Now().Format(time.RFC3339),
   178  	}
   179  	return s.AddClusterSecret(ctx, clusterSecret)
   180  }
   181  
   182  // FetchLatestTerminalClusterSecrets attempts to fetch the last applied cluster secrets to each terminal including the hashed secret version.
   183  func (s *clusterSecretService) FetchLatestTerminalClusterSecrets(ctx context.Context, clusterEdgeID string) ([]cc.TerminalClusterSecret, error) {
   184  	terminalSecrets := []cc.TerminalClusterSecret{}
   185  
   186  	rows, err := s.SQLDB.QueryContext(ctx, FetchLatestTerminalClusterSecretsQuery, clusterEdgeID)
   187  	if err != nil {
   188  		return terminalSecrets, err
   189  	}
   190  	defer rows.Close()
   191  	for rows.Next() {
   192  		var terminalClusterSecret cc.TerminalClusterSecret
   193  		var value string
   194  		if err := rows.Scan(&value); err != nil {
   195  			return terminalSecrets, err
   196  		}
   197  		value = strings.Trim(value, "[")
   198  		value = strings.Trim(value, "]")
   199  		valueList := strings.Split(value, ",")
   200  		for i := range valueList {
   201  			valueList[i] = strings.Trim(valueList[i], `"`)
   202  		}
   203  		// check length of values is divisible by 3 to avoid index errors as we don't know how many secrets will be returned
   204  		if math.Mod(float64(len(valueList)), float64(3)) != 0 {
   205  			continue
   206  		}
   207  		for i := 0; i < len(valueList); i += 3 {
   208  			terminalClusterSecret.SecretType = valueList[i]
   209  			terminalClusterSecret.Version = valueList[i+1]
   210  			terminalClusterSecret.TerminalEdgeID = valueList[i+2]
   211  			terminalSecrets = append(terminalSecrets, terminalClusterSecret)
   212  		}
   213  	}
   214  	if err := rows.Err(); err != nil {
   215  		return nil, sqlerr.Wrap(err)
   216  	}
   217  	return terminalSecrets, nil
   218  }
   219  
   220  // getExpirationTime retrieves the time a cluster secret will expire
   221  func (s *clusterSecretService) getExpirationTime(currentTime time.Time) (string, error) {
   222  	var maxSecretValidityPeriod time.Duration
   223  	if strings.Contains(s.Config.EdgeMaxSecretValidityPeriod, "d") {
   224  		edgeValidityPeriod, err := strconv.Atoi(strings.TrimSuffix(s.Config.EdgeMaxSecretValidityPeriod, "d"))
   225  		if err != nil {
   226  			return "", err
   227  		}
   228  		maxSecretValidityPeriod = time.Duration(edgeValidityPeriod) * 24 * time.Hour
   229  	} else {
   230  		var err error
   231  		maxSecretValidityPeriod, err = time.ParseDuration(s.Config.EdgeMaxSecretValidityPeriod)
   232  		if err != nil {
   233  			return "", err
   234  		}
   235  	}
   236  	expirationTime := currentTime.Add(maxSecretValidityPeriod).Format(time.RFC3339)
   237  	return expirationTime, nil
   238  }
   239  

View as plain text