...

Source file src/edge-infra.dev/pkg/f8n/kinform/sql/sql.go

Documentation: edge-infra.dev/pkg/f8n/kinform/sql

     1  package sql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"embed"
     7  	"errors"
     8  	"fmt"
     9  	"io/fs"
    10  	"os"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/google/uuid"
    16  	_ "github.com/jackc/pgx/v4/stdlib" // nolint:revive necessary for db driver
    17  
    18  	"edge-infra.dev/pkg/f8n/kinform/model"
    19  	sovereign "edge-infra.dev/pkg/f8n/sovereign/model"
    20  )
    21  
    22  type DBHandle struct {
    23  	*sql.DB
    24  	ClusterID string
    25  }
    26  
    27  //go:embed schema.sql
    28  var schemaFS embed.FS
    29  
    30  const postgres = "postgres"
    31  
    32  func FromDSN(dsn string, maxOpenConns, maxIdleConns int) (*DBHandle, error) {
    33  	db, err := sql.Open("pgx", dsn)
    34  	if err != nil {
    35  		return nil, fmt.Errorf("failed to open db with pgx driver: %w", err)
    36  	}
    37  
    38  	err = db.Ping()
    39  	if err != nil {
    40  		return nil, fmt.Errorf("failed to ping database: %w", err)
    41  	}
    42  
    43  	err = execSchema(db)
    44  	if err != nil {
    45  		return nil, fmt.Errorf("failed to execute schema: %w", err)
    46  	}
    47  
    48  	db.SetMaxOpenConns(maxOpenConns)
    49  	db.SetMaxIdleConns(maxIdleConns)
    50  	db.SetConnMaxIdleTime(time.Minute)
    51  
    52  	return &DBHandle{DB: db}, nil
    53  }
    54  
    55  func FromEnv() (*DBHandle, error) {
    56  	user, ok := os.LookupEnv("DB_USER")
    57  	if !ok {
    58  		user = postgres
    59  	}
    60  	password, ok := os.LookupEnv("DB_PASS")
    61  	if !ok {
    62  		password = ""
    63  	}
    64  	host, ok := os.LookupEnv("DB_HOST")
    65  	if !ok {
    66  		host = "127.0.0.1"
    67  	}
    68  	port, ok := os.LookupEnv("DB_PORT")
    69  	if !ok {
    70  		port = "5432"
    71  	}
    72  	dbName, ok := os.LookupEnv("DB_NAME")
    73  	if !ok {
    74  		dbName = postgres
    75  	}
    76  
    77  	var dbMaxConns int
    78  	dbMaxConnsStr, ok := os.LookupEnv("DB_MAX_CONNS")
    79  	if ok {
    80  		dbMaxConnsParsed, err := strconv.Atoi(dbMaxConnsStr)
    81  		if err != nil {
    82  			return nil, fmt.Errorf("failed to parse DB_MAX_CONNS: %w", err)
    83  		}
    84  		dbMaxConns = dbMaxConnsParsed
    85  	} else {
    86  		// AlloyDB assumed to have default config of 1000 max connections with reservation for superuser connections.
    87  		// 450 represents half of the remaining 900 user connections, allowing a second replica for rollouts, failovers.
    88  		dbMaxConns = 450
    89  	}
    90  
    91  	dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s", host, port, user, password, dbName)
    92  
    93  	return FromDSN(dsn, dbMaxConns, dbMaxConns)
    94  }
    95  
    96  func execSchema(db *sql.DB) error {
    97  	schemaSQL, err := fs.ReadFile(schemaFS, "schema.sql")
    98  	if err != nil {
    99  		return err
   100  	}
   101  	_, err = db.Exec(string(schemaSQL))
   102  	if err != nil {
   103  		return err
   104  	}
   105  	return nil
   106  }
   107  
   108  func (db *DBHandle) InsertCluster(ctx context.Context, name string) (uuid.UUID, error) {
   109  	insertSQL := `
   110  INSERT INTO clusters(name, version_major, version_minor)
   111  VALUES ($1, $2, $3)
   112  RETURNING id
   113  `
   114  
   115  	var id uuid.UUID
   116  	// TODO(dk185217): use real cluster version
   117  	row := db.QueryRowContext(ctx, insertSQL, name, 0, 0)
   118  	err := row.Scan(&id)
   119  	if err != nil {
   120  		return uuid.Nil, err
   121  	}
   122  
   123  	return id, nil
   124  }
   125  
   126  func (db *DBHandle) InsertResource(ctx context.Context, resource model.WatchedResource) error {
   127  	insertSQL := `
   128  INSERT INTO watched_resources(api_version, kind, resource, cluster)
   129  VALUES ($1, $2, $3, $4)
   130  ON CONFLICT ((resource['metadata']['uid']))
   131  DO UPDATE SET api_version = $1, kind = $2, resource = $3
   132  `
   133  
   134  	_, err := db.ExecContext(ctx, insertSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.Cluster)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  func (db *DBHandle) UpdateResource(ctx context.Context, resource model.WatchedResource) error {
   143  	// TODO(dk185217): Only update if newer. Dont depend on everything arriving in order (do this in sql)
   144  	updateSQL := `UPDATE watched_resources SET api_version=$1, kind=$2, resource=$3 WHERE resource['metadata']['uid'] = $4`
   145  
   146  	_, err := db.ExecContext(ctx, updateSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.MetadataUID)
   147  	if err != nil {
   148  		return err
   149  	}
   150  
   151  	return nil
   152  }
   153  
   154  func (db *DBHandle) DeleteResource(ctx context.Context, resource model.WatchedResource) error {
   155  	deleteSQL := "DELETE FROM watched_resources WHERE resource['metadata']['uid'] = $1"
   156  
   157  	_, err := db.ExecContext(ctx, deleteSQL, resource.MetadataUID)
   158  	if err != nil {
   159  		return err
   160  	}
   161  
   162  	return nil
   163  }
   164  
   165  // TODO(dk185217): experimental. full history of objects
   166  func (db *DBHandle) InsertResourceObservation(ctx context.Context, resource model.WatchedResource) error {
   167  	insertSQL := `INSERT INTO watched_resource_observations(api_version, kind, resource, cluster) VALUES ($1, $2, $3, $4)`
   168  
   169  	_, err := db.ExecContext(ctx, insertSQL, resource.APIVersion, resource.Kind, resource.Resource, resource.Cluster)
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	return nil
   175  }
   176  
   177  // TODO(dk185217): This should be temporary. It is parsing raw image strings from deployment specs. It should
   178  // be replaced by functionality acting on higher level information about artifacts
   179  func (db *DBHandle) InsertArtifactObserved(ctx context.Context, image string) error {
   180  	observeSQL := `
   181  WITH
   182  artifact_id AS (
   183  	SELECT fn_artifact_id_for($1) as id
   184  ),
   185  artifact_version_id AS (
   186  	SELECT fn_artifact_version_id_for((SELECT id FROM artifact_id), $2, $3) as id
   187  )
   188  INSERT INTO observed_states (cluster, artifact_version)
   189  VALUES ($4, (SELECT id FROM artifact_version_id))
   190  ON CONFLICT (cluster, artifact_version) DO UPDATE SET observed_at = NOW()
   191  `
   192  
   193  	ss := strings.Split(image, "/")
   194  	imageString := ss[1]
   195  	tag := ""
   196  	sha256Digest := "0000000000000000000000000000000000000000000000000000000000000000"
   197  	if strings.Contains(imageString, ":") {
   198  		tagSplit := strings.Split(imageString, ":")
   199  		tag = tagSplit[1]
   200  	}
   201  
   202  	if strings.Contains(imageString, "@") {
   203  		digestSplit := strings.Split(imageString, "@")
   204  		digest := digestSplit[1]
   205  		digs := strings.Split(digest, ":")
   206  		if digs[0] != "sha256" {
   207  			return fmt.Errorf("expected digest to be sha256. got: %v", digs[0])
   208  		}
   209  		sha256Digest = digs[1]
   210  	}
   211  
   212  	_, err := db.ExecContext(ctx, observeSQL, image, tag, sha256Digest, db.ClusterID)
   213  	if err != nil {
   214  		return err
   215  	}
   216  
   217  	return nil
   218  }
   219  
   220  func (db *DBHandle) UpdateClusterHeartbeatWithSession(ctx context.Context, h model.ClusterHeartbeat) error {
   221  	query := `
   222  WITH upsert_cluster AS (
   223  	INSERT INTO clusters (id, version_major, version_minor)
   224  	VALUES ($1, $2, $3)
   225  	ON CONFLICT (id) DO NOTHING
   226  )
   227  INSERT INTO kinform_sessions (cluster, session, last_heartbeat)
   228  VALUES ($1, $4, $5)
   229  ON CONFLICT (session) DO UPDATE SET last_heartbeat = $5
   230  `
   231  
   232  	_, err := db.ExecContext(ctx, query, h.Cluster, h.ClusterVersion.Major, h.ClusterVersion.Minor, h.SessionID, h.Timestamp)
   233  	if err != nil {
   234  		return err
   235  	}
   236  
   237  	return nil
   238  }
   239  
   240  // TODO(dk185217): Begin to refactor out ORM-style domain types (ie, entity mappings to go structs)
   241  type RemoteCommand struct {
   242  	ID      string
   243  	CmdType string
   244  	CmdArgs string
   245  }
   246  
   247  func (db *DBHandle) GetRemoteCommand(ctx context.Context) (RemoteCommand, bool, error) {
   248  	commandSQL := `SELECT id, command_type, command_args FROM remote_commands WHERE cluster = $1 LIMIT 1`
   249  
   250  	row := db.QueryRowContext(ctx, commandSQL, db.ClusterID)
   251  	var id string
   252  
   253  	var commandType string
   254  	var commandArgs string
   255  	if err := row.Scan(&id, &commandType, &commandArgs); err != nil {
   256  		if errors.Is(err, sql.ErrNoRows) {
   257  			fmt.Println("no commands pending for cluster:", db.ClusterID)
   258  			return RemoteCommand{}, false, nil
   259  		}
   260  		return RemoteCommand{}, false, err
   261  	}
   262  
   263  	rCmd := RemoteCommand{
   264  		ID:      id,
   265  		CmdType: commandType,
   266  		CmdArgs: commandArgs}
   267  	return rCmd, true, nil
   268  }
   269  
   270  // TODO(dk185217): Transact, or similar. Get, handle, and delete commands according to successful completion, and
   271  // make sure it is row locked for a single handler at once
   272  func (db *DBHandle) DeleteRemoteCommand(ctx context.Context, id string) error {
   273  	commandSQL := `DELETE FROM remote_commands WHERE id = $1`
   274  
   275  	_, err := db.ExecContext(ctx, commandSQL, id)
   276  	if err != nil {
   277  		return err
   278  	}
   279  
   280  	return nil
   281  }
   282  
   283  func (db *DBHandle) GetClustersMatchingArtifactLabels(ctx context.Context, artifact uuid.UUID) ([]uuid.UUID, error) {
   284  	q := `
   285  WITH labels as (
   286  	SELECT key, value
   287  	FROM artifact_labels
   288  	WHERE artifact = $1
   289  )
   290  SELECT cluster
   291  FROM cluster_labels c
   292  JOIN labels USING (key, value)
   293  GROUP BY c.cluster
   294  HAVING count(*) = (SELECT count(*) FROM labels)
   295  `
   296  	t0 := time.Now()
   297  	rows, err := db.QueryContext(ctx, q, artifact)
   298  	if err != nil && !errors.Is(err, sql.ErrNoRows) {
   299  		return []uuid.UUID{}, nil
   300  	}
   301  	dt := time.Since(t0)
   302  	fmt.Printf("Query complete in %v ms\n", dt.Milliseconds())
   303  	clusters := []uuid.UUID{}
   304  	for rows.Next() {
   305  		var clusterID uuid.UUID
   306  		err := rows.Scan(&clusterID)
   307  		if err != nil {
   308  			return []uuid.UUID{}, err
   309  		}
   310  		clusters = append(clusters, clusterID)
   311  	}
   312  
   313  	return clusters, nil
   314  }
   315  
   316  func (db *DBHandle) InsertArtifact(ctx context.Context, a sovereign.Artifact) (uuid.UUID, error) {
   317  	q := `
   318  INSERT INTO artifacts (project, repository, artifact_version)
   319  VALUES ($1, $2, $3)
   320  RETURNING id
   321  `
   322  
   323  	var id uuid.UUID
   324  	row := db.QueryRowContext(ctx, q, a.ProjectID, a.Repository, a.ArtifactVersion)
   325  	err := row.Scan(&id)
   326  	if err != nil {
   327  		return uuid.Nil, err
   328  	}
   329  
   330  	return id, nil
   331  }
   332  
   333  func (db *DBHandle) QueryArtifactVersion(ctx context.Context, image, sha25Digest string) (sovereign.ArtifactVersion, error) {
   334  	q := `SELECT id FROM artifact_versions WHERE image = $1 AND sha256_digest = $2`
   335  
   336  	var id uuid.UUID
   337  	row := db.QueryRowContext(ctx, q, image, sha25Digest)
   338  	err := row.Scan(&id)
   339  	if err != nil {
   340  		return sovereign.ArtifactVersion{}, err
   341  	}
   342  
   343  	av := &sovereign.ArtifactVersion{
   344  		ID:           id,
   345  		Image:        image,
   346  		Sha256Digest: sha25Digest,
   347  	}
   348  
   349  	return *av, nil
   350  }
   351  
   352  func (db *DBHandle) InsertArtifactVersion(ctx context.Context, image, tag, sha25Digest string) (uuid.UUID, error) {
   353  	q := `
   354  INSERT INTO artifact_versions (image, tag, sha256_digest)
   355  VALUES ($1, $2, $3)
   356  RETURNING id
   357  `
   358  
   359  	var id uuid.UUID
   360  	row := db.QueryRowContext(ctx, q, image, tag, sha25Digest)
   361  	err := row.Scan(&id)
   362  	if err != nil {
   363  		return uuid.Nil, err
   364  	}
   365  
   366  	return id, nil
   367  }
   368  
   369  func (db *DBHandle) DeleteArtifactVersion(ctx context.Context, image, sha25Digest string) error {
   370  	q := `
   371  DELETE FROM artifact_versions
   372  WHERE image = $1 AND sha256_digest = $2
   373  `
   374  
   375  	_, err := db.ExecContext(ctx, q, image, sha25Digest)
   376  	if err != nil {
   377  		return err
   378  	}
   379  
   380  	return nil
   381  }
   382  
   383  func (db *DBHandle) InsertArtifactLabel(ctx context.Context, artifact uuid.UUID, key, value string) (uuid.UUID, error) {
   384  	q := `
   385  INSERT INTO artifact_labels (artifact, key, value)
   386  VALUES ($1, $2, $3)
   387  RETURNING id
   388  `
   389  
   390  	var id uuid.UUID
   391  	row := db.QueryRowContext(ctx, q, artifact, key, value)
   392  	err := row.Scan(&id)
   393  	if err != nil {
   394  		return uuid.Nil, err
   395  	}
   396  
   397  	return id, nil
   398  }
   399  
   400  func (db *DBHandle) InsertClusterLabel(ctx context.Context, cluster uuid.UUID, key, value string) (uuid.UUID, error) {
   401  	q := `
   402  INSERT INTO cluster_labels (cluster, key, value)
   403  VALUES ($1, $2, $3)
   404  RETURNING id
   405  `
   406  
   407  	var id uuid.UUID
   408  	row := db.QueryRowContext(ctx, q, cluster, key, value)
   409  	err := row.Scan(&id)
   410  	if err != nil {
   411  		return uuid.Nil, err
   412  	}
   413  
   414  	return id, nil
   415  }
   416  
   417  func (db *DBHandle) GetKinformPubSubSubscriptions(ctx context.Context) ([]model.PubSubSubscription, error) {
   418  	q := `SELECT subscription, project FROM kinform_pubsub_subscriptions`
   419  	rows, err := db.QueryContext(ctx, q)
   420  	if err != nil && !errors.Is(err, sql.ErrNoRows) {
   421  		return []model.PubSubSubscription{}, nil
   422  	}
   423  
   424  	subscriptions := []model.PubSubSubscription{}
   425  	for rows.Next() {
   426  		var sub string
   427  		var project string
   428  		err := rows.Scan(&sub, &project)
   429  		if err != nil {
   430  			return []model.PubSubSubscription{}, err
   431  		}
   432  		pss := model.PubSubSubscription{
   433  			SubscriptionID: sub,
   434  			Project:        project,
   435  		}
   436  		subscriptions = append(subscriptions, pss)
   437  	}
   438  
   439  	return subscriptions, nil
   440  }
   441  
   442  func (db *DBHandle) InsertKinformPubSubSubscription(ctx context.Context, sub model.PubSubSubscription) (uuid.UUID, error) {
   443  	q := `
   444  INSERT INTO kinform_pubsub_subscriptions (subscription, project)
   445  VALUES ($1, $2)
   446  RETURNING id
   447  	`
   448  	var id uuid.UUID
   449  
   450  	row := db.QueryRowContext(ctx, q, sub.SubscriptionID, sub.Project)
   451  	err := row.Scan(&id)
   452  	if err != nil {
   453  		return uuid.Nil, err
   454  	}
   455  
   456  	return id, nil
   457  }
   458  

View as plain text