...

Source file src/github.com/jackc/pgx/v4/stdlib/sql.go

Documentation: github.com/jackc/pgx/v4/stdlib

     1  // Package stdlib is the compatibility layer from pgx to database/sql.
     2  //
     3  // A database/sql connection can be established through sql.Open.
     4  //
     5  //	db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
     6  //	if err != nil {
     7  //		return err
     8  //	}
     9  //
    10  // Or from a DSN string.
    11  //
    12  //	db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
    13  //	if err != nil {
    14  //		return err
    15  //	}
    16  //
    17  // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
    18  // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
    19  // with sql.Open.
    20  //
    21  //	connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
    22  //	connConfig.Logger = myLogger
    23  //	connStr := stdlib.RegisterConnConfig(connConfig)
    24  //	db, _ := sql.Open("pgx", connStr)
    25  //
    26  // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
    27  // It does not support named parameters.
    28  //
    29  //	db.QueryRow("select * from users where id=$1", userID)
    30  //
    31  // In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard
    32  // database/sql.DB connection pool. This allows operations that use pgx specific functionality.
    33  //
    34  //	// Given db is a *sql.DB
    35  //	conn, err := db.Conn(context.Background())
    36  //	if err != nil {
    37  //		// handle error from acquiring connection from DB pool
    38  //	}
    39  //
    40  //	err = conn.Raw(func(driverConn interface{}) error {
    41  //		conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn
    42  //		// Do pgx specific stuff with conn
    43  //		conn.CopyFrom(...)
    44  //		return nil
    45  //	})
    46  //	if err != nil {
    47  //		// handle error that occurred while using *pgx.Conn
    48  //	}
    49  package stdlib
    50  
    51  import (
    52  	"context"
    53  	"database/sql"
    54  	"database/sql/driver"
    55  	"errors"
    56  	"fmt"
    57  	"io"
    58  	"math"
    59  	"math/rand"
    60  	"reflect"
    61  	"strconv"
    62  	"strings"
    63  	"sync"
    64  	"time"
    65  
    66  	"github.com/jackc/pgconn"
    67  	"github.com/jackc/pgtype"
    68  	"github.com/jackc/pgx/v4"
    69  )
    70  
    71  // Only intrinsic types should be binary format with database/sql.
    72  var databaseSQLResultFormats pgx.QueryResultFormatsByOID
    73  
    74  var pgxDriver *Driver
    75  
    76  type ctxKey int
    77  
    78  var ctxKeyFakeTx ctxKey = 0
    79  
    80  var ErrNotPgx = errors.New("not pgx *sql.DB")
    81  
    82  func init() {
    83  	pgxDriver = &Driver{
    84  		configs: make(map[string]*pgx.ConnConfig),
    85  	}
    86  	fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
    87  
    88  	// if pgx driver was already registered by different pgx major version then we
    89  	// skip registration under the default name.
    90  	if !contains(sql.Drivers(), "pgx") {
    91  		sql.Register("pgx", pgxDriver)
    92  	}
    93  	sql.Register("pgx/v4", pgxDriver)
    94  
    95  	databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
    96  		pgtype.BoolOID:        1,
    97  		pgtype.ByteaOID:       1,
    98  		pgtype.CIDOID:         1,
    99  		pgtype.DateOID:        1,
   100  		pgtype.Float4OID:      1,
   101  		pgtype.Float8OID:      1,
   102  		pgtype.Int2OID:        1,
   103  		pgtype.Int4OID:        1,
   104  		pgtype.Int8OID:        1,
   105  		pgtype.OIDOID:         1,
   106  		pgtype.TimestampOID:   1,
   107  		pgtype.TimestamptzOID: 1,
   108  		pgtype.XIDOID:         1,
   109  	}
   110  }
   111  
   112  // TODO replace by slices.Contains when experimental package will be merged to stdlib
   113  // https://pkg.go.dev/golang.org/x/exp/slices#Contains
   114  func contains(list []string, y string) bool {
   115  	for _, x := range list {
   116  		if x == y {
   117  			return true
   118  		}
   119  	}
   120  	return false
   121  }
   122  
   123  var (
   124  	fakeTxMutex sync.Mutex
   125  	fakeTxConns map[*pgx.Conn]*sql.Tx
   126  )
   127  
   128  // OptionOpenDB options for configuring the driver when opening a new db pool.
   129  type OptionOpenDB func(*connector)
   130  
   131  // OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
   132  // be used to connect, so only its immediate members should be modified.
   133  func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
   134  	return func(dc *connector) {
   135  		dc.BeforeConnect = bc
   136  	}
   137  }
   138  
   139  // OptionAfterConnect provides a callback for after connect.
   140  func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB {
   141  	return func(dc *connector) {
   142  		dc.AfterConnect = ac
   143  	}
   144  }
   145  
   146  // OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the
   147  // connection if the connection has been used before.
   148  // If ResetSessionFunc returns ErrBadConn error the connection will be discarded.
   149  func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB {
   150  	return func(dc *connector) {
   151  		dc.ResetSession = rs
   152  	}
   153  }
   154  
   155  // RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a
   156  // new host becomes primary each time. This is useful to distribute connections for multi-master databases like
   157  // CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well
   158  // to ensure that connections are periodically rebalanced across your nodes.
   159  func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error {
   160  	if len(connConfig.Fallbacks) == 0 {
   161  		return nil
   162  	}
   163  
   164  	newFallbacks := append([]*pgconn.FallbackConfig{&pgconn.FallbackConfig{
   165  		Host:      connConfig.Host,
   166  		Port:      connConfig.Port,
   167  		TLSConfig: connConfig.TLSConfig,
   168  	}}, connConfig.Fallbacks...)
   169  
   170  	rand.Shuffle(len(newFallbacks), func(i, j int) {
   171  		newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i]
   172  	})
   173  
   174  	// Use the one that sorted last as the primary and keep the rest as the fallbacks
   175  	newPrimary := newFallbacks[len(newFallbacks)-1]
   176  	connConfig.Host = newPrimary.Host
   177  	connConfig.Port = newPrimary.Port
   178  	connConfig.TLSConfig = newPrimary.TLSConfig
   179  	connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1]
   180  	return nil
   181  }
   182  
   183  func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector {
   184  	c := connector{
   185  		ConnConfig:    config,
   186  		BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default
   187  		AfterConnect:  func(context.Context, *pgx.Conn) error { return nil },       // noop after connect by default
   188  		ResetSession:  func(context.Context, *pgx.Conn) error { return nil },       // noop reset session by default
   189  		driver:        pgxDriver,
   190  	}
   191  
   192  	for _, opt := range opts {
   193  		opt(&c)
   194  	}
   195  	return c
   196  }
   197  
   198  func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
   199  	c := GetConnector(config, opts...)
   200  	return sql.OpenDB(c)
   201  }
   202  
   203  type connector struct {
   204  	pgx.ConnConfig
   205  	BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
   206  	AfterConnect  func(context.Context, *pgx.Conn) error       // function to call after creation of every new connection
   207  	ResetSession  func(context.Context, *pgx.Conn) error       // function is called before a connection is reused
   208  	driver        *Driver
   209  }
   210  
   211  // Connect implement driver.Connector interface
   212  func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
   213  	var (
   214  		err  error
   215  		conn *pgx.Conn
   216  	)
   217  
   218  	// Create a shallow copy of the config, so that BeforeConnect can safely modify it
   219  	connConfig := c.ConnConfig
   220  	if err = c.BeforeConnect(ctx, &connConfig); err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
   225  		return nil, err
   226  	}
   227  
   228  	if err = c.AfterConnect(ctx, conn); err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil
   233  }
   234  
   235  // Driver implement driver.Connector interface
   236  func (c connector) Driver() driver.Driver {
   237  	return c.driver
   238  }
   239  
   240  // GetDefaultDriver returns the driver initialized in the init function
   241  // and used when the pgx driver is registered.
   242  func GetDefaultDriver() driver.Driver {
   243  	return pgxDriver
   244  }
   245  
   246  type Driver struct {
   247  	configMutex sync.Mutex
   248  	configs     map[string]*pgx.ConnConfig
   249  	sequence    int
   250  }
   251  
   252  func (d *Driver) Open(name string) (driver.Conn, error) {
   253  	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
   254  	defer cancel()
   255  
   256  	connector, err := d.OpenConnector(name)
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  	return connector.Connect(ctx)
   261  }
   262  
   263  func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
   264  	return &driverConnector{driver: d, name: name}, nil
   265  }
   266  
   267  func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string {
   268  	d.configMutex.Lock()
   269  	connStr := fmt.Sprintf("registeredConnConfig%d", d.sequence)
   270  	d.sequence++
   271  	d.configs[connStr] = c
   272  	d.configMutex.Unlock()
   273  	return connStr
   274  }
   275  
   276  func (d *Driver) unregisterConnConfig(connStr string) {
   277  	d.configMutex.Lock()
   278  	delete(d.configs, connStr)
   279  	d.configMutex.Unlock()
   280  }
   281  
   282  type driverConnector struct {
   283  	driver *Driver
   284  	name   string
   285  }
   286  
   287  func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
   288  	var connConfig *pgx.ConnConfig
   289  
   290  	dc.driver.configMutex.Lock()
   291  	connConfig = dc.driver.configs[dc.name]
   292  	dc.driver.configMutex.Unlock()
   293  
   294  	if connConfig == nil {
   295  		var err error
   296  		connConfig, err = pgx.ParseConfig(dc.name)
   297  		if err != nil {
   298  			return nil, err
   299  		}
   300  	}
   301  
   302  	conn, err := pgx.ConnectConfig(ctx, connConfig)
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  
   307  	c := &Conn{
   308  		conn:             conn,
   309  		driver:           dc.driver,
   310  		connConfig:       *connConfig,
   311  		resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
   312  	}
   313  
   314  	return c, nil
   315  }
   316  
   317  func (dc *driverConnector) Driver() driver.Driver {
   318  	return dc.driver
   319  }
   320  
   321  // RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open.
   322  func RegisterConnConfig(c *pgx.ConnConfig) string {
   323  	return pgxDriver.registerConnConfig(c)
   324  }
   325  
   326  // UnregisterConnConfig removes the ConnConfig registration for connStr.
   327  func UnregisterConnConfig(connStr string) {
   328  	pgxDriver.unregisterConnConfig(connStr)
   329  }
   330  
   331  type Conn struct {
   332  	conn             *pgx.Conn
   333  	psCount          int64 // Counter used for creating unique prepared statement names
   334  	driver           *Driver
   335  	connConfig       pgx.ConnConfig
   336  	resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused
   337  }
   338  
   339  // Conn returns the underlying *pgx.Conn
   340  func (c *Conn) Conn() *pgx.Conn {
   341  	return c.conn
   342  }
   343  
   344  func (c *Conn) Prepare(query string) (driver.Stmt, error) {
   345  	return c.PrepareContext(context.Background(), query)
   346  }
   347  
   348  func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   349  	if c.conn.IsClosed() {
   350  		return nil, driver.ErrBadConn
   351  	}
   352  
   353  	name := fmt.Sprintf("pgx_%d", c.psCount)
   354  	c.psCount++
   355  
   356  	sd, err := c.conn.Prepare(ctx, name, query)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  
   361  	return &Stmt{sd: sd, conn: c}, nil
   362  }
   363  
   364  func (c *Conn) Close() error {
   365  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   366  	defer cancel()
   367  	return c.conn.Close(ctx)
   368  }
   369  
   370  func (c *Conn) Begin() (driver.Tx, error) {
   371  	return c.BeginTx(context.Background(), driver.TxOptions{})
   372  }
   373  
   374  func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   375  	if c.conn.IsClosed() {
   376  		return nil, driver.ErrBadConn
   377  	}
   378  
   379  	if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok {
   380  		*pconn = c.conn
   381  		return fakeTx{}, nil
   382  	}
   383  
   384  	var pgxOpts pgx.TxOptions
   385  	switch sql.IsolationLevel(opts.Isolation) {
   386  	case sql.LevelDefault:
   387  	case sql.LevelReadUncommitted:
   388  		pgxOpts.IsoLevel = pgx.ReadUncommitted
   389  	case sql.LevelReadCommitted:
   390  		pgxOpts.IsoLevel = pgx.ReadCommitted
   391  	case sql.LevelRepeatableRead, sql.LevelSnapshot:
   392  		pgxOpts.IsoLevel = pgx.RepeatableRead
   393  	case sql.LevelSerializable:
   394  		pgxOpts.IsoLevel = pgx.Serializable
   395  	default:
   396  		return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
   397  	}
   398  
   399  	if opts.ReadOnly {
   400  		pgxOpts.AccessMode = pgx.ReadOnly
   401  	}
   402  
   403  	tx, err := c.conn.BeginTx(ctx, pgxOpts)
   404  	if err != nil {
   405  		return nil, err
   406  	}
   407  
   408  	return wrapTx{ctx: ctx, tx: tx}, nil
   409  }
   410  
   411  func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
   412  	if c.conn.IsClosed() {
   413  		return nil, driver.ErrBadConn
   414  	}
   415  
   416  	args := namedValueToInterface(argsV)
   417  
   418  	commandTag, err := c.conn.Exec(ctx, query, args...)
   419  	// if we got a network error before we had a chance to send the query, retry
   420  	if err != nil {
   421  		if pgconn.SafeToRetry(err) {
   422  			return nil, driver.ErrBadConn
   423  		}
   424  	}
   425  	return driver.RowsAffected(commandTag.RowsAffected()), err
   426  }
   427  
   428  func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
   429  	if c.conn.IsClosed() {
   430  		return nil, driver.ErrBadConn
   431  	}
   432  
   433  	args := []interface{}{databaseSQLResultFormats}
   434  	args = append(args, namedValueToInterface(argsV)...)
   435  
   436  	rows, err := c.conn.Query(ctx, query, args...)
   437  	if err != nil {
   438  		if pgconn.SafeToRetry(err) {
   439  			return nil, driver.ErrBadConn
   440  		}
   441  		return nil, err
   442  	}
   443  
   444  	// Preload first row because otherwise we won't know what columns are available when database/sql asks.
   445  	more := rows.Next()
   446  	if err = rows.Err(); err != nil {
   447  		rows.Close()
   448  		return nil, err
   449  	}
   450  	return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
   451  }
   452  
   453  func (c *Conn) Ping(ctx context.Context) error {
   454  	if c.conn.IsClosed() {
   455  		return driver.ErrBadConn
   456  	}
   457  
   458  	err := c.conn.Ping(ctx)
   459  	if err != nil {
   460  		// A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the
   461  		// failure, but manually close it just to be sure.
   462  		c.Close()
   463  		return driver.ErrBadConn
   464  	}
   465  
   466  	return nil
   467  }
   468  
   469  func (c *Conn) CheckNamedValue(*driver.NamedValue) error {
   470  	// Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly.
   471  	return nil
   472  }
   473  
   474  func (c *Conn) ResetSession(ctx context.Context) error {
   475  	if c.conn.IsClosed() {
   476  		return driver.ErrBadConn
   477  	}
   478  
   479  	return c.resetSessionFunc(ctx, c.conn)
   480  }
   481  
   482  type Stmt struct {
   483  	sd   *pgconn.StatementDescription
   484  	conn *Conn
   485  }
   486  
   487  func (s *Stmt) Close() error {
   488  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   489  	defer cancel()
   490  	return s.conn.conn.Deallocate(ctx, s.sd.Name)
   491  }
   492  
   493  func (s *Stmt) NumInput() int {
   494  	return len(s.sd.ParamOIDs)
   495  }
   496  
   497  func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
   498  	return nil, errors.New("Stmt.Exec deprecated and not implemented")
   499  }
   500  
   501  func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
   502  	return s.conn.ExecContext(ctx, s.sd.Name, argsV)
   503  }
   504  
   505  func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
   506  	return nil, errors.New("Stmt.Query deprecated and not implemented")
   507  }
   508  
   509  func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
   510  	return s.conn.QueryContext(ctx, s.sd.Name, argsV)
   511  }
   512  
   513  type rowValueFunc func(src []byte) (driver.Value, error)
   514  
   515  type Rows struct {
   516  	conn         *Conn
   517  	rows         pgx.Rows
   518  	valueFuncs   []rowValueFunc
   519  	skipNext     bool
   520  	skipNextMore bool
   521  
   522  	columnNames []string
   523  }
   524  
   525  func (r *Rows) Columns() []string {
   526  	if r.columnNames == nil {
   527  		fields := r.rows.FieldDescriptions()
   528  		r.columnNames = make([]string, len(fields))
   529  		for i, fd := range fields {
   530  			r.columnNames[i] = string(fd.Name)
   531  		}
   532  	}
   533  
   534  	return r.columnNames
   535  }
   536  
   537  // ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned.
   538  func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
   539  	if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok {
   540  		return strings.ToUpper(dt.Name)
   541  	}
   542  
   543  	return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10)
   544  }
   545  
   546  const varHeaderSize = 4
   547  
   548  // ColumnTypeLength returns the length of the column type if the column is a
   549  // variable length type. If the column is not a variable length type ok
   550  // should return false.
   551  func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
   552  	fd := r.rows.FieldDescriptions()[index]
   553  
   554  	switch fd.DataTypeOID {
   555  	case pgtype.TextOID, pgtype.ByteaOID:
   556  		return math.MaxInt64, true
   557  	case pgtype.VarcharOID, pgtype.BPCharArrayOID:
   558  		return int64(fd.TypeModifier - varHeaderSize), true
   559  	default:
   560  		return 0, false
   561  	}
   562  }
   563  
   564  // ColumnTypePrecisionScale should return the precision and scale for decimal
   565  // types. If not applicable, ok should be false.
   566  func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
   567  	fd := r.rows.FieldDescriptions()[index]
   568  
   569  	switch fd.DataTypeOID {
   570  	case pgtype.NumericOID:
   571  		mod := fd.TypeModifier - varHeaderSize
   572  		precision = int64((mod >> 16) & 0xffff)
   573  		scale = int64(mod & 0xffff)
   574  		return precision, scale, true
   575  	default:
   576  		return 0, 0, false
   577  	}
   578  }
   579  
   580  // ColumnTypeScanType returns the value type that can be used to scan types into.
   581  func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
   582  	fd := r.rows.FieldDescriptions()[index]
   583  
   584  	switch fd.DataTypeOID {
   585  	case pgtype.Float8OID:
   586  		return reflect.TypeOf(float64(0))
   587  	case pgtype.Float4OID:
   588  		return reflect.TypeOf(float32(0))
   589  	case pgtype.Int8OID:
   590  		return reflect.TypeOf(int64(0))
   591  	case pgtype.Int4OID:
   592  		return reflect.TypeOf(int32(0))
   593  	case pgtype.Int2OID:
   594  		return reflect.TypeOf(int16(0))
   595  	case pgtype.BoolOID:
   596  		return reflect.TypeOf(false)
   597  	case pgtype.NumericOID:
   598  		return reflect.TypeOf(float64(0))
   599  	case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
   600  		return reflect.TypeOf(time.Time{})
   601  	case pgtype.ByteaOID:
   602  		return reflect.TypeOf([]byte(nil))
   603  	default:
   604  		return reflect.TypeOf("")
   605  	}
   606  }
   607  
   608  func (r *Rows) Close() error {
   609  	r.rows.Close()
   610  	return r.rows.Err()
   611  }
   612  
   613  func (r *Rows) Next(dest []driver.Value) error {
   614  	ci := r.conn.conn.ConnInfo()
   615  	fieldDescriptions := r.rows.FieldDescriptions()
   616  
   617  	if r.valueFuncs == nil {
   618  		r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions))
   619  
   620  		for i, fd := range fieldDescriptions {
   621  			dataTypeOID := fd.DataTypeOID
   622  			format := fd.Format
   623  
   624  			switch fd.DataTypeOID {
   625  			case pgtype.BoolOID:
   626  				var d bool
   627  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   628  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   629  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   630  					return d, err
   631  				}
   632  			case pgtype.ByteaOID:
   633  				var d []byte
   634  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   635  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   636  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   637  					return d, err
   638  				}
   639  			case pgtype.CIDOID:
   640  				var d pgtype.CID
   641  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   642  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   643  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   644  					if err != nil {
   645  						return nil, err
   646  					}
   647  					return d.Value()
   648  				}
   649  			case pgtype.DateOID:
   650  				var d pgtype.Date
   651  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   652  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   653  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   654  					if err != nil {
   655  						return nil, err
   656  					}
   657  					return d.Value()
   658  				}
   659  			case pgtype.Float4OID:
   660  				var d float32
   661  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   662  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   663  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   664  					return float64(d), err
   665  				}
   666  			case pgtype.Float8OID:
   667  				var d float64
   668  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   669  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   670  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   671  					return d, err
   672  				}
   673  			case pgtype.Int2OID:
   674  				var d int16
   675  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   676  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   677  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   678  					return int64(d), err
   679  				}
   680  			case pgtype.Int4OID:
   681  				var d int32
   682  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   683  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   684  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   685  					return int64(d), err
   686  				}
   687  			case pgtype.Int8OID:
   688  				var d int64
   689  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   690  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   691  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   692  					return d, err
   693  				}
   694  			case pgtype.JSONOID:
   695  				var d pgtype.JSON
   696  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   697  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   698  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   699  					if err != nil {
   700  						return nil, err
   701  					}
   702  					return d.Value()
   703  				}
   704  			case pgtype.JSONBOID:
   705  				var d pgtype.JSONB
   706  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   707  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   708  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   709  					if err != nil {
   710  						return nil, err
   711  					}
   712  					return d.Value()
   713  				}
   714  			case pgtype.OIDOID:
   715  				var d pgtype.OIDValue
   716  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   717  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   718  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   719  					if err != nil {
   720  						return nil, err
   721  					}
   722  					return d.Value()
   723  				}
   724  			case pgtype.TimestampOID:
   725  				var d pgtype.Timestamp
   726  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   727  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   728  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   729  					if err != nil {
   730  						return nil, err
   731  					}
   732  					return d.Value()
   733  				}
   734  			case pgtype.TimestamptzOID:
   735  				var d pgtype.Timestamptz
   736  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   737  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   738  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   739  					if err != nil {
   740  						return nil, err
   741  					}
   742  					return d.Value()
   743  				}
   744  			case pgtype.XIDOID:
   745  				var d pgtype.XID
   746  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   747  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   748  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   749  					if err != nil {
   750  						return nil, err
   751  					}
   752  					return d.Value()
   753  				}
   754  			default:
   755  				var d string
   756  				scanPlan := ci.PlanScan(dataTypeOID, format, &d)
   757  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   758  					err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
   759  					return d, err
   760  				}
   761  			}
   762  		}
   763  	}
   764  
   765  	var more bool
   766  	if r.skipNext {
   767  		more = r.skipNextMore
   768  		r.skipNext = false
   769  	} else {
   770  		more = r.rows.Next()
   771  	}
   772  
   773  	if !more {
   774  		if r.rows.Err() == nil {
   775  			return io.EOF
   776  		} else {
   777  			return r.rows.Err()
   778  		}
   779  	}
   780  
   781  	for i, rv := range r.rows.RawValues() {
   782  		if rv != nil {
   783  			var err error
   784  			dest[i], err = r.valueFuncs[i](rv)
   785  			if err != nil {
   786  				return fmt.Errorf("convert field %d failed: %v", i, err)
   787  			}
   788  		} else {
   789  			dest[i] = nil
   790  		}
   791  	}
   792  
   793  	return nil
   794  }
   795  
   796  func valueToInterface(argsV []driver.Value) []interface{} {
   797  	args := make([]interface{}, 0, len(argsV))
   798  	for _, v := range argsV {
   799  		if v != nil {
   800  			args = append(args, v.(interface{}))
   801  		} else {
   802  			args = append(args, nil)
   803  		}
   804  	}
   805  	return args
   806  }
   807  
   808  func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
   809  	args := make([]interface{}, 0, len(argsV))
   810  	for _, v := range argsV {
   811  		if v.Value != nil {
   812  			args = append(args, v.Value.(interface{}))
   813  		} else {
   814  			args = append(args, nil)
   815  		}
   816  	}
   817  	return args
   818  }
   819  
   820  type wrapTx struct {
   821  	ctx context.Context
   822  	tx  pgx.Tx
   823  }
   824  
   825  func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) }
   826  
   827  func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) }
   828  
   829  type fakeTx struct{}
   830  
   831  func (fakeTx) Commit() error { return nil }
   832  
   833  func (fakeTx) Rollback() error { return nil }
   834  
   835  // AcquireConn acquires a *pgx.Conn from database/sql connection pool. It must be released with ReleaseConn.
   836  //
   837  // In Go 1.13 this functionality has been incorporated into the standard library in the db.Conn.Raw() method.
   838  func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
   839  	var conn *pgx.Conn
   840  	ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
   841  	tx, err := db.BeginTx(ctx, nil)
   842  	if err != nil {
   843  		return nil, err
   844  	}
   845  	if conn == nil {
   846  		tx.Rollback()
   847  		return nil, ErrNotPgx
   848  	}
   849  
   850  	fakeTxMutex.Lock()
   851  	fakeTxConns[conn] = tx
   852  	fakeTxMutex.Unlock()
   853  
   854  	return conn, nil
   855  }
   856  
   857  // ReleaseConn releases a *pgx.Conn acquired with AcquireConn.
   858  func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
   859  	var tx *sql.Tx
   860  	var ok bool
   861  
   862  	if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' {
   863  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   864  		defer cancel()
   865  		conn.Close(ctx)
   866  	}
   867  
   868  	fakeTxMutex.Lock()
   869  	tx, ok = fakeTxConns[conn]
   870  	if ok {
   871  		delete(fakeTxConns, conn)
   872  		fakeTxMutex.Unlock()
   873  	} else {
   874  		fakeTxMutex.Unlock()
   875  		return fmt.Errorf("can't release conn that is not acquired")
   876  	}
   877  
   878  	return tx.Rollback()
   879  }
   880  

View as plain text