1 package pgx
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "strconv"
9
10 "github.com/jackc/pgconn"
11 )
12
13
14 type TxIsoLevel string
15
16
17 const (
18 Serializable TxIsoLevel = "serializable"
19 RepeatableRead TxIsoLevel = "repeatable read"
20 ReadCommitted TxIsoLevel = "read committed"
21 ReadUncommitted TxIsoLevel = "read uncommitted"
22 )
23
24
25 type TxAccessMode string
26
27
28 const (
29 ReadWrite TxAccessMode = "read write"
30 ReadOnly TxAccessMode = "read only"
31 )
32
33
34 type TxDeferrableMode string
35
36
37 const (
38 Deferrable TxDeferrableMode = "deferrable"
39 NotDeferrable TxDeferrableMode = "not deferrable"
40 )
41
42
43 type TxOptions struct {
44 IsoLevel TxIsoLevel
45 AccessMode TxAccessMode
46 DeferrableMode TxDeferrableMode
47 }
48
49 var emptyTxOptions TxOptions
50
51 func (txOptions TxOptions) beginSQL() string {
52 if txOptions == emptyTxOptions {
53 return "begin"
54 }
55 buf := &bytes.Buffer{}
56 buf.WriteString("begin")
57 if txOptions.IsoLevel != "" {
58 fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel)
59 }
60 if txOptions.AccessMode != "" {
61 fmt.Fprintf(buf, " %s", txOptions.AccessMode)
62 }
63 if txOptions.DeferrableMode != "" {
64 fmt.Fprintf(buf, " %s", txOptions.DeferrableMode)
65 }
66
67 return buf.String()
68 }
69
70 var ErrTxClosed = errors.New("tx is closed")
71
72
73
74
75 var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
76
77
78
79 func (c *Conn) Begin(ctx context.Context) (Tx, error) {
80 return c.BeginTx(ctx, TxOptions{})
81 }
82
83
84
85 func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
86 _, err := c.Exec(ctx, txOptions.beginSQL())
87 if err != nil {
88
89
90 c.die(errors.New("failed to begin transaction"))
91 return nil, err
92 }
93
94 return &dbTx{conn: c}, nil
95 }
96
97
98
99
100 func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
101 return c.BeginTxFunc(ctx, TxOptions{}, f)
102 }
103
104
105
106
107
108 func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) {
109 var tx Tx
110 tx, err = c.BeginTx(ctx, txOptions)
111 if err != nil {
112 return err
113 }
114 defer func() {
115 rollbackErr := tx.Rollback(ctx)
116 if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
117 err = rollbackErr
118 }
119 }()
120
121 fErr := f(tx)
122 if fErr != nil {
123 _ = tx.Rollback(ctx)
124 return fErr
125 }
126
127 return tx.Commit(ctx)
128 }
129
130
131
132
133
134
135
136
137 type Tx interface {
138
139 Begin(ctx context.Context) (Tx, error)
140
141
142
143 BeginFunc(ctx context.Context, f func(Tx) error) (err error)
144
145
146
147
148
149 Commit(ctx context.Context) error
150
151
152
153
154
155 Rollback(ctx context.Context) error
156
157 CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error)
158 SendBatch(ctx context.Context, b *Batch) BatchResults
159 LargeObjects() LargeObjects
160
161 Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error)
162
163 Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error)
164 Query(ctx context.Context, sql string, args ...interface{}) (Rows, error)
165 QueryRow(ctx context.Context, sql string, args ...interface{}) Row
166 QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error)
167
168
169 Conn() *Conn
170 }
171
172
173
174
175
176 type dbTx struct {
177 conn *Conn
178 err error
179 savepointNum int64
180 closed bool
181 }
182
183
184 func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
185 if tx.closed {
186 return nil, ErrTxClosed
187 }
188
189 tx.savepointNum++
190 _, err := tx.conn.Exec(ctx, "savepoint sp_"+strconv.FormatInt(tx.savepointNum, 10))
191 if err != nil {
192 return nil, err
193 }
194
195 return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil
196 }
197
198 func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
199 if tx.closed {
200 return ErrTxClosed
201 }
202
203 var savepoint Tx
204 savepoint, err = tx.Begin(ctx)
205 if err != nil {
206 return err
207 }
208 defer func() {
209 rollbackErr := savepoint.Rollback(ctx)
210 if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
211 err = rollbackErr
212 }
213 }()
214
215 fErr := f(savepoint)
216 if fErr != nil {
217 _ = savepoint.Rollback(ctx)
218 return fErr
219 }
220
221 return savepoint.Commit(ctx)
222 }
223
224
225 func (tx *dbTx) Commit(ctx context.Context) error {
226 if tx.closed {
227 return ErrTxClosed
228 }
229
230 commandTag, err := tx.conn.Exec(ctx, "commit")
231 tx.closed = true
232 if err != nil {
233 if tx.conn.PgConn().TxStatus() != 'I' {
234 _ = tx.conn.Close(ctx)
235 }
236 return err
237 }
238 if string(commandTag) == "ROLLBACK" {
239 return ErrTxCommitRollback
240 }
241
242 return nil
243 }
244
245
246
247
248
249 func (tx *dbTx) Rollback(ctx context.Context) error {
250 if tx.closed {
251 return ErrTxClosed
252 }
253
254 _, err := tx.conn.Exec(ctx, "rollback")
255 tx.closed = true
256 if err != nil {
257
258 tx.conn.die(fmt.Errorf("rollback failed: %w", err))
259 return err
260 }
261
262 return nil
263 }
264
265
266 func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
267 if tx.closed {
268 return pgconn.CommandTag{}, ErrTxClosed
269 }
270
271 return tx.conn.Exec(ctx, sql, arguments...)
272 }
273
274
275 func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
276 if tx.closed {
277 return nil, ErrTxClosed
278 }
279
280 return tx.conn.Prepare(ctx, name, sql)
281 }
282
283
284 func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
285 if tx.closed {
286
287 err := ErrTxClosed
288 return &connRows{closed: true, err: err}, err
289 }
290
291 return tx.conn.Query(ctx, sql, args...)
292 }
293
294
295 func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
296 rows, _ := tx.Query(ctx, sql, args...)
297 return (*connRow)(rows.(*connRows))
298 }
299
300
301 func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
302 if tx.closed {
303 return nil, ErrTxClosed
304 }
305
306 return tx.conn.QueryFunc(ctx, sql, args, scans, f)
307 }
308
309
310 func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
311 if tx.closed {
312 return 0, ErrTxClosed
313 }
314
315 return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc)
316 }
317
318
319 func (tx *dbTx) SendBatch(ctx context.Context, b *Batch) BatchResults {
320 if tx.closed {
321 return &batchResults{err: ErrTxClosed}
322 }
323
324 return tx.conn.SendBatch(ctx, b)
325 }
326
327
328 func (tx *dbTx) LargeObjects() LargeObjects {
329 return LargeObjects{tx: tx}
330 }
331
332 func (tx *dbTx) Conn() *Conn {
333 return tx.conn
334 }
335
336
337 type dbSimulatedNestedTx struct {
338 tx Tx
339 savepointNum int64
340 closed bool
341 }
342
343
344 func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) {
345 if sp.closed {
346 return nil, ErrTxClosed
347 }
348
349 return sp.tx.Begin(ctx)
350 }
351
352 func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
353 if sp.closed {
354 return ErrTxClosed
355 }
356
357 return sp.tx.BeginFunc(ctx, f)
358 }
359
360
361 func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error {
362 if sp.closed {
363 return ErrTxClosed
364 }
365
366 _, err := sp.Exec(ctx, "release savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10))
367 sp.closed = true
368 return err
369 }
370
371
372
373
374 func (sp *dbSimulatedNestedTx) Rollback(ctx context.Context) error {
375 if sp.closed {
376 return ErrTxClosed
377 }
378
379 _, err := sp.Exec(ctx, "rollback to savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10))
380 sp.closed = true
381 return err
382 }
383
384
385 func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
386 if sp.closed {
387 return nil, ErrTxClosed
388 }
389
390 return sp.tx.Exec(ctx, sql, arguments...)
391 }
392
393
394 func (sp *dbSimulatedNestedTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
395 if sp.closed {
396 return nil, ErrTxClosed
397 }
398
399 return sp.tx.Prepare(ctx, name, sql)
400 }
401
402
403 func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
404 if sp.closed {
405
406 err := ErrTxClosed
407 return &connRows{closed: true, err: err}, err
408 }
409
410 return sp.tx.Query(ctx, sql, args...)
411 }
412
413
414 func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
415 rows, _ := sp.Query(ctx, sql, args...)
416 return (*connRow)(rows.(*connRows))
417 }
418
419
420 func (sp *dbSimulatedNestedTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
421 if sp.closed {
422 return nil, ErrTxClosed
423 }
424
425 return sp.tx.QueryFunc(ctx, sql, args, scans, f)
426 }
427
428
429 func (sp *dbSimulatedNestedTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
430 if sp.closed {
431 return 0, ErrTxClosed
432 }
433
434 return sp.tx.CopyFrom(ctx, tableName, columnNames, rowSrc)
435 }
436
437
438 func (sp *dbSimulatedNestedTx) SendBatch(ctx context.Context, b *Batch) BatchResults {
439 if sp.closed {
440 return &batchResults{err: ErrTxClosed}
441 }
442
443 return sp.tx.SendBatch(ctx, b)
444 }
445
446 func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects {
447 return LargeObjects{tx: sp}
448 }
449
450 func (sp *dbSimulatedNestedTx) Conn() *Conn {
451 return sp.tx.Conn()
452 }
453
View as plain text