1 package pgx_test
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "testing"
9 "time"
10
11 "github.com/jackc/pgx/v5"
12 "github.com/jackc/pgx/v5/pgconn"
13 "github.com/jackc/pgx/v5/pgxtest"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 )
17
18 func TestConnSendBatch(t *testing.T) {
19 t.Parallel()
20
21 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
22 defer cancel()
23
24 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
25 pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
26
27 sql := `create temporary table ledger(
28 id serial primary key,
29 description varchar not null,
30 amount int not null
31 );`
32 mustExec(t, conn, sql)
33
34 batch := &pgx.Batch{}
35 batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
36 batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
37 batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
38 batch.Queue("select id, description, amount from ledger order by id")
39 batch.Queue("select id, description, amount from ledger order by id")
40 batch.Queue("select * from ledger where false")
41 batch.Queue("select sum(amount) from ledger")
42
43 br := conn.SendBatch(ctx, batch)
44
45 ct, err := br.Exec()
46 if err != nil {
47 t.Error(err)
48 }
49 if ct.RowsAffected() != 1 {
50 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
51 }
52
53 ct, err = br.Exec()
54 if err != nil {
55 t.Error(err)
56 }
57 if ct.RowsAffected() != 1 {
58 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
59 }
60
61 ct, err = br.Exec()
62 if err != nil {
63 t.Error(err)
64 }
65 if ct.RowsAffected() != 1 {
66 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
67 }
68
69 selectFromLedgerExpectedRows := []struct {
70 id int32
71 description string
72 amount int32
73 }{
74 {1, "q1", 1},
75 {2, "q2", 2},
76 {3, "q3", 3},
77 }
78
79 rows, err := br.Query()
80 if err != nil {
81 t.Error(err)
82 }
83
84 var id int32
85 var description string
86 var amount int32
87 rowCount := 0
88
89 for rows.Next() {
90 if rowCount >= len(selectFromLedgerExpectedRows) {
91 t.Fatalf("got too many rows: %d", rowCount)
92 }
93
94 if err := rows.Scan(&id, &description, &amount); err != nil {
95 t.Fatalf("row %d: %v", rowCount, err)
96 }
97
98 if id != selectFromLedgerExpectedRows[rowCount].id {
99 t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
100 }
101 if description != selectFromLedgerExpectedRows[rowCount].description {
102 t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
103 }
104 if amount != selectFromLedgerExpectedRows[rowCount].amount {
105 t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
106 }
107
108 rowCount++
109 }
110
111 if rows.Err() != nil {
112 t.Fatal(rows.Err())
113 }
114
115 rowCount = 0
116 rows, _ = br.Query()
117 _, err = pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
118 if id != selectFromLedgerExpectedRows[rowCount].id {
119 t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
120 }
121 if description != selectFromLedgerExpectedRows[rowCount].description {
122 t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
123 }
124 if amount != selectFromLedgerExpectedRows[rowCount].amount {
125 t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
126 }
127
128 rowCount++
129
130 return nil
131 })
132 if err != nil {
133 t.Error(err)
134 }
135
136 err = br.QueryRow().Scan(&id, &description, &amount)
137 if !errors.Is(err, pgx.ErrNoRows) {
138 t.Errorf("expected pgx.ErrNoRows but got: %v", err)
139 }
140
141 err = br.QueryRow().Scan(&amount)
142 if err != nil {
143 t.Error(err)
144 }
145 if amount != 6 {
146 t.Errorf("amount => %v, want %v", amount, 6)
147 }
148
149 err = br.Close()
150 if err != nil {
151 t.Fatal(err)
152 }
153 })
154 }
155
156 func TestConnSendBatchQueuedQuery(t *testing.T) {
157 t.Parallel()
158
159 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
160 defer cancel()
161
162 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
163 pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
164
165 sql := `create temporary table ledger(
166 id serial primary key,
167 description varchar not null,
168 amount int not null
169 );`
170 mustExec(t, conn, sql)
171
172 batch := &pgx.Batch{}
173
174 batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error {
175 assert.EqualValues(t, 1, ct.RowsAffected())
176 return nil
177 })
178
179 batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error {
180 assert.EqualValues(t, 1, ct.RowsAffected())
181 return nil
182 })
183
184 batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error {
185 assert.EqualValues(t, 1, ct.RowsAffected())
186 return nil
187 })
188
189 selectFromLedgerExpectedRows := []struct {
190 id int32
191 description string
192 amount int32
193 }{
194 {1, "q1", 1},
195 {2, "q2", 2},
196 {3, "q3", 3},
197 }
198
199 batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
200 rowCount := 0
201 var id int32
202 var description string
203 var amount int32
204 _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
205 assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
206 assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
207 assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
208 rowCount++
209
210 return nil
211 })
212 assert.NoError(t, err)
213 return nil
214 })
215
216 batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
217 rowCount := 0
218 var id int32
219 var description string
220 var amount int32
221 _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
222 assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
223 assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
224 assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
225 rowCount++
226
227 return nil
228 })
229 assert.NoError(t, err)
230 return nil
231 })
232
233 batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error {
234 err := row.Scan(nil, nil, nil)
235 assert.ErrorIs(t, err, pgx.ErrNoRows)
236 return nil
237 })
238
239 batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error {
240 var sumAmount int32
241 err := row.Scan(&sumAmount)
242 assert.NoError(t, err)
243 assert.EqualValues(t, 6, sumAmount)
244 return nil
245 })
246
247 err := conn.SendBatch(ctx, batch).Close()
248 assert.NoError(t, err)
249 })
250 }
251
252 func TestConnSendBatchMany(t *testing.T) {
253 t.Parallel()
254
255 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
256 defer cancel()
257
258 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
259 sql := `create temporary table ledger(
260 id serial primary key,
261 description varchar not null,
262 amount int not null
263 );`
264 mustExec(t, conn, sql)
265
266 batch := &pgx.Batch{}
267
268 numInserts := 1000
269
270 for i := 0; i < numInserts; i++ {
271 batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
272 }
273 batch.Queue("select count(*) from ledger")
274
275 br := conn.SendBatch(ctx, batch)
276
277 for i := 0; i < numInserts; i++ {
278 ct, err := br.Exec()
279 assert.NoError(t, err)
280 assert.EqualValues(t, 1, ct.RowsAffected())
281 }
282
283 var actualInserts int
284 err := br.QueryRow().Scan(&actualInserts)
285 assert.NoError(t, err)
286 assert.EqualValues(t, numInserts, actualInserts)
287
288 err = br.Close()
289 require.NoError(t, err)
290 })
291 }
292
293 func TestConnSendBatchWithPreparedStatement(t *testing.T) {
294 t.Parallel()
295
296 modes := []pgx.QueryExecMode{
297 pgx.QueryExecModeCacheStatement,
298 pgx.QueryExecModeCacheDescribe,
299 pgx.QueryExecModeDescribeExec,
300 pgx.QueryExecModeExec,
301
302 }
303 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
304 defer cancel()
305
306 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
307 pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
308 _, err := conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
309 if err != nil {
310 t.Fatal(err)
311 }
312
313 batch := &pgx.Batch{}
314
315 queryCount := 3
316 for i := 0; i < queryCount; i++ {
317 batch.Queue("ps1", 5)
318 }
319
320 br := conn.SendBatch(ctx, batch)
321
322 for i := 0; i < queryCount; i++ {
323 rows, err := br.Query()
324 if err != nil {
325 t.Fatal(err)
326 }
327
328 for k := 0; rows.Next(); k++ {
329 var n int
330 if err := rows.Scan(&n); err != nil {
331 t.Fatal(err)
332 }
333 if n != k {
334 t.Fatalf("n => %v, want %v", n, k)
335 }
336 }
337
338 if rows.Err() != nil {
339 t.Fatal(rows.Err())
340 }
341 }
342
343 err = br.Close()
344 if err != nil {
345 t.Fatal(err)
346 }
347 })
348 }
349
350 func TestConnSendBatchWithQueryRewriter(t *testing.T) {
351 t.Parallel()
352
353 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
354 defer cancel()
355
356 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
357 batch := &pgx.Batch{}
358 batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
359 batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
360 batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
361
362 br := conn.SendBatch(ctx, batch)
363
364 var n int32
365 err := br.QueryRow().Scan(&n)
366 require.NoError(t, err)
367 require.EqualValues(t, 1, n)
368
369 var s string
370 err = br.QueryRow().Scan(&s)
371 require.NoError(t, err)
372 require.Equal(t, "hello", s)
373
374 err = br.QueryRow().Scan(&n)
375 require.NoError(t, err)
376 require.EqualValues(t, 3, n)
377
378 err = br.Close()
379 require.NoError(t, err)
380 })
381 }
382
383
384 func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
385 t.Parallel()
386
387 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
388 defer cancel()
389
390 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
391 require.NoError(t, err)
392
393 config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
394 config.StatementCacheCapacity = 0
395 config.DescriptionCacheCapacity = 0
396
397 conn := mustConnect(t, config)
398 defer closeConn(t, conn)
399
400 pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
401
402 _, err = conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
403 if err != nil {
404 t.Fatal(err)
405 }
406
407 batch := &pgx.Batch{}
408
409 queryCount := 3
410 for i := 0; i < queryCount; i++ {
411 batch.Queue("ps1", 5)
412 }
413
414 br := conn.SendBatch(ctx, batch)
415
416 for i := 0; i < queryCount; i++ {
417 rows, err := br.Query()
418 if err != nil {
419 t.Fatal(err)
420 }
421
422 for k := 0; rows.Next(); k++ {
423 var n int
424 if err := rows.Scan(&n); err != nil {
425 t.Fatal(err)
426 }
427 if n != k {
428 t.Fatalf("n => %v, want %v", n, k)
429 }
430 }
431
432 if rows.Err() != nil {
433 t.Fatal(rows.Err())
434 }
435 }
436
437 err = br.Close()
438 if err != nil {
439 t.Fatal(err)
440 }
441
442 ensureConnValid(t, conn)
443 }
444
445 func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
446 t.Parallel()
447
448 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
449 defer cancel()
450
451 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
452
453 batch := &pgx.Batch{}
454 batch.Queue("select n from generate_series(0,5) n")
455 batch.Queue("select n from generate_series(0,5) n")
456
457 br := conn.SendBatch(ctx, batch)
458
459 rows, err := br.Query()
460 if err != nil {
461 t.Error(err)
462 }
463
464 for i := 0; i < 3; i++ {
465 if !rows.Next() {
466 t.Error("expected a row to be available")
467 }
468
469 var n int
470 if err := rows.Scan(&n); err != nil {
471 t.Error(err)
472 }
473 if n != i {
474 t.Errorf("n => %v, want %v", n, i)
475 }
476 }
477
478 rows.Close()
479
480 rows, err = br.Query()
481 if err != nil {
482 t.Error(err)
483 }
484
485 for i := 0; rows.Next(); i++ {
486 var n int
487 if err := rows.Scan(&n); err != nil {
488 t.Error(err)
489 }
490 if n != i {
491 t.Errorf("n => %v, want %v", n, i)
492 }
493 }
494
495 if rows.Err() != nil {
496 t.Error(rows.Err())
497 }
498
499 err = br.Close()
500 if err != nil {
501 t.Fatal(err)
502 }
503
504 })
505 }
506
507 func TestConnSendBatchQueryError(t *testing.T) {
508 t.Parallel()
509
510 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
511 defer cancel()
512
513 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
514
515 batch := &pgx.Batch{}
516 batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
517 batch.Queue("select n from generate_series(0,5) n")
518
519 br := conn.SendBatch(ctx, batch)
520
521 rows, err := br.Query()
522 if err != nil {
523 t.Error(err)
524 }
525
526 for i := 0; rows.Next(); i++ {
527 var n int
528 if err := rows.Scan(&n); err != nil {
529 t.Error(err)
530 }
531 if n != i {
532 t.Errorf("n => %v, want %v", n, i)
533 }
534 }
535
536 if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
537 t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
538 }
539
540 err = br.Close()
541 if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
542 t.Errorf("br.Close() => %v, want error code %v", err, 22012)
543 }
544
545 })
546 }
547
548 func TestConnSendBatchQuerySyntaxError(t *testing.T) {
549 t.Parallel()
550
551 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
552 defer cancel()
553
554 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
555
556 batch := &pgx.Batch{}
557 batch.Queue("select 1 1")
558
559 br := conn.SendBatch(ctx, batch)
560
561 var n int32
562 err := br.QueryRow().Scan(&n)
563 if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
564 t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
565 }
566
567 err = br.Close()
568 if err == nil {
569 t.Error("Expected error")
570 }
571
572 })
573 }
574
575 func TestConnSendBatchQueryRowInsert(t *testing.T) {
576 t.Parallel()
577
578 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
579 defer cancel()
580
581 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
582
583 sql := `create temporary table ledger(
584 id serial primary key,
585 description varchar not null,
586 amount int not null
587 );`
588 mustExec(t, conn, sql)
589
590 batch := &pgx.Batch{}
591 batch.Queue("select 1")
592 batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
593
594 br := conn.SendBatch(ctx, batch)
595
596 var value int
597 err := br.QueryRow().Scan(&value)
598 if err != nil {
599 t.Error(err)
600 }
601
602 ct, err := br.Exec()
603 if err != nil {
604 t.Error(err)
605 }
606 if ct.RowsAffected() != 2 {
607 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
608 }
609
610 br.Close()
611
612 })
613 }
614
615 func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
616 t.Parallel()
617
618 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
619 defer cancel()
620
621 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
622
623 sql := `create temporary table ledger(
624 id serial primary key,
625 description varchar not null,
626 amount int not null
627 );`
628 mustExec(t, conn, sql)
629
630 batch := &pgx.Batch{}
631 batch.Queue("select 1 union all select 2 union all select 3")
632 batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
633
634 br := conn.SendBatch(ctx, batch)
635
636 rows, err := br.Query()
637 if err != nil {
638 t.Error(err)
639 }
640 rows.Close()
641
642 ct, err := br.Exec()
643 if err != nil {
644 t.Error(err)
645 }
646 if ct.RowsAffected() != 2 {
647 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
648 }
649
650 br.Close()
651
652 })
653 }
654
655 func TestTxSendBatch(t *testing.T) {
656 t.Parallel()
657
658 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
659 defer cancel()
660
661 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
662
663 sql := `create temporary table ledger1(
664 id serial primary key,
665 description varchar not null
666 );`
667 mustExec(t, conn, sql)
668
669 sql = `create temporary table ledger2(
670 id int primary key,
671 amount int not null
672 );`
673 mustExec(t, conn, sql)
674
675 tx, _ := conn.Begin(ctx)
676 batch := &pgx.Batch{}
677 batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
678
679 br := tx.SendBatch(context.Background(), batch)
680
681 var id int
682 err := br.QueryRow().Scan(&id)
683 if err != nil {
684 t.Error(err)
685 }
686 br.Close()
687
688 batch = &pgx.Batch{}
689 batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
690 batch.Queue("select amount from ledger2 where id = $1", id)
691
692 br = tx.SendBatch(ctx, batch)
693
694 ct, err := br.Exec()
695 if err != nil {
696 t.Error(err)
697 }
698 if ct.RowsAffected() != 1 {
699 t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
700 }
701
702 var amount int
703 err = br.QueryRow().Scan(&amount)
704 if err != nil {
705 t.Error(err)
706 }
707
708 br.Close()
709 tx.Commit(ctx)
710
711 var count int
712 conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id).Scan(&count)
713 if count != 1 {
714 t.Errorf("count => %v, want %v", count, 1)
715 }
716
717 err = br.Close()
718 if err != nil {
719 t.Fatal(err)
720 }
721
722 })
723 }
724
725 func TestTxSendBatchRollback(t *testing.T) {
726 t.Parallel()
727
728 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
729 defer cancel()
730
731 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
732
733 sql := `create temporary table ledger1(
734 id serial primary key,
735 description varchar not null
736 );`
737 mustExec(t, conn, sql)
738
739 tx, _ := conn.Begin(ctx)
740 batch := &pgx.Batch{}
741 batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
742
743 br := tx.SendBatch(ctx, batch)
744
745 var id int
746 err := br.QueryRow().Scan(&id)
747 if err != nil {
748 t.Error(err)
749 }
750 br.Close()
751 tx.Rollback(ctx)
752
753 row := conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id)
754 var count int
755 row.Scan(&count)
756 if count != 0 {
757 t.Errorf("count => %v, want %v", count, 0)
758 }
759
760 })
761 }
762
763
764 func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
765 t.Parallel()
766
767 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
768 defer cancel()
769
770 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
771 batch := &pgx.Batch{}
772 batch.Queue("select 4 / $1::int", 0)
773
774 batchResult := conn.SendBatch(ctx, batch)
775
776 _, execErr := batchResult.Exec()
777 require.Error(t, execErr)
778
779 closeErr := batchResult.Close()
780 require.Equal(t, execErr, closeErr)
781
782
783 _, err := conn.Exec(ctx, "select 1")
784 require.NoError(t, err)
785 })
786 }
787
788 func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) {
789 t.Parallel()
790
791 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
792 defer cancel()
793
794 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
795 batch := &pgx.Batch{}
796 batch.Queue("select 4 / n from generate_series(-2, 2) n")
797
798 batchResult := conn.SendBatch(ctx, batch)
799
800 _, execErr := batchResult.Exec()
801 require.Error(t, execErr)
802
803 closeErr := batchResult.Close()
804 require.Equal(t, execErr, closeErr)
805
806
807 _, err := conn.Exec(ctx, "select 1")
808 require.NoError(t, err)
809 })
810 }
811
812 func TestConnBeginBatchDeferredError(t *testing.T) {
813 t.Parallel()
814
815 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
816 defer cancel()
817
818 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
819
820 pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
821
822 mustExec(t, conn, `create temporary table t (
823 id text primary key,
824 n int not null,
825 unique (n) deferrable initially deferred
826 );
827
828 insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
829
830 batch := &pgx.Batch{}
831
832 batch.Queue(`update t set n=n+1 where id='b' returning *`)
833
834 br := conn.SendBatch(ctx, batch)
835
836 rows, err := br.Query()
837 if err != nil {
838 t.Error(err)
839 }
840
841 for rows.Next() {
842 var id string
843 var n int32
844 err = rows.Scan(&id, &n)
845 if err != nil {
846 t.Fatal(err)
847 }
848 }
849
850 err = br.Close()
851 if err == nil {
852 t.Fatal("expected error 23505 but got none")
853 }
854
855 if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
856 t.Fatalf("expected error 23505, got %v", err)
857 }
858
859 })
860 }
861
862 func TestConnSendBatchNoStatementCache(t *testing.T) {
863 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
864 defer cancel()
865
866 config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
867 config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
868 config.StatementCacheCapacity = 0
869 config.DescriptionCacheCapacity = 0
870
871 conn := mustConnect(t, config)
872 defer closeConn(t, conn)
873
874 testConnSendBatch(t, ctx, conn, 3)
875 }
876
877 func TestConnSendBatchPrepareStatementCache(t *testing.T) {
878 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
879 defer cancel()
880
881 config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
882 config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
883 config.StatementCacheCapacity = 32
884
885 conn := mustConnect(t, config)
886 defer closeConn(t, conn)
887
888 testConnSendBatch(t, ctx, conn, 3)
889 }
890
891 func TestConnSendBatchDescribeStatementCache(t *testing.T) {
892 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
893 defer cancel()
894
895 config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
896 config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
897 config.DescriptionCacheCapacity = 32
898
899 conn := mustConnect(t, config)
900 defer closeConn(t, conn)
901
902 testConnSendBatch(t, ctx, conn, 3)
903 }
904
905 func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryCount int) {
906 batch := &pgx.Batch{}
907 for j := 0; j < queryCount; j++ {
908 batch.Queue("select n from generate_series(0,5) n")
909 }
910
911 br := conn.SendBatch(ctx, batch)
912
913 for j := 0; j < queryCount; j++ {
914 rows, err := br.Query()
915 require.NoError(t, err)
916
917 for k := 0; rows.Next(); k++ {
918 var n int
919 err := rows.Scan(&n)
920 require.NoError(t, err)
921 require.Equal(t, k, n)
922 }
923
924 require.NoError(t, rows.Err())
925 }
926
927 err := br.Close()
928 require.NoError(t, err)
929 }
930
931 func TestSendBatchSimpleProtocol(t *testing.T) {
932 t.Parallel()
933
934 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
935 defer cancel()
936
937 config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
938 config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
939
940 conn := mustConnect(t, config)
941 defer closeConn(t, conn)
942
943 var batch pgx.Batch
944 batch.Queue("SELECT 1::int")
945 batch.Queue("SELECT 2::int; SELECT $1::int", 3)
946 results := conn.SendBatch(ctx, &batch)
947 rows, err := results.Query()
948 assert.NoError(t, err)
949 assert.True(t, rows.Next())
950 values, err := rows.Values()
951 assert.NoError(t, err)
952 assert.EqualValues(t, 1, values[0])
953 assert.False(t, rows.Next())
954
955 rows, err = results.Query()
956 assert.NoError(t, err)
957 assert.True(t, rows.Next())
958 values, err = rows.Values()
959 assert.NoError(t, err)
960 assert.EqualValues(t, 2, values[0])
961 assert.False(t, rows.Next())
962
963 rows, err = results.Query()
964 assert.NoError(t, err)
965 assert.True(t, rows.Next())
966 values, err = rows.Values()
967 assert.NoError(t, err)
968 assert.EqualValues(t, 3, values[0])
969 assert.False(t, rows.Next())
970 }
971
972 func ExampleConn_SendBatch() {
973 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
974 defer cancel()
975
976 conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
977 if err != nil {
978 fmt.Printf("Unable to establish connection: %v", err)
979 return
980 }
981
982 batch := &pgx.Batch{}
983 batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
984 var n int32
985 err := row.Scan(&n)
986 if err != nil {
987 return err
988 }
989
990 fmt.Println(n)
991
992 return err
993 })
994
995 batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error {
996 var n int32
997 err := row.Scan(&n)
998 if err != nil {
999 return err
1000 }
1001
1002 fmt.Println(n)
1003
1004 return err
1005 })
1006
1007 batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error {
1008 var n int32
1009 err := row.Scan(&n)
1010 if err != nil {
1011 return err
1012 }
1013
1014 fmt.Println(n)
1015
1016 return err
1017 })
1018
1019 err = conn.SendBatch(ctx, batch).Close()
1020 if err != nil {
1021 fmt.Printf("SendBatch error: %v", err)
1022 return
1023 }
1024
1025
1026
1027
1028
1029 }
1030
View as plain text