...

Source file src/github.com/jackc/pgx/v5/stdlib/sql.go

Documentation: github.com/jackc/pgx/v5/stdlib

     1  // Package stdlib is the compatibility layer from pgx to database/sql.
     2  //
     3  // A database/sql connection can be established through sql.Open.
     4  //
     5  //	db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
     6  //	if err != nil {
     7  //	  return err
     8  //	}
     9  //
    10  // Or from a DSN string.
    11  //
    12  //	db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
    13  //	if err != nil {
    14  //	  return err
    15  //	}
    16  //
    17  // Or from a *pgxpool.Pool.
    18  //
    19  //	pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
    20  //	if err != nil {
    21  //	  return err
    22  //	}
    23  //
    24  //	db := stdlib.OpenDBFromPool(pool)
    25  //
    26  // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
    27  // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
    28  // with sql.Open.
    29  //
    30  //	connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
    31  //	connConfig.Tracer = &tracelog.TraceLog{Logger: myLogger, LogLevel: tracelog.LogLevelInfo}
    32  //	connStr := stdlib.RegisterConnConfig(connConfig)
    33  //	db, _ := sql.Open("pgx", connStr)
    34  //
    35  // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters.
    36  //
    37  //	db.QueryRow("select * from users where id=$1", userID)
    38  //
    39  // (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows
    40  // operations that use pgx specific functionality.
    41  //
    42  //	// Given db is a *sql.DB
    43  //	conn, err := db.Conn(context.Background())
    44  //	if err != nil {
    45  //	  // handle error from acquiring connection from DB pool
    46  //	}
    47  //
    48  //	err = conn.Raw(func(driverConn any) error {
    49  //	  conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn
    50  //	  // Do pgx specific stuff with conn
    51  //	  conn.CopyFrom(...)
    52  //	  return nil
    53  //	})
    54  //	if err != nil {
    55  //	  // handle error that occurred while using *pgx.Conn
    56  //	}
    57  //
    58  // # PostgreSQL Specific Data Types
    59  //
    60  // The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes
    61  // these types usable as a sql.Scanner.
    62  //
    63  //	m := pgtype.NewMap()
    64  //	var a []int64
    65  //	err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
    66  package stdlib
    67  
    68  import (
    69  	"context"
    70  	"database/sql"
    71  	"database/sql/driver"
    72  	"errors"
    73  	"fmt"
    74  	"io"
    75  	"math"
    76  	"math/rand"
    77  	"reflect"
    78  	"strconv"
    79  	"strings"
    80  	"sync"
    81  	"time"
    82  
    83  	"github.com/jackc/pgx/v5"
    84  	"github.com/jackc/pgx/v5/pgconn"
    85  	"github.com/jackc/pgx/v5/pgtype"
    86  	"github.com/jackc/pgx/v5/pgxpool"
    87  )
    88  
    89  // Only intrinsic types should be binary format with database/sql.
    90  var databaseSQLResultFormats pgx.QueryResultFormatsByOID
    91  
    92  var pgxDriver *Driver
    93  
    94  func init() {
    95  	pgxDriver = &Driver{
    96  		configs: make(map[string]*pgx.ConnConfig),
    97  	}
    98  
    99  	// if pgx driver was already registered by different pgx major version then we
   100  	// skip registration under the default name.
   101  	if !contains(sql.Drivers(), "pgx") {
   102  		sql.Register("pgx", pgxDriver)
   103  	}
   104  	sql.Register("pgx/v5", pgxDriver)
   105  
   106  	databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
   107  		pgtype.BoolOID:        1,
   108  		pgtype.ByteaOID:       1,
   109  		pgtype.CIDOID:         1,
   110  		pgtype.DateOID:        1,
   111  		pgtype.Float4OID:      1,
   112  		pgtype.Float8OID:      1,
   113  		pgtype.Int2OID:        1,
   114  		pgtype.Int4OID:        1,
   115  		pgtype.Int8OID:        1,
   116  		pgtype.OIDOID:         1,
   117  		pgtype.TimestampOID:   1,
   118  		pgtype.TimestamptzOID: 1,
   119  		pgtype.XIDOID:         1,
   120  	}
   121  }
   122  
   123  // TODO replace by slices.Contains when experimental package will be merged to stdlib
   124  // https://pkg.go.dev/golang.org/x/exp/slices#Contains
   125  func contains(list []string, y string) bool {
   126  	for _, x := range list {
   127  		if x == y {
   128  			return true
   129  		}
   130  	}
   131  	return false
   132  }
   133  
   134  // OptionOpenDB options for configuring the driver when opening a new db pool.
   135  type OptionOpenDB func(*connector)
   136  
   137  // OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
   138  // be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig.
   139  func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
   140  	return func(dc *connector) {
   141  		dc.BeforeConnect = bc
   142  	}
   143  }
   144  
   145  // OptionAfterConnect provides a callback for after connect. Used only if db is opened with *pgx.ConnConfig.
   146  func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB {
   147  	return func(dc *connector) {
   148  		dc.AfterConnect = ac
   149  	}
   150  }
   151  
   152  // OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the
   153  // connection if the connection has been used before.
   154  // If ResetSessionFunc returns ErrBadConn error the connection will be discarded.
   155  func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB {
   156  	return func(dc *connector) {
   157  		dc.ResetSession = rs
   158  	}
   159  }
   160  
   161  // RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a
   162  // new host becomes primary each time. This is useful to distribute connections for multi-master databases like
   163  // CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well
   164  // to ensure that connections are periodically rebalanced across your nodes.
   165  func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error {
   166  	if len(connConfig.Fallbacks) == 0 {
   167  		return nil
   168  	}
   169  
   170  	newFallbacks := append([]*pgconn.FallbackConfig{{
   171  		Host:      connConfig.Host,
   172  		Port:      connConfig.Port,
   173  		TLSConfig: connConfig.TLSConfig,
   174  	}}, connConfig.Fallbacks...)
   175  
   176  	rand.Shuffle(len(newFallbacks), func(i, j int) {
   177  		newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i]
   178  	})
   179  
   180  	// Use the one that sorted last as the primary and keep the rest as the fallbacks
   181  	newPrimary := newFallbacks[len(newFallbacks)-1]
   182  	connConfig.Host = newPrimary.Host
   183  	connConfig.Port = newPrimary.Port
   184  	connConfig.TLSConfig = newPrimary.TLSConfig
   185  	connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1]
   186  	return nil
   187  }
   188  
   189  func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector {
   190  	c := connector{
   191  		ConnConfig:    config,
   192  		BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default
   193  		AfterConnect:  func(context.Context, *pgx.Conn) error { return nil },       // noop after connect by default
   194  		ResetSession:  func(context.Context, *pgx.Conn) error { return nil },       // noop reset session by default
   195  		driver:        pgxDriver,
   196  	}
   197  
   198  	for _, opt := range opts {
   199  		opt(&c)
   200  	}
   201  	return c
   202  }
   203  
   204  // GetPoolConnector creates a new driver.Connector from the given *pgxpool.Pool. By using this be sure to set the
   205  // maximum idle connections of the *sql.DB created with this connector to zero since they must be managed from the
   206  // *pgxpool.Pool. This is required to avoid acquiring all the connections from the pgxpool and starving any direct
   207  // users of the pgxpool.
   208  func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector {
   209  	c := connector{
   210  		pool:         pool,
   211  		ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default
   212  		driver:       pgxDriver,
   213  	}
   214  
   215  	for _, opt := range opts {
   216  		opt(&c)
   217  	}
   218  
   219  	return c
   220  }
   221  
   222  func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
   223  	c := GetConnector(config, opts...)
   224  	return sql.OpenDB(c)
   225  }
   226  
   227  // OpenDBFromPool creates a new *sql.DB from the given *pgxpool.Pool. Note that this method automatically sets the
   228  // maximum number of idle connections in *sql.DB to zero, since they must be managed from the *pgxpool.Pool. This is
   229  // required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool.
   230  func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB {
   231  	c := GetPoolConnector(pool, opts...)
   232  	db := sql.OpenDB(c)
   233  	db.SetMaxIdleConns(0)
   234  	return db
   235  }
   236  
   237  type connector struct {
   238  	pgx.ConnConfig
   239  	pool          *pgxpool.Pool
   240  	BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
   241  	AfterConnect  func(context.Context, *pgx.Conn) error       // function to call after creation of every new connection
   242  	ResetSession  func(context.Context, *pgx.Conn) error       // function is called before a connection is reused
   243  	driver        *Driver
   244  }
   245  
   246  // Connect implement driver.Connector interface
   247  func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
   248  	var (
   249  		connConfig pgx.ConnConfig
   250  		conn       *pgx.Conn
   251  		close      func(context.Context) error
   252  		err        error
   253  	)
   254  
   255  	if c.pool == nil {
   256  		// Create a shallow copy of the config, so that BeforeConnect can safely modify it
   257  		connConfig = c.ConnConfig
   258  
   259  		if err = c.BeforeConnect(ctx, &connConfig); err != nil {
   260  			return nil, err
   261  		}
   262  
   263  		if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
   264  			return nil, err
   265  		}
   266  
   267  		if err = c.AfterConnect(ctx, conn); err != nil {
   268  			return nil, err
   269  		}
   270  
   271  		close = conn.Close
   272  	} else {
   273  		var pconn *pgxpool.Conn
   274  
   275  		pconn, err = c.pool.Acquire(ctx)
   276  		if err != nil {
   277  			return nil, err
   278  		}
   279  
   280  		conn = pconn.Conn()
   281  
   282  		close = func(_ context.Context) error {
   283  			pconn.Release()
   284  			return nil
   285  		}
   286  	}
   287  
   288  	return &Conn{
   289  		conn:             conn,
   290  		close:            close,
   291  		driver:           c.driver,
   292  		connConfig:       connConfig,
   293  		resetSessionFunc: c.ResetSession,
   294  		psRefCounts:      make(map[*pgconn.StatementDescription]int),
   295  	}, nil
   296  }
   297  
   298  // Driver implement driver.Connector interface
   299  func (c connector) Driver() driver.Driver {
   300  	return c.driver
   301  }
   302  
   303  // GetDefaultDriver returns the driver initialized in the init function
   304  // and used when the pgx driver is registered.
   305  func GetDefaultDriver() driver.Driver {
   306  	return pgxDriver
   307  }
   308  
   309  type Driver struct {
   310  	configMutex sync.Mutex
   311  	configs     map[string]*pgx.ConnConfig
   312  	sequence    int
   313  }
   314  
   315  func (d *Driver) Open(name string) (driver.Conn, error) {
   316  	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
   317  	defer cancel()
   318  
   319  	connector, err := d.OpenConnector(name)
   320  	if err != nil {
   321  		return nil, err
   322  	}
   323  	return connector.Connect(ctx)
   324  }
   325  
   326  func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
   327  	return &driverConnector{driver: d, name: name}, nil
   328  }
   329  
   330  func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string {
   331  	d.configMutex.Lock()
   332  	connStr := fmt.Sprintf("registeredConnConfig%d", d.sequence)
   333  	d.sequence++
   334  	d.configs[connStr] = c
   335  	d.configMutex.Unlock()
   336  	return connStr
   337  }
   338  
   339  func (d *Driver) unregisterConnConfig(connStr string) {
   340  	d.configMutex.Lock()
   341  	delete(d.configs, connStr)
   342  	d.configMutex.Unlock()
   343  }
   344  
   345  type driverConnector struct {
   346  	driver *Driver
   347  	name   string
   348  }
   349  
   350  func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
   351  	var connConfig *pgx.ConnConfig
   352  
   353  	dc.driver.configMutex.Lock()
   354  	connConfig = dc.driver.configs[dc.name]
   355  	dc.driver.configMutex.Unlock()
   356  
   357  	if connConfig == nil {
   358  		var err error
   359  		connConfig, err = pgx.ParseConfig(dc.name)
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  	}
   364  
   365  	conn, err := pgx.ConnectConfig(ctx, connConfig)
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  
   370  	c := &Conn{
   371  		conn:             conn,
   372  		close:            conn.Close,
   373  		driver:           dc.driver,
   374  		connConfig:       *connConfig,
   375  		resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
   376  		psRefCounts:      make(map[*pgconn.StatementDescription]int),
   377  	}
   378  
   379  	return c, nil
   380  }
   381  
   382  func (dc *driverConnector) Driver() driver.Driver {
   383  	return dc.driver
   384  }
   385  
   386  // RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open.
   387  func RegisterConnConfig(c *pgx.ConnConfig) string {
   388  	return pgxDriver.registerConnConfig(c)
   389  }
   390  
   391  // UnregisterConnConfig removes the ConnConfig registration for connStr.
   392  func UnregisterConnConfig(connStr string) {
   393  	pgxDriver.unregisterConnConfig(connStr)
   394  }
   395  
   396  type Conn struct {
   397  	conn                 *pgx.Conn
   398  	close                func(context.Context) error
   399  	driver               *Driver
   400  	connConfig           pgx.ConnConfig
   401  	resetSessionFunc     func(context.Context, *pgx.Conn) error // Function is called before a connection is reused
   402  	lastResetSessionTime time.Time
   403  
   404  	// psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate
   405  	// deterministic statement names from the statement text. If this query has already been prepared then the existing
   406  	// *pgconn.StatementDescription will be returned. However, this means that if Close is called on the returned Stmt
   407  	// then the underlying prepared statement will be closed even when the underlying prepared statement is still in use
   408  	// by another database/sql Stmt. To prevent this psRefCounts keeps track of how many database/sql statements are using
   409  	// the same underlying statement and only closes the underlying statement when the reference count reaches 0.
   410  	psRefCounts map[*pgconn.StatementDescription]int
   411  }
   412  
   413  // Conn returns the underlying *pgx.Conn
   414  func (c *Conn) Conn() *pgx.Conn {
   415  	return c.conn
   416  }
   417  
   418  func (c *Conn) Prepare(query string) (driver.Stmt, error) {
   419  	return c.PrepareContext(context.Background(), query)
   420  }
   421  
   422  func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   423  	if c.conn.IsClosed() {
   424  		return nil, driver.ErrBadConn
   425  	}
   426  
   427  	sd, err := c.conn.Prepare(ctx, query, query)
   428  	if err != nil {
   429  		return nil, err
   430  	}
   431  	c.psRefCounts[sd]++
   432  
   433  	return &Stmt{sd: sd, conn: c}, nil
   434  }
   435  
   436  func (c *Conn) Close() error {
   437  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   438  	defer cancel()
   439  	return c.close(ctx)
   440  }
   441  
   442  func (c *Conn) Begin() (driver.Tx, error) {
   443  	return c.BeginTx(context.Background(), driver.TxOptions{})
   444  }
   445  
   446  func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   447  	if c.conn.IsClosed() {
   448  		return nil, driver.ErrBadConn
   449  	}
   450  
   451  	var pgxOpts pgx.TxOptions
   452  	switch sql.IsolationLevel(opts.Isolation) {
   453  	case sql.LevelDefault:
   454  	case sql.LevelReadUncommitted:
   455  		pgxOpts.IsoLevel = pgx.ReadUncommitted
   456  	case sql.LevelReadCommitted:
   457  		pgxOpts.IsoLevel = pgx.ReadCommitted
   458  	case sql.LevelRepeatableRead, sql.LevelSnapshot:
   459  		pgxOpts.IsoLevel = pgx.RepeatableRead
   460  	case sql.LevelSerializable:
   461  		pgxOpts.IsoLevel = pgx.Serializable
   462  	default:
   463  		return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
   464  	}
   465  
   466  	if opts.ReadOnly {
   467  		pgxOpts.AccessMode = pgx.ReadOnly
   468  	}
   469  
   470  	tx, err := c.conn.BeginTx(ctx, pgxOpts)
   471  	if err != nil {
   472  		return nil, err
   473  	}
   474  
   475  	return wrapTx{ctx: ctx, tx: tx}, nil
   476  }
   477  
   478  func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
   479  	if c.conn.IsClosed() {
   480  		return nil, driver.ErrBadConn
   481  	}
   482  
   483  	args := namedValueToInterface(argsV)
   484  
   485  	commandTag, err := c.conn.Exec(ctx, query, args...)
   486  	// if we got a network error before we had a chance to send the query, retry
   487  	if err != nil {
   488  		if pgconn.SafeToRetry(err) {
   489  			return nil, driver.ErrBadConn
   490  		}
   491  	}
   492  	return driver.RowsAffected(commandTag.RowsAffected()), err
   493  }
   494  
   495  func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
   496  	if c.conn.IsClosed() {
   497  		return nil, driver.ErrBadConn
   498  	}
   499  
   500  	args := []any{databaseSQLResultFormats}
   501  	args = append(args, namedValueToInterface(argsV)...)
   502  
   503  	rows, err := c.conn.Query(ctx, query, args...)
   504  	if err != nil {
   505  		if pgconn.SafeToRetry(err) {
   506  			return nil, driver.ErrBadConn
   507  		}
   508  		return nil, err
   509  	}
   510  
   511  	// Preload first row because otherwise we won't know what columns are available when database/sql asks.
   512  	more := rows.Next()
   513  	if err = rows.Err(); err != nil {
   514  		rows.Close()
   515  		return nil, err
   516  	}
   517  	return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
   518  }
   519  
   520  func (c *Conn) Ping(ctx context.Context) error {
   521  	if c.conn.IsClosed() {
   522  		return driver.ErrBadConn
   523  	}
   524  
   525  	err := c.conn.Ping(ctx)
   526  	if err != nil {
   527  		// A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the
   528  		// failure, but manually close it just to be sure.
   529  		c.Close()
   530  		return driver.ErrBadConn
   531  	}
   532  
   533  	return nil
   534  }
   535  
   536  func (c *Conn) CheckNamedValue(*driver.NamedValue) error {
   537  	// Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly.
   538  	return nil
   539  }
   540  
   541  func (c *Conn) ResetSession(ctx context.Context) error {
   542  	if c.conn.IsClosed() {
   543  		return driver.ErrBadConn
   544  	}
   545  
   546  	now := time.Now()
   547  	if now.Sub(c.lastResetSessionTime) > time.Second {
   548  		if err := c.conn.PgConn().Ping(ctx); err != nil {
   549  			return driver.ErrBadConn
   550  		}
   551  	}
   552  	c.lastResetSessionTime = now
   553  
   554  	return c.resetSessionFunc(ctx, c.conn)
   555  }
   556  
   557  type Stmt struct {
   558  	sd   *pgconn.StatementDescription
   559  	conn *Conn
   560  }
   561  
   562  func (s *Stmt) Close() error {
   563  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   564  	defer cancel()
   565  
   566  	refCount := s.conn.psRefCounts[s.sd]
   567  	if refCount == 1 {
   568  		delete(s.conn.psRefCounts, s.sd)
   569  	} else {
   570  		s.conn.psRefCounts[s.sd]--
   571  		return nil
   572  	}
   573  
   574  	return s.conn.conn.Deallocate(ctx, s.sd.SQL)
   575  }
   576  
   577  func (s *Stmt) NumInput() int {
   578  	return len(s.sd.ParamOIDs)
   579  }
   580  
   581  func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
   582  	return nil, errors.New("Stmt.Exec deprecated and not implemented")
   583  }
   584  
   585  func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
   586  	return s.conn.ExecContext(ctx, s.sd.SQL, argsV)
   587  }
   588  
   589  func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
   590  	return nil, errors.New("Stmt.Query deprecated and not implemented")
   591  }
   592  
   593  func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
   594  	return s.conn.QueryContext(ctx, s.sd.SQL, argsV)
   595  }
   596  
   597  type rowValueFunc func(src []byte) (driver.Value, error)
   598  
   599  type Rows struct {
   600  	conn         *Conn
   601  	rows         pgx.Rows
   602  	valueFuncs   []rowValueFunc
   603  	skipNext     bool
   604  	skipNextMore bool
   605  
   606  	columnNames []string
   607  }
   608  
   609  func (r *Rows) Columns() []string {
   610  	if r.columnNames == nil {
   611  		fields := r.rows.FieldDescriptions()
   612  		r.columnNames = make([]string, len(fields))
   613  		for i, fd := range fields {
   614  			r.columnNames[i] = string(fd.Name)
   615  		}
   616  	}
   617  
   618  	return r.columnNames
   619  }
   620  
   621  // ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned.
   622  func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
   623  	if dt, ok := r.conn.conn.TypeMap().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok {
   624  		return strings.ToUpper(dt.Name)
   625  	}
   626  
   627  	return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10)
   628  }
   629  
   630  const varHeaderSize = 4
   631  
   632  // ColumnTypeLength returns the length of the column type if the column is a
   633  // variable length type. If the column is not a variable length type ok
   634  // should return false.
   635  func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
   636  	fd := r.rows.FieldDescriptions()[index]
   637  
   638  	switch fd.DataTypeOID {
   639  	case pgtype.TextOID, pgtype.ByteaOID:
   640  		return math.MaxInt64, true
   641  	case pgtype.VarcharOID, pgtype.BPCharArrayOID:
   642  		return int64(fd.TypeModifier - varHeaderSize), true
   643  	default:
   644  		return 0, false
   645  	}
   646  }
   647  
   648  // ColumnTypePrecisionScale should return the precision and scale for decimal
   649  // types. If not applicable, ok should be false.
   650  func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
   651  	fd := r.rows.FieldDescriptions()[index]
   652  
   653  	switch fd.DataTypeOID {
   654  	case pgtype.NumericOID:
   655  		mod := fd.TypeModifier - varHeaderSize
   656  		precision = int64((mod >> 16) & 0xffff)
   657  		scale = int64(mod & 0xffff)
   658  		return precision, scale, true
   659  	default:
   660  		return 0, 0, false
   661  	}
   662  }
   663  
   664  // ColumnTypeScanType returns the value type that can be used to scan types into.
   665  func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
   666  	fd := r.rows.FieldDescriptions()[index]
   667  
   668  	switch fd.DataTypeOID {
   669  	case pgtype.Float8OID:
   670  		return reflect.TypeOf(float64(0))
   671  	case pgtype.Float4OID:
   672  		return reflect.TypeOf(float32(0))
   673  	case pgtype.Int8OID:
   674  		return reflect.TypeOf(int64(0))
   675  	case pgtype.Int4OID:
   676  		return reflect.TypeOf(int32(0))
   677  	case pgtype.Int2OID:
   678  		return reflect.TypeOf(int16(0))
   679  	case pgtype.BoolOID:
   680  		return reflect.TypeOf(false)
   681  	case pgtype.NumericOID:
   682  		return reflect.TypeOf(float64(0))
   683  	case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
   684  		return reflect.TypeOf(time.Time{})
   685  	case pgtype.ByteaOID:
   686  		return reflect.TypeOf([]byte(nil))
   687  	default:
   688  		return reflect.TypeOf("")
   689  	}
   690  }
   691  
   692  func (r *Rows) Close() error {
   693  	r.rows.Close()
   694  	return r.rows.Err()
   695  }
   696  
   697  func (r *Rows) Next(dest []driver.Value) error {
   698  	m := r.conn.conn.TypeMap()
   699  	fieldDescriptions := r.rows.FieldDescriptions()
   700  
   701  	if r.valueFuncs == nil {
   702  		r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions))
   703  
   704  		for i, fd := range fieldDescriptions {
   705  			dataTypeOID := fd.DataTypeOID
   706  			format := fd.Format
   707  
   708  			switch fd.DataTypeOID {
   709  			case pgtype.BoolOID:
   710  				var d bool
   711  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   712  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   713  					err := scanPlan.Scan(src, &d)
   714  					return d, err
   715  				}
   716  			case pgtype.ByteaOID:
   717  				var d []byte
   718  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   719  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   720  					err := scanPlan.Scan(src, &d)
   721  					return d, err
   722  				}
   723  			case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID:
   724  				var d pgtype.Uint32
   725  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   726  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   727  					err := scanPlan.Scan(src, &d)
   728  					if err != nil {
   729  						return nil, err
   730  					}
   731  					return d.Value()
   732  				}
   733  			case pgtype.DateOID:
   734  				var d pgtype.Date
   735  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   736  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   737  					err := scanPlan.Scan(src, &d)
   738  					if err != nil {
   739  						return nil, err
   740  					}
   741  					return d.Value()
   742  				}
   743  			case pgtype.Float4OID:
   744  				var d float32
   745  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   746  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   747  					err := scanPlan.Scan(src, &d)
   748  					return float64(d), err
   749  				}
   750  			case pgtype.Float8OID:
   751  				var d float64
   752  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   753  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   754  					err := scanPlan.Scan(src, &d)
   755  					return d, err
   756  				}
   757  			case pgtype.Int2OID:
   758  				var d int16
   759  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   760  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   761  					err := scanPlan.Scan(src, &d)
   762  					return int64(d), err
   763  				}
   764  			case pgtype.Int4OID:
   765  				var d int32
   766  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   767  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   768  					err := scanPlan.Scan(src, &d)
   769  					return int64(d), err
   770  				}
   771  			case pgtype.Int8OID:
   772  				var d int64
   773  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   774  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   775  					err := scanPlan.Scan(src, &d)
   776  					return d, err
   777  				}
   778  			case pgtype.JSONOID, pgtype.JSONBOID:
   779  				var d []byte
   780  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   781  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   782  					err := scanPlan.Scan(src, &d)
   783  					if err != nil {
   784  						return nil, err
   785  					}
   786  					return d, nil
   787  				}
   788  			case pgtype.TimestampOID:
   789  				var d pgtype.Timestamp
   790  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   791  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   792  					err := scanPlan.Scan(src, &d)
   793  					if err != nil {
   794  						return nil, err
   795  					}
   796  					return d.Value()
   797  				}
   798  			case pgtype.TimestamptzOID:
   799  				var d pgtype.Timestamptz
   800  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   801  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   802  					err := scanPlan.Scan(src, &d)
   803  					if err != nil {
   804  						return nil, err
   805  					}
   806  					return d.Value()
   807  				}
   808  			default:
   809  				var d string
   810  				scanPlan := m.PlanScan(dataTypeOID, format, &d)
   811  				r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
   812  					err := scanPlan.Scan(src, &d)
   813  					return d, err
   814  				}
   815  			}
   816  		}
   817  	}
   818  
   819  	var more bool
   820  	if r.skipNext {
   821  		more = r.skipNextMore
   822  		r.skipNext = false
   823  	} else {
   824  		more = r.rows.Next()
   825  	}
   826  
   827  	if !more {
   828  		if r.rows.Err() == nil {
   829  			return io.EOF
   830  		} else {
   831  			return r.rows.Err()
   832  		}
   833  	}
   834  
   835  	for i, rv := range r.rows.RawValues() {
   836  		if rv != nil {
   837  			var err error
   838  			dest[i], err = r.valueFuncs[i](rv)
   839  			if err != nil {
   840  				return fmt.Errorf("convert field %d failed: %w", i, err)
   841  			}
   842  		} else {
   843  			dest[i] = nil
   844  		}
   845  	}
   846  
   847  	return nil
   848  }
   849  
   850  func valueToInterface(argsV []driver.Value) []any {
   851  	args := make([]any, 0, len(argsV))
   852  	for _, v := range argsV {
   853  		if v != nil {
   854  			args = append(args, v.(any))
   855  		} else {
   856  			args = append(args, nil)
   857  		}
   858  	}
   859  	return args
   860  }
   861  
   862  func namedValueToInterface(argsV []driver.NamedValue) []any {
   863  	args := make([]any, 0, len(argsV))
   864  	for _, v := range argsV {
   865  		if v.Value != nil {
   866  			args = append(args, v.Value.(any))
   867  		} else {
   868  			args = append(args, nil)
   869  		}
   870  	}
   871  	return args
   872  }
   873  
   874  type wrapTx struct {
   875  	ctx context.Context
   876  	tx  pgx.Tx
   877  }
   878  
   879  func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) }
   880  
   881  func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) }
   882  

View as plain text