1 package stdlib_test
2
3 import (
4 "bytes"
5 "context"
6 "database/sql"
7 "encoding/json"
8 "fmt"
9 "math"
10 "os"
11 "reflect"
12 "regexp"
13 "strconv"
14 "sync"
15 "testing"
16 "time"
17
18 "github.com/jackc/pgx/v5"
19 "github.com/jackc/pgx/v5/pgconn"
20 "github.com/jackc/pgx/v5/pgtype"
21 "github.com/jackc/pgx/v5/pgxpool"
22 "github.com/jackc/pgx/v5/stdlib"
23 "github.com/jackc/pgx/v5/tracelog"
24 "github.com/stretchr/testify/assert"
25 "github.com/stretchr/testify/require"
26 )
27
28 func openDB(t testing.TB) *sql.DB {
29 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
30 require.NoError(t, err)
31 return stdlib.OpenDB(*config)
32 }
33
34 func closeDB(t testing.TB, db *sql.DB) {
35 err := db.Close()
36 require.NoError(t, err)
37 }
38
39 func skipCockroachDB(t testing.TB, db *sql.DB, msg string) {
40 conn, err := db.Conn(context.Background())
41 require.NoError(t, err)
42 defer conn.Close()
43
44 err = conn.Raw(func(driverConn any) error {
45 conn := driverConn.(*stdlib.Conn).Conn()
46 if conn.PgConn().ParameterStatus("crdb_version") != "" {
47 t.Skip(msg)
48 }
49 return nil
50 })
51 require.NoError(t, err)
52 }
53
54 func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) {
55 conn, err := db.Conn(context.Background())
56 require.NoError(t, err)
57 defer conn.Close()
58
59 err = conn.Raw(func(driverConn any) error {
60 conn := driverConn.(*stdlib.Conn).Conn()
61 serverVersionStr := conn.PgConn().ParameterStatus("server_version")
62 serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
63
64 if serverVersionStr == "" {
65 return nil
66 }
67
68 serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
69 if err != nil {
70 return err
71 }
72
73 if serverVersion < minVersion {
74 t.Skipf("Test requires PostgreSQL v%d+", minVersion)
75 }
76
77 return nil
78 })
79 require.NoError(t, err)
80 }
81
82 func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
83 for _, mode := range []pgx.QueryExecMode{
84 pgx.QueryExecModeCacheStatement,
85 pgx.QueryExecModeCacheDescribe,
86 pgx.QueryExecModeDescribeExec,
87 pgx.QueryExecModeExec,
88 pgx.QueryExecModeSimpleProtocol,
89 } {
90 t.Run(mode.String(),
91 func(t *testing.T) {
92 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
93 require.NoError(t, err)
94
95 config.DefaultQueryExecMode = mode
96 db := stdlib.OpenDB(*config)
97 defer func() {
98 err := db.Close()
99 require.NoError(t, err)
100 }()
101
102 f(t, db)
103
104 ensureDBValid(t, db)
105 },
106 )
107 }
108 }
109
110
111
112 func ensureDBValid(t testing.TB, db *sql.DB) {
113 var sum, rowCount int32
114
115 rows, err := db.Query("select generate_series(1,$1)", 10)
116 require.NoError(t, err)
117 defer rows.Close()
118
119 for rows.Next() {
120 var n int32
121 rows.Scan(&n)
122 sum += n
123 rowCount++
124 }
125
126 require.NoError(t, rows.Err())
127
128 if rowCount != 10 {
129 t.Error("Select called onDataRow wrong number of times")
130 }
131 if sum != 55 {
132 t.Error("Wrong values returned")
133 }
134 }
135
136 type preparer interface {
137 Prepare(query string) (*sql.Stmt, error)
138 }
139
140 func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
141 stmt, err := p.Prepare(sql)
142 require.NoError(t, err)
143 return stmt
144 }
145
146 func closeStmt(t *testing.T, stmt *sql.Stmt) {
147 err := stmt.Close()
148 require.NoError(t, err)
149 }
150
151 func TestSQLOpen(t *testing.T) {
152 tests := []struct {
153 driverName string
154 }{
155 {driverName: "pgx"},
156 {driverName: "pgx/v5"},
157 }
158
159 for _, tt := range tests {
160 tt := tt
161
162 t.Run(tt.driverName, func(t *testing.T) {
163 db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE"))
164 require.NoError(t, err)
165 closeDB(t, db)
166 })
167 }
168 }
169
170 func TestSQLOpenFromPool(t *testing.T) {
171 pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
172 require.NoError(t, err)
173 t.Cleanup(pool.Close)
174
175 db := stdlib.OpenDBFromPool(pool)
176 ensureDBValid(t, db)
177
178 db.Close()
179 }
180
181 func TestNormalLifeCycle(t *testing.T) {
182 db := openDB(t)
183 defer closeDB(t, db)
184
185 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
186
187 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
188 defer closeStmt(t, stmt)
189
190 rows, err := stmt.Query(int32(1), int32(10))
191 require.NoError(t, err)
192
193 rowCount := int64(0)
194
195 for rows.Next() {
196 rowCount++
197
198 var s string
199 var n int64
200 err := rows.Scan(&s, &n)
201 require.NoError(t, err)
202
203 if s != "foo" {
204 t.Errorf(`Expected "foo", received "%v"`, s)
205 }
206 if n != rowCount {
207 t.Errorf("Expected %d, received %d", rowCount, n)
208 }
209 }
210 require.NoError(t, rows.Err())
211
212 require.EqualValues(t, 10, rowCount)
213
214 err = rows.Close()
215 require.NoError(t, err)
216
217 ensureDBValid(t, db)
218 }
219
220 func TestStmtExec(t *testing.T) {
221 db := openDB(t)
222 defer closeDB(t, db)
223
224 tx, err := db.Begin()
225 require.NoError(t, err)
226
227 createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
228 _, err = createStmt.Exec()
229 require.NoError(t, err)
230 closeStmt(t, createStmt)
231
232 insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
233 result, err := insertStmt.Exec("foo")
234 require.NoError(t, err)
235
236 n, err := result.RowsAffected()
237 require.NoError(t, err)
238 require.EqualValues(t, 1, n)
239 closeStmt(t, insertStmt)
240
241 ensureDBValid(t, db)
242 }
243
244 func TestQueryCloseRowsEarly(t *testing.T) {
245 db := openDB(t)
246 defer closeDB(t, db)
247
248 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
249
250 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
251 defer closeStmt(t, stmt)
252
253 rows, err := stmt.Query(int32(1), int32(10))
254 require.NoError(t, err)
255
256
257 err = rows.Close()
258 require.NoError(t, err)
259
260
261 rows, err = stmt.Query(int32(1), int32(10))
262 require.NoError(t, err)
263
264 rowCount := int64(0)
265
266 for rows.Next() {
267 rowCount++
268
269 var s string
270 var n int64
271 err := rows.Scan(&s, &n)
272 require.NoError(t, err)
273 if s != "foo" {
274 t.Errorf(`Expected "foo", received "%v"`, s)
275 }
276 if n != rowCount {
277 t.Errorf("Expected %d, received %d", rowCount, n)
278 }
279 }
280 require.NoError(t, rows.Err())
281 require.EqualValues(t, 10, rowCount)
282
283 err = rows.Close()
284 require.NoError(t, err)
285
286 ensureDBValid(t, db)
287 }
288
289 func TestConnExec(t *testing.T) {
290 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
291 _, err := db.Exec("create temporary table t(a varchar not null)")
292 require.NoError(t, err)
293
294 result, err := db.Exec("insert into t values('hey')")
295 require.NoError(t, err)
296
297 n, err := result.RowsAffected()
298 require.NoError(t, err)
299 require.EqualValues(t, 1, n)
300 })
301 }
302
303 func TestConnQuery(t *testing.T) {
304 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
305 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
306
307 rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
308 require.NoError(t, err)
309
310 rowCount := int64(0)
311
312 for rows.Next() {
313 rowCount++
314
315 var s string
316 var n int64
317 err := rows.Scan(&s, &n)
318 require.NoError(t, err)
319 if s != "foo" {
320 t.Errorf(`Expected "foo", received "%v"`, s)
321 }
322 if n != rowCount {
323 t.Errorf("Expected %d, received %d", rowCount, n)
324 }
325 }
326 require.NoError(t, rows.Err())
327 require.EqualValues(t, 10, rowCount)
328
329 err = rows.Close()
330 require.NoError(t, err)
331 })
332 }
333
334 func TestConnConcurrency(t *testing.T) {
335 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
336 _, err := db.Exec("create table t (id integer primary key, str text, dur_str interval)")
337 require.NoError(t, err)
338
339 defer func() {
340 _, err := db.Exec("drop table t")
341 require.NoError(t, err)
342 }()
343
344 var wg sync.WaitGroup
345
346 concurrency := 50
347 errChan := make(chan error, concurrency)
348
349 for i := 1; i <= concurrency; i++ {
350 wg.Add(1)
351
352 go func(idx int) {
353 defer wg.Done()
354
355 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
356 defer cancel()
357
358 str := strconv.Itoa(idx)
359 duration := time.Duration(idx) * time.Second
360 _, err := db.ExecContext(ctx, "insert into t values($1)", idx)
361 if err != nil {
362 errChan <- fmt.Errorf("insert failed: %d %w", idx, err)
363 return
364 }
365 _, err = db.ExecContext(ctx, "update t set str = $1 where id = $2", str, idx)
366 if err != nil {
367 errChan <- fmt.Errorf("update 1 failed: %d %w", idx, err)
368 return
369 }
370 _, err = db.ExecContext(ctx, "update t set dur_str = $1 where id = $2", duration, idx)
371 if err != nil {
372 errChan <- fmt.Errorf("update 2 failed: %d %w", idx, err)
373 return
374 }
375
376 errChan <- nil
377 }(i)
378 }
379 wg.Wait()
380 for i := 1; i <= concurrency; i++ {
381 err := <-errChan
382 require.NoError(t, err)
383 }
384
385 for i := 1; i <= concurrency; i++ {
386 wg.Add(1)
387
388 go func(idx int) {
389 defer wg.Done()
390
391 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
392 defer cancel()
393
394 var id int
395 var str string
396 var duration pgtype.Interval
397 err := db.QueryRowContext(ctx, "select id,str,dur_str from t where id = $1", idx).Scan(&id, &str, &duration)
398 if err != nil {
399 errChan <- fmt.Errorf("select failed: %d %w", idx, err)
400 return
401 }
402 if id != idx {
403 errChan <- fmt.Errorf("id mismatch: %d %d", idx, id)
404 return
405 }
406 if str != strconv.Itoa(idx) {
407 errChan <- fmt.Errorf("str mismatch: %d %s", idx, str)
408 return
409 }
410 expectedDuration := pgtype.Interval{
411 Microseconds: int64(idx) * time.Second.Microseconds(),
412 Valid: true,
413 }
414 if duration != expectedDuration {
415 errChan <- fmt.Errorf("duration mismatch: %d %v", idx, duration)
416 return
417 }
418
419 errChan <- nil
420 }(i)
421 }
422 wg.Wait()
423 for i := 1; i <= concurrency; i++ {
424 err := <-errChan
425 require.NoError(t, err)
426 }
427 })
428 }
429
430
431 func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
432 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
433 var s string
434 var b bool
435
436 rows, err := db.Query("select true, 'foo'")
437 require.NoError(t, err)
438
439 require.True(t, rows.Next())
440 require.NoError(t, rows.Scan(&b, &s))
441 assert.Equal(t, true, b)
442 assert.Equal(t, "foo", s)
443 })
444 }
445
446 func TestConnQueryNull(t *testing.T) {
447 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
448 rows, err := db.Query("select $1::int", nil)
449 require.NoError(t, err)
450
451 rowCount := int64(0)
452
453 for rows.Next() {
454 rowCount++
455
456 var n sql.NullInt64
457 err := rows.Scan(&n)
458 require.NoError(t, err)
459 if n.Valid != false {
460 t.Errorf("Expected n to be null, but it was %v", n)
461 }
462 }
463 require.NoError(t, rows.Err())
464 require.EqualValues(t, 1, rowCount)
465
466 err = rows.Close()
467 require.NoError(t, err)
468 })
469 }
470
471 func TestConnQueryRowByteSlice(t *testing.T) {
472 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
473 expected := []byte{222, 173, 190, 239}
474 var actual []byte
475
476 err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
477 require.NoError(t, err)
478 require.EqualValues(t, expected, actual)
479 })
480 }
481
482 func TestConnQueryFailure(t *testing.T) {
483 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
484 _, err := db.Query("select 'foo")
485 require.Error(t, err)
486 require.IsType(t, new(pgconn.PgError), err)
487 })
488 }
489
490 func TestConnSimpleSlicePassThrough(t *testing.T) {
491 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
492 skipCockroachDB(t, db, "Server does not support cardinality function")
493
494 var n int64
495 err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
496 require.NoError(t, err)
497 assert.EqualValues(t, 3, n)
498 })
499 }
500
501 func TestConnQueryScanGoArray(t *testing.T) {
502 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
503 m := pgtype.NewMap()
504
505 var a []int64
506 err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
507 require.NoError(t, err)
508 assert.Equal(t, []int64{1, 2, 3}, a)
509 })
510 }
511
512 func TestConnQueryScanArray(t *testing.T) {
513 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
514 m := pgtype.NewMap()
515
516 var a pgtype.Array[int64]
517 err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
518 require.NoError(t, err)
519 assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
520
521 err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a))
522 require.NoError(t, err)
523 assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a)
524 })
525 }
526
527 func TestConnQueryScanRange(t *testing.T) {
528 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
529 skipCockroachDB(t, db, "Server does not support int4range")
530
531 m := pgtype.NewMap()
532
533 var r pgtype.Range[pgtype.Int4]
534 err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r))
535 require.NoError(t, err)
536 assert.Equal(
537 t,
538 pgtype.Range[pgtype.Int4]{
539 Lower: pgtype.Int4{Int32: 1, Valid: true},
540 Upper: pgtype.Int4{Int32: 5, Valid: true},
541 LowerType: pgtype.Inclusive,
542 UpperType: pgtype.Exclusive,
543 Valid: true,
544 },
545 r)
546 })
547 }
548
549
550
551 func TestConnQueryRowPgxBinary(t *testing.T) {
552 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
553 sql := "select $1::int4[]"
554 expected := "{1,2,3}"
555 var actual string
556
557 err := db.QueryRow(sql, expected).Scan(&actual)
558 require.NoError(t, err)
559 require.EqualValues(t, expected, actual)
560 })
561 }
562
563 func TestConnQueryRowUnknownType(t *testing.T) {
564 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
565 skipCockroachDB(t, db, "Server does not support point type")
566
567 sql := "select $1::point"
568 expected := "(1,2)"
569 var actual string
570
571 err := db.QueryRow(sql, expected).Scan(&actual)
572 require.NoError(t, err)
573 require.EqualValues(t, expected, actual)
574 })
575 }
576
577 func TestConnQueryJSONIntoByteSlice(t *testing.T) {
578 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
579 _, err := db.Exec(`
580 create temporary table docs(
581 body json not null
582 );
583
584 insert into docs(body) values('{"foo": "bar"}');
585 `)
586 require.NoError(t, err)
587
588 sql := `select * from docs`
589 expected := []byte(`{"foo": "bar"}`)
590 var actual []byte
591
592 err = db.QueryRow(sql).Scan(&actual)
593 if err != nil {
594 t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
595 }
596
597 if !bytes.Equal(actual, expected) {
598 t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
599 }
600
601 _, err = db.Exec(`drop table docs`)
602 require.NoError(t, err)
603 })
604 }
605
606 func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
607
608
609
610 db := openDB(t)
611 defer closeDB(t, db)
612
613 _, err := db.Exec(`
614 create temporary table docs(
615 body json not null
616 );
617 `)
618 require.NoError(t, err)
619
620 expected := []byte(`{"foo": "bar"}`)
621
622 _, err = db.Exec(`insert into docs(body) values($1)`, expected)
623 require.NoError(t, err)
624
625 var actual []byte
626 err = db.QueryRow(`select body from docs`).Scan(&actual)
627 require.NoError(t, err)
628
629 if !bytes.Equal(actual, expected) {
630 t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
631 }
632
633 _, err = db.Exec(`drop table docs`)
634 require.NoError(t, err)
635 }
636
637 func TestTransactionLifeCycle(t *testing.T) {
638 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
639 _, err := db.Exec("create temporary table t(a varchar not null)")
640 require.NoError(t, err)
641
642 tx, err := db.Begin()
643 require.NoError(t, err)
644
645 _, err = tx.Exec("insert into t values('hi')")
646 require.NoError(t, err)
647
648 err = tx.Rollback()
649 require.NoError(t, err)
650
651 var n int64
652 err = db.QueryRow("select count(*) from t").Scan(&n)
653 require.NoError(t, err)
654 require.EqualValues(t, 0, n)
655
656 tx, err = db.Begin()
657 require.NoError(t, err)
658
659 _, err = tx.Exec("insert into t values('hi')")
660 require.NoError(t, err)
661
662 err = tx.Commit()
663 require.NoError(t, err)
664
665 err = db.QueryRow("select count(*) from t").Scan(&n)
666 require.NoError(t, err)
667 require.EqualValues(t, 1, n)
668 })
669 }
670
671 func TestConnBeginTxIsolation(t *testing.T) {
672 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
673 skipCockroachDB(t, db, "Server always uses serializable isolation level")
674
675 var defaultIsoLevel string
676 err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
677 require.NoError(t, err)
678
679 supportedTests := []struct {
680 sqlIso sql.IsolationLevel
681 pgIso string
682 }{
683 {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
684 {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
685 {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
686 {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
687 {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
688 {sqlIso: sql.LevelSerializable, pgIso: "serializable"},
689 }
690 for i, tt := range supportedTests {
691 func() {
692 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
693 if err != nil {
694 t.Errorf("%d. BeginTx failed: %v", i, err)
695 return
696 }
697 defer tx.Rollback()
698
699 var pgIso string
700 err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
701 if err != nil {
702 t.Errorf("%d. QueryRow failed: %v", i, err)
703 }
704
705 if pgIso != tt.pgIso {
706 t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
707 }
708 }()
709 }
710
711 unsupportedTests := []struct {
712 sqlIso sql.IsolationLevel
713 }{
714 {sqlIso: sql.LevelWriteCommitted},
715 {sqlIso: sql.LevelLinearizable},
716 }
717 for i, tt := range unsupportedTests {
718 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
719 if err == nil {
720 t.Errorf("%d. BeginTx should have failed", i)
721 tx.Rollback()
722 }
723 }
724 })
725 }
726
727 func TestConnBeginTxReadOnly(t *testing.T) {
728 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
729 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
730 require.NoError(t, err)
731 defer tx.Rollback()
732
733 var pgReadOnly string
734 err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
735 if err != nil {
736 t.Errorf("QueryRow failed: %v", err)
737 }
738
739 if pgReadOnly != "on" {
740 t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
741 }
742 })
743 }
744
745 func TestBeginTxContextCancel(t *testing.T) {
746 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
747 _, err := db.Exec("drop table if exists t")
748 require.NoError(t, err)
749
750 ctx, cancelFn := context.WithCancel(context.Background())
751
752 tx, err := db.BeginTx(ctx, nil)
753 require.NoError(t, err)
754
755 _, err = tx.Exec("create table t(id serial)")
756 require.NoError(t, err)
757
758 cancelFn()
759
760 err = tx.Commit()
761 if err != context.Canceled && err != sql.ErrTxDone {
762 t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
763 }
764
765 var n int
766 err = db.QueryRow("select count(*) from t").Scan(&n)
767 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
768 t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
769 }
770 })
771 }
772
773 func TestConnRaw(t *testing.T) {
774 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
775 conn, err := db.Conn(context.Background())
776 require.NoError(t, err)
777
778 var n int
779 err = conn.Raw(func(driverConn any) error {
780 conn := driverConn.(*stdlib.Conn).Conn()
781 return conn.QueryRow(context.Background(), "select 42").Scan(&n)
782 })
783 require.NoError(t, err)
784 assert.EqualValues(t, 42, n)
785 })
786 }
787
788 func TestConnPingContextSuccess(t *testing.T) {
789 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
790 err := db.PingContext(context.Background())
791 require.NoError(t, err)
792 })
793 }
794
795 func TestConnPrepareContextSuccess(t *testing.T) {
796 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
797 stmt, err := db.PrepareContext(context.Background(), "select now()")
798 require.NoError(t, err)
799 err = stmt.Close()
800 require.NoError(t, err)
801 })
802 }
803
804
805
806 func TestConnMultiplePrepareAndDeallocate(t *testing.T) {
807 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
808 skipCockroachDB(t, db, "Server does not support pg_prepared_statements")
809
810 sql := "select 42"
811 stmt1, err := db.PrepareContext(context.Background(), sql)
812 require.NoError(t, err)
813 stmt2, err := db.PrepareContext(context.Background(), sql)
814 require.NoError(t, err)
815 err = stmt1.Close()
816 require.NoError(t, err)
817
818 var preparedStmtCount int64
819 err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
820 require.NoError(t, err)
821 require.EqualValues(t, 1, preparedStmtCount)
822
823 err = stmt2.Close()
824 require.NoError(t, err)
825
826 err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
827 require.NoError(t, err)
828 require.EqualValues(t, 0, preparedStmtCount)
829 })
830 }
831
832 func TestConnExecContextSuccess(t *testing.T) {
833 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
834 _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
835 require.NoError(t, err)
836 })
837 }
838
839 func TestConnQueryContextSuccess(t *testing.T) {
840 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
841 rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
842 require.NoError(t, err)
843
844 for rows.Next() {
845 var n int64
846 err := rows.Scan(&n)
847 require.NoError(t, err)
848 }
849 require.NoError(t, rows.Err())
850 })
851 }
852
853 func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
854 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
855 rows, err := db.Query("select 42::bigint")
856 require.NoError(t, err)
857
858 columnTypes, err := rows.ColumnTypes()
859 require.NoError(t, err)
860 require.Len(t, columnTypes, 1)
861
862 if columnTypes[0].DatabaseTypeName() != "INT8" {
863 t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8")
864 }
865
866 err = rows.Close()
867 require.NoError(t, err)
868 })
869 }
870
871 func TestStmtExecContextSuccess(t *testing.T) {
872 db := openDB(t)
873 defer closeDB(t, db)
874
875 _, err := db.Exec("create temporary table t(id int primary key)")
876 require.NoError(t, err)
877
878 stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
879 require.NoError(t, err)
880 defer stmt.Close()
881
882 _, err = stmt.ExecContext(context.Background(), 42)
883 require.NoError(t, err)
884
885 ensureDBValid(t, db)
886 }
887
888 func TestStmtExecContextCancel(t *testing.T) {
889 db := openDB(t)
890 defer closeDB(t, db)
891
892 _, err := db.Exec("create temporary table t(id int primary key)")
893 require.NoError(t, err)
894
895 stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
896 require.NoError(t, err)
897 defer stmt.Close()
898
899 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
900 defer cancel()
901
902 _, err = stmt.ExecContext(ctx, 42)
903 if !pgconn.Timeout(err) {
904 t.Errorf("expected timeout error, got %v", err)
905 }
906
907 ensureDBValid(t, db)
908 }
909
910 func TestStmtQueryContextSuccess(t *testing.T) {
911 db := openDB(t)
912 defer closeDB(t, db)
913
914 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
915
916 stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
917 require.NoError(t, err)
918 defer stmt.Close()
919
920 rows, err := stmt.QueryContext(context.Background(), 5)
921 require.NoError(t, err)
922
923 for rows.Next() {
924 var n int64
925 if err := rows.Scan(&n); err != nil {
926 t.Error(err)
927 }
928 }
929
930 if rows.Err() != nil {
931 t.Error(rows.Err())
932 }
933
934 ensureDBValid(t, db)
935 }
936
937 func TestRowsColumnTypes(t *testing.T) {
938 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
939 columnTypesTests := []struct {
940 Name string
941 TypeName string
942 Length struct {
943 Len int64
944 OK bool
945 }
946 DecimalSize struct {
947 Precision int64
948 Scale int64
949 OK bool
950 }
951 ScanType reflect.Type
952 }{
953 {
954 Name: "a",
955 TypeName: "INT8",
956 Length: struct {
957 Len int64
958 OK bool
959 }{
960 Len: 0,
961 OK: false,
962 },
963 DecimalSize: struct {
964 Precision int64
965 Scale int64
966 OK bool
967 }{
968 Precision: 0,
969 Scale: 0,
970 OK: false,
971 },
972 ScanType: reflect.TypeOf(int64(0)),
973 }, {
974 Name: "bar",
975 TypeName: "TEXT",
976 Length: struct {
977 Len int64
978 OK bool
979 }{
980 Len: math.MaxInt64,
981 OK: true,
982 },
983 DecimalSize: struct {
984 Precision int64
985 Scale int64
986 OK bool
987 }{
988 Precision: 0,
989 Scale: 0,
990 OK: false,
991 },
992 ScanType: reflect.TypeOf(""),
993 }, {
994 Name: "dec",
995 TypeName: "NUMERIC",
996 Length: struct {
997 Len int64
998 OK bool
999 }{
1000 Len: 0,
1001 OK: false,
1002 },
1003 DecimalSize: struct {
1004 Precision int64
1005 Scale int64
1006 OK bool
1007 }{
1008 Precision: 9,
1009 Scale: 2,
1010 OK: true,
1011 },
1012 ScanType: reflect.TypeOf(float64(0)),
1013 }, {
1014 Name: "d",
1015 TypeName: "1266",
1016 Length: struct {
1017 Len int64
1018 OK bool
1019 }{
1020 Len: 0,
1021 OK: false,
1022 },
1023 DecimalSize: struct {
1024 Precision int64
1025 Scale int64
1026 OK bool
1027 }{
1028 Precision: 0,
1029 Scale: 0,
1030 OK: false,
1031 },
1032 ScanType: reflect.TypeOf(""),
1033 },
1034 }
1035
1036 rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
1037 require.NoError(t, err)
1038
1039 columns, err := rows.ColumnTypes()
1040 require.NoError(t, err)
1041 assert.Len(t, columns, 4)
1042
1043 for i, tt := range columnTypesTests {
1044 c := columns[i]
1045 if c.Name() != tt.Name {
1046 t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
1047 }
1048 if c.DatabaseTypeName() != tt.TypeName {
1049 t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
1050 }
1051 l, ok := c.Length()
1052 if l != tt.Length.Len {
1053 t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
1054 }
1055 if ok != tt.Length.OK {
1056 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
1057 }
1058 p, s, ok := c.DecimalSize()
1059 if p != tt.DecimalSize.Precision {
1060 t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
1061 }
1062 if s != tt.DecimalSize.Scale {
1063 t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
1064 }
1065 if ok != tt.DecimalSize.OK {
1066 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
1067 }
1068 if c.ScanType() != tt.ScanType {
1069 t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
1070 }
1071 }
1072 })
1073 }
1074
1075 func TestQueryLifeCycle(t *testing.T) {
1076 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
1077 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
1078
1079 rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
1080 require.NoError(t, err)
1081
1082 rowCount := int64(0)
1083
1084 for rows.Next() {
1085 rowCount++
1086 var (
1087 s string
1088 n int64
1089 )
1090
1091 err := rows.Scan(&s, &n)
1092 require.NoError(t, err)
1093
1094 if s != "foo" {
1095 t.Errorf(`Expected "foo", received "%v"`, s)
1096 }
1097
1098 if n != rowCount {
1099 t.Errorf("Expected %d, received %d", rowCount, n)
1100 }
1101 }
1102 require.NoError(t, rows.Err())
1103
1104 err = rows.Close()
1105 require.NoError(t, err)
1106
1107 rows, err = db.Query("select 1 where false")
1108 require.NoError(t, err)
1109
1110 rowCount = int64(0)
1111
1112 for rows.Next() {
1113 rowCount++
1114 }
1115 require.NoError(t, rows.Err())
1116 require.EqualValues(t, 0, rowCount)
1117
1118 err = rows.Close()
1119 require.NoError(t, err)
1120 })
1121 }
1122
1123
1124 func TestScanJSONIntoJSONRawMessage(t *testing.T) {
1125 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
1126 var msg json.RawMessage
1127
1128 err := db.QueryRow("select '{}'::json").Scan(&msg)
1129 require.NoError(t, err)
1130 require.EqualValues(t, []byte("{}"), []byte(msg))
1131 })
1132 }
1133
1134 type testLog struct {
1135 lvl tracelog.LogLevel
1136 msg string
1137 data map[string]any
1138 }
1139
1140 type testLogger struct {
1141 logs []testLog
1142 }
1143
1144 func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) {
1145 l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
1146 }
1147
1148 func TestRegisterConnConfig(t *testing.T) {
1149 connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1150 require.NoError(t, err)
1151
1152 logger := &testLogger{}
1153 connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo}
1154
1155
1156
1157 connStr := stdlib.RegisterConnConfig(connConfig)
1158 require.Equal(t, "registeredConnConfig0", connStr)
1159 stdlib.UnregisterConnConfig(connStr)
1160
1161 connStr = stdlib.RegisterConnConfig(connConfig)
1162 defer stdlib.UnregisterConnConfig(connStr)
1163 require.Equal(t, "registeredConnConfig1", connStr)
1164
1165 db, err := sql.Open("pgx", connStr)
1166 require.NoError(t, err)
1167 defer closeDB(t, db)
1168
1169 var n int64
1170 err = db.QueryRow("select 1").Scan(&n)
1171 require.NoError(t, err)
1172
1173 l := logger.logs[len(logger.logs)-1]
1174 assert.Equal(t, "Query", l.msg)
1175 assert.Equal(t, "select 1", l.data["sql"])
1176 }
1177
1178
1179 func TestConnQueryRowConstraintErrors(t *testing.T) {
1180 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
1181 skipPostgreSQLVersionLessThan(t, db, 11)
1182 skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
1183
1184 _, err := db.Exec(`create temporary table defer_test (
1185 id text primary key,
1186 n int not null, unique (n),
1187 unique (n) deferrable initially deferred )`)
1188 require.NoError(t, err)
1189
1190 _, err = db.Exec(`drop function if exists test_trigger cascade`)
1191 require.NoError(t, err)
1192
1193 _, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$
1194 begin
1195 if new.n = 4 then
1196 raise exception 'n cant be 4!';
1197 end if;
1198 return new;
1199 end$$`)
1200 require.NoError(t, err)
1201
1202 _, err = db.Exec(`create constraint trigger test
1203 after insert or update on defer_test
1204 deferrable initially deferred
1205 for each row
1206 execute function test_trigger()`)
1207 require.NoError(t, err)
1208
1209 _, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`)
1210 require.NoError(t, err)
1211
1212 var id string
1213 err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id)
1214 assert.Error(t, err)
1215 })
1216 }
1217
1218 func TestOptionBeforeAfterConnect(t *testing.T) {
1219 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1220 require.NoError(t, err)
1221
1222 var beforeConnConfigs []*pgx.ConnConfig
1223 var afterConns []*pgx.Conn
1224 db := stdlib.OpenDB(*config,
1225 stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error {
1226 beforeConnConfigs = append(beforeConnConfigs, connConfig)
1227 return nil
1228 }),
1229 stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error {
1230 afterConns = append(afterConns, conn)
1231 return nil
1232 }))
1233 defer closeDB(t, db)
1234
1235
1236 db.SetMaxIdleConns(0)
1237
1238 _, err = db.Exec("select 1")
1239 require.NoError(t, err)
1240
1241 _, err = db.Exec("select 1")
1242 require.NoError(t, err)
1243
1244 require.Len(t, beforeConnConfigs, 2)
1245 require.Len(t, afterConns, 2)
1246
1247
1248
1249 require.False(t, config == beforeConnConfigs[0])
1250 require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1])
1251 }
1252
1253 func TestRandomizeHostOrderFunc(t *testing.T) {
1254 config, err := pgx.ParseConfig("postgres://host1,host2,host3")
1255 require.NoError(t, err)
1256
1257
1258 hostsNotSeenYet := map[string]struct{}{
1259 "host1": {},
1260 "host2": {},
1261 "host3": {},
1262 }
1263
1264
1265 for i := 0; i < 100000; i++ {
1266 connCopy := *config
1267 stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy)
1268
1269 delete(hostsNotSeenYet, connCopy.Host)
1270 if len(hostsNotSeenYet) == 0 {
1271 return
1272 }
1273
1274 hostCheckLoop:
1275 for _, h := range []string{"host1", "host2", "host3"} {
1276 if connCopy.Host == h {
1277 continue
1278 }
1279 for _, f := range connCopy.Fallbacks {
1280 if f.Host == h {
1281 continue hostCheckLoop
1282 }
1283 }
1284 require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy)
1285 }
1286 }
1287
1288 require.Fail(t, "did not get all hosts as primaries after many randomizations")
1289 }
1290
1291 func TestResetSessionHookCalled(t *testing.T) {
1292 var mockCalled bool
1293
1294 connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1295 require.NoError(t, err)
1296
1297 db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error {
1298 mockCalled = true
1299
1300 return nil
1301 }))
1302
1303 defer closeDB(t, db)
1304
1305 err = db.Ping()
1306 require.NoError(t, err)
1307
1308 err = db.Ping()
1309 require.NoError(t, err)
1310
1311 require.True(t, mockCalled)
1312 }
1313
1314 func TestCheckIdleConn(t *testing.T) {
1315 controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
1316 require.NoError(t, err)
1317 defer closeDB(t, controllerConn)
1318
1319 skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
1320
1321 db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
1322 require.NoError(t, err)
1323 defer closeDB(t, db)
1324
1325 var conns []*sql.Conn
1326 for i := 0; i < 3; i++ {
1327 c, err := db.Conn(context.Background())
1328 require.NoError(t, err)
1329 conns = append(conns, c)
1330 }
1331
1332 require.EqualValues(t, 3, db.Stats().OpenConnections)
1333
1334 var pids []uint32
1335 for _, c := range conns {
1336 err := c.Raw(func(driverConn any) error {
1337 pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID())
1338 return nil
1339 })
1340 require.NoError(t, err)
1341 err = c.Close()
1342 require.NoError(t, err)
1343 }
1344
1345
1346
1347
1348 _, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids)
1349 require.NoError(t, err)
1350
1351
1352
1353
1354
1355 time.Sleep(time.Second)
1356
1357
1358 err = db.PingContext(context.Background())
1359 require.NoError(t, err)
1360
1361
1362 require.EqualValues(t, 1, db.Stats().OpenConnections)
1363 c, err := db.Conn(context.Background())
1364 require.NoError(t, err)
1365
1366 var cPID uint32
1367 err = c.Raw(func(driverConn any) error {
1368 cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID()
1369 return nil
1370 })
1371 require.NoError(t, err)
1372 err = c.Close()
1373 require.NoError(t, err)
1374
1375 require.NotContains(t, pids, cPID)
1376 }
1377
View as plain text