1 package pgx
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7
8 "github.com/jackc/pgx/v5/pgconn"
9 )
10
11
12 type QueuedQuery struct {
13 SQL string
14 Arguments []any
15 fn batchItemFunc
16 sd *pgconn.StatementDescription
17 }
18
19 type batchItemFunc func(br BatchResults) error
20
21
22 func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
23 qq.fn = func(br BatchResults) error {
24 rows, _ := br.Query()
25 defer rows.Close()
26
27 err := fn(rows)
28 if err != nil {
29 return err
30 }
31 rows.Close()
32
33 return rows.Err()
34 }
35 }
36
37
38 func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
39 qq.fn = func(br BatchResults) error {
40 row := br.QueryRow()
41 return fn(row)
42 }
43 }
44
45
46 func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
47 qq.fn = func(br BatchResults) error {
48 ct, err := br.Exec()
49 if err != nil {
50 return err
51 }
52
53 return fn(ct)
54 }
55 }
56
57
58
59 type Batch struct {
60 QueuedQueries []*QueuedQuery
61 }
62
63
64
65
66 func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
67 qq := &QueuedQuery{
68 SQL: query,
69 Arguments: arguments,
70 }
71 b.QueuedQueries = append(b.QueuedQueries, qq)
72 return qq
73 }
74
75
76 func (b *Batch) Len() int {
77 return len(b.QueuedQueries)
78 }
79
80 type BatchResults interface {
81
82
83 Exec() (pgconn.CommandTag, error)
84
85
86
87 Query() (Rows, error)
88
89
90
91 QueryRow() Row
92
93
94
95
96
97
98
99
100
101
102
103 Close() error
104 }
105
106 type batchResults struct {
107 ctx context.Context
108 conn *Conn
109 mrr *pgconn.MultiResultReader
110 err error
111 b *Batch
112 qqIdx int
113 closed bool
114 endTraced bool
115 }
116
117
118 func (br *batchResults) Exec() (pgconn.CommandTag, error) {
119 if br.err != nil {
120 return pgconn.CommandTag{}, br.err
121 }
122 if br.closed {
123 return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
124 }
125
126 query, arguments, _ := br.nextQueryAndArgs()
127
128 if !br.mrr.NextResult() {
129 err := br.mrr.Close()
130 if err == nil {
131 err = errors.New("no result")
132 }
133 if br.conn.batchTracer != nil {
134 br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
135 SQL: query,
136 Args: arguments,
137 Err: err,
138 })
139 }
140 return pgconn.CommandTag{}, err
141 }
142
143 commandTag, err := br.mrr.ResultReader().Close()
144 if err != nil {
145 br.err = err
146 br.mrr.Close()
147 }
148
149 if br.conn.batchTracer != nil {
150 br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
151 SQL: query,
152 Args: arguments,
153 CommandTag: commandTag,
154 Err: br.err,
155 })
156 }
157
158 return commandTag, br.err
159 }
160
161
162 func (br *batchResults) Query() (Rows, error) {
163 query, arguments, ok := br.nextQueryAndArgs()
164 if !ok {
165 query = "batch query"
166 }
167
168 if br.err != nil {
169 return &baseRows{err: br.err, closed: true}, br.err
170 }
171
172 if br.closed {
173 alreadyClosedErr := fmt.Errorf("batch already closed")
174 return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
175 }
176
177 rows := br.conn.getRows(br.ctx, query, arguments)
178 rows.batchTracer = br.conn.batchTracer
179
180 if !br.mrr.NextResult() {
181 rows.err = br.mrr.Close()
182 if rows.err == nil {
183 rows.err = errors.New("no result")
184 }
185 rows.closed = true
186
187 if br.conn.batchTracer != nil {
188 br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
189 SQL: query,
190 Args: arguments,
191 Err: rows.err,
192 })
193 }
194
195 return rows, rows.err
196 }
197
198 rows.resultReader = br.mrr.ResultReader()
199 return rows, nil
200 }
201
202
203 func (br *batchResults) QueryRow() Row {
204 rows, _ := br.Query()
205 return (*connRow)(rows.(*baseRows))
206
207 }
208
209
210
211 func (br *batchResults) Close() error {
212 defer func() {
213 if !br.endTraced {
214 if br.conn != nil && br.conn.batchTracer != nil {
215 br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
216 }
217 br.endTraced = true
218 }
219 }()
220
221 if br.err != nil {
222 return br.err
223 }
224
225 if br.closed {
226 return nil
227 }
228
229
230 for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
231 if br.b.QueuedQueries[br.qqIdx].fn != nil {
232 err := br.b.QueuedQueries[br.qqIdx].fn(br)
233 if err != nil {
234 br.err = err
235 }
236 } else {
237 br.Exec()
238 }
239 }
240
241 br.closed = true
242
243 err := br.mrr.Close()
244 if br.err == nil {
245 br.err = err
246 }
247
248 return br.err
249 }
250
251 func (br *batchResults) earlyError() error {
252 return br.err
253 }
254
255 func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
256 if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
257 bi := br.b.QueuedQueries[br.qqIdx]
258 query = bi.SQL
259 args = bi.Arguments
260 ok = true
261 br.qqIdx++
262 }
263 return
264 }
265
266 type pipelineBatchResults struct {
267 ctx context.Context
268 conn *Conn
269 pipeline *pgconn.Pipeline
270 lastRows *baseRows
271 err error
272 b *Batch
273 qqIdx int
274 closed bool
275 endTraced bool
276 }
277
278
279 func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
280 if br.err != nil {
281 return pgconn.CommandTag{}, br.err
282 }
283 if br.closed {
284 return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
285 }
286 if br.lastRows != nil && br.lastRows.err != nil {
287 return pgconn.CommandTag{}, br.err
288 }
289
290 query, arguments, _ := br.nextQueryAndArgs()
291
292 results, err := br.pipeline.GetResults()
293 if err != nil {
294 br.err = err
295 return pgconn.CommandTag{}, br.err
296 }
297 var commandTag pgconn.CommandTag
298 switch results := results.(type) {
299 case *pgconn.ResultReader:
300 commandTag, br.err = results.Close()
301 default:
302 return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
303 }
304
305 if br.conn.batchTracer != nil {
306 br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
307 SQL: query,
308 Args: arguments,
309 CommandTag: commandTag,
310 Err: br.err,
311 })
312 }
313
314 return commandTag, br.err
315 }
316
317
318 func (br *pipelineBatchResults) Query() (Rows, error) {
319 if br.err != nil {
320 return &baseRows{err: br.err, closed: true}, br.err
321 }
322
323 if br.closed {
324 alreadyClosedErr := fmt.Errorf("batch already closed")
325 return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
326 }
327
328 if br.lastRows != nil && br.lastRows.err != nil {
329 br.err = br.lastRows.err
330 return &baseRows{err: br.err, closed: true}, br.err
331 }
332
333 query, arguments, ok := br.nextQueryAndArgs()
334 if !ok {
335 query = "batch query"
336 }
337
338 rows := br.conn.getRows(br.ctx, query, arguments)
339 rows.batchTracer = br.conn.batchTracer
340 br.lastRows = rows
341
342 results, err := br.pipeline.GetResults()
343 if err != nil {
344 br.err = err
345 rows.err = err
346 rows.closed = true
347
348 if br.conn.batchTracer != nil {
349 br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
350 SQL: query,
351 Args: arguments,
352 Err: err,
353 })
354 }
355 } else {
356 switch results := results.(type) {
357 case *pgconn.ResultReader:
358 rows.resultReader = results
359 default:
360 err = fmt.Errorf("unexpected pipeline result: %T", results)
361 br.err = err
362 rows.err = err
363 rows.closed = true
364 }
365 }
366
367 return rows, rows.err
368 }
369
370
371 func (br *pipelineBatchResults) QueryRow() Row {
372 rows, _ := br.Query()
373 return (*connRow)(rows.(*baseRows))
374
375 }
376
377
378
379 func (br *pipelineBatchResults) Close() error {
380 defer func() {
381 if !br.endTraced {
382 if br.conn.batchTracer != nil {
383 br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
384 }
385 br.endTraced = true
386 }
387 }()
388
389 if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
390 br.err = br.lastRows.err
391 return br.err
392 }
393
394 if br.closed {
395 return br.err
396 }
397
398
399 for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
400 if br.b.QueuedQueries[br.qqIdx].fn != nil {
401 err := br.b.QueuedQueries[br.qqIdx].fn(br)
402 if err != nil {
403 br.err = err
404 }
405 } else {
406 br.Exec()
407 }
408 }
409
410 br.closed = true
411
412 err := br.pipeline.Close()
413 if br.err == nil {
414 br.err = err
415 }
416
417 return br.err
418 }
419
420 func (br *pipelineBatchResults) earlyError() error {
421 return br.err
422 }
423
424 func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
425 if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
426 bi := br.b.QueuedQueries[br.qqIdx]
427 query = bi.SQL
428 args = bi.Arguments
429 ok = true
430 br.qqIdx++
431 }
432 return
433 }
434
View as plain text