...

Source file src/edge-infra.dev/pkg/edge/psqlinjector/sql.go

Documentation: edge-infra.dev/pkg/edge/psqlinjector

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  	"time"
    10  
    11  	"edge-infra.dev/pkg/f8n/kinform/model"
    12  
    13  	"github.com/google/uuid"
    14  )
    15  
    16  type DBHandle struct {
    17  	*sql.DB
    18  }
    19  
    20  func (h *DBHandle) GetBannerProjectIDs(ctx context.Context) ([]string, error) {
    21  	const stmt = "SELECT project_id FROM banners"
    22  	rows, err := h.QueryContext(ctx, stmt)
    23  	if err != nil {
    24  		return nil, fmt.Errorf("failed to query banners table for project ids: %w", err)
    25  	}
    26  	defer rows.Close()
    27  
    28  	var projectIDs []string
    29  	for rows.Next() {
    30  		var projectID string
    31  		err = rows.Scan(&projectID)
    32  		if err != nil {
    33  			return nil, fmt.Errorf("failed to scan banners table for project id: %w", err)
    34  		}
    35  		projectIDs = append(projectIDs, projectID)
    36  	}
    37  
    38  	if err = rows.Err(); err != nil {
    39  		return nil, fmt.Errorf("failed to query all project id rows in banners table: %w", err)
    40  	}
    41  
    42  	return projectIDs, nil
    43  }
    44  
    45  func (h *DBHandle) SetClusterHeartbeatTime(ctx context.Context, t time.Time, clusterEdgeID uuid.UUID) error {
    46  	const stmt = "UPDATE clusters SET infra_status_updated_at = $1 WHERE cluster_edge_id = $2"
    47  	_, err := h.ExecContext(ctx, stmt, t, clusterEdgeID.String())
    48  	if err != nil {
    49  		return fmt.Errorf("failed to update cluster %q status time: %w", clusterEdgeID.String(), err)
    50  	}
    51  	return nil
    52  }
    53  
    54  var stmtSetWatchedFieldObject = `INSERT INTO watched_field_objects (cluster_edge_id, api_version, kind, name, namespace, watched_at, deleted)
    55  VALUES ($1, $2, $3, $4, $5, $6, 'false')
    56  ON CONFLICT (cluster_edge_id, api_version, kind, name, namespace)
    57  DO UPDATE SET (watched_at, deleted) = (EXCLUDED.watched_at, 'false')
    58  WHERE watched_field_objects.watched_at < $7 
    59  RETURNING object_id
    60  `
    61  
    62  func TxSetWatchedFieldObject(tx *sql.Tx, wf model.WatchedField) (*uuid.UUID, error) {
    63  	var vals = []interface{}{
    64  		wf.Cluster.String(),
    65  		wf.APIVersion,
    66  		wf.Kind,
    67  		wf.Name,
    68  		wf.Namespace,
    69  		wf.Timestamp,
    70  		wf.Timestamp,
    71  	}
    72  	var objectID uuid.UUID
    73  	if err := tx.QueryRow(stmtSetWatchedFieldObject, vals...).Scan(&objectID); err != nil {
    74  		if errors.Is(err, sql.ErrNoRows) {
    75  			return nil, IgnoredMessageErrorf("ignoring outdated watched field object")
    76  		}
    77  		return nil, err
    78  	}
    79  	return &objectID, nil
    80  }
    81  
    82  const fmtSetWatchedFieldValues = `WITH
    83  ins_values AS (
    84    INSERT INTO watched_field_values (object_id, jsonpath, value, missing)
    85    VALUES %s
    86    ON CONFLICT (object_id, jsonpath) DO UPDATE
    87    SET (value, missing) = (EXCLUDED.value, EXCLUDED.missing)
    88    RETURNING jsonpath
    89  )
    90  DELETE FROM watched_field_values
    91  WHERE object_id = ($%d)
    92  AND jsonpath NOT IN (SELECT jsonpath FROM ins_values)
    93  `
    94  
    95  func stmtSetWatchedField(valuesCount int) string {
    96  	const fmtFieldValueArg = "($%d, $%d, $%d, $%d)"
    97  	var args []string
    98  	for i := 0; i < valuesCount; i++ {
    99  		var x = i * 4
   100  		var a = fmt.Sprintf(fmtFieldValueArg, x+1, x+2, x+3, x+4)
   101  		args = append(args, a)
   102  	}
   103  	var lastArg = 1 + valuesCount*4
   104  	return fmt.Sprintf(fmtSetWatchedFieldValues, strings.Join(args, ", "), lastArg)
   105  }
   106  
   107  func TxSetWatchedFieldValues(tx *sql.Tx, wf model.WatchedField, objectID *uuid.UUID) error {
   108  	var vals []interface{}
   109  	for _, fv := range wf.Fields {
   110  		vals = append(vals, objectID.String(), fv.JSONPath, fv.Value, fv.Missing)
   111  	}
   112  	vals = append(vals, objectID.String())
   113  
   114  	_, err := tx.Exec(stmtSetWatchedField(len(wf.Fields)), vals...)
   115  	return err
   116  }
   117  
   118  // SetWatchedField upserts the watched field and cleans up outdated values. If the timestamp is out of date, no data is updated.
   119  func (h *DBHandle) SetWatchedField(ctx context.Context, wf model.WatchedField) error {
   120  	tx, err := h.DB.BeginTx(ctx, nil)
   121  	if err != nil {
   122  		return fmt.Errorf("error beginning transaction for watched field: %w", err)
   123  	}
   124  	defer tx.Rollback() //nolint: errcheck
   125  
   126  	objectID, err := TxSetWatchedFieldObject(tx, wf)
   127  	if err != nil {
   128  		return err
   129  	}
   130  
   131  	if err := TxSetWatchedFieldValues(tx, wf, objectID); err != nil {
   132  		return err
   133  	}
   134  
   135  	if err := tx.Commit(); err != nil {
   136  		return fmt.Errorf("error committing watched field: %w", err)
   137  	}
   138  	return nil
   139  }
   140  
   141  const stmtDeleteOutdatedWatchedFieldObjects = `UPDATE watched_field_objects
   142  SET (watched_at, deleted) = ($1, 'true')
   143  WHERE (cluster_edge_id, deleted) = ($2, 'false')
   144  AND watched_at < $3
   145  RETURNING object_id
   146  `
   147  
   148  const stmtDeleteOutdatedWatchedFieldValues = `DELETE FROM watched_field_values WHERE object_id = ANY ($1)`
   149  
   150  func (h *DBHandle) DeleteOutdatedWatchedFieldObjects(ctx context.Context, sm model.ScrapeMessage) error {
   151  	tx, err := h.DB.BeginTx(ctx, nil)
   152  	if err != nil {
   153  		return fmt.Errorf("error beginning transaction to delete outdated watched field objects: %w", err)
   154  	}
   155  	defer tx.Rollback() //nolint: errcheck
   156  
   157  	var objectIDs []string
   158  	var vals = []interface{}{
   159  		sm.StartTime,
   160  		sm.Cluster.String(),
   161  		sm.StartTime,
   162  	}
   163  	// delete watched field objects
   164  	rows, err := tx.Query(stmtDeleteOutdatedWatchedFieldObjects, vals...)
   165  	if err != nil {
   166  		return fmt.Errorf("error deleting outdated watched field objects: %w", err)
   167  	}
   168  	for rows.Next() {
   169  		var objectID uuid.UUID
   170  		if err := rows.Scan(&objectID); err != nil {
   171  			return fmt.Errorf("error scanning outdated watched field objects: %w", err)
   172  		}
   173  		objectIDs = append(objectIDs, objectID.String())
   174  	}
   175  	if err := rows.Err(); err != nil {
   176  		return fmt.Errorf("error scanning all outdated watched field objects: %w", err)
   177  	}
   178  
   179  	// delete watched field values
   180  	if len(objectIDs) != 0 {
   181  		if _, err := tx.Exec(stmtDeleteOutdatedWatchedFieldValues, objectIDs); err != nil {
   182  			return fmt.Errorf("error deleting outdated watched field values: %w", err)
   183  		}
   184  	}
   185  
   186  	if err := tx.Commit(); err != nil {
   187  		return fmt.Errorf("error committing delete watched field: %w", err)
   188  	}
   189  	return nil
   190  }
   191  
   192  // stmtDeleteWatchedField also deletes the watched_field_values on cascade.
   193  const stmtDeleteWatchedFieldObject = `UPDATE watched_field_objects
   194  SET (watched_at, deleted) = ($1, 'true')
   195  WHERE (cluster_edge_id, api_version, kind, name, namespace) = ($2, $3, $4, $5, $6)
   196  AND watched_at < $7
   197  RETURNING object_id`
   198  
   199  const stmtDeleteWatchedFieldValues = `DELETE FROM watched_field_values
   200  WHERE object_id = $1
   201  `
   202  
   203  func (h *DBHandle) DeleteWatchedField(ctx context.Context, wf model.WatchedField) error {
   204  	tx, err := h.DB.BeginTx(ctx, nil)
   205  	if err != nil {
   206  		return fmt.Errorf("error beginning transaction to delete watched field: %w", err)
   207  	}
   208  	defer tx.Rollback() //nolint: errcheck
   209  
   210  	var objectID uuid.UUID
   211  	var vals = []interface{}{
   212  		wf.Timestamp,
   213  		wf.Cluster.String(),
   214  		wf.APIVersion,
   215  		wf.Kind,
   216  		wf.Name,
   217  		wf.Namespace,
   218  		wf.Timestamp,
   219  	}
   220  	if err = tx.QueryRow(stmtDeleteWatchedFieldObject, vals...).Scan(&objectID); err != nil {
   221  		if errors.Is(err, sql.ErrNoRows) {
   222  			return IgnoredMessageErrorf("outdated delete for watched field object")
   223  		}
   224  		return fmt.Errorf("error deleting watched field object: %w", err)
   225  	}
   226  
   227  	if _, err = tx.Exec(stmtDeleteWatchedFieldValues, objectID.String()); err != nil {
   228  		return fmt.Errorf("error deleting watched field values: %w", err)
   229  	}
   230  
   231  	if err := tx.Commit(); err != nil {
   232  		return fmt.Errorf("error committing delete watched field: %w", err)
   233  	}
   234  	return nil
   235  }
   236  
   237  const stmtGarbageCollectDeletedWatchedFields = `WITH
   238  gc_deleted AS (
   239    DELETE FROM watched_field_objects
   240    WHERE deleted IS TRUE
   241    AND watched_at < NOW() - INTERVAL '1 HOUR'
   242    RETURNING *
   243  )
   244  SELECT COUNT(*) FROM gc_deleted
   245  `
   246  
   247  // GarbageCollectDeletedWatchedFieldObjects garbage collects watched field objects that have been deleted for over an hour.
   248  // Since garbage collection requires a full table scan, it should only occur at startup when psqlinjector is not processing messages.
   249  // Running this query before processing messages prevents deadlocks and performance hiccups in the database.
   250  func (h *DBHandle) GarbageCollectDeletedWatchedFieldObjects(ctx context.Context) (int, error) {
   251  	tx, err := h.DB.BeginTx(ctx, nil)
   252  	if err != nil {
   253  		return 0, fmt.Errorf("error beginning transaction to garbage collect deleted watched field objects: %w", err)
   254  	}
   255  	defer tx.Rollback() //nolint: errcheck
   256  
   257  	var deletedCount int
   258  	if err = tx.QueryRow(stmtGarbageCollectDeletedWatchedFields).Scan(&deletedCount); err != nil {
   259  		return 0, fmt.Errorf("error garbage collecting deleted watched field objects: %w", err)
   260  	}
   261  
   262  	if err := tx.Commit(); err != nil {
   263  		return 0, fmt.Errorf("error committing garbage collection of deleted watched field objects: %w", err)
   264  	}
   265  	return deletedCount, nil
   266  }
   267  

View as plain text