1 package pgx_test
2
3 import (
4 "context"
5 "errors"
6 "os"
7 "testing"
8 "time"
9
10 "github.com/jackc/pgx/v5"
11 "github.com/jackc/pgx/v5/pgconn"
12 "github.com/jackc/pgx/v5/pgxtest"
13 "github.com/stretchr/testify/require"
14 )
15
16 func TestTransactionSuccessfulCommit(t *testing.T) {
17 t.Parallel()
18
19 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
20 defer closeConn(t, conn)
21
22 createSql := `
23 create temporary table foo(
24 id integer,
25 unique (id)
26 );
27 `
28
29 if _, err := conn.Exec(context.Background(), createSql); err != nil {
30 t.Fatalf("Failed to create table: %v", err)
31 }
32
33 tx, err := conn.Begin(context.Background())
34 if err != nil {
35 t.Fatalf("conn.Begin failed: %v", err)
36 }
37
38 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
39 if err != nil {
40 t.Fatalf("tx.Exec failed: %v", err)
41 }
42
43 err = tx.Commit(context.Background())
44 if err != nil {
45 t.Fatalf("tx.Commit failed: %v", err)
46 }
47
48 var n int64
49 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
50 if err != nil {
51 t.Fatalf("QueryRow Scan failed: %v", err)
52 }
53 if n != 1 {
54 t.Fatalf("Did not receive correct number of rows: %v", n)
55 }
56 }
57
58 func TestTxCommitWhenTxBroken(t *testing.T) {
59 t.Parallel()
60
61 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
62 defer closeConn(t, conn)
63
64 createSql := `
65 create temporary table foo(
66 id integer,
67 unique (id)
68 );
69 `
70
71 if _, err := conn.Exec(context.Background(), createSql); err != nil {
72 t.Fatalf("Failed to create table: %v", err)
73 }
74
75 tx, err := conn.Begin(context.Background())
76 if err != nil {
77 t.Fatalf("conn.Begin failed: %v", err)
78 }
79
80 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
81 t.Fatalf("tx.Exec failed: %v", err)
82 }
83
84
85 if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
86 t.Fatal("Unexpected success")
87 }
88
89 err = tx.Commit(context.Background())
90 if err != pgx.ErrTxCommitRollback {
91 t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
92 }
93
94 var n int64
95 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
96 if err != nil {
97 t.Fatalf("QueryRow Scan failed: %v", err)
98 }
99 if n != 0 {
100 t.Fatalf("Did not receive correct number of rows: %v", n)
101 }
102 }
103
104 func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) {
105 t.Parallel()
106
107 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
108 defer closeConn(t, conn)
109
110 pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
111
112 createSql := `
113 create temporary table foo(
114 id integer,
115 unique (id) initially deferred
116 );
117 `
118
119 if _, err := conn.Exec(context.Background(), createSql); err != nil {
120 t.Fatalf("Failed to create table: %v", err)
121 }
122
123 tx, err := conn.Begin(context.Background())
124 if err != nil {
125 t.Fatalf("conn.Begin failed: %v", err)
126 }
127
128 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
129 t.Fatalf("tx.Exec failed: %v", err)
130 }
131
132 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
133 t.Fatalf("tx.Exec failed: %v", err)
134 }
135
136 err = tx.Commit(context.Background())
137 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" {
138 t.Fatalf("Expected unique constraint violation 23505, got %#v", err)
139 }
140
141 var n int64
142 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
143 if err != nil {
144 t.Fatalf("QueryRow Scan failed: %v", err)
145 }
146 if n != 0 {
147 t.Fatalf("Did not receive correct number of rows: %v", n)
148 }
149 }
150
151 func TestTxCommitSerializationFailure(t *testing.T) {
152 t.Parallel()
153
154 c1 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
155 defer closeConn(t, c1)
156
157 if c1.PgConn().ParameterStatus("crdb_version") != "" {
158 t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/60754)")
159 }
160
161 c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
162 defer closeConn(t, c2)
163
164 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
165 defer cancel()
166
167 c1.Exec(ctx, `drop table if exists tx_serializable_sums`)
168 _, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`)
169 if err != nil {
170 t.Fatalf("Unable to create temporary table: %v", err)
171 }
172 defer c1.Exec(ctx, `drop table tx_serializable_sums`)
173
174 tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
175 if err != nil {
176 t.Fatalf("Begin failed: %v", err)
177 }
178 defer tx1.Rollback(ctx)
179
180 tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
181 if err != nil {
182 t.Fatalf("Begin failed: %v", err)
183 }
184 defer tx2.Rollback(ctx)
185
186 _, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
187 if err != nil {
188 t.Fatalf("Exec failed: %v", err)
189 }
190
191 _, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
192 if err != nil {
193 t.Fatalf("Exec failed: %v", err)
194 }
195
196 err = tx1.Commit(ctx)
197 if err != nil {
198 t.Fatalf("Commit failed: %v", err)
199 }
200
201 err = tx2.Commit(ctx)
202 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" {
203 t.Fatalf("Expected serialization error 40001, got %#v", err)
204 }
205
206 ensureConnValid(t, c1)
207 ensureConnValid(t, c2)
208 }
209
210 func TestTransactionSuccessfulRollback(t *testing.T) {
211 t.Parallel()
212
213 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
214 defer closeConn(t, conn)
215
216 createSql := `
217 create temporary table foo(
218 id integer,
219 unique (id)
220 );
221 `
222
223 if _, err := conn.Exec(context.Background(), createSql); err != nil {
224 t.Fatalf("Failed to create table: %v", err)
225 }
226
227 tx, err := conn.Begin(context.Background())
228 if err != nil {
229 t.Fatalf("conn.Begin failed: %v", err)
230 }
231
232 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
233 if err != nil {
234 t.Fatalf("tx.Exec failed: %v", err)
235 }
236
237 err = tx.Rollback(context.Background())
238 if err != nil {
239 t.Fatalf("tx.Rollback failed: %v", err)
240 }
241
242 var n int64
243 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
244 if err != nil {
245 t.Fatalf("QueryRow Scan failed: %v", err)
246 }
247 if n != 0 {
248 t.Fatalf("Did not receive correct number of rows: %v", n)
249 }
250 }
251
252 func TestTransactionRollbackFailsClosesConnection(t *testing.T) {
253 t.Parallel()
254
255 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
256 defer closeConn(t, conn)
257
258 ctx, cancel := context.WithCancel(context.Background())
259
260 tx, err := conn.Begin(ctx)
261 require.NoError(t, err)
262
263 cancel()
264
265 err = tx.Rollback(ctx)
266 require.Error(t, err)
267
268 require.True(t, conn.IsClosed())
269 }
270
271 func TestBeginIsoLevels(t *testing.T) {
272 t.Parallel()
273
274 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
275 defer closeConn(t, conn)
276
277 pgxtest.SkipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)")
278
279 isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
280 for _, iso := range isoLevels {
281 tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: iso})
282 if err != nil {
283 t.Fatalf("conn.Begin failed: %v", err)
284 }
285
286 var level pgx.TxIsoLevel
287 conn.QueryRow(context.Background(), "select current_setting('transaction_isolation')").Scan(&level)
288 if level != iso {
289 t.Errorf("Expected to be in isolation level %v but was %v", iso, level)
290 }
291
292 err = tx.Rollback(context.Background())
293 if err != nil {
294 t.Fatalf("tx.Rollback failed: %v", err)
295 }
296 }
297 }
298
299 func TestBeginFunc(t *testing.T) {
300 t.Parallel()
301
302 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
303 defer closeConn(t, conn)
304
305 createSql := `
306 create temporary table foo(
307 id integer,
308 unique (id)
309 );
310 `
311
312 _, err := conn.Exec(context.Background(), createSql)
313 require.NoError(t, err)
314
315 err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
316 _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
317 require.NoError(t, err)
318 return nil
319 })
320 require.NoError(t, err)
321
322 var n int64
323 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
324 require.NoError(t, err)
325 require.EqualValues(t, 1, n)
326 }
327
328 func TestBeginFuncRollbackOnError(t *testing.T) {
329 t.Parallel()
330
331 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
332 defer closeConn(t, conn)
333
334 createSql := `
335 create temporary table foo(
336 id integer,
337 unique (id)
338 );
339 `
340
341 _, err := conn.Exec(context.Background(), createSql)
342 require.NoError(t, err)
343
344 err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error {
345 _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
346 require.NoError(t, err)
347 return errors.New("some error")
348 })
349 require.EqualError(t, err, "some error")
350
351 var n int64
352 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
353 require.NoError(t, err)
354 require.EqualValues(t, 0, n)
355 }
356
357 func TestBeginReadOnly(t *testing.T) {
358 t.Parallel()
359
360 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
361 defer closeConn(t, conn)
362
363 tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly})
364 if err != nil {
365 t.Fatalf("conn.Begin failed: %v", err)
366 }
367 defer tx.Rollback(context.Background())
368
369 _, err = conn.Exec(context.Background(), "create table foo(id serial primary key)")
370 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" {
371 t.Errorf("Expected error SQLSTATE 25006, but got %#v", err)
372 }
373 }
374
375 func TestBeginTxBeginQuery(t *testing.T) {
376 t.Parallel()
377
378 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
379 defer cancel()
380
381 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
382 tx, err := conn.BeginTx(ctx, pgx.TxOptions{BeginQuery: "begin read only"})
383 require.NoError(t, err)
384 defer tx.Rollback(ctx)
385
386 var readOnly bool
387 conn.QueryRow(ctx, "select current_setting('transaction_read_only')::bool").Scan(&readOnly)
388 require.True(t, readOnly)
389
390 err = tx.Rollback(ctx)
391 require.NoError(t, err)
392 })
393 }
394
395 func TestTxNestedTransactionCommit(t *testing.T) {
396 t.Parallel()
397
398 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
399 defer closeConn(t, conn)
400
401 createSql := `
402 create temporary table foo(
403 id integer,
404 unique (id)
405 );
406 `
407
408 if _, err := conn.Exec(context.Background(), createSql); err != nil {
409 t.Fatalf("Failed to create table: %v", err)
410 }
411
412 tx, err := conn.Begin(context.Background())
413 if err != nil {
414 t.Fatal(err)
415 }
416
417 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
418 if err != nil {
419 t.Fatalf("tx.Exec failed: %v", err)
420 }
421
422 nestedTx, err := tx.Begin(context.Background())
423 if err != nil {
424 t.Fatal(err)
425 }
426
427 _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
428 if err != nil {
429 t.Fatalf("nestedTx.Exec failed: %v", err)
430 }
431
432 doubleNestedTx, err := nestedTx.Begin(context.Background())
433 if err != nil {
434 t.Fatal(err)
435 }
436
437 _, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)")
438 if err != nil {
439 t.Fatalf("doubleNestedTx.Exec failed: %v", err)
440 }
441
442 err = doubleNestedTx.Commit(context.Background())
443 if err != nil {
444 t.Fatalf("doubleNestedTx.Commit failed: %v", err)
445 }
446
447 err = nestedTx.Commit(context.Background())
448 if err != nil {
449 t.Fatalf("nestedTx.Commit failed: %v", err)
450 }
451
452 err = tx.Commit(context.Background())
453 if err != nil {
454 t.Fatalf("tx.Commit failed: %v", err)
455 }
456
457 var n int64
458 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
459 if err != nil {
460 t.Fatalf("QueryRow Scan failed: %v", err)
461 }
462 if n != 3 {
463 t.Fatalf("Did not receive correct number of rows: %v", n)
464 }
465 }
466
467 func TestTxNestedTransactionRollback(t *testing.T) {
468 t.Parallel()
469
470 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
471 defer closeConn(t, conn)
472
473 createSql := `
474 create temporary table foo(
475 id integer,
476 unique (id)
477 );
478 `
479
480 if _, err := conn.Exec(context.Background(), createSql); err != nil {
481 t.Fatalf("Failed to create table: %v", err)
482 }
483
484 tx, err := conn.Begin(context.Background())
485 if err != nil {
486 t.Fatal(err)
487 }
488
489 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
490 if err != nil {
491 t.Fatalf("tx.Exec failed: %v", err)
492 }
493
494 nestedTx, err := tx.Begin(context.Background())
495 if err != nil {
496 t.Fatal(err)
497 }
498
499 _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
500 if err != nil {
501 t.Fatalf("nestedTx.Exec failed: %v", err)
502 }
503
504 err = nestedTx.Rollback(context.Background())
505 if err != nil {
506 t.Fatalf("nestedTx.Rollback failed: %v", err)
507 }
508
509 _, err = tx.Exec(context.Background(), "insert into foo(id) values (3)")
510 if err != nil {
511 t.Fatalf("tx.Exec failed: %v", err)
512 }
513
514 err = tx.Commit(context.Background())
515 if err != nil {
516 t.Fatalf("tx.Commit failed: %v", err)
517 }
518
519 var n int64
520 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
521 if err != nil {
522 t.Fatalf("QueryRow Scan failed: %v", err)
523 }
524 if n != 2 {
525 t.Fatalf("Did not receive correct number of rows: %v", n)
526 }
527 }
528
529 func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
530 t.Parallel()
531
532 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
533 defer closeConn(t, db)
534
535 createSql := `
536 create temporary table foo(
537 id integer,
538 unique (id)
539 );
540 `
541
542 _, err := db.Exec(context.Background(), createSql)
543 require.NoError(t, err)
544
545 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
546 _, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
547 require.NoError(t, err)
548
549 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
550 _, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
551 require.NoError(t, err)
552
553 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
554 _, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
555 require.NoError(t, err)
556 return nil
557 })
558 require.NoError(t, err)
559
560 return nil
561 })
562 require.NoError(t, err)
563 return nil
564 })
565 require.NoError(t, err)
566
567 var n int64
568 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
569 require.NoError(t, err)
570 require.EqualValues(t, 3, n)
571 }
572
573 func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
574 t.Parallel()
575
576 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
577 defer closeConn(t, db)
578
579 createSql := `
580 create temporary table foo(
581 id integer,
582 unique (id)
583 );
584 `
585
586 _, err := db.Exec(context.Background(), createSql)
587 require.NoError(t, err)
588
589 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
590 _, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
591 require.NoError(t, err)
592
593 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error {
594 _, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
595 require.NoError(t, err)
596 return errors.New("do a rollback")
597 })
598 require.EqualError(t, err, "do a rollback")
599
600 _, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
601 require.NoError(t, err)
602
603 return nil
604 })
605 require.NoError(t, err)
606
607 var n int64
608 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
609 require.NoError(t, err)
610 require.EqualValues(t, 2, n)
611 }
612
613 func TestTxSendBatchClosed(t *testing.T) {
614 t.Parallel()
615
616 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
617 defer closeConn(t, db)
618
619 tx, err := db.Begin(context.Background())
620 require.NoError(t, err)
621 defer tx.Rollback(context.Background())
622
623 err = tx.Commit(context.Background())
624 require.NoError(t, err)
625
626 batch := &pgx.Batch{}
627 batch.Queue("select 1")
628 batch.Queue("select 2")
629 batch.Queue("select 3")
630
631 br := tx.SendBatch(context.Background(), batch)
632 defer br.Close()
633
634 var n int
635
636 _, err = br.Exec()
637 require.Error(t, err)
638
639 err = br.QueryRow().Scan(&n)
640 require.Error(t, err)
641
642 _, err = br.Query()
643 require.Error(t, err)
644 }
645
View as plain text