package server import ( "context" "database/sql" "errors" "fmt" "strings" "time" "edge-infra.dev/pkg/f8n/kinform/model" "github.com/google/uuid" ) type DBHandle struct { *sql.DB } func (h *DBHandle) GetBannerProjectIDs(ctx context.Context) ([]string, error) { const stmt = "SELECT project_id FROM banners" rows, err := h.QueryContext(ctx, stmt) if err != nil { return nil, fmt.Errorf("failed to query banners table for project ids: %w", err) } defer rows.Close() var projectIDs []string for rows.Next() { var projectID string err = rows.Scan(&projectID) if err != nil { return nil, fmt.Errorf("failed to scan banners table for project id: %w", err) } projectIDs = append(projectIDs, projectID) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("failed to query all project id rows in banners table: %w", err) } return projectIDs, nil } func (h *DBHandle) SetClusterHeartbeatTime(ctx context.Context, t time.Time, clusterEdgeID uuid.UUID) error { const stmt = "UPDATE clusters SET infra_status_updated_at = $1 WHERE cluster_edge_id = $2" _, err := h.ExecContext(ctx, stmt, t, clusterEdgeID.String()) if err != nil { return fmt.Errorf("failed to update cluster %q status time: %w", clusterEdgeID.String(), err) } return nil } var stmtSetWatchedFieldObject = `INSERT INTO watched_field_objects (cluster_edge_id, api_version, kind, name, namespace, watched_at, deleted) VALUES ($1, $2, $3, $4, $5, $6, 'false') ON CONFLICT (cluster_edge_id, api_version, kind, name, namespace) DO UPDATE SET (watched_at, deleted) = (EXCLUDED.watched_at, 'false') WHERE watched_field_objects.watched_at < $7 RETURNING object_id ` func TxSetWatchedFieldObject(tx *sql.Tx, wf model.WatchedField) (*uuid.UUID, error) { var vals = []interface{}{ wf.Cluster.String(), wf.APIVersion, wf.Kind, wf.Name, wf.Namespace, wf.Timestamp, wf.Timestamp, } var objectID uuid.UUID if err := tx.QueryRow(stmtSetWatchedFieldObject, vals...).Scan(&objectID); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, IgnoredMessageErrorf("ignoring outdated watched field object") } return nil, err } return &objectID, nil } const fmtSetWatchedFieldValues = `WITH ins_values AS ( INSERT INTO watched_field_values (object_id, jsonpath, value, missing) VALUES %s ON CONFLICT (object_id, jsonpath) DO UPDATE SET (value, missing) = (EXCLUDED.value, EXCLUDED.missing) RETURNING jsonpath ) DELETE FROM watched_field_values WHERE object_id = ($%d) AND jsonpath NOT IN (SELECT jsonpath FROM ins_values) ` func stmtSetWatchedField(valuesCount int) string { const fmtFieldValueArg = "($%d, $%d, $%d, $%d)" var args []string for i := 0; i < valuesCount; i++ { var x = i * 4 var a = fmt.Sprintf(fmtFieldValueArg, x+1, x+2, x+3, x+4) args = append(args, a) } var lastArg = 1 + valuesCount*4 return fmt.Sprintf(fmtSetWatchedFieldValues, strings.Join(args, ", "), lastArg) } func TxSetWatchedFieldValues(tx *sql.Tx, wf model.WatchedField, objectID *uuid.UUID) error { var vals []interface{} for _, fv := range wf.Fields { vals = append(vals, objectID.String(), fv.JSONPath, fv.Value, fv.Missing) } vals = append(vals, objectID.String()) _, err := tx.Exec(stmtSetWatchedField(len(wf.Fields)), vals...) return err } // SetWatchedField upserts the watched field and cleans up outdated values. If the timestamp is out of date, no data is updated. func (h *DBHandle) SetWatchedField(ctx context.Context, wf model.WatchedField) error { tx, err := h.DB.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("error beginning transaction for watched field: %w", err) } defer tx.Rollback() //nolint: errcheck objectID, err := TxSetWatchedFieldObject(tx, wf) if err != nil { return err } if err := TxSetWatchedFieldValues(tx, wf, objectID); err != nil { return err } if err := tx.Commit(); err != nil { return fmt.Errorf("error committing watched field: %w", err) } return nil } const stmtDeleteOutdatedWatchedFieldObjects = `UPDATE watched_field_objects SET (watched_at, deleted) = ($1, 'true') WHERE (cluster_edge_id, deleted) = ($2, 'false') AND watched_at < $3 RETURNING object_id ` const stmtDeleteOutdatedWatchedFieldValues = `DELETE FROM watched_field_values WHERE object_id = ANY ($1)` func (h *DBHandle) DeleteOutdatedWatchedFieldObjects(ctx context.Context, sm model.ScrapeMessage) error { tx, err := h.DB.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("error beginning transaction to delete outdated watched field objects: %w", err) } defer tx.Rollback() //nolint: errcheck var objectIDs []string var vals = []interface{}{ sm.StartTime, sm.Cluster.String(), sm.StartTime, } // delete watched field objects rows, err := tx.Query(stmtDeleteOutdatedWatchedFieldObjects, vals...) if err != nil { return fmt.Errorf("error deleting outdated watched field objects: %w", err) } for rows.Next() { var objectID uuid.UUID if err := rows.Scan(&objectID); err != nil { return fmt.Errorf("error scanning outdated watched field objects: %w", err) } objectIDs = append(objectIDs, objectID.String()) } if err := rows.Err(); err != nil { return fmt.Errorf("error scanning all outdated watched field objects: %w", err) } // delete watched field values if len(objectIDs) != 0 { if _, err := tx.Exec(stmtDeleteOutdatedWatchedFieldValues, objectIDs); err != nil { return fmt.Errorf("error deleting outdated watched field values: %w", err) } } if err := tx.Commit(); err != nil { return fmt.Errorf("error committing delete watched field: %w", err) } return nil } // stmtDeleteWatchedField also deletes the watched_field_values on cascade. const stmtDeleteWatchedFieldObject = `UPDATE watched_field_objects SET (watched_at, deleted) = ($1, 'true') WHERE (cluster_edge_id, api_version, kind, name, namespace) = ($2, $3, $4, $5, $6) AND watched_at < $7 RETURNING object_id` const stmtDeleteWatchedFieldValues = `DELETE FROM watched_field_values WHERE object_id = $1 ` func (h *DBHandle) DeleteWatchedField(ctx context.Context, wf model.WatchedField) error { tx, err := h.DB.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("error beginning transaction to delete watched field: %w", err) } defer tx.Rollback() //nolint: errcheck var objectID uuid.UUID var vals = []interface{}{ wf.Timestamp, wf.Cluster.String(), wf.APIVersion, wf.Kind, wf.Name, wf.Namespace, wf.Timestamp, } if err = tx.QueryRow(stmtDeleteWatchedFieldObject, vals...).Scan(&objectID); err != nil { if errors.Is(err, sql.ErrNoRows) { return IgnoredMessageErrorf("outdated delete for watched field object") } return fmt.Errorf("error deleting watched field object: %w", err) } if _, err = tx.Exec(stmtDeleteWatchedFieldValues, objectID.String()); err != nil { return fmt.Errorf("error deleting watched field values: %w", err) } if err := tx.Commit(); err != nil { return fmt.Errorf("error committing delete watched field: %w", err) } return nil } const stmtGarbageCollectDeletedWatchedFields = `WITH gc_deleted AS ( DELETE FROM watched_field_objects WHERE deleted IS TRUE AND watched_at < NOW() - INTERVAL '1 HOUR' RETURNING * ) SELECT COUNT(*) FROM gc_deleted ` // GarbageCollectDeletedWatchedFieldObjects garbage collects watched field objects that have been deleted for over an hour. // Since garbage collection requires a full table scan, it should only occur at startup when psqlinjector is not processing messages. // Running this query before processing messages prevents deadlocks and performance hiccups in the database. func (h *DBHandle) GarbageCollectDeletedWatchedFieldObjects(ctx context.Context) (int, error) { tx, err := h.DB.BeginTx(ctx, nil) if err != nil { return 0, fmt.Errorf("error beginning transaction to garbage collect deleted watched field objects: %w", err) } defer tx.Rollback() //nolint: errcheck var deletedCount int if err = tx.QueryRow(stmtGarbageCollectDeletedWatchedFields).Scan(&deletedCount); err != nil { return 0, fmt.Errorf("error garbage collecting deleted watched field objects: %w", err) } if err := tx.Commit(); err != nil { return 0, fmt.Errorf("error committing garbage collection of deleted watched field objects: %w", err) } return deletedCount, nil }