1 package clustersecrets
2
3 import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "math"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/google/uuid"
14
15 sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql"
16 "edge-infra.dev/pkg/edge/api/graph/model"
17 "edge-infra.dev/pkg/sds/clustersecrets"
18 cc "edge-infra.dev/pkg/sds/clustersecrets/common"
19 )
20
21 var (
22 ErrClusterSecretNotAdded = errors.New("cluster secret not added")
23 ErrClusterSecretNotDeleted = errors.New("cluster secret version not deleted")
24 ErrClusterSecretNotUpdated = errors.New("cluster secret was not updated")
25 ErrClusterSecretNotExpired = errors.New("cluster secret was not expired")
26 )
27
28
29 func (s *clusterSecretService) AddClusterSecret(ctx context.Context, clusterSecret cc.ClusterSecret) error {
30 currentTime := time.Now().UTC()
31 expirationTime, err := s.getExpirationTime(currentTime)
32 if err != nil {
33 return err
34 }
35 result, err := s.SQLDB.ExecContext(ctx, AddClusterSecretQuery, clusterSecret.SecretEdgeID, clusterSecret.LeaseEdgeID, clusterSecret.Name, clusterSecret.Version, expirationTime, currentTime.Format(time.RFC3339), currentTime.Format(time.RFC3339), clusterSecret.Type.String())
36 if err != nil {
37 return err
38 }
39 rows, err := result.RowsAffected()
40 if err != nil {
41 return err
42 }
43 if rows != 1 {
44 return ErrClusterSecretNotAdded
45 }
46 return nil
47 }
48
49
50 func (s *clusterSecretService) UpdateClusterSecret(ctx context.Context, clusterSecretEdgeID string, clusterSecretType model.ClusterSecretType, version string) error {
51 currentTime := time.Now().UTC()
52 expirationTime, err := s.getExpirationTime(currentTime)
53 if err != nil {
54 return err
55 }
56 result, err := s.SQLDB.ExecContext(ctx, UpdateClusterSecretQuery, expirationTime, version, clusterSecretEdgeID, clusterSecretType)
57 if err != nil {
58 return err
59 }
60 rows, err := result.RowsAffected()
61 if err != nil {
62 return err
63 }
64 if rows != 1 {
65 return ErrClusterSecretNotUpdated
66 }
67 return nil
68 }
69
70
71 func (s *clusterSecretService) FetchClusterSecret(ctx context.Context, clusterEdgeID string, secretType model.ClusterSecretType) (cc.ClusterSecret, error) {
72 clusterSecret := cc.ClusterSecret{}
73 clusterSecret.Type = secretType
74
75 row := s.SQLDB.QueryRowContext(ctx, FetchClusterSecretQuery, clusterEdgeID, secretType.String())
76 if err := row.Scan(
77 &clusterSecret.SecretEdgeID,
78 &clusterSecret.LeaseEdgeID,
79 &clusterSecret.Name,
80 &clusterSecret.Version,
81 &clusterSecret.ExpirationTime,
82 ); err != nil {
83 return clusterSecret, err
84 }
85 return clusterSecret, nil
86 }
87
88
89 func (s *clusterSecretService) ExpireClusterSecrets(ctx context.Context, clusterSecretLeaseEdgeID string) error {
90 expirationTime := time.Now().UTC().Format(time.RFC3339)
91 result, err := s.SQLDB.ExecContext(ctx, ExpireClusterSecretsQuery, expirationTime, clusterSecretLeaseEdgeID)
92 if err != nil {
93 return err
94 }
95 rows, err := result.RowsAffected()
96 if err != nil {
97 return err
98 }
99 if rows != 2 {
100 return ErrClusterSecretNotExpired
101 }
102 return nil
103 }
104
105
106 func (s *clusterSecretService) FetchClusterSecretVersions(ctx context.Context, clusterEdgeID string, secretType model.ClusterSecretType) ([]*model.ClusterSecretVersionInfo, error) {
107 versionInfos := []*model.ClusterSecretVersionInfo{}
108 notLatestTerminalVersions := map[string][]string{}
109 row := s.SQLDB.QueryRowContext(ctx, FetchSecretVersionInfoQuery, secretType.String(), clusterEdgeID)
110 var expiration, currentVersion, secretName string
111 if err := row.Scan(&expiration, ¤tVersion); err != nil {
112 return nil, err
113 }
114 versionInfos = append(versionInfos, &model.ClusterSecretVersionInfo{
115 Version: currentVersion,
116 ExpiresAt: expiration,
117 })
118 terminalSecrets, err := s.FetchLatestTerminalClusterSecrets(ctx, clusterEdgeID)
119 if err != nil {
120 return nil, err
121 }
122 secretName = clustersecrets.NameFromType(secretType)
123
124 for _, terminalSec := range terminalSecrets {
125 if terminalSec.SecretType == secretName && terminalSec.Version != currentVersion {
126 notLatestTerminalVersions[terminalSec.Version] = append(notLatestTerminalVersions[terminalSec.Version], terminalSec.TerminalEdgeID)
127 }
128 }
129 for version, terminalIDs := range notLatestTerminalVersions {
130 versionInfos = append(versionInfos, &model.ClusterSecretVersionInfo{
131 Version: version,
132 ExpiresAt: fmt.Sprintf("expires once latest version syncs to terminal(s) %s", strings.Join(terminalIDs, ", ")),
133 })
134 }
135 return versionInfos, nil
136 }
137
138
139 func (s *clusterSecretService) VerifyClusterSecretExists(ctx context.Context, clusterEdgeID string, secret cc.Secret, leaseID string) error {
140 _, err := s.FetchClusterSecret(ctx, clusterEdgeID, secret.Type())
141 if !errors.Is(err, sql.ErrNoRows) {
142 return err
143 }
144 return s.createClusterSecret(ctx, clusterEdgeID, secret, leaseID)
145 }
146
147 func (s *clusterSecretService) CheckSecretIsExpired(ctx context.Context, clusterEdgeID string, clusterSecretType model.ClusterSecretType) (bool, error) {
148 clusterSecret, err := s.FetchClusterSecret(ctx, clusterEdgeID, clusterSecretType)
149 if err != nil {
150 return false, err
151 }
152 currentSecretExpirationTime, err := time.Parse(time.RFC3339, clusterSecret.ExpirationTime)
153 if err != nil {
154 return false, err
155 }
156 if currentSecretExpirationTime.Before(time.Now()) {
157 return true, nil
158 }
159 return false, nil
160 }
161
162
163 func (s *clusterSecretService) createClusterSecret(ctx context.Context, clusterEdgeID string, secret cc.Secret, leaseID string) error {
164 var err error
165 if len(leaseID) == 0 {
166 leaseID, err = s.FetchLeaseID(ctx, clusterEdgeID)
167 if err != nil {
168 return err
169 }
170 }
171 clusterSecret := cc.ClusterSecret{
172 SecretEdgeID: uuid.NewString(),
173 LeaseEdgeID: leaseID,
174 Name: secret.Names(clusterEdgeID)[1],
175 Version: secret.Version(),
176 Type: secret.Type(),
177 ExpirationTime: time.Now().Format(time.RFC3339),
178 }
179 return s.AddClusterSecret(ctx, clusterSecret)
180 }
181
182
183 func (s *clusterSecretService) FetchLatestTerminalClusterSecrets(ctx context.Context, clusterEdgeID string) ([]cc.TerminalClusterSecret, error) {
184 terminalSecrets := []cc.TerminalClusterSecret{}
185
186 rows, err := s.SQLDB.QueryContext(ctx, FetchLatestTerminalClusterSecretsQuery, clusterEdgeID)
187 if err != nil {
188 return terminalSecrets, err
189 }
190 defer rows.Close()
191 for rows.Next() {
192 var terminalClusterSecret cc.TerminalClusterSecret
193 var value string
194 if err := rows.Scan(&value); err != nil {
195 return terminalSecrets, err
196 }
197 value = strings.Trim(value, "[")
198 value = strings.Trim(value, "]")
199 valueList := strings.Split(value, ",")
200 for i := range valueList {
201 valueList[i] = strings.Trim(valueList[i], `"`)
202 }
203
204 if math.Mod(float64(len(valueList)), float64(3)) != 0 {
205 continue
206 }
207 for i := 0; i < len(valueList); i += 3 {
208 terminalClusterSecret.SecretType = valueList[i]
209 terminalClusterSecret.Version = valueList[i+1]
210 terminalClusterSecret.TerminalEdgeID = valueList[i+2]
211 terminalSecrets = append(terminalSecrets, terminalClusterSecret)
212 }
213 }
214 if err := rows.Err(); err != nil {
215 return nil, sqlerr.Wrap(err)
216 }
217 return terminalSecrets, nil
218 }
219
220
221 func (s *clusterSecretService) getExpirationTime(currentTime time.Time) (string, error) {
222 var maxSecretValidityPeriod time.Duration
223 if strings.Contains(s.Config.EdgeMaxSecretValidityPeriod, "d") {
224 edgeValidityPeriod, err := strconv.Atoi(strings.TrimSuffix(s.Config.EdgeMaxSecretValidityPeriod, "d"))
225 if err != nil {
226 return "", err
227 }
228 maxSecretValidityPeriod = time.Duration(edgeValidityPeriod) * 24 * time.Hour
229 } else {
230 var err error
231 maxSecretValidityPeriod, err = time.ParseDuration(s.Config.EdgeMaxSecretValidityPeriod)
232 if err != nil {
233 return "", err
234 }
235 }
236 expirationTime := currentTime.Add(maxSecretValidityPeriod).Format(time.RFC3339)
237 return expirationTime, nil
238 }
239
View as plain text