...

Source file src/github.com/jackc/pgx/v4/tx_test.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"os"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/jackc/pgconn"
    11  	"github.com/jackc/pgx/v4"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func TestTransactionSuccessfulCommit(t *testing.T) {
    16  	t.Parallel()
    17  
    18  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    19  	defer closeConn(t, conn)
    20  
    21  	createSql := `
    22      create temporary table foo(
    23        id integer,
    24        unique (id)
    25      );
    26    `
    27  
    28  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
    29  		t.Fatalf("Failed to create table: %v", err)
    30  	}
    31  
    32  	tx, err := conn.Begin(context.Background())
    33  	if err != nil {
    34  		t.Fatalf("conn.Begin failed: %v", err)
    35  	}
    36  
    37  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
    38  	if err != nil {
    39  		t.Fatalf("tx.Exec failed: %v", err)
    40  	}
    41  
    42  	err = tx.Commit(context.Background())
    43  	if err != nil {
    44  		t.Fatalf("tx.Commit failed: %v", err)
    45  	}
    46  
    47  	var n int64
    48  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
    49  	if err != nil {
    50  		t.Fatalf("QueryRow Scan failed: %v", err)
    51  	}
    52  	if n != 1 {
    53  		t.Fatalf("Did not receive correct number of rows: %v", n)
    54  	}
    55  }
    56  
    57  func TestTxCommitWhenTxBroken(t *testing.T) {
    58  	t.Parallel()
    59  
    60  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    61  	defer closeConn(t, conn)
    62  
    63  	createSql := `
    64      create temporary table foo(
    65        id integer,
    66        unique (id)
    67      );
    68    `
    69  
    70  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
    71  		t.Fatalf("Failed to create table: %v", err)
    72  	}
    73  
    74  	tx, err := conn.Begin(context.Background())
    75  	if err != nil {
    76  		t.Fatalf("conn.Begin failed: %v", err)
    77  	}
    78  
    79  	if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
    80  		t.Fatalf("tx.Exec failed: %v", err)
    81  	}
    82  
    83  	// Purposely break transaction
    84  	if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
    85  		t.Fatal("Unexpected success")
    86  	}
    87  
    88  	err = tx.Commit(context.Background())
    89  	if err != pgx.ErrTxCommitRollback {
    90  		t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
    91  	}
    92  
    93  	var n int64
    94  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
    95  	if err != nil {
    96  		t.Fatalf("QueryRow Scan failed: %v", err)
    97  	}
    98  	if n != 0 {
    99  		t.Fatalf("Did not receive correct number of rows: %v", n)
   100  	}
   101  }
   102  
   103  func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) {
   104  	t.Parallel()
   105  
   106  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   107  	defer closeConn(t, conn)
   108  
   109  	skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
   110  
   111  	createSql := `
   112      create temporary table foo(
   113        id integer,
   114        unique (id) initially deferred
   115      );
   116    `
   117  
   118  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   119  		t.Fatalf("Failed to create table: %v", err)
   120  	}
   121  
   122  	tx, err := conn.Begin(context.Background())
   123  	if err != nil {
   124  		t.Fatalf("conn.Begin failed: %v", err)
   125  	}
   126  
   127  	if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
   128  		t.Fatalf("tx.Exec failed: %v", err)
   129  	}
   130  
   131  	if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
   132  		t.Fatalf("tx.Exec failed: %v", err)
   133  	}
   134  
   135  	err = tx.Commit(context.Background())
   136  	if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" {
   137  		t.Fatalf("Expected unique constraint violation 23505, got %#v", err)
   138  	}
   139  
   140  	var n int64
   141  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   142  	if err != nil {
   143  		t.Fatalf("QueryRow Scan failed: %v", err)
   144  	}
   145  	if n != 0 {
   146  		t.Fatalf("Did not receive correct number of rows: %v", n)
   147  	}
   148  }
   149  
   150  func TestTxCommitSerializationFailure(t *testing.T) {
   151  	t.Parallel()
   152  
   153  	c1 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   154  	defer closeConn(t, c1)
   155  
   156  	if c1.PgConn().ParameterStatus("crdb_version") != "" {
   157  		t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/60754)")
   158  	}
   159  
   160  	c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   161  	defer closeConn(t, c2)
   162  
   163  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   164  	defer cancel()
   165  
   166  	c1.Exec(ctx, `drop table if exists tx_serializable_sums`)
   167  	_, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`)
   168  	if err != nil {
   169  		t.Fatalf("Unable to create temporary table: %v", err)
   170  	}
   171  	defer c1.Exec(ctx, `drop table tx_serializable_sums`)
   172  
   173  	tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
   174  	if err != nil {
   175  		t.Fatalf("Begin failed: %v", err)
   176  	}
   177  	defer tx1.Rollback(ctx)
   178  
   179  	tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
   180  	if err != nil {
   181  		t.Fatalf("Begin failed: %v", err)
   182  	}
   183  	defer tx2.Rollback(ctx)
   184  
   185  	_, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
   186  	if err != nil {
   187  		t.Fatalf("Exec failed: %v", err)
   188  	}
   189  
   190  	_, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
   191  	if err != nil {
   192  		t.Fatalf("Exec failed: %v", err)
   193  	}
   194  
   195  	err = tx1.Commit(ctx)
   196  	if err != nil {
   197  		t.Fatalf("Commit failed: %v", err)
   198  	}
   199  
   200  	err = tx2.Commit(ctx)
   201  	if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" {
   202  		t.Fatalf("Expected serialization error 40001, got %#v", err)
   203  	}
   204  
   205  	ensureConnValid(t, c1)
   206  	ensureConnValid(t, c2)
   207  }
   208  
   209  func TestTransactionSuccessfulRollback(t *testing.T) {
   210  	t.Parallel()
   211  
   212  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   213  	defer closeConn(t, conn)
   214  
   215  	createSql := `
   216      create temporary table foo(
   217        id integer,
   218        unique (id)
   219      );
   220    `
   221  
   222  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   223  		t.Fatalf("Failed to create table: %v", err)
   224  	}
   225  
   226  	tx, err := conn.Begin(context.Background())
   227  	if err != nil {
   228  		t.Fatalf("conn.Begin failed: %v", err)
   229  	}
   230  
   231  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
   232  	if err != nil {
   233  		t.Fatalf("tx.Exec failed: %v", err)
   234  	}
   235  
   236  	err = tx.Rollback(context.Background())
   237  	if err != nil {
   238  		t.Fatalf("tx.Rollback failed: %v", err)
   239  	}
   240  
   241  	var n int64
   242  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   243  	if err != nil {
   244  		t.Fatalf("QueryRow Scan failed: %v", err)
   245  	}
   246  	if n != 0 {
   247  		t.Fatalf("Did not receive correct number of rows: %v", n)
   248  	}
   249  }
   250  
   251  func TestTransactionRollbackFailsClosesConnection(t *testing.T) {
   252  	t.Parallel()
   253  
   254  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   255  	defer closeConn(t, conn)
   256  
   257  	ctx, cancel := context.WithCancel(context.Background())
   258  
   259  	tx, err := conn.Begin(ctx)
   260  	require.NoError(t, err)
   261  
   262  	cancel()
   263  
   264  	err = tx.Rollback(ctx)
   265  	require.Error(t, err)
   266  
   267  	require.True(t, conn.IsClosed())
   268  }
   269  
   270  func TestBeginIsoLevels(t *testing.T) {
   271  	t.Parallel()
   272  
   273  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   274  	defer closeConn(t, conn)
   275  
   276  	skipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)")
   277  
   278  	isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
   279  	for _, iso := range isoLevels {
   280  		tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: iso})
   281  		if err != nil {
   282  			t.Fatalf("conn.Begin failed: %v", err)
   283  		}
   284  
   285  		var level pgx.TxIsoLevel
   286  		conn.QueryRow(context.Background(), "select current_setting('transaction_isolation')").Scan(&level)
   287  		if level != iso {
   288  			t.Errorf("Expected to be in isolation level %v but was %v", iso, level)
   289  		}
   290  
   291  		err = tx.Rollback(context.Background())
   292  		if err != nil {
   293  			t.Fatalf("tx.Rollback failed: %v", err)
   294  		}
   295  	}
   296  }
   297  
   298  func TestBeginFunc(t *testing.T) {
   299  	t.Parallel()
   300  
   301  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   302  	defer closeConn(t, conn)
   303  
   304  	createSql := `
   305      create temporary table foo(
   306        id integer,
   307        unique (id)
   308      );
   309    `
   310  
   311  	_, err := conn.Exec(context.Background(), createSql)
   312  	require.NoError(t, err)
   313  
   314  	err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
   315  		_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
   316  		require.NoError(t, err)
   317  		return nil
   318  	})
   319  	require.NoError(t, err)
   320  
   321  	var n int64
   322  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   323  	require.NoError(t, err)
   324  	require.EqualValues(t, 1, n)
   325  }
   326  
   327  func TestBeginFuncRollbackOnError(t *testing.T) {
   328  	t.Parallel()
   329  
   330  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   331  	defer closeConn(t, conn)
   332  
   333  	createSql := `
   334      create temporary table foo(
   335        id integer,
   336        unique (id)
   337      );
   338    `
   339  
   340  	_, err := conn.Exec(context.Background(), createSql)
   341  	require.NoError(t, err)
   342  
   343  	err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
   344  		_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
   345  		require.NoError(t, err)
   346  		return errors.New("some error")
   347  	})
   348  	require.EqualError(t, err, "some error")
   349  
   350  	var n int64
   351  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   352  	require.NoError(t, err)
   353  	require.EqualValues(t, 0, n)
   354  }
   355  
   356  func TestBeginReadOnly(t *testing.T) {
   357  	t.Parallel()
   358  
   359  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   360  	defer closeConn(t, conn)
   361  
   362  	tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly})
   363  	if err != nil {
   364  		t.Fatalf("conn.Begin failed: %v", err)
   365  	}
   366  	defer tx.Rollback(context.Background())
   367  
   368  	_, err = conn.Exec(context.Background(), "create table foo(id serial primary key)")
   369  	if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" {
   370  		t.Errorf("Expected error SQLSTATE 25006, but got %#v", err)
   371  	}
   372  }
   373  
   374  func TestTxNestedTransactionCommit(t *testing.T) {
   375  	t.Parallel()
   376  
   377  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   378  	defer closeConn(t, conn)
   379  
   380  	createSql := `
   381      create temporary table foo(
   382        id integer,
   383        unique (id)
   384      );
   385    `
   386  
   387  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   388  		t.Fatalf("Failed to create table: %v", err)
   389  	}
   390  
   391  	tx, err := conn.Begin(context.Background())
   392  	if err != nil {
   393  		t.Fatal(err)
   394  	}
   395  
   396  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
   397  	if err != nil {
   398  		t.Fatalf("tx.Exec failed: %v", err)
   399  	}
   400  
   401  	nestedTx, err := tx.Begin(context.Background())
   402  	if err != nil {
   403  		t.Fatal(err)
   404  	}
   405  
   406  	_, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
   407  	if err != nil {
   408  		t.Fatalf("nestedTx.Exec failed: %v", err)
   409  	}
   410  
   411  	doubleNestedTx, err := nestedTx.Begin(context.Background())
   412  	if err != nil {
   413  		t.Fatal(err)
   414  	}
   415  
   416  	_, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)")
   417  	if err != nil {
   418  		t.Fatalf("doubleNestedTx.Exec failed: %v", err)
   419  	}
   420  
   421  	err = doubleNestedTx.Commit(context.Background())
   422  	if err != nil {
   423  		t.Fatalf("doubleNestedTx.Commit failed: %v", err)
   424  	}
   425  
   426  	err = nestedTx.Commit(context.Background())
   427  	if err != nil {
   428  		t.Fatalf("nestedTx.Commit failed: %v", err)
   429  	}
   430  
   431  	err = tx.Commit(context.Background())
   432  	if err != nil {
   433  		t.Fatalf("tx.Commit failed: %v", err)
   434  	}
   435  
   436  	var n int64
   437  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   438  	if err != nil {
   439  		t.Fatalf("QueryRow Scan failed: %v", err)
   440  	}
   441  	if n != 3 {
   442  		t.Fatalf("Did not receive correct number of rows: %v", n)
   443  	}
   444  }
   445  
   446  func TestTxNestedTransactionRollback(t *testing.T) {
   447  	t.Parallel()
   448  
   449  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   450  	defer closeConn(t, conn)
   451  
   452  	createSql := `
   453      create temporary table foo(
   454        id integer,
   455        unique (id)
   456      );
   457    `
   458  
   459  	if _, err := conn.Exec(context.Background(), createSql); err != nil {
   460  		t.Fatalf("Failed to create table: %v", err)
   461  	}
   462  
   463  	tx, err := conn.Begin(context.Background())
   464  	if err != nil {
   465  		t.Fatal(err)
   466  	}
   467  
   468  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
   469  	if err != nil {
   470  		t.Fatalf("tx.Exec failed: %v", err)
   471  	}
   472  
   473  	nestedTx, err := tx.Begin(context.Background())
   474  	if err != nil {
   475  		t.Fatal(err)
   476  	}
   477  
   478  	_, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
   479  	if err != nil {
   480  		t.Fatalf("nestedTx.Exec failed: %v", err)
   481  	}
   482  
   483  	err = nestedTx.Rollback(context.Background())
   484  	if err != nil {
   485  		t.Fatalf("nestedTx.Rollback failed: %v", err)
   486  	}
   487  
   488  	_, err = tx.Exec(context.Background(), "insert into foo(id) values (3)")
   489  	if err != nil {
   490  		t.Fatalf("tx.Exec failed: %v", err)
   491  	}
   492  
   493  	err = tx.Commit(context.Background())
   494  	if err != nil {
   495  		t.Fatalf("tx.Commit failed: %v", err)
   496  	}
   497  
   498  	var n int64
   499  	err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   500  	if err != nil {
   501  		t.Fatalf("QueryRow Scan failed: %v", err)
   502  	}
   503  	if n != 2 {
   504  		t.Fatalf("Did not receive correct number of rows: %v", n)
   505  	}
   506  }
   507  
   508  func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
   509  	t.Parallel()
   510  
   511  	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   512  	defer closeConn(t, db)
   513  
   514  	createSql := `
   515      create temporary table foo(
   516        id integer,
   517        unique (id)
   518      );
   519    `
   520  
   521  	_, err := db.Exec(context.Background(), createSql)
   522  	require.NoError(t, err)
   523  
   524  	err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
   525  		_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
   526  		require.NoError(t, err)
   527  
   528  		err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
   529  			_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
   530  			require.NoError(t, err)
   531  
   532  			err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
   533  				_, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
   534  				require.NoError(t, err)
   535  				return nil
   536  			})
   537  
   538  			return nil
   539  		})
   540  		require.NoError(t, err)
   541  		return nil
   542  	})
   543  	require.NoError(t, err)
   544  
   545  	var n int64
   546  	err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   547  	require.NoError(t, err)
   548  	require.EqualValues(t, 3, n)
   549  }
   550  
   551  func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
   552  	t.Parallel()
   553  
   554  	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   555  	defer closeConn(t, db)
   556  
   557  	createSql := `
   558      create temporary table foo(
   559        id integer,
   560        unique (id)
   561      );
   562    `
   563  
   564  	_, err := db.Exec(context.Background(), createSql)
   565  	require.NoError(t, err)
   566  
   567  	err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
   568  		_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
   569  		require.NoError(t, err)
   570  
   571  		err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
   572  			_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
   573  			require.NoError(t, err)
   574  			return errors.New("do a rollback")
   575  		})
   576  		require.EqualError(t, err, "do a rollback")
   577  
   578  		_, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
   579  		require.NoError(t, err)
   580  
   581  		return nil
   582  	})
   583  
   584  	var n int64
   585  	err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
   586  	require.NoError(t, err)
   587  	require.EqualValues(t, 2, n)
   588  }
   589  
   590  func TestTxSendBatchClosed(t *testing.T) {
   591  	t.Parallel()
   592  
   593  	db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   594  	defer closeConn(t, db)
   595  
   596  	tx, err := db.Begin(context.Background())
   597  	require.NoError(t, err)
   598  	defer tx.Rollback(context.Background())
   599  
   600  	err = tx.Commit(context.Background())
   601  	require.NoError(t, err)
   602  
   603  	batch := &pgx.Batch{}
   604  	batch.Queue("select 1")
   605  	batch.Queue("select 2")
   606  	batch.Queue("select 3")
   607  
   608  	br := tx.SendBatch(context.Background(), batch)
   609  	defer br.Close()
   610  
   611  	var n int
   612  
   613  	_, err = br.Exec()
   614  	require.Error(t, err)
   615  
   616  	err = br.QueryRow().Scan(&n)
   617  	require.Error(t, err)
   618  
   619  	_, err = br.Query()
   620  	require.Error(t, err)
   621  }
   622  

View as plain text