...

Source file src/github.com/jackc/pgx/v4/batch_test.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"os"
     7  	"testing"
     8  
     9  	"github.com/jackc/pgconn"
    10  	"github.com/jackc/pgconn/stmtcache"
    11  	"github.com/jackc/pgx/v4"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func TestConnSendBatch(t *testing.T) {
    17  	t.Parallel()
    18  
    19  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    20  	defer closeConn(t, conn)
    21  
    22  	skipCockroachDB(t, conn, "Server serial type is incompatible with test")
    23  
    24  	sql := `create temporary table ledger(
    25  	  id serial primary key,
    26  	  description varchar not null,
    27  	  amount int not null
    28  	);`
    29  	mustExec(t, conn, sql)
    30  
    31  	batch := &pgx.Batch{}
    32  	batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
    33  	batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
    34  	batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
    35  	batch.Queue("select id, description, amount from ledger order by id")
    36  	batch.Queue("select id, description, amount from ledger order by id")
    37  	batch.Queue("select * from ledger where false")
    38  	batch.Queue("select sum(amount) from ledger")
    39  
    40  	br := conn.SendBatch(context.Background(), batch)
    41  
    42  	ct, err := br.Exec()
    43  	if err != nil {
    44  		t.Error(err)
    45  	}
    46  	if ct.RowsAffected() != 1 {
    47  		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
    48  	}
    49  
    50  	ct, err = br.Exec()
    51  	if err != nil {
    52  		t.Error(err)
    53  	}
    54  	if ct.RowsAffected() != 1 {
    55  		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
    56  	}
    57  
    58  	ct, err = br.Exec()
    59  	if err != nil {
    60  		t.Error(err)
    61  	}
    62  	if ct.RowsAffected() != 1 {
    63  		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
    64  	}
    65  
    66  	selectFromLedgerExpectedRows := []struct {
    67  		id          int32
    68  		description string
    69  		amount      int32
    70  	}{
    71  		{1, "q1", 1},
    72  		{2, "q2", 2},
    73  		{3, "q3", 3},
    74  	}
    75  
    76  	rows, err := br.Query()
    77  	if err != nil {
    78  		t.Error(err)
    79  	}
    80  
    81  	var id int32
    82  	var description string
    83  	var amount int32
    84  	rowCount := 0
    85  
    86  	for rows.Next() {
    87  		if rowCount >= len(selectFromLedgerExpectedRows) {
    88  			t.Fatalf("got too many rows: %d", rowCount)
    89  		}
    90  
    91  		if err := rows.Scan(&id, &description, &amount); err != nil {
    92  			t.Fatalf("row %d: %v", rowCount, err)
    93  		}
    94  
    95  		if id != selectFromLedgerExpectedRows[rowCount].id {
    96  			t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
    97  		}
    98  		if description != selectFromLedgerExpectedRows[rowCount].description {
    99  			t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
   100  		}
   101  		if amount != selectFromLedgerExpectedRows[rowCount].amount {
   102  			t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
   103  		}
   104  
   105  		rowCount++
   106  	}
   107  
   108  	if rows.Err() != nil {
   109  		t.Fatal(rows.Err())
   110  	}
   111  
   112  	rowCount = 0
   113  	_, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error {
   114  		if id != selectFromLedgerExpectedRows[rowCount].id {
   115  			t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
   116  		}
   117  		if description != selectFromLedgerExpectedRows[rowCount].description {
   118  			t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
   119  		}
   120  		if amount != selectFromLedgerExpectedRows[rowCount].amount {
   121  			t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
   122  		}
   123  
   124  		rowCount++
   125  
   126  		return nil
   127  	})
   128  	if err != nil {
   129  		t.Error(err)
   130  	}
   131  
   132  	err = br.QueryRow().Scan(&id, &description, &amount)
   133  	if !errors.Is(err, pgx.ErrNoRows) {
   134  		t.Errorf("expected pgx.ErrNoRows but got: %v", err)
   135  	}
   136  
   137  	err = br.QueryRow().Scan(&amount)
   138  	if err != nil {
   139  		t.Error(err)
   140  	}
   141  	if amount != 6 {
   142  		t.Errorf("amount => %v, want %v", amount, 6)
   143  	}
   144  
   145  	err = br.Close()
   146  	if err != nil {
   147  		t.Fatal(err)
   148  	}
   149  
   150  	ensureConnValid(t, conn)
   151  }
   152  
   153  func TestConnSendBatchMany(t *testing.T) {
   154  	t.Parallel()
   155  
   156  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   157  	defer closeConn(t, conn)
   158  
   159  	sql := `create temporary table ledger(
   160  	  id serial primary key,
   161  	  description varchar not null,
   162  	  amount int not null
   163  	);`
   164  	mustExec(t, conn, sql)
   165  
   166  	batch := &pgx.Batch{}
   167  
   168  	numInserts := 1000
   169  
   170  	for i := 0; i < numInserts; i++ {
   171  		batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
   172  	}
   173  	batch.Queue("select count(*) from ledger")
   174  
   175  	br := conn.SendBatch(context.Background(), batch)
   176  
   177  	for i := 0; i < numInserts; i++ {
   178  		ct, err := br.Exec()
   179  		assert.NoError(t, err)
   180  		assert.EqualValues(t, 1, ct.RowsAffected())
   181  	}
   182  
   183  	var actualInserts int
   184  	err := br.QueryRow().Scan(&actualInserts)
   185  	assert.NoError(t, err)
   186  	assert.EqualValues(t, numInserts, actualInserts)
   187  
   188  	err = br.Close()
   189  	require.NoError(t, err)
   190  
   191  	ensureConnValid(t, conn)
   192  }
   193  
   194  func TestConnSendBatchWithPreparedStatement(t *testing.T) {
   195  	t.Parallel()
   196  
   197  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   198  	defer closeConn(t, conn)
   199  
   200  	skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   201  
   202  	_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  
   207  	batch := &pgx.Batch{}
   208  
   209  	queryCount := 3
   210  	for i := 0; i < queryCount; i++ {
   211  		batch.Queue("ps1", 5)
   212  	}
   213  
   214  	br := conn.SendBatch(context.Background(), batch)
   215  
   216  	for i := 0; i < queryCount; i++ {
   217  		rows, err := br.Query()
   218  		if err != nil {
   219  			t.Fatal(err)
   220  		}
   221  
   222  		for k := 0; rows.Next(); k++ {
   223  			var n int
   224  			if err := rows.Scan(&n); err != nil {
   225  				t.Fatal(err)
   226  			}
   227  			if n != k {
   228  				t.Fatalf("n => %v, want %v", n, k)
   229  			}
   230  		}
   231  
   232  		if rows.Err() != nil {
   233  			t.Fatal(rows.Err())
   234  		}
   235  	}
   236  
   237  	err = br.Close()
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  
   242  	ensureConnValid(t, conn)
   243  }
   244  
   245  // https://github.com/jackc/pgx/issues/856
   246  func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
   247  	t.Parallel()
   248  
   249  	config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
   250  	require.NoError(t, err)
   251  
   252  	config.BuildStatementCache = nil
   253  
   254  	conn := mustConnect(t, config)
   255  	defer closeConn(t, conn)
   256  
   257  	skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
   258  
   259  	_, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
   260  	if err != nil {
   261  		t.Fatal(err)
   262  	}
   263  
   264  	batch := &pgx.Batch{}
   265  
   266  	queryCount := 3
   267  	for i := 0; i < queryCount; i++ {
   268  		batch.Queue("ps1", 5)
   269  	}
   270  
   271  	br := conn.SendBatch(context.Background(), batch)
   272  
   273  	for i := 0; i < queryCount; i++ {
   274  		rows, err := br.Query()
   275  		if err != nil {
   276  			t.Fatal(err)
   277  		}
   278  
   279  		for k := 0; rows.Next(); k++ {
   280  			var n int
   281  			if err := rows.Scan(&n); err != nil {
   282  				t.Fatal(err)
   283  			}
   284  			if n != k {
   285  				t.Fatalf("n => %v, want %v", n, k)
   286  			}
   287  		}
   288  
   289  		if rows.Err() != nil {
   290  			t.Fatal(rows.Err())
   291  		}
   292  	}
   293  
   294  	err = br.Close()
   295  	if err != nil {
   296  		t.Fatal(err)
   297  	}
   298  
   299  	ensureConnValid(t, conn)
   300  }
   301  
   302  func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
   303  	t.Parallel()
   304  
   305  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   306  	defer closeConn(t, conn)
   307  
   308  	batch := &pgx.Batch{}
   309  	batch.Queue("select n from generate_series(0,5) n")
   310  	batch.Queue("select n from generate_series(0,5) n")
   311  
   312  	br := conn.SendBatch(context.Background(), batch)
   313  
   314  	rows, err := br.Query()
   315  	if err != nil {
   316  		t.Error(err)
   317  	}
   318  
   319  	for i := 0; i < 3; i++ {
   320  		if !rows.Next() {
   321  			t.Error("expected a row to be available")
   322  		}
   323  
   324  		var n int
   325  		if err := rows.Scan(&n); err != nil {
   326  			t.Error(err)
   327  		}
   328  		if n != i {
   329  			t.Errorf("n => %v, want %v", n, i)
   330  		}
   331  	}
   332  
   333  	rows.Close()
   334  
   335  	rows, err = br.Query()
   336  	if err != nil {
   337  		t.Error(err)
   338  	}
   339  
   340  	for i := 0; rows.Next(); i++ {
   341  		var n int
   342  		if err := rows.Scan(&n); err != nil {
   343  			t.Error(err)
   344  		}
   345  		if n != i {
   346  			t.Errorf("n => %v, want %v", n, i)
   347  		}
   348  	}
   349  
   350  	if rows.Err() != nil {
   351  		t.Error(rows.Err())
   352  	}
   353  
   354  	err = br.Close()
   355  	if err != nil {
   356  		t.Fatal(err)
   357  	}
   358  
   359  	ensureConnValid(t, conn)
   360  }
   361  
   362  func TestConnSendBatchQueryError(t *testing.T) {
   363  	t.Parallel()
   364  
   365  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   366  	defer closeConn(t, conn)
   367  
   368  	batch := &pgx.Batch{}
   369  	batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
   370  	batch.Queue("select n from generate_series(0,5) n")
   371  
   372  	br := conn.SendBatch(context.Background(), batch)
   373  
   374  	rows, err := br.Query()
   375  	if err != nil {
   376  		t.Error(err)
   377  	}
   378  
   379  	for i := 0; rows.Next(); i++ {
   380  		var n int
   381  		if err := rows.Scan(&n); err != nil {
   382  			t.Error(err)
   383  		}
   384  		if n != i {
   385  			t.Errorf("n => %v, want %v", n, i)
   386  		}
   387  	}
   388  
   389  	if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
   390  		t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
   391  	}
   392  
   393  	err = br.Close()
   394  	if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
   395  		t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
   396  	}
   397  
   398  	ensureConnValid(t, conn)
   399  }
   400  
   401  func TestConnSendBatchQuerySyntaxError(t *testing.T) {
   402  	t.Parallel()
   403  
   404  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   405  	defer closeConn(t, conn)
   406  
   407  	batch := &pgx.Batch{}
   408  	batch.Queue("select 1 1")
   409  
   410  	br := conn.SendBatch(context.Background(), batch)
   411  
   412  	var n int32
   413  	err := br.QueryRow().Scan(&n)
   414  	if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
   415  		t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
   416  	}
   417  
   418  	err = br.Close()
   419  	if err == nil {
   420  		t.Error("Expected error")
   421  	}
   422  
   423  	ensureConnValid(t, conn)
   424  }
   425  
   426  func TestConnSendBatchQueryRowInsert(t *testing.T) {
   427  	t.Parallel()
   428  
   429  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   430  	defer closeConn(t, conn)
   431  
   432  	sql := `create temporary table ledger(
   433  	  id serial primary key,
   434  	  description varchar not null,
   435  	  amount int not null
   436  	);`
   437  	mustExec(t, conn, sql)
   438  
   439  	batch := &pgx.Batch{}
   440  	batch.Queue("select 1")
   441  	batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
   442  
   443  	br := conn.SendBatch(context.Background(), batch)
   444  
   445  	var value int
   446  	err := br.QueryRow().Scan(&value)
   447  	if err != nil {
   448  		t.Error(err)
   449  	}
   450  
   451  	ct, err := br.Exec()
   452  	if err != nil {
   453  		t.Error(err)
   454  	}
   455  	if ct.RowsAffected() != 2 {
   456  		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
   457  	}
   458  
   459  	br.Close()
   460  
   461  	ensureConnValid(t, conn)
   462  }
   463  
   464  func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
   465  	t.Parallel()
   466  
   467  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   468  	defer closeConn(t, conn)
   469  
   470  	sql := `create temporary table ledger(
   471  	  id serial primary key,
   472  	  description varchar not null,
   473  	  amount int not null
   474  	);`
   475  	mustExec(t, conn, sql)
   476  
   477  	batch := &pgx.Batch{}
   478  	batch.Queue("select 1 union all select 2 union all select 3")
   479  	batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
   480  
   481  	br := conn.SendBatch(context.Background(), batch)
   482  
   483  	rows, err := br.Query()
   484  	if err != nil {
   485  		t.Error(err)
   486  	}
   487  	rows.Close()
   488  
   489  	ct, err := br.Exec()
   490  	if err != nil {
   491  		t.Error(err)
   492  	}
   493  	if ct.RowsAffected() != 2 {
   494  		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
   495  	}
   496  
   497  	br.Close()
   498  
   499  	ensureConnValid(t, conn)
   500  }
   501  
   502  func TestTxSendBatch(t *testing.T) {
   503  	t.Parallel()
   504  
   505  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   506  	defer closeConn(t, conn)
   507  
   508  	sql := `create temporary table ledger1(
   509  	  id serial primary key,
   510  	  description varchar not null
   511  	);`
   512  	mustExec(t, conn, sql)
   513  
   514  	sql = `create temporary table ledger2(
   515  	  id int primary key,
   516  	  amount int not null
   517  	);`
   518  	mustExec(t, conn, sql)
   519  
   520  	tx, _ := conn.Begin(context.Background())
   521  	batch := &pgx.Batch{}
   522  	batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
   523  
   524  	br := tx.SendBatch(context.Background(), batch)
   525  
   526  	var id int
   527  	err := br.QueryRow().Scan(&id)
   528  	if err != nil {
   529  		t.Error(err)
   530  	}
   531  	br.Close()
   532  
   533  	batch = &pgx.Batch{}
   534  	batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
   535  	batch.Queue("select amount from ledger2 where id = $1", id)
   536  
   537  	br = tx.SendBatch(context.Background(), batch)
   538  
   539  	ct, err := br.Exec()
   540  	if err != nil {
   541  		t.Error(err)
   542  	}
   543  	if ct.RowsAffected() != 1 {
   544  		t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
   545  	}
   546  
   547  	var amount int
   548  	err = br.QueryRow().Scan(&amount)
   549  	if err != nil {
   550  		t.Error(err)
   551  	}
   552  
   553  	br.Close()
   554  	tx.Commit(context.Background())
   555  
   556  	var count int
   557  	conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count)
   558  	if count != 1 {
   559  		t.Errorf("count => %v, want %v", count, 1)
   560  	}
   561  
   562  	err = br.Close()
   563  	if err != nil {
   564  		t.Fatal(err)
   565  	}
   566  
   567  	ensureConnValid(t, conn)
   568  }
   569  
   570  func TestTxSendBatchRollback(t *testing.T) {
   571  	t.Parallel()
   572  
   573  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   574  	defer closeConn(t, conn)
   575  
   576  	sql := `create temporary table ledger1(
   577  	  id serial primary key,
   578  	  description varchar not null
   579  	);`
   580  	mustExec(t, conn, sql)
   581  
   582  	tx, _ := conn.Begin(context.Background())
   583  	batch := &pgx.Batch{}
   584  	batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
   585  
   586  	br := tx.SendBatch(context.Background(), batch)
   587  
   588  	var id int
   589  	err := br.QueryRow().Scan(&id)
   590  	if err != nil {
   591  		t.Error(err)
   592  	}
   593  	br.Close()
   594  	tx.Rollback(context.Background())
   595  
   596  	row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
   597  	var count int
   598  	row.Scan(&count)
   599  	if count != 0 {
   600  		t.Errorf("count => %v, want %v", count, 0)
   601  	}
   602  
   603  	ensureConnValid(t, conn)
   604  }
   605  
   606  func TestConnBeginBatchDeferredError(t *testing.T) {
   607  	t.Parallel()
   608  
   609  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   610  	defer closeConn(t, conn)
   611  
   612  	skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
   613  
   614  	mustExec(t, conn, `create temporary table t (
   615  		id text primary key,
   616  		n int not null,
   617  		unique (n) deferrable initially deferred
   618  	);
   619  
   620  	insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
   621  
   622  	batch := &pgx.Batch{}
   623  
   624  	batch.Queue(`update t set n=n+1 where id='b' returning *`)
   625  
   626  	br := conn.SendBatch(context.Background(), batch)
   627  
   628  	rows, err := br.Query()
   629  	if err != nil {
   630  		t.Error(err)
   631  	}
   632  
   633  	for rows.Next() {
   634  		var id string
   635  		var n int32
   636  		err = rows.Scan(&id, &n)
   637  		if err != nil {
   638  			t.Fatal(err)
   639  		}
   640  	}
   641  
   642  	err = br.Close()
   643  	if err == nil {
   644  		t.Fatal("expected error 23505 but got none")
   645  	}
   646  
   647  	if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
   648  		t.Fatalf("expected error 23505, got %v", err)
   649  	}
   650  
   651  	ensureConnValid(t, conn)
   652  }
   653  
   654  func TestConnSendBatchNoStatementCache(t *testing.T) {
   655  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   656  	config.BuildStatementCache = nil
   657  
   658  	conn := mustConnect(t, config)
   659  	defer closeConn(t, conn)
   660  
   661  	testConnSendBatch(t, conn, 3)
   662  }
   663  
   664  func TestConnSendBatchPrepareStatementCache(t *testing.T) {
   665  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   666  	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
   667  		return stmtcache.New(conn, stmtcache.ModePrepare, 32)
   668  	}
   669  
   670  	conn := mustConnect(t, config)
   671  	defer closeConn(t, conn)
   672  
   673  	testConnSendBatch(t, conn, 3)
   674  }
   675  
   676  func TestConnSendBatchDescribeStatementCache(t *testing.T) {
   677  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   678  	config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
   679  		return stmtcache.New(conn, stmtcache.ModeDescribe, 32)
   680  	}
   681  
   682  	conn := mustConnect(t, config)
   683  	defer closeConn(t, conn)
   684  
   685  	testConnSendBatch(t, conn, 3)
   686  }
   687  
   688  func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) {
   689  	batch := &pgx.Batch{}
   690  	for j := 0; j < queryCount; j++ {
   691  		batch.Queue("select n from generate_series(0,5) n")
   692  	}
   693  
   694  	br := conn.SendBatch(context.Background(), batch)
   695  
   696  	for j := 0; j < queryCount; j++ {
   697  		rows, err := br.Query()
   698  		require.NoError(t, err)
   699  
   700  		for k := 0; rows.Next(); k++ {
   701  			var n int
   702  			err := rows.Scan(&n)
   703  			require.NoError(t, err)
   704  			require.Equal(t, k, n)
   705  		}
   706  
   707  		require.NoError(t, rows.Err())
   708  	}
   709  
   710  	err := br.Close()
   711  	require.NoError(t, err)
   712  }
   713  
   714  func TestLogBatchStatementsOnExec(t *testing.T) {
   715  	l1 := &testLogger{}
   716  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   717  	config.Logger = l1
   718  
   719  	conn := mustConnect(t, config)
   720  	defer closeConn(t, conn)
   721  
   722  	l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
   723  
   724  	batch := &pgx.Batch{}
   725  	batch.Queue("create table foo (id bigint)")
   726  	batch.Queue("drop table foo")
   727  
   728  	br := conn.SendBatch(context.Background(), batch)
   729  
   730  	_, err := br.Exec()
   731  	if err != nil {
   732  		t.Fatalf("Unexpected error creating table: %v", err)
   733  	}
   734  
   735  	_, err = br.Exec()
   736  	if err != nil {
   737  		t.Fatalf("Unexpected error dropping table: %v", err)
   738  	}
   739  
   740  	if len(l1.logs) != 3 {
   741  		t.Fatalf("Expected two log entries but got %d", len(l1.logs))
   742  	}
   743  
   744  	if l1.logs[0].msg != "SendBatch" {
   745  		t.Errorf("Expected first log message to be 'SendBatch' but was '%s'", l1.logs[0].msg)
   746  	}
   747  
   748  	if l1.logs[1].msg != "BatchResult.Exec" {
   749  		t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s'", l1.logs[0].msg)
   750  	}
   751  
   752  	if l1.logs[1].data["sql"] != "create table foo (id bigint)" {
   753  		t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"])
   754  	}
   755  
   756  	if l1.logs[2].msg != "BatchResult.Exec" {
   757  		t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg)
   758  	}
   759  
   760  	if l1.logs[2].data["sql"] != "drop table foo" {
   761  		t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"])
   762  	}
   763  }
   764  
   765  func TestLogBatchStatementsOnBatchResultClose(t *testing.T) {
   766  	l1 := &testLogger{}
   767  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   768  	config.Logger = l1
   769  
   770  	conn := mustConnect(t, config)
   771  	defer closeConn(t, conn)
   772  
   773  	l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
   774  
   775  	batch := &pgx.Batch{}
   776  	batch.Queue("select generate_series(1,$1)", 100)
   777  	batch.Queue("select 1 = 1;")
   778  
   779  	br := conn.SendBatch(context.Background(), batch)
   780  
   781  	if err := br.Close(); err != nil {
   782  		t.Fatalf("Unexpected batch error: %v", err)
   783  	}
   784  
   785  	if len(l1.logs) != 3 {
   786  		t.Fatalf("Expected 2 log statements but found %d", len(l1.logs))
   787  	}
   788  
   789  	if l1.logs[0].msg != "SendBatch" {
   790  		t.Errorf("Expected first log message to be 'SendBatch' but was '%s'", l1.logs[0].msg)
   791  	}
   792  
   793  	if l1.logs[1].msg != "BatchResult.Close" {
   794  		t.Errorf("Expected first log statement to be 'BatchResult.Close' but was '%s'", l1.logs[0].msg)
   795  	}
   796  
   797  	if l1.logs[1].data["sql"] != "select generate_series(1,$1)" {
   798  		t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"])
   799  	}
   800  
   801  	if l1.logs[2].msg != "BatchResult.Close" {
   802  		t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg)
   803  	}
   804  
   805  	if l1.logs[2].data["sql"] != "select 1 = 1;" {
   806  		t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"])
   807  	}
   808  }
   809  
   810  func TestSendBatchSimpleProtocol(t *testing.T) {
   811  	t.Parallel()
   812  
   813  	config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
   814  	config.PreferSimpleProtocol = true
   815  
   816  	ctx, cancelFunc := context.WithCancel(context.Background())
   817  	defer cancelFunc()
   818  
   819  	conn := mustConnect(t, config)
   820  	defer closeConn(t, conn)
   821  
   822  	var batch pgx.Batch
   823  	batch.Queue("SELECT 1::int")
   824  	batch.Queue("SELECT 2::int; SELECT $1::int", 3)
   825  	results := conn.SendBatch(ctx, &batch)
   826  	rows, err := results.Query()
   827  	assert.NoError(t, err)
   828  	assert.True(t, rows.Next())
   829  	values, err := rows.Values()
   830  	assert.NoError(t, err)
   831  	assert.EqualValues(t, 1, values[0])
   832  	assert.False(t, rows.Next())
   833  
   834  	rows, err = results.Query()
   835  	assert.NoError(t, err)
   836  	assert.True(t, rows.Next())
   837  	values, err = rows.Values()
   838  	assert.NoError(t, err)
   839  	assert.EqualValues(t, 2, values[0])
   840  	assert.False(t, rows.Next())
   841  
   842  	rows, err = results.Query()
   843  	assert.NoError(t, err)
   844  	assert.True(t, rows.Next())
   845  	values, err = rows.Values()
   846  	assert.NoError(t, err)
   847  	assert.EqualValues(t, 3, values[0])
   848  	assert.False(t, rows.Next())
   849  }
   850  

View as plain text