1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49 package stdlib
50
51 import (
52 "context"
53 "database/sql"
54 "database/sql/driver"
55 "errors"
56 "fmt"
57 "io"
58 "math"
59 "math/rand"
60 "reflect"
61 "strconv"
62 "strings"
63 "sync"
64 "time"
65
66 "github.com/jackc/pgconn"
67 "github.com/jackc/pgtype"
68 "github.com/jackc/pgx/v4"
69 )
70
71
72 var databaseSQLResultFormats pgx.QueryResultFormatsByOID
73
74 var pgxDriver *Driver
75
76 type ctxKey int
77
78 var ctxKeyFakeTx ctxKey = 0
79
80 var ErrNotPgx = errors.New("not pgx *sql.DB")
81
82 func init() {
83 pgxDriver = &Driver{
84 configs: make(map[string]*pgx.ConnConfig),
85 }
86 fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
87
88
89
90 if !contains(sql.Drivers(), "pgx") {
91 sql.Register("pgx", pgxDriver)
92 }
93 sql.Register("pgx/v4", pgxDriver)
94
95 databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
96 pgtype.BoolOID: 1,
97 pgtype.ByteaOID: 1,
98 pgtype.CIDOID: 1,
99 pgtype.DateOID: 1,
100 pgtype.Float4OID: 1,
101 pgtype.Float8OID: 1,
102 pgtype.Int2OID: 1,
103 pgtype.Int4OID: 1,
104 pgtype.Int8OID: 1,
105 pgtype.OIDOID: 1,
106 pgtype.TimestampOID: 1,
107 pgtype.TimestamptzOID: 1,
108 pgtype.XIDOID: 1,
109 }
110 }
111
112
113
114 func contains(list []string, y string) bool {
115 for _, x := range list {
116 if x == y {
117 return true
118 }
119 }
120 return false
121 }
122
123 var (
124 fakeTxMutex sync.Mutex
125 fakeTxConns map[*pgx.Conn]*sql.Tx
126 )
127
128
129 type OptionOpenDB func(*connector)
130
131
132
133 func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
134 return func(dc *connector) {
135 dc.BeforeConnect = bc
136 }
137 }
138
139
140 func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB {
141 return func(dc *connector) {
142 dc.AfterConnect = ac
143 }
144 }
145
146
147
148
149 func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB {
150 return func(dc *connector) {
151 dc.ResetSession = rs
152 }
153 }
154
155
156
157
158
159 func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error {
160 if len(connConfig.Fallbacks) == 0 {
161 return nil
162 }
163
164 newFallbacks := append([]*pgconn.FallbackConfig{&pgconn.FallbackConfig{
165 Host: connConfig.Host,
166 Port: connConfig.Port,
167 TLSConfig: connConfig.TLSConfig,
168 }}, connConfig.Fallbacks...)
169
170 rand.Shuffle(len(newFallbacks), func(i, j int) {
171 newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i]
172 })
173
174
175 newPrimary := newFallbacks[len(newFallbacks)-1]
176 connConfig.Host = newPrimary.Host
177 connConfig.Port = newPrimary.Port
178 connConfig.TLSConfig = newPrimary.TLSConfig
179 connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1]
180 return nil
181 }
182
183 func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector {
184 c := connector{
185 ConnConfig: config,
186 BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil },
187 AfterConnect: func(context.Context, *pgx.Conn) error { return nil },
188 ResetSession: func(context.Context, *pgx.Conn) error { return nil },
189 driver: pgxDriver,
190 }
191
192 for _, opt := range opts {
193 opt(&c)
194 }
195 return c
196 }
197
198 func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
199 c := GetConnector(config, opts...)
200 return sql.OpenDB(c)
201 }
202
203 type connector struct {
204 pgx.ConnConfig
205 BeforeConnect func(context.Context, *pgx.ConnConfig) error
206 AfterConnect func(context.Context, *pgx.Conn) error
207 ResetSession func(context.Context, *pgx.Conn) error
208 driver *Driver
209 }
210
211
212 func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
213 var (
214 err error
215 conn *pgx.Conn
216 )
217
218
219 connConfig := c.ConnConfig
220 if err = c.BeforeConnect(ctx, &connConfig); err != nil {
221 return nil, err
222 }
223
224 if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
225 return nil, err
226 }
227
228 if err = c.AfterConnect(ctx, conn); err != nil {
229 return nil, err
230 }
231
232 return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil
233 }
234
235
236 func (c connector) Driver() driver.Driver {
237 return c.driver
238 }
239
240
241
242 func GetDefaultDriver() driver.Driver {
243 return pgxDriver
244 }
245
246 type Driver struct {
247 configMutex sync.Mutex
248 configs map[string]*pgx.ConnConfig
249 sequence int
250 }
251
252 func (d *Driver) Open(name string) (driver.Conn, error) {
253 ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
254 defer cancel()
255
256 connector, err := d.OpenConnector(name)
257 if err != nil {
258 return nil, err
259 }
260 return connector.Connect(ctx)
261 }
262
263 func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
264 return &driverConnector{driver: d, name: name}, nil
265 }
266
267 func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string {
268 d.configMutex.Lock()
269 connStr := fmt.Sprintf("registeredConnConfig%d", d.sequence)
270 d.sequence++
271 d.configs[connStr] = c
272 d.configMutex.Unlock()
273 return connStr
274 }
275
276 func (d *Driver) unregisterConnConfig(connStr string) {
277 d.configMutex.Lock()
278 delete(d.configs, connStr)
279 d.configMutex.Unlock()
280 }
281
282 type driverConnector struct {
283 driver *Driver
284 name string
285 }
286
287 func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
288 var connConfig *pgx.ConnConfig
289
290 dc.driver.configMutex.Lock()
291 connConfig = dc.driver.configs[dc.name]
292 dc.driver.configMutex.Unlock()
293
294 if connConfig == nil {
295 var err error
296 connConfig, err = pgx.ParseConfig(dc.name)
297 if err != nil {
298 return nil, err
299 }
300 }
301
302 conn, err := pgx.ConnectConfig(ctx, connConfig)
303 if err != nil {
304 return nil, err
305 }
306
307 c := &Conn{
308 conn: conn,
309 driver: dc.driver,
310 connConfig: *connConfig,
311 resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
312 }
313
314 return c, nil
315 }
316
317 func (dc *driverConnector) Driver() driver.Driver {
318 return dc.driver
319 }
320
321
322 func RegisterConnConfig(c *pgx.ConnConfig) string {
323 return pgxDriver.registerConnConfig(c)
324 }
325
326
327 func UnregisterConnConfig(connStr string) {
328 pgxDriver.unregisterConnConfig(connStr)
329 }
330
331 type Conn struct {
332 conn *pgx.Conn
333 psCount int64
334 driver *Driver
335 connConfig pgx.ConnConfig
336 resetSessionFunc func(context.Context, *pgx.Conn) error
337 }
338
339
340 func (c *Conn) Conn() *pgx.Conn {
341 return c.conn
342 }
343
344 func (c *Conn) Prepare(query string) (driver.Stmt, error) {
345 return c.PrepareContext(context.Background(), query)
346 }
347
348 func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
349 if c.conn.IsClosed() {
350 return nil, driver.ErrBadConn
351 }
352
353 name := fmt.Sprintf("pgx_%d", c.psCount)
354 c.psCount++
355
356 sd, err := c.conn.Prepare(ctx, name, query)
357 if err != nil {
358 return nil, err
359 }
360
361 return &Stmt{sd: sd, conn: c}, nil
362 }
363
364 func (c *Conn) Close() error {
365 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
366 defer cancel()
367 return c.conn.Close(ctx)
368 }
369
370 func (c *Conn) Begin() (driver.Tx, error) {
371 return c.BeginTx(context.Background(), driver.TxOptions{})
372 }
373
374 func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
375 if c.conn.IsClosed() {
376 return nil, driver.ErrBadConn
377 }
378
379 if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok {
380 *pconn = c.conn
381 return fakeTx{}, nil
382 }
383
384 var pgxOpts pgx.TxOptions
385 switch sql.IsolationLevel(opts.Isolation) {
386 case sql.LevelDefault:
387 case sql.LevelReadUncommitted:
388 pgxOpts.IsoLevel = pgx.ReadUncommitted
389 case sql.LevelReadCommitted:
390 pgxOpts.IsoLevel = pgx.ReadCommitted
391 case sql.LevelRepeatableRead, sql.LevelSnapshot:
392 pgxOpts.IsoLevel = pgx.RepeatableRead
393 case sql.LevelSerializable:
394 pgxOpts.IsoLevel = pgx.Serializable
395 default:
396 return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
397 }
398
399 if opts.ReadOnly {
400 pgxOpts.AccessMode = pgx.ReadOnly
401 }
402
403 tx, err := c.conn.BeginTx(ctx, pgxOpts)
404 if err != nil {
405 return nil, err
406 }
407
408 return wrapTx{ctx: ctx, tx: tx}, nil
409 }
410
411 func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
412 if c.conn.IsClosed() {
413 return nil, driver.ErrBadConn
414 }
415
416 args := namedValueToInterface(argsV)
417
418 commandTag, err := c.conn.Exec(ctx, query, args...)
419
420 if err != nil {
421 if pgconn.SafeToRetry(err) {
422 return nil, driver.ErrBadConn
423 }
424 }
425 return driver.RowsAffected(commandTag.RowsAffected()), err
426 }
427
428 func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
429 if c.conn.IsClosed() {
430 return nil, driver.ErrBadConn
431 }
432
433 args := []interface{}{databaseSQLResultFormats}
434 args = append(args, namedValueToInterface(argsV)...)
435
436 rows, err := c.conn.Query(ctx, query, args...)
437 if err != nil {
438 if pgconn.SafeToRetry(err) {
439 return nil, driver.ErrBadConn
440 }
441 return nil, err
442 }
443
444
445 more := rows.Next()
446 if err = rows.Err(); err != nil {
447 rows.Close()
448 return nil, err
449 }
450 return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
451 }
452
453 func (c *Conn) Ping(ctx context.Context) error {
454 if c.conn.IsClosed() {
455 return driver.ErrBadConn
456 }
457
458 err := c.conn.Ping(ctx)
459 if err != nil {
460
461
462 c.Close()
463 return driver.ErrBadConn
464 }
465
466 return nil
467 }
468
469 func (c *Conn) CheckNamedValue(*driver.NamedValue) error {
470
471 return nil
472 }
473
474 func (c *Conn) ResetSession(ctx context.Context) error {
475 if c.conn.IsClosed() {
476 return driver.ErrBadConn
477 }
478
479 return c.resetSessionFunc(ctx, c.conn)
480 }
481
482 type Stmt struct {
483 sd *pgconn.StatementDescription
484 conn *Conn
485 }
486
487 func (s *Stmt) Close() error {
488 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
489 defer cancel()
490 return s.conn.conn.Deallocate(ctx, s.sd.Name)
491 }
492
493 func (s *Stmt) NumInput() int {
494 return len(s.sd.ParamOIDs)
495 }
496
497 func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
498 return nil, errors.New("Stmt.Exec deprecated and not implemented")
499 }
500
501 func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
502 return s.conn.ExecContext(ctx, s.sd.Name, argsV)
503 }
504
505 func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
506 return nil, errors.New("Stmt.Query deprecated and not implemented")
507 }
508
509 func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
510 return s.conn.QueryContext(ctx, s.sd.Name, argsV)
511 }
512
513 type rowValueFunc func(src []byte) (driver.Value, error)
514
515 type Rows struct {
516 conn *Conn
517 rows pgx.Rows
518 valueFuncs []rowValueFunc
519 skipNext bool
520 skipNextMore bool
521
522 columnNames []string
523 }
524
525 func (r *Rows) Columns() []string {
526 if r.columnNames == nil {
527 fields := r.rows.FieldDescriptions()
528 r.columnNames = make([]string, len(fields))
529 for i, fd := range fields {
530 r.columnNames[i] = string(fd.Name)
531 }
532 }
533
534 return r.columnNames
535 }
536
537
538 func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
539 if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok {
540 return strings.ToUpper(dt.Name)
541 }
542
543 return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10)
544 }
545
546 const varHeaderSize = 4
547
548
549
550
551 func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
552 fd := r.rows.FieldDescriptions()[index]
553
554 switch fd.DataTypeOID {
555 case pgtype.TextOID, pgtype.ByteaOID:
556 return math.MaxInt64, true
557 case pgtype.VarcharOID, pgtype.BPCharArrayOID:
558 return int64(fd.TypeModifier - varHeaderSize), true
559 default:
560 return 0, false
561 }
562 }
563
564
565
566 func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
567 fd := r.rows.FieldDescriptions()[index]
568
569 switch fd.DataTypeOID {
570 case pgtype.NumericOID:
571 mod := fd.TypeModifier - varHeaderSize
572 precision = int64((mod >> 16) & 0xffff)
573 scale = int64(mod & 0xffff)
574 return precision, scale, true
575 default:
576 return 0, 0, false
577 }
578 }
579
580
581 func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
582 fd := r.rows.FieldDescriptions()[index]
583
584 switch fd.DataTypeOID {
585 case pgtype.Float8OID:
586 return reflect.TypeOf(float64(0))
587 case pgtype.Float4OID:
588 return reflect.TypeOf(float32(0))
589 case pgtype.Int8OID:
590 return reflect.TypeOf(int64(0))
591 case pgtype.Int4OID:
592 return reflect.TypeOf(int32(0))
593 case pgtype.Int2OID:
594 return reflect.TypeOf(int16(0))
595 case pgtype.BoolOID:
596 return reflect.TypeOf(false)
597 case pgtype.NumericOID:
598 return reflect.TypeOf(float64(0))
599 case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
600 return reflect.TypeOf(time.Time{})
601 case pgtype.ByteaOID:
602 return reflect.TypeOf([]byte(nil))
603 default:
604 return reflect.TypeOf("")
605 }
606 }
607
608 func (r *Rows) Close() error {
609 r.rows.Close()
610 return r.rows.Err()
611 }
612
613 func (r *Rows) Next(dest []driver.Value) error {
614 ci := r.conn.conn.ConnInfo()
615 fieldDescriptions := r.rows.FieldDescriptions()
616
617 if r.valueFuncs == nil {
618 r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions))
619
620 for i, fd := range fieldDescriptions {
621 dataTypeOID := fd.DataTypeOID
622 format := fd.Format
623
624 switch fd.DataTypeOID {
625 case pgtype.BoolOID:
626 var d bool
627 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
628 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
629 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
630 return d, err
631 }
632 case pgtype.ByteaOID:
633 var d []byte
634 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
635 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
636 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
637 return d, err
638 }
639 case pgtype.CIDOID:
640 var d pgtype.CID
641 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
642 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
643 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
644 if err != nil {
645 return nil, err
646 }
647 return d.Value()
648 }
649 case pgtype.DateOID:
650 var d pgtype.Date
651 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
652 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
653 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
654 if err != nil {
655 return nil, err
656 }
657 return d.Value()
658 }
659 case pgtype.Float4OID:
660 var d float32
661 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
662 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
663 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
664 return float64(d), err
665 }
666 case pgtype.Float8OID:
667 var d float64
668 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
669 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
670 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
671 return d, err
672 }
673 case pgtype.Int2OID:
674 var d int16
675 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
676 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
677 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
678 return int64(d), err
679 }
680 case pgtype.Int4OID:
681 var d int32
682 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
683 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
684 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
685 return int64(d), err
686 }
687 case pgtype.Int8OID:
688 var d int64
689 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
690 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
691 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
692 return d, err
693 }
694 case pgtype.JSONOID:
695 var d pgtype.JSON
696 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
697 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
698 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
699 if err != nil {
700 return nil, err
701 }
702 return d.Value()
703 }
704 case pgtype.JSONBOID:
705 var d pgtype.JSONB
706 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
707 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
708 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
709 if err != nil {
710 return nil, err
711 }
712 return d.Value()
713 }
714 case pgtype.OIDOID:
715 var d pgtype.OIDValue
716 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
717 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
718 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
719 if err != nil {
720 return nil, err
721 }
722 return d.Value()
723 }
724 case pgtype.TimestampOID:
725 var d pgtype.Timestamp
726 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
727 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
728 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
729 if err != nil {
730 return nil, err
731 }
732 return d.Value()
733 }
734 case pgtype.TimestamptzOID:
735 var d pgtype.Timestamptz
736 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
737 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
738 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
739 if err != nil {
740 return nil, err
741 }
742 return d.Value()
743 }
744 case pgtype.XIDOID:
745 var d pgtype.XID
746 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
747 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
748 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
749 if err != nil {
750 return nil, err
751 }
752 return d.Value()
753 }
754 default:
755 var d string
756 scanPlan := ci.PlanScan(dataTypeOID, format, &d)
757 r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
758 err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
759 return d, err
760 }
761 }
762 }
763 }
764
765 var more bool
766 if r.skipNext {
767 more = r.skipNextMore
768 r.skipNext = false
769 } else {
770 more = r.rows.Next()
771 }
772
773 if !more {
774 if r.rows.Err() == nil {
775 return io.EOF
776 } else {
777 return r.rows.Err()
778 }
779 }
780
781 for i, rv := range r.rows.RawValues() {
782 if rv != nil {
783 var err error
784 dest[i], err = r.valueFuncs[i](rv)
785 if err != nil {
786 return fmt.Errorf("convert field %d failed: %v", i, err)
787 }
788 } else {
789 dest[i] = nil
790 }
791 }
792
793 return nil
794 }
795
796 func valueToInterface(argsV []driver.Value) []interface{} {
797 args := make([]interface{}, 0, len(argsV))
798 for _, v := range argsV {
799 if v != nil {
800 args = append(args, v.(interface{}))
801 } else {
802 args = append(args, nil)
803 }
804 }
805 return args
806 }
807
808 func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
809 args := make([]interface{}, 0, len(argsV))
810 for _, v := range argsV {
811 if v.Value != nil {
812 args = append(args, v.Value.(interface{}))
813 } else {
814 args = append(args, nil)
815 }
816 }
817 return args
818 }
819
820 type wrapTx struct {
821 ctx context.Context
822 tx pgx.Tx
823 }
824
825 func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) }
826
827 func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) }
828
829 type fakeTx struct{}
830
831 func (fakeTx) Commit() error { return nil }
832
833 func (fakeTx) Rollback() error { return nil }
834
835
836
837
838 func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
839 var conn *pgx.Conn
840 ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
841 tx, err := db.BeginTx(ctx, nil)
842 if err != nil {
843 return nil, err
844 }
845 if conn == nil {
846 tx.Rollback()
847 return nil, ErrNotPgx
848 }
849
850 fakeTxMutex.Lock()
851 fakeTxConns[conn] = tx
852 fakeTxMutex.Unlock()
853
854 return conn, nil
855 }
856
857
858 func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
859 var tx *sql.Tx
860 var ok bool
861
862 if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' {
863 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
864 defer cancel()
865 conn.Close(ctx)
866 }
867
868 fakeTxMutex.Lock()
869 tx, ok = fakeTxConns[conn]
870 if ok {
871 delete(fakeTxConns, conn)
872 fakeTxMutex.Unlock()
873 } else {
874 fakeTxMutex.Unlock()
875 return fmt.Errorf("can't release conn that is not acquired")
876 }
877
878 return tx.Rollback()
879 }
880
View as plain text