1 package pgx
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "reflect"
8 "strings"
9 "time"
10
11 "github.com/jackc/pgx/v5/pgconn"
12 "github.com/jackc/pgx/v5/pgtype"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26 type Rows interface {
27
28
29 Close()
30
31
32
33
34 Err() error
35
36
37 CommandTag() pgconn.CommandTag
38
39
40
41 FieldDescriptions() []pgconn.FieldDescription
42
43
44
45
46
47
48
49
50
51
52
53 Next() bool
54
55
56
57
58
59 Scan(dest ...any) error
60
61
62
63
64 Values() ([]any, error)
65
66
67
68 RawValues() [][]byte
69
70
71
72 Conn() *Conn
73 }
74
75
76
77
78
79
80
81 type Row interface {
82
83
84
85 Scan(dest ...any) error
86 }
87
88
89 type RowScanner interface {
90
91 ScanRow(rows Rows) error
92 }
93
94
95 type connRow baseRows
96
97 func (r *connRow) Scan(dest ...any) (err error) {
98 rows := (*baseRows)(r)
99
100 if rows.Err() != nil {
101 return rows.Err()
102 }
103
104 for _, d := range dest {
105 if _, ok := d.(*pgtype.DriverBytes); ok {
106 rows.Close()
107 return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow")
108 }
109 }
110
111 if !rows.Next() {
112 if rows.Err() == nil {
113 return ErrNoRows
114 }
115 return rows.Err()
116 }
117
118 rows.Scan(dest...)
119 rows.Close()
120 return rows.Err()
121 }
122
123
124 type baseRows struct {
125 typeMap *pgtype.Map
126 resultReader *pgconn.ResultReader
127
128 values [][]byte
129
130 commandTag pgconn.CommandTag
131 err error
132 closed bool
133
134 scanPlans []pgtype.ScanPlan
135 scanTypes []reflect.Type
136
137 conn *Conn
138 multiResultReader *pgconn.MultiResultReader
139
140 queryTracer QueryTracer
141 batchTracer BatchTracer
142 ctx context.Context
143 startTime time.Time
144 sql string
145 args []any
146 rowCount int
147 }
148
149 func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription {
150 return rows.resultReader.FieldDescriptions()
151 }
152
153 func (rows *baseRows) Close() {
154 if rows.closed {
155 return
156 }
157
158 rows.closed = true
159
160 if rows.resultReader != nil {
161 var closeErr error
162 rows.commandTag, closeErr = rows.resultReader.Close()
163 if rows.err == nil {
164 rows.err = closeErr
165 }
166 }
167
168 if rows.multiResultReader != nil {
169 closeErr := rows.multiResultReader.Close()
170 if rows.err == nil {
171 rows.err = closeErr
172 }
173 }
174
175 if rows.err != nil && rows.conn != nil && rows.sql != "" {
176 if sc := rows.conn.statementCache; sc != nil {
177 sc.Invalidate(rows.sql)
178 }
179
180 if sc := rows.conn.descriptionCache; sc != nil {
181 sc.Invalidate(rows.sql)
182 }
183 }
184
185 if rows.batchTracer != nil {
186 rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err})
187 } else if rows.queryTracer != nil {
188 rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err})
189 }
190 }
191
192 func (rows *baseRows) CommandTag() pgconn.CommandTag {
193 return rows.commandTag
194 }
195
196 func (rows *baseRows) Err() error {
197 return rows.err
198 }
199
200
201
202 func (rows *baseRows) fatal(err error) {
203 if rows.err != nil {
204 return
205 }
206
207 rows.err = err
208 rows.Close()
209 }
210
211 func (rows *baseRows) Next() bool {
212 if rows.closed {
213 return false
214 }
215
216 if rows.resultReader.NextRow() {
217 rows.rowCount++
218 rows.values = rows.resultReader.Values()
219 return true
220 } else {
221 rows.Close()
222 return false
223 }
224 }
225
226 func (rows *baseRows) Scan(dest ...any) error {
227 m := rows.typeMap
228 fieldDescriptions := rows.FieldDescriptions()
229 values := rows.values
230
231 if len(fieldDescriptions) != len(values) {
232 err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
233 rows.fatal(err)
234 return err
235 }
236
237 if len(dest) == 1 {
238 if rc, ok := dest[0].(RowScanner); ok {
239 err := rc.ScanRow(rows)
240 if err != nil {
241 rows.fatal(err)
242 }
243 return err
244 }
245 }
246
247 if len(fieldDescriptions) != len(dest) {
248 err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
249 rows.fatal(err)
250 return err
251 }
252
253 if rows.scanPlans == nil {
254 rows.scanPlans = make([]pgtype.ScanPlan, len(values))
255 rows.scanTypes = make([]reflect.Type, len(values))
256 for i := range dest {
257 rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
258 rows.scanTypes[i] = reflect.TypeOf(dest[i])
259 }
260 }
261
262 for i, dst := range dest {
263 if dst == nil {
264 continue
265 }
266
267 if rows.scanTypes[i] != reflect.TypeOf(dst) {
268 rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
269 rows.scanTypes[i] = reflect.TypeOf(dest[i])
270 }
271
272 err := rows.scanPlans[i].Scan(values[i], dst)
273 if err != nil {
274 err = ScanArgError{ColumnIndex: i, Err: err}
275 rows.fatal(err)
276 return err
277 }
278 }
279
280 return nil
281 }
282
283 func (rows *baseRows) Values() ([]any, error) {
284 if rows.closed {
285 return nil, errors.New("rows is closed")
286 }
287
288 values := make([]any, 0, len(rows.FieldDescriptions()))
289
290 for i := range rows.FieldDescriptions() {
291 buf := rows.values[i]
292 fd := &rows.FieldDescriptions()[i]
293
294 if buf == nil {
295 values = append(values, nil)
296 continue
297 }
298
299 if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok {
300 value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf)
301 if err != nil {
302 rows.fatal(err)
303 }
304 values = append(values, value)
305 } else {
306 switch fd.Format {
307 case TextFormatCode:
308 values = append(values, string(buf))
309 case BinaryFormatCode:
310 newBuf := make([]byte, len(buf))
311 copy(newBuf, buf)
312 values = append(values, newBuf)
313 default:
314 rows.fatal(errors.New("unknown format code"))
315 }
316 }
317
318 if rows.Err() != nil {
319 return nil, rows.Err()
320 }
321 }
322
323 return values, rows.Err()
324 }
325
326 func (rows *baseRows) RawValues() [][]byte {
327 return rows.values
328 }
329
330 func (rows *baseRows) Conn() *Conn {
331 return rows.conn
332 }
333
334 type ScanArgError struct {
335 ColumnIndex int
336 Err error
337 }
338
339 func (e ScanArgError) Error() string {
340 return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
341 }
342
343 func (e ScanArgError) Unwrap() error {
344 return e.Err
345 }
346
347
348
349
350
351
352
353 func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error {
354 if len(fieldDescriptions) != len(values) {
355 return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
356 }
357 if len(fieldDescriptions) != len(dest) {
358 return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
359 }
360
361 for i, d := range dest {
362 if d == nil {
363 continue
364 }
365
366 err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
367 if err != nil {
368 return ScanArgError{ColumnIndex: i, Err: err}
369 }
370 }
371
372 return nil
373 }
374
375
376
377 func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows {
378 return &baseRows{
379 typeMap: typeMap,
380 resultReader: resultReader,
381 }
382 }
383
384
385
386
387 func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) {
388 defer rows.Close()
389
390 for rows.Next() {
391 err := rows.Scan(scans...)
392 if err != nil {
393 return pgconn.CommandTag{}, err
394 }
395
396 err = fn()
397 if err != nil {
398 return pgconn.CommandTag{}, err
399 }
400 }
401
402 if err := rows.Err(); err != nil {
403 return pgconn.CommandTag{}, err
404 }
405
406 return rows.CommandTag(), nil
407 }
408
409
410 type CollectableRow interface {
411 FieldDescriptions() []pgconn.FieldDescription
412 Scan(dest ...any) error
413 Values() ([]any, error)
414 RawValues() [][]byte
415 }
416
417
418 type RowToFunc[T any] func(row CollectableRow) (T, error)
419
420
421 func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
422 defer rows.Close()
423
424 for rows.Next() {
425 value, err := fn(rows)
426 if err != nil {
427 return nil, err
428 }
429 slice = append(slice, value)
430 }
431
432 if err := rows.Err(); err != nil {
433 return nil, err
434 }
435
436 return slice, nil
437 }
438
439
440 func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
441 return AppendRows([]T{}, rows, fn)
442 }
443
444
445
446 func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
447 defer rows.Close()
448
449 var value T
450 var err error
451
452 if !rows.Next() {
453 if err = rows.Err(); err != nil {
454 return value, err
455 }
456 return value, ErrNoRows
457 }
458
459 value, err = fn(rows)
460 if err != nil {
461 return value, err
462 }
463
464 rows.Close()
465 return value, rows.Err()
466 }
467
468
469
470
471 func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
472 defer rows.Close()
473
474 var (
475 err error
476 value T
477 )
478
479 if !rows.Next() {
480 if err = rows.Err(); err != nil {
481 return value, err
482 }
483
484 return value, ErrNoRows
485 }
486
487 value, err = fn(rows)
488 if err != nil {
489 return value, err
490 }
491
492 if rows.Next() {
493 var zero T
494
495 return zero, ErrTooManyRows
496 }
497
498 return value, rows.Err()
499 }
500
501
502 func RowTo[T any](row CollectableRow) (T, error) {
503 var value T
504 err := row.Scan(&value)
505 return value, err
506 }
507
508
509 func RowToAddrOf[T any](row CollectableRow) (*T, error) {
510 var value T
511 err := row.Scan(&value)
512 return &value, err
513 }
514
515
516 func RowToMap(row CollectableRow) (map[string]any, error) {
517 var value map[string]any
518 err := row.Scan((*mapRowScanner)(&value))
519 return value, err
520 }
521
522 type mapRowScanner map[string]any
523
524 func (rs *mapRowScanner) ScanRow(rows Rows) error {
525 values, err := rows.Values()
526 if err != nil {
527 return err
528 }
529
530 *rs = make(mapRowScanner, len(values))
531
532 for i := range values {
533 (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i]
534 }
535
536 return nil
537 }
538
539
540
541
542 func RowToStructByPos[T any](row CollectableRow) (T, error) {
543 var value T
544 err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
545 return value, err
546 }
547
548
549
550
551 func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
552 var value T
553 err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
554 return &value, err
555 }
556
557 type positionalStructRowScanner struct {
558 ptrToStruct any
559 }
560
561 func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
562 dst := rs.ptrToStruct
563 dstValue := reflect.ValueOf(dst)
564 if dstValue.Kind() != reflect.Ptr {
565 return fmt.Errorf("dst not a pointer")
566 }
567
568 dstElemValue := dstValue.Elem()
569 scanTargets := rs.appendScanTargets(dstElemValue, nil)
570
571 if len(rows.RawValues()) > len(scanTargets) {
572 return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
573 }
574
575 return rows.Scan(scanTargets...)
576 }
577
578 func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any {
579 dstElemType := dstElemValue.Type()
580
581 if scanTargets == nil {
582 scanTargets = make([]any, 0, dstElemType.NumField())
583 }
584
585 for i := 0; i < dstElemType.NumField(); i++ {
586 sf := dstElemType.Field(i)
587
588 if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
589 scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
590 } else if sf.PkgPath == "" {
591 dbTag, _ := sf.Tag.Lookup(structTagKey)
592 if dbTag == "-" {
593
594 continue
595 }
596 scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
597 }
598 }
599
600 return scanTargets
601 }
602
603
604
605
606 func RowToStructByName[T any](row CollectableRow) (T, error) {
607 var value T
608 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
609 return value, err
610 }
611
612
613
614
615
616 func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
617 var value T
618 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
619 return &value, err
620 }
621
622
623
624
625 func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
626 var value T
627 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
628 return value, err
629 }
630
631
632
633
634
635 func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
636 var value T
637 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
638 return &value, err
639 }
640
641 type namedStructRowScanner struct {
642 ptrToStruct any
643 lax bool
644 }
645
646 func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
647 dst := rs.ptrToStruct
648 dstValue := reflect.ValueOf(dst)
649 if dstValue.Kind() != reflect.Ptr {
650 return fmt.Errorf("dst not a pointer")
651 }
652
653 dstElemValue := dstValue.Elem()
654 scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
655 if err != nil {
656 return err
657 }
658
659 for i, t := range scanTargets {
660 if t == nil {
661 return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
662 }
663 }
664
665 return rows.Scan(scanTargets...)
666 }
667
668 const structTagKey = "db"
669
670 func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
671 i = -1
672 for i, desc := range fldDescs {
673
674
675 field = strings.ReplaceAll(field, "_", "")
676 descName := strings.ReplaceAll(desc.Name, "_", "")
677
678 if strings.EqualFold(descName, field) {
679 return i
680 }
681 }
682 return
683 }
684
685 func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
686 var err error
687 dstElemType := dstElemValue.Type()
688
689 if scanTargets == nil {
690 scanTargets = make([]any, len(fldDescs))
691 }
692
693 for i := 0; i < dstElemType.NumField(); i++ {
694 sf := dstElemType.Field(i)
695 if sf.PkgPath != "" && !sf.Anonymous {
696
697 continue
698 }
699
700 if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
701 scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
702 if err != nil {
703 return nil, err
704 }
705 } else {
706 dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
707 if dbTagPresent {
708 dbTag, _, _ = strings.Cut(dbTag, ",")
709 }
710 if dbTag == "-" {
711
712 continue
713 }
714 colName := dbTag
715 if !dbTagPresent {
716 colName = sf.Name
717 }
718 fpos := fieldPosByName(fldDescs, colName)
719 if fpos == -1 {
720 if rs.lax {
721 continue
722 }
723 return nil, fmt.Errorf("cannot find field %s in returned row", colName)
724 }
725 if fpos >= len(scanTargets) && !rs.lax {
726 return nil, fmt.Errorf("cannot find field %s in returned row", colName)
727 }
728 scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
729 }
730 }
731
732 return scanTargets, err
733 }
734
View as plain text