1 package sql
2
3 import (
4 "context"
5 "database/sql"
6 "embed"
7 "errors"
8 "fmt"
9 "io/fs"
10 "os"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/google/uuid"
16 _ "github.com/jackc/pgx/v4/stdlib"
17
18 "edge-infra.dev/pkg/f8n/kinform/model"
19 sovereign "edge-infra.dev/pkg/f8n/sovereign/model"
20 )
21
22 type DBHandle struct {
23 *sql.DB
24 ClusterID string
25 }
26
27
28 var schemaFS embed.FS
29
30 const postgres = "postgres"
31
32 func FromDSN(dsn string, maxOpenConns, maxIdleConns int) (*DBHandle, error) {
33 db, err := sql.Open("pgx", dsn)
34 if err != nil {
35 return nil, fmt.Errorf("failed to open db with pgx driver: %w", err)
36 }
37
38 err = db.Ping()
39 if err != nil {
40 return nil, fmt.Errorf("failed to ping database: %w", err)
41 }
42
43 err = execSchema(db)
44 if err != nil {
45 return nil, fmt.Errorf("failed to execute schema: %w", err)
46 }
47
48 db.SetMaxOpenConns(maxOpenConns)
49 db.SetMaxIdleConns(maxIdleConns)
50 db.SetConnMaxIdleTime(time.Minute)
51
52 return &DBHandle{DB: db}, nil
53 }
54
55 func FromEnv() (*DBHandle, error) {
56 user, ok := os.LookupEnv("DB_USER")
57 if !ok {
58 user = postgres
59 }
60 password, ok := os.LookupEnv("DB_PASS")
61 if !ok {
62 password = ""
63 }
64 host, ok := os.LookupEnv("DB_HOST")
65 if !ok {
66 host = "127.0.0.1"
67 }
68 port, ok := os.LookupEnv("DB_PORT")
69 if !ok {
70 port = "5432"
71 }
72 dbName, ok := os.LookupEnv("DB_NAME")
73 if !ok {
74 dbName = postgres
75 }
76
77 var dbMaxConns int
78 dbMaxConnsStr, ok := os.LookupEnv("DB_MAX_CONNS")
79 if ok {
80 dbMaxConnsParsed, err := strconv.Atoi(dbMaxConnsStr)
81 if err != nil {
82 return nil, fmt.Errorf("failed to parse DB_MAX_CONNS: %w", err)
83 }
84 dbMaxConns = dbMaxConnsParsed
85 } else {
86
87
88 dbMaxConns = 450
89 }
90
91 dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s", host, port, user, password, dbName)
92
93 return FromDSN(dsn, dbMaxConns, dbMaxConns)
94 }
95
96 func execSchema(db *sql.DB) error {
97 schemaSQL, err := fs.ReadFile(schemaFS, "schema.sql")
98 if err != nil {
99 return err
100 }
101 _, err = db.Exec(string(schemaSQL))
102 if err != nil {
103 return err
104 }
105 return nil
106 }
107
108 func (db *DBHandle) InsertCluster(ctx context.Context, name string) (uuid.UUID, error) {
109 insertSQL := `
110 INSERT INTO clusters(name, version_major, version_minor)
111 VALUES ($1, $2, $3)
112 RETURNING id
113 `
114
115 var id uuid.UUID
116
117 row := db.QueryRowContext(ctx, insertSQL, name, 0, 0)
118 err := row.Scan(&id)
119 if err != nil {
120 return uuid.Nil, err
121 }
122
123 return id, nil
124 }
125
126 func (db *DBHandle) InsertResource(ctx context.Context, resource model.WatchedResource) error {
127 insertSQL := `
128 INSERT INTO watched_resources(api_version, kind, resource, cluster)
129 VALUES ($1, $2, $3, $4)
130 ON CONFLICT ((resource['metadata']['uid']))
131 DO UPDATE SET api_version = $1, kind = $2, resource = $3
132 `
133
134 _, err := db.ExecContext(ctx, insertSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.Cluster)
135 if err != nil {
136 return err
137 }
138
139 return nil
140 }
141
142 func (db *DBHandle) UpdateResource(ctx context.Context, resource model.WatchedResource) error {
143
144 updateSQL := `UPDATE watched_resources SET api_version=$1, kind=$2, resource=$3 WHERE resource['metadata']['uid'] = $4`
145
146 _, err := db.ExecContext(ctx, updateSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.MetadataUID)
147 if err != nil {
148 return err
149 }
150
151 return nil
152 }
153
154 func (db *DBHandle) DeleteResource(ctx context.Context, resource model.WatchedResource) error {
155 deleteSQL := "DELETE FROM watched_resources WHERE resource['metadata']['uid'] = $1"
156
157 _, err := db.ExecContext(ctx, deleteSQL, resource.MetadataUID)
158 if err != nil {
159 return err
160 }
161
162 return nil
163 }
164
165
166 func (db *DBHandle) InsertResourceObservation(ctx context.Context, resource model.WatchedResource) error {
167 insertSQL := `INSERT INTO watched_resource_observations(api_version, kind, resource, cluster) VALUES ($1, $2, $3, $4)`
168
169 _, err := db.ExecContext(ctx, insertSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.Cluster)
170 if err != nil {
171 return err
172 }
173
174 return nil
175 }
176
177
178
179 func (db *DBHandle) InsertArtifactObserved(ctx context.Context, image string) error {
180 observeSQL := `
181 WITH
182 artifact_id AS (
183 SELECT fn_artifact_id_for($1) as id
184 ),
185 artifact_version_id AS (
186 SELECT fn_artifact_version_id_for((SELECT id FROM artifact_id), $2, $3) as id
187 )
188 INSERT INTO observed_states (cluster, artifact_version)
189 VALUES ($4, (SELECT id FROM artifact_version_id))
190 ON CONFLICT (cluster, artifact_version) DO UPDATE SET observed_at = NOW()
191 `
192
193 ss := strings.Split(image, "/")
194 imageString := ss[1]
195 tag := ""
196 sha256Digest := "0000000000000000000000000000000000000000000000000000000000000000"
197 if strings.Contains(imageString, ":") {
198 tagSplit := strings.Split(imageString, ":")
199 tag = tagSplit[1]
200 }
201
202 if strings.Contains(imageString, "@") {
203 digestSplit := strings.Split(imageString, "@")
204 digest := digestSplit[1]
205 digs := strings.Split(digest, ":")
206 if digs[0] != "sha256" {
207 return fmt.Errorf("expected digest to be sha256. got: %v", digs[0])
208 }
209 sha256Digest = digs[1]
210 }
211
212 _, err := db.ExecContext(ctx, observeSQL, image, tag, sha256Digest, db.ClusterID)
213 if err != nil {
214 return err
215 }
216
217 return nil
218 }
219
220 func (db *DBHandle) UpdateClusterHeartbeatWithSession(ctx context.Context, h model.ClusterHeartbeat) error {
221 query := `
222 WITH upsert_cluster AS (
223 INSERT INTO clusters (id, version_major, version_minor)
224 VALUES ($1, $2, $3)
225 ON CONFLICT (id) DO NOTHING
226 )
227 INSERT INTO kinform_sessions (cluster, session, last_heartbeat)
228 VALUES ($1, $4, $5)
229 ON CONFLICT (session) DO UPDATE SET last_heartbeat = $5
230 `
231
232 _, err := db.ExecContext(ctx, query, h.Cluster, h.ClusterVersion.Major, h.ClusterVersion.Minor, h.SessionID, h.Timestamp)
233 if err != nil {
234 return err
235 }
236
237 return nil
238 }
239
240
241 type RemoteCommand struct {
242 ID string
243 CmdType string
244 CmdArgs string
245 }
246
247 func (db *DBHandle) GetRemoteCommand(ctx context.Context) (RemoteCommand, bool, error) {
248 commandSQL := `SELECT id, command_type, command_args FROM remote_commands WHERE cluster = $1 LIMIT 1`
249
250 row := db.QueryRowContext(ctx, commandSQL, db.ClusterID)
251 var id string
252
253 var commandType string
254 var commandArgs string
255 if err := row.Scan(&id, &commandType, &commandArgs); err != nil {
256 if errors.Is(err, sql.ErrNoRows) {
257 fmt.Println("no commands pending for cluster:", db.ClusterID)
258 return RemoteCommand{}, false, nil
259 }
260 return RemoteCommand{}, false, err
261 }
262
263 rCmd := RemoteCommand{
264 ID: id,
265 CmdType: commandType,
266 CmdArgs: commandArgs}
267 return rCmd, true, nil
268 }
269
270
271
272 func (db *DBHandle) DeleteRemoteCommand(ctx context.Context, id string) error {
273 commandSQL := `DELETE FROM remote_commands WHERE id = $1`
274
275 _, err := db.ExecContext(ctx, commandSQL, id)
276 if err != nil {
277 return err
278 }
279
280 return nil
281 }
282
283 func (db *DBHandle) GetClustersMatchingArtifactLabels(ctx context.Context, artifact uuid.UUID) ([]uuid.UUID, error) {
284 q := `
285 WITH labels as (
286 SELECT key, value
287 FROM artifact_labels
288 WHERE artifact = $1
289 )
290 SELECT cluster
291 FROM cluster_labels c
292 JOIN labels USING (key, value)
293 GROUP BY c.cluster
294 HAVING count(*) = (SELECT count(*) FROM labels)
295 `
296 t0 := time.Now()
297 rows, err := db.QueryContext(ctx, q, artifact)
298 if err != nil && !errors.Is(err, sql.ErrNoRows) {
299 return []uuid.UUID{}, nil
300 }
301 dt := time.Since(t0)
302 fmt.Printf("Query complete in %v ms\n", dt.Milliseconds())
303 clusters := []uuid.UUID{}
304 for rows.Next() {
305 var clusterID uuid.UUID
306 err := rows.Scan(&clusterID)
307 if err != nil {
308 return []uuid.UUID{}, err
309 }
310 clusters = append(clusters, clusterID)
311 }
312
313 return clusters, nil
314 }
315
316 func (db *DBHandle) InsertArtifact(ctx context.Context, a sovereign.Artifact) (uuid.UUID, error) {
317 q := `
318 INSERT INTO artifacts (project, repository, artifact_version)
319 VALUES ($1, $2, $3)
320 RETURNING id
321 `
322
323 var id uuid.UUID
324 row := db.QueryRowContext(ctx, q, a.ProjectID, a.Repository, a.ArtifactVersion)
325 err := row.Scan(&id)
326 if err != nil {
327 return uuid.Nil, err
328 }
329
330 return id, nil
331 }
332
333 func (db *DBHandle) QueryArtifactVersion(ctx context.Context, image, sha25Digest string) (sovereign.ArtifactVersion, error) {
334 q := `SELECT id FROM artifact_versions WHERE image = $1 AND sha256_digest = $2`
335
336 var id uuid.UUID
337 row := db.QueryRowContext(ctx, q, image, sha25Digest)
338 err := row.Scan(&id)
339 if err != nil {
340 return sovereign.ArtifactVersion{}, err
341 }
342
343 av := &sovereign.ArtifactVersion{
344 ID: id,
345 Image: image,
346 Sha256Digest: sha25Digest,
347 }
348
349 return *av, nil
350 }
351
352 func (db *DBHandle) InsertArtifactVersion(ctx context.Context, image, tag, sha25Digest string) (uuid.UUID, error) {
353 q := `
354 INSERT INTO artifact_versions (image, tag, sha256_digest)
355 VALUES ($1, $2, $3)
356 RETURNING id
357 `
358
359 var id uuid.UUID
360 row := db.QueryRowContext(ctx, q, image, tag, sha25Digest)
361 err := row.Scan(&id)
362 if err != nil {
363 return uuid.Nil, err
364 }
365
366 return id, nil
367 }
368
369 func (db *DBHandle) DeleteArtifactVersion(ctx context.Context, image, sha25Digest string) error {
370 q := `
371 DELETE FROM artifact_versions
372 WHERE image = $1 AND sha256_digest = $2
373 `
374
375 _, err := db.ExecContext(ctx, q, image, sha25Digest)
376 if err != nil {
377 return err
378 }
379
380 return nil
381 }
382
383 func (db *DBHandle) InsertArtifactLabel(ctx context.Context, artifact uuid.UUID, key, value string) (uuid.UUID, error) {
384 q := `
385 INSERT INTO artifact_labels (artifact, key, value)
386 VALUES ($1, $2, $3)
387 RETURNING id
388 `
389
390 var id uuid.UUID
391 row := db.QueryRowContext(ctx, q, artifact, key, value)
392 err := row.Scan(&id)
393 if err != nil {
394 return uuid.Nil, err
395 }
396
397 return id, nil
398 }
399
400 func (db *DBHandle) InsertClusterLabel(ctx context.Context, cluster uuid.UUID, key, value string) (uuid.UUID, error) {
401 q := `
402 INSERT INTO cluster_labels (cluster, key, value)
403 VALUES ($1, $2, $3)
404 RETURNING id
405 `
406
407 var id uuid.UUID
408 row := db.QueryRowContext(ctx, q, cluster, key, value)
409 err := row.Scan(&id)
410 if err != nil {
411 return uuid.Nil, err
412 }
413
414 return id, nil
415 }
416
417 func (db *DBHandle) GetKinformPubSubSubscriptions(ctx context.Context) ([]model.PubSubSubscription, error) {
418 q := `SELECT subscription, project FROM kinform_pubsub_subscriptions`
419 rows, err := db.QueryContext(ctx, q)
420 if err != nil && !errors.Is(err, sql.ErrNoRows) {
421 return []model.PubSubSubscription{}, nil
422 }
423
424 subscriptions := []model.PubSubSubscription{}
425 for rows.Next() {
426 var sub string
427 var project string
428 err := rows.Scan(&sub, &project)
429 if err != nil {
430 return []model.PubSubSubscription{}, err
431 }
432 pss := model.PubSubSubscription{
433 SubscriptionID: sub,
434 Project: project,
435 }
436 subscriptions = append(subscriptions, pss)
437 }
438
439 return subscriptions, nil
440 }
441
442 func (db *DBHandle) InsertKinformPubSubSubscription(ctx context.Context, sub model.PubSubSubscription) (uuid.UUID, error) {
443 q := `
444 INSERT INTO kinform_pubsub_subscriptions (subscription, project)
445 VALUES ($1, $2)
446 RETURNING id
447 `
448 var id uuid.UUID
449
450 row := db.QueryRowContext(ctx, q, sub.SubscriptionID, sub.Project)
451 err := row.Scan(&id)
452 if err != nil {
453 return uuid.Nil, err
454 }
455
456 return id, nil
457 }
458
View as plain text