...

Source file src/github.com/jackc/pgx/v5/rows.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/jackc/pgx/v5/pgconn"
    12  	"github.com/jackc/pgx/v5/pgtype"
    13  )
    14  
    15  // Rows is the result set returned from *Conn.Query. Rows must be closed before
    16  // the *Conn can be used again. Rows are closed by explicitly calling Close(),
    17  // calling Next() until it returns false, or when a fatal error occurs.
    18  //
    19  // Once a Rows is closed the only methods that may be called are Close(), Err(),
    20  // and CommandTag().
    21  //
    22  // Rows is an interface instead of a struct to allow tests to mock Query. However,
    23  // adding a method to an interface is technically a breaking change. Because of this
    24  // the Rows interface is partially excluded from semantic version requirements.
    25  // Methods will not be removed or changed, but new methods may be added.
    26  type Rows interface {
    27  	// Close closes the rows, making the connection ready for use again. It is safe
    28  	// to call Close after rows is already closed.
    29  	Close()
    30  
    31  	// Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by
    32  	// calling Close or by Next returning false). If it is called early it may return nil even if there was an error
    33  	// executing the query.
    34  	Err() error
    35  
    36  	// CommandTag returns the command tag from this query. It is only available after Rows is closed.
    37  	CommandTag() pgconn.CommandTag
    38  
    39  	// FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur
    40  	// when there was an error executing the query.
    41  	FieldDescriptions() []pgconn.FieldDescription
    42  
    43  	// Next prepares the next row for reading. It returns true if there is another
    44  	// row and false if no more rows are available or a fatal error has occurred.
    45  	// It automatically closes rows when all rows are read.
    46  	//
    47  	// Callers should check rows.Err() after rows.Next() returns false to detect
    48  	// whether result-set reading ended prematurely due to an error. See
    49  	// Conn.Query for details.
    50  	//
    51  	// For simpler error handling, consider using the higher-level pgx v5
    52  	// CollectRows() and ForEachRow() helpers instead.
    53  	Next() bool
    54  
    55  	// Scan reads the values from the current row into dest values positionally.
    56  	// dest can include pointers to core types, values implementing the Scanner
    57  	// interface, and nil. nil will skip the value entirely. It is an error to
    58  	// call Scan without first calling Next() and checking that it returned true.
    59  	Scan(dest ...any) error
    60  
    61  	// Values returns the decoded row values. As with Scan(), it is an error to
    62  	// call Values without first calling Next() and checking that it returned
    63  	// true.
    64  	Values() ([]any, error)
    65  
    66  	// RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next
    67  	// call or the Rows is closed.
    68  	RawValues() [][]byte
    69  
    70  	// Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a
    71  	// *Conn (e.g. if it was created by RowsFromResultReader)
    72  	Conn() *Conn
    73  }
    74  
    75  // Row is a convenience wrapper over Rows that is returned by QueryRow.
    76  //
    77  // Row is an interface instead of a struct to allow tests to mock QueryRow. However,
    78  // adding a method to an interface is technically a breaking change. Because of this
    79  // the Row interface is partially excluded from semantic version requirements.
    80  // Methods will not be removed or changed, but new methods may be added.
    81  type Row interface {
    82  	// Scan works the same as Rows. with the following exceptions. If no
    83  	// rows were found it returns ErrNoRows. If multiple rows are returned it
    84  	// ignores all but the first.
    85  	Scan(dest ...any) error
    86  }
    87  
    88  // RowScanner scans an entire row at a time into the RowScanner.
    89  type RowScanner interface {
    90  	// ScanRows scans the row.
    91  	ScanRow(rows Rows) error
    92  }
    93  
    94  // connRow implements the Row interface for Conn.QueryRow.
    95  type connRow baseRows
    96  
    97  func (r *connRow) Scan(dest ...any) (err error) {
    98  	rows := (*baseRows)(r)
    99  
   100  	if rows.Err() != nil {
   101  		return rows.Err()
   102  	}
   103  
   104  	for _, d := range dest {
   105  		if _, ok := d.(*pgtype.DriverBytes); ok {
   106  			rows.Close()
   107  			return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow")
   108  		}
   109  	}
   110  
   111  	if !rows.Next() {
   112  		if rows.Err() == nil {
   113  			return ErrNoRows
   114  		}
   115  		return rows.Err()
   116  	}
   117  
   118  	rows.Scan(dest...)
   119  	rows.Close()
   120  	return rows.Err()
   121  }
   122  
   123  // baseRows implements the Rows interface for Conn.Query.
   124  type baseRows struct {
   125  	typeMap      *pgtype.Map
   126  	resultReader *pgconn.ResultReader
   127  
   128  	values [][]byte
   129  
   130  	commandTag pgconn.CommandTag
   131  	err        error
   132  	closed     bool
   133  
   134  	scanPlans []pgtype.ScanPlan
   135  	scanTypes []reflect.Type
   136  
   137  	conn              *Conn
   138  	multiResultReader *pgconn.MultiResultReader
   139  
   140  	queryTracer QueryTracer
   141  	batchTracer BatchTracer
   142  	ctx         context.Context
   143  	startTime   time.Time
   144  	sql         string
   145  	args        []any
   146  	rowCount    int
   147  }
   148  
   149  func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription {
   150  	return rows.resultReader.FieldDescriptions()
   151  }
   152  
   153  func (rows *baseRows) Close() {
   154  	if rows.closed {
   155  		return
   156  	}
   157  
   158  	rows.closed = true
   159  
   160  	if rows.resultReader != nil {
   161  		var closeErr error
   162  		rows.commandTag, closeErr = rows.resultReader.Close()
   163  		if rows.err == nil {
   164  			rows.err = closeErr
   165  		}
   166  	}
   167  
   168  	if rows.multiResultReader != nil {
   169  		closeErr := rows.multiResultReader.Close()
   170  		if rows.err == nil {
   171  			rows.err = closeErr
   172  		}
   173  	}
   174  
   175  	if rows.err != nil && rows.conn != nil && rows.sql != "" {
   176  		if sc := rows.conn.statementCache; sc != nil {
   177  			sc.Invalidate(rows.sql)
   178  		}
   179  
   180  		if sc := rows.conn.descriptionCache; sc != nil {
   181  			sc.Invalidate(rows.sql)
   182  		}
   183  	}
   184  
   185  	if rows.batchTracer != nil {
   186  		rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err})
   187  	} else if rows.queryTracer != nil {
   188  		rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err})
   189  	}
   190  }
   191  
   192  func (rows *baseRows) CommandTag() pgconn.CommandTag {
   193  	return rows.commandTag
   194  }
   195  
   196  func (rows *baseRows) Err() error {
   197  	return rows.err
   198  }
   199  
   200  // fatal signals an error occurred after the query was sent to the server. It
   201  // closes the rows automatically.
   202  func (rows *baseRows) fatal(err error) {
   203  	if rows.err != nil {
   204  		return
   205  	}
   206  
   207  	rows.err = err
   208  	rows.Close()
   209  }
   210  
   211  func (rows *baseRows) Next() bool {
   212  	if rows.closed {
   213  		return false
   214  	}
   215  
   216  	if rows.resultReader.NextRow() {
   217  		rows.rowCount++
   218  		rows.values = rows.resultReader.Values()
   219  		return true
   220  	} else {
   221  		rows.Close()
   222  		return false
   223  	}
   224  }
   225  
   226  func (rows *baseRows) Scan(dest ...any) error {
   227  	m := rows.typeMap
   228  	fieldDescriptions := rows.FieldDescriptions()
   229  	values := rows.values
   230  
   231  	if len(fieldDescriptions) != len(values) {
   232  		err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
   233  		rows.fatal(err)
   234  		return err
   235  	}
   236  
   237  	if len(dest) == 1 {
   238  		if rc, ok := dest[0].(RowScanner); ok {
   239  			err := rc.ScanRow(rows)
   240  			if err != nil {
   241  				rows.fatal(err)
   242  			}
   243  			return err
   244  		}
   245  	}
   246  
   247  	if len(fieldDescriptions) != len(dest) {
   248  		err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
   249  		rows.fatal(err)
   250  		return err
   251  	}
   252  
   253  	if rows.scanPlans == nil {
   254  		rows.scanPlans = make([]pgtype.ScanPlan, len(values))
   255  		rows.scanTypes = make([]reflect.Type, len(values))
   256  		for i := range dest {
   257  			rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
   258  			rows.scanTypes[i] = reflect.TypeOf(dest[i])
   259  		}
   260  	}
   261  
   262  	for i, dst := range dest {
   263  		if dst == nil {
   264  			continue
   265  		}
   266  
   267  		if rows.scanTypes[i] != reflect.TypeOf(dst) {
   268  			rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
   269  			rows.scanTypes[i] = reflect.TypeOf(dest[i])
   270  		}
   271  
   272  		err := rows.scanPlans[i].Scan(values[i], dst)
   273  		if err != nil {
   274  			err = ScanArgError{ColumnIndex: i, Err: err}
   275  			rows.fatal(err)
   276  			return err
   277  		}
   278  	}
   279  
   280  	return nil
   281  }
   282  
   283  func (rows *baseRows) Values() ([]any, error) {
   284  	if rows.closed {
   285  		return nil, errors.New("rows is closed")
   286  	}
   287  
   288  	values := make([]any, 0, len(rows.FieldDescriptions()))
   289  
   290  	for i := range rows.FieldDescriptions() {
   291  		buf := rows.values[i]
   292  		fd := &rows.FieldDescriptions()[i]
   293  
   294  		if buf == nil {
   295  			values = append(values, nil)
   296  			continue
   297  		}
   298  
   299  		if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok {
   300  			value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf)
   301  			if err != nil {
   302  				rows.fatal(err)
   303  			}
   304  			values = append(values, value)
   305  		} else {
   306  			switch fd.Format {
   307  			case TextFormatCode:
   308  				values = append(values, string(buf))
   309  			case BinaryFormatCode:
   310  				newBuf := make([]byte, len(buf))
   311  				copy(newBuf, buf)
   312  				values = append(values, newBuf)
   313  			default:
   314  				rows.fatal(errors.New("unknown format code"))
   315  			}
   316  		}
   317  
   318  		if rows.Err() != nil {
   319  			return nil, rows.Err()
   320  		}
   321  	}
   322  
   323  	return values, rows.Err()
   324  }
   325  
   326  func (rows *baseRows) RawValues() [][]byte {
   327  	return rows.values
   328  }
   329  
   330  func (rows *baseRows) Conn() *Conn {
   331  	return rows.conn
   332  }
   333  
   334  type ScanArgError struct {
   335  	ColumnIndex int
   336  	Err         error
   337  }
   338  
   339  func (e ScanArgError) Error() string {
   340  	return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
   341  }
   342  
   343  func (e ScanArgError) Unwrap() error {
   344  	return e.Err
   345  }
   346  
   347  // ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface.
   348  //
   349  // typeMap - OID to Go type mapping.
   350  // fieldDescriptions - OID and format of values
   351  // values - the raw data as returned from the PostgreSQL server
   352  // dest - the destination that values will be decoded into
   353  func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error {
   354  	if len(fieldDescriptions) != len(values) {
   355  		return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
   356  	}
   357  	if len(fieldDescriptions) != len(dest) {
   358  		return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
   359  	}
   360  
   361  	for i, d := range dest {
   362  		if d == nil {
   363  			continue
   364  		}
   365  
   366  		err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
   367  		if err != nil {
   368  			return ScanArgError{ColumnIndex: i, Err: err}
   369  		}
   370  	}
   371  
   372  	return nil
   373  }
   374  
   375  // RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used
   376  // to read from the lower level pgconn interface.
   377  func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows {
   378  	return &baseRows{
   379  		typeMap:      typeMap,
   380  		resultReader: resultReader,
   381  	}
   382  }
   383  
   384  // ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row
   385  // fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed
   386  // when ForEachRow returns.
   387  func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) {
   388  	defer rows.Close()
   389  
   390  	for rows.Next() {
   391  		err := rows.Scan(scans...)
   392  		if err != nil {
   393  			return pgconn.CommandTag{}, err
   394  		}
   395  
   396  		err = fn()
   397  		if err != nil {
   398  			return pgconn.CommandTag{}, err
   399  		}
   400  	}
   401  
   402  	if err := rows.Err(); err != nil {
   403  		return pgconn.CommandTag{}, err
   404  	}
   405  
   406  	return rows.CommandTag(), nil
   407  }
   408  
   409  // CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call.
   410  type CollectableRow interface {
   411  	FieldDescriptions() []pgconn.FieldDescription
   412  	Scan(dest ...any) error
   413  	Values() ([]any, error)
   414  	RawValues() [][]byte
   415  }
   416  
   417  // RowToFunc is a function that scans or otherwise converts row to a T.
   418  type RowToFunc[T any] func(row CollectableRow) (T, error)
   419  
   420  // AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
   421  func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
   422  	defer rows.Close()
   423  
   424  	for rows.Next() {
   425  		value, err := fn(rows)
   426  		if err != nil {
   427  			return nil, err
   428  		}
   429  		slice = append(slice, value)
   430  	}
   431  
   432  	if err := rows.Err(); err != nil {
   433  		return nil, err
   434  	}
   435  
   436  	return slice, nil
   437  }
   438  
   439  // CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
   440  func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
   441  	return AppendRows([]T{}, rows, fn)
   442  }
   443  
   444  // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
   445  // CollectOneRow is to CollectRows as QueryRow is to Query.
   446  func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
   447  	defer rows.Close()
   448  
   449  	var value T
   450  	var err error
   451  
   452  	if !rows.Next() {
   453  		if err = rows.Err(); err != nil {
   454  			return value, err
   455  		}
   456  		return value, ErrNoRows
   457  	}
   458  
   459  	value, err = fn(rows)
   460  	if err != nil {
   461  		return value, err
   462  	}
   463  
   464  	rows.Close()
   465  	return value, rows.Err()
   466  }
   467  
   468  // CollectExactlyOneRow calls fn for the first row in rows and returns the result.
   469  //   - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
   470  //   - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
   471  func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
   472  	defer rows.Close()
   473  
   474  	var (
   475  		err   error
   476  		value T
   477  	)
   478  
   479  	if !rows.Next() {
   480  		if err = rows.Err(); err != nil {
   481  			return value, err
   482  		}
   483  
   484  		return value, ErrNoRows
   485  	}
   486  
   487  	value, err = fn(rows)
   488  	if err != nil {
   489  		return value, err
   490  	}
   491  
   492  	if rows.Next() {
   493  		var zero T
   494  
   495  		return zero, ErrTooManyRows
   496  	}
   497  
   498  	return value, rows.Err()
   499  }
   500  
   501  // RowTo returns a T scanned from row.
   502  func RowTo[T any](row CollectableRow) (T, error) {
   503  	var value T
   504  	err := row.Scan(&value)
   505  	return value, err
   506  }
   507  
   508  // RowTo returns a the address of a T scanned from row.
   509  func RowToAddrOf[T any](row CollectableRow) (*T, error) {
   510  	var value T
   511  	err := row.Scan(&value)
   512  	return &value, err
   513  }
   514  
   515  // RowToMap returns a map scanned from row.
   516  func RowToMap(row CollectableRow) (map[string]any, error) {
   517  	var value map[string]any
   518  	err := row.Scan((*mapRowScanner)(&value))
   519  	return value, err
   520  }
   521  
   522  type mapRowScanner map[string]any
   523  
   524  func (rs *mapRowScanner) ScanRow(rows Rows) error {
   525  	values, err := rows.Values()
   526  	if err != nil {
   527  		return err
   528  	}
   529  
   530  	*rs = make(mapRowScanner, len(values))
   531  
   532  	for i := range values {
   533  		(*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i]
   534  	}
   535  
   536  	return nil
   537  }
   538  
   539  // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
   540  // has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be
   541  // ignored.
   542  func RowToStructByPos[T any](row CollectableRow) (T, error) {
   543  	var value T
   544  	err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
   545  	return value, err
   546  }
   547  
   548  // RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
   549  // public fields as row has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then
   550  // the field will be ignored.
   551  func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
   552  	var value T
   553  	err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
   554  	return &value, err
   555  }
   556  
   557  type positionalStructRowScanner struct {
   558  	ptrToStruct any
   559  }
   560  
   561  func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
   562  	dst := rs.ptrToStruct
   563  	dstValue := reflect.ValueOf(dst)
   564  	if dstValue.Kind() != reflect.Ptr {
   565  		return fmt.Errorf("dst not a pointer")
   566  	}
   567  
   568  	dstElemValue := dstValue.Elem()
   569  	scanTargets := rs.appendScanTargets(dstElemValue, nil)
   570  
   571  	if len(rows.RawValues()) > len(scanTargets) {
   572  		return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
   573  	}
   574  
   575  	return rows.Scan(scanTargets...)
   576  }
   577  
   578  func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any {
   579  	dstElemType := dstElemValue.Type()
   580  
   581  	if scanTargets == nil {
   582  		scanTargets = make([]any, 0, dstElemType.NumField())
   583  	}
   584  
   585  	for i := 0; i < dstElemType.NumField(); i++ {
   586  		sf := dstElemType.Field(i)
   587  		// Handle anonymous struct embedding, but do not try to handle embedded pointers.
   588  		if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
   589  			scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
   590  		} else if sf.PkgPath == "" {
   591  			dbTag, _ := sf.Tag.Lookup(structTagKey)
   592  			if dbTag == "-" {
   593  				// Field is ignored, skip it.
   594  				continue
   595  			}
   596  			scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
   597  		}
   598  	}
   599  
   600  	return scanTargets
   601  }
   602  
   603  // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
   604  // fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
   605  // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
   606  func RowToStructByName[T any](row CollectableRow) (T, error) {
   607  	var value T
   608  	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
   609  	return value, err
   610  }
   611  
   612  // RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number
   613  // of named public fields as row has fields. The row and T fields will be matched by name. The match is
   614  // case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
   615  // then the field will be ignored.
   616  func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
   617  	var value T
   618  	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
   619  	return &value, err
   620  }
   621  
   622  // RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public
   623  // fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database
   624  // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
   625  func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
   626  	var value T
   627  	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
   628  	return value, err
   629  }
   630  
   631  // RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or
   632  // equal number of named public fields as row has fields. The row and T fields will be matched by name. The match is
   633  // case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
   634  // then the field will be ignored.
   635  func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
   636  	var value T
   637  	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
   638  	return &value, err
   639  }
   640  
   641  type namedStructRowScanner struct {
   642  	ptrToStruct any
   643  	lax         bool
   644  }
   645  
   646  func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
   647  	dst := rs.ptrToStruct
   648  	dstValue := reflect.ValueOf(dst)
   649  	if dstValue.Kind() != reflect.Ptr {
   650  		return fmt.Errorf("dst not a pointer")
   651  	}
   652  
   653  	dstElemValue := dstValue.Elem()
   654  	scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
   655  	if err != nil {
   656  		return err
   657  	}
   658  
   659  	for i, t := range scanTargets {
   660  		if t == nil {
   661  			return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
   662  		}
   663  	}
   664  
   665  	return rows.Scan(scanTargets...)
   666  }
   667  
   668  const structTagKey = "db"
   669  
   670  func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
   671  	i = -1
   672  	for i, desc := range fldDescs {
   673  
   674  		// Snake case support.
   675  		field = strings.ReplaceAll(field, "_", "")
   676  		descName := strings.ReplaceAll(desc.Name, "_", "")
   677  
   678  		if strings.EqualFold(descName, field) {
   679  			return i
   680  		}
   681  	}
   682  	return
   683  }
   684  
   685  func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
   686  	var err error
   687  	dstElemType := dstElemValue.Type()
   688  
   689  	if scanTargets == nil {
   690  		scanTargets = make([]any, len(fldDescs))
   691  	}
   692  
   693  	for i := 0; i < dstElemType.NumField(); i++ {
   694  		sf := dstElemType.Field(i)
   695  		if sf.PkgPath != "" && !sf.Anonymous {
   696  			// Field is unexported, skip it.
   697  			continue
   698  		}
   699  		// Handle anonymous struct embedding, but do not try to handle embedded pointers.
   700  		if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
   701  			scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
   702  			if err != nil {
   703  				return nil, err
   704  			}
   705  		} else {
   706  			dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
   707  			if dbTagPresent {
   708  				dbTag, _, _ = strings.Cut(dbTag, ",")
   709  			}
   710  			if dbTag == "-" {
   711  				// Field is ignored, skip it.
   712  				continue
   713  			}
   714  			colName := dbTag
   715  			if !dbTagPresent {
   716  				colName = sf.Name
   717  			}
   718  			fpos := fieldPosByName(fldDescs, colName)
   719  			if fpos == -1 {
   720  				if rs.lax {
   721  					continue
   722  				}
   723  				return nil, fmt.Errorf("cannot find field %s in returned row", colName)
   724  			}
   725  			if fpos >= len(scanTargets) && !rs.lax {
   726  				return nil, fmt.Errorf("cannot find field %s in returned row", colName)
   727  			}
   728  			scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
   729  		}
   730  	}
   731  
   732  	return scanTargets, err
   733  }
   734  

View as plain text