...

Source file src/github.com/jackc/pgx/v5/batch_test.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/jackc/pgx/v5"
    12  	"github.com/jackc/pgx/v5/pgconn"
    13  	"github.com/jackc/pgx/v5/pgxtest"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func TestConnSendBatch(t *testing.T) {
    19  	t.Parallel()
    20  
    21  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
    22  	defer cancel()
    23  
    24  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
    25  		pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
    26  
    27  		sql := `create temporary table ledger(
    28  	  id serial primary key,
    29  	  description varchar not null,
    30  	  amount int not null
    31  	);`
    32  		mustExec(t, conn, sql)
    33  
    34  		batch := &pgx.Batch{}
    35  		batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
    36  		batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
    37  		batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
    38  		batch.Queue("select id, description, amount from ledger order by id")
    39  		batch.Queue("select id, description, amount from ledger order by id")
    40  		batch.Queue("select * from ledger where false")
    41  		batch.Queue("select sum(amount) from ledger")
    42  
    43  		br := conn.SendBatch(ctx, batch)
    44  
    45  		ct, err := br.Exec()
    46  		if err != nil {
    47  			t.Error(err)
    48  		}
    49  		if ct.RowsAffected() != 1 {
    50  			t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
    51  		}
    52  
    53  		ct, err = br.Exec()
    54  		if err != nil {
    55  			t.Error(err)
    56  		}
    57  		if ct.RowsAffected() != 1 {
    58  			t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
    59  		}
    60  
    61  		ct, err = br.Exec()
    62  		if err != nil {
    63  			t.Error(err)
    64  		}
    65  		if ct.RowsAffected() != 1 {
    66  			t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
    67  		}
    68  
    69  		selectFromLedgerExpectedRows := []struct {
    70  			id          int32
    71  			description string
    72  			amount      int32
    73  		}{
    74  			{1, "q1", 1},
    75  			{2, "q2", 2},
    76  			{3, "q3", 3},
    77  		}
    78  
    79  		rows, err := br.Query()
    80  		if err != nil {
    81  			t.Error(err)
    82  		}
    83  
    84  		var id int32
    85  		var description string
    86  		var amount int32
    87  		rowCount := 0
    88  
    89  		for rows.Next() {
    90  			if rowCount >= len(selectFromLedgerExpectedRows) {
    91  				t.Fatalf("got too many rows: %d", rowCount)
    92  			}
    93  
    94  			if err := rows.Scan(&id, &description, &amount); err != nil {
    95  				t.Fatalf("row %d: %v", rowCount, err)
    96  			}
    97  
    98  			if id != selectFromLedgerExpectedRows[rowCount].id {
    99  				t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
   100  			}
   101  			if description != selectFromLedgerExpectedRows[rowCount].description {
   102  				t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
   103  			}
   104  			if amount != selectFromLedgerExpectedRows[rowCount].amount {
   105  				t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
   106  			}
   107  
   108  			rowCount++
   109  		}
   110  
   111  		if rows.Err() != nil {
   112  			t.Fatal(rows.Err())
   113  		}
   114  
   115  		rowCount = 0
   116  		rows, _ = br.Query()
   117  		_, err = pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
   118  			if id != selectFromLedgerExpectedRows[rowCount].id {
   119  				t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
   120  			}
   121  			if description != selectFromLedgerExpectedRows[rowCount].description {
   122  				t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
   123  			}
   124  			if amount != selectFromLedgerExpectedRows[rowCount].amount {
   125  				t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
   126  			}
   127  
   128  			rowCount++
   129  
   130  			return nil
   131  		})
   132  		if err != nil {
   133  			t.Error(err)
   134  		}
   135  
   136  		err = br.QueryRow().Scan(&id, &description, &amount)
   137  		if !errors.Is(err, pgx.ErrNoRows) {
   138  			t.Errorf("expected pgx.ErrNoRows but got: %v", err)
   139  		}
   140  
   141  		err = br.QueryRow().Scan(&amount)
   142  		if err != nil {
   143  			t.Error(err)
   144  		}
   145  		if amount != 6 {
   146  			t.Errorf("amount => %v, want %v", amount, 6)
   147  		}
   148  
   149  		err = br.Close()
   150  		if err != nil {
   151  			t.Fatal(err)
   152  		}
   153  	})
   154  }
   155  
   156  func TestConnSendBatchQueuedQuery(t *testing.T) {
   157  	t.Parallel()
   158  
   159  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   160  	defer cancel()
   161  
   162  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   163  		pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
   164  
   165  		sql := `create temporary table ledger(
   166  	  id serial primary key,
   167  	  description varchar not null,
   168  	  amount int not null
   169  	);`
   170  		mustExec(t, conn, sql)
   171  
   172  		batch := &pgx.Batch{}
   173  
   174  		batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error {
   175  			assert.EqualValues(t, 1, ct.RowsAffected())
   176  			return nil
   177  		})
   178  
   179  		batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error {
   180  			assert.EqualValues(t, 1, ct.RowsAffected())
   181  			return nil
   182  		})
   183  
   184  		batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error {
   185  			assert.EqualValues(t, 1, ct.RowsAffected())
   186  			return nil
   187  		})
   188  
   189  		selectFromLedgerExpectedRows := []struct {
   190  			id          int32
   191  			description string
   192  			amount      int32
   193  		}{
   194  			{1, "q1", 1},
   195  			{2, "q2", 2},
   196  			{3, "q3", 3},
   197  		}
   198  
   199  		batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
   200  			rowCount := 0
   201  			var id int32
   202  			var description string
   203  			var amount int32
   204  			_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
   205  				assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
   206  				assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
   207  				assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
   208  				rowCount++
   209  
   210  				return nil
   211  			})
   212  			assert.NoError(t, err)
   213  			return nil
   214  		})
   215  
   216  		batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
   217  			rowCount := 0
   218  			var id int32
   219  			var description string
   220  			var amount int32
   221  			_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
   222  				assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
   223  				assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
   224  				assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
   225  				rowCount++
   226  
   227  				return nil
   228  			})
   229  			assert.NoError(t, err)
   230  			return nil
   231  		})
   232  
   233  		batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error {
   234  			err := row.Scan(nil, nil, nil)
   235  			assert.ErrorIs(t, err, pgx.ErrNoRows)
   236  			return nil
   237  		})
   238  
   239  		batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error {
   240  			var sumAmount int32
   241  			err := row.Scan(&sumAmount)
   242  			assert.NoError(t, err)
   243  			assert.EqualValues(t, 6, sumAmount)
   244  			return nil
   245  		})
   246  
   247  		err := conn.SendBatch(ctx, batch).Close()
   248  		assert.NoError(t, err)
   249  	})
   250  }
   251  
   252  func TestConnSendBatchMany(t *testing.T) {
   253  	t.Parallel()
   254  
   255  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   256  	defer cancel()
   257  
   258  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   259  		sql := `create temporary table ledger(
   260  	  id serial primary key,
   261  	  description varchar not null,
   262  	  amount int not null
   263  	);`
   264  		mustExec(t, conn, sql)
   265  
   266  		batch := &pgx.Batch{}
   267  
   268  		numInserts := 1000
   269  
   270  		for i := 0; i < numInserts; i++ {
   271  			batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
   272  		}
   273  		batch.Queue("select count(*) from ledger")
   274  
   275  		br := conn.SendBatch(ctx, batch)
   276  
   277  		for i := 0; i < numInserts; i++ {
   278  			ct, err := br.Exec()
   279  			assert.NoError(t, err)
   280  			assert.EqualValues(t, 1, ct.RowsAffected())
   281  		}
   282  
   283  		var actualInserts int
   284  		err := br.QueryRow().Scan(&actualInserts)
   285  		assert.NoError(t, err)
   286  		assert.EqualValues(t, numInserts, actualInserts)
   287  
   288  		err = br.Close()
   289  		require.NoError(t, err)
   290  	})
   291  }
   292  
   293  func TestConnSendBatchWithPreparedStatement(t *testing.T) {
   294  	t.Parallel()
   295  
   296  	modes := []pgx.QueryExecMode{
   297  		pgx.QueryExecModeCacheStatement,
   298  		pgx.QueryExecModeCacheDescribe,
   299  		pgx.QueryExecModeDescribeExec,
   300  		pgx.QueryExecModeExec,
   301  		// Don't test simple mode with prepared statements.
   302  	}
   303  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   304  	defer cancel()
   305  
   306  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   307  		pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   308  		_, err := conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
   309  		if err != nil {
   310  			t.Fatal(err)
   311  		}
   312  
   313  		batch := &pgx.Batch{}
   314  
   315  		queryCount := 3
   316  		for i := 0; i < queryCount; i++ {
   317  			batch.Queue("ps1", 5)
   318  		}
   319  
   320  		br := conn.SendBatch(ctx, batch)
   321  
   322  		for i := 0; i < queryCount; i++ {
   323  			rows, err := br.Query()
   324  			if err != nil {
   325  				t.Fatal(err)
   326  			}
   327  
   328  			for k := 0; rows.Next(); k++ {
   329  				var n int
   330  				if err := rows.Scan(&n); err != nil {
   331  					t.Fatal(err)
   332  				}
   333  				if n != k {
   334  					t.Fatalf("n => %v, want %v", n, k)
   335  				}
   336  			}
   337  
   338  			if rows.Err() != nil {
   339  				t.Fatal(rows.Err())
   340  			}
   341  		}
   342  
   343  		err = br.Close()
   344  		if err != nil {
   345  			t.Fatal(err)
   346  		}
   347  	})
   348  }
   349  
   350  func TestConnSendBatchWithQueryRewriter(t *testing.T) {
   351  	t.Parallel()
   352  
   353  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   354  	defer cancel()
   355  
   356  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   357  		batch := &pgx.Batch{}
   358  		batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
   359  		batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
   360  		batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
   361  
   362  		br := conn.SendBatch(ctx, batch)
   363  
   364  		var n int32
   365  		err := br.QueryRow().Scan(&n)
   366  		require.NoError(t, err)
   367  		require.EqualValues(t, 1, n)
   368  
   369  		var s string
   370  		err = br.QueryRow().Scan(&s)
   371  		require.NoError(t, err)
   372  		require.Equal(t, "hello", s)
   373  
   374  		err = br.QueryRow().Scan(&n)
   375  		require.NoError(t, err)
   376  		require.EqualValues(t, 3, n)
   377  
   378  		err = br.Close()
   379  		require.NoError(t, err)
   380  	})
   381  }
   382  
   383  // https://github.com/jackc/pgx/issues/856
   384  func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
   385  	t.Parallel()
   386  
   387  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   388  	defer cancel()
   389  
   390  	config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
   391  	require.NoError(t, err)
   392  
   393  	config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
   394  	config.StatementCacheCapacity = 0
   395  	config.DescriptionCacheCapacity = 0
   396  
   397  	conn := mustConnect(t, config)
   398  	defer closeConn(t, conn)
   399  
   400  	pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   401  
   402  	_, err = conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
   403  	if err != nil {
   404  		t.Fatal(err)
   405  	}
   406  
   407  	batch := &pgx.Batch{}
   408  
   409  	queryCount := 3
   410  	for i := 0; i < queryCount; i++ {
   411  		batch.Queue("ps1", 5)
   412  	}
   413  
   414  	br := conn.SendBatch(ctx, batch)
   415  
   416  	for i := 0; i < queryCount; i++ {
   417  		rows, err := br.Query()
   418  		if err != nil {
   419  			t.Fatal(err)
   420  		}
   421  
   422  		for k := 0; rows.Next(); k++ {
   423  			var n int
   424  			if err := rows.Scan(&n); err != nil {
   425  				t.Fatal(err)
   426  			}
   427  			if n != k {
   428  				t.Fatalf("n => %v, want %v", n, k)
   429  			}
   430  		}
   431  
   432  		if rows.Err() != nil {
   433  			t.Fatal(rows.Err())
   434  		}
   435  	}
   436  
   437  	err = br.Close()
   438  	if err != nil {
   439  		t.Fatal(err)
   440  	}
   441  
   442  	ensureConnValid(t, conn)
   443  }
   444  
   445  func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
   446  	t.Parallel()
   447  
   448  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   449  	defer cancel()
   450  
   451  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   452  
   453  		batch := &pgx.Batch{}
   454  		batch.Queue("select n from generate_series(0,5) n")
   455  		batch.Queue("select n from generate_series(0,5) n")
   456  
   457  		br := conn.SendBatch(ctx, batch)
   458  
   459  		rows, err := br.Query()
   460  		if err != nil {
   461  			t.Error(err)
   462  		}
   463  
   464  		for i := 0; i < 3; i++ {
   465  			if !rows.Next() {
   466  				t.Error("expected a row to be available")
   467  			}
   468  
   469  			var n int
   470  			if err := rows.Scan(&n); err != nil {
   471  				t.Error(err)
   472  			}
   473  			if n != i {
   474  				t.Errorf("n => %v, want %v", n, i)
   475  			}
   476  		}
   477  
   478  		rows.Close()
   479  
   480  		rows, err = br.Query()
   481  		if err != nil {
   482  			t.Error(err)
   483  		}
   484  
   485  		for i := 0; rows.Next(); i++ {
   486  			var n int
   487  			if err := rows.Scan(&n); err != nil {
   488  				t.Error(err)
   489  			}
   490  			if n != i {
   491  				t.Errorf("n => %v, want %v", n, i)
   492  			}
   493  		}
   494  
   495  		if rows.Err() != nil {
   496  			t.Error(rows.Err())
   497  		}
   498  
   499  		err = br.Close()
   500  		if err != nil {
   501  			t.Fatal(err)
   502  		}
   503  
   504  	})
   505  }
   506  
   507  func TestConnSendBatchQueryError(t *testing.T) {
   508  	t.Parallel()
   509  
   510  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   511  	defer cancel()
   512  
   513  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   514  
   515  		batch := &pgx.Batch{}
   516  		batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
   517  		batch.Queue("select n from generate_series(0,5) n")
   518  
   519  		br := conn.SendBatch(ctx, batch)
   520  
   521  		rows, err := br.Query()
   522  		if err != nil {
   523  			t.Error(err)
   524  		}
   525  
   526  		for i := 0; rows.Next(); i++ {
   527  			var n int
   528  			if err := rows.Scan(&n); err != nil {
   529  				t.Error(err)
   530  			}
   531  			if n != i {
   532  				t.Errorf("n => %v, want %v", n, i)
   533  			}
   534  		}
   535  
   536  		if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
   537  			t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
   538  		}
   539  
   540  		err = br.Close()
   541  		if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
   542  			t.Errorf("br.Close() => %v, want error code %v", err, 22012)
   543  		}
   544  
   545  	})
   546  }
   547  
   548  func TestConnSendBatchQuerySyntaxError(t *testing.T) {
   549  	t.Parallel()
   550  
   551  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   552  	defer cancel()
   553  
   554  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   555  
   556  		batch := &pgx.Batch{}
   557  		batch.Queue("select 1 1")
   558  
   559  		br := conn.SendBatch(ctx, batch)
   560  
   561  		var n int32
   562  		err := br.QueryRow().Scan(&n)
   563  		if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
   564  			t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
   565  		}
   566  
   567  		err = br.Close()
   568  		if err == nil {
   569  			t.Error("Expected error")
   570  		}
   571  
   572  	})
   573  }
   574  
   575  func TestConnSendBatchQueryRowInsert(t *testing.T) {
   576  	t.Parallel()
   577  
   578  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   579  	defer cancel()
   580  
   581  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   582  
   583  		sql := `create temporary table ledger(
   584  	  id serial primary key,
   585  	  description varchar not null,
   586  	  amount int not null
   587  	);`
   588  		mustExec(t, conn, sql)
   589  
   590  		batch := &pgx.Batch{}
   591  		batch.Queue("select 1")
   592  		batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
   593  
   594  		br := conn.SendBatch(ctx, batch)
   595  
   596  		var value int
   597  		err := br.QueryRow().Scan(&value)
   598  		if err != nil {
   599  			t.Error(err)
   600  		}
   601  
   602  		ct, err := br.Exec()
   603  		if err != nil {
   604  			t.Error(err)
   605  		}
   606  		if ct.RowsAffected() != 2 {
   607  			t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
   608  		}
   609  
   610  		br.Close()
   611  
   612  	})
   613  }
   614  
   615  func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
   616  	t.Parallel()
   617  
   618  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   619  	defer cancel()
   620  
   621  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   622  
   623  		sql := `create temporary table ledger(
   624  	  id serial primary key,
   625  	  description varchar not null,
   626  	  amount int not null
   627  	);`
   628  		mustExec(t, conn, sql)
   629  
   630  		batch := &pgx.Batch{}
   631  		batch.Queue("select 1 union all select 2 union all select 3")
   632  		batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
   633  
   634  		br := conn.SendBatch(ctx, batch)
   635  
   636  		rows, err := br.Query()
   637  		if err != nil {
   638  			t.Error(err)
   639  		}
   640  		rows.Close()
   641  
   642  		ct, err := br.Exec()
   643  		if err != nil {
   644  			t.Error(err)
   645  		}
   646  		if ct.RowsAffected() != 2 {
   647  			t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
   648  		}
   649  
   650  		br.Close()
   651  
   652  	})
   653  }
   654  
   655  func TestTxSendBatch(t *testing.T) {
   656  	t.Parallel()
   657  
   658  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   659  	defer cancel()
   660  
   661  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   662  
   663  		sql := `create temporary table ledger1(
   664  	  id serial primary key,
   665  	  description varchar not null
   666  	);`
   667  		mustExec(t, conn, sql)
   668  
   669  		sql = `create temporary table ledger2(
   670  	  id int primary key,
   671  	  amount int not null
   672  	);`
   673  		mustExec(t, conn, sql)
   674  
   675  		tx, _ := conn.Begin(ctx)
   676  		batch := &pgx.Batch{}
   677  		batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
   678  
   679  		br := tx.SendBatch(context.Background(), batch)
   680  
   681  		var id int
   682  		err := br.QueryRow().Scan(&id)
   683  		if err != nil {
   684  			t.Error(err)
   685  		}
   686  		br.Close()
   687  
   688  		batch = &pgx.Batch{}
   689  		batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
   690  		batch.Queue("select amount from ledger2 where id = $1", id)
   691  
   692  		br = tx.SendBatch(ctx, batch)
   693  
   694  		ct, err := br.Exec()
   695  		if err != nil {
   696  			t.Error(err)
   697  		}
   698  		if ct.RowsAffected() != 1 {
   699  			t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
   700  		}
   701  
   702  		var amount int
   703  		err = br.QueryRow().Scan(&amount)
   704  		if err != nil {
   705  			t.Error(err)
   706  		}
   707  
   708  		br.Close()
   709  		tx.Commit(ctx)
   710  
   711  		var count int
   712  		conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id).Scan(&count)
   713  		if count != 1 {
   714  			t.Errorf("count => %v, want %v", count, 1)
   715  		}
   716  
   717  		err = br.Close()
   718  		if err != nil {
   719  			t.Fatal(err)
   720  		}
   721  
   722  	})
   723  }
   724  
   725  func TestTxSendBatchRollback(t *testing.T) {
   726  	t.Parallel()
   727  
   728  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   729  	defer cancel()
   730  
   731  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   732  
   733  		sql := `create temporary table ledger1(
   734  	  id serial primary key,
   735  	  description varchar not null
   736  	);`
   737  		mustExec(t, conn, sql)
   738  
   739  		tx, _ := conn.Begin(ctx)
   740  		batch := &pgx.Batch{}
   741  		batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
   742  
   743  		br := tx.SendBatch(ctx, batch)
   744  
   745  		var id int
   746  		err := br.QueryRow().Scan(&id)
   747  		if err != nil {
   748  			t.Error(err)
   749  		}
   750  		br.Close()
   751  		tx.Rollback(ctx)
   752  
   753  		row := conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id)
   754  		var count int
   755  		row.Scan(&count)
   756  		if count != 0 {
   757  			t.Errorf("count => %v, want %v", count, 0)
   758  		}
   759  
   760  	})
   761  }
   762  
   763  // https://github.com/jackc/pgx/issues/1578
   764  func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
   765  	t.Parallel()
   766  
   767  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   768  	defer cancel()
   769  
   770  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   771  		batch := &pgx.Batch{}
   772  		batch.Queue("select 4 / $1::int", 0)
   773  
   774  		batchResult := conn.SendBatch(ctx, batch)
   775  
   776  		_, execErr := batchResult.Exec()
   777  		require.Error(t, execErr)
   778  
   779  		closeErr := batchResult.Close()
   780  		require.Equal(t, execErr, closeErr)
   781  
   782  		// Try to use the connection.
   783  		_, err := conn.Exec(ctx, "select 1")
   784  		require.NoError(t, err)
   785  	})
   786  }
   787  
   788  func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) {
   789  	t.Parallel()
   790  
   791  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   792  	defer cancel()
   793  
   794  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   795  		batch := &pgx.Batch{}
   796  		batch.Queue("select 4 / n from generate_series(-2, 2) n")
   797  
   798  		batchResult := conn.SendBatch(ctx, batch)
   799  
   800  		_, execErr := batchResult.Exec()
   801  		require.Error(t, execErr)
   802  
   803  		closeErr := batchResult.Close()
   804  		require.Equal(t, execErr, closeErr)
   805  
   806  		// Try to use the connection.
   807  		_, err := conn.Exec(ctx, "select 1")
   808  		require.NoError(t, err)
   809  	})
   810  }
   811  
   812  func TestConnBeginBatchDeferredError(t *testing.T) {
   813  	t.Parallel()
   814  
   815  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   816  	defer cancel()
   817  
   818  	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   819  
   820  		pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
   821  
   822  		mustExec(t, conn, `create temporary table t (
   823  		id text primary key,
   824  		n int not null,
   825  		unique (n) deferrable initially deferred
   826  	);
   827  
   828  	insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
   829  
   830  		batch := &pgx.Batch{}
   831  
   832  		batch.Queue(`update t set n=n+1 where id='b' returning *`)
   833  
   834  		br := conn.SendBatch(ctx, batch)
   835  
   836  		rows, err := br.Query()
   837  		if err != nil {
   838  			t.Error(err)
   839  		}
   840  
   841  		for rows.Next() {
   842  			var id string
   843  			var n int32
   844  			err = rows.Scan(&id, &n)
   845  			if err != nil {
   846  				t.Fatal(err)
   847  			}
   848  		}
   849  
   850  		err = br.Close()
   851  		if err == nil {
   852  			t.Fatal("expected error 23505 but got none")
   853  		}
   854  
   855  		if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
   856  			t.Fatalf("expected error 23505, got %v", err)
   857  		}
   858  
   859  	})
   860  }
   861  
   862  func TestConnSendBatchNoStatementCache(t *testing.T) {
   863  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   864  	defer cancel()
   865  
   866  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   867  	config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
   868  	config.StatementCacheCapacity = 0
   869  	config.DescriptionCacheCapacity = 0
   870  
   871  	conn := mustConnect(t, config)
   872  	defer closeConn(t, conn)
   873  
   874  	testConnSendBatch(t, ctx, conn, 3)
   875  }
   876  
   877  func TestConnSendBatchPrepareStatementCache(t *testing.T) {
   878  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   879  	defer cancel()
   880  
   881  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   882  	config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
   883  	config.StatementCacheCapacity = 32
   884  
   885  	conn := mustConnect(t, config)
   886  	defer closeConn(t, conn)
   887  
   888  	testConnSendBatch(t, ctx, conn, 3)
   889  }
   890  
   891  func TestConnSendBatchDescribeStatementCache(t *testing.T) {
   892  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   893  	defer cancel()
   894  
   895  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   896  	config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
   897  	config.DescriptionCacheCapacity = 32
   898  
   899  	conn := mustConnect(t, config)
   900  	defer closeConn(t, conn)
   901  
   902  	testConnSendBatch(t, ctx, conn, 3)
   903  }
   904  
   905  func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryCount int) {
   906  	batch := &pgx.Batch{}
   907  	for j := 0; j < queryCount; j++ {
   908  		batch.Queue("select n from generate_series(0,5) n")
   909  	}
   910  
   911  	br := conn.SendBatch(ctx, batch)
   912  
   913  	for j := 0; j < queryCount; j++ {
   914  		rows, err := br.Query()
   915  		require.NoError(t, err)
   916  
   917  		for k := 0; rows.Next(); k++ {
   918  			var n int
   919  			err := rows.Scan(&n)
   920  			require.NoError(t, err)
   921  			require.Equal(t, k, n)
   922  		}
   923  
   924  		require.NoError(t, rows.Err())
   925  	}
   926  
   927  	err := br.Close()
   928  	require.NoError(t, err)
   929  }
   930  
   931  func TestSendBatchSimpleProtocol(t *testing.T) {
   932  	t.Parallel()
   933  
   934  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   935  	defer cancel()
   936  
   937  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   938  	config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
   939  
   940  	conn := mustConnect(t, config)
   941  	defer closeConn(t, conn)
   942  
   943  	var batch pgx.Batch
   944  	batch.Queue("SELECT 1::int")
   945  	batch.Queue("SELECT 2::int; SELECT $1::int", 3)
   946  	results := conn.SendBatch(ctx, &batch)
   947  	rows, err := results.Query()
   948  	assert.NoError(t, err)
   949  	assert.True(t, rows.Next())
   950  	values, err := rows.Values()
   951  	assert.NoError(t, err)
   952  	assert.EqualValues(t, 1, values[0])
   953  	assert.False(t, rows.Next())
   954  
   955  	rows, err = results.Query()
   956  	assert.NoError(t, err)
   957  	assert.True(t, rows.Next())
   958  	values, err = rows.Values()
   959  	assert.NoError(t, err)
   960  	assert.EqualValues(t, 2, values[0])
   961  	assert.False(t, rows.Next())
   962  
   963  	rows, err = results.Query()
   964  	assert.NoError(t, err)
   965  	assert.True(t, rows.Next())
   966  	values, err = rows.Values()
   967  	assert.NoError(t, err)
   968  	assert.EqualValues(t, 3, values[0])
   969  	assert.False(t, rows.Next())
   970  }
   971  
   972  func ExampleConn_SendBatch() {
   973  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   974  	defer cancel()
   975  
   976  	conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
   977  	if err != nil {
   978  		fmt.Printf("Unable to establish connection: %v", err)
   979  		return
   980  	}
   981  
   982  	batch := &pgx.Batch{}
   983  	batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
   984  		var n int32
   985  		err := row.Scan(&n)
   986  		if err != nil {
   987  			return err
   988  		}
   989  
   990  		fmt.Println(n)
   991  
   992  		return err
   993  	})
   994  
   995  	batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error {
   996  		var n int32
   997  		err := row.Scan(&n)
   998  		if err != nil {
   999  			return err
  1000  		}
  1001  
  1002  		fmt.Println(n)
  1003  
  1004  		return err
  1005  	})
  1006  
  1007  	batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error {
  1008  		var n int32
  1009  		err := row.Scan(&n)
  1010  		if err != nil {
  1011  			return err
  1012  		}
  1013  
  1014  		fmt.Println(n)
  1015  
  1016  		return err
  1017  	})
  1018  
  1019  	err = conn.SendBatch(ctx, batch).Close()
  1020  	if err != nil {
  1021  		fmt.Printf("SendBatch error: %v", err)
  1022  		return
  1023  	}
  1024  
  1025  	// Output:
  1026  	// 2
  1027  	// 3
  1028  	// 5
  1029  }
  1030  

View as plain text