...

Source file src/github.com/jackc/pgx/v5/stdlib/sql_test.go

Documentation: github.com/jackc/pgx/v5/stdlib

     1  package stdlib_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"encoding/json"
     8  	"fmt"
     9  	"math"
    10  	"os"
    11  	"reflect"
    12  	"regexp"
    13  	"strconv"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/jackc/pgx/v5"
    19  	"github.com/jackc/pgx/v5/pgconn"
    20  	"github.com/jackc/pgx/v5/pgtype"
    21  	"github.com/jackc/pgx/v5/pgxpool"
    22  	"github.com/jackc/pgx/v5/stdlib"
    23  	"github.com/jackc/pgx/v5/tracelog"
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  func openDB(t testing.TB) *sql.DB {
    29  	config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
    30  	require.NoError(t, err)
    31  	return stdlib.OpenDB(*config)
    32  }
    33  
    34  func closeDB(t testing.TB, db *sql.DB) {
    35  	err := db.Close()
    36  	require.NoError(t, err)
    37  }
    38  
    39  func skipCockroachDB(t testing.TB, db *sql.DB, msg string) {
    40  	conn, err := db.Conn(context.Background())
    41  	require.NoError(t, err)
    42  	defer conn.Close()
    43  
    44  	err = conn.Raw(func(driverConn any) error {
    45  		conn := driverConn.(*stdlib.Conn).Conn()
    46  		if conn.PgConn().ParameterStatus("crdb_version") != "" {
    47  			t.Skip(msg)
    48  		}
    49  		return nil
    50  	})
    51  	require.NoError(t, err)
    52  }
    53  
    54  func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) {
    55  	conn, err := db.Conn(context.Background())
    56  	require.NoError(t, err)
    57  	defer conn.Close()
    58  
    59  	err = conn.Raw(func(driverConn any) error {
    60  		conn := driverConn.(*stdlib.Conn).Conn()
    61  		serverVersionStr := conn.PgConn().ParameterStatus("server_version")
    62  		serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
    63  		// if not PostgreSQL do nothing
    64  		if serverVersionStr == "" {
    65  			return nil
    66  		}
    67  
    68  		serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
    69  		if err != nil {
    70  			return err
    71  		}
    72  
    73  		if serverVersion < minVersion {
    74  			t.Skipf("Test requires PostgreSQL v%d+", minVersion)
    75  		}
    76  
    77  		return nil
    78  	})
    79  	require.NoError(t, err)
    80  }
    81  
    82  func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) {
    83  	for _, mode := range []pgx.QueryExecMode{
    84  		pgx.QueryExecModeCacheStatement,
    85  		pgx.QueryExecModeCacheDescribe,
    86  		pgx.QueryExecModeDescribeExec,
    87  		pgx.QueryExecModeExec,
    88  		pgx.QueryExecModeSimpleProtocol,
    89  	} {
    90  		t.Run(mode.String(),
    91  			func(t *testing.T) {
    92  				config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
    93  				require.NoError(t, err)
    94  
    95  				config.DefaultQueryExecMode = mode
    96  				db := stdlib.OpenDB(*config)
    97  				defer func() {
    98  					err := db.Close()
    99  					require.NoError(t, err)
   100  				}()
   101  
   102  				f(t, db)
   103  
   104  				ensureDBValid(t, db)
   105  			},
   106  		)
   107  	}
   108  }
   109  
   110  // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
   111  // cover broken connections.
   112  func ensureDBValid(t testing.TB, db *sql.DB) {
   113  	var sum, rowCount int32
   114  
   115  	rows, err := db.Query("select generate_series(1,$1)", 10)
   116  	require.NoError(t, err)
   117  	defer rows.Close()
   118  
   119  	for rows.Next() {
   120  		var n int32
   121  		rows.Scan(&n)
   122  		sum += n
   123  		rowCount++
   124  	}
   125  
   126  	require.NoError(t, rows.Err())
   127  
   128  	if rowCount != 10 {
   129  		t.Error("Select called onDataRow wrong number of times")
   130  	}
   131  	if sum != 55 {
   132  		t.Error("Wrong values returned")
   133  	}
   134  }
   135  
   136  type preparer interface {
   137  	Prepare(query string) (*sql.Stmt, error)
   138  }
   139  
   140  func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
   141  	stmt, err := p.Prepare(sql)
   142  	require.NoError(t, err)
   143  	return stmt
   144  }
   145  
   146  func closeStmt(t *testing.T, stmt *sql.Stmt) {
   147  	err := stmt.Close()
   148  	require.NoError(t, err)
   149  }
   150  
   151  func TestSQLOpen(t *testing.T) {
   152  	tests := []struct {
   153  		driverName string
   154  	}{
   155  		{driverName: "pgx"},
   156  		{driverName: "pgx/v5"},
   157  	}
   158  
   159  	for _, tt := range tests {
   160  		tt := tt
   161  
   162  		t.Run(tt.driverName, func(t *testing.T) {
   163  			db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE"))
   164  			require.NoError(t, err)
   165  			closeDB(t, db)
   166  		})
   167  	}
   168  }
   169  
   170  func TestSQLOpenFromPool(t *testing.T) {
   171  	pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
   172  	require.NoError(t, err)
   173  	t.Cleanup(pool.Close)
   174  
   175  	db := stdlib.OpenDBFromPool(pool)
   176  	ensureDBValid(t, db)
   177  
   178  	db.Close()
   179  }
   180  
   181  func TestNormalLifeCycle(t *testing.T) {
   182  	db := openDB(t)
   183  	defer closeDB(t, db)
   184  
   185  	skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   186  
   187  	stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
   188  	defer closeStmt(t, stmt)
   189  
   190  	rows, err := stmt.Query(int32(1), int32(10))
   191  	require.NoError(t, err)
   192  
   193  	rowCount := int64(0)
   194  
   195  	for rows.Next() {
   196  		rowCount++
   197  
   198  		var s string
   199  		var n int64
   200  		err := rows.Scan(&s, &n)
   201  		require.NoError(t, err)
   202  
   203  		if s != "foo" {
   204  			t.Errorf(`Expected "foo", received "%v"`, s)
   205  		}
   206  		if n != rowCount {
   207  			t.Errorf("Expected %d, received %d", rowCount, n)
   208  		}
   209  	}
   210  	require.NoError(t, rows.Err())
   211  
   212  	require.EqualValues(t, 10, rowCount)
   213  
   214  	err = rows.Close()
   215  	require.NoError(t, err)
   216  
   217  	ensureDBValid(t, db)
   218  }
   219  
   220  func TestStmtExec(t *testing.T) {
   221  	db := openDB(t)
   222  	defer closeDB(t, db)
   223  
   224  	tx, err := db.Begin()
   225  	require.NoError(t, err)
   226  
   227  	createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
   228  	_, err = createStmt.Exec()
   229  	require.NoError(t, err)
   230  	closeStmt(t, createStmt)
   231  
   232  	insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
   233  	result, err := insertStmt.Exec("foo")
   234  	require.NoError(t, err)
   235  
   236  	n, err := result.RowsAffected()
   237  	require.NoError(t, err)
   238  	require.EqualValues(t, 1, n)
   239  	closeStmt(t, insertStmt)
   240  
   241  	ensureDBValid(t, db)
   242  }
   243  
   244  func TestQueryCloseRowsEarly(t *testing.T) {
   245  	db := openDB(t)
   246  	defer closeDB(t, db)
   247  
   248  	skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   249  
   250  	stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
   251  	defer closeStmt(t, stmt)
   252  
   253  	rows, err := stmt.Query(int32(1), int32(10))
   254  	require.NoError(t, err)
   255  
   256  	// Close rows immediately without having read them
   257  	err = rows.Close()
   258  	require.NoError(t, err)
   259  
   260  	// Run the query again to ensure the connection and statement are still ok
   261  	rows, err = stmt.Query(int32(1), int32(10))
   262  	require.NoError(t, err)
   263  
   264  	rowCount := int64(0)
   265  
   266  	for rows.Next() {
   267  		rowCount++
   268  
   269  		var s string
   270  		var n int64
   271  		err := rows.Scan(&s, &n)
   272  		require.NoError(t, err)
   273  		if s != "foo" {
   274  			t.Errorf(`Expected "foo", received "%v"`, s)
   275  		}
   276  		if n != rowCount {
   277  			t.Errorf("Expected %d, received %d", rowCount, n)
   278  		}
   279  	}
   280  	require.NoError(t, rows.Err())
   281  	require.EqualValues(t, 10, rowCount)
   282  
   283  	err = rows.Close()
   284  	require.NoError(t, err)
   285  
   286  	ensureDBValid(t, db)
   287  }
   288  
   289  func TestConnExec(t *testing.T) {
   290  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   291  		_, err := db.Exec("create temporary table t(a varchar not null)")
   292  		require.NoError(t, err)
   293  
   294  		result, err := db.Exec("insert into t values('hey')")
   295  		require.NoError(t, err)
   296  
   297  		n, err := result.RowsAffected()
   298  		require.NoError(t, err)
   299  		require.EqualValues(t, 1, n)
   300  	})
   301  }
   302  
   303  func TestConnQuery(t *testing.T) {
   304  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   305  		skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   306  
   307  		rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
   308  		require.NoError(t, err)
   309  
   310  		rowCount := int64(0)
   311  
   312  		for rows.Next() {
   313  			rowCount++
   314  
   315  			var s string
   316  			var n int64
   317  			err := rows.Scan(&s, &n)
   318  			require.NoError(t, err)
   319  			if s != "foo" {
   320  				t.Errorf(`Expected "foo", received "%v"`, s)
   321  			}
   322  			if n != rowCount {
   323  				t.Errorf("Expected %d, received %d", rowCount, n)
   324  			}
   325  		}
   326  		require.NoError(t, rows.Err())
   327  		require.EqualValues(t, 10, rowCount)
   328  
   329  		err = rows.Close()
   330  		require.NoError(t, err)
   331  	})
   332  }
   333  
   334  func TestConnConcurrency(t *testing.T) {
   335  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   336  		_, err := db.Exec("create table t (id integer primary key, str text, dur_str interval)")
   337  		require.NoError(t, err)
   338  
   339  		defer func() {
   340  			_, err := db.Exec("drop table t")
   341  			require.NoError(t, err)
   342  		}()
   343  
   344  		var wg sync.WaitGroup
   345  
   346  		concurrency := 50
   347  		errChan := make(chan error, concurrency)
   348  
   349  		for i := 1; i <= concurrency; i++ {
   350  			wg.Add(1)
   351  
   352  			go func(idx int) {
   353  				defer wg.Done()
   354  
   355  				ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   356  				defer cancel()
   357  
   358  				str := strconv.Itoa(idx)
   359  				duration := time.Duration(idx) * time.Second
   360  				_, err := db.ExecContext(ctx, "insert into t values($1)", idx)
   361  				if err != nil {
   362  					errChan <- fmt.Errorf("insert failed: %d %w", idx, err)
   363  					return
   364  				}
   365  				_, err = db.ExecContext(ctx, "update t set str = $1 where id = $2", str, idx)
   366  				if err != nil {
   367  					errChan <- fmt.Errorf("update 1 failed: %d %w", idx, err)
   368  					return
   369  				}
   370  				_, err = db.ExecContext(ctx, "update t set dur_str = $1 where id = $2", duration, idx)
   371  				if err != nil {
   372  					errChan <- fmt.Errorf("update 2 failed: %d %w", idx, err)
   373  					return
   374  				}
   375  
   376  				errChan <- nil
   377  			}(i)
   378  		}
   379  		wg.Wait()
   380  		for i := 1; i <= concurrency; i++ {
   381  			err := <-errChan
   382  			require.NoError(t, err)
   383  		}
   384  
   385  		for i := 1; i <= concurrency; i++ {
   386  			wg.Add(1)
   387  
   388  			go func(idx int) {
   389  				defer wg.Done()
   390  
   391  				ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   392  				defer cancel()
   393  
   394  				var id int
   395  				var str string
   396  				var duration pgtype.Interval
   397  				err := db.QueryRowContext(ctx, "select id,str,dur_str from t where id = $1", idx).Scan(&id, &str, &duration)
   398  				if err != nil {
   399  					errChan <- fmt.Errorf("select failed: %d %w", idx, err)
   400  					return
   401  				}
   402  				if id != idx {
   403  					errChan <- fmt.Errorf("id mismatch: %d %d", idx, id)
   404  					return
   405  				}
   406  				if str != strconv.Itoa(idx) {
   407  					errChan <- fmt.Errorf("str mismatch: %d %s", idx, str)
   408  					return
   409  				}
   410  				expectedDuration := pgtype.Interval{
   411  					Microseconds: int64(idx) * time.Second.Microseconds(),
   412  					Valid:        true,
   413  				}
   414  				if duration != expectedDuration {
   415  					errChan <- fmt.Errorf("duration mismatch: %d %v", idx, duration)
   416  					return
   417  				}
   418  
   419  				errChan <- nil
   420  			}(i)
   421  		}
   422  		wg.Wait()
   423  		for i := 1; i <= concurrency; i++ {
   424  			err := <-errChan
   425  			require.NoError(t, err)
   426  		}
   427  	})
   428  }
   429  
   430  // https://github.com/jackc/pgx/issues/781
   431  func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
   432  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   433  		var s string
   434  		var b bool
   435  
   436  		rows, err := db.Query("select true, 'foo'")
   437  		require.NoError(t, err)
   438  
   439  		require.True(t, rows.Next())
   440  		require.NoError(t, rows.Scan(&b, &s))
   441  		assert.Equal(t, true, b)
   442  		assert.Equal(t, "foo", s)
   443  	})
   444  }
   445  
   446  func TestConnQueryNull(t *testing.T) {
   447  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   448  		rows, err := db.Query("select $1::int", nil)
   449  		require.NoError(t, err)
   450  
   451  		rowCount := int64(0)
   452  
   453  		for rows.Next() {
   454  			rowCount++
   455  
   456  			var n sql.NullInt64
   457  			err := rows.Scan(&n)
   458  			require.NoError(t, err)
   459  			if n.Valid != false {
   460  				t.Errorf("Expected n to be null, but it was %v", n)
   461  			}
   462  		}
   463  		require.NoError(t, rows.Err())
   464  		require.EqualValues(t, 1, rowCount)
   465  
   466  		err = rows.Close()
   467  		require.NoError(t, err)
   468  	})
   469  }
   470  
   471  func TestConnQueryRowByteSlice(t *testing.T) {
   472  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   473  		expected := []byte{222, 173, 190, 239}
   474  		var actual []byte
   475  
   476  		err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
   477  		require.NoError(t, err)
   478  		require.EqualValues(t, expected, actual)
   479  	})
   480  }
   481  
   482  func TestConnQueryFailure(t *testing.T) {
   483  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   484  		_, err := db.Query("select 'foo")
   485  		require.Error(t, err)
   486  		require.IsType(t, new(pgconn.PgError), err)
   487  	})
   488  }
   489  
   490  func TestConnSimpleSlicePassThrough(t *testing.T) {
   491  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   492  		skipCockroachDB(t, db, "Server does not support cardinality function")
   493  
   494  		var n int64
   495  		err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
   496  		require.NoError(t, err)
   497  		assert.EqualValues(t, 3, n)
   498  	})
   499  }
   500  
   501  func TestConnQueryScanGoArray(t *testing.T) {
   502  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   503  		m := pgtype.NewMap()
   504  
   505  		var a []int64
   506  		err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
   507  		require.NoError(t, err)
   508  		assert.Equal(t, []int64{1, 2, 3}, a)
   509  	})
   510  }
   511  
   512  func TestConnQueryScanArray(t *testing.T) {
   513  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   514  		m := pgtype.NewMap()
   515  
   516  		var a pgtype.Array[int64]
   517  		err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
   518  		require.NoError(t, err)
   519  		assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a)
   520  
   521  		err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a))
   522  		require.NoError(t, err)
   523  		assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a)
   524  	})
   525  }
   526  
   527  func TestConnQueryScanRange(t *testing.T) {
   528  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   529  		skipCockroachDB(t, db, "Server does not support int4range")
   530  
   531  		m := pgtype.NewMap()
   532  
   533  		var r pgtype.Range[pgtype.Int4]
   534  		err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r))
   535  		require.NoError(t, err)
   536  		assert.Equal(
   537  			t,
   538  			pgtype.Range[pgtype.Int4]{
   539  				Lower:     pgtype.Int4{Int32: 1, Valid: true},
   540  				Upper:     pgtype.Int4{Int32: 5, Valid: true},
   541  				LowerType: pgtype.Inclusive,
   542  				UpperType: pgtype.Exclusive,
   543  				Valid:     true,
   544  			},
   545  			r)
   546  	})
   547  }
   548  
   549  // Test type that pgx would handle natively in binary, but since it is not a
   550  // database/sql native type should be passed through as a string
   551  func TestConnQueryRowPgxBinary(t *testing.T) {
   552  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   553  		sql := "select $1::int4[]"
   554  		expected := "{1,2,3}"
   555  		var actual string
   556  
   557  		err := db.QueryRow(sql, expected).Scan(&actual)
   558  		require.NoError(t, err)
   559  		require.EqualValues(t, expected, actual)
   560  	})
   561  }
   562  
   563  func TestConnQueryRowUnknownType(t *testing.T) {
   564  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   565  		skipCockroachDB(t, db, "Server does not support point type")
   566  
   567  		sql := "select $1::point"
   568  		expected := "(1,2)"
   569  		var actual string
   570  
   571  		err := db.QueryRow(sql, expected).Scan(&actual)
   572  		require.NoError(t, err)
   573  		require.EqualValues(t, expected, actual)
   574  	})
   575  }
   576  
   577  func TestConnQueryJSONIntoByteSlice(t *testing.T) {
   578  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   579  		_, err := db.Exec(`
   580  		create temporary table docs(
   581  			body json not null
   582  		);
   583  
   584  		insert into docs(body) values('{"foo": "bar"}');
   585  `)
   586  		require.NoError(t, err)
   587  
   588  		sql := `select * from docs`
   589  		expected := []byte(`{"foo": "bar"}`)
   590  		var actual []byte
   591  
   592  		err = db.QueryRow(sql).Scan(&actual)
   593  		if err != nil {
   594  			t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
   595  		}
   596  
   597  		if !bytes.Equal(actual, expected) {
   598  			t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
   599  		}
   600  
   601  		_, err = db.Exec(`drop table docs`)
   602  		require.NoError(t, err)
   603  	})
   604  }
   605  
   606  func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
   607  	// Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data
   608  	// that needs to escape. No way to know whether the destination is really a text compatible or a bytea.
   609  
   610  	db := openDB(t)
   611  	defer closeDB(t, db)
   612  
   613  	_, err := db.Exec(`
   614  		create temporary table docs(
   615  			body json not null
   616  		);
   617  `)
   618  	require.NoError(t, err)
   619  
   620  	expected := []byte(`{"foo": "bar"}`)
   621  
   622  	_, err = db.Exec(`insert into docs(body) values($1)`, expected)
   623  	require.NoError(t, err)
   624  
   625  	var actual []byte
   626  	err = db.QueryRow(`select body from docs`).Scan(&actual)
   627  	require.NoError(t, err)
   628  
   629  	if !bytes.Equal(actual, expected) {
   630  		t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
   631  	}
   632  
   633  	_, err = db.Exec(`drop table docs`)
   634  	require.NoError(t, err)
   635  }
   636  
   637  func TestTransactionLifeCycle(t *testing.T) {
   638  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   639  		_, err := db.Exec("create temporary table t(a varchar not null)")
   640  		require.NoError(t, err)
   641  
   642  		tx, err := db.Begin()
   643  		require.NoError(t, err)
   644  
   645  		_, err = tx.Exec("insert into t values('hi')")
   646  		require.NoError(t, err)
   647  
   648  		err = tx.Rollback()
   649  		require.NoError(t, err)
   650  
   651  		var n int64
   652  		err = db.QueryRow("select count(*) from t").Scan(&n)
   653  		require.NoError(t, err)
   654  		require.EqualValues(t, 0, n)
   655  
   656  		tx, err = db.Begin()
   657  		require.NoError(t, err)
   658  
   659  		_, err = tx.Exec("insert into t values('hi')")
   660  		require.NoError(t, err)
   661  
   662  		err = tx.Commit()
   663  		require.NoError(t, err)
   664  
   665  		err = db.QueryRow("select count(*) from t").Scan(&n)
   666  		require.NoError(t, err)
   667  		require.EqualValues(t, 1, n)
   668  	})
   669  }
   670  
   671  func TestConnBeginTxIsolation(t *testing.T) {
   672  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   673  		skipCockroachDB(t, db, "Server always uses serializable isolation level")
   674  
   675  		var defaultIsoLevel string
   676  		err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
   677  		require.NoError(t, err)
   678  
   679  		supportedTests := []struct {
   680  			sqlIso sql.IsolationLevel
   681  			pgIso  string
   682  		}{
   683  			{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
   684  			{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
   685  			{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
   686  			{sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
   687  			{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
   688  			{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
   689  		}
   690  		for i, tt := range supportedTests {
   691  			func() {
   692  				tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
   693  				if err != nil {
   694  					t.Errorf("%d. BeginTx failed: %v", i, err)
   695  					return
   696  				}
   697  				defer tx.Rollback()
   698  
   699  				var pgIso string
   700  				err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
   701  				if err != nil {
   702  					t.Errorf("%d. QueryRow failed: %v", i, err)
   703  				}
   704  
   705  				if pgIso != tt.pgIso {
   706  					t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
   707  				}
   708  			}()
   709  		}
   710  
   711  		unsupportedTests := []struct {
   712  			sqlIso sql.IsolationLevel
   713  		}{
   714  			{sqlIso: sql.LevelWriteCommitted},
   715  			{sqlIso: sql.LevelLinearizable},
   716  		}
   717  		for i, tt := range unsupportedTests {
   718  			tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
   719  			if err == nil {
   720  				t.Errorf("%d. BeginTx should have failed", i)
   721  				tx.Rollback()
   722  			}
   723  		}
   724  	})
   725  }
   726  
   727  func TestConnBeginTxReadOnly(t *testing.T) {
   728  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   729  		tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
   730  		require.NoError(t, err)
   731  		defer tx.Rollback()
   732  
   733  		var pgReadOnly string
   734  		err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
   735  		if err != nil {
   736  			t.Errorf("QueryRow failed: %v", err)
   737  		}
   738  
   739  		if pgReadOnly != "on" {
   740  			t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
   741  		}
   742  	})
   743  }
   744  
   745  func TestBeginTxContextCancel(t *testing.T) {
   746  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   747  		_, err := db.Exec("drop table if exists t")
   748  		require.NoError(t, err)
   749  
   750  		ctx, cancelFn := context.WithCancel(context.Background())
   751  
   752  		tx, err := db.BeginTx(ctx, nil)
   753  		require.NoError(t, err)
   754  
   755  		_, err = tx.Exec("create table t(id serial)")
   756  		require.NoError(t, err)
   757  
   758  		cancelFn()
   759  
   760  		err = tx.Commit()
   761  		if err != context.Canceled && err != sql.ErrTxDone {
   762  			t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
   763  		}
   764  
   765  		var n int
   766  		err = db.QueryRow("select count(*) from t").Scan(&n)
   767  		if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
   768  			t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
   769  		}
   770  	})
   771  }
   772  
   773  func TestConnRaw(t *testing.T) {
   774  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   775  		conn, err := db.Conn(context.Background())
   776  		require.NoError(t, err)
   777  
   778  		var n int
   779  		err = conn.Raw(func(driverConn any) error {
   780  			conn := driverConn.(*stdlib.Conn).Conn()
   781  			return conn.QueryRow(context.Background(), "select 42").Scan(&n)
   782  		})
   783  		require.NoError(t, err)
   784  		assert.EqualValues(t, 42, n)
   785  	})
   786  }
   787  
   788  func TestConnPingContextSuccess(t *testing.T) {
   789  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   790  		err := db.PingContext(context.Background())
   791  		require.NoError(t, err)
   792  	})
   793  }
   794  
   795  func TestConnPrepareContextSuccess(t *testing.T) {
   796  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   797  		stmt, err := db.PrepareContext(context.Background(), "select now()")
   798  		require.NoError(t, err)
   799  		err = stmt.Close()
   800  		require.NoError(t, err)
   801  	})
   802  }
   803  
   804  // https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281
   805  // https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
   806  func TestConnMultiplePrepareAndDeallocate(t *testing.T) {
   807  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   808  		skipCockroachDB(t, db, "Server does not support pg_prepared_statements")
   809  
   810  		sql := "select 42"
   811  		stmt1, err := db.PrepareContext(context.Background(), sql)
   812  		require.NoError(t, err)
   813  		stmt2, err := db.PrepareContext(context.Background(), sql)
   814  		require.NoError(t, err)
   815  		err = stmt1.Close()
   816  		require.NoError(t, err)
   817  
   818  		var preparedStmtCount int64
   819  		err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
   820  		require.NoError(t, err)
   821  		require.EqualValues(t, 1, preparedStmtCount)
   822  
   823  		err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate.
   824  		require.NoError(t, err)
   825  
   826  		err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
   827  		require.NoError(t, err)
   828  		require.EqualValues(t, 0, preparedStmtCount)
   829  	})
   830  }
   831  
   832  func TestConnExecContextSuccess(t *testing.T) {
   833  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   834  		_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
   835  		require.NoError(t, err)
   836  	})
   837  }
   838  
   839  func TestConnQueryContextSuccess(t *testing.T) {
   840  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   841  		rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
   842  		require.NoError(t, err)
   843  
   844  		for rows.Next() {
   845  			var n int64
   846  			err := rows.Scan(&n)
   847  			require.NoError(t, err)
   848  		}
   849  		require.NoError(t, rows.Err())
   850  	})
   851  }
   852  
   853  func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
   854  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   855  		rows, err := db.Query("select 42::bigint")
   856  		require.NoError(t, err)
   857  
   858  		columnTypes, err := rows.ColumnTypes()
   859  		require.NoError(t, err)
   860  		require.Len(t, columnTypes, 1)
   861  
   862  		if columnTypes[0].DatabaseTypeName() != "INT8" {
   863  			t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8")
   864  		}
   865  
   866  		err = rows.Close()
   867  		require.NoError(t, err)
   868  	})
   869  }
   870  
   871  func TestStmtExecContextSuccess(t *testing.T) {
   872  	db := openDB(t)
   873  	defer closeDB(t, db)
   874  
   875  	_, err := db.Exec("create temporary table t(id int primary key)")
   876  	require.NoError(t, err)
   877  
   878  	stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
   879  	require.NoError(t, err)
   880  	defer stmt.Close()
   881  
   882  	_, err = stmt.ExecContext(context.Background(), 42)
   883  	require.NoError(t, err)
   884  
   885  	ensureDBValid(t, db)
   886  }
   887  
   888  func TestStmtExecContextCancel(t *testing.T) {
   889  	db := openDB(t)
   890  	defer closeDB(t, db)
   891  
   892  	_, err := db.Exec("create temporary table t(id int primary key)")
   893  	require.NoError(t, err)
   894  
   895  	stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
   896  	require.NoError(t, err)
   897  	defer stmt.Close()
   898  
   899  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   900  	defer cancel()
   901  
   902  	_, err = stmt.ExecContext(ctx, 42)
   903  	if !pgconn.Timeout(err) {
   904  		t.Errorf("expected timeout error, got %v", err)
   905  	}
   906  
   907  	ensureDBValid(t, db)
   908  }
   909  
   910  func TestStmtQueryContextSuccess(t *testing.T) {
   911  	db := openDB(t)
   912  	defer closeDB(t, db)
   913  
   914  	skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   915  
   916  	stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
   917  	require.NoError(t, err)
   918  	defer stmt.Close()
   919  
   920  	rows, err := stmt.QueryContext(context.Background(), 5)
   921  	require.NoError(t, err)
   922  
   923  	for rows.Next() {
   924  		var n int64
   925  		if err := rows.Scan(&n); err != nil {
   926  			t.Error(err)
   927  		}
   928  	}
   929  
   930  	if rows.Err() != nil {
   931  		t.Error(rows.Err())
   932  	}
   933  
   934  	ensureDBValid(t, db)
   935  }
   936  
   937  func TestRowsColumnTypes(t *testing.T) {
   938  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
   939  		columnTypesTests := []struct {
   940  			Name     string
   941  			TypeName string
   942  			Length   struct {
   943  				Len int64
   944  				OK  bool
   945  			}
   946  			DecimalSize struct {
   947  				Precision int64
   948  				Scale     int64
   949  				OK        bool
   950  			}
   951  			ScanType reflect.Type
   952  		}{
   953  			{
   954  				Name:     "a",
   955  				TypeName: "INT8",
   956  				Length: struct {
   957  					Len int64
   958  					OK  bool
   959  				}{
   960  					Len: 0,
   961  					OK:  false,
   962  				},
   963  				DecimalSize: struct {
   964  					Precision int64
   965  					Scale     int64
   966  					OK        bool
   967  				}{
   968  					Precision: 0,
   969  					Scale:     0,
   970  					OK:        false,
   971  				},
   972  				ScanType: reflect.TypeOf(int64(0)),
   973  			}, {
   974  				Name:     "bar",
   975  				TypeName: "TEXT",
   976  				Length: struct {
   977  					Len int64
   978  					OK  bool
   979  				}{
   980  					Len: math.MaxInt64,
   981  					OK:  true,
   982  				},
   983  				DecimalSize: struct {
   984  					Precision int64
   985  					Scale     int64
   986  					OK        bool
   987  				}{
   988  					Precision: 0,
   989  					Scale:     0,
   990  					OK:        false,
   991  				},
   992  				ScanType: reflect.TypeOf(""),
   993  			}, {
   994  				Name:     "dec",
   995  				TypeName: "NUMERIC",
   996  				Length: struct {
   997  					Len int64
   998  					OK  bool
   999  				}{
  1000  					Len: 0,
  1001  					OK:  false,
  1002  				},
  1003  				DecimalSize: struct {
  1004  					Precision int64
  1005  					Scale     int64
  1006  					OK        bool
  1007  				}{
  1008  					Precision: 9,
  1009  					Scale:     2,
  1010  					OK:        true,
  1011  				},
  1012  				ScanType: reflect.TypeOf(float64(0)),
  1013  			}, {
  1014  				Name:     "d",
  1015  				TypeName: "1266",
  1016  				Length: struct {
  1017  					Len int64
  1018  					OK  bool
  1019  				}{
  1020  					Len: 0,
  1021  					OK:  false,
  1022  				},
  1023  				DecimalSize: struct {
  1024  					Precision int64
  1025  					Scale     int64
  1026  					OK        bool
  1027  				}{
  1028  					Precision: 0,
  1029  					Scale:     0,
  1030  					OK:        false,
  1031  				},
  1032  				ScanType: reflect.TypeOf(""),
  1033  			},
  1034  		}
  1035  
  1036  		rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
  1037  		require.NoError(t, err)
  1038  
  1039  		columns, err := rows.ColumnTypes()
  1040  		require.NoError(t, err)
  1041  		assert.Len(t, columns, 4)
  1042  
  1043  		for i, tt := range columnTypesTests {
  1044  			c := columns[i]
  1045  			if c.Name() != tt.Name {
  1046  				t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
  1047  			}
  1048  			if c.DatabaseTypeName() != tt.TypeName {
  1049  				t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
  1050  			}
  1051  			l, ok := c.Length()
  1052  			if l != tt.Length.Len {
  1053  				t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
  1054  			}
  1055  			if ok != tt.Length.OK {
  1056  				t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
  1057  			}
  1058  			p, s, ok := c.DecimalSize()
  1059  			if p != tt.DecimalSize.Precision {
  1060  				t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
  1061  			}
  1062  			if s != tt.DecimalSize.Scale {
  1063  				t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
  1064  			}
  1065  			if ok != tt.DecimalSize.OK {
  1066  				t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
  1067  			}
  1068  			if c.ScanType() != tt.ScanType {
  1069  				t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
  1070  			}
  1071  		}
  1072  	})
  1073  }
  1074  
  1075  func TestQueryLifeCycle(t *testing.T) {
  1076  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
  1077  		skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
  1078  
  1079  		rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
  1080  		require.NoError(t, err)
  1081  
  1082  		rowCount := int64(0)
  1083  
  1084  		for rows.Next() {
  1085  			rowCount++
  1086  			var (
  1087  				s string
  1088  				n int64
  1089  			)
  1090  
  1091  			err := rows.Scan(&s, &n)
  1092  			require.NoError(t, err)
  1093  
  1094  			if s != "foo" {
  1095  				t.Errorf(`Expected "foo", received "%v"`, s)
  1096  			}
  1097  
  1098  			if n != rowCount {
  1099  				t.Errorf("Expected %d, received %d", rowCount, n)
  1100  			}
  1101  		}
  1102  		require.NoError(t, rows.Err())
  1103  
  1104  		err = rows.Close()
  1105  		require.NoError(t, err)
  1106  
  1107  		rows, err = db.Query("select 1 where false")
  1108  		require.NoError(t, err)
  1109  
  1110  		rowCount = int64(0)
  1111  
  1112  		for rows.Next() {
  1113  			rowCount++
  1114  		}
  1115  		require.NoError(t, rows.Err())
  1116  		require.EqualValues(t, 0, rowCount)
  1117  
  1118  		err = rows.Close()
  1119  		require.NoError(t, err)
  1120  	})
  1121  }
  1122  
  1123  // https://github.com/jackc/pgx/issues/409
  1124  func TestScanJSONIntoJSONRawMessage(t *testing.T) {
  1125  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
  1126  		var msg json.RawMessage
  1127  
  1128  		err := db.QueryRow("select '{}'::json").Scan(&msg)
  1129  		require.NoError(t, err)
  1130  		require.EqualValues(t, []byte("{}"), []byte(msg))
  1131  	})
  1132  }
  1133  
  1134  type testLog struct {
  1135  	lvl  tracelog.LogLevel
  1136  	msg  string
  1137  	data map[string]any
  1138  }
  1139  
  1140  type testLogger struct {
  1141  	logs []testLog
  1142  }
  1143  
  1144  func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) {
  1145  	l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
  1146  }
  1147  
  1148  func TestRegisterConnConfig(t *testing.T) {
  1149  	connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
  1150  	require.NoError(t, err)
  1151  
  1152  	logger := &testLogger{}
  1153  	connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo}
  1154  
  1155  	// Issue 947: Register and unregister a ConnConfig and ensure that the
  1156  	// returned connection string is not reused.
  1157  	connStr := stdlib.RegisterConnConfig(connConfig)
  1158  	require.Equal(t, "registeredConnConfig0", connStr)
  1159  	stdlib.UnregisterConnConfig(connStr)
  1160  
  1161  	connStr = stdlib.RegisterConnConfig(connConfig)
  1162  	defer stdlib.UnregisterConnConfig(connStr)
  1163  	require.Equal(t, "registeredConnConfig1", connStr)
  1164  
  1165  	db, err := sql.Open("pgx", connStr)
  1166  	require.NoError(t, err)
  1167  	defer closeDB(t, db)
  1168  
  1169  	var n int64
  1170  	err = db.QueryRow("select 1").Scan(&n)
  1171  	require.NoError(t, err)
  1172  
  1173  	l := logger.logs[len(logger.logs)-1]
  1174  	assert.Equal(t, "Query", l.msg)
  1175  	assert.Equal(t, "select 1", l.data["sql"])
  1176  }
  1177  
  1178  // https://github.com/jackc/pgx/issues/958
  1179  func TestConnQueryRowConstraintErrors(t *testing.T) {
  1180  	testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
  1181  		skipPostgreSQLVersionLessThan(t, db, 11)
  1182  		skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
  1183  
  1184  		_, err := db.Exec(`create temporary table defer_test (
  1185  			id text primary key,
  1186  			n int not null, unique (n),
  1187  			unique (n) deferrable initially deferred )`)
  1188  		require.NoError(t, err)
  1189  
  1190  		_, err = db.Exec(`drop function if exists test_trigger cascade`)
  1191  		require.NoError(t, err)
  1192  
  1193  		_, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$
  1194  		begin
  1195  		if new.n = 4 then
  1196  			raise exception 'n cant be 4!';
  1197  		end if;
  1198  		return new;
  1199  	end$$`)
  1200  		require.NoError(t, err)
  1201  
  1202  		_, err = db.Exec(`create constraint trigger test
  1203  			after insert or update on defer_test
  1204  			deferrable initially deferred
  1205  			for each row
  1206  			execute function test_trigger()`)
  1207  		require.NoError(t, err)
  1208  
  1209  		_, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`)
  1210  		require.NoError(t, err)
  1211  
  1212  		var id string
  1213  		err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id)
  1214  		assert.Error(t, err)
  1215  	})
  1216  }
  1217  
  1218  func TestOptionBeforeAfterConnect(t *testing.T) {
  1219  	config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
  1220  	require.NoError(t, err)
  1221  
  1222  	var beforeConnConfigs []*pgx.ConnConfig
  1223  	var afterConns []*pgx.Conn
  1224  	db := stdlib.OpenDB(*config,
  1225  		stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error {
  1226  			beforeConnConfigs = append(beforeConnConfigs, connConfig)
  1227  			return nil
  1228  		}),
  1229  		stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error {
  1230  			afterConns = append(afterConns, conn)
  1231  			return nil
  1232  		}))
  1233  	defer closeDB(t, db)
  1234  
  1235  	// Force it to close and reopen a new connection after each query
  1236  	db.SetMaxIdleConns(0)
  1237  
  1238  	_, err = db.Exec("select 1")
  1239  	require.NoError(t, err)
  1240  
  1241  	_, err = db.Exec("select 1")
  1242  	require.NoError(t, err)
  1243  
  1244  	require.Len(t, beforeConnConfigs, 2)
  1245  	require.Len(t, afterConns, 2)
  1246  
  1247  	// Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they
  1248  	// are different objects, so can't use require.NotEqual
  1249  	require.False(t, config == beforeConnConfigs[0])
  1250  	require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1])
  1251  }
  1252  
  1253  func TestRandomizeHostOrderFunc(t *testing.T) {
  1254  	config, err := pgx.ParseConfig("postgres://host1,host2,host3")
  1255  	require.NoError(t, err)
  1256  
  1257  	// Test that at some point we connect to all 3 hosts
  1258  	hostsNotSeenYet := map[string]struct{}{
  1259  		"host1": {},
  1260  		"host2": {},
  1261  		"host3": {},
  1262  	}
  1263  
  1264  	// If we don't succeed within this many iterations, something is certainly wrong
  1265  	for i := 0; i < 100000; i++ {
  1266  		connCopy := *config
  1267  		stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy)
  1268  
  1269  		delete(hostsNotSeenYet, connCopy.Host)
  1270  		if len(hostsNotSeenYet) == 0 {
  1271  			return
  1272  		}
  1273  
  1274  	hostCheckLoop:
  1275  		for _, h := range []string{"host1", "host2", "host3"} {
  1276  			if connCopy.Host == h {
  1277  				continue
  1278  			}
  1279  			for _, f := range connCopy.Fallbacks {
  1280  				if f.Host == h {
  1281  					continue hostCheckLoop
  1282  				}
  1283  			}
  1284  			require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy)
  1285  		}
  1286  	}
  1287  
  1288  	require.Fail(t, "did not get all hosts as primaries after many randomizations")
  1289  }
  1290  
  1291  func TestResetSessionHookCalled(t *testing.T) {
  1292  	var mockCalled bool
  1293  
  1294  	connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
  1295  	require.NoError(t, err)
  1296  
  1297  	db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error {
  1298  		mockCalled = true
  1299  
  1300  		return nil
  1301  	}))
  1302  
  1303  	defer closeDB(t, db)
  1304  
  1305  	err = db.Ping()
  1306  	require.NoError(t, err)
  1307  
  1308  	err = db.Ping()
  1309  	require.NoError(t, err)
  1310  
  1311  	require.True(t, mockCalled)
  1312  }
  1313  
  1314  func TestCheckIdleConn(t *testing.T) {
  1315  	controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
  1316  	require.NoError(t, err)
  1317  	defer closeDB(t, controllerConn)
  1318  
  1319  	skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
  1320  
  1321  	db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE"))
  1322  	require.NoError(t, err)
  1323  	defer closeDB(t, db)
  1324  
  1325  	var conns []*sql.Conn
  1326  	for i := 0; i < 3; i++ {
  1327  		c, err := db.Conn(context.Background())
  1328  		require.NoError(t, err)
  1329  		conns = append(conns, c)
  1330  	}
  1331  
  1332  	require.EqualValues(t, 3, db.Stats().OpenConnections)
  1333  
  1334  	var pids []uint32
  1335  	for _, c := range conns {
  1336  		err := c.Raw(func(driverConn any) error {
  1337  			pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID())
  1338  			return nil
  1339  		})
  1340  		require.NoError(t, err)
  1341  		err = c.Close()
  1342  		require.NoError(t, err)
  1343  	}
  1344  
  1345  	// The database/sql connection pool seems to automatically close idle connections to only keep 2 alive.
  1346  	// require.EqualValues(t, 3, db.Stats().OpenConnections)
  1347  
  1348  	_, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids)
  1349  	require.NoError(t, err)
  1350  
  1351  	// All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing
  1352  	// idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections)
  1353  
  1354  	// Wait long enough so the pool will realize it needs to check the connections.
  1355  	time.Sleep(time.Second)
  1356  
  1357  	// Pool should try all existing connections and find them dead, then create a new connection which should successfully ping.
  1358  	err = db.PingContext(context.Background())
  1359  	require.NoError(t, err)
  1360  
  1361  	// The original 3 conns should have been terminated and the a new conn established for the ping.
  1362  	require.EqualValues(t, 1, db.Stats().OpenConnections)
  1363  	c, err := db.Conn(context.Background())
  1364  	require.NoError(t, err)
  1365  
  1366  	var cPID uint32
  1367  	err = c.Raw(func(driverConn any) error {
  1368  		cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID()
  1369  		return nil
  1370  	})
  1371  	require.NoError(t, err)
  1372  	err = c.Close()
  1373  	require.NoError(t, err)
  1374  
  1375  	require.NotContains(t, pids, cPID)
  1376  }
  1377  

View as plain text