package sql import ( "context" "database/sql" "embed" "errors" "fmt" "io/fs" "os" "strconv" "strings" "time" "github.com/google/uuid" _ "github.com/jackc/pgx/v4/stdlib" // nolint:revive necessary for db driver "edge-infra.dev/pkg/f8n/kinform/model" sovereign "edge-infra.dev/pkg/f8n/sovereign/model" ) type DBHandle struct { *sql.DB ClusterID string } //go:embed schema.sql var schemaFS embed.FS const postgres = "postgres" func FromDSN(dsn string, maxOpenConns, maxIdleConns int) (*DBHandle, error) { db, err := sql.Open("pgx", dsn) if err != nil { return nil, fmt.Errorf("failed to open db with pgx driver: %w", err) } err = db.Ping() if err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) } err = execSchema(db) if err != nil { return nil, fmt.Errorf("failed to execute schema: %w", err) } db.SetMaxOpenConns(maxOpenConns) db.SetMaxIdleConns(maxIdleConns) db.SetConnMaxIdleTime(time.Minute) return &DBHandle{DB: db}, nil } func FromEnv() (*DBHandle, error) { user, ok := os.LookupEnv("DB_USER") if !ok { user = postgres } password, ok := os.LookupEnv("DB_PASS") if !ok { password = "" } host, ok := os.LookupEnv("DB_HOST") if !ok { host = "127.0.0.1" } port, ok := os.LookupEnv("DB_PORT") if !ok { port = "5432" } dbName, ok := os.LookupEnv("DB_NAME") if !ok { dbName = postgres } var dbMaxConns int dbMaxConnsStr, ok := os.LookupEnv("DB_MAX_CONNS") if ok { dbMaxConnsParsed, err := strconv.Atoi(dbMaxConnsStr) if err != nil { return nil, fmt.Errorf("failed to parse DB_MAX_CONNS: %w", err) } dbMaxConns = dbMaxConnsParsed } else { // AlloyDB assumed to have default config of 1000 max connections with reservation for superuser connections. // 450 represents half of the remaining 900 user connections, allowing a second replica for rollouts, failovers. dbMaxConns = 450 } dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s", host, port, user, password, dbName) return FromDSN(dsn, dbMaxConns, dbMaxConns) } func execSchema(db *sql.DB) error { schemaSQL, err := fs.ReadFile(schemaFS, "schema.sql") if err != nil { return err } _, err = db.Exec(string(schemaSQL)) if err != nil { return err } return nil } func (db *DBHandle) InsertCluster(ctx context.Context, name string) (uuid.UUID, error) { insertSQL := ` INSERT INTO clusters(name, version_major, version_minor) VALUES ($1, $2, $3) RETURNING id ` var id uuid.UUID // TODO(dk185217): use real cluster version row := db.QueryRowContext(ctx, insertSQL, name, 0, 0) err := row.Scan(&id) if err != nil { return uuid.Nil, err } return id, nil } func (db *DBHandle) InsertResource(ctx context.Context, resource model.WatchedResource) error { insertSQL := ` INSERT INTO watched_resources(api_version, kind, resource, cluster) VALUES ($1, $2, $3, $4) ON CONFLICT ((resource['metadata']['uid'])) DO UPDATE SET api_version = $1, kind = $2, resource = $3 ` _, err := db.ExecContext(ctx, insertSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.Cluster) if err != nil { return err } return nil } func (db *DBHandle) UpdateResource(ctx context.Context, resource model.WatchedResource) error { // TODO(dk185217): Only update if newer. Dont depend on everything arriving in order (do this in sql) updateSQL := `UPDATE watched_resources SET api_version=$1, kind=$2, resource=$3 WHERE resource['metadata']['uid'] = $4` _, err := db.ExecContext(ctx, updateSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.MetadataUID) if err != nil { return err } return nil } func (db *DBHandle) DeleteResource(ctx context.Context, resource model.WatchedResource) error { deleteSQL := "DELETE FROM watched_resources WHERE resource['metadata']['uid'] = $1" _, err := db.ExecContext(ctx, deleteSQL, resource.MetadataUID) if err != nil { return err } return nil } // TODO(dk185217): experimental. full history of objects func (db *DBHandle) InsertResourceObservation(ctx context.Context, resource model.WatchedResource) error { insertSQL := `INSERT INTO watched_resource_observations(api_version, kind, resource, cluster) VALUES ($1, $2, $3, $4)` _, err := db.ExecContext(ctx, insertSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.Cluster) if err != nil { return err } return nil } // TODO(dk185217): This should be temporary. It is parsing raw image strings from deployment specs. It should // be replaced by functionality acting on higher level information about artifacts func (db *DBHandle) InsertArtifactObserved(ctx context.Context, image string) error { observeSQL := ` WITH artifact_id AS ( SELECT fn_artifact_id_for($1) as id ), artifact_version_id AS ( SELECT fn_artifact_version_id_for((SELECT id FROM artifact_id), $2, $3) as id ) INSERT INTO observed_states (cluster, artifact_version) VALUES ($4, (SELECT id FROM artifact_version_id)) ON CONFLICT (cluster, artifact_version) DO UPDATE SET observed_at = NOW() ` ss := strings.Split(image, "/") imageString := ss[1] tag := "" sha256Digest := "0000000000000000000000000000000000000000000000000000000000000000" if strings.Contains(imageString, ":") { tagSplit := strings.Split(imageString, ":") tag = tagSplit[1] } if strings.Contains(imageString, "@") { digestSplit := strings.Split(imageString, "@") digest := digestSplit[1] digs := strings.Split(digest, ":") if digs[0] != "sha256" { return fmt.Errorf("expected digest to be sha256. got: %v", digs[0]) } sha256Digest = digs[1] } _, err := db.ExecContext(ctx, observeSQL, image, tag, sha256Digest, db.ClusterID) if err != nil { return err } return nil } func (db *DBHandle) UpdateClusterHeartbeatWithSession(ctx context.Context, h model.ClusterHeartbeat) error { query := ` WITH upsert_cluster AS ( INSERT INTO clusters (id, version_major, version_minor) VALUES ($1, $2, $3) ON CONFLICT (id) DO NOTHING ) INSERT INTO kinform_sessions (cluster, session, last_heartbeat) VALUES ($1, $4, $5) ON CONFLICT (session) DO UPDATE SET last_heartbeat = $5 ` _, err := db.ExecContext(ctx, query, h.Cluster, h.ClusterVersion.Major, h.ClusterVersion.Minor, h.SessionID, h.Timestamp) if err != nil { return err } return nil } // TODO(dk185217): Begin to refactor out ORM-style domain types (ie, entity mappings to go structs) type RemoteCommand struct { ID string CmdType string CmdArgs string } func (db *DBHandle) GetRemoteCommand(ctx context.Context) (RemoteCommand, bool, error) { commandSQL := `SELECT id, command_type, command_args FROM remote_commands WHERE cluster = $1 LIMIT 1` row := db.QueryRowContext(ctx, commandSQL, db.ClusterID) var id string var commandType string var commandArgs string if err := row.Scan(&id, &commandType, &commandArgs); err != nil { if errors.Is(err, sql.ErrNoRows) { fmt.Println("no commands pending for cluster:", db.ClusterID) return RemoteCommand{}, false, nil } return RemoteCommand{}, false, err } rCmd := RemoteCommand{ ID: id, CmdType: commandType, CmdArgs: commandArgs} return rCmd, true, nil } // TODO(dk185217): Transact, or similar. Get, handle, and delete commands according to successful completion, and // make sure it is row locked for a single handler at once func (db *DBHandle) DeleteRemoteCommand(ctx context.Context, id string) error { commandSQL := `DELETE FROM remote_commands WHERE id = $1` _, err := db.ExecContext(ctx, commandSQL, id) if err != nil { return err } return nil } func (db *DBHandle) GetClustersMatchingArtifactLabels(ctx context.Context, artifact uuid.UUID) ([]uuid.UUID, error) { q := ` WITH labels as ( SELECT key, value FROM artifact_labels WHERE artifact = $1 ) SELECT cluster FROM cluster_labels c JOIN labels USING (key, value) GROUP BY c.cluster HAVING count(*) = (SELECT count(*) FROM labels) ` t0 := time.Now() rows, err := db.QueryContext(ctx, q, artifact) if err != nil && !errors.Is(err, sql.ErrNoRows) { return []uuid.UUID{}, nil } dt := time.Since(t0) fmt.Printf("Query complete in %v ms\n", dt.Milliseconds()) clusters := []uuid.UUID{} for rows.Next() { var clusterID uuid.UUID err := rows.Scan(&clusterID) if err != nil { return []uuid.UUID{}, err } clusters = append(clusters, clusterID) } return clusters, nil } func (db *DBHandle) InsertArtifact(ctx context.Context, a sovereign.Artifact) (uuid.UUID, error) { q := ` INSERT INTO artifacts (project, repository, artifact_version) VALUES ($1, $2, $3) RETURNING id ` var id uuid.UUID row := db.QueryRowContext(ctx, q, a.ProjectID, a.Repository, a.ArtifactVersion) err := row.Scan(&id) if err != nil { return uuid.Nil, err } return id, nil } func (db *DBHandle) QueryArtifactVersion(ctx context.Context, image, sha25Digest string) (sovereign.ArtifactVersion, error) { q := `SELECT id FROM artifact_versions WHERE image = $1 AND sha256_digest = $2` var id uuid.UUID row := db.QueryRowContext(ctx, q, image, sha25Digest) err := row.Scan(&id) if err != nil { return sovereign.ArtifactVersion{}, err } av := &sovereign.ArtifactVersion{ ID: id, Image: image, Sha256Digest: sha25Digest, } return *av, nil } func (db *DBHandle) InsertArtifactVersion(ctx context.Context, image, tag, sha25Digest string) (uuid.UUID, error) { q := ` INSERT INTO artifact_versions (image, tag, sha256_digest) VALUES ($1, $2, $3) RETURNING id ` var id uuid.UUID row := db.QueryRowContext(ctx, q, image, tag, sha25Digest) err := row.Scan(&id) if err != nil { return uuid.Nil, err } return id, nil } func (db *DBHandle) DeleteArtifactVersion(ctx context.Context, image, sha25Digest string) error { q := ` DELETE FROM artifact_versions WHERE image = $1 AND sha256_digest = $2 ` _, err := db.ExecContext(ctx, q, image, sha25Digest) if err != nil { return err } return nil } func (db *DBHandle) InsertArtifactLabel(ctx context.Context, artifact uuid.UUID, key, value string) (uuid.UUID, error) { q := ` INSERT INTO artifact_labels (artifact, key, value) VALUES ($1, $2, $3) RETURNING id ` var id uuid.UUID row := db.QueryRowContext(ctx, q, artifact, key, value) err := row.Scan(&id) if err != nil { return uuid.Nil, err } return id, nil } func (db *DBHandle) InsertClusterLabel(ctx context.Context, cluster uuid.UUID, key, value string) (uuid.UUID, error) { q := ` INSERT INTO cluster_labels (cluster, key, value) VALUES ($1, $2, $3) RETURNING id ` var id uuid.UUID row := db.QueryRowContext(ctx, q, cluster, key, value) err := row.Scan(&id) if err != nil { return uuid.Nil, err } return id, nil } func (db *DBHandle) GetKinformPubSubSubscriptions(ctx context.Context) ([]model.PubSubSubscription, error) { q := `SELECT subscription, project FROM kinform_pubsub_subscriptions` rows, err := db.QueryContext(ctx, q) if err != nil && !errors.Is(err, sql.ErrNoRows) { return []model.PubSubSubscription{}, nil } subscriptions := []model.PubSubSubscription{} for rows.Next() { var sub string var project string err := rows.Scan(&sub, &project) if err != nil { return []model.PubSubSubscription{}, err } pss := model.PubSubSubscription{ SubscriptionID: sub, Project: project, } subscriptions = append(subscriptions, pss) } return subscriptions, nil } func (db *DBHandle) InsertKinformPubSubSubscription(ctx context.Context, sub model.PubSubSubscription) (uuid.UUID, error) { q := ` INSERT INTO kinform_pubsub_subscriptions (subscription, project) VALUES ($1, $2) RETURNING id ` var id uuid.UUID row := db.QueryRowContext(ctx, q, sub.SubscriptionID, sub.Project) err := row.Scan(&id) if err != nil { return uuid.Nil, err } return id, nil }