1 package pgx
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "time"
8
9 "github.com/jackc/pgconn"
10 "github.com/jackc/pgproto3/v2"
11 "github.com/jackc/pgtype"
12 )
13
14
15
16
17
18
19
20
21
22
23
24 type Rows interface {
25
26
27 Close()
28
29
30 Err() error
31
32
33 CommandTag() pgconn.CommandTag
34
35 FieldDescriptions() []pgproto3.FieldDescription
36
37
38
39
40 Next() bool
41
42
43
44
45
46 Scan(dest ...interface{}) error
47
48
49
50
51 Values() ([]interface{}, error)
52
53
54
55 RawValues() [][]byte
56 }
57
58
59
60
61
62
63
64 type Row interface {
65
66
67
68 Scan(dest ...interface{}) error
69 }
70
71
72 type connRow connRows
73
74 func (r *connRow) Scan(dest ...interface{}) (err error) {
75 rows := (*connRows)(r)
76
77 if rows.Err() != nil {
78 return rows.Err()
79 }
80
81 if !rows.Next() {
82 if rows.Err() == nil {
83 return ErrNoRows
84 }
85 return rows.Err()
86 }
87
88 rows.Scan(dest...)
89 rows.Close()
90 return rows.Err()
91 }
92
93 type rowLog interface {
94 shouldLog(lvl LogLevel) bool
95 log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{})
96 }
97
98
99 type connRows struct {
100 ctx context.Context
101 logger rowLog
102 connInfo *pgtype.ConnInfo
103 values [][]byte
104 rowCount int
105 err error
106 commandTag pgconn.CommandTag
107 startTime time.Time
108 sql string
109 args []interface{}
110 closed bool
111 conn *Conn
112
113 resultReader *pgconn.ResultReader
114 multiResultReader *pgconn.MultiResultReader
115
116 scanPlans []pgtype.ScanPlan
117 }
118
119 func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription {
120 return rows.resultReader.FieldDescriptions()
121 }
122
123 func (rows *connRows) Close() {
124 if rows.closed {
125 return
126 }
127
128 rows.closed = true
129
130 if rows.resultReader != nil {
131 var closeErr error
132 rows.commandTag, closeErr = rows.resultReader.Close()
133 if rows.err == nil {
134 rows.err = closeErr
135 }
136 }
137
138 if rows.multiResultReader != nil {
139 closeErr := rows.multiResultReader.Close()
140 if rows.err == nil {
141 rows.err = closeErr
142 }
143 }
144
145 if rows.logger != nil {
146 endTime := time.Now()
147
148 if rows.err == nil {
149 if rows.logger.shouldLog(LogLevelInfo) {
150 rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
151 }
152 } else {
153 if rows.logger.shouldLog(LogLevelError) {
154 rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)})
155 }
156 if rows.err != nil && rows.conn.stmtcache != nil {
157 rows.conn.stmtcache.StatementErrored(rows.sql, rows.err)
158 }
159 }
160 }
161 }
162
163 func (rows *connRows) CommandTag() pgconn.CommandTag {
164 return rows.commandTag
165 }
166
167 func (rows *connRows) Err() error {
168 return rows.err
169 }
170
171
172
173 func (rows *connRows) fatal(err error) {
174 if rows.err != nil {
175 return
176 }
177
178 rows.err = err
179 rows.Close()
180 }
181
182 func (rows *connRows) Next() bool {
183 if rows.closed {
184 return false
185 }
186
187 if rows.resultReader.NextRow() {
188 rows.rowCount++
189 rows.values = rows.resultReader.Values()
190 return true
191 } else {
192 rows.Close()
193 return false
194 }
195 }
196
197 func (rows *connRows) Scan(dest ...interface{}) error {
198 ci := rows.connInfo
199 fieldDescriptions := rows.FieldDescriptions()
200 values := rows.values
201
202 if len(fieldDescriptions) != len(values) {
203 err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
204 rows.fatal(err)
205 return err
206 }
207 if len(fieldDescriptions) != len(dest) {
208 err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
209 rows.fatal(err)
210 return err
211 }
212
213 if rows.scanPlans == nil {
214 rows.scanPlans = make([]pgtype.ScanPlan, len(values))
215 for i := range dest {
216 rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i])
217 }
218 }
219
220 for i, dst := range dest {
221 if dst == nil {
222 continue
223 }
224
225 err := rows.scanPlans[i].Scan(ci, fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], dst)
226 if err != nil {
227 err = ScanArgError{ColumnIndex: i, Err: err}
228 rows.fatal(err)
229 return err
230 }
231 }
232
233 return nil
234 }
235
236 func (rows *connRows) Values() ([]interface{}, error) {
237 if rows.closed {
238 return nil, errors.New("rows is closed")
239 }
240
241 values := make([]interface{}, 0, len(rows.FieldDescriptions()))
242
243 for i := range rows.FieldDescriptions() {
244 buf := rows.values[i]
245 fd := &rows.FieldDescriptions()[i]
246
247 if buf == nil {
248 values = append(values, nil)
249 continue
250 }
251
252 if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok {
253 value := dt.Value
254
255 switch fd.Format {
256 case TextFormatCode:
257 decoder, ok := value.(pgtype.TextDecoder)
258 if !ok {
259 decoder = &pgtype.GenericText{}
260 }
261 err := decoder.DecodeText(rows.connInfo, buf)
262 if err != nil {
263 rows.fatal(err)
264 }
265 values = append(values, decoder.(pgtype.Value).Get())
266 case BinaryFormatCode:
267 decoder, ok := value.(pgtype.BinaryDecoder)
268 if !ok {
269 decoder = &pgtype.GenericBinary{}
270 }
271 err := decoder.DecodeBinary(rows.connInfo, buf)
272 if err != nil {
273 rows.fatal(err)
274 }
275 values = append(values, value.Get())
276 default:
277 rows.fatal(errors.New("Unknown format code"))
278 }
279 } else {
280 switch fd.Format {
281 case TextFormatCode:
282 decoder := &pgtype.GenericText{}
283 err := decoder.DecodeText(rows.connInfo, buf)
284 if err != nil {
285 rows.fatal(err)
286 }
287 values = append(values, decoder.Get())
288 case BinaryFormatCode:
289 decoder := &pgtype.GenericBinary{}
290 err := decoder.DecodeBinary(rows.connInfo, buf)
291 if err != nil {
292 rows.fatal(err)
293 }
294 values = append(values, decoder.Get())
295 default:
296 rows.fatal(errors.New("Unknown format code"))
297 }
298 }
299
300 if rows.Err() != nil {
301 return nil, rows.Err()
302 }
303 }
304
305 return values, rows.Err()
306 }
307
308 func (rows *connRows) RawValues() [][]byte {
309 return rows.values
310 }
311
312 type ScanArgError struct {
313 ColumnIndex int
314 Err error
315 }
316
317 func (e ScanArgError) Error() string {
318 return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
319 }
320
321 func (e ScanArgError) Unwrap() error {
322 return e.Err
323 }
324
325
326
327
328
329
330
331 func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error {
332 if len(fieldDescriptions) != len(values) {
333 return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values))
334 }
335 if len(fieldDescriptions) != len(dest) {
336 return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest))
337 }
338
339 for i, d := range dest {
340 if d == nil {
341 continue
342 }
343
344 err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
345 if err != nil {
346 return ScanArgError{ColumnIndex: i, Err: err}
347 }
348 }
349
350 return nil
351 }
352
View as plain text