1 package stdlib_test
2
3 import (
4 "bytes"
5 "context"
6 "database/sql"
7 "encoding/json"
8 "math"
9 "os"
10 "reflect"
11 "regexp"
12 "testing"
13 "time"
14
15 "github.com/Masterminds/semver/v3"
16 "github.com/jackc/pgconn"
17 "github.com/jackc/pgx/v4"
18 "github.com/jackc/pgx/v4/stdlib"
19 "github.com/stretchr/testify/assert"
20 "github.com/stretchr/testify/require"
21 )
22
23 func openDB(t testing.TB) *sql.DB {
24 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
25 require.NoError(t, err)
26 return stdlib.OpenDB(*config)
27 }
28
29 func closeDB(t testing.TB, db *sql.DB) {
30 err := db.Close()
31 require.NoError(t, err)
32 }
33
34 func skipCockroachDB(t testing.TB, db *sql.DB, msg string) {
35 conn, err := db.Conn(context.Background())
36 require.NoError(t, err)
37 defer conn.Close()
38
39 err = conn.Raw(func(driverConn interface{}) error {
40 conn := driverConn.(*stdlib.Conn).Conn()
41 if conn.PgConn().ParameterStatus("crdb_version") != "" {
42 t.Skip(msg)
43 }
44 return nil
45 })
46 require.NoError(t, err)
47 }
48
49 func skipPostgreSQLVersion(t testing.TB, db *sql.DB, constraintStr, msg string) {
50 conn, err := db.Conn(context.Background())
51 require.NoError(t, err)
52 defer conn.Close()
53
54 err = conn.Raw(func(driverConn interface{}) error {
55 conn := driverConn.(*stdlib.Conn).Conn()
56 serverVersionStr := conn.PgConn().ParameterStatus("server_version")
57 serverVersionStr = regexp.MustCompile(`^[0-9.]+`).FindString(serverVersionStr)
58
59 if serverVersionStr == "" {
60 return nil
61 }
62
63 serverVersion, err := semver.NewVersion(serverVersionStr)
64 if err != nil {
65 return err
66 }
67
68 c, err := semver.NewConstraint(constraintStr)
69 if err != nil {
70 return err
71 }
72
73 if c.Check(serverVersion) {
74 t.Skip(msg)
75 }
76 return nil
77 })
78 require.NoError(t, err)
79 }
80
81 func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) {
82 t.Run("SimpleProto",
83 func(t *testing.T) {
84 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
85 require.NoError(t, err)
86
87 config.PreferSimpleProtocol = true
88 db := stdlib.OpenDB(*config)
89 defer func() {
90 err := db.Close()
91 require.NoError(t, err)
92 }()
93
94 f(t, db)
95
96 ensureDBValid(t, db)
97 },
98 )
99
100 t.Run("DefaultProto",
101 func(t *testing.T) {
102 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
103 require.NoError(t, err)
104
105 db := stdlib.OpenDB(*config)
106 defer func() {
107 err := db.Close()
108 require.NoError(t, err)
109 }()
110
111 f(t, db)
112
113 ensureDBValid(t, db)
114 },
115 )
116 }
117
118
119
120 func ensureDBValid(t testing.TB, db *sql.DB) {
121 var sum, rowCount int32
122
123 rows, err := db.Query("select generate_series(1,$1)", 10)
124 require.NoError(t, err)
125 defer rows.Close()
126
127 for rows.Next() {
128 var n int32
129 rows.Scan(&n)
130 sum += n
131 rowCount++
132 }
133
134 require.NoError(t, rows.Err())
135
136 if rowCount != 10 {
137 t.Error("Select called onDataRow wrong number of times")
138 }
139 if sum != 55 {
140 t.Error("Wrong values returned")
141 }
142 }
143
144 type preparer interface {
145 Prepare(query string) (*sql.Stmt, error)
146 }
147
148 func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
149 stmt, err := p.Prepare(sql)
150 require.NoError(t, err)
151 return stmt
152 }
153
154 func closeStmt(t *testing.T, stmt *sql.Stmt) {
155 err := stmt.Close()
156 require.NoError(t, err)
157 }
158
159 func TestSQLOpen(t *testing.T) {
160 tests := []struct {
161 driverName string
162 }{
163 {driverName: "pgx"},
164 {driverName: "pgx/v4"},
165 }
166
167 for _, tt := range tests {
168 tt := tt
169
170 t.Run(tt.driverName, func(t *testing.T) {
171 db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE"))
172 require.NoError(t, err)
173 closeDB(t, db)
174 })
175 }
176 }
177
178 func TestNormalLifeCycle(t *testing.T) {
179 db := openDB(t)
180 defer closeDB(t, db)
181
182 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
183
184 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
185 defer closeStmt(t, stmt)
186
187 rows, err := stmt.Query(int32(1), int32(10))
188 require.NoError(t, err)
189
190 rowCount := int64(0)
191
192 for rows.Next() {
193 rowCount++
194
195 var s string
196 var n int64
197 err := rows.Scan(&s, &n)
198 require.NoError(t, err)
199
200 if s != "foo" {
201 t.Errorf(`Expected "foo", received "%v"`, s)
202 }
203 if n != rowCount {
204 t.Errorf("Expected %d, received %d", rowCount, n)
205 }
206 }
207 require.NoError(t, rows.Err())
208
209 require.EqualValues(t, 10, rowCount)
210
211 err = rows.Close()
212 require.NoError(t, err)
213
214 ensureDBValid(t, db)
215 }
216
217 func TestStmtExec(t *testing.T) {
218 db := openDB(t)
219 defer closeDB(t, db)
220
221 tx, err := db.Begin()
222 require.NoError(t, err)
223
224 createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
225 _, err = createStmt.Exec()
226 require.NoError(t, err)
227 closeStmt(t, createStmt)
228
229 insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
230 result, err := insertStmt.Exec("foo")
231 require.NoError(t, err)
232
233 n, err := result.RowsAffected()
234 require.NoError(t, err)
235 require.EqualValues(t, 1, n)
236 closeStmt(t, insertStmt)
237
238 ensureDBValid(t, db)
239 }
240
241 func TestQueryCloseRowsEarly(t *testing.T) {
242 db := openDB(t)
243 defer closeDB(t, db)
244
245 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
246
247 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
248 defer closeStmt(t, stmt)
249
250 rows, err := stmt.Query(int32(1), int32(10))
251 require.NoError(t, err)
252
253
254 err = rows.Close()
255 require.NoError(t, err)
256
257
258 rows, err = stmt.Query(int32(1), int32(10))
259 require.NoError(t, err)
260
261 rowCount := int64(0)
262
263 for rows.Next() {
264 rowCount++
265
266 var s string
267 var n int64
268 err := rows.Scan(&s, &n)
269 require.NoError(t, err)
270 if s != "foo" {
271 t.Errorf(`Expected "foo", received "%v"`, s)
272 }
273 if n != rowCount {
274 t.Errorf("Expected %d, received %d", rowCount, n)
275 }
276 }
277 require.NoError(t, rows.Err())
278 require.EqualValues(t, 10, rowCount)
279
280 err = rows.Close()
281 require.NoError(t, err)
282
283 ensureDBValid(t, db)
284 }
285
286 func TestConnExec(t *testing.T) {
287 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
288 _, err := db.Exec("create temporary table t(a varchar not null)")
289 require.NoError(t, err)
290
291 result, err := db.Exec("insert into t values('hey')")
292 require.NoError(t, err)
293
294 n, err := result.RowsAffected()
295 require.NoError(t, err)
296 require.EqualValues(t, 1, n)
297 })
298 }
299
300 func TestConnQuery(t *testing.T) {
301 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
302 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
303
304 rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
305 require.NoError(t, err)
306
307 rowCount := int64(0)
308
309 for rows.Next() {
310 rowCount++
311
312 var s string
313 var n int64
314 err := rows.Scan(&s, &n)
315 require.NoError(t, err)
316 if s != "foo" {
317 t.Errorf(`Expected "foo", received "%v"`, s)
318 }
319 if n != rowCount {
320 t.Errorf("Expected %d, received %d", rowCount, n)
321 }
322 }
323 require.NoError(t, rows.Err())
324 require.EqualValues(t, 10, rowCount)
325
326 err = rows.Close()
327 require.NoError(t, err)
328 })
329 }
330
331
332 func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
333 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
334 var s string
335 var b bool
336
337 rows, err := db.Query("select true, 'foo'")
338 require.NoError(t, err)
339
340 require.True(t, rows.Next())
341 require.NoError(t, rows.Scan(&b, &s))
342 assert.Equal(t, true, b)
343 assert.Equal(t, "foo", s)
344 })
345 }
346
347 func TestConnQueryNull(t *testing.T) {
348 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
349 rows, err := db.Query("select $1::int", nil)
350 require.NoError(t, err)
351
352 rowCount := int64(0)
353
354 for rows.Next() {
355 rowCount++
356
357 var n sql.NullInt64
358 err := rows.Scan(&n)
359 require.NoError(t, err)
360 if n.Valid != false {
361 t.Errorf("Expected n to be null, but it was %v", n)
362 }
363 }
364 require.NoError(t, rows.Err())
365 require.EqualValues(t, 1, rowCount)
366
367 err = rows.Close()
368 require.NoError(t, err)
369 })
370 }
371
372 func TestConnQueryRowByteSlice(t *testing.T) {
373 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
374 expected := []byte{222, 173, 190, 239}
375 var actual []byte
376
377 err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
378 require.NoError(t, err)
379 require.EqualValues(t, expected, actual)
380 })
381 }
382
383 func TestConnQueryFailure(t *testing.T) {
384 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
385 _, err := db.Query("select 'foo")
386 require.Error(t, err)
387 require.IsType(t, new(pgconn.PgError), err)
388 })
389 }
390
391 func TestConnSimpleSlicePassThrough(t *testing.T) {
392 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
393 skipCockroachDB(t, db, "Server does not support cardinality function")
394
395 var n int64
396 err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
397 require.NoError(t, err)
398 assert.EqualValues(t, 3, n)
399 })
400 }
401
402
403
404 func TestConnQueryRowPgxBinary(t *testing.T) {
405 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
406 sql := "select $1::int4[]"
407 expected := "{1,2,3}"
408 var actual string
409
410 err := db.QueryRow(sql, expected).Scan(&actual)
411 require.NoError(t, err)
412 require.EqualValues(t, expected, actual)
413 })
414 }
415
416 func TestConnQueryRowUnknownType(t *testing.T) {
417 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
418 skipCockroachDB(t, db, "Server does not support point type")
419
420 sql := "select $1::point"
421 expected := "(1,2)"
422 var actual string
423
424 err := db.QueryRow(sql, expected).Scan(&actual)
425 require.NoError(t, err)
426 require.EqualValues(t, expected, actual)
427 })
428 }
429
430 func TestConnQueryJSONIntoByteSlice(t *testing.T) {
431 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
432 _, err := db.Exec(`
433 create temporary table docs(
434 body json not null
435 );
436
437 insert into docs(body) values('{"foo": "bar"}');
438 `)
439 require.NoError(t, err)
440
441 sql := `select * from docs`
442 expected := []byte(`{"foo": "bar"}`)
443 var actual []byte
444
445 err = db.QueryRow(sql).Scan(&actual)
446 if err != nil {
447 t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
448 }
449
450 if bytes.Compare(actual, expected) != 0 {
451 t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
452 }
453
454 _, err = db.Exec(`drop table docs`)
455 require.NoError(t, err)
456 })
457 }
458
459 func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
460
461
462
463 db := openDB(t)
464 defer closeDB(t, db)
465
466 _, err := db.Exec(`
467 create temporary table docs(
468 body json not null
469 );
470 `)
471 require.NoError(t, err)
472
473 expected := []byte(`{"foo": "bar"}`)
474
475 _, err = db.Exec(`insert into docs(body) values($1)`, expected)
476 require.NoError(t, err)
477
478 var actual []byte
479 err = db.QueryRow(`select body from docs`).Scan(&actual)
480 require.NoError(t, err)
481
482 if bytes.Compare(actual, expected) != 0 {
483 t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
484 }
485
486 _, err = db.Exec(`drop table docs`)
487 require.NoError(t, err)
488 }
489
490 func TestTransactionLifeCycle(t *testing.T) {
491 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
492 _, err := db.Exec("create temporary table t(a varchar not null)")
493 require.NoError(t, err)
494
495 tx, err := db.Begin()
496 require.NoError(t, err)
497
498 _, err = tx.Exec("insert into t values('hi')")
499 require.NoError(t, err)
500
501 err = tx.Rollback()
502 require.NoError(t, err)
503
504 var n int64
505 err = db.QueryRow("select count(*) from t").Scan(&n)
506 require.NoError(t, err)
507 require.EqualValues(t, 0, n)
508
509 tx, err = db.Begin()
510 require.NoError(t, err)
511
512 _, err = tx.Exec("insert into t values('hi')")
513 require.NoError(t, err)
514
515 err = tx.Commit()
516 require.NoError(t, err)
517
518 err = db.QueryRow("select count(*) from t").Scan(&n)
519 require.NoError(t, err)
520 require.EqualValues(t, 1, n)
521 })
522 }
523
524 func TestConnBeginTxIsolation(t *testing.T) {
525 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
526 skipCockroachDB(t, db, "Server always uses serializable isolation level")
527
528 var defaultIsoLevel string
529 err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
530 require.NoError(t, err)
531
532 supportedTests := []struct {
533 sqlIso sql.IsolationLevel
534 pgIso string
535 }{
536 {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
537 {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
538 {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
539 {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
540 {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
541 {sqlIso: sql.LevelSerializable, pgIso: "serializable"},
542 }
543 for i, tt := range supportedTests {
544 func() {
545 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
546 if err != nil {
547 t.Errorf("%d. BeginTx failed: %v", i, err)
548 return
549 }
550 defer tx.Rollback()
551
552 var pgIso string
553 err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
554 if err != nil {
555 t.Errorf("%d. QueryRow failed: %v", i, err)
556 }
557
558 if pgIso != tt.pgIso {
559 t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
560 }
561 }()
562 }
563
564 unsupportedTests := []struct {
565 sqlIso sql.IsolationLevel
566 }{
567 {sqlIso: sql.LevelWriteCommitted},
568 {sqlIso: sql.LevelLinearizable},
569 }
570 for i, tt := range unsupportedTests {
571 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
572 if err == nil {
573 t.Errorf("%d. BeginTx should have failed", i)
574 tx.Rollback()
575 }
576 }
577 })
578 }
579
580 func TestConnBeginTxReadOnly(t *testing.T) {
581 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
582 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
583 require.NoError(t, err)
584 defer tx.Rollback()
585
586 var pgReadOnly string
587 err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
588 if err != nil {
589 t.Errorf("QueryRow failed: %v", err)
590 }
591
592 if pgReadOnly != "on" {
593 t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
594 }
595 })
596 }
597
598 func TestBeginTxContextCancel(t *testing.T) {
599 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
600 _, err := db.Exec("drop table if exists t")
601 require.NoError(t, err)
602
603 ctx, cancelFn := context.WithCancel(context.Background())
604
605 tx, err := db.BeginTx(ctx, nil)
606 require.NoError(t, err)
607
608 _, err = tx.Exec("create table t(id serial)")
609 require.NoError(t, err)
610
611 cancelFn()
612
613 err = tx.Commit()
614 if err != context.Canceled && err != sql.ErrTxDone {
615 t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
616 }
617
618 var n int
619 err = db.QueryRow("select count(*) from t").Scan(&n)
620 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
621 t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
622 }
623 })
624 }
625
626 func TestAcquireConn(t *testing.T) {
627 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
628 var conns []*pgx.Conn
629
630 for i := 1; i < 6; i++ {
631 conn, err := stdlib.AcquireConn(db)
632 if err != nil {
633 t.Errorf("%d. AcquireConn failed: %v", i, err)
634 continue
635 }
636
637 var n int32
638 err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
639 if err != nil {
640 t.Errorf("%d. QueryRow failed: %v", i, err)
641 }
642 if n != 1 {
643 t.Errorf("%d. n => %d, want %d", i, n, 1)
644 }
645
646 stats := db.Stats()
647 if stats.OpenConnections != i {
648 t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i)
649 }
650
651 conns = append(conns, conn)
652 }
653
654 for i, conn := range conns {
655 if err := stdlib.ReleaseConn(db, conn); err != nil {
656 t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err)
657 }
658 }
659 })
660 }
661
662 func TestConnRaw(t *testing.T) {
663 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
664 conn, err := db.Conn(context.Background())
665 require.NoError(t, err)
666
667 var n int
668 err = conn.Raw(func(driverConn interface{}) error {
669 conn := driverConn.(*stdlib.Conn).Conn()
670 return conn.QueryRow(context.Background(), "select 42").Scan(&n)
671 })
672 require.NoError(t, err)
673 assert.EqualValues(t, 42, n)
674 })
675 }
676
677
678 func TestReleaseConnWithTxInProgress(t *testing.T) {
679 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
680 skipCockroachDB(t, db, "Server does not support backend PID")
681
682 c1, err := stdlib.AcquireConn(db)
683 require.NoError(t, err)
684
685 _, err = c1.Exec(context.Background(), "begin")
686 require.NoError(t, err)
687
688 c1PID := c1.PgConn().PID()
689
690 err = stdlib.ReleaseConn(db, c1)
691 require.NoError(t, err)
692
693 c2, err := stdlib.AcquireConn(db)
694 require.NoError(t, err)
695
696 c2PID := c2.PgConn().PID()
697
698 err = stdlib.ReleaseConn(db, c2)
699 require.NoError(t, err)
700
701 require.NotEqual(t, c1PID, c2PID)
702
703
704 stats := db.Stats()
705 require.Equal(t, 1, stats.OpenConnections)
706 })
707 }
708
709 func TestConnPingContextSuccess(t *testing.T) {
710 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
711 err := db.PingContext(context.Background())
712 require.NoError(t, err)
713 })
714 }
715
716 func TestConnPrepareContextSuccess(t *testing.T) {
717 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
718 stmt, err := db.PrepareContext(context.Background(), "select now()")
719 require.NoError(t, err)
720 err = stmt.Close()
721 require.NoError(t, err)
722 })
723 }
724
725 func TestConnExecContextSuccess(t *testing.T) {
726 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
727 _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
728 require.NoError(t, err)
729 })
730 }
731
732 func TestConnExecContextFailureRetry(t *testing.T) {
733 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
734
735
736 {
737 conn, err := stdlib.AcquireConn(db)
738 require.NoError(t, err)
739 conn.Close(context.Background())
740 stdlib.ReleaseConn(db, conn)
741 }
742 conn, err := db.Conn(context.Background())
743 require.NoError(t, err)
744 _, err = conn.ExecContext(context.Background(), "select 1")
745 require.NoError(t, err)
746 })
747 }
748
749 func TestConnQueryContextSuccess(t *testing.T) {
750 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
751 rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
752 require.NoError(t, err)
753
754 for rows.Next() {
755 var n int64
756 err := rows.Scan(&n)
757 require.NoError(t, err)
758 }
759 require.NoError(t, rows.Err())
760 })
761 }
762
763 func TestConnQueryContextFailureRetry(t *testing.T) {
764 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
765
766
767 {
768 conn, err := stdlib.AcquireConn(db)
769 require.NoError(t, err)
770 conn.Close(context.Background())
771 stdlib.ReleaseConn(db, conn)
772 }
773 conn, err := db.Conn(context.Background())
774 require.NoError(t, err)
775
776 _, err = conn.QueryContext(context.Background(), "select 1")
777 require.NoError(t, err)
778 })
779 }
780
781 func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
782 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
783 rows, err := db.Query("select 42::bigint")
784 require.NoError(t, err)
785
786 columnTypes, err := rows.ColumnTypes()
787 require.NoError(t, err)
788 require.Len(t, columnTypes, 1)
789
790 if columnTypes[0].DatabaseTypeName() != "INT8" {
791 t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8")
792 }
793
794 err = rows.Close()
795 require.NoError(t, err)
796 })
797 }
798
799 func TestStmtExecContextSuccess(t *testing.T) {
800 db := openDB(t)
801 defer closeDB(t, db)
802
803 _, err := db.Exec("create temporary table t(id int primary key)")
804 require.NoError(t, err)
805
806 stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
807 require.NoError(t, err)
808 defer stmt.Close()
809
810 _, err = stmt.ExecContext(context.Background(), 42)
811 require.NoError(t, err)
812
813 ensureDBValid(t, db)
814 }
815
816 func TestStmtExecContextCancel(t *testing.T) {
817 db := openDB(t)
818 defer closeDB(t, db)
819
820 _, err := db.Exec("create temporary table t(id int primary key)")
821 require.NoError(t, err)
822
823 stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
824 require.NoError(t, err)
825 defer stmt.Close()
826
827 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
828 defer cancel()
829
830 _, err = stmt.ExecContext(ctx, 42)
831 if !pgconn.Timeout(err) {
832 t.Errorf("expected timeout error, got %v", err)
833 }
834
835 ensureDBValid(t, db)
836 }
837
838 func TestStmtQueryContextSuccess(t *testing.T) {
839 db := openDB(t)
840 defer closeDB(t, db)
841
842 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
843
844 stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
845 require.NoError(t, err)
846 defer stmt.Close()
847
848 rows, err := stmt.QueryContext(context.Background(), 5)
849 require.NoError(t, err)
850
851 for rows.Next() {
852 var n int64
853 if err := rows.Scan(&n); err != nil {
854 t.Error(err)
855 }
856 }
857
858 if rows.Err() != nil {
859 t.Error(rows.Err())
860 }
861
862 ensureDBValid(t, db)
863 }
864
865 func TestRowsColumnTypes(t *testing.T) {
866 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
867 columnTypesTests := []struct {
868 Name string
869 TypeName string
870 Length struct {
871 Len int64
872 OK bool
873 }
874 DecimalSize struct {
875 Precision int64
876 Scale int64
877 OK bool
878 }
879 ScanType reflect.Type
880 }{
881 {
882 Name: "a",
883 TypeName: "INT8",
884 Length: struct {
885 Len int64
886 OK bool
887 }{
888 Len: 0,
889 OK: false,
890 },
891 DecimalSize: struct {
892 Precision int64
893 Scale int64
894 OK bool
895 }{
896 Precision: 0,
897 Scale: 0,
898 OK: false,
899 },
900 ScanType: reflect.TypeOf(int64(0)),
901 }, {
902 Name: "bar",
903 TypeName: "TEXT",
904 Length: struct {
905 Len int64
906 OK bool
907 }{
908 Len: math.MaxInt64,
909 OK: true,
910 },
911 DecimalSize: struct {
912 Precision int64
913 Scale int64
914 OK bool
915 }{
916 Precision: 0,
917 Scale: 0,
918 OK: false,
919 },
920 ScanType: reflect.TypeOf(""),
921 }, {
922 Name: "dec",
923 TypeName: "NUMERIC",
924 Length: struct {
925 Len int64
926 OK bool
927 }{
928 Len: 0,
929 OK: false,
930 },
931 DecimalSize: struct {
932 Precision int64
933 Scale int64
934 OK bool
935 }{
936 Precision: 9,
937 Scale: 2,
938 OK: true,
939 },
940 ScanType: reflect.TypeOf(float64(0)),
941 }, {
942 Name: "d",
943 TypeName: "1266",
944 Length: struct {
945 Len int64
946 OK bool
947 }{
948 Len: 0,
949 OK: false,
950 },
951 DecimalSize: struct {
952 Precision int64
953 Scale int64
954 OK bool
955 }{
956 Precision: 0,
957 Scale: 0,
958 OK: false,
959 },
960 ScanType: reflect.TypeOf(""),
961 },
962 }
963
964 rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
965 require.NoError(t, err)
966
967 columns, err := rows.ColumnTypes()
968 require.NoError(t, err)
969 assert.Len(t, columns, 4)
970
971 for i, tt := range columnTypesTests {
972 c := columns[i]
973 if c.Name() != tt.Name {
974 t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
975 }
976 if c.DatabaseTypeName() != tt.TypeName {
977 t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
978 }
979 l, ok := c.Length()
980 if l != tt.Length.Len {
981 t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
982 }
983 if ok != tt.Length.OK {
984 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
985 }
986 p, s, ok := c.DecimalSize()
987 if p != tt.DecimalSize.Precision {
988 t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
989 }
990 if s != tt.DecimalSize.Scale {
991 t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
992 }
993 if ok != tt.DecimalSize.OK {
994 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
995 }
996 if c.ScanType() != tt.ScanType {
997 t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
998 }
999 }
1000 })
1001 }
1002
1003 func TestQueryLifeCycle(t *testing.T) {
1004 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
1005 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
1006
1007 rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
1008 require.NoError(t, err)
1009
1010 rowCount := int64(0)
1011
1012 for rows.Next() {
1013 rowCount++
1014 var (
1015 s string
1016 n int64
1017 )
1018
1019 err := rows.Scan(&s, &n)
1020 require.NoError(t, err)
1021
1022 if s != "foo" {
1023 t.Errorf(`Expected "foo", received "%v"`, s)
1024 }
1025
1026 if n != rowCount {
1027 t.Errorf("Expected %d, received %d", rowCount, n)
1028 }
1029 }
1030 require.NoError(t, rows.Err())
1031
1032 err = rows.Close()
1033 require.NoError(t, err)
1034
1035 rows, err = db.Query("select 1 where false")
1036 require.NoError(t, err)
1037
1038 rowCount = int64(0)
1039
1040 for rows.Next() {
1041 rowCount++
1042 }
1043 require.NoError(t, rows.Err())
1044 require.EqualValues(t, 0, rowCount)
1045
1046 err = rows.Close()
1047 require.NoError(t, err)
1048 })
1049 }
1050
1051
1052 func TestScanJSONIntoJSONRawMessage(t *testing.T) {
1053 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
1054 var msg json.RawMessage
1055
1056 err := db.QueryRow("select '{}'::json").Scan(&msg)
1057 require.NoError(t, err)
1058 require.EqualValues(t, []byte("{}"), []byte(msg))
1059 })
1060 }
1061
1062 type testLog struct {
1063 lvl pgx.LogLevel
1064 msg string
1065 data map[string]interface{}
1066 }
1067
1068 type testLogger struct {
1069 logs []testLog
1070 }
1071
1072 func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) {
1073 l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
1074 }
1075
1076 func TestRegisterConnConfig(t *testing.T) {
1077 connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1078 require.NoError(t, err)
1079
1080 logger := &testLogger{}
1081 connConfig.Logger = logger
1082
1083
1084
1085 connStr := stdlib.RegisterConnConfig(connConfig)
1086 require.Equal(t, "registeredConnConfig0", connStr)
1087 stdlib.UnregisterConnConfig(connStr)
1088
1089 connStr = stdlib.RegisterConnConfig(connConfig)
1090 defer stdlib.UnregisterConnConfig(connStr)
1091 require.Equal(t, "registeredConnConfig1", connStr)
1092
1093 db, err := sql.Open("pgx", connStr)
1094 require.NoError(t, err)
1095 defer closeDB(t, db)
1096
1097 var n int64
1098 err = db.QueryRow("select 1").Scan(&n)
1099 require.NoError(t, err)
1100
1101 l := logger.logs[len(logger.logs)-1]
1102 assert.Equal(t, "Query", l.msg)
1103 assert.Equal(t, "select 1", l.data["sql"])
1104 }
1105
1106
1107 func TestConnQueryRowConstraintErrors(t *testing.T) {
1108 testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
1109 skipPostgreSQLVersion(t, db, "< 11", "Test requires PG 11+")
1110 skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
1111
1112 _, err := db.Exec(`create temporary table defer_test (
1113 id text primary key,
1114 n int not null, unique (n),
1115 unique (n) deferrable initially deferred )`)
1116 require.NoError(t, err)
1117
1118 _, err = db.Exec(`drop function if exists test_trigger cascade`)
1119 require.NoError(t, err)
1120
1121 _, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$
1122 begin
1123 if new.n = 4 then
1124 raise exception 'n cant be 4!';
1125 end if;
1126 return new;
1127 end$$`)
1128 require.NoError(t, err)
1129
1130 _, err = db.Exec(`create constraint trigger test
1131 after insert or update on defer_test
1132 deferrable initially deferred
1133 for each row
1134 execute function test_trigger()`)
1135 require.NoError(t, err)
1136
1137 _, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`)
1138 require.NoError(t, err)
1139
1140 var id string
1141 err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id)
1142 assert.Error(t, err)
1143 })
1144 }
1145
1146 func TestOptionBeforeAfterConnect(t *testing.T) {
1147 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1148 require.NoError(t, err)
1149
1150 var beforeConnConfigs []*pgx.ConnConfig
1151 var afterConns []*pgx.Conn
1152 db := stdlib.OpenDB(*config,
1153 stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error {
1154 beforeConnConfigs = append(beforeConnConfigs, connConfig)
1155 return nil
1156 }),
1157 stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error {
1158 afterConns = append(afterConns, conn)
1159 return nil
1160 }))
1161 defer closeDB(t, db)
1162
1163
1164 db.SetMaxIdleConns(0)
1165
1166 _, err = db.Exec("select 1")
1167 require.NoError(t, err)
1168
1169 _, err = db.Exec("select 1")
1170 require.NoError(t, err)
1171
1172 require.Len(t, beforeConnConfigs, 2)
1173 require.Len(t, afterConns, 2)
1174
1175
1176
1177 require.False(t, config == beforeConnConfigs[0])
1178 require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1])
1179 }
1180
1181 func TestRandomizeHostOrderFunc(t *testing.T) {
1182 config, err := pgx.ParseConfig("postgres://host1,host2,host3")
1183 require.NoError(t, err)
1184
1185
1186 hostsNotSeenYet := map[string]struct{}{
1187 "host1": struct{}{},
1188 "host2": struct{}{},
1189 "host3": struct{}{},
1190 }
1191
1192
1193 for i := 0; i < 100000; i++ {
1194 connCopy := *config
1195 stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy)
1196
1197 delete(hostsNotSeenYet, connCopy.Host)
1198 if len(hostsNotSeenYet) == 0 {
1199 return
1200 }
1201
1202 hostCheckLoop:
1203 for _, h := range []string{"host1", "host2", "host3"} {
1204 if connCopy.Host == h {
1205 continue
1206 }
1207 for _, f := range connCopy.Fallbacks {
1208 if f.Host == h {
1209 continue hostCheckLoop
1210 }
1211 }
1212 require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy)
1213 }
1214 }
1215
1216 require.Fail(t, "did not get all hosts as primaries after many randomizations")
1217 }
1218
1219 func TestResetSessionHookCalled(t *testing.T) {
1220 var mockCalled bool
1221
1222 connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1223 require.NoError(t, err)
1224
1225 db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error {
1226 mockCalled = true
1227
1228 return nil
1229 }))
1230
1231 defer closeDB(t, db)
1232
1233 err = db.Ping()
1234 require.NoError(t, err)
1235
1236 err = db.Ping()
1237 require.NoError(t, err)
1238
1239 require.True(t, mockCalled)
1240 }
1241
View as plain text