...

Source file src/github.com/jackc/pgx/v4/stdlib/sql_test.go

Documentation: github.com/jackc/pgx/v4/stdlib

     1  package stdlib_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"encoding/json"
     8  	"math"
     9  	"os"
    10  	"reflect"
    11  	"regexp"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/Masterminds/semver/v3"
    16  	"github.com/jackc/pgconn"
    17  	"github.com/jackc/pgx/v4"
    18  	"github.com/jackc/pgx/v4/stdlib"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func openDB(t testing.TB) *sql.DB {
    24  	config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
    25  	require.NoError(t, err)
    26  	return stdlib.OpenDB(*config)
    27  }
    28  
    29  func closeDB(t testing.TB, db *sql.DB) {
    30  	err := db.Close()
    31  	require.NoError(t, err)
    32  }
    33  
    34  func skipCockroachDB(t testing.TB, db *sql.DB, msg string) {
    35  	conn, err := db.Conn(context.Background())
    36  	require.NoError(t, err)
    37  	defer conn.Close()
    38  
    39  	err = conn.Raw(func(driverConn interface{}) error {
    40  		conn := driverConn.(*stdlib.Conn).Conn()
    41  		if conn.PgConn().ParameterStatus("crdb_version") != "" {
    42  			t.Skip(msg)
    43  		}
    44  		return nil
    45  	})
    46  	require.NoError(t, err)
    47  }
    48  
    49  func skipPostgreSQLVersion(t testing.TB, db *sql.DB, constraintStr, msg string) {
    50  	conn, err := db.Conn(context.Background())
    51  	require.NoError(t, err)
    52  	defer conn.Close()
    53  
    54  	err = conn.Raw(func(driverConn interface{}) error {
    55  		conn := driverConn.(*stdlib.Conn).Conn()
    56  		serverVersionStr := conn.PgConn().ParameterStatus("server_version")
    57  		serverVersionStr = regexp.MustCompile(`^[0-9.]+`).FindString(serverVersionStr)
    58  		// if not PostgreSQL do nothing
    59  		if serverVersionStr == "" {
    60  			return nil
    61  		}
    62  
    63  		serverVersion, err := semver.NewVersion(serverVersionStr)
    64  		if err != nil {
    65  			return err
    66  		}
    67  
    68  		c, err := semver.NewConstraint(constraintStr)
    69  		if err != nil {
    70  			return err
    71  		}
    72  
    73  		if c.Check(serverVersion) {
    74  			t.Skip(msg)
    75  		}
    76  		return nil
    77  	})
    78  	require.NoError(t, err)
    79  }
    80  
    81  func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) {
    82  	t.Run("SimpleProto",
    83  		func(t *testing.T) {
    84  			config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
    85  			require.NoError(t, err)
    86  
    87  			config.PreferSimpleProtocol = true
    88  			db := stdlib.OpenDB(*config)
    89  			defer func() {
    90  				err := db.Close()
    91  				require.NoError(t, err)
    92  			}()
    93  
    94  			f(t, db)
    95  
    96  			ensureDBValid(t, db)
    97  		},
    98  	)
    99  
   100  	t.Run("DefaultProto",
   101  		func(t *testing.T) {
   102  			config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
   103  			require.NoError(t, err)
   104  
   105  			db := stdlib.OpenDB(*config)
   106  			defer func() {
   107  				err := db.Close()
   108  				require.NoError(t, err)
   109  			}()
   110  
   111  			f(t, db)
   112  
   113  			ensureDBValid(t, db)
   114  		},
   115  	)
   116  }
   117  
   118  // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should
   119  // cover an broken connections.
   120  func ensureDBValid(t testing.TB, db *sql.DB) {
   121  	var sum, rowCount int32
   122  
   123  	rows, err := db.Query("select generate_series(1,$1)", 10)
   124  	require.NoError(t, err)
   125  	defer rows.Close()
   126  
   127  	for rows.Next() {
   128  		var n int32
   129  		rows.Scan(&n)
   130  		sum += n
   131  		rowCount++
   132  	}
   133  
   134  	require.NoError(t, rows.Err())
   135  
   136  	if rowCount != 10 {
   137  		t.Error("Select called onDataRow wrong number of times")
   138  	}
   139  	if sum != 55 {
   140  		t.Error("Wrong values returned")
   141  	}
   142  }
   143  
   144  type preparer interface {
   145  	Prepare(query string) (*sql.Stmt, error)
   146  }
   147  
   148  func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
   149  	stmt, err := p.Prepare(sql)
   150  	require.NoError(t, err)
   151  	return stmt
   152  }
   153  
   154  func closeStmt(t *testing.T, stmt *sql.Stmt) {
   155  	err := stmt.Close()
   156  	require.NoError(t, err)
   157  }
   158  
   159  func TestSQLOpen(t *testing.T) {
   160  	tests := []struct {
   161  		driverName string
   162  	}{
   163  		{driverName: "pgx"},
   164  		{driverName: "pgx/v4"},
   165  	}
   166  
   167  	for _, tt := range tests {
   168  		tt := tt
   169  
   170  		t.Run(tt.driverName, func(t *testing.T) {
   171  			db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE"))
   172  			require.NoError(t, err)
   173  			closeDB(t, db)
   174  		})
   175  	}
   176  }
   177  
   178  func TestNormalLifeCycle(t *testing.T) {
   179  	db := openDB(t)
   180  	defer closeDB(t, db)
   181  
   182  	skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   183  
   184  	stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
   185  	defer closeStmt(t, stmt)
   186  
   187  	rows, err := stmt.Query(int32(1), int32(10))
   188  	require.NoError(t, err)
   189  
   190  	rowCount := int64(0)
   191  
   192  	for rows.Next() {
   193  		rowCount++
   194  
   195  		var s string
   196  		var n int64
   197  		err := rows.Scan(&s, &n)
   198  		require.NoError(t, err)
   199  
   200  		if s != "foo" {
   201  			t.Errorf(`Expected "foo", received "%v"`, s)
   202  		}
   203  		if n != rowCount {
   204  			t.Errorf("Expected %d, received %d", rowCount, n)
   205  		}
   206  	}
   207  	require.NoError(t, rows.Err())
   208  
   209  	require.EqualValues(t, 10, rowCount)
   210  
   211  	err = rows.Close()
   212  	require.NoError(t, err)
   213  
   214  	ensureDBValid(t, db)
   215  }
   216  
   217  func TestStmtExec(t *testing.T) {
   218  	db := openDB(t)
   219  	defer closeDB(t, db)
   220  
   221  	tx, err := db.Begin()
   222  	require.NoError(t, err)
   223  
   224  	createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
   225  	_, err = createStmt.Exec()
   226  	require.NoError(t, err)
   227  	closeStmt(t, createStmt)
   228  
   229  	insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
   230  	result, err := insertStmt.Exec("foo")
   231  	require.NoError(t, err)
   232  
   233  	n, err := result.RowsAffected()
   234  	require.NoError(t, err)
   235  	require.EqualValues(t, 1, n)
   236  	closeStmt(t, insertStmt)
   237  
   238  	ensureDBValid(t, db)
   239  }
   240  
   241  func TestQueryCloseRowsEarly(t *testing.T) {
   242  	db := openDB(t)
   243  	defer closeDB(t, db)
   244  
   245  	skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   246  
   247  	stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
   248  	defer closeStmt(t, stmt)
   249  
   250  	rows, err := stmt.Query(int32(1), int32(10))
   251  	require.NoError(t, err)
   252  
   253  	// Close rows immediately without having read them
   254  	err = rows.Close()
   255  	require.NoError(t, err)
   256  
   257  	// Run the query again to ensure the connection and statement are still ok
   258  	rows, err = stmt.Query(int32(1), int32(10))
   259  	require.NoError(t, err)
   260  
   261  	rowCount := int64(0)
   262  
   263  	for rows.Next() {
   264  		rowCount++
   265  
   266  		var s string
   267  		var n int64
   268  		err := rows.Scan(&s, &n)
   269  		require.NoError(t, err)
   270  		if s != "foo" {
   271  			t.Errorf(`Expected "foo", received "%v"`, s)
   272  		}
   273  		if n != rowCount {
   274  			t.Errorf("Expected %d, received %d", rowCount, n)
   275  		}
   276  	}
   277  	require.NoError(t, rows.Err())
   278  	require.EqualValues(t, 10, rowCount)
   279  
   280  	err = rows.Close()
   281  	require.NoError(t, err)
   282  
   283  	ensureDBValid(t, db)
   284  }
   285  
   286  func TestConnExec(t *testing.T) {
   287  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   288  		_, err := db.Exec("create temporary table t(a varchar not null)")
   289  		require.NoError(t, err)
   290  
   291  		result, err := db.Exec("insert into t values('hey')")
   292  		require.NoError(t, err)
   293  
   294  		n, err := result.RowsAffected()
   295  		require.NoError(t, err)
   296  		require.EqualValues(t, 1, n)
   297  	})
   298  }
   299  
   300  func TestConnQuery(t *testing.T) {
   301  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   302  		skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   303  
   304  		rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
   305  		require.NoError(t, err)
   306  
   307  		rowCount := int64(0)
   308  
   309  		for rows.Next() {
   310  			rowCount++
   311  
   312  			var s string
   313  			var n int64
   314  			err := rows.Scan(&s, &n)
   315  			require.NoError(t, err)
   316  			if s != "foo" {
   317  				t.Errorf(`Expected "foo", received "%v"`, s)
   318  			}
   319  			if n != rowCount {
   320  				t.Errorf("Expected %d, received %d", rowCount, n)
   321  			}
   322  		}
   323  		require.NoError(t, rows.Err())
   324  		require.EqualValues(t, 10, rowCount)
   325  
   326  		err = rows.Close()
   327  		require.NoError(t, err)
   328  	})
   329  }
   330  
   331  // https://github.com/jackc/pgx/issues/781
   332  func TestConnQueryDifferentScanPlansIssue781(t *testing.T) {
   333  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   334  		var s string
   335  		var b bool
   336  
   337  		rows, err := db.Query("select true, 'foo'")
   338  		require.NoError(t, err)
   339  
   340  		require.True(t, rows.Next())
   341  		require.NoError(t, rows.Scan(&b, &s))
   342  		assert.Equal(t, true, b)
   343  		assert.Equal(t, "foo", s)
   344  	})
   345  }
   346  
   347  func TestConnQueryNull(t *testing.T) {
   348  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   349  		rows, err := db.Query("select $1::int", nil)
   350  		require.NoError(t, err)
   351  
   352  		rowCount := int64(0)
   353  
   354  		for rows.Next() {
   355  			rowCount++
   356  
   357  			var n sql.NullInt64
   358  			err := rows.Scan(&n)
   359  			require.NoError(t, err)
   360  			if n.Valid != false {
   361  				t.Errorf("Expected n to be null, but it was %v", n)
   362  			}
   363  		}
   364  		require.NoError(t, rows.Err())
   365  		require.EqualValues(t, 1, rowCount)
   366  
   367  		err = rows.Close()
   368  		require.NoError(t, err)
   369  	})
   370  }
   371  
   372  func TestConnQueryRowByteSlice(t *testing.T) {
   373  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   374  		expected := []byte{222, 173, 190, 239}
   375  		var actual []byte
   376  
   377  		err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual)
   378  		require.NoError(t, err)
   379  		require.EqualValues(t, expected, actual)
   380  	})
   381  }
   382  
   383  func TestConnQueryFailure(t *testing.T) {
   384  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   385  		_, err := db.Query("select 'foo")
   386  		require.Error(t, err)
   387  		require.IsType(t, new(pgconn.PgError), err)
   388  	})
   389  }
   390  
   391  func TestConnSimpleSlicePassThrough(t *testing.T) {
   392  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   393  		skipCockroachDB(t, db, "Server does not support cardinality function")
   394  
   395  		var n int64
   396  		err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n)
   397  		require.NoError(t, err)
   398  		assert.EqualValues(t, 3, n)
   399  	})
   400  }
   401  
   402  // Test type that pgx would handle natively in binary, but since it is not a
   403  // database/sql native type should be passed through as a string
   404  func TestConnQueryRowPgxBinary(t *testing.T) {
   405  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   406  		sql := "select $1::int4[]"
   407  		expected := "{1,2,3}"
   408  		var actual string
   409  
   410  		err := db.QueryRow(sql, expected).Scan(&actual)
   411  		require.NoError(t, err)
   412  		require.EqualValues(t, expected, actual)
   413  	})
   414  }
   415  
   416  func TestConnQueryRowUnknownType(t *testing.T) {
   417  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   418  		skipCockroachDB(t, db, "Server does not support point type")
   419  
   420  		sql := "select $1::point"
   421  		expected := "(1,2)"
   422  		var actual string
   423  
   424  		err := db.QueryRow(sql, expected).Scan(&actual)
   425  		require.NoError(t, err)
   426  		require.EqualValues(t, expected, actual)
   427  	})
   428  }
   429  
   430  func TestConnQueryJSONIntoByteSlice(t *testing.T) {
   431  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   432  		_, err := db.Exec(`
   433  		create temporary table docs(
   434  			body json not null
   435  		);
   436  
   437  		insert into docs(body) values('{"foo": "bar"}');
   438  `)
   439  		require.NoError(t, err)
   440  
   441  		sql := `select * from docs`
   442  		expected := []byte(`{"foo": "bar"}`)
   443  		var actual []byte
   444  
   445  		err = db.QueryRow(sql).Scan(&actual)
   446  		if err != nil {
   447  			t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql)
   448  		}
   449  
   450  		if bytes.Compare(actual, expected) != 0 {
   451  			t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql)
   452  		}
   453  
   454  		_, err = db.Exec(`drop table docs`)
   455  		require.NoError(t, err)
   456  	})
   457  }
   458  
   459  func TestConnExecInsertByteSliceIntoJSON(t *testing.T) {
   460  	// Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data
   461  	// that needs to escape. No way to know whether the destination is really a text compatible or a bytea.
   462  
   463  	db := openDB(t)
   464  	defer closeDB(t, db)
   465  
   466  	_, err := db.Exec(`
   467  		create temporary table docs(
   468  			body json not null
   469  		);
   470  `)
   471  	require.NoError(t, err)
   472  
   473  	expected := []byte(`{"foo": "bar"}`)
   474  
   475  	_, err = db.Exec(`insert into docs(body) values($1)`, expected)
   476  	require.NoError(t, err)
   477  
   478  	var actual []byte
   479  	err = db.QueryRow(`select body from docs`).Scan(&actual)
   480  	require.NoError(t, err)
   481  
   482  	if bytes.Compare(actual, expected) != 0 {
   483  		t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual))
   484  	}
   485  
   486  	_, err = db.Exec(`drop table docs`)
   487  	require.NoError(t, err)
   488  }
   489  
   490  func TestTransactionLifeCycle(t *testing.T) {
   491  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   492  		_, err := db.Exec("create temporary table t(a varchar not null)")
   493  		require.NoError(t, err)
   494  
   495  		tx, err := db.Begin()
   496  		require.NoError(t, err)
   497  
   498  		_, err = tx.Exec("insert into t values('hi')")
   499  		require.NoError(t, err)
   500  
   501  		err = tx.Rollback()
   502  		require.NoError(t, err)
   503  
   504  		var n int64
   505  		err = db.QueryRow("select count(*) from t").Scan(&n)
   506  		require.NoError(t, err)
   507  		require.EqualValues(t, 0, n)
   508  
   509  		tx, err = db.Begin()
   510  		require.NoError(t, err)
   511  
   512  		_, err = tx.Exec("insert into t values('hi')")
   513  		require.NoError(t, err)
   514  
   515  		err = tx.Commit()
   516  		require.NoError(t, err)
   517  
   518  		err = db.QueryRow("select count(*) from t").Scan(&n)
   519  		require.NoError(t, err)
   520  		require.EqualValues(t, 1, n)
   521  	})
   522  }
   523  
   524  func TestConnBeginTxIsolation(t *testing.T) {
   525  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   526  		skipCockroachDB(t, db, "Server always uses serializable isolation level")
   527  
   528  		var defaultIsoLevel string
   529  		err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel)
   530  		require.NoError(t, err)
   531  
   532  		supportedTests := []struct {
   533  			sqlIso sql.IsolationLevel
   534  			pgIso  string
   535  		}{
   536  			{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
   537  			{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
   538  			{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
   539  			{sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
   540  			{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
   541  			{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
   542  		}
   543  		for i, tt := range supportedTests {
   544  			func() {
   545  				tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
   546  				if err != nil {
   547  					t.Errorf("%d. BeginTx failed: %v", i, err)
   548  					return
   549  				}
   550  				defer tx.Rollback()
   551  
   552  				var pgIso string
   553  				err = tx.QueryRow("show transaction_isolation").Scan(&pgIso)
   554  				if err != nil {
   555  					t.Errorf("%d. QueryRow failed: %v", i, err)
   556  				}
   557  
   558  				if pgIso != tt.pgIso {
   559  					t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso)
   560  				}
   561  			}()
   562  		}
   563  
   564  		unsupportedTests := []struct {
   565  			sqlIso sql.IsolationLevel
   566  		}{
   567  			{sqlIso: sql.LevelWriteCommitted},
   568  			{sqlIso: sql.LevelLinearizable},
   569  		}
   570  		for i, tt := range unsupportedTests {
   571  			tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso})
   572  			if err == nil {
   573  				t.Errorf("%d. BeginTx should have failed", i)
   574  				tx.Rollback()
   575  			}
   576  		}
   577  	})
   578  }
   579  
   580  func TestConnBeginTxReadOnly(t *testing.T) {
   581  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   582  		tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
   583  		require.NoError(t, err)
   584  		defer tx.Rollback()
   585  
   586  		var pgReadOnly string
   587  		err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly)
   588  		if err != nil {
   589  			t.Errorf("QueryRow failed: %v", err)
   590  		}
   591  
   592  		if pgReadOnly != "on" {
   593  			t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on")
   594  		}
   595  	})
   596  }
   597  
   598  func TestBeginTxContextCancel(t *testing.T) {
   599  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   600  		_, err := db.Exec("drop table if exists t")
   601  		require.NoError(t, err)
   602  
   603  		ctx, cancelFn := context.WithCancel(context.Background())
   604  
   605  		tx, err := db.BeginTx(ctx, nil)
   606  		require.NoError(t, err)
   607  
   608  		_, err = tx.Exec("create table t(id serial)")
   609  		require.NoError(t, err)
   610  
   611  		cancelFn()
   612  
   613  		err = tx.Commit()
   614  		if err != context.Canceled && err != sql.ErrTxDone {
   615  			t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone)
   616  		}
   617  
   618  		var n int
   619  		err = db.QueryRow("select count(*) from t").Scan(&n)
   620  		if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" {
   621  			t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err)
   622  		}
   623  	})
   624  }
   625  
   626  func TestAcquireConn(t *testing.T) {
   627  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   628  		var conns []*pgx.Conn
   629  
   630  		for i := 1; i < 6; i++ {
   631  			conn, err := stdlib.AcquireConn(db)
   632  			if err != nil {
   633  				t.Errorf("%d. AcquireConn failed: %v", i, err)
   634  				continue
   635  			}
   636  
   637  			var n int32
   638  			err = conn.QueryRow(context.Background(), "select 1").Scan(&n)
   639  			if err != nil {
   640  				t.Errorf("%d. QueryRow failed: %v", i, err)
   641  			}
   642  			if n != 1 {
   643  				t.Errorf("%d. n => %d, want %d", i, n, 1)
   644  			}
   645  
   646  			stats := db.Stats()
   647  			if stats.OpenConnections != i {
   648  				t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i)
   649  			}
   650  
   651  			conns = append(conns, conn)
   652  		}
   653  
   654  		for i, conn := range conns {
   655  			if err := stdlib.ReleaseConn(db, conn); err != nil {
   656  				t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err)
   657  			}
   658  		}
   659  	})
   660  }
   661  
   662  func TestConnRaw(t *testing.T) {
   663  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   664  		conn, err := db.Conn(context.Background())
   665  		require.NoError(t, err)
   666  
   667  		var n int
   668  		err = conn.Raw(func(driverConn interface{}) error {
   669  			conn := driverConn.(*stdlib.Conn).Conn()
   670  			return conn.QueryRow(context.Background(), "select 42").Scan(&n)
   671  		})
   672  		require.NoError(t, err)
   673  		assert.EqualValues(t, 42, n)
   674  	})
   675  }
   676  
   677  // https://github.com/jackc/pgx/issues/673
   678  func TestReleaseConnWithTxInProgress(t *testing.T) {
   679  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   680  		skipCockroachDB(t, db, "Server does not support backend PID")
   681  
   682  		c1, err := stdlib.AcquireConn(db)
   683  		require.NoError(t, err)
   684  
   685  		_, err = c1.Exec(context.Background(), "begin")
   686  		require.NoError(t, err)
   687  
   688  		c1PID := c1.PgConn().PID()
   689  
   690  		err = stdlib.ReleaseConn(db, c1)
   691  		require.NoError(t, err)
   692  
   693  		c2, err := stdlib.AcquireConn(db)
   694  		require.NoError(t, err)
   695  
   696  		c2PID := c2.PgConn().PID()
   697  
   698  		err = stdlib.ReleaseConn(db, c2)
   699  		require.NoError(t, err)
   700  
   701  		require.NotEqual(t, c1PID, c2PID)
   702  
   703  		// Releasing a conn with a tx in progress should close the connection
   704  		stats := db.Stats()
   705  		require.Equal(t, 1, stats.OpenConnections)
   706  	})
   707  }
   708  
   709  func TestConnPingContextSuccess(t *testing.T) {
   710  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   711  		err := db.PingContext(context.Background())
   712  		require.NoError(t, err)
   713  	})
   714  }
   715  
   716  func TestConnPrepareContextSuccess(t *testing.T) {
   717  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   718  		stmt, err := db.PrepareContext(context.Background(), "select now()")
   719  		require.NoError(t, err)
   720  		err = stmt.Close()
   721  		require.NoError(t, err)
   722  	})
   723  }
   724  
   725  func TestConnExecContextSuccess(t *testing.T) {
   726  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   727  		_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")
   728  		require.NoError(t, err)
   729  	})
   730  }
   731  
   732  func TestConnExecContextFailureRetry(t *testing.T) {
   733  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   734  		// We get a connection, immediately close it, and then get it back;
   735  		// DB.Conn along with Conn.ResetSession does the retry for us.
   736  		{
   737  			conn, err := stdlib.AcquireConn(db)
   738  			require.NoError(t, err)
   739  			conn.Close(context.Background())
   740  			stdlib.ReleaseConn(db, conn)
   741  		}
   742  		conn, err := db.Conn(context.Background())
   743  		require.NoError(t, err)
   744  		_, err = conn.ExecContext(context.Background(), "select 1")
   745  		require.NoError(t, err)
   746  	})
   747  }
   748  
   749  func TestConnQueryContextSuccess(t *testing.T) {
   750  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   751  		rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n")
   752  		require.NoError(t, err)
   753  
   754  		for rows.Next() {
   755  			var n int64
   756  			err := rows.Scan(&n)
   757  			require.NoError(t, err)
   758  		}
   759  		require.NoError(t, rows.Err())
   760  	})
   761  }
   762  
   763  func TestConnQueryContextFailureRetry(t *testing.T) {
   764  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   765  		// We get a connection, immediately close it, and then get it back;
   766  		// DB.Conn along with Conn.ResetSession does the retry for us.
   767  		{
   768  			conn, err := stdlib.AcquireConn(db)
   769  			require.NoError(t, err)
   770  			conn.Close(context.Background())
   771  			stdlib.ReleaseConn(db, conn)
   772  		}
   773  		conn, err := db.Conn(context.Background())
   774  		require.NoError(t, err)
   775  
   776  		_, err = conn.QueryContext(context.Background(), "select 1")
   777  		require.NoError(t, err)
   778  	})
   779  }
   780  
   781  func TestRowsColumnTypeDatabaseTypeName(t *testing.T) {
   782  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   783  		rows, err := db.Query("select 42::bigint")
   784  		require.NoError(t, err)
   785  
   786  		columnTypes, err := rows.ColumnTypes()
   787  		require.NoError(t, err)
   788  		require.Len(t, columnTypes, 1)
   789  
   790  		if columnTypes[0].DatabaseTypeName() != "INT8" {
   791  			t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8")
   792  		}
   793  
   794  		err = rows.Close()
   795  		require.NoError(t, err)
   796  	})
   797  }
   798  
   799  func TestStmtExecContextSuccess(t *testing.T) {
   800  	db := openDB(t)
   801  	defer closeDB(t, db)
   802  
   803  	_, err := db.Exec("create temporary table t(id int primary key)")
   804  	require.NoError(t, err)
   805  
   806  	stmt, err := db.Prepare("insert into t(id) values ($1::int4)")
   807  	require.NoError(t, err)
   808  	defer stmt.Close()
   809  
   810  	_, err = stmt.ExecContext(context.Background(), 42)
   811  	require.NoError(t, err)
   812  
   813  	ensureDBValid(t, db)
   814  }
   815  
   816  func TestStmtExecContextCancel(t *testing.T) {
   817  	db := openDB(t)
   818  	defer closeDB(t, db)
   819  
   820  	_, err := db.Exec("create temporary table t(id int primary key)")
   821  	require.NoError(t, err)
   822  
   823  	stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)")
   824  	require.NoError(t, err)
   825  	defer stmt.Close()
   826  
   827  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   828  	defer cancel()
   829  
   830  	_, err = stmt.ExecContext(ctx, 42)
   831  	if !pgconn.Timeout(err) {
   832  		t.Errorf("expected timeout error, got %v", err)
   833  	}
   834  
   835  	ensureDBValid(t, db)
   836  }
   837  
   838  func TestStmtQueryContextSuccess(t *testing.T) {
   839  	db := openDB(t)
   840  	defer closeDB(t, db)
   841  
   842  	skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   843  
   844  	stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
   845  	require.NoError(t, err)
   846  	defer stmt.Close()
   847  
   848  	rows, err := stmt.QueryContext(context.Background(), 5)
   849  	require.NoError(t, err)
   850  
   851  	for rows.Next() {
   852  		var n int64
   853  		if err := rows.Scan(&n); err != nil {
   854  			t.Error(err)
   855  		}
   856  	}
   857  
   858  	if rows.Err() != nil {
   859  		t.Error(rows.Err())
   860  	}
   861  
   862  	ensureDBValid(t, db)
   863  }
   864  
   865  func TestRowsColumnTypes(t *testing.T) {
   866  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
   867  		columnTypesTests := []struct {
   868  			Name     string
   869  			TypeName string
   870  			Length   struct {
   871  				Len int64
   872  				OK  bool
   873  			}
   874  			DecimalSize struct {
   875  				Precision int64
   876  				Scale     int64
   877  				OK        bool
   878  			}
   879  			ScanType reflect.Type
   880  		}{
   881  			{
   882  				Name:     "a",
   883  				TypeName: "INT8",
   884  				Length: struct {
   885  					Len int64
   886  					OK  bool
   887  				}{
   888  					Len: 0,
   889  					OK:  false,
   890  				},
   891  				DecimalSize: struct {
   892  					Precision int64
   893  					Scale     int64
   894  					OK        bool
   895  				}{
   896  					Precision: 0,
   897  					Scale:     0,
   898  					OK:        false,
   899  				},
   900  				ScanType: reflect.TypeOf(int64(0)),
   901  			}, {
   902  				Name:     "bar",
   903  				TypeName: "TEXT",
   904  				Length: struct {
   905  					Len int64
   906  					OK  bool
   907  				}{
   908  					Len: math.MaxInt64,
   909  					OK:  true,
   910  				},
   911  				DecimalSize: struct {
   912  					Precision int64
   913  					Scale     int64
   914  					OK        bool
   915  				}{
   916  					Precision: 0,
   917  					Scale:     0,
   918  					OK:        false,
   919  				},
   920  				ScanType: reflect.TypeOf(""),
   921  			}, {
   922  				Name:     "dec",
   923  				TypeName: "NUMERIC",
   924  				Length: struct {
   925  					Len int64
   926  					OK  bool
   927  				}{
   928  					Len: 0,
   929  					OK:  false,
   930  				},
   931  				DecimalSize: struct {
   932  					Precision int64
   933  					Scale     int64
   934  					OK        bool
   935  				}{
   936  					Precision: 9,
   937  					Scale:     2,
   938  					OK:        true,
   939  				},
   940  				ScanType: reflect.TypeOf(float64(0)),
   941  			}, {
   942  				Name:     "d",
   943  				TypeName: "1266",
   944  				Length: struct {
   945  					Len int64
   946  					OK  bool
   947  				}{
   948  					Len: 0,
   949  					OK:  false,
   950  				},
   951  				DecimalSize: struct {
   952  					Precision int64
   953  					Scale     int64
   954  					OK        bool
   955  				}{
   956  					Precision: 0,
   957  					Scale:     0,
   958  					OK:        false,
   959  				},
   960  				ScanType: reflect.TypeOf(""),
   961  			},
   962  		}
   963  
   964  		rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d")
   965  		require.NoError(t, err)
   966  
   967  		columns, err := rows.ColumnTypes()
   968  		require.NoError(t, err)
   969  		assert.Len(t, columns, 4)
   970  
   971  		for i, tt := range columnTypesTests {
   972  			c := columns[i]
   973  			if c.Name() != tt.Name {
   974  				t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name)
   975  			}
   976  			if c.DatabaseTypeName() != tt.TypeName {
   977  				t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName)
   978  			}
   979  			l, ok := c.Length()
   980  			if l != tt.Length.Len {
   981  				t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len)
   982  			}
   983  			if ok != tt.Length.OK {
   984  				t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK)
   985  			}
   986  			p, s, ok := c.DecimalSize()
   987  			if p != tt.DecimalSize.Precision {
   988  				t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision)
   989  			}
   990  			if s != tt.DecimalSize.Scale {
   991  				t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale)
   992  			}
   993  			if ok != tt.DecimalSize.OK {
   994  				t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK)
   995  			}
   996  			if c.ScanType() != tt.ScanType {
   997  				t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType)
   998  			}
   999  		}
  1000  	})
  1001  }
  1002  
  1003  func TestQueryLifeCycle(t *testing.T) {
  1004  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
  1005  		skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
  1006  
  1007  		rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
  1008  		require.NoError(t, err)
  1009  
  1010  		rowCount := int64(0)
  1011  
  1012  		for rows.Next() {
  1013  			rowCount++
  1014  			var (
  1015  				s string
  1016  				n int64
  1017  			)
  1018  
  1019  			err := rows.Scan(&s, &n)
  1020  			require.NoError(t, err)
  1021  
  1022  			if s != "foo" {
  1023  				t.Errorf(`Expected "foo", received "%v"`, s)
  1024  			}
  1025  
  1026  			if n != rowCount {
  1027  				t.Errorf("Expected %d, received %d", rowCount, n)
  1028  			}
  1029  		}
  1030  		require.NoError(t, rows.Err())
  1031  
  1032  		err = rows.Close()
  1033  		require.NoError(t, err)
  1034  
  1035  		rows, err = db.Query("select 1 where false")
  1036  		require.NoError(t, err)
  1037  
  1038  		rowCount = int64(0)
  1039  
  1040  		for rows.Next() {
  1041  			rowCount++
  1042  		}
  1043  		require.NoError(t, rows.Err())
  1044  		require.EqualValues(t, 0, rowCount)
  1045  
  1046  		err = rows.Close()
  1047  		require.NoError(t, err)
  1048  	})
  1049  }
  1050  
  1051  // https://github.com/jackc/pgx/issues/409
  1052  func TestScanJSONIntoJSONRawMessage(t *testing.T) {
  1053  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
  1054  		var msg json.RawMessage
  1055  
  1056  		err := db.QueryRow("select '{}'::json").Scan(&msg)
  1057  		require.NoError(t, err)
  1058  		require.EqualValues(t, []byte("{}"), []byte(msg))
  1059  	})
  1060  }
  1061  
  1062  type testLog struct {
  1063  	lvl  pgx.LogLevel
  1064  	msg  string
  1065  	data map[string]interface{}
  1066  }
  1067  
  1068  type testLogger struct {
  1069  	logs []testLog
  1070  }
  1071  
  1072  func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) {
  1073  	l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
  1074  }
  1075  
  1076  func TestRegisterConnConfig(t *testing.T) {
  1077  	connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
  1078  	require.NoError(t, err)
  1079  
  1080  	logger := &testLogger{}
  1081  	connConfig.Logger = logger
  1082  
  1083  	// Issue 947: Register and unregister a ConnConfig and ensure that the
  1084  	// returned connection string is not reused.
  1085  	connStr := stdlib.RegisterConnConfig(connConfig)
  1086  	require.Equal(t, "registeredConnConfig0", connStr)
  1087  	stdlib.UnregisterConnConfig(connStr)
  1088  
  1089  	connStr = stdlib.RegisterConnConfig(connConfig)
  1090  	defer stdlib.UnregisterConnConfig(connStr)
  1091  	require.Equal(t, "registeredConnConfig1", connStr)
  1092  
  1093  	db, err := sql.Open("pgx", connStr)
  1094  	require.NoError(t, err)
  1095  	defer closeDB(t, db)
  1096  
  1097  	var n int64
  1098  	err = db.QueryRow("select 1").Scan(&n)
  1099  	require.NoError(t, err)
  1100  
  1101  	l := logger.logs[len(logger.logs)-1]
  1102  	assert.Equal(t, "Query", l.msg)
  1103  	assert.Equal(t, "select 1", l.data["sql"])
  1104  }
  1105  
  1106  // https://github.com/jackc/pgx/issues/958
  1107  func TestConnQueryRowConstraintErrors(t *testing.T) {
  1108  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
  1109  		skipPostgreSQLVersion(t, db, "< 11", "Test requires PG 11+")
  1110  		skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
  1111  
  1112  		_, err := db.Exec(`create temporary table defer_test (
  1113  			id text primary key,
  1114  			n int not null, unique (n),
  1115  			unique (n) deferrable initially deferred )`)
  1116  		require.NoError(t, err)
  1117  
  1118  		_, err = db.Exec(`drop function if exists test_trigger cascade`)
  1119  		require.NoError(t, err)
  1120  
  1121  		_, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$
  1122  		begin
  1123  		if new.n = 4 then
  1124  			raise exception 'n cant be 4!';
  1125  		end if;
  1126  		return new;
  1127  	end$$`)
  1128  		require.NoError(t, err)
  1129  
  1130  		_, err = db.Exec(`create constraint trigger test
  1131  			after insert or update on defer_test
  1132  			deferrable initially deferred
  1133  			for each row
  1134  			execute function test_trigger()`)
  1135  		require.NoError(t, err)
  1136  
  1137  		_, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`)
  1138  		require.NoError(t, err)
  1139  
  1140  		var id string
  1141  		err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id)
  1142  		assert.Error(t, err)
  1143  	})
  1144  }
  1145  
  1146  func TestOptionBeforeAfterConnect(t *testing.T) {
  1147  	config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
  1148  	require.NoError(t, err)
  1149  
  1150  	var beforeConnConfigs []*pgx.ConnConfig
  1151  	var afterConns []*pgx.Conn
  1152  	db := stdlib.OpenDB(*config,
  1153  		stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error {
  1154  			beforeConnConfigs = append(beforeConnConfigs, connConfig)
  1155  			return nil
  1156  		}),
  1157  		stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error {
  1158  			afterConns = append(afterConns, conn)
  1159  			return nil
  1160  		}))
  1161  	defer closeDB(t, db)
  1162  
  1163  	// Force it to close and reopen a new connection after each query
  1164  	db.SetMaxIdleConns(0)
  1165  
  1166  	_, err = db.Exec("select 1")
  1167  	require.NoError(t, err)
  1168  
  1169  	_, err = db.Exec("select 1")
  1170  	require.NoError(t, err)
  1171  
  1172  	require.Len(t, beforeConnConfigs, 2)
  1173  	require.Len(t, afterConns, 2)
  1174  
  1175  	// Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they
  1176  	// are different objects, so can't use require.NotEqual
  1177  	require.False(t, config == beforeConnConfigs[0])
  1178  	require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1])
  1179  }
  1180  
  1181  func TestRandomizeHostOrderFunc(t *testing.T) {
  1182  	config, err := pgx.ParseConfig("postgres://host1,host2,host3")
  1183  	require.NoError(t, err)
  1184  
  1185  	// Test that at some point we connect to all 3 hosts
  1186  	hostsNotSeenYet := map[string]struct{}{
  1187  		"host1": struct{}{},
  1188  		"host2": struct{}{},
  1189  		"host3": struct{}{},
  1190  	}
  1191  
  1192  	// If we don't succeed within this many iterations, something is certainly wrong
  1193  	for i := 0; i < 100000; i++ {
  1194  		connCopy := *config
  1195  		stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy)
  1196  
  1197  		delete(hostsNotSeenYet, connCopy.Host)
  1198  		if len(hostsNotSeenYet) == 0 {
  1199  			return
  1200  		}
  1201  
  1202  	hostCheckLoop:
  1203  		for _, h := range []string{"host1", "host2", "host3"} {
  1204  			if connCopy.Host == h {
  1205  				continue
  1206  			}
  1207  			for _, f := range connCopy.Fallbacks {
  1208  				if f.Host == h {
  1209  					continue hostCheckLoop
  1210  				}
  1211  			}
  1212  			require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy)
  1213  		}
  1214  	}
  1215  
  1216  	require.Fail(t, "did not get all hosts as primaries after many randomizations")
  1217  }
  1218  
  1219  func TestResetSessionHookCalled(t *testing.T) {
  1220  	var mockCalled bool
  1221  
  1222  	connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
  1223  	require.NoError(t, err)
  1224  
  1225  	db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error {
  1226  		mockCalled = true
  1227  
  1228  		return nil
  1229  	}))
  1230  
  1231  	defer closeDB(t, db)
  1232  
  1233  	err = db.Ping()
  1234  	require.NoError(t, err)
  1235  
  1236  	err = db.Ping()
  1237  	require.NoError(t, err)
  1238  
  1239  	require.True(t, mockCalled)
  1240  }
  1241  

View as plain text