...

Source file src/edge-infra.dev/test/f2/x/postgres/postgres.go

Documentation: edge-infra.dev/test/f2/x/postgres

     1  package postgres
     2  
     3  import (
     4  	"database/sql"
     5  	"flag"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"path"
    11  	"path/filepath"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/bazelbuild/rules_go/go/runfiles"
    16  	pgsql "github.com/fergusstrange/embedded-postgres"
    17  	"github.com/golang-migrate/migrate/v4/database/postgres"
    18  	_ "github.com/jackc/pgx/v4/stdlib" // TODO remove once using package cloudsql to create DB connections
    19  	"github.com/stretchr/testify/require"
    20  
    21  	edgesql "edge-infra.dev/pkg/edge/api/sql"
    22  	"edge-infra.dev/pkg/edge/api/sql/plugin"
    23  	"edge-infra.dev/pkg/lib/build/bazel"
    24  	"edge-infra.dev/pkg/lib/compression"
    25  	"edge-infra.dev/pkg/lib/gcp/cloudsql"
    26  	"edge-infra.dev/pkg/lib/logging"
    27  	"edge-infra.dev/test/f2"
    28  	"edge-infra.dev/test/f2/fctx"
    29  	"edge-infra.dev/test/f2/integration"
    30  )
    31  
    32  const (
    33  	postgresName = "postgres"
    34  )
    35  
    36  const (
    37  	// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
    38  	maxSchemaLength = 63
    39  )
    40  
    41  // Postgres is a f2 extension that sets up an embedded Postgres server
    42  type Postgres struct {
    43  	ConnectionName string
    44  	Host           string
    45  	Port           uint
    46  	User           string
    47  	Password       string
    48  	Database       string
    49  	MaxConns       int
    50  	MaxIdleConns   int
    51  
    52  	k8sHost string
    53  
    54  	epg     *pgsql.EmbeddedPostgres
    55  	options *options
    56  
    57  	dsn    string
    58  	schema string
    59  
    60  	// gdb uses default schema
    61  	gdb *sql.DB
    62  	// Scoped to a specific test schema
    63  	db *sql.DB
    64  }
    65  
    66  const (
    67  	embeddedTxzVar = "TEST_ASSET_EMBEDDED_POSTGRES_TXZ"
    68  )
    69  
    70  // New initialises a new [Postgres] struct. The extension will be initialised
    71  // with migrations from edgesql by default. if data with custom data structures are needed and these would conflict
    72  // with edgesql tables, SkipSeedModel [Option] should be passed:
    73  //
    74  //	pg := postgres.New(postgres.SkipSeedModel())
    75  func New(opts ...Option) *Postgres {
    76  	o := makeOptions(opts...)
    77  	return &Postgres{options: o}
    78  }
    79  
    80  // DSN returns a DSN string for the extension. Can be used to connect via database/sql.Open.
    81  //
    82  // As integration tests are expected to run against both CloudSQL and regular
    83  // postgres databases, it is recommended that tests use [Postgres.DB] to fetch a
    84  // preconfigured sql.DB struct, rather than manually calling databas/sql.Open
    85  func (pg *Postgres) DSN() string {
    86  	// TODO: how will `pkg/f8n/warehouse/forwarder/subscriber_test.go` work when we connect to CloudSQL
    87  
    88  	// includes the search_path
    89  	return pg.dsn
    90  }
    91  
    92  // dsn without the search_path
    93  func (pg *Postgres) globalDSN() string {
    94  	return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", pg.Host, pg.Port, pg.User, pg.Password, pg.Database)
    95  }
    96  
    97  // DB returns an initialised *sql.DB with the search_path set to help isolate tests
    98  //
    99  // DB returns a single sql.DB instance for every call within a given test. Tests
   100  // generally shouldn't call db.Close manually.
   101  func (pg *Postgres) DB() *sql.DB {
   102  	return pg.db
   103  }
   104  
   105  // Returns the schema to be used by the individual test for test isolation
   106  func (pg *Postgres) Schema() string {
   107  	return pg.schema
   108  }
   109  
   110  // K8SHost returns the hostname to be used to connect to the DB when connecting
   111  // from a pod within the ktest cluster. The [Postgres.Host] field should be used
   112  // in most cases during normal connection from the test binary.
   113  func (pg *Postgres) K8SHost() string {
   114  	if pg.k8sHost != "" {
   115  		return pg.k8sHost
   116  	}
   117  	return pg.Host
   118  }
   119  
   120  // FromContext attempts to fetch an instance of Postgres from the test context and
   121  // returns an error if it is not discovered.
   122  func FromContext(ctx fctx.Context) (*Postgres, error) {
   123  	v := fctx.ValueFrom[Postgres](ctx)
   124  	if v == nil {
   125  		return nil, fmt.Errorf("%w: warehouse.Postgres extension", fctx.ErrNotFound)
   126  	}
   127  	return v, nil
   128  }
   129  
   130  // FromContextT is a testing variant of FromContext that immediately fails the
   131  // test if Postgres isnt presnt in the testing context.
   132  func FromContextT(ctx fctx.Context, t *testing.T) *Postgres {
   133  	return fctx.ValueFromT[Postgres](ctx, t)
   134  }
   135  
   136  // IntoContext stores the framework extension in the test context.
   137  func (pg *Postgres) IntoContext(ctx fctx.Context) fctx.Context {
   138  	return fctx.ValueInto(ctx, pg)
   139  }
   140  
   141  func (pg *Postgres) RegisterFns(f f2.Framework) {
   142  	if integration.IsL1() {
   143  		f.Setup(func(ctx f2.Context) (f2.Context, error) {
   144  			// Don't use cli supplied connection details when running as L1 test
   145  			// as we are starting a new embedded postgres process
   146  			pg.ConnectionName = ""
   147  			pg.Host = "127.0.0.1"
   148  			pg.Port = 0
   149  			pg.User = postgresName
   150  			pg.Password = postgresName
   151  			pg.Database = postgresName
   152  
   153  			return ctx, pg.newEmbeddedDB()
   154  		})
   155  	}
   156  
   157  	f.Setup(func(ctx f2.Context) (f2.Context, error) {
   158  		// Set the global database struct to connect using the default schema
   159  		// Set the test specific database struct to the default schema, can be
   160  		// overridden by a test specific DB later
   161  		db, err := pg.initializeGlobalDB()
   162  		if err != nil {
   163  			return ctx, fmt.Errorf("opening database connection: %w", err)
   164  		}
   165  		pg.gdb = db
   166  
   167  		dsn := pg.globalDSN()
   168  		pg.dsn = dsn
   169  
   170  		return ctx, nil
   171  	})
   172  
   173  	// Setup schema isolation by creating a unique schema for each test
   174  	f.BeforeEachTest(func(ctx f2.Context, t *testing.T) (f2.Context, error) {
   175  		// TODO: parallel tests. Separate PG struct in child context?
   176  		if !pg.options.skipSchemaIsolation {
   177  			name := strings.ToLower(t.Name())
   178  			schemaName := name + "_" + ctx.RunID
   179  
   180  			// if the proposed namespace is above the max shorten the name
   181  			if len(schemaName) > maxSchemaLength {
   182  				t.Log("proposed schema name was too long", schemaName)
   183  				schemaName = name[:len(name)-(len(schemaName)-maxSchemaLength)] + "_" + ctx.RunID
   184  			}
   185  
   186  			// can't use placeholder parameters when creating schema, gives syntax_error
   187  			_, err := pg.gdb.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA \"%s\";", schemaName))
   188  			if err != nil {
   189  				return ctx, fmt.Errorf("error creating test schema: %w", err)
   190  			}
   191  
   192  			pg.dsn = fmt.Sprintf("%s search_path=%s", pg.globalDSN(), schemaName)
   193  			pg.schema = schemaName
   194  
   195  			db, err := pg.initializeDB()
   196  			if err != nil {
   197  				return ctx, fmt.Errorf("error initialising db: %w", err)
   198  			}
   199  
   200  			pg.db = db
   201  		}
   202  		return ctx, nil
   203  	})
   204  
   205  	// apply migration to the db
   206  	f.BeforeEachTest(func(ctx f2.Context, _ *testing.T) (f2.Context, error) {
   207  		if pg.options.applySeedModel {
   208  			db := pg.DB()
   209  
   210  			err := seedExistingDB(db, []plugin.Seed{})
   211  			if err != nil {
   212  				return ctx, err
   213  			}
   214  		}
   215  		return ctx, nil
   216  	})
   217  
   218  	f.AfterEachTest(func(ctx f2.Context, _ *testing.T) (f2.Context, error) {
   219  		if pg.options.skipSchemaIsolation {
   220  			return ctx, nil
   221  		}
   222  
   223  		_, err := pg.gdb.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA \"%s\" CASCADE;", pg.schema))
   224  		if err != nil {
   225  			return ctx, fmt.Errorf("error dropping schema (%s): %w", pg.schema, err)
   226  		}
   227  		return ctx, err
   228  	})
   229  
   230  	f.Teardown(func(ctx fctx.Context) (fctx.Context, error) {
   231  		if integration.IsL1() {
   232  			err := pg.epg.Stop()
   233  			if err != nil {
   234  				return ctx, err
   235  			}
   236  		}
   237  		return ctx, nil
   238  	})
   239  }
   240  
   241  func (pg *Postgres) initDB(schema string) (*sql.DB, error) {
   242  	// Use cloudsql to handle connection
   243  	var edgeDB *cloudsql.EdgePostgres
   244  
   245  	// Initialise the correct type of connection: CLoudSQL or local host, port pair,
   246  	// determined by the presence of the ConnectionName configuration param
   247  	switch {
   248  	case pg.ConnectionName != "":
   249  		edgeDB = cloudsql.GCPPostgresConnection(pg.ConnectionName)
   250  	case pg.ConnectionName == "" && pg.Host != "":
   251  		if pg.Port == 0 {
   252  			return nil, fmt.Errorf("postgres-port is required")
   253  		}
   254  		edgeDB = cloudsql.PostgresConnection(pg.Host, fmt.Sprint(pg.Port)).
   255  			Password(pg.Password)
   256  	default:
   257  		return nil, fmt.Errorf("postgres-connection-name or postgres-host must be provided")
   258  	}
   259  
   260  	// Open the connection to the configured DB
   261  	edgeDB = edgeDB.
   262  		DBName(pg.Database).
   263  		Username(pg.User).
   264  		MaxOpenConns(pg.MaxConns)
   265  
   266  	if schema != "" {
   267  		edgeDB = edgeDB.SearchPath(schema)
   268  	}
   269  
   270  	db, err := edgeDB.NewConnection()
   271  	if err != nil {
   272  		return nil, fmt.Errorf("error opening connection to the database: %w", err)
   273  	}
   274  	return db, nil
   275  }
   276  
   277  // Creates a new connection to the database without setting a schema
   278  func (pg *Postgres) initializeGlobalDB() (*sql.DB, error) {
   279  	return pg.initDB("")
   280  }
   281  
   282  // Creates a new connection with the currently configured schema set to the search_path
   283  func (pg *Postgres) initializeDB() (*sql.DB, error) {
   284  	return pg.initDB(pg.Schema())
   285  }
   286  
   287  // BindFlags registers test flags for the framework extension.
   288  func (pg *Postgres) BindFlags(fs *flag.FlagSet) {
   289  	fs.StringVar(&pg.ConnectionName,
   290  		"postgres-connection-name",
   291  		"",
   292  		"CloudSQL connection name",
   293  	)
   294  	fs.StringVar(&pg.Host,
   295  		"postgres-host",
   296  		"127.0.0.1",
   297  		"the host to connect to",
   298  	)
   299  	fs.UintVar(&pg.Port,
   300  		"postgres-port",
   301  		5432,
   302  		"port to connect to for L2 tests",
   303  	)
   304  	fs.StringVar(&pg.User,
   305  		"postgres-user",
   306  		"postgres",
   307  		"user to create or connect as",
   308  	)
   309  	fs.StringVar(&pg.Password,
   310  		"postgres-pass",
   311  		"postgres",
   312  		"password to set or connect with",
   313  	)
   314  	fs.StringVar(&pg.Database,
   315  		"postgres-database",
   316  		"postgres",
   317  		"name of the database to create or connect to",
   318  	)
   319  	fs.IntVar(&pg.MaxConns,
   320  		"postgres-max-conns",
   321  		10,
   322  		"maximum amount of open client connections to allow",
   323  	)
   324  	fs.IntVar(&pg.MaxIdleConns,
   325  		"postgres-max-idle-conns",
   326  		10,
   327  		"maximum amount of client connections allowed in the idle pool",
   328  	)
   329  	fs.StringVar(&pg.k8sHost,
   330  		"postgres-k8s-host",
   331  		"",
   332  		"Set this option when you need to use a different DNS address to connect to the DB from K8S containers running within the ktest cluster.",
   333  	)
   334  }
   335  
   336  func findOpenPort() (int, error) {
   337  	addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
   338  	if err != nil {
   339  		return 0, err
   340  	}
   341  	l, err := net.ListenTCP("tcp", addr)
   342  	if err != nil {
   343  		return 0, err
   344  	}
   345  	defer l.Close()
   346  	return l.Addr().(*net.TCPAddr).Port, nil
   347  }
   348  
   349  func maybeSetEnv(key, bin, runfilePath string) error {
   350  	if os.Getenv(key) != "" {
   351  		return nil
   352  	}
   353  	if !bazel.IsBazelTest() {
   354  		return fmt.Errorf(`failed to find integration test dependency %q.
   355  Either re-run this test using "bazel test" or set the %s environment variable`, bin, key)
   356  	}
   357  	p, err := runfiles.Rlocation(filepath.Join(os.Getenv(bazel.TestWorkspace), runfilePath))
   358  	if err != nil {
   359  		return fmt.Errorf("failed to look up test dependency %q: %w. ensure that "+
   360  			"it is present in this test targets 'data' attribute", bin, err)
   361  	}
   362  	os.Setenv(key, p)
   363  	return nil
   364  }
   365  
   366  // initialises embedded postgres
   367  func (pg *Postgres) newEmbeddedDB() error {
   368  	err := maybeSetEnv(embeddedTxzVar, "postgres.txz", "/hack/tools/postgres.txz")
   369  	if err != nil {
   370  		return err
   371  	}
   372  
   373  	embeddedTxzFile := os.Getenv(embeddedTxzVar)
   374  
   375  	tempDir, err := bazel.NewTestTmpDir("pgsql-*")
   376  	if err != nil {
   377  		return err
   378  	}
   379  	pgRuntimePath := path.Join(tempDir, "runtime")
   380  	pgTempDir := tempDir
   381  
   382  	err = compression.DecompressTarXz(embeddedTxzFile, tempDir)
   383  	if err != nil {
   384  		return err
   385  	}
   386  
   387  	if pg.Port == 0 {
   388  		port, err := findOpenPort()
   389  		if err != nil {
   390  			return err
   391  		}
   392  		pg.Port = uint(port) /* #nosec G115 */
   393  	}
   394  
   395  	cfg := pgsql.DefaultConfig()
   396  	cfg = cfg.Port(uint32(pg.Port)) /* #nosec G115 */
   397  	cfg = cfg.Username(pg.User)
   398  	cfg = cfg.Password(pg.Password)
   399  	cfg = cfg.Version(pgsql.V14)
   400  	cfg = cfg.RuntimePath(pgRuntimePath)
   401  	cfg = cfg.BinariesPath(pgTempDir)
   402  	cfg = cfg.Database(pg.Database)
   403  	// cfg = cfg.StartParameters(map[string]string{"max_connections": "20", "shared_buffers": "40"})
   404  
   405  	epg := pgsql.NewDatabase(cfg)
   406  	if err := epg.Start(); err != nil {
   407  		return fmt.Errorf("failed to start database, err: %w", err)
   408  	}
   409  	pg.epg = epg
   410  	return nil
   411  }
   412  
   413  // Applies edge-sql migrations and seeds data if provided.
   414  func seedExistingDB(db *sql.DB, seedData []plugin.Seed) error {
   415  	driver, err := postgres.WithInstance(db, &postgres.Config{})
   416  	if err != nil {
   417  		return err
   418  	}
   419  
   420  	var pluginConfig = &plugin.Config{
   421  		MigrationAction: "up",
   422  		Ordered:         true,
   423  		TestMode:        true,
   424  		Data:            seedData,
   425  	}
   426  	// TODO ncr-swt-retail/edge-roadmap#5668 . Would be useful to have a logger here incase of error with migration or seeding.
   427  	logger := logging.New(logging.To(io.Discard))
   428  	return edgesql.SetupEdgeTables(pluginConfig, driver, logger, db)
   429  }
   430  
   431  // WithData is a f2.StepFn that can be used to add data to the initialised database.
   432  // Requires [ApplySeedModel] to be specified.
   433  //
   434  // The edge-infra SeededPostgres package includes a Seed variable which can be
   435  // used as an example to create Seed data or can be passed directly to this
   436  // function.
   437  func WithData(seedData []plugin.Seed, msgAndArgs ...interface{}) f2.StepFn {
   438  	return func(ctx f2.Context, t *testing.T) f2.Context {
   439  		pg := FromContextT(ctx, t)
   440  		db := pg.DB()
   441  		err := seedExistingDB(db, seedData)
   442  		require.NoError(t, err, msgAndArgs...)
   443  		return ctx
   444  	}
   445  }
   446  

View as plain text