package postgres import ( "database/sql" "flag" "fmt" "io" "net" "os" "path" "path/filepath" "strings" "testing" "github.com/bazelbuild/rules_go/go/runfiles" pgsql "github.com/fergusstrange/embedded-postgres" "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/jackc/pgx/v4/stdlib" // TODO remove once using package cloudsql to create DB connections "github.com/stretchr/testify/require" edgesql "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/api/sql/plugin" "edge-infra.dev/pkg/lib/build/bazel" "edge-infra.dev/pkg/lib/compression" "edge-infra.dev/pkg/lib/gcp/cloudsql" "edge-infra.dev/pkg/lib/logging" "edge-infra.dev/test/f2" "edge-infra.dev/test/f2/fctx" "edge-infra.dev/test/f2/integration" ) const ( postgresName = "postgres" ) const ( // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS maxSchemaLength = 63 ) // Postgres is a f2 extension that sets up an embedded Postgres server type Postgres struct { ConnectionName string Host string Port uint User string Password string Database string MaxConns int MaxIdleConns int k8sHost string epg *pgsql.EmbeddedPostgres options *options dsn string schema string // gdb uses default schema gdb *sql.DB // Scoped to a specific test schema db *sql.DB } const ( embeddedTxzVar = "TEST_ASSET_EMBEDDED_POSTGRES_TXZ" ) // New initialises a new [Postgres] struct. The extension will be initialised // with migrations from edgesql by default. if data with custom data structures are needed and these would conflict // with edgesql tables, SkipSeedModel [Option] should be passed: // // pg := postgres.New(postgres.SkipSeedModel()) func New(opts ...Option) *Postgres { o := makeOptions(opts...) return &Postgres{options: o} } // DSN returns a DSN string for the extension. Can be used to connect via database/sql.Open. // // As integration tests are expected to run against both CloudSQL and regular // postgres databases, it is recommended that tests use [Postgres.DB] to fetch a // preconfigured sql.DB struct, rather than manually calling databas/sql.Open func (pg *Postgres) DSN() string { // TODO: how will `pkg/f8n/warehouse/forwarder/subscriber_test.go` work when we connect to CloudSQL // includes the search_path return pg.dsn } // dsn without the search_path func (pg *Postgres) globalDSN() string { return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", pg.Host, pg.Port, pg.User, pg.Password, pg.Database) } // DB returns an initialised *sql.DB with the search_path set to help isolate tests // // DB returns a single sql.DB instance for every call within a given test. Tests // generally shouldn't call db.Close manually. func (pg *Postgres) DB() *sql.DB { return pg.db } // Returns the schema to be used by the individual test for test isolation func (pg *Postgres) Schema() string { return pg.schema } // K8SHost returns the hostname to be used to connect to the DB when connecting // from a pod within the ktest cluster. The [Postgres.Host] field should be used // in most cases during normal connection from the test binary. func (pg *Postgres) K8SHost() string { if pg.k8sHost != "" { return pg.k8sHost } return pg.Host } // FromContext attempts to fetch an instance of Postgres from the test context and // returns an error if it is not discovered. func FromContext(ctx fctx.Context) (*Postgres, error) { v := fctx.ValueFrom[Postgres](ctx) if v == nil { return nil, fmt.Errorf("%w: warehouse.Postgres extension", fctx.ErrNotFound) } return v, nil } // FromContextT is a testing variant of FromContext that immediately fails the // test if Postgres isnt presnt in the testing context. func FromContextT(ctx fctx.Context, t *testing.T) *Postgres { return fctx.ValueFromT[Postgres](ctx, t) } // IntoContext stores the framework extension in the test context. func (pg *Postgres) IntoContext(ctx fctx.Context) fctx.Context { return fctx.ValueInto(ctx, pg) } func (pg *Postgres) RegisterFns(f f2.Framework) { if integration.IsL1() { f.Setup(func(ctx f2.Context) (f2.Context, error) { // Don't use cli supplied connection details when running as L1 test // as we are starting a new embedded postgres process pg.ConnectionName = "" pg.Host = "127.0.0.1" pg.Port = 0 pg.User = postgresName pg.Password = postgresName pg.Database = postgresName return ctx, pg.newEmbeddedDB() }) } f.Setup(func(ctx f2.Context) (f2.Context, error) { // Set the global database struct to connect using the default schema // Set the test specific database struct to the default schema, can be // overridden by a test specific DB later db, err := pg.initializeGlobalDB() if err != nil { return ctx, fmt.Errorf("opening database connection: %w", err) } pg.gdb = db dsn := pg.globalDSN() pg.dsn = dsn return ctx, nil }) // Setup schema isolation by creating a unique schema for each test f.BeforeEachTest(func(ctx f2.Context, t *testing.T) (f2.Context, error) { // TODO: parallel tests. Separate PG struct in child context? if !pg.options.skipSchemaIsolation { name := strings.ToLower(t.Name()) schemaName := name + "_" + ctx.RunID // if the proposed namespace is above the max shorten the name if len(schemaName) > maxSchemaLength { t.Log("proposed schema name was too long", schemaName) schemaName = name[:len(name)-(len(schemaName)-maxSchemaLength)] + "_" + ctx.RunID } // can't use placeholder parameters when creating schema, gives syntax_error _, err := pg.gdb.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA \"%s\";", schemaName)) if err != nil { return ctx, fmt.Errorf("error creating test schema: %w", err) } pg.dsn = fmt.Sprintf("%s search_path=%s", pg.globalDSN(), schemaName) pg.schema = schemaName db, err := pg.initializeDB() if err != nil { return ctx, fmt.Errorf("error initialising db: %w", err) } pg.db = db } return ctx, nil }) // apply migration to the db f.BeforeEachTest(func(ctx f2.Context, _ *testing.T) (f2.Context, error) { if pg.options.applySeedModel { db := pg.DB() err := seedExistingDB(db, []plugin.Seed{}) if err != nil { return ctx, err } } return ctx, nil }) f.AfterEachTest(func(ctx f2.Context, _ *testing.T) (f2.Context, error) { if pg.options.skipSchemaIsolation { return ctx, nil } _, err := pg.gdb.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA \"%s\" CASCADE;", pg.schema)) if err != nil { return ctx, fmt.Errorf("error dropping schema (%s): %w", pg.schema, err) } return ctx, err }) f.Teardown(func(ctx fctx.Context) (fctx.Context, error) { if integration.IsL1() { err := pg.epg.Stop() if err != nil { return ctx, err } } return ctx, nil }) } func (pg *Postgres) initDB(schema string) (*sql.DB, error) { // Use cloudsql to handle connection var edgeDB *cloudsql.EdgePostgres // Initialise the correct type of connection: CLoudSQL or local host, port pair, // determined by the presence of the ConnectionName configuration param switch { case pg.ConnectionName != "": edgeDB = cloudsql.GCPPostgresConnection(pg.ConnectionName) case pg.ConnectionName == "" && pg.Host != "": if pg.Port == 0 { return nil, fmt.Errorf("postgres-port is required") } edgeDB = cloudsql.PostgresConnection(pg.Host, fmt.Sprint(pg.Port)). Password(pg.Password) default: return nil, fmt.Errorf("postgres-connection-name or postgres-host must be provided") } // Open the connection to the configured DB edgeDB = edgeDB. DBName(pg.Database). Username(pg.User). MaxOpenConns(pg.MaxConns) if schema != "" { edgeDB = edgeDB.SearchPath(schema) } db, err := edgeDB.NewConnection() if err != nil { return nil, fmt.Errorf("error opening connection to the database: %w", err) } return db, nil } // Creates a new connection to the database without setting a schema func (pg *Postgres) initializeGlobalDB() (*sql.DB, error) { return pg.initDB("") } // Creates a new connection with the currently configured schema set to the search_path func (pg *Postgres) initializeDB() (*sql.DB, error) { return pg.initDB(pg.Schema()) } // BindFlags registers test flags for the framework extension. func (pg *Postgres) BindFlags(fs *flag.FlagSet) { fs.StringVar(&pg.ConnectionName, "postgres-connection-name", "", "CloudSQL connection name", ) fs.StringVar(&pg.Host, "postgres-host", "127.0.0.1", "the host to connect to", ) fs.UintVar(&pg.Port, "postgres-port", 5432, "port to connect to for L2 tests", ) fs.StringVar(&pg.User, "postgres-user", "postgres", "user to create or connect as", ) fs.StringVar(&pg.Password, "postgres-pass", "postgres", "password to set or connect with", ) fs.StringVar(&pg.Database, "postgres-database", "postgres", "name of the database to create or connect to", ) fs.IntVar(&pg.MaxConns, "postgres-max-conns", 10, "maximum amount of open client connections to allow", ) fs.IntVar(&pg.MaxIdleConns, "postgres-max-idle-conns", 10, "maximum amount of client connections allowed in the idle pool", ) fs.StringVar(&pg.k8sHost, "postgres-k8s-host", "", "Set this option when you need to use a different DNS address to connect to the DB from K8S containers running within the ktest cluster.", ) } func findOpenPort() (int, error) { addr, err := net.ResolveTCPAddr("tcp", "localhost:0") if err != nil { return 0, err } l, err := net.ListenTCP("tcp", addr) if err != nil { return 0, err } defer l.Close() return l.Addr().(*net.TCPAddr).Port, nil } func maybeSetEnv(key, bin, runfilePath string) error { if os.Getenv(key) != "" { return nil } if !bazel.IsBazelTest() { return fmt.Errorf(`failed to find integration test dependency %q. Either re-run this test using "bazel test" or set the %s environment variable`, bin, key) } p, err := runfiles.Rlocation(filepath.Join(os.Getenv(bazel.TestWorkspace), runfilePath)) if err != nil { return fmt.Errorf("failed to look up test dependency %q: %w. ensure that "+ "it is present in this test targets 'data' attribute", bin, err) } os.Setenv(key, p) return nil } // initialises embedded postgres func (pg *Postgres) newEmbeddedDB() error { err := maybeSetEnv(embeddedTxzVar, "postgres.txz", "/hack/tools/postgres.txz") if err != nil { return err } embeddedTxzFile := os.Getenv(embeddedTxzVar) tempDir, err := bazel.NewTestTmpDir("pgsql-*") if err != nil { return err } pgRuntimePath := path.Join(tempDir, "runtime") pgTempDir := tempDir err = compression.DecompressTarXz(embeddedTxzFile, tempDir) if err != nil { return err } if pg.Port == 0 { port, err := findOpenPort() if err != nil { return err } pg.Port = uint(port) /* #nosec G115 */ } cfg := pgsql.DefaultConfig() cfg = cfg.Port(uint32(pg.Port)) /* #nosec G115 */ cfg = cfg.Username(pg.User) cfg = cfg.Password(pg.Password) cfg = cfg.Version(pgsql.V14) cfg = cfg.RuntimePath(pgRuntimePath) cfg = cfg.BinariesPath(pgTempDir) cfg = cfg.Database(pg.Database) // cfg = cfg.StartParameters(map[string]string{"max_connections": "20", "shared_buffers": "40"}) epg := pgsql.NewDatabase(cfg) if err := epg.Start(); err != nil { return fmt.Errorf("failed to start database, err: %w", err) } pg.epg = epg return nil } // Applies edge-sql migrations and seeds data if provided. func seedExistingDB(db *sql.DB, seedData []plugin.Seed) error { driver, err := postgres.WithInstance(db, &postgres.Config{}) if err != nil { return err } var pluginConfig = &plugin.Config{ MigrationAction: "up", Ordered: true, TestMode: true, Data: seedData, } // TODO ncr-swt-retail/edge-roadmap#5668 . Would be useful to have a logger here incase of error with migration or seeding. logger := logging.New(logging.To(io.Discard)) return edgesql.SetupEdgeTables(pluginConfig, driver, logger, db) } // WithData is a f2.StepFn that can be used to add data to the initialised database. // Requires [ApplySeedModel] to be specified. // // The edge-infra SeededPostgres package includes a Seed variable which can be // used as an example to create Seed data or can be passed directly to this // function. func WithData(seedData []plugin.Seed, msgAndArgs ...interface{}) f2.StepFn { return func(ctx f2.Context, t *testing.T) f2.Context { pg := FromContextT(ctx, t) db := pg.DB() err := seedExistingDB(db, seedData) require.NoError(t, err, msgAndArgs...) return ctx } }