package stdlib_test import ( "bytes" "context" "database/sql" "encoding/json" "fmt" "math" "os" "reflect" "regexp" "strconv" "sync" "testing" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" "github.com/jackc/pgx/v5/tracelog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func openDB(t testing.TB) *sql.DB { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) return stdlib.OpenDB(*config) } func closeDB(t testing.TB, db *sql.DB) { err := db.Close() require.NoError(t, err) } func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { conn, err := db.Conn(context.Background()) require.NoError(t, err) defer conn.Close() err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() if conn.PgConn().ParameterStatus("crdb_version") != "" { t.Skip(msg) } return nil }) require.NoError(t, err) } func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) { conn, err := db.Conn(context.Background()) require.NoError(t, err) defer conn.Close() err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() serverVersionStr := conn.PgConn().ParameterStatus("server_version") serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) // if not PostgreSQL do nothing if serverVersionStr == "" { return nil } serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) if err != nil { return err } if serverVersion < minVersion { t.Skipf("Test requires PostgreSQL v%d+", minVersion) } return nil }) require.NoError(t, err) } func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { for _, mode := range []pgx.QueryExecMode{ pgx.QueryExecModeCacheStatement, pgx.QueryExecModeCacheDescribe, pgx.QueryExecModeDescribeExec, pgx.QueryExecModeExec, pgx.QueryExecModeSimpleProtocol, } { t.Run(mode.String(), func(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.DefaultQueryExecMode = mode db := stdlib.OpenDB(*config) defer func() { err := db.Close() require.NoError(t, err) }() f(t, db) ensureDBValid(t, db) }, ) } } // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should // cover broken connections. func ensureDBValid(t testing.TB, db *sql.DB) { var sum, rowCount int32 rows, err := db.Query("select generate_series(1,$1)", 10) require.NoError(t, err) defer rows.Close() for rows.Next() { var n int32 rows.Scan(&n) sum += n rowCount++ } require.NoError(t, rows.Err()) if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } if sum != 55 { t.Error("Wrong values returned") } } type preparer interface { Prepare(query string) (*sql.Stmt, error) } func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { stmt, err := p.Prepare(sql) require.NoError(t, err) return stmt } func closeStmt(t *testing.T, stmt *sql.Stmt) { err := stmt.Close() require.NoError(t, err) } func TestSQLOpen(t *testing.T) { tests := []struct { driverName string }{ {driverName: "pgx"}, {driverName: "pgx/v5"}, } for _, tt := range tests { tt := tt t.Run(tt.driverName, func(t *testing.T) { db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) closeDB(t, db) }) } } func TestSQLOpenFromPool(t *testing.T) { pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) t.Cleanup(pool.Close) db := stdlib.OpenDBFromPool(pool) ensureDBValid(t, db) db.Close() } func TestNormalLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db) skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var s string var n int64 err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 10, rowCount) err = rows.Close() require.NoError(t, err) ensureDBValid(t, db) } func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db) tx, err := db.Begin() require.NoError(t, err) createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") _, err = createStmt.Exec() require.NoError(t, err) closeStmt(t, createStmt) insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") result, err := insertStmt.Exec("foo") require.NoError(t, err) n, err := result.RowsAffected() require.NoError(t, err) require.EqualValues(t, 1, n) closeStmt(t, insertStmt) ensureDBValid(t, db) } func TestQueryCloseRowsEarly(t *testing.T) { db := openDB(t) defer closeDB(t, db) skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) require.NoError(t, err) // Close rows immediately without having read them err = rows.Close() require.NoError(t, err) // Run the query again to ensure the connection and statement are still ok rows, err = stmt.Query(int32(1), int32(10)) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var s string var n int64 err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 10, rowCount) err = rows.Close() require.NoError(t, err) ensureDBValid(t, db) } func TestConnExec(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) result, err := db.Exec("insert into t values('hey')") require.NoError(t, err) n, err := result.RowsAffected() require.NoError(t, err) require.EqualValues(t, 1, n) }) } func TestConnQuery(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var s string var n int64 err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 10, rowCount) err = rows.Close() require.NoError(t, err) }) } func TestConnConcurrency(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create table t (id integer primary key, str text, dur_str interval)") require.NoError(t, err) defer func() { _, err := db.Exec("drop table t") require.NoError(t, err) }() var wg sync.WaitGroup concurrency := 50 errChan := make(chan error, concurrency) for i := 1; i <= concurrency; i++ { wg.Add(1) go func(idx int) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() str := strconv.Itoa(idx) duration := time.Duration(idx) * time.Second _, err := db.ExecContext(ctx, "insert into t values($1)", idx) if err != nil { errChan <- fmt.Errorf("insert failed: %d %w", idx, err) return } _, err = db.ExecContext(ctx, "update t set str = $1 where id = $2", str, idx) if err != nil { errChan <- fmt.Errorf("update 1 failed: %d %w", idx, err) return } _, err = db.ExecContext(ctx, "update t set dur_str = $1 where id = $2", duration, idx) if err != nil { errChan <- fmt.Errorf("update 2 failed: %d %w", idx, err) return } errChan <- nil }(i) } wg.Wait() for i := 1; i <= concurrency; i++ { err := <-errChan require.NoError(t, err) } for i := 1; i <= concurrency; i++ { wg.Add(1) go func(idx int) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() var id int var str string var duration pgtype.Interval err := db.QueryRowContext(ctx, "select id,str,dur_str from t where id = $1", idx).Scan(&id, &str, &duration) if err != nil { errChan <- fmt.Errorf("select failed: %d %w", idx, err) return } if id != idx { errChan <- fmt.Errorf("id mismatch: %d %d", idx, id) return } if str != strconv.Itoa(idx) { errChan <- fmt.Errorf("str mismatch: %d %s", idx, str) return } expectedDuration := pgtype.Interval{ Microseconds: int64(idx) * time.Second.Microseconds(), Valid: true, } if duration != expectedDuration { errChan <- fmt.Errorf("duration mismatch: %d %v", idx, duration) return } errChan <- nil }(i) } wg.Wait() for i := 1; i <= concurrency; i++ { err := <-errChan require.NoError(t, err) } }) } // https://github.com/jackc/pgx/issues/781 func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { var s string var b bool rows, err := db.Query("select true, 'foo'") require.NoError(t, err) require.True(t, rows.Next()) require.NoError(t, rows.Scan(&b, &s)) assert.Equal(t, true, b) assert.Equal(t, "foo", s) }) } func TestConnQueryNull(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select $1::int", nil) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var n sql.NullInt64 err := rows.Scan(&n) require.NoError(t, err) if n.Valid != false { t.Errorf("Expected n to be null, but it was %v", n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 1, rowCount) err = rows.Close() require.NoError(t, err) }) } func TestConnQueryRowByteSlice(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { expected := []byte{222, 173, 190, 239} var actual []byte err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) require.NoError(t, err) require.EqualValues(t, expected, actual) }) } func TestConnQueryFailure(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Query("select 'foo") require.Error(t, err) require.IsType(t, new(pgconn.PgError), err) }) } func TestConnSimpleSlicePassThrough(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support cardinality function") var n int64 err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n) require.NoError(t, err) assert.EqualValues(t, 3, n) }) } func TestConnQueryScanGoArray(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { m := pgtype.NewMap() var a []int64 err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) require.NoError(t, err) assert.Equal(t, []int64{1, 2, 3}, a) }) } func TestConnQueryScanArray(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { m := pgtype.NewMap() var a pgtype.Array[int64] err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) require.NoError(t, err) assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a) err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a)) require.NoError(t, err) assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a) }) } func TestConnQueryScanRange(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support int4range") m := pgtype.NewMap() var r pgtype.Range[pgtype.Int4] err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r)) require.NoError(t, err) assert.Equal( t, pgtype.Range[pgtype.Int4]{ Lower: pgtype.Int4{Int32: 1, Valid: true}, Upper: pgtype.Int4{Int32: 5, Valid: true}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Valid: true, }, r) }) } // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { sql := "select $1::int4[]" expected := "{1,2,3}" var actual string err := db.QueryRow(sql, expected).Scan(&actual) require.NoError(t, err) require.EqualValues(t, expected, actual) }) } func TestConnQueryRowUnknownType(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support point type") sql := "select $1::point" expected := "(1,2)" var actual string err := db.QueryRow(sql, expected).Scan(&actual) require.NoError(t, err) require.EqualValues(t, expected, actual) }) } func TestConnQueryJSONIntoByteSlice(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec(` create temporary table docs( body json not null ); insert into docs(body) values('{"foo": "bar"}'); `) require.NoError(t, err) sql := `select * from docs` expected := []byte(`{"foo": "bar"}`) var actual []byte err = db.QueryRow(sql).Scan(&actual) if err != nil { t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) } if !bytes.Equal(actual, expected) { t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) } _, err = db.Exec(`drop table docs`) require.NoError(t, err) }) } func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { // Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data // that needs to escape. No way to know whether the destination is really a text compatible or a bytea. db := openDB(t) defer closeDB(t, db) _, err := db.Exec(` create temporary table docs( body json not null ); `) require.NoError(t, err) expected := []byte(`{"foo": "bar"}`) _, err = db.Exec(`insert into docs(body) values($1)`, expected) require.NoError(t, err) var actual []byte err = db.QueryRow(`select body from docs`).Scan(&actual) require.NoError(t, err) if !bytes.Equal(actual, expected) { t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual)) } _, err = db.Exec(`drop table docs`) require.NoError(t, err) } func TestTransactionLifeCycle(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) tx, err := db.Begin() require.NoError(t, err) _, err = tx.Exec("insert into t values('hi')") require.NoError(t, err) err = tx.Rollback() require.NoError(t, err) var n int64 err = db.QueryRow("select count(*) from t").Scan(&n) require.NoError(t, err) require.EqualValues(t, 0, n) tx, err = db.Begin() require.NoError(t, err) _, err = tx.Exec("insert into t values('hi')") require.NoError(t, err) err = tx.Commit() require.NoError(t, err) err = db.QueryRow("select count(*) from t").Scan(&n) require.NoError(t, err) require.EqualValues(t, 1, n) }) } func TestConnBeginTxIsolation(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server always uses serializable isolation level") var defaultIsoLevel string err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) require.NoError(t, err) supportedTests := []struct { sqlIso sql.IsolationLevel pgIso string }{ {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, } for i, tt := range supportedTests { func() { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) if err != nil { t.Errorf("%d. BeginTx failed: %v", i, err) return } defer tx.Rollback() var pgIso string err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) if err != nil { t.Errorf("%d. QueryRow failed: %v", i, err) } if pgIso != tt.pgIso { t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) } }() } unsupportedTests := []struct { sqlIso sql.IsolationLevel }{ {sqlIso: sql.LevelWriteCommitted}, {sqlIso: sql.LevelLinearizable}, } for i, tt := range unsupportedTests { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) if err == nil { t.Errorf("%d. BeginTx should have failed", i) tx.Rollback() } } }) } func TestConnBeginTxReadOnly(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) require.NoError(t, err) defer tx.Rollback() var pgReadOnly string err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) if err != nil { t.Errorf("QueryRow failed: %v", err) } if pgReadOnly != "on" { t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") } }) } func TestBeginTxContextCancel(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("drop table if exists t") require.NoError(t, err) ctx, cancelFn := context.WithCancel(context.Background()) tx, err := db.BeginTx(ctx, nil) require.NoError(t, err) _, err = tx.Exec("create table t(id serial)") require.NoError(t, err) cancelFn() err = tx.Commit() if err != context.Canceled && err != sql.ErrTxDone { t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) } var n int err = db.QueryRow("select count(*) from t").Scan(&n) if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" { t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) } }) } func TestConnRaw(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { conn, err := db.Conn(context.Background()) require.NoError(t, err) var n int err = conn.Raw(func(driverConn any) error { conn := driverConn.(*stdlib.Conn).Conn() return conn.QueryRow(context.Background(), "select 42").Scan(&n) }) require.NoError(t, err) assert.EqualValues(t, 42, n) }) } func TestConnPingContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { err := db.PingContext(context.Background()) require.NoError(t, err) }) } func TestConnPrepareContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { stmt, err := db.PrepareContext(context.Background(), "select now()") require.NoError(t, err) err = stmt.Close() require.NoError(t, err) }) } // https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281 // https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634 func TestConnMultiplePrepareAndDeallocate(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support pg_prepared_statements") sql := "select 42" stmt1, err := db.PrepareContext(context.Background(), sql) require.NoError(t, err) stmt2, err := db.PrepareContext(context.Background(), sql) require.NoError(t, err) err = stmt1.Close() require.NoError(t, err) var preparedStmtCount int64 err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) require.NoError(t, err) require.EqualValues(t, 1, preparedStmtCount) err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate. require.NoError(t, err) err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) require.NoError(t, err) require.EqualValues(t, 0, preparedStmtCount) }) } func TestConnExecContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") require.NoError(t, err) }) } func TestConnQueryContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") require.NoError(t, err) for rows.Next() { var n int64 err := rows.Scan(&n) require.NoError(t, err) } require.NoError(t, rows.Err()) }) } func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select 42::bigint") require.NoError(t, err) columnTypes, err := rows.ColumnTypes() require.NoError(t, err) require.Len(t, columnTypes, 1) if columnTypes[0].DatabaseTypeName() != "INT8" { t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8") } err = rows.Close() require.NoError(t, err) }) } func TestStmtExecContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) _, err := db.Exec("create temporary table t(id int primary key)") require.NoError(t, err) stmt, err := db.Prepare("insert into t(id) values ($1::int4)") require.NoError(t, err) defer stmt.Close() _, err = stmt.ExecContext(context.Background(), 42) require.NoError(t, err) ensureDBValid(t, db) } func TestStmtExecContextCancel(t *testing.T) { db := openDB(t) defer closeDB(t, db) _, err := db.Exec("create temporary table t(id int primary key)") require.NoError(t, err) stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") require.NoError(t, err) defer stmt.Close() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err = stmt.ExecContext(ctx, 42) if !pgconn.Timeout(err) { t.Errorf("expected timeout error, got %v", err) } ensureDBValid(t, db) } func TestStmtQueryContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") require.NoError(t, err) defer stmt.Close() rows, err := stmt.QueryContext(context.Background(), 5) require.NoError(t, err) for rows.Next() { var n int64 if err := rows.Scan(&n); err != nil { t.Error(err) } } if rows.Err() != nil { t.Error(rows.Err()) } ensureDBValid(t, db) } func TestRowsColumnTypes(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { columnTypesTests := []struct { Name string TypeName string Length struct { Len int64 OK bool } DecimalSize struct { Precision int64 Scale int64 OK bool } ScanType reflect.Type }{ { Name: "a", TypeName: "INT8", Length: struct { Len int64 OK bool }{ Len: 0, OK: false, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(int64(0)), }, { Name: "bar", TypeName: "TEXT", Length: struct { Len int64 OK bool }{ Len: math.MaxInt64, OK: true, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(""), }, { Name: "dec", TypeName: "NUMERIC", Length: struct { Len int64 OK bool }{ Len: 0, OK: false, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 9, Scale: 2, OK: true, }, ScanType: reflect.TypeOf(float64(0)), }, { Name: "d", TypeName: "1266", Length: struct { Len int64 OK bool }{ Len: 0, OK: false, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(""), }, } rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d") require.NoError(t, err) columns, err := rows.ColumnTypes() require.NoError(t, err) assert.Len(t, columns, 4) for i, tt := range columnTypesTests { c := columns[i] if c.Name() != tt.Name { t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) } if c.DatabaseTypeName() != tt.TypeName { t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) } l, ok := c.Length() if l != tt.Length.Len { t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) } if ok != tt.Length.OK { t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) } p, s, ok := c.DecimalSize() if p != tt.DecimalSize.Precision { t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) } if s != tt.DecimalSize.Scale { t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) } if ok != tt.DecimalSize.OK { t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) } if c.ScanType() != tt.ScanType { t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) } } }) } func TestQueryLifeCycle(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var ( s string n int64 ) err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) err = rows.Close() require.NoError(t, err) rows, err = db.Query("select 1 where false") require.NoError(t, err) rowCount = int64(0) for rows.Next() { rowCount++ } require.NoError(t, rows.Err()) require.EqualValues(t, 0, rowCount) err = rows.Close() require.NoError(t, err) }) } // https://github.com/jackc/pgx/issues/409 func TestScanJSONIntoJSONRawMessage(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { var msg json.RawMessage err := db.QueryRow("select '{}'::json").Scan(&msg) require.NoError(t, err) require.EqualValues(t, []byte("{}"), []byte(msg)) }) } type testLog struct { lvl tracelog.LogLevel msg string data map[string]any } type testLogger struct { logs []testLog } func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) { l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } func TestRegisterConnConfig(t *testing.T) { connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) logger := &testLogger{} connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo} // Issue 947: Register and unregister a ConnConfig and ensure that the // returned connection string is not reused. connStr := stdlib.RegisterConnConfig(connConfig) require.Equal(t, "registeredConnConfig0", connStr) stdlib.UnregisterConnConfig(connStr) connStr = stdlib.RegisterConnConfig(connConfig) defer stdlib.UnregisterConnConfig(connStr) require.Equal(t, "registeredConnConfig1", connStr) db, err := sql.Open("pgx", connStr) require.NoError(t, err) defer closeDB(t, db) var n int64 err = db.QueryRow("select 1").Scan(&n) require.NoError(t, err) l := logger.logs[len(logger.logs)-1] assert.Equal(t, "Query", l.msg) assert.Equal(t, "select 1", l.data["sql"]) } // https://github.com/jackc/pgx/issues/958 func TestConnQueryRowConstraintErrors(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { skipPostgreSQLVersionLessThan(t, db, 11) skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") _, err := db.Exec(`create temporary table defer_test ( id text primary key, n int not null, unique (n), unique (n) deferrable initially deferred )`) require.NoError(t, err) _, err = db.Exec(`drop function if exists test_trigger cascade`) require.NoError(t, err) _, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$ begin if new.n = 4 then raise exception 'n cant be 4!'; end if; return new; end$$`) require.NoError(t, err) _, err = db.Exec(`create constraint trigger test after insert or update on defer_test deferrable initially deferred for each row execute function test_trigger()`) require.NoError(t, err) _, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`) require.NoError(t, err) var id string err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id) assert.Error(t, err) }) } func TestOptionBeforeAfterConnect(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) var beforeConnConfigs []*pgx.ConnConfig var afterConns []*pgx.Conn db := stdlib.OpenDB(*config, stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error { beforeConnConfigs = append(beforeConnConfigs, connConfig) return nil }), stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { afterConns = append(afterConns, conn) return nil })) defer closeDB(t, db) // Force it to close and reopen a new connection after each query db.SetMaxIdleConns(0) _, err = db.Exec("select 1") require.NoError(t, err) _, err = db.Exec("select 1") require.NoError(t, err) require.Len(t, beforeConnConfigs, 2) require.Len(t, afterConns, 2) // Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they // are different objects, so can't use require.NotEqual require.False(t, config == beforeConnConfigs[0]) require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1]) } func TestRandomizeHostOrderFunc(t *testing.T) { config, err := pgx.ParseConfig("postgres://host1,host2,host3") require.NoError(t, err) // Test that at some point we connect to all 3 hosts hostsNotSeenYet := map[string]struct{}{ "host1": {}, "host2": {}, "host3": {}, } // If we don't succeed within this many iterations, something is certainly wrong for i := 0; i < 100000; i++ { connCopy := *config stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy) delete(hostsNotSeenYet, connCopy.Host) if len(hostsNotSeenYet) == 0 { return } hostCheckLoop: for _, h := range []string{"host1", "host2", "host3"} { if connCopy.Host == h { continue } for _, f := range connCopy.Fallbacks { if f.Host == h { continue hostCheckLoop } } require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy) } } require.Fail(t, "did not get all hosts as primaries after many randomizations") } func TestResetSessionHookCalled(t *testing.T) { var mockCalled bool connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error { mockCalled = true return nil })) defer closeDB(t, db) err = db.Ping() require.NoError(t, err) err = db.Ping() require.NoError(t, err) require.True(t, mockCalled) } func TestCheckIdleConn(t *testing.T) { controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer closeDB(t, controllerConn) skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer closeDB(t, db) var conns []*sql.Conn for i := 0; i < 3; i++ { c, err := db.Conn(context.Background()) require.NoError(t, err) conns = append(conns, c) } require.EqualValues(t, 3, db.Stats().OpenConnections) var pids []uint32 for _, c := range conns { err := c.Raw(func(driverConn any) error { pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID()) return nil }) require.NoError(t, err) err = c.Close() require.NoError(t, err) } // The database/sql connection pool seems to automatically close idle connections to only keep 2 alive. // require.EqualValues(t, 3, db.Stats().OpenConnections) _, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) require.NoError(t, err) // All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing // idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections) // Wait long enough so the pool will realize it needs to check the connections. time.Sleep(time.Second) // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. err = db.PingContext(context.Background()) require.NoError(t, err) // The original 3 conns should have been terminated and the a new conn established for the ping. require.EqualValues(t, 1, db.Stats().OpenConnections) c, err := db.Conn(context.Background()) require.NoError(t, err) var cPID uint32 err = c.Raw(func(driverConn any) error { cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID() return nil }) require.NoError(t, err) err = c.Close() require.NoError(t, err) require.NotContains(t, pids, cPID) }