...

Source file src/github.com/jackc/pgx/v4/conn.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/jackc/pgconn"
    12  	"github.com/jackc/pgconn/stmtcache"
    13  	"github.com/jackc/pgproto3/v2"
    14  	"github.com/jackc/pgtype"
    15  	"github.com/jackc/pgx/v4/internal/sanitize"
    16  )
    17  
    18  // ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and
    19  // then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
    20  type ConnConfig struct {
    21  	pgconn.Config
    22  	Logger   Logger
    23  	LogLevel LogLevel
    24  
    25  	// Original connection string that was parsed into config.
    26  	connString string
    27  
    28  	// BuildStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set
    29  	// to nil to disable automatic prepared statements.
    30  	BuildStatementCache BuildStatementCacheFunc
    31  
    32  	// PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended
    33  	// protocol. This can improve performance due to being able to use the binary format. It also does not rely on client
    34  	// side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement)
    35  	// and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
    36  	// used by default. The same functionality can be controlled on a per query basis by setting
    37  	// QueryExOptions.SimpleProtocol.
    38  	PreferSimpleProtocol bool
    39  
    40  	createdByParseConfig bool // Used to enforce created by ParseConfig rule.
    41  }
    42  
    43  // Copy returns a deep copy of the config that is safe to use and modify.
    44  // The only exception is the tls.Config:
    45  // according to the tls.Config docs it must not be modified after creation.
    46  func (cc *ConnConfig) Copy() *ConnConfig {
    47  	newConfig := new(ConnConfig)
    48  	*newConfig = *cc
    49  	newConfig.Config = *newConfig.Config.Copy()
    50  	return newConfig
    51  }
    52  
    53  // ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig.
    54  func (cc *ConnConfig) ConnString() string { return cc.connString }
    55  
    56  // BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection.
    57  type BuildStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache
    58  
    59  // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access
    60  // to multiple database connections from multiple goroutines.
    61  type Conn struct {
    62  	pgConn             *pgconn.PgConn
    63  	config             *ConnConfig // config used when establishing this connection
    64  	preparedStatements map[string]*pgconn.StatementDescription
    65  	stmtcache          stmtcache.Cache
    66  	logger             Logger
    67  	logLevel           LogLevel
    68  
    69  	notifications []*pgconn.Notification
    70  
    71  	doneChan   chan struct{}
    72  	closedChan chan error
    73  
    74  	connInfo *pgtype.ConnInfo
    75  
    76  	wbuf []byte
    77  	eqb  extendedQueryBuilder
    78  }
    79  
    80  // Identifier a PostgreSQL identifier or name. Identifiers can be composed of
    81  // multiple parts such as ["schema", "table"] or ["table", "column"].
    82  type Identifier []string
    83  
    84  // Sanitize returns a sanitized string safe for SQL interpolation.
    85  func (ident Identifier) Sanitize() string {
    86  	parts := make([]string, len(ident))
    87  	for i := range ident {
    88  		s := strings.ReplaceAll(ident[i], string([]byte{0}), "")
    89  		parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
    90  	}
    91  	return strings.Join(parts, ".")
    92  }
    93  
    94  // ErrNoRows occurs when rows are expected but none are returned.
    95  var ErrNoRows = errors.New("no rows in result set")
    96  
    97  // ErrInvalidLogLevel occurs on attempt to set an invalid log level.
    98  var ErrInvalidLogLevel = errors.New("invalid log level")
    99  
   100  // Connect establishes a connection with a PostgreSQL server with a connection string. See
   101  // pgconn.Connect for details.
   102  func Connect(ctx context.Context, connString string) (*Conn, error) {
   103  	connConfig, err := ParseConfig(connString)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	return connect(ctx, connConfig)
   108  }
   109  
   110  // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct.
   111  // connConfig must have been created by ParseConfig.
   112  func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
   113  	return connect(ctx, connConfig)
   114  }
   115  
   116  // ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig
   117  // does. In addition, it accepts the following options:
   118  //
   119  //	statement_cache_capacity
   120  //		The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512.
   121  //
   122  //	statement_cache_mode
   123  //		Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server.
   124  //		"describe" will use the anonymous prepared statement to describe a statement without creating a statement on the
   125  //		server. "describe" is primarily useful when the environment does not allow prepared statements such as when
   126  //		running a connection pooler like PgBouncer. Default: "prepare"
   127  //
   128  //	prefer_simple_protocol
   129  //		Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false
   130  func ParseConfig(connString string) (*ConnConfig, error) {
   131  	config, err := pgconn.ParseConfig(connString)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	var buildStatementCache BuildStatementCacheFunc
   137  	statementCacheCapacity := 512
   138  	statementCacheMode := stmtcache.ModePrepare
   139  	if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok {
   140  		delete(config.RuntimeParams, "statement_cache_capacity")
   141  		n, err := strconv.ParseInt(s, 10, 32)
   142  		if err != nil {
   143  			return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
   144  		}
   145  		statementCacheCapacity = int(n)
   146  	}
   147  
   148  	if s, ok := config.RuntimeParams["statement_cache_mode"]; ok {
   149  		delete(config.RuntimeParams, "statement_cache_mode")
   150  		switch s {
   151  		case "prepare":
   152  			statementCacheMode = stmtcache.ModePrepare
   153  		case "describe":
   154  			statementCacheMode = stmtcache.ModeDescribe
   155  		default:
   156  			return nil, fmt.Errorf("invalid statement_cache_mod: %s", s)
   157  		}
   158  	}
   159  
   160  	if statementCacheCapacity > 0 {
   161  		buildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
   162  			return stmtcache.New(conn, statementCacheMode, statementCacheCapacity)
   163  		}
   164  	}
   165  
   166  	preferSimpleProtocol := false
   167  	if s, ok := config.RuntimeParams["prefer_simple_protocol"]; ok {
   168  		delete(config.RuntimeParams, "prefer_simple_protocol")
   169  		if b, err := strconv.ParseBool(s); err == nil {
   170  			preferSimpleProtocol = b
   171  		} else {
   172  			return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err)
   173  		}
   174  	}
   175  
   176  	connConfig := &ConnConfig{
   177  		Config:               *config,
   178  		createdByParseConfig: true,
   179  		LogLevel:             LogLevelInfo,
   180  		BuildStatementCache:  buildStatementCache,
   181  		PreferSimpleProtocol: preferSimpleProtocol,
   182  		connString:           connString,
   183  	}
   184  
   185  	return connConfig, nil
   186  }
   187  
   188  func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
   189  	// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
   190  	// zero values.
   191  	if !config.createdByParseConfig {
   192  		panic("config must be created by ParseConfig")
   193  	}
   194  	originalConfig := config
   195  
   196  	// This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting
   197  	// other connections with the same config. See https://github.com/jackc/pgx/issues/618.
   198  	{
   199  		configCopy := *config
   200  		config = &configCopy
   201  	}
   202  
   203  	c = &Conn{
   204  		config:   originalConfig,
   205  		connInfo: pgtype.NewConnInfo(),
   206  		logLevel: config.LogLevel,
   207  		logger:   config.Logger,
   208  	}
   209  
   210  	// Only install pgx notification system if no other callback handler is present.
   211  	if config.Config.OnNotification == nil {
   212  		config.Config.OnNotification = c.bufferNotifications
   213  	} else {
   214  		if c.shouldLog(LogLevelDebug) {
   215  			c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host})
   216  		}
   217  	}
   218  
   219  	if c.shouldLog(LogLevelInfo) {
   220  		c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
   221  	}
   222  	c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
   223  	if err != nil {
   224  		if c.shouldLog(LogLevelError) {
   225  			c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err})
   226  		}
   227  		return nil, err
   228  	}
   229  
   230  	c.preparedStatements = make(map[string]*pgconn.StatementDescription)
   231  	c.doneChan = make(chan struct{})
   232  	c.closedChan = make(chan error)
   233  	c.wbuf = make([]byte, 0, 1024)
   234  
   235  	if c.config.BuildStatementCache != nil {
   236  		c.stmtcache = c.config.BuildStatementCache(c.pgConn)
   237  	}
   238  
   239  	// Replication connections can't execute the queries to
   240  	// populate the c.PgTypes and c.pgsqlAfInet
   241  	if _, ok := config.Config.RuntimeParams["replication"]; ok {
   242  		return c, nil
   243  	}
   244  
   245  	return c, nil
   246  }
   247  
   248  // Close closes a connection. It is safe to call Close on a already closed
   249  // connection.
   250  func (c *Conn) Close(ctx context.Context) error {
   251  	if c.IsClosed() {
   252  		return nil
   253  	}
   254  
   255  	err := c.pgConn.Close(ctx)
   256  	if c.shouldLog(LogLevelInfo) {
   257  		c.log(ctx, LogLevelInfo, "closed connection", nil)
   258  	}
   259  	return err
   260  }
   261  
   262  // Prepare creates a prepared statement with name and sql. sql can contain placeholders
   263  // for bound parameters. These placeholders are referenced positional as $1, $2, etc.
   264  //
   265  // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
   266  // name and sql arguments. This allows a code path to Prepare and Query/Exec without
   267  // concern for if the statement has already been prepared.
   268  func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
   269  	if name != "" {
   270  		var ok bool
   271  		if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql {
   272  			return sd, nil
   273  		}
   274  	}
   275  
   276  	if c.shouldLog(LogLevelError) {
   277  		defer func() {
   278  			if err != nil {
   279  				c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
   280  			}
   281  		}()
   282  	}
   283  
   284  	sd, err = c.pgConn.Prepare(ctx, name, sql, nil)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  
   289  	if name != "" {
   290  		c.preparedStatements[name] = sd
   291  	}
   292  
   293  	return sd, nil
   294  }
   295  
   296  // Deallocate released a prepared statement
   297  func (c *Conn) Deallocate(ctx context.Context, name string) error {
   298  	delete(c.preparedStatements, name)
   299  	_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
   300  	return err
   301  }
   302  
   303  func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) {
   304  	c.notifications = append(c.notifications, n)
   305  }
   306  
   307  // WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a
   308  // slightly more convenient form.
   309  func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) {
   310  	var n *pgconn.Notification
   311  
   312  	// Return already received notification immediately
   313  	if len(c.notifications) > 0 {
   314  		n = c.notifications[0]
   315  		c.notifications = c.notifications[1:]
   316  		return n, nil
   317  	}
   318  
   319  	err := c.pgConn.WaitForNotification(ctx)
   320  	if len(c.notifications) > 0 {
   321  		n = c.notifications[0]
   322  		c.notifications = c.notifications[1:]
   323  	}
   324  	return n, err
   325  }
   326  
   327  // IsClosed reports if the connection has been closed.
   328  func (c *Conn) IsClosed() bool {
   329  	return c.pgConn.IsClosed()
   330  }
   331  
   332  func (c *Conn) die(err error) {
   333  	if c.IsClosed() {
   334  		return
   335  	}
   336  
   337  	ctx, cancel := context.WithCancel(context.Background())
   338  	cancel() // force immediate hard cancel
   339  	c.pgConn.Close(ctx)
   340  }
   341  
   342  func (c *Conn) shouldLog(lvl LogLevel) bool {
   343  	return c.logger != nil && c.logLevel >= lvl
   344  }
   345  
   346  func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) {
   347  	if data == nil {
   348  		data = map[string]interface{}{}
   349  	}
   350  	if c.pgConn != nil && c.pgConn.PID() != 0 {
   351  		data["pid"] = c.pgConn.PID()
   352  	}
   353  
   354  	c.logger.Log(ctx, lvl, msg, data)
   355  }
   356  
   357  func quoteIdentifier(s string) string {
   358  	return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
   359  }
   360  
   361  // Ping executes an empty sql statement against the *Conn
   362  // If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned.
   363  func (c *Conn) Ping(ctx context.Context) error {
   364  	_, err := c.Exec(ctx, ";")
   365  	return err
   366  }
   367  
   368  // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
   369  // PostgreSQL connection than pgx exposes.
   370  //
   371  // It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn
   372  // is used and the connection must be returned to the same state before any *pgx.Conn methods are again used.
   373  func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }
   374  
   375  // StatementCache returns the statement cache used for this connection.
   376  func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache }
   377  
   378  // ConnInfo returns the connection info used for this connection.
   379  func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo }
   380  
   381  // Config returns a copy of config that was used to establish this connection.
   382  func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
   383  
   384  // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
   385  // positionally from the sql string as $1, $2, etc.
   386  func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
   387  	startTime := time.Now()
   388  
   389  	commandTag, err := c.exec(ctx, sql, arguments...)
   390  	if err != nil {
   391  		if c.shouldLog(LogLevelError) {
   392  			endTime := time.Now()
   393  			c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)})
   394  		}
   395  		return commandTag, err
   396  	}
   397  
   398  	if c.shouldLog(LogLevelInfo) {
   399  		endTime := time.Now()
   400  		c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
   401  	}
   402  
   403  	return commandTag, err
   404  }
   405  
   406  func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
   407  	simpleProtocol := c.config.PreferSimpleProtocol
   408  
   409  optionLoop:
   410  	for len(arguments) > 0 {
   411  		switch arg := arguments[0].(type) {
   412  		case QuerySimpleProtocol:
   413  			simpleProtocol = bool(arg)
   414  			arguments = arguments[1:]
   415  		default:
   416  			break optionLoop
   417  		}
   418  	}
   419  
   420  	if sd, ok := c.preparedStatements[sql]; ok {
   421  		return c.execPrepared(ctx, sd, arguments)
   422  	}
   423  
   424  	if simpleProtocol {
   425  		return c.execSimpleProtocol(ctx, sql, arguments)
   426  	}
   427  
   428  	if len(arguments) == 0 {
   429  		return c.execSimpleProtocol(ctx, sql, arguments)
   430  	}
   431  
   432  	if c.stmtcache != nil {
   433  		sd, err := c.stmtcache.Get(ctx, sql)
   434  		if err != nil {
   435  			return nil, err
   436  		}
   437  
   438  		if c.stmtcache.Mode() == stmtcache.ModeDescribe {
   439  			return c.execParams(ctx, sd, arguments)
   440  		}
   441  		return c.execPrepared(ctx, sd, arguments)
   442  	}
   443  
   444  	sd, err := c.Prepare(ctx, "", sql)
   445  	if err != nil {
   446  		return nil, err
   447  	}
   448  	return c.execPrepared(ctx, sd, arguments)
   449  }
   450  
   451  func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
   452  	if len(arguments) > 0 {
   453  		sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
   454  		if err != nil {
   455  			return nil, err
   456  		}
   457  	}
   458  
   459  	mrr := c.pgConn.Exec(ctx, sql)
   460  	for mrr.NextResult() {
   461  		commandTag, err = mrr.ResultReader().Close()
   462  	}
   463  	err = mrr.Close()
   464  	return commandTag, err
   465  }
   466  
   467  func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error {
   468  	if len(sd.ParamOIDs) != len(arguments) {
   469  		return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments))
   470  	}
   471  
   472  	c.eqb.Reset()
   473  
   474  	args, err := convertDriverValuers(arguments)
   475  	if err != nil {
   476  		return err
   477  	}
   478  
   479  	for i := range args {
   480  		err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i])
   481  		if err != nil {
   482  			return err
   483  		}
   484  	}
   485  
   486  	for i := range sd.Fields {
   487  		c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID))
   488  	}
   489  
   490  	return nil
   491  }
   492  
   493  func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
   494  	err := c.execParamsAndPreparedPrefix(sd, arguments)
   495  	if err != nil {
   496  		return nil, err
   497  	}
   498  
   499  	result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
   500  	c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   501  	return result.CommandTag, result.Err
   502  }
   503  
   504  func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
   505  	err := c.execParamsAndPreparedPrefix(sd, arguments)
   506  	if err != nil {
   507  		return nil, err
   508  	}
   509  
   510  	result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read()
   511  	c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   512  	return result.CommandTag, result.Err
   513  }
   514  
   515  func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
   516  	r := &connRows{}
   517  
   518  	r.ctx = ctx
   519  	r.logger = c
   520  	r.connInfo = c.connInfo
   521  	r.startTime = time.Now()
   522  	r.sql = sql
   523  	r.args = args
   524  	r.conn = c
   525  
   526  	return r
   527  }
   528  
   529  // QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query.
   530  type QuerySimpleProtocol bool
   531  
   532  // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
   533  type QueryResultFormats []int16
   534  
   535  // QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
   536  type QueryResultFormatsByOID map[uint32]int16
   537  
   538  // Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
   539  // and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to
   540  // determine if the query executed successfully.
   541  //
   542  // The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the
   543  // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It
   544  // is allowed to ignore the error returned from Query and handle it in Rows.
   545  //
   546  // Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully
   547  // as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row.
   548  //
   549  // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and
   550  // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
   551  // needed. See the documentation for those types for details.
   552  func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
   553  	var resultFormats QueryResultFormats
   554  	var resultFormatsByOID QueryResultFormatsByOID
   555  	simpleProtocol := c.config.PreferSimpleProtocol
   556  
   557  optionLoop:
   558  	for len(args) > 0 {
   559  		switch arg := args[0].(type) {
   560  		case QueryResultFormats:
   561  			resultFormats = arg
   562  			args = args[1:]
   563  		case QueryResultFormatsByOID:
   564  			resultFormatsByOID = arg
   565  			args = args[1:]
   566  		case QuerySimpleProtocol:
   567  			simpleProtocol = bool(arg)
   568  			args = args[1:]
   569  		default:
   570  			break optionLoop
   571  		}
   572  	}
   573  
   574  	rows := c.getRows(ctx, sql, args)
   575  
   576  	var err error
   577  	sd, ok := c.preparedStatements[sql]
   578  
   579  	if simpleProtocol && !ok {
   580  		sql, err = c.sanitizeForSimpleQuery(sql, args...)
   581  		if err != nil {
   582  			rows.fatal(err)
   583  			return rows, err
   584  		}
   585  
   586  		mrr := c.pgConn.Exec(ctx, sql)
   587  		if mrr.NextResult() {
   588  			rows.resultReader = mrr.ResultReader()
   589  			rows.multiResultReader = mrr
   590  		} else {
   591  			err = mrr.Close()
   592  			rows.fatal(err)
   593  			return rows, err
   594  		}
   595  
   596  		return rows, nil
   597  	}
   598  
   599  	c.eqb.Reset()
   600  
   601  	if !ok {
   602  		if c.stmtcache != nil {
   603  			sd, err = c.stmtcache.Get(ctx, sql)
   604  			if err != nil {
   605  				rows.fatal(err)
   606  				return rows, rows.err
   607  			}
   608  		} else {
   609  			sd, err = c.pgConn.Prepare(ctx, "", sql, nil)
   610  			if err != nil {
   611  				rows.fatal(err)
   612  				return rows, rows.err
   613  			}
   614  		}
   615  	}
   616  	if len(sd.ParamOIDs) != len(args) {
   617  		rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
   618  		return rows, rows.err
   619  	}
   620  
   621  	rows.sql = sd.SQL
   622  
   623  	args, err = convertDriverValuers(args)
   624  	if err != nil {
   625  		rows.fatal(err)
   626  		return rows, rows.err
   627  	}
   628  
   629  	for i := range args {
   630  		err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i])
   631  		if err != nil {
   632  			rows.fatal(err)
   633  			return rows, rows.err
   634  		}
   635  	}
   636  
   637  	if resultFormatsByOID != nil {
   638  		resultFormats = make([]int16, len(sd.Fields))
   639  		for i := range resultFormats {
   640  			resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
   641  		}
   642  	}
   643  
   644  	if resultFormats == nil {
   645  		for i := range sd.Fields {
   646  			c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID))
   647  		}
   648  
   649  		resultFormats = c.eqb.resultFormats
   650  	}
   651  
   652  	if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe && !ok {
   653  		rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
   654  	} else {
   655  		rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
   656  	}
   657  
   658  	c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   659  
   660  	return rows, rows.err
   661  }
   662  
   663  // QueryRow is a convenience wrapper over Query. Any error that occurs while
   664  // querying is deferred until calling Scan on the returned Row. That Row will
   665  // error with ErrNoRows if no rows are returned.
   666  func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
   667  	rows, _ := c.Query(ctx, sql, args...)
   668  	return (*connRow)(rows.(*connRows))
   669  }
   670  
   671  // QueryFuncRow is the argument to the QueryFunc callback function.
   672  //
   673  // QueryFuncRow is an interface instead of a struct to allow tests to mock QueryFunc. However, adding a method to an
   674  // interface is technically a breaking change. Because of this the QueryFuncRow interface is partially excluded from
   675  // semantic version requirements. Methods will not be removed or changed, but new methods may be added.
   676  type QueryFuncRow interface {
   677  	FieldDescriptions() []pgproto3.FieldDescription
   678  
   679  	// RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid during the current
   680  	// function call. However, the underlying byte data is safe to retain a reference to and mutate.
   681  	RawValues() [][]byte
   682  }
   683  
   684  // QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of
   685  // scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error
   686  // will be returned.
   687  func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
   688  	rows, err := c.Query(ctx, sql, args...)
   689  	if err != nil {
   690  		return nil, err
   691  	}
   692  	defer rows.Close()
   693  
   694  	for rows.Next() {
   695  		err = rows.Scan(scans...)
   696  		if err != nil {
   697  			return nil, err
   698  		}
   699  
   700  		err = f(rows)
   701  		if err != nil {
   702  			return nil, err
   703  		}
   704  	}
   705  
   706  	if err := rows.Err(); err != nil {
   707  		return nil, err
   708  	}
   709  
   710  	return rows.CommandTag(), nil
   711  }
   712  
   713  // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
   714  // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
   715  // is used again.
   716  func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
   717  	startTime := time.Now()
   718  
   719  	simpleProtocol := c.config.PreferSimpleProtocol
   720  	var sb strings.Builder
   721  	if simpleProtocol {
   722  		for i, bi := range b.items {
   723  			if i > 0 {
   724  				sb.WriteByte(';')
   725  			}
   726  			sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
   727  			if err != nil {
   728  				return &batchResults{ctx: ctx, conn: c, err: err}
   729  			}
   730  			sb.WriteString(sql)
   731  		}
   732  		mrr := c.pgConn.Exec(ctx, sb.String())
   733  		return &batchResults{
   734  			ctx:  ctx,
   735  			conn: c,
   736  			mrr:  mrr,
   737  			b:    b,
   738  			ix:   0,
   739  		}
   740  	}
   741  
   742  	distinctUnpreparedQueries := map[string]struct{}{}
   743  
   744  	for _, bi := range b.items {
   745  		if _, ok := c.preparedStatements[bi.query]; ok {
   746  			continue
   747  		}
   748  		distinctUnpreparedQueries[bi.query] = struct{}{}
   749  	}
   750  
   751  	var stmtCache stmtcache.Cache
   752  	if len(distinctUnpreparedQueries) > 0 {
   753  		if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) {
   754  			stmtCache = c.stmtcache
   755  		} else {
   756  			stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
   757  		}
   758  
   759  		for sql, _ := range distinctUnpreparedQueries {
   760  			_, err := stmtCache.Get(ctx, sql)
   761  			if err != nil {
   762  				return &batchResults{ctx: ctx, conn: c, err: err}
   763  			}
   764  		}
   765  	}
   766  
   767  	batch := &pgconn.Batch{}
   768  
   769  	for _, bi := range b.items {
   770  		c.eqb.Reset()
   771  
   772  		sd := c.preparedStatements[bi.query]
   773  		if sd == nil {
   774  			var err error
   775  			sd, err = stmtCache.Get(ctx, bi.query)
   776  			if err != nil {
   777  				return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err})
   778  			}
   779  		}
   780  
   781  		if len(sd.ParamOIDs) != len(bi.arguments) {
   782  			return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")})
   783  		}
   784  
   785  		args, err := convertDriverValuers(bi.arguments)
   786  		if err != nil {
   787  			return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err})
   788  		}
   789  
   790  		for i := range args {
   791  			err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i])
   792  			if err != nil {
   793  				return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err})
   794  			}
   795  		}
   796  
   797  		for i := range sd.Fields {
   798  			c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID))
   799  		}
   800  
   801  		if sd.Name == "" {
   802  			batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
   803  		} else {
   804  			batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
   805  		}
   806  	}
   807  
   808  	c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
   809  
   810  	mrr := c.pgConn.ExecBatch(ctx, batch)
   811  
   812  	return c.logBatchResults(ctx, startTime, &batchResults{
   813  		ctx:  ctx,
   814  		conn: c,
   815  		mrr:  mrr,
   816  		b:    b,
   817  		ix:   0,
   818  	})
   819  }
   820  
   821  func (c *Conn) logBatchResults(ctx context.Context, startTime time.Time, results *batchResults) BatchResults {
   822  	if results.err != nil {
   823  		if c.shouldLog(LogLevelError) {
   824  			endTime := time.Now()
   825  			c.log(ctx, LogLevelError, "SendBatch", map[string]interface{}{"err": results.err, "time": endTime.Sub(startTime)})
   826  		}
   827  		return results
   828  	}
   829  
   830  	if c.shouldLog(LogLevelInfo) {
   831  		endTime := time.Now()
   832  		c.log(ctx, LogLevelInfo, "SendBatch", map[string]interface{}{"batchLen": results.b.Len(), "time": endTime.Sub(startTime)})
   833  	}
   834  
   835  	return results
   836  }
   837  
   838  func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
   839  	if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
   840  		return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
   841  	}
   842  
   843  	if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
   844  		return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
   845  	}
   846  
   847  	var err error
   848  	valueArgs := make([]interface{}, len(args))
   849  	for i, a := range args {
   850  		valueArgs[i], err = convertSimpleArgument(c.connInfo, a)
   851  		if err != nil {
   852  			return "", err
   853  		}
   854  	}
   855  
   856  	return sanitize.SanitizeSQL(sql, valueArgs...)
   857  }
   858  

View as plain text