package clustersecrets import ( "context" "database/sql" "errors" "fmt" "math" "strconv" "strings" "time" "github.com/google/uuid" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" "edge-infra.dev/pkg/sds/clustersecrets" cc "edge-infra.dev/pkg/sds/clustersecrets/common" ) var ( ErrClusterSecretNotAdded = errors.New("cluster secret not added") ErrClusterSecretNotDeleted = errors.New("cluster secret version not deleted") ErrClusterSecretNotUpdated = errors.New("cluster secret was not updated") ErrClusterSecretNotExpired = errors.New("cluster secret was not expired") ) // AddClusterSecret creates an entry in the db for a cluster secret func (s *clusterSecretService) AddClusterSecret(ctx context.Context, clusterSecret cc.ClusterSecret) error { currentTime := time.Now().UTC() expirationTime, err := s.getExpirationTime(currentTime) if err != nil { return err } result, err := s.SQLDB.ExecContext(ctx, AddClusterSecretQuery, clusterSecret.SecretEdgeID, clusterSecret.LeaseEdgeID, clusterSecret.Name, clusterSecret.Version, expirationTime, currentTime.Format(time.RFC3339), currentTime.Format(time.RFC3339), clusterSecret.Type.String()) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return ErrClusterSecretNotAdded } return nil } // UpdateClusterSecret updates the cluster secret expiry time and version func (s *clusterSecretService) UpdateClusterSecret(ctx context.Context, clusterSecretEdgeID string, clusterSecretType model.ClusterSecretType, version string) error { currentTime := time.Now().UTC() expirationTime, err := s.getExpirationTime(currentTime) if err != nil { return err } result, err := s.SQLDB.ExecContext(ctx, UpdateClusterSecretQuery, expirationTime, version, clusterSecretEdgeID, clusterSecretType) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 1 { return ErrClusterSecretNotUpdated } return nil } // FetchClusterSecret retrieves the cluster secret from the db func (s *clusterSecretService) FetchClusterSecret(ctx context.Context, clusterEdgeID string, secretType model.ClusterSecretType) (cc.ClusterSecret, error) { clusterSecret := cc.ClusterSecret{} clusterSecret.Type = secretType row := s.SQLDB.QueryRowContext(ctx, FetchClusterSecretQuery, clusterEdgeID, secretType.String()) if err := row.Scan( &clusterSecret.SecretEdgeID, &clusterSecret.LeaseEdgeID, &clusterSecret.Name, &clusterSecret.Version, &clusterSecret.ExpirationTime, ); err != nil { return clusterSecret, err } return clusterSecret, nil } // ExpireClusterSecret sets expiry time to now so clusterctl triggers new secret credentials func (s *clusterSecretService) ExpireClusterSecrets(ctx context.Context, clusterSecretLeaseEdgeID string) error { expirationTime := time.Now().UTC().Format(time.RFC3339) result, err := s.SQLDB.ExecContext(ctx, ExpireClusterSecretsQuery, expirationTime, clusterSecretLeaseEdgeID) if err != nil { return err } rows, err := result.RowsAffected() if err != nil { return err } if rows != 2 { return ErrClusterSecretNotExpired } return nil } // FetchClusterSecretVersions retrieves the latest version and the expiry time of the cluster secret func (s *clusterSecretService) FetchClusterSecretVersions(ctx context.Context, clusterEdgeID string, secretType model.ClusterSecretType) ([]*model.ClusterSecretVersionInfo, error) { versionInfos := []*model.ClusterSecretVersionInfo{} notLatestTerminalVersions := map[string][]string{} row := s.SQLDB.QueryRowContext(ctx, FetchSecretVersionInfoQuery, secretType.String(), clusterEdgeID) var expiration, currentVersion, secretName string if err := row.Scan(&expiration, ¤tVersion); err != nil { return nil, err } versionInfos = append(versionInfos, &model.ClusterSecretVersionInfo{ Version: currentVersion, ExpiresAt: expiration, }) terminalSecrets, err := s.FetchLatestTerminalClusterSecrets(ctx, clusterEdgeID) if err != nil { return nil, err } secretName = clustersecrets.NameFromType(secretType) // get terminal secret versions that aren't latest, and group them according to version for _, terminalSec := range terminalSecrets { if terminalSec.SecretType == secretName && terminalSec.Version != currentVersion { notLatestTerminalVersions[terminalSec.Version] = append(notLatestTerminalVersions[terminalSec.Version], terminalSec.TerminalEdgeID) } } for version, terminalIDs := range notLatestTerminalVersions { versionInfos = append(versionInfos, &model.ClusterSecretVersionInfo{ Version: version, ExpiresAt: fmt.Sprintf("expires once latest version syncs to terminal(s) %s", strings.Join(terminalIDs, ", ")), }) } return versionInfos, nil } // VerifyClusterSecretExists checks the cluster secret exists and creates it if not func (s *clusterSecretService) VerifyClusterSecretExists(ctx context.Context, clusterEdgeID string, secret cc.Secret, leaseID string) error { _, err := s.FetchClusterSecret(ctx, clusterEdgeID, secret.Type()) if !errors.Is(err, sql.ErrNoRows) { return err } return s.createClusterSecret(ctx, clusterEdgeID, secret, leaseID) } func (s *clusterSecretService) CheckSecretIsExpired(ctx context.Context, clusterEdgeID string, clusterSecretType model.ClusterSecretType) (bool, error) { clusterSecret, err := s.FetchClusterSecret(ctx, clusterEdgeID, clusterSecretType) if err != nil { return false, err } currentSecretExpirationTime, err := time.Parse(time.RFC3339, clusterSecret.ExpirationTime) if err != nil { return false, err } if currentSecretExpirationTime.Before(time.Now()) { return true, nil } return false, nil } // createClusterSecret gets the values required for a cluster secret and requests for it to be created in the db func (s *clusterSecretService) createClusterSecret(ctx context.Context, clusterEdgeID string, secret cc.Secret, leaseID string) error { var err error if len(leaseID) == 0 { leaseID, err = s.FetchLeaseID(ctx, clusterEdgeID) if err != nil { return err } } clusterSecret := cc.ClusterSecret{ SecretEdgeID: uuid.NewString(), LeaseEdgeID: leaseID, Name: secret.Names(clusterEdgeID)[1], Version: secret.Version(), Type: secret.Type(), ExpirationTime: time.Now().Format(time.RFC3339), } return s.AddClusterSecret(ctx, clusterSecret) } // FetchLatestTerminalClusterSecrets attempts to fetch the last applied cluster secrets to each terminal including the hashed secret version. func (s *clusterSecretService) FetchLatestTerminalClusterSecrets(ctx context.Context, clusterEdgeID string) ([]cc.TerminalClusterSecret, error) { terminalSecrets := []cc.TerminalClusterSecret{} rows, err := s.SQLDB.QueryContext(ctx, FetchLatestTerminalClusterSecretsQuery, clusterEdgeID) if err != nil { return terminalSecrets, err } defer rows.Close() for rows.Next() { var terminalClusterSecret cc.TerminalClusterSecret var value string if err := rows.Scan(&value); err != nil { return terminalSecrets, err } value = strings.Trim(value, "[") value = strings.Trim(value, "]") valueList := strings.Split(value, ",") for i := range valueList { valueList[i] = strings.Trim(valueList[i], `"`) } // check length of values is divisible by 3 to avoid index errors as we don't know how many secrets will be returned if math.Mod(float64(len(valueList)), float64(3)) != 0 { continue } for i := 0; i < len(valueList); i += 3 { terminalClusterSecret.SecretType = valueList[i] terminalClusterSecret.Version = valueList[i+1] terminalClusterSecret.TerminalEdgeID = valueList[i+2] terminalSecrets = append(terminalSecrets, terminalClusterSecret) } } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return terminalSecrets, nil } // getExpirationTime retrieves the time a cluster secret will expire func (s *clusterSecretService) getExpirationTime(currentTime time.Time) (string, error) { var maxSecretValidityPeriod time.Duration if strings.Contains(s.Config.EdgeMaxSecretValidityPeriod, "d") { edgeValidityPeriod, err := strconv.Atoi(strings.TrimSuffix(s.Config.EdgeMaxSecretValidityPeriod, "d")) if err != nil { return "", err } maxSecretValidityPeriod = time.Duration(edgeValidityPeriod) * 24 * time.Hour } else { var err error maxSecretValidityPeriod, err = time.ParseDuration(s.Config.EdgeMaxSecretValidityPeriod) if err != nil { return "", err } } expirationTime := currentTime.Add(maxSecretValidityPeriod).Format(time.RFC3339) return expirationTime, nil }