...

Source file src/github.com/jackc/pgx/v4/query_test.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"errors"
     8  	"fmt"
     9  	"os"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/cockroachdb/apd"
    17  	"github.com/gofrs/uuid"
    18  	"github.com/jackc/pgconn"
    19  	"github.com/jackc/pgconn/stmtcache"
    20  	"github.com/jackc/pgtype"
    21  	"github.com/jackc/pgx/v4"
    22  	"github.com/shopspring/decimal"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  )
    26  
    27  func TestConnQueryScan(t *testing.T) {
    28  	t.Parallel()
    29  
    30  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    31  	defer closeConn(t, conn)
    32  
    33  	var sum, rowCount int32
    34  
    35  	rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
    36  	if err != nil {
    37  		t.Fatalf("conn.Query failed: %v", err)
    38  	}
    39  	defer rows.Close()
    40  
    41  	for rows.Next() {
    42  		var n int32
    43  		rows.Scan(&n)
    44  		sum += n
    45  		rowCount++
    46  	}
    47  
    48  	if rows.Err() != nil {
    49  		t.Fatalf("conn.Query failed: %v", err)
    50  	}
    51  
    52  	assert.Equal(t, "SELECT 10", string(rows.CommandTag()))
    53  
    54  	if rowCount != 10 {
    55  		t.Error("Select called onDataRow wrong number of times")
    56  	}
    57  	if sum != 55 {
    58  		t.Error("Wrong values returned")
    59  	}
    60  }
    61  
    62  func TestConnQueryRowsFieldDescriptionsBeforeNext(t *testing.T) {
    63  	t.Parallel()
    64  
    65  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    66  	defer closeConn(t, conn)
    67  
    68  	rows, err := conn.Query(context.Background(), "select 'hello' as msg")
    69  	require.NoError(t, err)
    70  	defer rows.Close()
    71  
    72  	require.Len(t, rows.FieldDescriptions(), 1)
    73  	assert.Equal(t, []byte("msg"), rows.FieldDescriptions()[0].Name)
    74  }
    75  
    76  func TestConnQueryWithoutResultSetCommandTag(t *testing.T) {
    77  	t.Parallel()
    78  
    79  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    80  	defer closeConn(t, conn)
    81  
    82  	rows, err := conn.Query(context.Background(), "create temporary table t (id serial);")
    83  	assert.NoError(t, err)
    84  	rows.Close()
    85  	assert.NoError(t, rows.Err())
    86  	assert.Equal(t, "CREATE TABLE", string(rows.CommandTag()))
    87  }
    88  
    89  func TestConnQueryScanWithManyColumns(t *testing.T) {
    90  	t.Parallel()
    91  
    92  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    93  	defer closeConn(t, conn)
    94  
    95  	columnCount := 1000
    96  	sql := "select "
    97  	for i := 0; i < columnCount; i++ {
    98  		if i > 0 {
    99  			sql += ","
   100  		}
   101  		sql += fmt.Sprintf(" %d", i)
   102  	}
   103  	sql += " from generate_series(1,5)"
   104  
   105  	dest := make([]int, columnCount)
   106  
   107  	var rowCount int
   108  
   109  	rows, err := conn.Query(context.Background(), sql)
   110  	if err != nil {
   111  		t.Fatalf("conn.Query failed: %v", err)
   112  	}
   113  	defer rows.Close()
   114  
   115  	for rows.Next() {
   116  		destPtrs := make([]interface{}, columnCount)
   117  		for i := range destPtrs {
   118  			destPtrs[i] = &dest[i]
   119  		}
   120  		if err := rows.Scan(destPtrs...); err != nil {
   121  			t.Fatalf("rows.Scan failed: %v", err)
   122  		}
   123  		rowCount++
   124  
   125  		for i := range dest {
   126  			if dest[i] != i {
   127  				t.Errorf("dest[%d] => %d, want %d", i, dest[i], i)
   128  			}
   129  		}
   130  	}
   131  
   132  	if rows.Err() != nil {
   133  		t.Fatalf("conn.Query failed: %v", err)
   134  	}
   135  
   136  	if rowCount != 5 {
   137  		t.Errorf("rowCount => %d, want %d", rowCount, 5)
   138  	}
   139  }
   140  
   141  func TestConnQueryValues(t *testing.T) {
   142  	t.Parallel()
   143  
   144  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   145  	defer closeConn(t, conn)
   146  
   147  	var rowCount int32
   148  
   149  	rows, err := conn.Query(context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10)
   150  	if err != nil {
   151  		t.Fatalf("conn.Query failed: %v", err)
   152  	}
   153  	defer rows.Close()
   154  
   155  	for rows.Next() {
   156  		rowCount++
   157  
   158  		values, err := rows.Values()
   159  		require.NoError(t, err)
   160  		require.Len(t, values, 5)
   161  		assert.Equal(t, "foo", values[0])
   162  		assert.Equal(t, "bar", values[1])
   163  		assert.EqualValues(t, rowCount, values[2])
   164  		assert.Nil(t, values[3])
   165  		assert.EqualValues(t, rowCount, values[4])
   166  	}
   167  
   168  	if rows.Err() != nil {
   169  		t.Fatalf("conn.Query failed: %v", err)
   170  	}
   171  
   172  	if rowCount != 10 {
   173  		t.Error("Select called onDataRow wrong number of times")
   174  	}
   175  }
   176  
   177  // https://github.com/jackc/pgx/issues/666
   178  func TestConnQueryValuesWhenUnableToDecode(t *testing.T) {
   179  	t.Parallel()
   180  
   181  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   182  	defer closeConn(t, conn)
   183  
   184  	// Note that this relies on pgtype.Record not supporting the text protocol. This seems safe as it is impossible to
   185  	// decode the text protocol because unlike the binary protocol there is no way to determine the OIDs of the elements.
   186  	rows, err := conn.Query(context.Background(), "select (array[1::oid], null)", pgx.QueryResultFormats{pgx.TextFormatCode})
   187  	require.NoError(t, err)
   188  	defer rows.Close()
   189  
   190  	require.True(t, rows.Next())
   191  
   192  	values, err := rows.Values()
   193  	require.NoError(t, err)
   194  	require.Equal(t, "({1},)", values[0])
   195  }
   196  
   197  func TestConnQueryValuesWithUnknownOID(t *testing.T) {
   198  	t.Parallel()
   199  
   200  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   201  	defer cancel()
   202  
   203  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   204  	defer closeConn(t, conn)
   205  
   206  	tx, err := conn.Begin(ctx)
   207  	require.NoError(t, err)
   208  	defer tx.Rollback(ctx)
   209  
   210  	_, err = tx.Exec(ctx, "create type fruit as enum('orange', 'apple', 'pear')")
   211  	require.NoError(t, err)
   212  
   213  	rows, err := conn.Query(context.Background(), "select 'orange'::fruit")
   214  	require.NoError(t, err)
   215  	defer rows.Close()
   216  
   217  	require.True(t, rows.Next())
   218  
   219  	values, err := rows.Values()
   220  	require.NoError(t, err)
   221  	require.Equal(t, "orange", values[0])
   222  }
   223  
   224  // https://github.com/jackc/pgx/issues/478
   225  func TestConnQueryReadRowMultipleTimes(t *testing.T) {
   226  	t.Parallel()
   227  
   228  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   229  	defer closeConn(t, conn)
   230  
   231  	var rowCount int32
   232  
   233  	rows, err := conn.Query(context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10)
   234  	require.NoError(t, err)
   235  	defer rows.Close()
   236  
   237  	for rows.Next() {
   238  		rowCount++
   239  
   240  		for i := 0; i < 2; i++ {
   241  			values, err := rows.Values()
   242  			require.NoError(t, err)
   243  			require.Len(t, values, 5)
   244  			require.Equal(t, "foo", values[0])
   245  			require.Equal(t, "bar", values[1])
   246  			require.EqualValues(t, rowCount, values[2])
   247  			require.Nil(t, values[3])
   248  			require.EqualValues(t, rowCount, values[4])
   249  
   250  			var a, b string
   251  			var c int32
   252  			var d pgtype.Unknown
   253  			var e int32
   254  
   255  			err = rows.Scan(&a, &b, &c, &d, &e)
   256  			require.NoError(t, err)
   257  			require.Equal(t, "foo", a)
   258  			require.Equal(t, "bar", b)
   259  			require.Equal(t, rowCount, c)
   260  			require.Equal(t, pgtype.Null, d.Status)
   261  			require.Equal(t, rowCount, e)
   262  		}
   263  	}
   264  
   265  	require.NoError(t, rows.Err())
   266  	require.Equal(t, int32(10), rowCount)
   267  }
   268  
   269  // https://github.com/jackc/pgx/issues/386
   270  func TestConnQueryValuesWithMultipleComplexColumnsOfSameType(t *testing.T) {
   271  	t.Parallel()
   272  
   273  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   274  	defer closeConn(t, conn)
   275  
   276  	expected0 := &pgtype.Int8Array{
   277  		Elements: []pgtype.Int8{
   278  			{Int: 1, Status: pgtype.Present},
   279  			{Int: 2, Status: pgtype.Present},
   280  			{Int: 3, Status: pgtype.Present},
   281  		},
   282  		Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}},
   283  		Status:     pgtype.Present,
   284  	}
   285  
   286  	expected1 := &pgtype.Int8Array{
   287  		Elements: []pgtype.Int8{
   288  			{Int: 4, Status: pgtype.Present},
   289  			{Int: 5, Status: pgtype.Present},
   290  			{Int: 6, Status: pgtype.Present},
   291  		},
   292  		Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}},
   293  		Status:     pgtype.Present,
   294  	}
   295  
   296  	var rowCount int32
   297  
   298  	rows, err := conn.Query(context.Background(), "select '{1,2,3}'::bigint[], '{4,5,6}'::bigint[] from generate_series(1,$1) n", 10)
   299  	if err != nil {
   300  		t.Fatalf("conn.Query failed: %v", err)
   301  	}
   302  	defer rows.Close()
   303  
   304  	for rows.Next() {
   305  		rowCount++
   306  
   307  		values, err := rows.Values()
   308  		if err != nil {
   309  			t.Fatalf("rows.Values failed: %v", err)
   310  		}
   311  		if len(values) != 2 {
   312  			t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values))
   313  		}
   314  		if !reflect.DeepEqual(values[0], *expected0) {
   315  			t.Errorf(`Expected values[0] to be %v, but it was %v`, *expected0, values[0])
   316  		}
   317  		if !reflect.DeepEqual(values[1], *expected1) {
   318  			t.Errorf(`Expected values[1] to be %v, but it was %v`, *expected1, values[1])
   319  		}
   320  	}
   321  
   322  	if rows.Err() != nil {
   323  		t.Fatalf("conn.Query failed: %v", err)
   324  	}
   325  
   326  	if rowCount != 10 {
   327  		t.Error("Select called onDataRow wrong number of times")
   328  	}
   329  }
   330  
   331  // https://github.com/jackc/pgx/issues/228
   332  func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) {
   333  	t.Parallel()
   334  
   335  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   336  	defer closeConn(t, conn)
   337  
   338  	var s string
   339  
   340  	err := conn.QueryRow(context.Background(), "select 1").Scan(&s)
   341  	if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) {
   342  		t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err)
   343  	}
   344  
   345  	ensureConnValid(t, conn)
   346  }
   347  
   348  func TestConnQueryRawValues(t *testing.T) {
   349  	t.Parallel()
   350  
   351  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   352  	defer closeConn(t, conn)
   353  
   354  	var rowCount int32
   355  
   356  	rows, err := conn.Query(
   357  		context.Background(),
   358  		"select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n",
   359  		pgx.QuerySimpleProtocol(true),
   360  		10,
   361  	)
   362  	require.NoError(t, err)
   363  	defer rows.Close()
   364  
   365  	for rows.Next() {
   366  		rowCount++
   367  
   368  		rawValues := rows.RawValues()
   369  		assert.Len(t, rawValues, 5)
   370  		assert.Equal(t, "foo", string(rawValues[0]))
   371  		assert.Equal(t, "bar", string(rawValues[1]))
   372  		assert.Equal(t, strconv.FormatInt(int64(rowCount), 10), string(rawValues[2]))
   373  		assert.Nil(t, rawValues[3])
   374  		assert.Equal(t, strconv.FormatInt(int64(rowCount), 10), string(rawValues[4]))
   375  	}
   376  
   377  	require.NoError(t, rows.Err())
   378  	assert.EqualValues(t, 10, rowCount)
   379  }
   380  
   381  // Test that a connection stays valid when query results are closed early
   382  func TestConnQueryCloseEarly(t *testing.T) {
   383  	t.Parallel()
   384  
   385  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   386  	defer closeConn(t, conn)
   387  
   388  	// Immediately close query without reading any rows
   389  	rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
   390  	if err != nil {
   391  		t.Fatalf("conn.Query failed: %v", err)
   392  	}
   393  	rows.Close()
   394  
   395  	ensureConnValid(t, conn)
   396  
   397  	// Read partial response then close
   398  	rows, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10)
   399  	if err != nil {
   400  		t.Fatalf("conn.Query failed: %v", err)
   401  	}
   402  
   403  	ok := rows.Next()
   404  	if !ok {
   405  		t.Fatal("rows.Next terminated early")
   406  	}
   407  
   408  	var n int32
   409  	rows.Scan(&n)
   410  	if n != 1 {
   411  		t.Fatalf("Expected 1 from first row, but got %v", n)
   412  	}
   413  
   414  	rows.Close()
   415  
   416  	ensureConnValid(t, conn)
   417  }
   418  
   419  func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) {
   420  	t.Parallel()
   421  
   422  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   423  	defer closeConn(t, conn)
   424  
   425  	rows, err := conn.Query(context.Background(), "select 1/(10-n) from generate_series(1,10) n")
   426  	if err != nil {
   427  		t.Fatalf("conn.Query failed: %v", err)
   428  	}
   429  	assert.False(t, pgconn.SafeToRetry(err))
   430  	rows.Close()
   431  
   432  	ensureConnValid(t, conn)
   433  }
   434  
   435  // Test that a connection stays valid when query results read incorrectly
   436  func TestConnQueryReadWrongTypeError(t *testing.T) {
   437  	t.Parallel()
   438  
   439  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   440  	defer closeConn(t, conn)
   441  
   442  	// Read a single value incorrectly
   443  	rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
   444  	if err != nil {
   445  		t.Fatalf("conn.Query failed: %v", err)
   446  	}
   447  
   448  	rowsRead := 0
   449  
   450  	for rows.Next() {
   451  		var t time.Time
   452  		rows.Scan(&t)
   453  		rowsRead++
   454  	}
   455  
   456  	if rowsRead != 1 {
   457  		t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
   458  	}
   459  
   460  	if rows.Err() == nil {
   461  		t.Fatal("Expected Rows to have an error after an improper read but it didn't")
   462  	}
   463  
   464  	if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") {
   465  		t.Fatalf("Expected different Rows.Err(): %v", rows.Err())
   466  	}
   467  
   468  	ensureConnValid(t, conn)
   469  }
   470  
   471  // Test that a connection stays valid when query results read incorrectly
   472  func TestConnQueryReadTooManyValues(t *testing.T) {
   473  	t.Parallel()
   474  
   475  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   476  	defer closeConn(t, conn)
   477  
   478  	// Read too many values
   479  	rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
   480  	if err != nil {
   481  		t.Fatalf("conn.Query failed: %v", err)
   482  	}
   483  
   484  	rowsRead := 0
   485  
   486  	for rows.Next() {
   487  		var n, m int32
   488  		rows.Scan(&n, &m)
   489  		rowsRead++
   490  	}
   491  
   492  	if rowsRead != 1 {
   493  		t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead)
   494  	}
   495  
   496  	if rows.Err() == nil {
   497  		t.Fatal("Expected Rows to have an error after an improper read but it didn't")
   498  	}
   499  
   500  	ensureConnValid(t, conn)
   501  }
   502  
   503  func TestConnQueryScanIgnoreColumn(t *testing.T) {
   504  	t.Parallel()
   505  
   506  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   507  	defer closeConn(t, conn)
   508  
   509  	rows, err := conn.Query(context.Background(), "select 1::int8, 2::int8, 3::int8")
   510  	if err != nil {
   511  		t.Fatalf("conn.Query failed: %v", err)
   512  	}
   513  
   514  	ok := rows.Next()
   515  	if !ok {
   516  		t.Fatal("rows.Next terminated early")
   517  	}
   518  
   519  	var n, m int64
   520  	err = rows.Scan(&n, nil, &m)
   521  	if err != nil {
   522  		t.Fatalf("rows.Scan failed: %v", err)
   523  	}
   524  	rows.Close()
   525  
   526  	if n != 1 {
   527  		t.Errorf("Expected n to equal 1, but it was %d", n)
   528  	}
   529  
   530  	if m != 3 {
   531  		t.Errorf("Expected n to equal 3, but it was %d", m)
   532  	}
   533  
   534  	ensureConnValid(t, conn)
   535  }
   536  
   537  // https://github.com/jackc/pgx/issues/570
   538  func TestConnQueryDeferredError(t *testing.T) {
   539  	t.Parallel()
   540  
   541  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   542  	defer closeConn(t, conn)
   543  
   544  	skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
   545  
   546  	mustExec(t, conn, `create temporary table t (
   547  	id text primary key,
   548  	n int not null,
   549  	unique (n) deferrable initially deferred
   550  );
   551  
   552  insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
   553  
   554  	rows, err := conn.Query(context.Background(), `update t set n=n+1 where id='b' returning *`)
   555  	if err != nil {
   556  		t.Fatal(err)
   557  	}
   558  	defer rows.Close()
   559  
   560  	for rows.Next() {
   561  		var id string
   562  		var n int32
   563  		err = rows.Scan(&id, &n)
   564  		if err != nil {
   565  			t.Fatal(err)
   566  		}
   567  	}
   568  
   569  	if rows.Err() == nil {
   570  		t.Fatal("expected error 23505 but got none")
   571  	}
   572  
   573  	if err, ok := rows.Err().(*pgconn.PgError); !ok || err.Code != "23505" {
   574  		t.Fatalf("expected error 23505, got %v", err)
   575  	}
   576  
   577  	ensureConnValid(t, conn)
   578  }
   579  
   580  func TestConnQueryErrorWhileReturningRows(t *testing.T) {
   581  	t.Parallel()
   582  
   583  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   584  	defer closeConn(t, conn)
   585  
   586  	skipCockroachDB(t, conn, "Server uses numeric instead of int")
   587  
   588  	for i := 0; i < 100; i++ {
   589  		func() {
   590  			sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)`
   591  
   592  			rows, err := conn.Query(context.Background(), sql)
   593  			if err != nil {
   594  				t.Fatal(err)
   595  			}
   596  			defer rows.Close()
   597  
   598  			for rows.Next() {
   599  				var n int32
   600  				if err := rows.Scan(&n); err != nil {
   601  					t.Fatalf("Row scan failed: %v", err)
   602  				}
   603  			}
   604  
   605  			if _, ok := rows.Err().(*pgconn.PgError); !ok {
   606  				t.Fatalf("Expected pgx.PgError, got %v", rows.Err())
   607  			}
   608  
   609  			ensureConnValid(t, conn)
   610  		}()
   611  	}
   612  
   613  }
   614  
   615  func TestQueryEncodeError(t *testing.T) {
   616  	t.Parallel()
   617  
   618  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   619  	defer closeConn(t, conn)
   620  
   621  	rows, err := conn.Query(context.Background(), "select $1::integer", "wrong")
   622  	if err != nil {
   623  		t.Errorf("conn.Query failure: %v", err)
   624  	}
   625  	assert.False(t, pgconn.SafeToRetry(err))
   626  	defer rows.Close()
   627  
   628  	rows.Next()
   629  
   630  	if rows.Err() == nil {
   631  		t.Error("Expected rows.Err() to return error, but it didn't")
   632  	}
   633  	if conn.PgConn().ParameterStatus("crdb_version") != "" {
   634  		if !strings.Contains(rows.Err().Error(), "SQLSTATE 08P01") {
   635  			// CockroachDB returns protocol_violation instead of invalid_text_representation
   636  			t.Error("Expected rows.Err() to return different error:", rows.Err())
   637  		}
   638  	} else {
   639  		if !strings.Contains(rows.Err().Error(), "SQLSTATE 22P02") {
   640  			t.Error("Expected rows.Err() to return different error:", rows.Err())
   641  		}
   642  	}
   643  }
   644  
   645  func TestQueryRowCoreTypes(t *testing.T) {
   646  	t.Parallel()
   647  
   648  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   649  	defer closeConn(t, conn)
   650  
   651  	type allTypes struct {
   652  		s   string
   653  		f32 float32
   654  		f64 float64
   655  		b   bool
   656  		t   time.Time
   657  		oid uint32
   658  	}
   659  
   660  	var actual, zero allTypes
   661  
   662  	tests := []struct {
   663  		sql       string
   664  		queryArgs []interface{}
   665  		scanArgs  []interface{}
   666  		expected  allTypes
   667  	}{
   668  		{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}},
   669  		{"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}},
   670  		{"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}},
   671  		{"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}},
   672  		{"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}},
   673  		{"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}},
   674  		{"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}},
   675  		{"select $1::oid", []interface{}{uint32(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}},
   676  	}
   677  
   678  	for i, tt := range tests {
   679  		actual = zero
   680  
   681  		err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
   682  		if err != nil {
   683  			t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
   684  		}
   685  
   686  		if actual.s != tt.expected.s || actual.f32 != tt.expected.f32 || actual.b != tt.expected.b || !actual.t.Equal(tt.expected.t) || actual.oid != tt.expected.oid {
   687  			t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs)
   688  		}
   689  
   690  		ensureConnValid(t, conn)
   691  
   692  		// Check that Scan errors when a core type is null
   693  		err = conn.QueryRow(context.Background(), tt.sql, nil).Scan(tt.scanArgs...)
   694  		if err == nil {
   695  			t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql)
   696  		}
   697  
   698  		ensureConnValid(t, conn)
   699  	}
   700  }
   701  
   702  func TestQueryRowCoreIntegerEncoding(t *testing.T) {
   703  	t.Parallel()
   704  
   705  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   706  	defer closeConn(t, conn)
   707  
   708  	type allTypes struct {
   709  		ui   uint
   710  		ui8  uint8
   711  		ui16 uint16
   712  		ui32 uint32
   713  		ui64 uint64
   714  		i    int
   715  		i8   int8
   716  		i16  int16
   717  		i32  int32
   718  		i64  int64
   719  	}
   720  
   721  	var actual, zero allTypes
   722  
   723  	successfulEncodeTests := []struct {
   724  		sql      string
   725  		queryArg interface{}
   726  		scanArg  interface{}
   727  		expected allTypes
   728  	}{
   729  		// Check any integer type where value is within int2 range can be encoded
   730  		{"select $1::int2", int(42), &actual.i16, allTypes{i16: 42}},
   731  		{"select $1::int2", int8(42), &actual.i16, allTypes{i16: 42}},
   732  		{"select $1::int2", int16(42), &actual.i16, allTypes{i16: 42}},
   733  		{"select $1::int2", int32(42), &actual.i16, allTypes{i16: 42}},
   734  		{"select $1::int2", int64(42), &actual.i16, allTypes{i16: 42}},
   735  		{"select $1::int2", uint(42), &actual.i16, allTypes{i16: 42}},
   736  		{"select $1::int2", uint8(42), &actual.i16, allTypes{i16: 42}},
   737  		{"select $1::int2", uint16(42), &actual.i16, allTypes{i16: 42}},
   738  		{"select $1::int2", uint32(42), &actual.i16, allTypes{i16: 42}},
   739  		{"select $1::int2", uint64(42), &actual.i16, allTypes{i16: 42}},
   740  
   741  		// Check any integer type where value is within int4 range can be encoded
   742  		{"select $1::int4", int(42), &actual.i32, allTypes{i32: 42}},
   743  		{"select $1::int4", int8(42), &actual.i32, allTypes{i32: 42}},
   744  		{"select $1::int4", int16(42), &actual.i32, allTypes{i32: 42}},
   745  		{"select $1::int4", int32(42), &actual.i32, allTypes{i32: 42}},
   746  		{"select $1::int4", int64(42), &actual.i32, allTypes{i32: 42}},
   747  		{"select $1::int4", uint(42), &actual.i32, allTypes{i32: 42}},
   748  		{"select $1::int4", uint8(42), &actual.i32, allTypes{i32: 42}},
   749  		{"select $1::int4", uint16(42), &actual.i32, allTypes{i32: 42}},
   750  		{"select $1::int4", uint32(42), &actual.i32, allTypes{i32: 42}},
   751  		{"select $1::int4", uint64(42), &actual.i32, allTypes{i32: 42}},
   752  
   753  		// Check any integer type where value is within int8 range can be encoded
   754  		{"select $1::int8", int(42), &actual.i64, allTypes{i64: 42}},
   755  		{"select $1::int8", int8(42), &actual.i64, allTypes{i64: 42}},
   756  		{"select $1::int8", int16(42), &actual.i64, allTypes{i64: 42}},
   757  		{"select $1::int8", int32(42), &actual.i64, allTypes{i64: 42}},
   758  		{"select $1::int8", int64(42), &actual.i64, allTypes{i64: 42}},
   759  		{"select $1::int8", uint(42), &actual.i64, allTypes{i64: 42}},
   760  		{"select $1::int8", uint8(42), &actual.i64, allTypes{i64: 42}},
   761  		{"select $1::int8", uint16(42), &actual.i64, allTypes{i64: 42}},
   762  		{"select $1::int8", uint32(42), &actual.i64, allTypes{i64: 42}},
   763  		{"select $1::int8", uint64(42), &actual.i64, allTypes{i64: 42}},
   764  	}
   765  
   766  	for i, tt := range successfulEncodeTests {
   767  		actual = zero
   768  
   769  		err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(tt.scanArg)
   770  		if err != nil {
   771  			t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
   772  			continue
   773  		}
   774  
   775  		if actual != tt.expected {
   776  			t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArg -> %v)", i, tt.expected, actual, tt.sql, tt.queryArg)
   777  		}
   778  
   779  		ensureConnValid(t, conn)
   780  	}
   781  
   782  	failedEncodeTests := []struct {
   783  		sql      string
   784  		queryArg interface{}
   785  	}{
   786  		// Check any integer type where value is outside pg:int2 range cannot be encoded
   787  		{"select $1::int2", int(32769)},
   788  		{"select $1::int2", int32(32769)},
   789  		{"select $1::int2", int32(32769)},
   790  		{"select $1::int2", int64(32769)},
   791  		{"select $1::int2", uint(32769)},
   792  		{"select $1::int2", uint16(32769)},
   793  		{"select $1::int2", uint32(32769)},
   794  		{"select $1::int2", uint64(32769)},
   795  
   796  		// Check any integer type where value is outside pg:int4 range cannot be encoded
   797  		{"select $1::int4", int64(2147483649)},
   798  		{"select $1::int4", uint32(2147483649)},
   799  		{"select $1::int4", uint64(2147483649)},
   800  
   801  		// Check any integer type where value is outside pg:int8 range cannot be encoded
   802  		{"select $1::int8", uint64(9223372036854775809)},
   803  	}
   804  
   805  	for i, tt := range failedEncodeTests {
   806  		err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(nil)
   807  		if err == nil {
   808  			t.Errorf("%d. Expected failure to encode, but unexpectedly succeeded: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
   809  		} else if !strings.Contains(err.Error(), "is greater than") {
   810  			t.Errorf("%d. Expected failure to encode, but got: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg)
   811  		}
   812  
   813  		ensureConnValid(t, conn)
   814  	}
   815  }
   816  
   817  func TestQueryRowCoreIntegerDecoding(t *testing.T) {
   818  	t.Parallel()
   819  
   820  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   821  	defer closeConn(t, conn)
   822  
   823  	type allTypes struct {
   824  		ui   uint
   825  		ui8  uint8
   826  		ui16 uint16
   827  		ui32 uint32
   828  		ui64 uint64
   829  		i    int
   830  		i8   int8
   831  		i16  int16
   832  		i32  int32
   833  		i64  int64
   834  	}
   835  
   836  	var actual, zero allTypes
   837  
   838  	successfulDecodeTests := []struct {
   839  		sql      string
   840  		scanArg  interface{}
   841  		expected allTypes
   842  	}{
   843  		// Check any integer type where value is within Go:int range can be decoded
   844  		{"select 42::int2", &actual.i, allTypes{i: 42}},
   845  		{"select 42::int4", &actual.i, allTypes{i: 42}},
   846  		{"select 42::int8", &actual.i, allTypes{i: 42}},
   847  		{"select -42::int2", &actual.i, allTypes{i: -42}},
   848  		{"select -42::int4", &actual.i, allTypes{i: -42}},
   849  		{"select -42::int8", &actual.i, allTypes{i: -42}},
   850  
   851  		// Check any integer type where value is within Go:int8 range can be decoded
   852  		{"select 42::int2", &actual.i8, allTypes{i8: 42}},
   853  		{"select 42::int4", &actual.i8, allTypes{i8: 42}},
   854  		{"select 42::int8", &actual.i8, allTypes{i8: 42}},
   855  		{"select -42::int2", &actual.i8, allTypes{i8: -42}},
   856  		{"select -42::int4", &actual.i8, allTypes{i8: -42}},
   857  		{"select -42::int8", &actual.i8, allTypes{i8: -42}},
   858  
   859  		// Check any integer type where value is within Go:int16 range can be decoded
   860  		{"select 42::int2", &actual.i16, allTypes{i16: 42}},
   861  		{"select 42::int4", &actual.i16, allTypes{i16: 42}},
   862  		{"select 42::int8", &actual.i16, allTypes{i16: 42}},
   863  		{"select -42::int2", &actual.i16, allTypes{i16: -42}},
   864  		{"select -42::int4", &actual.i16, allTypes{i16: -42}},
   865  		{"select -42::int8", &actual.i16, allTypes{i16: -42}},
   866  
   867  		// Check any integer type where value is within Go:int32 range can be decoded
   868  		{"select 42::int2", &actual.i32, allTypes{i32: 42}},
   869  		{"select 42::int4", &actual.i32, allTypes{i32: 42}},
   870  		{"select 42::int8", &actual.i32, allTypes{i32: 42}},
   871  		{"select -42::int2", &actual.i32, allTypes{i32: -42}},
   872  		{"select -42::int4", &actual.i32, allTypes{i32: -42}},
   873  		{"select -42::int8", &actual.i32, allTypes{i32: -42}},
   874  
   875  		// Check any integer type where value is within Go:int64 range can be decoded
   876  		{"select 42::int2", &actual.i64, allTypes{i64: 42}},
   877  		{"select 42::int4", &actual.i64, allTypes{i64: 42}},
   878  		{"select 42::int8", &actual.i64, allTypes{i64: 42}},
   879  		{"select -42::int2", &actual.i64, allTypes{i64: -42}},
   880  		{"select -42::int4", &actual.i64, allTypes{i64: -42}},
   881  		{"select -42::int8", &actual.i64, allTypes{i64: -42}},
   882  
   883  		// Check any integer type where value is within Go:uint range can be decoded
   884  		{"select 128::int2", &actual.ui, allTypes{ui: 128}},
   885  		{"select 128::int4", &actual.ui, allTypes{ui: 128}},
   886  		{"select 128::int8", &actual.ui, allTypes{ui: 128}},
   887  
   888  		// Check any integer type where value is within Go:uint8 range can be decoded
   889  		{"select 128::int2", &actual.ui8, allTypes{ui8: 128}},
   890  		{"select 128::int4", &actual.ui8, allTypes{ui8: 128}},
   891  		{"select 128::int8", &actual.ui8, allTypes{ui8: 128}},
   892  
   893  		// Check any integer type where value is within Go:uint16 range can be decoded
   894  		{"select 42::int2", &actual.ui16, allTypes{ui16: 42}},
   895  		{"select 32768::int4", &actual.ui16, allTypes{ui16: 32768}},
   896  		{"select 32768::int8", &actual.ui16, allTypes{ui16: 32768}},
   897  
   898  		// Check any integer type where value is within Go:uint32 range can be decoded
   899  		{"select 42::int2", &actual.ui32, allTypes{ui32: 42}},
   900  		{"select 42::int4", &actual.ui32, allTypes{ui32: 42}},
   901  		{"select 2147483648::int8", &actual.ui32, allTypes{ui32: 2147483648}},
   902  
   903  		// Check any integer type where value is within Go:uint64 range can be decoded
   904  		{"select 42::int2", &actual.ui64, allTypes{ui64: 42}},
   905  		{"select 42::int4", &actual.ui64, allTypes{ui64: 42}},
   906  		{"select 42::int8", &actual.ui64, allTypes{ui64: 42}},
   907  	}
   908  
   909  	for i, tt := range successfulDecodeTests {
   910  		actual = zero
   911  
   912  		err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg)
   913  		if err != nil {
   914  			t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
   915  			continue
   916  		}
   917  
   918  		if actual != tt.expected {
   919  			t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
   920  		}
   921  
   922  		ensureConnValid(t, conn)
   923  	}
   924  
   925  	failedDecodeTests := []struct {
   926  		sql         string
   927  		scanArg     interface{}
   928  		expectedErr string
   929  	}{
   930  		// Check any integer type where value is outside Go:int8 range cannot be decoded
   931  		{"select 128::int2", &actual.i8, "is greater than"},
   932  		{"select 128::int4", &actual.i8, "is greater than"},
   933  		{"select 128::int8", &actual.i8, "is greater than"},
   934  		{"select -129::int2", &actual.i8, "is less than"},
   935  		{"select -129::int4", &actual.i8, "is less than"},
   936  		{"select -129::int8", &actual.i8, "is less than"},
   937  
   938  		// Check any integer type where value is outside Go:int16 range cannot be decoded
   939  		{"select 32768::int4", &actual.i16, "is greater than"},
   940  		{"select 32768::int8", &actual.i16, "is greater than"},
   941  		{"select -32769::int4", &actual.i16, "is less than"},
   942  		{"select -32769::int8", &actual.i16, "is less than"},
   943  
   944  		// Check any integer type where value is outside Go:int32 range cannot be decoded
   945  		{"select 2147483648::int8", &actual.i32, "is greater than"},
   946  		{"select -2147483649::int8", &actual.i32, "is less than"},
   947  
   948  		// Check any integer type where value is outside Go:uint range cannot be decoded
   949  		{"select -1::int2", &actual.ui, "is less than"},
   950  		{"select -1::int4", &actual.ui, "is less than"},
   951  		{"select -1::int8", &actual.ui, "is less than"},
   952  
   953  		// Check any integer type where value is outside Go:uint8 range cannot be decoded
   954  		{"select 256::int2", &actual.ui8, "is greater than"},
   955  		{"select 256::int4", &actual.ui8, "is greater than"},
   956  		{"select 256::int8", &actual.ui8, "is greater than"},
   957  		{"select -1::int2", &actual.ui8, "is less than"},
   958  		{"select -1::int4", &actual.ui8, "is less than"},
   959  		{"select -1::int8", &actual.ui8, "is less than"},
   960  
   961  		// Check any integer type where value is outside Go:uint16 cannot be decoded
   962  		{"select 65536::int4", &actual.ui16, "is greater than"},
   963  		{"select 65536::int8", &actual.ui16, "is greater than"},
   964  		{"select -1::int2", &actual.ui16, "is less than"},
   965  		{"select -1::int4", &actual.ui16, "is less than"},
   966  		{"select -1::int8", &actual.ui16, "is less than"},
   967  
   968  		// Check any integer type where value is outside Go:uint32 range cannot be decoded
   969  		{"select 4294967296::int8", &actual.ui32, "is greater than"},
   970  		{"select -1::int2", &actual.ui32, "is less than"},
   971  		{"select -1::int4", &actual.ui32, "is less than"},
   972  		{"select -1::int8", &actual.ui32, "is less than"},
   973  
   974  		// Check any integer type where value is outside Go:uint64 range cannot be decoded
   975  		{"select -1::int2", &actual.ui64, "is less than"},
   976  		{"select -1::int4", &actual.ui64, "is less than"},
   977  		{"select -1::int8", &actual.ui64, "is less than"},
   978  	}
   979  
   980  	for i, tt := range failedDecodeTests {
   981  		err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg)
   982  		if err == nil {
   983  			t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql)
   984  		} else if !strings.Contains(err.Error(), tt.expectedErr) {
   985  			t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql)
   986  		}
   987  
   988  		ensureConnValid(t, conn)
   989  	}
   990  }
   991  
   992  func TestQueryRowCoreByteSlice(t *testing.T) {
   993  	t.Parallel()
   994  
   995  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   996  	defer closeConn(t, conn)
   997  
   998  	tests := []struct {
   999  		sql      string
  1000  		queryArg interface{}
  1001  		expected []byte
  1002  	}{
  1003  		{"select $1::text", "Jack", []byte("Jack")},
  1004  		{"select $1::text", []byte("Jack"), []byte("Jack")},
  1005  		{"select $1::varchar", []byte("Jack"), []byte("Jack")},
  1006  		{"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}},
  1007  	}
  1008  
  1009  	for i, tt := range tests {
  1010  		var actual []byte
  1011  
  1012  		err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(&actual)
  1013  		if err != nil {
  1014  			t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
  1015  		}
  1016  
  1017  		if !bytes.Equal(actual, tt.expected) {
  1018  			t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
  1019  		}
  1020  
  1021  		ensureConnValid(t, conn)
  1022  	}
  1023  }
  1024  
  1025  func TestQueryRowErrors(t *testing.T) {
  1026  	t.Parallel()
  1027  
  1028  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1029  	defer closeConn(t, conn)
  1030  
  1031  	type allTypes struct {
  1032  		i16 int16
  1033  		i   int
  1034  		s   string
  1035  	}
  1036  
  1037  	var actual, zero allTypes
  1038  
  1039  	tests := []struct {
  1040  		sql       string
  1041  		queryArgs []interface{}
  1042  		scanArgs  []interface{}
  1043  		err       string
  1044  	}{
  1045  		// {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`},
  1046  		// {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"},
  1047  		{"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "unable to assign"},
  1048  		// {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"},
  1049  	}
  1050  
  1051  	for i, tt := range tests {
  1052  		actual = zero
  1053  
  1054  		err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
  1055  		if err == nil {
  1056  			t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs)
  1057  		}
  1058  		if err != nil && !strings.Contains(err.Error(), tt.err) {
  1059  			t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs)
  1060  		}
  1061  
  1062  		ensureConnValid(t, conn)
  1063  	}
  1064  }
  1065  
  1066  func TestQueryRowNoResults(t *testing.T) {
  1067  	t.Parallel()
  1068  
  1069  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1070  	defer closeConn(t, conn)
  1071  
  1072  	var n int32
  1073  	err := conn.QueryRow(context.Background(), "select 1 where 1=0").Scan(&n)
  1074  	if err != pgx.ErrNoRows {
  1075  		t.Errorf("Expected pgx.ErrNoRows, got %v", err)
  1076  	}
  1077  
  1078  	ensureConnValid(t, conn)
  1079  }
  1080  
  1081  func TestQueryRowEmptyQuery(t *testing.T) {
  1082  	t.Parallel()
  1083  
  1084  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1085  	defer closeConn(t, conn)
  1086  
  1087  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
  1088  	defer cancel()
  1089  
  1090  	var n int32
  1091  	err := conn.QueryRow(ctx, "").Scan(&n)
  1092  	require.Error(t, err)
  1093  	require.False(t, pgconn.Timeout(err))
  1094  
  1095  	ensureConnValid(t, conn)
  1096  }
  1097  
  1098  func TestReadingValueAfterEmptyArray(t *testing.T) {
  1099  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1100  	defer closeConn(t, conn)
  1101  
  1102  	var a []string
  1103  	var b int32
  1104  	err := conn.QueryRow(context.Background(), "select '{}'::text[], 42::integer").Scan(&a, &b)
  1105  	if err != nil {
  1106  		t.Fatalf("conn.QueryRow failed: %v", err)
  1107  	}
  1108  
  1109  	if len(a) != 0 {
  1110  		t.Errorf("Expected 'a' to have length 0, but it was: %d", len(a))
  1111  	}
  1112  
  1113  	if b != 42 {
  1114  		t.Errorf("Expected 'b' to 42, but it was: %d", b)
  1115  	}
  1116  }
  1117  
  1118  func TestReadingNullByteArray(t *testing.T) {
  1119  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1120  	defer closeConn(t, conn)
  1121  
  1122  	var a []byte
  1123  	err := conn.QueryRow(context.Background(), "select null::text").Scan(&a)
  1124  	if err != nil {
  1125  		t.Fatalf("conn.QueryRow failed: %v", err)
  1126  	}
  1127  
  1128  	if a != nil {
  1129  		t.Errorf("Expected 'a' to be nil, but it was: %v", a)
  1130  	}
  1131  }
  1132  
  1133  func TestReadingNullByteArrays(t *testing.T) {
  1134  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1135  	defer closeConn(t, conn)
  1136  
  1137  	rows, err := conn.Query(context.Background(), "select null::text union all select null::text")
  1138  	if err != nil {
  1139  		t.Fatalf("conn.Query failed: %v", err)
  1140  	}
  1141  
  1142  	count := 0
  1143  	for rows.Next() {
  1144  		count++
  1145  		var a []byte
  1146  		if err := rows.Scan(&a); err != nil {
  1147  			t.Fatalf("failed to scan row: %v", err)
  1148  		}
  1149  		if a != nil {
  1150  			t.Errorf("Expected 'a' to be nil, but it was: %v", a)
  1151  		}
  1152  	}
  1153  	if count != 2 {
  1154  		t.Errorf("Expected to read 2 rows, read: %d", count)
  1155  	}
  1156  }
  1157  
  1158  // Use github.com/shopspring/decimal as real-world database/sql custom type
  1159  // to test against.
  1160  func TestConnQueryDatabaseSQLScanner(t *testing.T) {
  1161  	t.Parallel()
  1162  
  1163  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1164  	defer closeConn(t, conn)
  1165  
  1166  	var num decimal.Decimal
  1167  
  1168  	err := conn.QueryRow(context.Background(), "select '1234.567'::decimal").Scan(&num)
  1169  	if err != nil {
  1170  		t.Fatalf("Scan failed: %v", err)
  1171  	}
  1172  
  1173  	expected, err := decimal.NewFromString("1234.567")
  1174  	if err != nil {
  1175  		t.Fatal(err)
  1176  	}
  1177  
  1178  	if !num.Equals(expected) {
  1179  		t.Errorf("Expected num to be %v, but it was %v", expected, num)
  1180  	}
  1181  
  1182  	ensureConnValid(t, conn)
  1183  }
  1184  
  1185  // Use github.com/shopspring/decimal as real-world database/sql custom type
  1186  // to test against.
  1187  func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) {
  1188  	t.Parallel()
  1189  
  1190  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1191  	defer closeConn(t, conn)
  1192  
  1193  	expected, err := decimal.NewFromString("1234.567")
  1194  	if err != nil {
  1195  		t.Fatal(err)
  1196  	}
  1197  	var num decimal.Decimal
  1198  
  1199  	err = conn.QueryRow(context.Background(), "select $1::decimal", &expected).Scan(&num)
  1200  	if err != nil {
  1201  		t.Fatalf("Scan failed: %v", err)
  1202  	}
  1203  
  1204  	if !num.Equals(expected) {
  1205  		t.Errorf("Expected num to be %v, but it was %v", expected, num)
  1206  	}
  1207  
  1208  	ensureConnValid(t, conn)
  1209  }
  1210  
  1211  // https://github.com/jackc/pgx/issues/339
  1212  func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *testing.T) {
  1213  	t.Parallel()
  1214  
  1215  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1216  	defer closeConn(t, conn)
  1217  
  1218  	mustExec(t, conn, "create temporary table t(n numeric)")
  1219  
  1220  	var d *apd.Decimal
  1221  	commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d)
  1222  	if err != nil {
  1223  		t.Fatal(err)
  1224  	}
  1225  	if string(commandTag) != "INSERT 0 1" {
  1226  		t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag)
  1227  	}
  1228  
  1229  	ensureConnValid(t, conn)
  1230  }
  1231  
  1232  func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) {
  1233  	t.Parallel()
  1234  
  1235  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1236  	defer closeConn(t, conn)
  1237  
  1238  	expected, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
  1239  	if err != nil {
  1240  		t.Fatal(err)
  1241  	}
  1242  
  1243  	var u2 uuid.UUID
  1244  	err = conn.QueryRow(context.Background(), "select $1::uuid", expected).Scan(&u2)
  1245  	if err != nil {
  1246  		t.Fatalf("Scan failed: %v", err)
  1247  	}
  1248  
  1249  	if expected != u2 {
  1250  		t.Errorf("Expected u2 to be %v, but it was %v", expected, u2)
  1251  	}
  1252  
  1253  	ensureConnValid(t, conn)
  1254  }
  1255  
  1256  func TestConnQueryDatabaseSQLNullX(t *testing.T) {
  1257  	t.Parallel()
  1258  
  1259  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1260  	defer closeConn(t, conn)
  1261  
  1262  	type row struct {
  1263  		boolValid    sql.NullBool
  1264  		boolNull     sql.NullBool
  1265  		int64Valid   sql.NullInt64
  1266  		int64Null    sql.NullInt64
  1267  		float64Valid sql.NullFloat64
  1268  		float64Null  sql.NullFloat64
  1269  		stringValid  sql.NullString
  1270  		stringNull   sql.NullString
  1271  	}
  1272  
  1273  	expected := row{
  1274  		boolValid:    sql.NullBool{Bool: true, Valid: true},
  1275  		int64Valid:   sql.NullInt64{Int64: 123, Valid: true},
  1276  		float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true},
  1277  		stringValid:  sql.NullString{String: "pgx", Valid: true},
  1278  	}
  1279  
  1280  	var actual row
  1281  
  1282  	err := conn.QueryRow(
  1283  		context.Background(),
  1284  		"select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text",
  1285  		expected.boolValid,
  1286  		expected.boolNull,
  1287  		expected.int64Valid,
  1288  		expected.int64Null,
  1289  		expected.float64Valid,
  1290  		expected.float64Null,
  1291  		expected.stringValid,
  1292  		expected.stringNull,
  1293  	).Scan(
  1294  		&actual.boolValid,
  1295  		&actual.boolNull,
  1296  		&actual.int64Valid,
  1297  		&actual.int64Null,
  1298  		&actual.float64Valid,
  1299  		&actual.float64Null,
  1300  		&actual.stringValid,
  1301  		&actual.stringNull,
  1302  	)
  1303  	if err != nil {
  1304  		t.Fatalf("Scan failed: %v", err)
  1305  	}
  1306  
  1307  	if expected != actual {
  1308  		t.Errorf("Expected %v, but got %v", expected, actual)
  1309  	}
  1310  
  1311  	ensureConnValid(t, conn)
  1312  }
  1313  
  1314  func TestQueryContextSuccess(t *testing.T) {
  1315  	t.Parallel()
  1316  
  1317  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1318  	defer closeConn(t, conn)
  1319  
  1320  	ctx, cancelFunc := context.WithCancel(context.Background())
  1321  	defer cancelFunc()
  1322  
  1323  	rows, err := conn.Query(ctx, "select 42::integer")
  1324  	if err != nil {
  1325  		t.Fatal(err)
  1326  	}
  1327  
  1328  	var result, rowCount int
  1329  	for rows.Next() {
  1330  		err = rows.Scan(&result)
  1331  		if err != nil {
  1332  			t.Fatal(err)
  1333  		}
  1334  		rowCount++
  1335  	}
  1336  
  1337  	if rows.Err() != nil {
  1338  		t.Fatal(rows.Err())
  1339  	}
  1340  
  1341  	if rowCount != 1 {
  1342  		t.Fatalf("Expected 1 row, got %d", rowCount)
  1343  	}
  1344  	if result != 42 {
  1345  		t.Fatalf("Expected result 42, got %d", result)
  1346  	}
  1347  
  1348  	ensureConnValid(t, conn)
  1349  }
  1350  
  1351  func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
  1352  	t.Parallel()
  1353  
  1354  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1355  	defer closeConn(t, conn)
  1356  
  1357  	skipCockroachDB(t, conn, "Server uses numeric instead of int")
  1358  
  1359  	ctx, cancelFunc := context.WithCancel(context.Background())
  1360  	defer cancelFunc()
  1361  
  1362  	rows, err := conn.Query(ctx, "select 10/(10-n) from generate_series(1, 100) n")
  1363  	if err != nil {
  1364  		t.Fatal(err)
  1365  	}
  1366  
  1367  	var result, rowCount int
  1368  	for rows.Next() {
  1369  		err = rows.Scan(&result)
  1370  		if err != nil {
  1371  			t.Fatal(err)
  1372  		}
  1373  		rowCount++
  1374  	}
  1375  
  1376  	if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" {
  1377  		t.Fatalf("Expected division by zero error, but got %v", rows.Err())
  1378  	}
  1379  
  1380  	if rowCount != 9 {
  1381  		t.Fatalf("Expected 9 rows, got %d", rowCount)
  1382  	}
  1383  	if result != 10 {
  1384  		t.Fatalf("Expected result 10, got %d", result)
  1385  	}
  1386  
  1387  	ensureConnValid(t, conn)
  1388  }
  1389  
  1390  func TestQueryRowContextSuccess(t *testing.T) {
  1391  	t.Parallel()
  1392  
  1393  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1394  	defer closeConn(t, conn)
  1395  
  1396  	ctx, cancelFunc := context.WithCancel(context.Background())
  1397  	defer cancelFunc()
  1398  
  1399  	var result int
  1400  	err := conn.QueryRow(ctx, "select 42::integer").Scan(&result)
  1401  	if err != nil {
  1402  		t.Fatal(err)
  1403  	}
  1404  	if result != 42 {
  1405  		t.Fatalf("Expected result 42, got %d", result)
  1406  	}
  1407  
  1408  	ensureConnValid(t, conn)
  1409  }
  1410  
  1411  func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
  1412  	t.Parallel()
  1413  
  1414  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1415  	defer closeConn(t, conn)
  1416  
  1417  	ctx, cancelFunc := context.WithCancel(context.Background())
  1418  	defer cancelFunc()
  1419  
  1420  	var result int
  1421  	err := conn.QueryRow(ctx, "select 10/0").Scan(&result)
  1422  	if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
  1423  		t.Fatalf("Expected division by zero error, but got %v", err)
  1424  	}
  1425  
  1426  	ensureConnValid(t, conn)
  1427  }
  1428  
  1429  func TestQueryCloseBefore(t *testing.T) {
  1430  	t.Parallel()
  1431  
  1432  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1433  	closeConn(t, conn)
  1434  
  1435  	_, err := conn.Query(context.Background(), "select 1")
  1436  	require.Error(t, err)
  1437  	assert.True(t, pgconn.SafeToRetry(err))
  1438  }
  1439  
  1440  func TestScanRow(t *testing.T) {
  1441  	t.Parallel()
  1442  
  1443  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1444  	defer closeConn(t, conn)
  1445  
  1446  	resultReader := conn.PgConn().ExecParams(context.Background(), "select generate_series(1,$1)", [][]byte{[]byte("10")}, nil, nil, nil)
  1447  
  1448  	var sum, rowCount int32
  1449  
  1450  	for resultReader.NextRow() {
  1451  		var n int32
  1452  		err := pgx.ScanRow(conn.ConnInfo(), resultReader.FieldDescriptions(), resultReader.Values(), &n)
  1453  		assert.NoError(t, err)
  1454  		sum += n
  1455  		rowCount++
  1456  	}
  1457  
  1458  	_, err := resultReader.Close()
  1459  
  1460  	require.NoError(t, err)
  1461  	assert.EqualValues(t, 10, rowCount)
  1462  	assert.EqualValues(t, 55, sum)
  1463  }
  1464  
  1465  func TestConnSimpleProtocol(t *testing.T) {
  1466  	t.Parallel()
  1467  
  1468  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1469  	defer closeConn(t, conn)
  1470  
  1471  	// Test all supported low-level types
  1472  
  1473  	{
  1474  		expected := int64(42)
  1475  		var actual int64
  1476  		err := conn.QueryRow(
  1477  			context.Background(),
  1478  			"select $1::int8",
  1479  			pgx.QuerySimpleProtocol(true),
  1480  			expected,
  1481  		).Scan(&actual)
  1482  		if err != nil {
  1483  			t.Error(err)
  1484  		}
  1485  		if expected != actual {
  1486  			t.Errorf("expected %v got %v", expected, actual)
  1487  		}
  1488  	}
  1489  
  1490  	{
  1491  		expected := float64(1.23)
  1492  		var actual float64
  1493  		err := conn.QueryRow(
  1494  			context.Background(),
  1495  			"select $1::float8",
  1496  			pgx.QuerySimpleProtocol(true),
  1497  			expected,
  1498  		).Scan(&actual)
  1499  		if err != nil {
  1500  			t.Error(err)
  1501  		}
  1502  		if expected != actual {
  1503  			t.Errorf("expected %v got %v", expected, actual)
  1504  		}
  1505  	}
  1506  
  1507  	{
  1508  		expected := true
  1509  		var actual bool
  1510  		err := conn.QueryRow(
  1511  			context.Background(),
  1512  			"select $1",
  1513  			pgx.QuerySimpleProtocol(true),
  1514  			expected,
  1515  		).Scan(&actual)
  1516  		if err != nil {
  1517  			t.Error(err)
  1518  		}
  1519  		if expected != actual {
  1520  			t.Errorf("expected %v got %v", expected, actual)
  1521  		}
  1522  	}
  1523  
  1524  	{
  1525  		expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
  1526  		var actual []byte
  1527  		err := conn.QueryRow(
  1528  			context.Background(),
  1529  			"select $1::bytea",
  1530  			pgx.QuerySimpleProtocol(true),
  1531  			expected,
  1532  		).Scan(&actual)
  1533  		if err != nil {
  1534  			t.Error(err)
  1535  		}
  1536  		if bytes.Compare(actual, expected) != 0 {
  1537  			t.Errorf("expected %v got %v", expected, actual)
  1538  		}
  1539  	}
  1540  
  1541  	{
  1542  		expected := "test"
  1543  		var actual string
  1544  		err := conn.QueryRow(
  1545  			context.Background(),
  1546  			"select $1::text",
  1547  			pgx.QuerySimpleProtocol(true),
  1548  			expected,
  1549  		).Scan(&actual)
  1550  		if err != nil {
  1551  			t.Error(err)
  1552  		}
  1553  		if expected != actual {
  1554  			t.Errorf("expected %v got %v", expected, actual)
  1555  		}
  1556  	}
  1557  
  1558  	{
  1559  		tests := []struct {
  1560  			expected []string
  1561  		}{
  1562  			{[]string(nil)},
  1563  			{[]string{}},
  1564  			{[]string{"test", "foo", "bar"}},
  1565  			{[]string{`foo'bar"\baz;quz`, `foo'bar"\baz;quz`}},
  1566  		}
  1567  		for i, tt := range tests {
  1568  			var actual []string
  1569  			err := conn.QueryRow(
  1570  				context.Background(),
  1571  				"select $1::text[]",
  1572  				pgx.QuerySimpleProtocol(true),
  1573  				tt.expected,
  1574  			).Scan(&actual)
  1575  			assert.NoErrorf(t, err, "%d", i)
  1576  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1577  		}
  1578  	}
  1579  
  1580  	{
  1581  		tests := []struct {
  1582  			expected []int16
  1583  		}{
  1584  			{[]int16(nil)},
  1585  			{[]int16{}},
  1586  			{[]int16{1, 2, 3}},
  1587  		}
  1588  		for i, tt := range tests {
  1589  			var actual []int16
  1590  			err := conn.QueryRow(
  1591  				context.Background(),
  1592  				"select $1::smallint[]",
  1593  				pgx.QuerySimpleProtocol(true),
  1594  				tt.expected,
  1595  			).Scan(&actual)
  1596  			assert.NoErrorf(t, err, "%d", i)
  1597  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1598  		}
  1599  	}
  1600  
  1601  	{
  1602  		tests := []struct {
  1603  			expected []int32
  1604  		}{
  1605  			{[]int32(nil)},
  1606  			{[]int32{}},
  1607  			{[]int32{1, 2, 3}},
  1608  		}
  1609  		for i, tt := range tests {
  1610  			var actual []int32
  1611  			err := conn.QueryRow(
  1612  				context.Background(),
  1613  				"select $1::int[]",
  1614  				pgx.QuerySimpleProtocol(true),
  1615  				tt.expected,
  1616  			).Scan(&actual)
  1617  			assert.NoErrorf(t, err, "%d", i)
  1618  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1619  		}
  1620  	}
  1621  
  1622  	{
  1623  		tests := []struct {
  1624  			expected []int64
  1625  		}{
  1626  			{[]int64(nil)},
  1627  			{[]int64{}},
  1628  			{[]int64{1, 2, 3}},
  1629  		}
  1630  		for i, tt := range tests {
  1631  			var actual []int64
  1632  			err := conn.QueryRow(
  1633  				context.Background(),
  1634  				"select $1::bigint[]",
  1635  				pgx.QuerySimpleProtocol(true),
  1636  				tt.expected,
  1637  			).Scan(&actual)
  1638  			assert.NoErrorf(t, err, "%d", i)
  1639  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1640  		}
  1641  	}
  1642  
  1643  	{
  1644  		tests := []struct {
  1645  			expected []int
  1646  		}{
  1647  			{[]int(nil)},
  1648  			{[]int{}},
  1649  			{[]int{1, 2, 3}},
  1650  		}
  1651  		for i, tt := range tests {
  1652  			var actual []int
  1653  			err := conn.QueryRow(
  1654  				context.Background(),
  1655  				"select $1::bigint[]",
  1656  				pgx.QuerySimpleProtocol(true),
  1657  				tt.expected,
  1658  			).Scan(&actual)
  1659  			assert.NoErrorf(t, err, "%d", i)
  1660  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1661  		}
  1662  	}
  1663  
  1664  	{
  1665  		tests := []struct {
  1666  			expected []uint16
  1667  		}{
  1668  			{[]uint16(nil)},
  1669  			{[]uint16{}},
  1670  			{[]uint16{1, 2, 3}},
  1671  		}
  1672  		for i, tt := range tests {
  1673  			var actual []uint16
  1674  			err := conn.QueryRow(
  1675  				context.Background(),
  1676  				"select $1::smallint[]",
  1677  				pgx.QuerySimpleProtocol(true),
  1678  				tt.expected,
  1679  			).Scan(&actual)
  1680  			assert.NoErrorf(t, err, "%d", i)
  1681  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1682  		}
  1683  	}
  1684  
  1685  	{
  1686  		tests := []struct {
  1687  			expected []uint32
  1688  		}{
  1689  			{[]uint32(nil)},
  1690  			{[]uint32{}},
  1691  			{[]uint32{1, 2, 3}},
  1692  		}
  1693  		for i, tt := range tests {
  1694  			var actual []uint32
  1695  			err := conn.QueryRow(
  1696  				context.Background(),
  1697  				"select $1::bigint[]",
  1698  				pgx.QuerySimpleProtocol(true),
  1699  				tt.expected,
  1700  			).Scan(&actual)
  1701  			assert.NoErrorf(t, err, "%d", i)
  1702  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1703  		}
  1704  	}
  1705  
  1706  	{
  1707  		tests := []struct {
  1708  			expected []uint64
  1709  		}{
  1710  			{[]uint64(nil)},
  1711  			{[]uint64{}},
  1712  			{[]uint64{1, 2, 3}},
  1713  		}
  1714  		for i, tt := range tests {
  1715  			var actual []uint64
  1716  			err := conn.QueryRow(
  1717  				context.Background(),
  1718  				"select $1::bigint[]",
  1719  				pgx.QuerySimpleProtocol(true),
  1720  				tt.expected,
  1721  			).Scan(&actual)
  1722  			assert.NoErrorf(t, err, "%d", i)
  1723  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1724  		}
  1725  	}
  1726  
  1727  	{
  1728  		tests := []struct {
  1729  			expected []uint
  1730  		}{
  1731  			{[]uint(nil)},
  1732  			{[]uint{}},
  1733  			{[]uint{1, 2, 3}},
  1734  		}
  1735  		for i, tt := range tests {
  1736  			var actual []uint
  1737  			err := conn.QueryRow(
  1738  				context.Background(),
  1739  				"select $1::bigint[]",
  1740  				pgx.QuerySimpleProtocol(true),
  1741  				tt.expected,
  1742  			).Scan(&actual)
  1743  			assert.NoErrorf(t, err, "%d", i)
  1744  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1745  		}
  1746  	}
  1747  
  1748  	{
  1749  		tests := []struct {
  1750  			expected []float32
  1751  		}{
  1752  			{[]float32(nil)},
  1753  			{[]float32{}},
  1754  			{[]float32{1, 2, 3}},
  1755  		}
  1756  		for i, tt := range tests {
  1757  			var actual []float32
  1758  			err := conn.QueryRow(
  1759  				context.Background(),
  1760  				"select $1::float4[]",
  1761  				pgx.QuerySimpleProtocol(true),
  1762  				tt.expected,
  1763  			).Scan(&actual)
  1764  			assert.NoErrorf(t, err, "%d", i)
  1765  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1766  		}
  1767  	}
  1768  
  1769  	{
  1770  		tests := []struct {
  1771  			expected []float64
  1772  		}{
  1773  			{[]float64(nil)},
  1774  			{[]float64{}},
  1775  			{[]float64{1, 2, 3}},
  1776  		}
  1777  		for i, tt := range tests {
  1778  			var actual []float64
  1779  			err := conn.QueryRow(
  1780  				context.Background(),
  1781  				"select $1::float8[]",
  1782  				pgx.QuerySimpleProtocol(true),
  1783  				tt.expected,
  1784  			).Scan(&actual)
  1785  			assert.NoErrorf(t, err, "%d", i)
  1786  			assert.Equalf(t, tt.expected, actual, "%d", i)
  1787  		}
  1788  	}
  1789  
  1790  	// Test high-level type
  1791  
  1792  	{
  1793  		if conn.PgConn().ParameterStatus("crdb_version") == "" {
  1794  			// CockroachDB doesn't support circle type.
  1795  			expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
  1796  			actual := expected
  1797  			err := conn.QueryRow(
  1798  				context.Background(),
  1799  				"select $1::circle",
  1800  				pgx.QuerySimpleProtocol(true),
  1801  				&expected,
  1802  			).Scan(&actual)
  1803  			if err != nil {
  1804  				t.Error(err)
  1805  			}
  1806  			if expected != actual {
  1807  				t.Errorf("expected %v got %v", expected, actual)
  1808  			}
  1809  		}
  1810  	}
  1811  
  1812  	// Test multiple args in single query
  1813  
  1814  	{
  1815  		expectedInt64 := int64(234423)
  1816  		expectedFloat64 := float64(-0.2312)
  1817  		expectedBool := true
  1818  		expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
  1819  		expectedString := "test"
  1820  		var actualInt64 int64
  1821  		var actualFloat64 float64
  1822  		var actualBool bool
  1823  		var actualBytes []byte
  1824  		var actualString string
  1825  		err := conn.QueryRow(
  1826  			context.Background(),
  1827  			"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
  1828  			pgx.QuerySimpleProtocol(true),
  1829  			expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
  1830  		).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
  1831  		if err != nil {
  1832  			t.Error(err)
  1833  		}
  1834  		if expectedInt64 != actualInt64 {
  1835  			t.Errorf("expected %v got %v", expectedInt64, actualInt64)
  1836  		}
  1837  		if expectedFloat64 != actualFloat64 {
  1838  			t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
  1839  		}
  1840  		if expectedBool != actualBool {
  1841  			t.Errorf("expected %v got %v", expectedBool, actualBool)
  1842  		}
  1843  		if bytes.Compare(expectedBytes, actualBytes) != 0 {
  1844  			t.Errorf("expected %v got %v", expectedBytes, actualBytes)
  1845  		}
  1846  		if expectedString != actualString {
  1847  			t.Errorf("expected %v got %v", expectedString, actualString)
  1848  		}
  1849  	}
  1850  
  1851  	// Test dangerous cases
  1852  
  1853  	{
  1854  		expected := "foo';drop table users;"
  1855  		var actual string
  1856  		err := conn.QueryRow(
  1857  			context.Background(),
  1858  			"select $1",
  1859  			pgx.QuerySimpleProtocol(true),
  1860  			expected,
  1861  		).Scan(&actual)
  1862  		if err != nil {
  1863  			t.Error(err)
  1864  		}
  1865  		if expected != actual {
  1866  			t.Errorf("expected %v got %v", expected, actual)
  1867  		}
  1868  	}
  1869  
  1870  	ensureConnValid(t, conn)
  1871  }
  1872  
  1873  func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
  1874  	t.Parallel()
  1875  
  1876  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1877  	defer closeConn(t, conn)
  1878  
  1879  	skipCockroachDB(t, conn, "Server does not support changing client_encoding (https://www.cockroachlabs.com/docs/stable/set-vars.html)")
  1880  
  1881  	mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
  1882  
  1883  	var expected string
  1884  	err := conn.QueryRow(
  1885  		context.Background(),
  1886  		"select $1",
  1887  		pgx.QuerySimpleProtocol(true),
  1888  		"test",
  1889  	).Scan(&expected)
  1890  	if err == nil {
  1891  		t.Error("expected error when client_encoding not UTF8, but no error occurred")
  1892  	}
  1893  
  1894  	ensureConnValid(t, conn)
  1895  }
  1896  
  1897  func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
  1898  	t.Parallel()
  1899  
  1900  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1901  	defer closeConn(t, conn)
  1902  
  1903  	skipCockroachDB(t, conn, "Server does not support standard_conforming_strings = off (https://github.com/cockroachdb/cockroach/issues/36215)")
  1904  
  1905  	mustExec(t, conn, "set standard_conforming_strings to off")
  1906  
  1907  	var expected string
  1908  	err := conn.QueryRow(
  1909  		context.Background(),
  1910  		"select $1",
  1911  		pgx.QuerySimpleProtocol(true),
  1912  		`\'; drop table users; --`,
  1913  	).Scan(&expected)
  1914  	if err == nil {
  1915  		t.Error("expected error when standard_conforming_strings is off, but no error occurred")
  1916  	}
  1917  
  1918  	ensureConnValid(t, conn)
  1919  }
  1920  
  1921  func TestQueryStatementCacheModes(t *testing.T) {
  1922  	t.Parallel()
  1923  
  1924  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
  1925  
  1926  	tests := []struct {
  1927  		name                string
  1928  		buildStatementCache pgx.BuildStatementCacheFunc
  1929  	}{
  1930  		{
  1931  			name:                "disabled",
  1932  			buildStatementCache: nil,
  1933  		},
  1934  		{
  1935  			name: "prepare",
  1936  			buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
  1937  				return stmtcache.New(conn, stmtcache.ModePrepare, 32)
  1938  			},
  1939  		},
  1940  		{
  1941  			name: "describe",
  1942  			buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache {
  1943  				return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
  1944  			},
  1945  		},
  1946  	}
  1947  
  1948  	for _, tt := range tests {
  1949  		func() {
  1950  			config.BuildStatementCache = tt.buildStatementCache
  1951  			conn := mustConnect(t, config)
  1952  			defer closeConn(t, conn)
  1953  
  1954  			var n int
  1955  			err := conn.QueryRow(context.Background(), "select 1").Scan(&n)
  1956  			assert.NoError(t, err, tt.name)
  1957  			assert.Equal(t, 1, n, tt.name)
  1958  
  1959  			err = conn.QueryRow(context.Background(), "select 2").Scan(&n)
  1960  			assert.NoError(t, err, tt.name)
  1961  			assert.Equal(t, 2, n, tt.name)
  1962  
  1963  			err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
  1964  			assert.NoError(t, err, tt.name)
  1965  			assert.Equal(t, 1, n, tt.name)
  1966  
  1967  			ensureConnValid(t, conn)
  1968  		}()
  1969  	}
  1970  }
  1971  
  1972  // https://github.com/jackc/pgx/issues/895
  1973  func TestQueryErrorWithNilStatementCacheMode(t *testing.T) {
  1974  	t.Parallel()
  1975  
  1976  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
  1977  	config.BuildStatementCache = nil
  1978  
  1979  	conn := mustConnect(t, config)
  1980  	defer closeConn(t, conn)
  1981  
  1982  	_, err := conn.Exec(context.Background(), "create temporary table t_unq(id text primary key);")
  1983  	require.NoError(t, err)
  1984  
  1985  	_, err = conn.Exec(context.Background(), "insert into t_unq (id) values ($1)", "abc")
  1986  	require.NoError(t, err)
  1987  
  1988  	rows, err := conn.Query(context.Background(), "insert into t_unq (id) values ($1)", "abc")
  1989  	require.NoError(t, err)
  1990  	rows.Close()
  1991  	err = rows.Err()
  1992  	require.Error(t, err)
  1993  	var pgErr *pgconn.PgError
  1994  	if errors.As(err, &pgErr) {
  1995  		assert.Equal(t, "23505", pgErr.Code)
  1996  	} else {
  1997  		t.Errorf("err is not a *pgconn.PgError: %T", err)
  1998  	}
  1999  
  2000  	ensureConnValid(t, conn)
  2001  }
  2002  
  2003  func TestConnQueryFunc(t *testing.T) {
  2004  	t.Parallel()
  2005  
  2006  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
  2007  		var actualResults []interface{}
  2008  
  2009  		var a, b int
  2010  		ct, err := conn.QueryFunc(
  2011  			context.Background(),
  2012  			"select n, n * 2 from generate_series(1, $1) n",
  2013  			[]interface{}{3},
  2014  			[]interface{}{&a, &b},
  2015  			func(pgx.QueryFuncRow) error {
  2016  				actualResults = append(actualResults, []interface{}{a, b})
  2017  				return nil
  2018  			},
  2019  		)
  2020  		require.NoError(t, err)
  2021  
  2022  		expectedResults := []interface{}{
  2023  			[]interface{}{1, 2},
  2024  			[]interface{}{2, 4},
  2025  			[]interface{}{3, 6},
  2026  		}
  2027  		require.Equal(t, expectedResults, actualResults)
  2028  		require.EqualValues(t, 3, ct.RowsAffected())
  2029  	})
  2030  }
  2031  
  2032  func TestConnQueryFuncScanError(t *testing.T) {
  2033  	t.Parallel()
  2034  
  2035  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
  2036  		var actualResults []interface{}
  2037  
  2038  		var a, b int
  2039  		ct, err := conn.QueryFunc(
  2040  			context.Background(),
  2041  			"select 'foo', 'bar' from generate_series(1, $1) n",
  2042  			[]interface{}{3},
  2043  			[]interface{}{&a, &b},
  2044  			func(pgx.QueryFuncRow) error {
  2045  				actualResults = append(actualResults, []interface{}{a, b})
  2046  				return nil
  2047  			},
  2048  		)
  2049  		require.EqualError(t, err, "can't scan into dest[0]: unable to assign to *int")
  2050  		require.Nil(t, ct)
  2051  	})
  2052  }
  2053  
  2054  func TestConnQueryFuncAbort(t *testing.T) {
  2055  	t.Parallel()
  2056  
  2057  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
  2058  		var a, b int
  2059  		ct, err := conn.QueryFunc(
  2060  			context.Background(),
  2061  			"select n, n * 2 from generate_series(1, $1) n",
  2062  			[]interface{}{3},
  2063  			[]interface{}{&a, &b},
  2064  			func(pgx.QueryFuncRow) error {
  2065  				return errors.New("abort")
  2066  			},
  2067  		)
  2068  		require.EqualError(t, err, "abort")
  2069  		require.Nil(t, ct)
  2070  	})
  2071  }
  2072  
  2073  func ExampleConn_QueryFunc() {
  2074  	conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
  2075  	if err != nil {
  2076  		fmt.Printf("Unable to establish connection: %v", err)
  2077  		return
  2078  	}
  2079  
  2080  	var a, b int
  2081  	_, err = conn.QueryFunc(
  2082  		context.Background(),
  2083  		"select n, n * 2 from generate_series(1, $1) n",
  2084  		[]interface{}{3},
  2085  		[]interface{}{&a, &b},
  2086  		func(pgx.QueryFuncRow) error {
  2087  			fmt.Printf("%v, %v\n", a, b)
  2088  			return nil
  2089  		},
  2090  	)
  2091  	if err != nil {
  2092  		fmt.Printf("QueryFunc error: %v", err)
  2093  		return
  2094  	}
  2095  
  2096  	// Output:
  2097  	// 1, 2
  2098  	// 2, 4
  2099  	// 3, 6
  2100  }
  2101  

View as plain text