1 package postgres
2
3 import (
4 "database/sql"
5 "flag"
6 "fmt"
7 "io"
8 "net"
9 "os"
10 "path"
11 "path/filepath"
12 "strings"
13 "testing"
14
15 "github.com/bazelbuild/rules_go/go/runfiles"
16 pgsql "github.com/fergusstrange/embedded-postgres"
17 "github.com/golang-migrate/migrate/v4/database/postgres"
18 _ "github.com/jackc/pgx/v4/stdlib"
19 "github.com/stretchr/testify/require"
20
21 edgesql "edge-infra.dev/pkg/edge/api/sql"
22 "edge-infra.dev/pkg/edge/api/sql/plugin"
23 "edge-infra.dev/pkg/lib/build/bazel"
24 "edge-infra.dev/pkg/lib/compression"
25 "edge-infra.dev/pkg/lib/gcp/cloudsql"
26 "edge-infra.dev/pkg/lib/logging"
27 "edge-infra.dev/test/f2"
28 "edge-infra.dev/test/f2/fctx"
29 "edge-infra.dev/test/f2/integration"
30 )
31
32 const (
33 postgresName = "postgres"
34 )
35
36 const (
37
38 maxSchemaLength = 63
39 )
40
41
42 type Postgres struct {
43 ConnectionName string
44 Host string
45 Port uint
46 User string
47 Password string
48 Database string
49 MaxConns int
50 MaxIdleConns int
51
52 k8sHost string
53
54 epg *pgsql.EmbeddedPostgres
55 options *options
56
57 dsn string
58 schema string
59
60
61 gdb *sql.DB
62
63 db *sql.DB
64 }
65
66 const (
67 embeddedTxzVar = "TEST_ASSET_EMBEDDED_POSTGRES_TXZ"
68 )
69
70
71
72
73
74
75 func New(opts ...Option) *Postgres {
76 o := makeOptions(opts...)
77 return &Postgres{options: o}
78 }
79
80
81
82
83
84
85 func (pg *Postgres) DSN() string {
86
87
88
89 return pg.dsn
90 }
91
92
93 func (pg *Postgres) globalDSN() string {
94 return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", pg.Host, pg.Port, pg.User, pg.Password, pg.Database)
95 }
96
97
98
99
100
101 func (pg *Postgres) DB() *sql.DB {
102 return pg.db
103 }
104
105
106 func (pg *Postgres) Schema() string {
107 return pg.schema
108 }
109
110
111
112
113 func (pg *Postgres) K8SHost() string {
114 if pg.k8sHost != "" {
115 return pg.k8sHost
116 }
117 return pg.Host
118 }
119
120
121
122 func FromContext(ctx fctx.Context) (*Postgres, error) {
123 v := fctx.ValueFrom[Postgres](ctx)
124 if v == nil {
125 return nil, fmt.Errorf("%w: warehouse.Postgres extension", fctx.ErrNotFound)
126 }
127 return v, nil
128 }
129
130
131
132 func FromContextT(ctx fctx.Context, t *testing.T) *Postgres {
133 return fctx.ValueFromT[Postgres](ctx, t)
134 }
135
136
137 func (pg *Postgres) IntoContext(ctx fctx.Context) fctx.Context {
138 return fctx.ValueInto(ctx, pg)
139 }
140
141 func (pg *Postgres) RegisterFns(f f2.Framework) {
142 if integration.IsL1() {
143 f.Setup(func(ctx f2.Context) (f2.Context, error) {
144
145
146 pg.ConnectionName = ""
147 pg.Host = "127.0.0.1"
148 pg.Port = 0
149 pg.User = postgresName
150 pg.Password = postgresName
151 pg.Database = postgresName
152
153 return ctx, pg.newEmbeddedDB()
154 })
155 }
156
157 f.Setup(func(ctx f2.Context) (f2.Context, error) {
158
159
160
161 db, err := pg.initializeGlobalDB()
162 if err != nil {
163 return ctx, fmt.Errorf("opening database connection: %w", err)
164 }
165 pg.gdb = db
166
167 dsn := pg.globalDSN()
168 pg.dsn = dsn
169
170 return ctx, nil
171 })
172
173
174 f.BeforeEachTest(func(ctx f2.Context, t *testing.T) (f2.Context, error) {
175
176 if !pg.options.skipSchemaIsolation {
177 name := strings.ToLower(t.Name())
178 schemaName := name + "_" + ctx.RunID
179
180
181 if len(schemaName) > maxSchemaLength {
182 t.Log("proposed schema name was too long", schemaName)
183 schemaName = name[:len(name)-(len(schemaName)-maxSchemaLength)] + "_" + ctx.RunID
184 }
185
186
187 _, err := pg.gdb.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA \"%s\";", schemaName))
188 if err != nil {
189 return ctx, fmt.Errorf("error creating test schema: %w", err)
190 }
191
192 pg.dsn = fmt.Sprintf("%s search_path=%s", pg.globalDSN(), schemaName)
193 pg.schema = schemaName
194
195 db, err := pg.initializeDB()
196 if err != nil {
197 return ctx, fmt.Errorf("error initialising db: %w", err)
198 }
199
200 pg.db = db
201 }
202 return ctx, nil
203 })
204
205
206 f.BeforeEachTest(func(ctx f2.Context, _ *testing.T) (f2.Context, error) {
207 if pg.options.applySeedModel {
208 db := pg.DB()
209
210 err := seedExistingDB(db, []plugin.Seed{})
211 if err != nil {
212 return ctx, err
213 }
214 }
215 return ctx, nil
216 })
217
218 f.AfterEachTest(func(ctx f2.Context, _ *testing.T) (f2.Context, error) {
219 if pg.options.skipSchemaIsolation {
220 return ctx, nil
221 }
222
223 _, err := pg.gdb.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA \"%s\" CASCADE;", pg.schema))
224 if err != nil {
225 return ctx, fmt.Errorf("error dropping schema (%s): %w", pg.schema, err)
226 }
227 return ctx, err
228 })
229
230 f.Teardown(func(ctx fctx.Context) (fctx.Context, error) {
231 if integration.IsL1() {
232 err := pg.epg.Stop()
233 if err != nil {
234 return ctx, err
235 }
236 }
237 return ctx, nil
238 })
239 }
240
241 func (pg *Postgres) initDB(schema string) (*sql.DB, error) {
242
243 var edgeDB *cloudsql.EdgePostgres
244
245
246
247 switch {
248 case pg.ConnectionName != "":
249 edgeDB = cloudsql.GCPPostgresConnection(pg.ConnectionName)
250 case pg.ConnectionName == "" && pg.Host != "":
251 if pg.Port == 0 {
252 return nil, fmt.Errorf("postgres-port is required")
253 }
254 edgeDB = cloudsql.PostgresConnection(pg.Host, fmt.Sprint(pg.Port)).
255 Password(pg.Password)
256 default:
257 return nil, fmt.Errorf("postgres-connection-name or postgres-host must be provided")
258 }
259
260
261 edgeDB = edgeDB.
262 DBName(pg.Database).
263 Username(pg.User).
264 MaxOpenConns(pg.MaxConns)
265
266 if schema != "" {
267 edgeDB = edgeDB.SearchPath(schema)
268 }
269
270 db, err := edgeDB.NewConnection()
271 if err != nil {
272 return nil, fmt.Errorf("error opening connection to the database: %w", err)
273 }
274 return db, nil
275 }
276
277
278 func (pg *Postgres) initializeGlobalDB() (*sql.DB, error) {
279 return pg.initDB("")
280 }
281
282
283 func (pg *Postgres) initializeDB() (*sql.DB, error) {
284 return pg.initDB(pg.Schema())
285 }
286
287
288 func (pg *Postgres) BindFlags(fs *flag.FlagSet) {
289 fs.StringVar(&pg.ConnectionName,
290 "postgres-connection-name",
291 "",
292 "CloudSQL connection name",
293 )
294 fs.StringVar(&pg.Host,
295 "postgres-host",
296 "127.0.0.1",
297 "the host to connect to",
298 )
299 fs.UintVar(&pg.Port,
300 "postgres-port",
301 5432,
302 "port to connect to for L2 tests",
303 )
304 fs.StringVar(&pg.User,
305 "postgres-user",
306 "postgres",
307 "user to create or connect as",
308 )
309 fs.StringVar(&pg.Password,
310 "postgres-pass",
311 "postgres",
312 "password to set or connect with",
313 )
314 fs.StringVar(&pg.Database,
315 "postgres-database",
316 "postgres",
317 "name of the database to create or connect to",
318 )
319 fs.IntVar(&pg.MaxConns,
320 "postgres-max-conns",
321 10,
322 "maximum amount of open client connections to allow",
323 )
324 fs.IntVar(&pg.MaxIdleConns,
325 "postgres-max-idle-conns",
326 10,
327 "maximum amount of client connections allowed in the idle pool",
328 )
329 fs.StringVar(&pg.k8sHost,
330 "postgres-k8s-host",
331 "",
332 "Set this option when you need to use a different DNS address to connect to the DB from K8S containers running within the ktest cluster.",
333 )
334 }
335
336 func findOpenPort() (int, error) {
337 addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
338 if err != nil {
339 return 0, err
340 }
341 l, err := net.ListenTCP("tcp", addr)
342 if err != nil {
343 return 0, err
344 }
345 defer l.Close()
346 return l.Addr().(*net.TCPAddr).Port, nil
347 }
348
349 func maybeSetEnv(key, bin, runfilePath string) error {
350 if os.Getenv(key) != "" {
351 return nil
352 }
353 if !bazel.IsBazelTest() {
354 return fmt.Errorf(`failed to find integration test dependency %q.
355 Either re-run this test using "bazel test" or set the %s environment variable`, bin, key)
356 }
357 p, err := runfiles.Rlocation(filepath.Join(os.Getenv(bazel.TestWorkspace), runfilePath))
358 if err != nil {
359 return fmt.Errorf("failed to look up test dependency %q: %w. ensure that "+
360 "it is present in this test targets 'data' attribute", bin, err)
361 }
362 os.Setenv(key, p)
363 return nil
364 }
365
366
367 func (pg *Postgres) newEmbeddedDB() error {
368 err := maybeSetEnv(embeddedTxzVar, "postgres.txz", "/hack/tools/postgres.txz")
369 if err != nil {
370 return err
371 }
372
373 embeddedTxzFile := os.Getenv(embeddedTxzVar)
374
375 tempDir, err := bazel.NewTestTmpDir("pgsql-*")
376 if err != nil {
377 return err
378 }
379 pgRuntimePath := path.Join(tempDir, "runtime")
380 pgTempDir := tempDir
381
382 err = compression.DecompressTarXz(embeddedTxzFile, tempDir)
383 if err != nil {
384 return err
385 }
386
387 if pg.Port == 0 {
388 port, err := findOpenPort()
389 if err != nil {
390 return err
391 }
392 pg.Port = uint(port)
393 }
394
395 cfg := pgsql.DefaultConfig()
396 cfg = cfg.Port(uint32(pg.Port))
397 cfg = cfg.Username(pg.User)
398 cfg = cfg.Password(pg.Password)
399 cfg = cfg.Version(pgsql.V14)
400 cfg = cfg.RuntimePath(pgRuntimePath)
401 cfg = cfg.BinariesPath(pgTempDir)
402 cfg = cfg.Database(pg.Database)
403
404
405 epg := pgsql.NewDatabase(cfg)
406 if err := epg.Start(); err != nil {
407 return fmt.Errorf("failed to start database, err: %w", err)
408 }
409 pg.epg = epg
410 return nil
411 }
412
413
414 func seedExistingDB(db *sql.DB, seedData []plugin.Seed) error {
415 driver, err := postgres.WithInstance(db, &postgres.Config{})
416 if err != nil {
417 return err
418 }
419
420 var pluginConfig = &plugin.Config{
421 MigrationAction: "up",
422 Ordered: true,
423 TestMode: true,
424 Data: seedData,
425 }
426
427 logger := logging.New(logging.To(io.Discard))
428 return edgesql.SetupEdgeTables(pluginConfig, driver, logger, db)
429 }
430
431
432
433
434
435
436
437 func WithData(seedData []plugin.Seed, msgAndArgs ...interface{}) f2.StepFn {
438 return func(ctx f2.Context, t *testing.T) f2.Context {
439 pg := FromContextT(ctx, t)
440 db := pg.DB()
441 err := seedExistingDB(db, seedData)
442 require.NoError(t, err, msgAndArgs...)
443 return ctx
444 }
445 }
446
View as plain text