...
1 package exec
2
3 import (
4 "database/sql"
5 "reflect"
6
7 "github.com/doug-martin/goqu/v9/exp"
8 "github.com/doug-martin/goqu/v9/internal/errors"
9 "github.com/doug-martin/goqu/v9/internal/util"
10 )
11
12 type (
13
14 Scanner interface {
15 Next() bool
16 ScanStruct(i interface{}) error
17 ScanStructs(i interface{}) error
18 ScanVal(i interface{}) error
19 ScanVals(i interface{}) error
20 Close() error
21 Err() error
22 }
23
24 scanner struct {
25 rows *sql.Rows
26 columnMap util.ColumnMap
27 columns []string
28 }
29 )
30
31 func unableToFindFieldError(col string) error {
32 return errors.New(`unable to find corresponding field to column "%s" returned by query`, col)
33 }
34
35
36 func NewScanner(rows *sql.Rows) Scanner {
37 return &scanner{rows: rows}
38 }
39
40
41
42 func (s *scanner) Next() bool {
43 return s.rows.Next()
44 }
45
46
47
48 func (s *scanner) Err() error {
49 return s.rows.Err()
50 }
51
52
53 func (s *scanner) ScanStruct(i interface{}) error {
54
55 if s.columnMap == nil || s.columns == nil {
56 cm, err := util.GetColumnMap(i)
57 if err != nil {
58 return err
59 }
60
61 cols, err := s.rows.Columns()
62 if err != nil {
63 return err
64 }
65
66 s.columnMap = cm
67 s.columns = cols
68 }
69
70 scans := make([]interface{}, 0, len(s.columns))
71 for _, col := range s.columns {
72 data, ok := s.columnMap[col]
73 switch {
74 case !ok:
75 return unableToFindFieldError(col)
76 default:
77 scans = append(scans, reflect.New(data.GoType).Interface())
78 }
79 }
80
81 if err := s.rows.Scan(scans...); err != nil {
82 return err
83 }
84
85 record := exp.Record{}
86 for index, col := range s.columns {
87 record[col] = scans[index]
88 }
89
90 util.AssignStructVals(i, record, s.columnMap)
91
92 return s.Err()
93 }
94
95
96 func (s *scanner) ScanStructs(i interface{}) error {
97 val, err := checkScanStructsTarget(i)
98 if err != nil {
99 return err
100 }
101 return s.scanIntoSlice(val, func(i interface{}) error {
102 return s.ScanStruct(i)
103 })
104 }
105
106
107 func (s *scanner) ScanVal(i interface{}) error {
108 if err := s.rows.Scan(i); err != nil {
109 return err
110 }
111
112 return s.Err()
113 }
114
115
116 func (s *scanner) ScanVals(i interface{}) error {
117 val, err := checkScanValsTarget(i)
118 if err != nil {
119 return err
120 }
121 return s.scanIntoSlice(val, func(i interface{}) error {
122 return s.ScanVal(i)
123 })
124 }
125
126
127
128 func (s *scanner) Close() error {
129 return s.rows.Close()
130 }
131
132 func (s *scanner) scanIntoSlice(val reflect.Value, it func(i interface{}) error) error {
133 elemType := util.GetSliceElementType(val)
134
135 for s.Next() {
136 row := reflect.New(elemType)
137 if rowErr := it(row.Interface()); rowErr != nil {
138 return rowErr
139 }
140 util.AppendSliceElement(val, row)
141 }
142
143 return s.Err()
144 }
145
146 func checkScanStructsTarget(i interface{}) (reflect.Value, error) {
147 val := reflect.ValueOf(i)
148 if !util.IsPointer(val.Kind()) {
149 return val, errUnsupportedScanStructsType
150 }
151 val = reflect.Indirect(val)
152 if !util.IsSlice(val.Kind()) {
153 return val, errUnsupportedScanStructsType
154 }
155 return val, nil
156 }
157
158 func checkScanValsTarget(i interface{}) (reflect.Value, error) {
159 val := reflect.ValueOf(i)
160 if !util.IsPointer(val.Kind()) {
161 return val, errUnsupportedScanValsType
162 }
163 val = reflect.Indirect(val)
164 if !util.IsSlice(val.Kind()) {
165 return val, errUnsupportedScanValsType
166 }
167 return val, nil
168 }
169
View as plain text