...

Source file src/github.com/jackc/pgx/v5/conn.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/jackc/pgx/v5/internal/anynil"
    14  	"github.com/jackc/pgx/v5/internal/sanitize"
    15  	"github.com/jackc/pgx/v5/internal/stmtcache"
    16  	"github.com/jackc/pgx/v5/pgconn"
    17  	"github.com/jackc/pgx/v5/pgtype"
    18  )
    19  
    20  // ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and
    21  // then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
    22  type ConnConfig struct {
    23  	pgconn.Config
    24  
    25  	Tracer QueryTracer
    26  
    27  	// Original connection string that was parsed into config.
    28  	connString string
    29  
    30  	// StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement"
    31  	// query exec mode.
    32  	StatementCacheCapacity int
    33  
    34  	// DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with
    35  	// "cache_describe" query exec mode.
    36  	DescriptionCacheCapacity int
    37  
    38  	// DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol
    39  	// and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as
    40  	// PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
    41  	// functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument.
    42  	DefaultQueryExecMode QueryExecMode
    43  
    44  	createdByParseConfig bool // Used to enforce created by ParseConfig rule.
    45  }
    46  
    47  // ParseConfigOptions contains options that control how a config is built such as getsslpassword.
    48  type ParseConfigOptions struct {
    49  	pgconn.ParseConfigOptions
    50  }
    51  
    52  // Copy returns a deep copy of the config that is safe to use and modify.
    53  // The only exception is the tls.Config:
    54  // according to the tls.Config docs it must not be modified after creation.
    55  func (cc *ConnConfig) Copy() *ConnConfig {
    56  	newConfig := new(ConnConfig)
    57  	*newConfig = *cc
    58  	newConfig.Config = *newConfig.Config.Copy()
    59  	return newConfig
    60  }
    61  
    62  // ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig.
    63  func (cc *ConnConfig) ConnString() string { return cc.connString }
    64  
    65  // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access
    66  // to multiple database connections from multiple goroutines.
    67  type Conn struct {
    68  	pgConn             *pgconn.PgConn
    69  	config             *ConnConfig // config used when establishing this connection
    70  	preparedStatements map[string]*pgconn.StatementDescription
    71  	statementCache     stmtcache.Cache
    72  	descriptionCache   stmtcache.Cache
    73  
    74  	queryTracer    QueryTracer
    75  	batchTracer    BatchTracer
    76  	copyFromTracer CopyFromTracer
    77  	prepareTracer  PrepareTracer
    78  
    79  	notifications []*pgconn.Notification
    80  
    81  	doneChan   chan struct{}
    82  	closedChan chan error
    83  
    84  	typeMap *pgtype.Map
    85  
    86  	wbuf []byte
    87  	eqb  ExtendedQueryBuilder
    88  }
    89  
    90  // Identifier a PostgreSQL identifier or name. Identifiers can be composed of
    91  // multiple parts such as ["schema", "table"] or ["table", "column"].
    92  type Identifier []string
    93  
    94  // Sanitize returns a sanitized string safe for SQL interpolation.
    95  func (ident Identifier) Sanitize() string {
    96  	parts := make([]string, len(ident))
    97  	for i := range ident {
    98  		s := strings.ReplaceAll(ident[i], string([]byte{0}), "")
    99  		parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
   100  	}
   101  	return strings.Join(parts, ".")
   102  }
   103  
   104  var (
   105  	// ErrNoRows occurs when rows are expected but none are returned.
   106  	ErrNoRows = errors.New("no rows in result set")
   107  	// ErrTooManyRows occurs when more rows than expected are returned.
   108  	ErrTooManyRows = errors.New("too many rows in result set")
   109  )
   110  
   111  var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
   112  var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
   113  
   114  // Connect establishes a connection with a PostgreSQL server with a connection string. See
   115  // pgconn.Connect for details.
   116  func Connect(ctx context.Context, connString string) (*Conn, error) {
   117  	connConfig, err := ParseConfig(connString)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	return connect(ctx, connConfig)
   122  }
   123  
   124  // ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to
   125  // provide a GetSSLPassword function.
   126  func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) {
   127  	connConfig, err := ParseConfigWithOptions(connString, options)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	return connect(ctx, connConfig)
   132  }
   133  
   134  // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct.
   135  // connConfig must have been created by ParseConfig.
   136  func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
   137  	// In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other
   138  	// connections with the same config. See https://github.com/jackc/pgx/issues/618.
   139  	connConfig = connConfig.Copy()
   140  
   141  	return connect(ctx, connConfig)
   142  }
   143  
   144  // ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is
   145  // only used to provide a GetSSLPassword function.
   146  func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) {
   147  	config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	statementCacheCapacity := 512
   153  	if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok {
   154  		delete(config.RuntimeParams, "statement_cache_capacity")
   155  		n, err := strconv.ParseInt(s, 10, 32)
   156  		if err != nil {
   157  			return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
   158  		}
   159  		statementCacheCapacity = int(n)
   160  	}
   161  
   162  	descriptionCacheCapacity := 512
   163  	if s, ok := config.RuntimeParams["description_cache_capacity"]; ok {
   164  		delete(config.RuntimeParams, "description_cache_capacity")
   165  		n, err := strconv.ParseInt(s, 10, 32)
   166  		if err != nil {
   167  			return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err)
   168  		}
   169  		descriptionCacheCapacity = int(n)
   170  	}
   171  
   172  	defaultQueryExecMode := QueryExecModeCacheStatement
   173  	if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok {
   174  		delete(config.RuntimeParams, "default_query_exec_mode")
   175  		switch s {
   176  		case "cache_statement":
   177  			defaultQueryExecMode = QueryExecModeCacheStatement
   178  		case "cache_describe":
   179  			defaultQueryExecMode = QueryExecModeCacheDescribe
   180  		case "describe_exec":
   181  			defaultQueryExecMode = QueryExecModeDescribeExec
   182  		case "exec":
   183  			defaultQueryExecMode = QueryExecModeExec
   184  		case "simple_protocol":
   185  			defaultQueryExecMode = QueryExecModeSimpleProtocol
   186  		default:
   187  			return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s)
   188  		}
   189  	}
   190  
   191  	connConfig := &ConnConfig{
   192  		Config:                   *config,
   193  		createdByParseConfig:     true,
   194  		StatementCacheCapacity:   statementCacheCapacity,
   195  		DescriptionCacheCapacity: descriptionCacheCapacity,
   196  		DefaultQueryExecMode:     defaultQueryExecMode,
   197  		connString:               connString,
   198  	}
   199  
   200  	return connConfig, nil
   201  }
   202  
   203  // ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig]
   204  // does. In addition, it accepts the following options:
   205  //
   206  //   - default_query_exec_mode.
   207  //     Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
   208  //     QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
   209  //
   210  //   - statement_cache_capacity.
   211  //     The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
   212  //     Default: 512.
   213  //
   214  //   - description_cache_capacity.
   215  //     The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
   216  //     Default: 512.
   217  func ParseConfig(connString string) (*ConnConfig, error) {
   218  	return ParseConfigWithOptions(connString, ParseConfigOptions{})
   219  }
   220  
   221  // connect connects to a database. connect takes ownership of config. The caller must not use or access it again.
   222  func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
   223  	if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
   224  		ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config})
   225  		defer func() {
   226  			connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err})
   227  		}()
   228  	}
   229  
   230  	// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
   231  	// zero values.
   232  	if !config.createdByParseConfig {
   233  		panic("config must be created by ParseConfig")
   234  	}
   235  
   236  	c = &Conn{
   237  		config:      config,
   238  		typeMap:     pgtype.NewMap(),
   239  		queryTracer: config.Tracer,
   240  	}
   241  
   242  	if t, ok := c.queryTracer.(BatchTracer); ok {
   243  		c.batchTracer = t
   244  	}
   245  	if t, ok := c.queryTracer.(CopyFromTracer); ok {
   246  		c.copyFromTracer = t
   247  	}
   248  	if t, ok := c.queryTracer.(PrepareTracer); ok {
   249  		c.prepareTracer = t
   250  	}
   251  
   252  	// Only install pgx notification system if no other callback handler is present.
   253  	if config.Config.OnNotification == nil {
   254  		config.Config.OnNotification = c.bufferNotifications
   255  	}
   256  
   257  	c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  
   262  	c.preparedStatements = make(map[string]*pgconn.StatementDescription)
   263  	c.doneChan = make(chan struct{})
   264  	c.closedChan = make(chan error)
   265  	c.wbuf = make([]byte, 0, 1024)
   266  
   267  	if c.config.StatementCacheCapacity > 0 {
   268  		c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
   269  	}
   270  
   271  	if c.config.DescriptionCacheCapacity > 0 {
   272  		c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
   273  	}
   274  
   275  	return c, nil
   276  }
   277  
   278  // Close closes a connection. It is safe to call Close on an already closed
   279  // connection.
   280  func (c *Conn) Close(ctx context.Context) error {
   281  	if c.IsClosed() {
   282  		return nil
   283  	}
   284  
   285  	err := c.pgConn.Close(ctx)
   286  	return err
   287  }
   288  
   289  // Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These
   290  // placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and
   291  // Exec to execute the statement. It can also be used with Batch.Queue.
   292  //
   293  // The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if
   294  // name == sql.
   295  //
   296  // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This
   297  // allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared.
   298  func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
   299  	if c.prepareTracer != nil {
   300  		ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
   301  	}
   302  
   303  	if name != "" {
   304  		var ok bool
   305  		if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql {
   306  			if c.prepareTracer != nil {
   307  				c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true})
   308  			}
   309  			return sd, nil
   310  		}
   311  	}
   312  
   313  	if c.prepareTracer != nil {
   314  		defer func() {
   315  			c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err})
   316  		}()
   317  	}
   318  
   319  	var psName, psKey string
   320  	if name == sql {
   321  		digest := sha256.Sum256([]byte(sql))
   322  		psName = "stmt_" + hex.EncodeToString(digest[0:24])
   323  		psKey = sql
   324  	} else {
   325  		psName = name
   326  		psKey = name
   327  	}
   328  
   329  	sd, err = c.pgConn.Prepare(ctx, psName, sql, nil)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	if psKey != "" {
   335  		c.preparedStatements[psKey] = sd
   336  	}
   337  
   338  	return sd, nil
   339  }
   340  
   341  // Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed.
   342  func (c *Conn) Deallocate(ctx context.Context, name string) error {
   343  	var psName string
   344  	sd := c.preparedStatements[name]
   345  	if sd != nil {
   346  		psName = sd.Name
   347  	} else {
   348  		psName = name
   349  	}
   350  
   351  	err := c.pgConn.Deallocate(ctx, psName)
   352  	if err != nil {
   353  		return err
   354  	}
   355  
   356  	if sd != nil {
   357  		delete(c.preparedStatements, name)
   358  	}
   359  
   360  	return nil
   361  }
   362  
   363  // DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache.
   364  func (c *Conn) DeallocateAll(ctx context.Context) error {
   365  	c.preparedStatements = map[string]*pgconn.StatementDescription{}
   366  	if c.config.StatementCacheCapacity > 0 {
   367  		c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
   368  	}
   369  	if c.config.DescriptionCacheCapacity > 0 {
   370  		c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
   371  	}
   372  	_, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll()
   373  	return err
   374  }
   375  
   376  func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) {
   377  	c.notifications = append(c.notifications, n)
   378  }
   379  
   380  // WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a
   381  // slightly more convenient form.
   382  func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) {
   383  	var n *pgconn.Notification
   384  
   385  	// Return already received notification immediately
   386  	if len(c.notifications) > 0 {
   387  		n = c.notifications[0]
   388  		c.notifications = c.notifications[1:]
   389  		return n, nil
   390  	}
   391  
   392  	err := c.pgConn.WaitForNotification(ctx)
   393  	if len(c.notifications) > 0 {
   394  		n = c.notifications[0]
   395  		c.notifications = c.notifications[1:]
   396  	}
   397  	return n, err
   398  }
   399  
   400  // IsClosed reports if the connection has been closed.
   401  func (c *Conn) IsClosed() bool {
   402  	return c.pgConn.IsClosed()
   403  }
   404  
   405  func (c *Conn) die(err error) {
   406  	if c.IsClosed() {
   407  		return
   408  	}
   409  
   410  	ctx, cancel := context.WithCancel(context.Background())
   411  	cancel() // force immediate hard cancel
   412  	c.pgConn.Close(ctx)
   413  }
   414  
   415  func quoteIdentifier(s string) string {
   416  	return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
   417  }
   418  
   419  // Ping delegates to the underlying *pgconn.PgConn.Ping.
   420  func (c *Conn) Ping(ctx context.Context) error {
   421  	return c.pgConn.Ping(ctx)
   422  }
   423  
   424  // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
   425  // PostgreSQL connection than pgx exposes.
   426  //
   427  // It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn
   428  // is used and the connection must be returned to the same state before any *pgx.Conn methods are again used.
   429  func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }
   430  
   431  // TypeMap returns the connection info used for this connection.
   432  func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap }
   433  
   434  // Config returns a copy of config that was used to establish this connection.
   435  func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
   436  
   437  // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
   438  // positionally from the sql string as $1, $2, etc.
   439  func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
   440  	if c.queryTracer != nil {
   441  		ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments})
   442  	}
   443  
   444  	if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
   445  		return pgconn.CommandTag{}, err
   446  	}
   447  
   448  	commandTag, err := c.exec(ctx, sql, arguments...)
   449  
   450  	if c.queryTracer != nil {
   451  		c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err})
   452  	}
   453  
   454  	return commandTag, err
   455  }
   456  
   457  func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
   458  	mode := c.config.DefaultQueryExecMode
   459  	var queryRewriter QueryRewriter
   460  
   461  optionLoop:
   462  	for len(arguments) > 0 {
   463  		switch arg := arguments[0].(type) {
   464  		case QueryExecMode:
   465  			mode = arg
   466  			arguments = arguments[1:]
   467  		case QueryRewriter:
   468  			queryRewriter = arg
   469  			arguments = arguments[1:]
   470  		default:
   471  			break optionLoop
   472  		}
   473  	}
   474  
   475  	if queryRewriter != nil {
   476  		sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
   477  		if err != nil {
   478  			return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err)
   479  		}
   480  	}
   481  
   482  	// Always use simple protocol when there are no arguments.
   483  	if len(arguments) == 0 {
   484  		mode = QueryExecModeSimpleProtocol
   485  	}
   486  
   487  	if sd, ok := c.preparedStatements[sql]; ok {
   488  		return c.execPrepared(ctx, sd, arguments)
   489  	}
   490  
   491  	switch mode {
   492  	case QueryExecModeCacheStatement:
   493  		if c.statementCache == nil {
   494  			return pgconn.CommandTag{}, errDisabledStatementCache
   495  		}
   496  		sd := c.statementCache.Get(sql)
   497  		if sd == nil {
   498  			sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
   499  			if err != nil {
   500  				return pgconn.CommandTag{}, err
   501  			}
   502  			c.statementCache.Put(sd)
   503  		}
   504  
   505  		return c.execPrepared(ctx, sd, arguments)
   506  	case QueryExecModeCacheDescribe:
   507  		if c.descriptionCache == nil {
   508  			return pgconn.CommandTag{}, errDisabledDescriptionCache
   509  		}
   510  		sd := c.descriptionCache.Get(sql)
   511  		if sd == nil {
   512  			sd, err = c.Prepare(ctx, "", sql)
   513  			if err != nil {
   514  				return pgconn.CommandTag{}, err
   515  			}
   516  			c.descriptionCache.Put(sd)
   517  		}
   518  
   519  		return c.execParams(ctx, sd, arguments)
   520  	case QueryExecModeDescribeExec:
   521  		sd, err := c.Prepare(ctx, "", sql)
   522  		if err != nil {
   523  			return pgconn.CommandTag{}, err
   524  		}
   525  		return c.execPrepared(ctx, sd, arguments)
   526  	case QueryExecModeExec:
   527  		return c.execSQLParams(ctx, sql, arguments)
   528  	case QueryExecModeSimpleProtocol:
   529  		return c.execSimpleProtocol(ctx, sql, arguments)
   530  	default:
   531  		return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode)
   532  	}
   533  }
   534  
   535  func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []any) (commandTag pgconn.CommandTag, err error) {
   536  	if len(arguments) > 0 {
   537  		sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
   538  		if err != nil {
   539  			return pgconn.CommandTag{}, err
   540  		}
   541  	}
   542  
   543  	mrr := c.pgConn.Exec(ctx, sql)
   544  	for mrr.NextResult() {
   545  		commandTag, _ = mrr.ResultReader().Close()
   546  	}
   547  	err = mrr.Close()
   548  	return commandTag, err
   549  }
   550  
   551  func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) {
   552  	err := c.eqb.Build(c.typeMap, sd, arguments)
   553  	if err != nil {
   554  		return pgconn.CommandTag{}, err
   555  	}
   556  
   557  	result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
   558  	c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   559  	return result.CommandTag, result.Err
   560  }
   561  
   562  func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) {
   563  	err := c.eqb.Build(c.typeMap, sd, arguments)
   564  	if err != nil {
   565  		return pgconn.CommandTag{}, err
   566  	}
   567  
   568  	result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
   569  	c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   570  	return result.CommandTag, result.Err
   571  }
   572  
   573  type unknownArgumentTypeQueryExecModeExecError struct {
   574  	arg any
   575  }
   576  
   577  func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
   578  	return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
   579  }
   580  
   581  func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
   582  	err := c.eqb.Build(c.typeMap, nil, args)
   583  	if err != nil {
   584  		return pgconn.CommandTag{}, err
   585  	}
   586  
   587  	result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
   588  	c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   589  	return result.CommandTag, result.Err
   590  }
   591  
   592  func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows {
   593  	r := &baseRows{}
   594  
   595  	r.ctx = ctx
   596  	r.queryTracer = c.queryTracer
   597  	r.typeMap = c.typeMap
   598  	r.startTime = time.Now()
   599  	r.sql = sql
   600  	r.args = args
   601  	r.conn = c
   602  
   603  	return r
   604  }
   605  
   606  type QueryExecMode int32
   607  
   608  const (
   609  	_ QueryExecMode = iota
   610  
   611  	// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round
   612  	// trip after the statement is cached. This is the default. If the database schema is modified or the search_path is
   613  	// changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the
   614  	// number of columns returned by a "SELECT *" changes or the type of a column is changed.
   615  	QueryExecModeCacheStatement
   616  
   617  	// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended
   618  	// protocol. Queries are executed in a single round trip after the description is cached. If the database schema is
   619  	// modified or the search_path is changed after a statement is cached then the first execution of a previously cached
   620  	// query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed.
   621  	QueryExecModeCacheDescribe
   622  
   623  	// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips
   624  	// to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the
   625  	// statement description on the first round trip and then uses it to execute the query on the second round trip. This
   626  	// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe
   627  	// even when the the database schema is modified concurrently.
   628  	QueryExecModeDescribeExec
   629  
   630  	// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
   631  	// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
   632  	// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
   633  	// unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
   634  	// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
   635  	QueryExecModeExec
   636  
   637  	// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
   638  	// Queries are executed in a single round trip. Type mappings can be registered with
   639  	// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
   640  	// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
   641  	// a map[string]string directly as an argument. This mode cannot.
   642  	//
   643  	// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor
   644  	// exceptions such as behavior when multiple result returning queries are erroneously sent in a single string.
   645  	//
   646  	// QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer
   647  	// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol
   648  	// should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does
   649  	// not support the extended protocol.
   650  	QueryExecModeSimpleProtocol
   651  )
   652  
   653  func (m QueryExecMode) String() string {
   654  	switch m {
   655  	case QueryExecModeCacheStatement:
   656  		return "cache statement"
   657  	case QueryExecModeCacheDescribe:
   658  		return "cache describe"
   659  	case QueryExecModeDescribeExec:
   660  		return "describe exec"
   661  	case QueryExecModeExec:
   662  		return "exec"
   663  	case QueryExecModeSimpleProtocol:
   664  		return "simple protocol"
   665  	default:
   666  		return "invalid"
   667  	}
   668  }
   669  
   670  // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
   671  type QueryResultFormats []int16
   672  
   673  // QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
   674  type QueryResultFormatsByOID map[uint32]int16
   675  
   676  // QueryRewriter rewrites a query when used as the first arguments to a query method.
   677  type QueryRewriter interface {
   678  	RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
   679  }
   680  
   681  // Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
   682  // and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to
   683  // determine if the query executed successfully.
   684  //
   685  // The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the
   686  // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It
   687  // is allowed to ignore the error returned from Query and handle it in Rows.
   688  //
   689  // It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not
   690  // return an error.
   691  //
   692  // It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be
   693  // collected before processing rather than processed while receiving each row. This avoids the possibility of the
   694  // application processing rows from a query that the server rejected. The CollectRows function is useful here.
   695  //
   696  // An implementor of QueryRewriter may be passed as the first element of args. It can rewrite the sql and change or
   697  // replace args. For example, NamedArgs is QueryRewriter that implements named arguments.
   698  //
   699  // For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and
   700  // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
   701  // needed. See the documentation for those types for details.
   702  func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
   703  	if c.queryTracer != nil {
   704  		ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args})
   705  	}
   706  
   707  	if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
   708  		if c.queryTracer != nil {
   709  			c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err})
   710  		}
   711  		return &baseRows{err: err, closed: true}, err
   712  	}
   713  
   714  	var resultFormats QueryResultFormats
   715  	var resultFormatsByOID QueryResultFormatsByOID
   716  	mode := c.config.DefaultQueryExecMode
   717  	var queryRewriter QueryRewriter
   718  
   719  optionLoop:
   720  	for len(args) > 0 {
   721  		switch arg := args[0].(type) {
   722  		case QueryResultFormats:
   723  			resultFormats = arg
   724  			args = args[1:]
   725  		case QueryResultFormatsByOID:
   726  			resultFormatsByOID = arg
   727  			args = args[1:]
   728  		case QueryExecMode:
   729  			mode = arg
   730  			args = args[1:]
   731  		case QueryRewriter:
   732  			queryRewriter = arg
   733  			args = args[1:]
   734  		default:
   735  			break optionLoop
   736  		}
   737  	}
   738  
   739  	if queryRewriter != nil {
   740  		var err error
   741  		originalSQL := sql
   742  		originalArgs := args
   743  		sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
   744  		if err != nil {
   745  			rows := c.getRows(ctx, originalSQL, originalArgs)
   746  			err = fmt.Errorf("rewrite query failed: %w", err)
   747  			rows.fatal(err)
   748  			return rows, err
   749  		}
   750  	}
   751  
   752  	// Bypass any statement caching.
   753  	if sql == "" {
   754  		mode = QueryExecModeSimpleProtocol
   755  	}
   756  
   757  	c.eqb.reset()
   758  	anynil.NormalizeSlice(args)
   759  	rows := c.getRows(ctx, sql, args)
   760  
   761  	var err error
   762  	sd, explicitPreparedStatement := c.preparedStatements[sql]
   763  	if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
   764  		if sd == nil {
   765  			sd, err = c.getStatementDescription(ctx, mode, sql)
   766  			if err != nil {
   767  				rows.fatal(err)
   768  				return rows, err
   769  			}
   770  		}
   771  
   772  		if len(sd.ParamOIDs) != len(args) {
   773  			rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
   774  			return rows, rows.err
   775  		}
   776  
   777  		rows.sql = sd.SQL
   778  
   779  		err = c.eqb.Build(c.typeMap, sd, args)
   780  		if err != nil {
   781  			rows.fatal(err)
   782  			return rows, rows.err
   783  		}
   784  
   785  		if resultFormatsByOID != nil {
   786  			resultFormats = make([]int16, len(sd.Fields))
   787  			for i := range resultFormats {
   788  				resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
   789  			}
   790  		}
   791  
   792  		if resultFormats == nil {
   793  			resultFormats = c.eqb.ResultFormats
   794  		}
   795  
   796  		if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe {
   797  			rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats)
   798  		} else {
   799  			rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats)
   800  		}
   801  	} else if mode == QueryExecModeExec {
   802  		err := c.eqb.Build(c.typeMap, nil, args)
   803  		if err != nil {
   804  			rows.fatal(err)
   805  			return rows, rows.err
   806  		}
   807  
   808  		rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
   809  	} else if mode == QueryExecModeSimpleProtocol {
   810  		sql, err = c.sanitizeForSimpleQuery(sql, args...)
   811  		if err != nil {
   812  			rows.fatal(err)
   813  			return rows, err
   814  		}
   815  
   816  		mrr := c.pgConn.Exec(ctx, sql)
   817  		if mrr.NextResult() {
   818  			rows.resultReader = mrr.ResultReader()
   819  			rows.multiResultReader = mrr
   820  		} else {
   821  			err = mrr.Close()
   822  			rows.fatal(err)
   823  			return rows, err
   824  		}
   825  
   826  		return rows, nil
   827  	} else {
   828  		err = fmt.Errorf("unknown QueryExecMode: %v", mode)
   829  		rows.fatal(err)
   830  		return rows, rows.err
   831  	}
   832  
   833  	c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   834  
   835  	return rows, rows.err
   836  }
   837  
   838  // getStatementDescription returns the statement description of the sql query
   839  // according to the given mode.
   840  //
   841  // If the mode is one that doesn't require to know the param and result OIDs
   842  // then nil is returned without error.
   843  func (c *Conn) getStatementDescription(
   844  	ctx context.Context,
   845  	mode QueryExecMode,
   846  	sql string,
   847  ) (sd *pgconn.StatementDescription, err error) {
   848  
   849  	switch mode {
   850  	case QueryExecModeCacheStatement:
   851  		if c.statementCache == nil {
   852  			return nil, errDisabledStatementCache
   853  		}
   854  		sd = c.statementCache.Get(sql)
   855  		if sd == nil {
   856  			sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
   857  			if err != nil {
   858  				return nil, err
   859  			}
   860  			c.statementCache.Put(sd)
   861  		}
   862  	case QueryExecModeCacheDescribe:
   863  		if c.descriptionCache == nil {
   864  			return nil, errDisabledDescriptionCache
   865  		}
   866  		sd = c.descriptionCache.Get(sql)
   867  		if sd == nil {
   868  			sd, err = c.Prepare(ctx, "", sql)
   869  			if err != nil {
   870  				return nil, err
   871  			}
   872  			c.descriptionCache.Put(sd)
   873  		}
   874  	case QueryExecModeDescribeExec:
   875  		return c.Prepare(ctx, "", sql)
   876  	}
   877  	return sd, err
   878  }
   879  
   880  // QueryRow is a convenience wrapper over Query. Any error that occurs while
   881  // querying is deferred until calling Scan on the returned Row. That Row will
   882  // error with ErrNoRows if no rows are returned.
   883  func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
   884  	rows, _ := c.Query(ctx, sql, args...)
   885  	return (*connRow)(rows.(*baseRows))
   886  }
   887  
   888  // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
   889  // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
   890  // is used again.
   891  func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
   892  	if c.batchTracer != nil {
   893  		ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
   894  		defer func() {
   895  			err := br.(interface{ earlyError() error }).earlyError()
   896  			if err != nil {
   897  				c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err})
   898  			}
   899  		}()
   900  	}
   901  
   902  	if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
   903  		return &batchResults{ctx: ctx, conn: c, err: err}
   904  	}
   905  
   906  	for _, bi := range b.QueuedQueries {
   907  		var queryRewriter QueryRewriter
   908  		sql := bi.SQL
   909  		arguments := bi.Arguments
   910  
   911  	optionLoop:
   912  		for len(arguments) > 0 {
   913  			// Update Batch.Queue function comment when additional options are implemented
   914  			switch arg := arguments[0].(type) {
   915  			case QueryRewriter:
   916  				queryRewriter = arg
   917  				arguments = arguments[1:]
   918  			default:
   919  				break optionLoop
   920  			}
   921  		}
   922  
   923  		if queryRewriter != nil {
   924  			var err error
   925  			sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
   926  			if err != nil {
   927  				return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)}
   928  			}
   929  		}
   930  
   931  		bi.SQL = sql
   932  		bi.Arguments = arguments
   933  	}
   934  
   935  	// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
   936  	mode := c.config.DefaultQueryExecMode
   937  	if mode == QueryExecModeSimpleProtocol {
   938  		return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
   939  	}
   940  
   941  	// All other modes use extended protocol and thus can use prepared statements.
   942  	for _, bi := range b.QueuedQueries {
   943  		if sd, ok := c.preparedStatements[bi.SQL]; ok {
   944  			bi.sd = sd
   945  		}
   946  	}
   947  
   948  	switch mode {
   949  	case QueryExecModeExec:
   950  		return c.sendBatchQueryExecModeExec(ctx, b)
   951  	case QueryExecModeCacheStatement:
   952  		return c.sendBatchQueryExecModeCacheStatement(ctx, b)
   953  	case QueryExecModeCacheDescribe:
   954  		return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
   955  	case QueryExecModeDescribeExec:
   956  		return c.sendBatchQueryExecModeDescribeExec(ctx, b)
   957  	default:
   958  		panic("unknown QueryExecMode")
   959  	}
   960  }
   961  
   962  func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
   963  	var sb strings.Builder
   964  	for i, bi := range b.QueuedQueries {
   965  		if i > 0 {
   966  			sb.WriteByte(';')
   967  		}
   968  		sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
   969  		if err != nil {
   970  			return &batchResults{ctx: ctx, conn: c, err: err}
   971  		}
   972  		sb.WriteString(sql)
   973  	}
   974  	mrr := c.pgConn.Exec(ctx, sb.String())
   975  	return &batchResults{
   976  		ctx:   ctx,
   977  		conn:  c,
   978  		mrr:   mrr,
   979  		b:     b,
   980  		qqIdx: 0,
   981  	}
   982  }
   983  
   984  func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
   985  	batch := &pgconn.Batch{}
   986  
   987  	for _, bi := range b.QueuedQueries {
   988  		sd := bi.sd
   989  		if sd != nil {
   990  			err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
   991  			if err != nil {
   992  				return &batchResults{ctx: ctx, conn: c, err: err}
   993  			}
   994  
   995  			batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
   996  		} else {
   997  			err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
   998  			if err != nil {
   999  				return &batchResults{ctx: ctx, conn: c, err: err}
  1000  			}
  1001  			batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
  1002  		}
  1003  	}
  1004  
  1005  	c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
  1006  
  1007  	mrr := c.pgConn.ExecBatch(ctx, batch)
  1008  
  1009  	return &batchResults{
  1010  		ctx:   ctx,
  1011  		conn:  c,
  1012  		mrr:   mrr,
  1013  		b:     b,
  1014  		qqIdx: 0,
  1015  	}
  1016  }
  1017  
  1018  func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
  1019  	if c.statementCache == nil {
  1020  		return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
  1021  	}
  1022  
  1023  	distinctNewQueries := []*pgconn.StatementDescription{}
  1024  	distinctNewQueriesIdxMap := make(map[string]int)
  1025  
  1026  	for _, bi := range b.QueuedQueries {
  1027  		if bi.sd == nil {
  1028  			sd := c.statementCache.Get(bi.SQL)
  1029  			if sd != nil {
  1030  				bi.sd = sd
  1031  			} else {
  1032  				if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
  1033  					bi.sd = distinctNewQueries[idx]
  1034  				} else {
  1035  					sd = &pgconn.StatementDescription{
  1036  						Name: stmtcache.StatementName(bi.SQL),
  1037  						SQL:  bi.SQL,
  1038  					}
  1039  					distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
  1040  					distinctNewQueries = append(distinctNewQueries, sd)
  1041  					bi.sd = sd
  1042  				}
  1043  			}
  1044  		}
  1045  	}
  1046  
  1047  	return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
  1048  }
  1049  
  1050  func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
  1051  	if c.descriptionCache == nil {
  1052  		return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
  1053  	}
  1054  
  1055  	distinctNewQueries := []*pgconn.StatementDescription{}
  1056  	distinctNewQueriesIdxMap := make(map[string]int)
  1057  
  1058  	for _, bi := range b.QueuedQueries {
  1059  		if bi.sd == nil {
  1060  			sd := c.descriptionCache.Get(bi.SQL)
  1061  			if sd != nil {
  1062  				bi.sd = sd
  1063  			} else {
  1064  				if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
  1065  					bi.sd = distinctNewQueries[idx]
  1066  				} else {
  1067  					sd = &pgconn.StatementDescription{
  1068  						SQL: bi.SQL,
  1069  					}
  1070  					distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
  1071  					distinctNewQueries = append(distinctNewQueries, sd)
  1072  					bi.sd = sd
  1073  				}
  1074  			}
  1075  		}
  1076  	}
  1077  
  1078  	return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
  1079  }
  1080  
  1081  func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
  1082  	distinctNewQueries := []*pgconn.StatementDescription{}
  1083  	distinctNewQueriesIdxMap := make(map[string]int)
  1084  
  1085  	for _, bi := range b.QueuedQueries {
  1086  		if bi.sd == nil {
  1087  			if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
  1088  				bi.sd = distinctNewQueries[idx]
  1089  			} else {
  1090  				sd := &pgconn.StatementDescription{
  1091  					SQL: bi.SQL,
  1092  				}
  1093  				distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
  1094  				distinctNewQueries = append(distinctNewQueries, sd)
  1095  				bi.sd = sd
  1096  			}
  1097  		}
  1098  	}
  1099  
  1100  	return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
  1101  }
  1102  
  1103  func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
  1104  	pipeline := c.pgConn.StartPipeline(ctx)
  1105  	defer func() {
  1106  		if pbr != nil && pbr.err != nil {
  1107  			pipeline.Close()
  1108  		}
  1109  	}()
  1110  
  1111  	// Prepare any needed queries
  1112  	if len(distinctNewQueries) > 0 {
  1113  		for _, sd := range distinctNewQueries {
  1114  			pipeline.SendPrepare(sd.Name, sd.SQL, nil)
  1115  		}
  1116  
  1117  		err := pipeline.Sync()
  1118  		if err != nil {
  1119  			return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
  1120  		}
  1121  
  1122  		for _, sd := range distinctNewQueries {
  1123  			results, err := pipeline.GetResults()
  1124  			if err != nil {
  1125  				return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
  1126  			}
  1127  
  1128  			resultSD, ok := results.(*pgconn.StatementDescription)
  1129  			if !ok {
  1130  				return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
  1131  			}
  1132  
  1133  			// Fill in the previously empty / pending statement descriptions.
  1134  			sd.ParamOIDs = resultSD.ParamOIDs
  1135  			sd.Fields = resultSD.Fields
  1136  		}
  1137  
  1138  		results, err := pipeline.GetResults()
  1139  		if err != nil {
  1140  			return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
  1141  		}
  1142  
  1143  		_, ok := results.(*pgconn.PipelineSync)
  1144  		if !ok {
  1145  			return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
  1146  		}
  1147  	}
  1148  
  1149  	// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
  1150  	if sdCache != nil {
  1151  		for _, sd := range distinctNewQueries {
  1152  			sdCache.Put(sd)
  1153  		}
  1154  	}
  1155  
  1156  	// Queue the queries.
  1157  	for _, bi := range b.QueuedQueries {
  1158  		err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
  1159  		if err != nil {
  1160  			// we wrap the error so we the user can understand which query failed inside the batch
  1161  			err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
  1162  			return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
  1163  		}
  1164  
  1165  		if bi.sd.Name == "" {
  1166  			pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
  1167  		} else {
  1168  			pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
  1169  		}
  1170  	}
  1171  
  1172  	err := pipeline.Sync()
  1173  	if err != nil {
  1174  		return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
  1175  	}
  1176  
  1177  	return &pipelineBatchResults{
  1178  		ctx:      ctx,
  1179  		conn:     c,
  1180  		pipeline: pipeline,
  1181  		b:        b,
  1182  	}
  1183  }
  1184  
  1185  func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
  1186  	if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
  1187  		return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
  1188  	}
  1189  
  1190  	if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
  1191  		return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
  1192  	}
  1193  
  1194  	var err error
  1195  	valueArgs := make([]any, len(args))
  1196  	for i, a := range args {
  1197  		valueArgs[i], err = convertSimpleArgument(c.typeMap, a)
  1198  		if err != nil {
  1199  			return "", err
  1200  		}
  1201  	}
  1202  
  1203  	return sanitize.SanitizeSQL(sql, valueArgs...)
  1204  }
  1205  
  1206  // LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
  1207  // the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
  1208  // typeName must be one of the following:
  1209  //   - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
  1210  //   - A composite type name where all field types are already registered.
  1211  //   - A domain type name where the base type is already registered.
  1212  //   - An enum type name.
  1213  //   - A range type name where the element type is already registered.
  1214  //   - A multirange type name where the element type is already registered.
  1215  func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
  1216  	var oid uint32
  1217  
  1218  	err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
  1219  	if err != nil {
  1220  		return nil, err
  1221  	}
  1222  
  1223  	var typtype string
  1224  	var typbasetype uint32
  1225  
  1226  	err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
  1227  	if err != nil {
  1228  		return nil, err
  1229  	}
  1230  
  1231  	switch typtype {
  1232  	case "b": // array
  1233  		elementOID, err := c.getArrayElementOID(ctx, oid)
  1234  		if err != nil {
  1235  			return nil, err
  1236  		}
  1237  
  1238  		dt, ok := c.TypeMap().TypeForOID(elementOID)
  1239  		if !ok {
  1240  			return nil, errors.New("array element OID not registered")
  1241  		}
  1242  
  1243  		return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}, nil
  1244  	case "c": // composite
  1245  		fields, err := c.getCompositeFields(ctx, oid)
  1246  		if err != nil {
  1247  			return nil, err
  1248  		}
  1249  
  1250  		return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
  1251  	case "d": // domain
  1252  		dt, ok := c.TypeMap().TypeForOID(typbasetype)
  1253  		if !ok {
  1254  			return nil, errors.New("domain base type OID not registered")
  1255  		}
  1256  
  1257  		return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
  1258  	case "e": // enum
  1259  		return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
  1260  	case "r": // range
  1261  		elementOID, err := c.getRangeElementOID(ctx, oid)
  1262  		if err != nil {
  1263  			return nil, err
  1264  		}
  1265  
  1266  		dt, ok := c.TypeMap().TypeForOID(elementOID)
  1267  		if !ok {
  1268  			return nil, errors.New("range element OID not registered")
  1269  		}
  1270  
  1271  		return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil
  1272  	case "m": // multirange
  1273  		elementOID, err := c.getMultiRangeElementOID(ctx, oid)
  1274  		if err != nil {
  1275  			return nil, err
  1276  		}
  1277  
  1278  		dt, ok := c.TypeMap().TypeForOID(elementOID)
  1279  		if !ok {
  1280  			return nil, errors.New("multirange element OID not registered")
  1281  		}
  1282  
  1283  		return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil
  1284  	default:
  1285  		return &pgtype.Type{}, errors.New("unknown typtype")
  1286  	}
  1287  }
  1288  
  1289  func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) {
  1290  	var typelem uint32
  1291  
  1292  	err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
  1293  	if err != nil {
  1294  		return 0, err
  1295  	}
  1296  
  1297  	return typelem, nil
  1298  }
  1299  
  1300  func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
  1301  	var typelem uint32
  1302  
  1303  	err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem)
  1304  	if err != nil {
  1305  		return 0, err
  1306  	}
  1307  
  1308  	return typelem, nil
  1309  }
  1310  
  1311  func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
  1312  	var typelem uint32
  1313  
  1314  	err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem)
  1315  	if err != nil {
  1316  		return 0, err
  1317  	}
  1318  
  1319  	return typelem, nil
  1320  }
  1321  
  1322  func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
  1323  	var typrelid uint32
  1324  
  1325  	err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
  1326  	if err != nil {
  1327  		return nil, err
  1328  	}
  1329  
  1330  	var fields []pgtype.CompositeCodecField
  1331  	var fieldName string
  1332  	var fieldOID uint32
  1333  	rows, _ := c.Query(ctx, `select attname, atttypid
  1334  from pg_attribute
  1335  where attrelid=$1
  1336  	and not attisdropped
  1337  	and attnum > 0
  1338  order by attnum`,
  1339  		typrelid,
  1340  	)
  1341  	_, err = ForEachRow(rows, []any{&fieldName, &fieldOID}, func() error {
  1342  		dt, ok := c.TypeMap().TypeForOID(fieldOID)
  1343  		if !ok {
  1344  			return fmt.Errorf("unknown composite type field OID: %v", fieldOID)
  1345  		}
  1346  		fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
  1347  		return nil
  1348  	})
  1349  	if err != nil {
  1350  		return nil, err
  1351  	}
  1352  
  1353  	return fields, nil
  1354  }
  1355  
  1356  func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
  1357  	if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
  1358  		return nil
  1359  	}
  1360  
  1361  	if c.descriptionCache != nil {
  1362  		c.descriptionCache.RemoveInvalidated()
  1363  	}
  1364  
  1365  	var invalidatedStatements []*pgconn.StatementDescription
  1366  	if c.statementCache != nil {
  1367  		invalidatedStatements = c.statementCache.GetInvalidated()
  1368  	}
  1369  
  1370  	if len(invalidatedStatements) == 0 {
  1371  		return nil
  1372  	}
  1373  
  1374  	pipeline := c.pgConn.StartPipeline(ctx)
  1375  	defer pipeline.Close()
  1376  
  1377  	for _, sd := range invalidatedStatements {
  1378  		pipeline.SendDeallocate(sd.Name)
  1379  	}
  1380  
  1381  	err := pipeline.Sync()
  1382  	if err != nil {
  1383  		return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
  1384  	}
  1385  
  1386  	err = pipeline.Close()
  1387  	if err != nil {
  1388  		return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
  1389  	}
  1390  
  1391  	c.statementCache.RemoveInvalidated()
  1392  	for _, sd := range invalidatedStatements {
  1393  		delete(c.preparedStatements, sd.Name)
  1394  	}
  1395  
  1396  	return nil
  1397  }
  1398  

View as plain text