...

Source file src/github.com/jackc/pgx/v5/tx_test.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"os"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/jackc/pgx/v5"
    11  	"github.com/jackc/pgx/v5/pgconn"
    12  	"github.com/jackc/pgx/v5/pgxtest"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func TestTransactionSuccessfulCommit(t *testing.T) {
    17  	t.Parallel()
    18  
    19  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    20  	defer closeConn(t, conn)
    21  
    22  	createSql := `
    23      create temporary table foo(
    24        id integer,
    25        unique (id)
    26      );
    27    `
    28  
    29  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
    30  		t.Fatalf("Failed to create table: %v", err)
    31  	}
    32  
    33  	tx, err := conn.Begin(context.Background())
    34  	if err != nil {
    35  		t.Fatalf("conn.Begin failed: %v", err)
    36  	}
    37  
    38  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
    39  	if err != nil {
    40  		t.Fatalf("tx.Exec failed: %v", err)
    41  	}
    42  
    43  	err = tx.Commit(context.Background())
    44  	if err != nil {
    45  		t.Fatalf("tx.Commit failed: %v", err)
    46  	}
    47  
    48  	var n int64
    49  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
    50  	if err != nil {
    51  		t.Fatalf("QueryRow Scan failed: %v", err)
    52  	}
    53  	if n != 1 {
    54  		t.Fatalf("Did not receive correct number of rows: %v", n)
    55  	}
    56  }
    57  
    58  func TestTxCommitWhenTxBroken(t *testing.T) {
    59  	t.Parallel()
    60  
    61  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    62  	defer closeConn(t, conn)
    63  
    64  	createSql := `
    65      create temporary table foo(
    66        id integer,
    67        unique (id)
    68      );
    69    `
    70  
    71  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
    72  		t.Fatalf("Failed to create table: %v", err)
    73  	}
    74  
    75  	tx, err := conn.Begin(context.Background())
    76  	if err != nil {
    77  		t.Fatalf("conn.Begin failed: %v", err)
    78  	}
    79  
    80  	if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
    81  		t.Fatalf("tx.Exec failed: %v", err)
    82  	}
    83  
    84  	// Purposely break transaction
    85  	if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
    86  		t.Fatal("Unexpected success")
    87  	}
    88  
    89  	err = tx.Commit(context.Background())
    90  	if err != pgx.ErrTxCommitRollback {
    91  		t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
    92  	}
    93  
    94  	var n int64
    95  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
    96  	if err != nil {
    97  		t.Fatalf("QueryRow Scan failed: %v", err)
    98  	}
    99  	if n != 0 {
   100  		t.Fatalf("Did not receive correct number of rows: %v", n)
   101  	}
   102  }
   103  
   104  func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) {
   105  	t.Parallel()
   106  
   107  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   108  	defer closeConn(t, conn)
   109  
   110  	pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
   111  
   112  	createSql := `
   113      create temporary table foo(
   114        id integer,
   115        unique (id) initially deferred
   116      );
   117    `
   118  
   119  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   120  		t.Fatalf("Failed to create table: %v", err)
   121  	}
   122  
   123  	tx, err := conn.Begin(context.Background())
   124  	if err != nil {
   125  		t.Fatalf("conn.Begin failed: %v", err)
   126  	}
   127  
   128  	if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
   129  		t.Fatalf("tx.Exec failed: %v", err)
   130  	}
   131  
   132  	if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
   133  		t.Fatalf("tx.Exec failed: %v", err)
   134  	}
   135  
   136  	err = tx.Commit(context.Background())
   137  	if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" {
   138  		t.Fatalf("Expected unique constraint violation 23505, got %#v", err)
   139  	}
   140  
   141  	var n int64
   142  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   143  	if err != nil {
   144  		t.Fatalf("QueryRow Scan failed: %v", err)
   145  	}
   146  	if n != 0 {
   147  		t.Fatalf("Did not receive correct number of rows: %v", n)
   148  	}
   149  }
   150  
   151  func TestTxCommitSerializationFailure(t *testing.T) {
   152  	t.Parallel()
   153  
   154  	c1 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   155  	defer closeConn(t, c1)
   156  
   157  	if c1.PgConn().ParameterStatus("crdb_version") != "" {
   158  		t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/60754)")
   159  	}
   160  
   161  	c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   162  	defer closeConn(t, c2)
   163  
   164  	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
   165  	defer cancel()
   166  
   167  	c1.Exec(ctx, `drop table if exists tx_serializable_sums`)
   168  	_, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`)
   169  	if err != nil {
   170  		t.Fatalf("Unable to create temporary table: %v", err)
   171  	}
   172  	defer c1.Exec(ctx, `drop table tx_serializable_sums`)
   173  
   174  	tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
   175  	if err != nil {
   176  		t.Fatalf("Begin failed: %v", err)
   177  	}
   178  	defer tx1.Rollback(ctx)
   179  
   180  	tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
   181  	if err != nil {
   182  		t.Fatalf("Begin failed: %v", err)
   183  	}
   184  	defer tx2.Rollback(ctx)
   185  
   186  	_, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
   187  	if err != nil {
   188  		t.Fatalf("Exec failed: %v", err)
   189  	}
   190  
   191  	_, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
   192  	if err != nil {
   193  		t.Fatalf("Exec failed: %v", err)
   194  	}
   195  
   196  	err = tx1.Commit(ctx)
   197  	if err != nil {
   198  		t.Fatalf("Commit failed: %v", err)
   199  	}
   200  
   201  	err = tx2.Commit(ctx)
   202  	if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" {
   203  		t.Fatalf("Expected serialization error 40001, got %#v", err)
   204  	}
   205  
   206  	ensureConnValid(t, c1)
   207  	ensureConnValid(t, c2)
   208  }
   209  
   210  func TestTransactionSuccessfulRollback(t *testing.T) {
   211  	t.Parallel()
   212  
   213  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   214  	defer closeConn(t, conn)
   215  
   216  	createSql := `
   217      create temporary table foo(
   218        id integer,
   219        unique (id)
   220      );
   221    `
   222  
   223  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   224  		t.Fatalf("Failed to create table: %v", err)
   225  	}
   226  
   227  	tx, err := conn.Begin(context.Background())
   228  	if err != nil {
   229  		t.Fatalf("conn.Begin failed: %v", err)
   230  	}
   231  
   232  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
   233  	if err != nil {
   234  		t.Fatalf("tx.Exec failed: %v", err)
   235  	}
   236  
   237  	err = tx.Rollback(context.Background())
   238  	if err != nil {
   239  		t.Fatalf("tx.Rollback failed: %v", err)
   240  	}
   241  
   242  	var n int64
   243  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   244  	if err != nil {
   245  		t.Fatalf("QueryRow Scan failed: %v", err)
   246  	}
   247  	if n != 0 {
   248  		t.Fatalf("Did not receive correct number of rows: %v", n)
   249  	}
   250  }
   251  
   252  func TestTransactionRollbackFailsClosesConnection(t *testing.T) {
   253  	t.Parallel()
   254  
   255  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   256  	defer closeConn(t, conn)
   257  
   258  	ctx, cancel := context.WithCancel(context.Background())
   259  
   260  	tx, err := conn.Begin(ctx)
   261  	require.NoError(t, err)
   262  
   263  	cancel()
   264  
   265  	err = tx.Rollback(ctx)
   266  	require.Error(t, err)
   267  
   268  	require.True(t, conn.IsClosed())
   269  }
   270  
   271  func TestBeginIsoLevels(t *testing.T) {
   272  	t.Parallel()
   273  
   274  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   275  	defer closeConn(t, conn)
   276  
   277  	pgxtest.SkipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)")
   278  
   279  	isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
   280  	for _, iso := range isoLevels {
   281  		tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: iso})
   282  		if err != nil {
   283  			t.Fatalf("conn.Begin failed: %v", err)
   284  		}
   285  
   286  		var level pgx.TxIsoLevel
   287  		conn.QueryRow(context.Background(), "select current_setting('transaction_isolation')").Scan(&level)
   288  		if level != iso {
   289  			t.Errorf("Expected to be in isolation level %v but was %v", iso, level)
   290  		}
   291  
   292  		err = tx.Rollback(context.Background())
   293  		if err != nil {
   294  			t.Fatalf("tx.Rollback failed: %v", err)
   295  		}
   296  	}
   297  }
   298  
   299  func TestBeginFunc(t *testing.T) {
   300  	t.Parallel()
   301  
   302  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   303  	defer closeConn(t, conn)
   304  
   305  	createSql := `
   306      create temporary table foo(
   307        id integer,
   308        unique (id)
   309      );
   310    `
   311  
   312  	_, err := conn.Exec(context.Background(), createSql)
   313  	require.NoError(t, err)
   314  
   315  	err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
   316  		_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
   317  		require.NoError(t, err)
   318  		return nil
   319  	})
   320  	require.NoError(t, err)
   321  
   322  	var n int64
   323  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   324  	require.NoError(t, err)
   325  	require.EqualValues(t, 1, n)
   326  }
   327  
   328  func TestBeginFuncRollbackOnError(t *testing.T) {
   329  	t.Parallel()
   330  
   331  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   332  	defer closeConn(t, conn)
   333  
   334  	createSql := `
   335      create temporary table foo(
   336        id integer,
   337        unique (id)
   338      );
   339    `
   340  
   341  	_, err := conn.Exec(context.Background(), createSql)
   342  	require.NoError(t, err)
   343  
   344  	err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
   345  		_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
   346  		require.NoError(t, err)
   347  		return errors.New("some error")
   348  	})
   349  	require.EqualError(t, err, "some error")
   350  
   351  	var n int64
   352  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   353  	require.NoError(t, err)
   354  	require.EqualValues(t, 0, n)
   355  }
   356  
   357  func TestBeginReadOnly(t *testing.T) {
   358  	t.Parallel()
   359  
   360  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   361  	defer closeConn(t, conn)
   362  
   363  	tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly})
   364  	if err != nil {
   365  		t.Fatalf("conn.Begin failed: %v", err)
   366  	}
   367  	defer tx.Rollback(context.Background())
   368  
   369  	_, err = conn.Exec(context.Background(), "create table foo(id serial primary key)")
   370  	if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" {
   371  		t.Errorf("Expected error SQLSTATE 25006, but got %#v", err)
   372  	}
   373  }
   374  
   375  func TestBeginTxBeginQuery(t *testing.T) {
   376  	t.Parallel()
   377  
   378  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   379  	defer cancel()
   380  
   381  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   382  		tx, err := conn.BeginTx(ctx, pgx.TxOptions{BeginQuery: "begin read only"})
   383  		require.NoError(t, err)
   384  		defer tx.Rollback(ctx)
   385  
   386  		var readOnly bool
   387  		conn.QueryRow(ctx, "select current_setting('transaction_read_only')::bool").Scan(&readOnly)
   388  		require.True(t, readOnly)
   389  
   390  		err = tx.Rollback(ctx)
   391  		require.NoError(t, err)
   392  	})
   393  }
   394  
   395  func TestTxNestedTransactionCommit(t *testing.T) {
   396  	t.Parallel()
   397  
   398  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   399  	defer closeConn(t, conn)
   400  
   401  	createSql := `
   402      create temporary table foo(
   403        id integer,
   404        unique (id)
   405      );
   406    `
   407  
   408  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   409  		t.Fatalf("Failed to create table: %v", err)
   410  	}
   411  
   412  	tx, err := conn.Begin(context.Background())
   413  	if err != nil {
   414  		t.Fatal(err)
   415  	}
   416  
   417  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
   418  	if err != nil {
   419  		t.Fatalf("tx.Exec failed: %v", err)
   420  	}
   421  
   422  	nestedTx, err := tx.Begin(context.Background())
   423  	if err != nil {
   424  		t.Fatal(err)
   425  	}
   426  
   427  	_, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
   428  	if err != nil {
   429  		t.Fatalf("nestedTx.Exec failed: %v", err)
   430  	}
   431  
   432  	doubleNestedTx, err := nestedTx.Begin(context.Background())
   433  	if err != nil {
   434  		t.Fatal(err)
   435  	}
   436  
   437  	_, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)")
   438  	if err != nil {
   439  		t.Fatalf("doubleNestedTx.Exec failed: %v", err)
   440  	}
   441  
   442  	err = doubleNestedTx.Commit(context.Background())
   443  	if err != nil {
   444  		t.Fatalf("doubleNestedTx.Commit failed: %v", err)
   445  	}
   446  
   447  	err = nestedTx.Commit(context.Background())
   448  	if err != nil {
   449  		t.Fatalf("nestedTx.Commit failed: %v", err)
   450  	}
   451  
   452  	err = tx.Commit(context.Background())
   453  	if err != nil {
   454  		t.Fatalf("tx.Commit failed: %v", err)
   455  	}
   456  
   457  	var n int64
   458  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   459  	if err != nil {
   460  		t.Fatalf("QueryRow Scan failed: %v", err)
   461  	}
   462  	if n != 3 {
   463  		t.Fatalf("Did not receive correct number of rows: %v", n)
   464  	}
   465  }
   466  
   467  func TestTxNestedTransactionRollback(t *testing.T) {
   468  	t.Parallel()
   469  
   470  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   471  	defer closeConn(t, conn)
   472  
   473  	createSql := `
   474      create temporary table foo(
   475        id integer,
   476        unique (id)
   477      );
   478    `
   479  
   480  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   481  		t.Fatalf("Failed to create table: %v", err)
   482  	}
   483  
   484  	tx, err := conn.Begin(context.Background())
   485  	if err != nil {
   486  		t.Fatal(err)
   487  	}
   488  
   489  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
   490  	if err != nil {
   491  		t.Fatalf("tx.Exec failed: %v", err)
   492  	}
   493  
   494  	nestedTx, err := tx.Begin(context.Background())
   495  	if err != nil {
   496  		t.Fatal(err)
   497  	}
   498  
   499  	_, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
   500  	if err != nil {
   501  		t.Fatalf("nestedTx.Exec failed: %v", err)
   502  	}
   503  
   504  	err = nestedTx.Rollback(context.Background())
   505  	if err != nil {
   506  		t.Fatalf("nestedTx.Rollback failed: %v", err)
   507  	}
   508  
   509  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (3)")
   510  	if err != nil {
   511  		t.Fatalf("tx.Exec failed: %v", err)
   512  	}
   513  
   514  	err = tx.Commit(context.Background())
   515  	if err != nil {
   516  		t.Fatalf("tx.Commit failed: %v", err)
   517  	}
   518  
   519  	var n int64
   520  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   521  	if err != nil {
   522  		t.Fatalf("QueryRow Scan failed: %v", err)
   523  	}
   524  	if n != 2 {
   525  		t.Fatalf("Did not receive correct number of rows: %v", n)
   526  	}
   527  }
   528  
   529  func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
   530  	t.Parallel()
   531  
   532  	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   533  	defer closeConn(t, db)
   534  
   535  	createSql := `
   536      create temporary table foo(
   537        id integer,
   538        unique (id)
   539      );
   540    `
   541  
   542  	_, err := db.Exec(context.Background(), createSql)
   543  	require.NoError(t, err)
   544  
   545  	err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
   546  		_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
   547  		require.NoError(t, err)
   548  
   549  		err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
   550  			_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
   551  			require.NoError(t, err)
   552  
   553  			err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
   554  				_, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
   555  				require.NoError(t, err)
   556  				return nil
   557  			})
   558  			require.NoError(t, err)
   559  
   560  			return nil
   561  		})
   562  		require.NoError(t, err)
   563  		return nil
   564  	})
   565  	require.NoError(t, err)
   566  
   567  	var n int64
   568  	err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   569  	require.NoError(t, err)
   570  	require.EqualValues(t, 3, n)
   571  }
   572  
   573  func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
   574  	t.Parallel()
   575  
   576  	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   577  	defer closeConn(t, db)
   578  
   579  	createSql := `
   580      create temporary table foo(
   581        id integer,
   582        unique (id)
   583      );
   584    `
   585  
   586  	_, err := db.Exec(context.Background(), createSql)
   587  	require.NoError(t, err)
   588  
   589  	err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
   590  		_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
   591  		require.NoError(t, err)
   592  
   593  		err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
   594  			_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
   595  			require.NoError(t, err)
   596  			return errors.New("do a rollback")
   597  		})
   598  		require.EqualError(t, err, "do a rollback")
   599  
   600  		_, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
   601  		require.NoError(t, err)
   602  
   603  		return nil
   604  	})
   605  	require.NoError(t, err)
   606  
   607  	var n int64
   608  	err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   609  	require.NoError(t, err)
   610  	require.EqualValues(t, 2, n)
   611  }
   612  
   613  func TestTxSendBatchClosed(t *testing.T) {
   614  	t.Parallel()
   615  
   616  	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   617  	defer closeConn(t, db)
   618  
   619  	tx, err := db.Begin(context.Background())
   620  	require.NoError(t, err)
   621  	defer tx.Rollback(context.Background())
   622  
   623  	err = tx.Commit(context.Background())
   624  	require.NoError(t, err)
   625  
   626  	batch := &pgx.Batch{}
   627  	batch.Queue("select 1")
   628  	batch.Queue("select 2")
   629  	batch.Queue("select 3")
   630  
   631  	br := tx.SendBatch(context.Background(), batch)
   632  	defer br.Close()
   633  
   634  	var n int
   635  
   636  	_, err = br.Exec()
   637  	require.Error(t, err)
   638  
   639  	err = br.QueryRow().Scan(&n)
   640  	require.Error(t, err)
   641  
   642  	_, err = br.Query()
   643  	require.Error(t, err)
   644  }
   645  

View as plain text