1 package pgx_test
2
3 import (
4 "context"
5 "errors"
6 "os"
7 "testing"
8 "time"
9
10 "github.com/jackc/pgconn"
11 "github.com/jackc/pgx/v4"
12 "github.com/stretchr/testify/require"
13 )
14
15 func TestTransactionSuccessfulCommit(t *testing.T) {
16 t.Parallel()
17
18 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
19 defer closeConn(t, conn)
20
21 createSql := `
22 create temporary table foo(
23 id integer,
24 unique (id)
25 );
26 `
27
28 if _, err := conn.Exec(context.Background(), createSql); err != nil {
29 t.Fatalf("Failed to create table: %v", err)
30 }
31
32 tx, err := conn.Begin(context.Background())
33 if err != nil {
34 t.Fatalf("conn.Begin failed: %v", err)
35 }
36
37 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
38 if err != nil {
39 t.Fatalf("tx.Exec failed: %v", err)
40 }
41
42 err = tx.Commit(context.Background())
43 if err != nil {
44 t.Fatalf("tx.Commit failed: %v", err)
45 }
46
47 var n int64
48 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
49 if err != nil {
50 t.Fatalf("QueryRow Scan failed: %v", err)
51 }
52 if n != 1 {
53 t.Fatalf("Did not receive correct number of rows: %v", n)
54 }
55 }
56
57 func TestTxCommitWhenTxBroken(t *testing.T) {
58 t.Parallel()
59
60 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
61 defer closeConn(t, conn)
62
63 createSql := `
64 create temporary table foo(
65 id integer,
66 unique (id)
67 );
68 `
69
70 if _, err := conn.Exec(context.Background(), createSql); err != nil {
71 t.Fatalf("Failed to create table: %v", err)
72 }
73
74 tx, err := conn.Begin(context.Background())
75 if err != nil {
76 t.Fatalf("conn.Begin failed: %v", err)
77 }
78
79 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
80 t.Fatalf("tx.Exec failed: %v", err)
81 }
82
83
84 if _, err := tx.Exec(context.Background(), "syntax error"); err == nil {
85 t.Fatal("Unexpected success")
86 }
87
88 err = tx.Commit(context.Background())
89 if err != pgx.ErrTxCommitRollback {
90 t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err)
91 }
92
93 var n int64
94 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
95 if err != nil {
96 t.Fatalf("QueryRow Scan failed: %v", err)
97 }
98 if n != 0 {
99 t.Fatalf("Did not receive correct number of rows: %v", n)
100 }
101 }
102
103 func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) {
104 t.Parallel()
105
106 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
107 defer closeConn(t, conn)
108
109 skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
110
111 createSql := `
112 create temporary table foo(
113 id integer,
114 unique (id) initially deferred
115 );
116 `
117
118 if _, err := conn.Exec(context.Background(), createSql); err != nil {
119 t.Fatalf("Failed to create table: %v", err)
120 }
121
122 tx, err := conn.Begin(context.Background())
123 if err != nil {
124 t.Fatalf("conn.Begin failed: %v", err)
125 }
126
127 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
128 t.Fatalf("tx.Exec failed: %v", err)
129 }
130
131 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil {
132 t.Fatalf("tx.Exec failed: %v", err)
133 }
134
135 err = tx.Commit(context.Background())
136 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" {
137 t.Fatalf("Expected unique constraint violation 23505, got %#v", err)
138 }
139
140 var n int64
141 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
142 if err != nil {
143 t.Fatalf("QueryRow Scan failed: %v", err)
144 }
145 if n != 0 {
146 t.Fatalf("Did not receive correct number of rows: %v", n)
147 }
148 }
149
150 func TestTxCommitSerializationFailure(t *testing.T) {
151 t.Parallel()
152
153 c1 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
154 defer closeConn(t, c1)
155
156 if c1.PgConn().ParameterStatus("crdb_version") != "" {
157 t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/60754)")
158 }
159
160 c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
161 defer closeConn(t, c2)
162
163 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
164 defer cancel()
165
166 c1.Exec(ctx, `drop table if exists tx_serializable_sums`)
167 _, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`)
168 if err != nil {
169 t.Fatalf("Unable to create temporary table: %v", err)
170 }
171 defer c1.Exec(ctx, `drop table tx_serializable_sums`)
172
173 tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
174 if err != nil {
175 t.Fatalf("Begin failed: %v", err)
176 }
177 defer tx1.Rollback(ctx)
178
179 tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable})
180 if err != nil {
181 t.Fatalf("Begin failed: %v", err)
182 }
183 defer tx2.Rollback(ctx)
184
185 _, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
186 if err != nil {
187 t.Fatalf("Exec failed: %v", err)
188 }
189
190 _, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`)
191 if err != nil {
192 t.Fatalf("Exec failed: %v", err)
193 }
194
195 err = tx1.Commit(ctx)
196 if err != nil {
197 t.Fatalf("Commit failed: %v", err)
198 }
199
200 err = tx2.Commit(ctx)
201 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" {
202 t.Fatalf("Expected serialization error 40001, got %#v", err)
203 }
204
205 ensureConnValid(t, c1)
206 ensureConnValid(t, c2)
207 }
208
209 func TestTransactionSuccessfulRollback(t *testing.T) {
210 t.Parallel()
211
212 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
213 defer closeConn(t, conn)
214
215 createSql := `
216 create temporary table foo(
217 id integer,
218 unique (id)
219 );
220 `
221
222 if _, err := conn.Exec(context.Background(), createSql); err != nil {
223 t.Fatalf("Failed to create table: %v", err)
224 }
225
226 tx, err := conn.Begin(context.Background())
227 if err != nil {
228 t.Fatalf("conn.Begin failed: %v", err)
229 }
230
231 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
232 if err != nil {
233 t.Fatalf("tx.Exec failed: %v", err)
234 }
235
236 err = tx.Rollback(context.Background())
237 if err != nil {
238 t.Fatalf("tx.Rollback failed: %v", err)
239 }
240
241 var n int64
242 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
243 if err != nil {
244 t.Fatalf("QueryRow Scan failed: %v", err)
245 }
246 if n != 0 {
247 t.Fatalf("Did not receive correct number of rows: %v", n)
248 }
249 }
250
251 func TestTransactionRollbackFailsClosesConnection(t *testing.T) {
252 t.Parallel()
253
254 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
255 defer closeConn(t, conn)
256
257 ctx, cancel := context.WithCancel(context.Background())
258
259 tx, err := conn.Begin(ctx)
260 require.NoError(t, err)
261
262 cancel()
263
264 err = tx.Rollback(ctx)
265 require.Error(t, err)
266
267 require.True(t, conn.IsClosed())
268 }
269
270 func TestBeginIsoLevels(t *testing.T) {
271 t.Parallel()
272
273 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
274 defer closeConn(t, conn)
275
276 skipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)")
277
278 isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted}
279 for _, iso := range isoLevels {
280 tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: iso})
281 if err != nil {
282 t.Fatalf("conn.Begin failed: %v", err)
283 }
284
285 var level pgx.TxIsoLevel
286 conn.QueryRow(context.Background(), "select current_setting('transaction_isolation')").Scan(&level)
287 if level != iso {
288 t.Errorf("Expected to be in isolation level %v but was %v", iso, level)
289 }
290
291 err = tx.Rollback(context.Background())
292 if err != nil {
293 t.Fatalf("tx.Rollback failed: %v", err)
294 }
295 }
296 }
297
298 func TestBeginFunc(t *testing.T) {
299 t.Parallel()
300
301 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
302 defer closeConn(t, conn)
303
304 createSql := `
305 create temporary table foo(
306 id integer,
307 unique (id)
308 );
309 `
310
311 _, err := conn.Exec(context.Background(), createSql)
312 require.NoError(t, err)
313
314 err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
315 _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
316 require.NoError(t, err)
317 return nil
318 })
319 require.NoError(t, err)
320
321 var n int64
322 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
323 require.NoError(t, err)
324 require.EqualValues(t, 1, n)
325 }
326
327 func TestBeginFuncRollbackOnError(t *testing.T) {
328 t.Parallel()
329
330 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
331 defer closeConn(t, conn)
332
333 createSql := `
334 create temporary table foo(
335 id integer,
336 unique (id)
337 );
338 `
339
340 _, err := conn.Exec(context.Background(), createSql)
341 require.NoError(t, err)
342
343 err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
344 _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
345 require.NoError(t, err)
346 return errors.New("some error")
347 })
348 require.EqualError(t, err, "some error")
349
350 var n int64
351 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
352 require.NoError(t, err)
353 require.EqualValues(t, 0, n)
354 }
355
356 func TestBeginReadOnly(t *testing.T) {
357 t.Parallel()
358
359 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
360 defer closeConn(t, conn)
361
362 tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly})
363 if err != nil {
364 t.Fatalf("conn.Begin failed: %v", err)
365 }
366 defer tx.Rollback(context.Background())
367
368 _, err = conn.Exec(context.Background(), "create table foo(id serial primary key)")
369 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" {
370 t.Errorf("Expected error SQLSTATE 25006, but got %#v", err)
371 }
372 }
373
374 func TestTxNestedTransactionCommit(t *testing.T) {
375 t.Parallel()
376
377 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
378 defer closeConn(t, conn)
379
380 createSql := `
381 create temporary table foo(
382 id integer,
383 unique (id)
384 );
385 `
386
387 if _, err := conn.Exec(context.Background(), createSql); err != nil {
388 t.Fatalf("Failed to create table: %v", err)
389 }
390
391 tx, err := conn.Begin(context.Background())
392 if err != nil {
393 t.Fatal(err)
394 }
395
396 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
397 if err != nil {
398 t.Fatalf("tx.Exec failed: %v", err)
399 }
400
401 nestedTx, err := tx.Begin(context.Background())
402 if err != nil {
403 t.Fatal(err)
404 }
405
406 _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
407 if err != nil {
408 t.Fatalf("nestedTx.Exec failed: %v", err)
409 }
410
411 doubleNestedTx, err := nestedTx.Begin(context.Background())
412 if err != nil {
413 t.Fatal(err)
414 }
415
416 _, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)")
417 if err != nil {
418 t.Fatalf("doubleNestedTx.Exec failed: %v", err)
419 }
420
421 err = doubleNestedTx.Commit(context.Background())
422 if err != nil {
423 t.Fatalf("doubleNestedTx.Commit failed: %v", err)
424 }
425
426 err = nestedTx.Commit(context.Background())
427 if err != nil {
428 t.Fatalf("nestedTx.Commit failed: %v", err)
429 }
430
431 err = tx.Commit(context.Background())
432 if err != nil {
433 t.Fatalf("tx.Commit failed: %v", err)
434 }
435
436 var n int64
437 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
438 if err != nil {
439 t.Fatalf("QueryRow Scan failed: %v", err)
440 }
441 if n != 3 {
442 t.Fatalf("Did not receive correct number of rows: %v", n)
443 }
444 }
445
446 func TestTxNestedTransactionRollback(t *testing.T) {
447 t.Parallel()
448
449 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
450 defer closeConn(t, conn)
451
452 createSql := `
453 create temporary table foo(
454 id integer,
455 unique (id)
456 );
457 `
458
459 if _, err := conn.Exec(context.Background(), createSql); err != nil {
460 t.Fatalf("Failed to create table: %v", err)
461 }
462
463 tx, err := conn.Begin(context.Background())
464 if err != nil {
465 t.Fatal(err)
466 }
467
468 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)")
469 if err != nil {
470 t.Fatalf("tx.Exec failed: %v", err)
471 }
472
473 nestedTx, err := tx.Begin(context.Background())
474 if err != nil {
475 t.Fatal(err)
476 }
477
478 _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)")
479 if err != nil {
480 t.Fatalf("nestedTx.Exec failed: %v", err)
481 }
482
483 err = nestedTx.Rollback(context.Background())
484 if err != nil {
485 t.Fatalf("nestedTx.Rollback failed: %v", err)
486 }
487
488 _, err = tx.Exec(context.Background(), "insert into foo(id) values (3)")
489 if err != nil {
490 t.Fatalf("tx.Exec failed: %v", err)
491 }
492
493 err = tx.Commit(context.Background())
494 if err != nil {
495 t.Fatalf("tx.Commit failed: %v", err)
496 }
497
498 var n int64
499 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
500 if err != nil {
501 t.Fatalf("QueryRow Scan failed: %v", err)
502 }
503 if n != 2 {
504 t.Fatalf("Did not receive correct number of rows: %v", n)
505 }
506 }
507
508 func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
509 t.Parallel()
510
511 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
512 defer closeConn(t, db)
513
514 createSql := `
515 create temporary table foo(
516 id integer,
517 unique (id)
518 );
519 `
520
521 _, err := db.Exec(context.Background(), createSql)
522 require.NoError(t, err)
523
524 err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
525 _, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
526 require.NoError(t, err)
527
528 err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
529 _, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
530 require.NoError(t, err)
531
532 err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
533 _, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
534 require.NoError(t, err)
535 return nil
536 })
537
538 return nil
539 })
540 require.NoError(t, err)
541 return nil
542 })
543 require.NoError(t, err)
544
545 var n int64
546 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
547 require.NoError(t, err)
548 require.EqualValues(t, 3, n)
549 }
550
551 func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
552 t.Parallel()
553
554 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
555 defer closeConn(t, db)
556
557 createSql := `
558 create temporary table foo(
559 id integer,
560 unique (id)
561 );
562 `
563
564 _, err := db.Exec(context.Background(), createSql)
565 require.NoError(t, err)
566
567 err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
568 _, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
569 require.NoError(t, err)
570
571 err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
572 _, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
573 require.NoError(t, err)
574 return errors.New("do a rollback")
575 })
576 require.EqualError(t, err, "do a rollback")
577
578 _, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
579 require.NoError(t, err)
580
581 return nil
582 })
583
584 var n int64
585 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
586 require.NoError(t, err)
587 require.EqualValues(t, 2, n)
588 }
589
590 func TestTxSendBatchClosed(t *testing.T) {
591 t.Parallel()
592
593 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
594 defer closeConn(t, db)
595
596 tx, err := db.Begin(context.Background())
597 require.NoError(t, err)
598 defer tx.Rollback(context.Background())
599
600 err = tx.Commit(context.Background())
601 require.NoError(t, err)
602
603 batch := &pgx.Batch{}
604 batch.Queue("select 1")
605 batch.Queue("select 2")
606 batch.Queue("select 3")
607
608 br := tx.SendBatch(context.Background(), batch)
609 defer br.Close()
610
611 var n int
612
613 _, err = br.Exec()
614 require.Error(t, err)
615
616 err = br.QueryRow().Scan(&n)
617 require.Error(t, err)
618
619 _, err = br.Query()
620 require.Error(t, err)
621 }
622
View as plain text