1 package pgx_test
2
3 import (
4 "context"
5 "testing"
6 "time"
7
8 "github.com/jackc/pgx/v5"
9 "github.com/jackc/pgx/v5/pgxtest"
10 "github.com/stretchr/testify/require"
11 )
12
13 type testTracer struct {
14 traceQueryStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context
15 traceQueryEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData)
16 traceBatchStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context
17 traceBatchQuery func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData)
18 traceBatchEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData)
19 traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context
20 traceCopyFromEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData)
21 tracePrepareStart func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context
22 tracePrepareEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData)
23 traceConnectStart func(ctx context.Context, data pgx.TraceConnectStartData) context.Context
24 traceConnectEnd func(ctx context.Context, data pgx.TraceConnectEndData)
25 }
26
27 type ctxKey string
28
29 func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
30 if tt.traceQueryStart != nil {
31 return tt.traceQueryStart(ctx, conn, data)
32 }
33 return ctx
34 }
35
36 func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
37 if tt.traceQueryEnd != nil {
38 tt.traceQueryEnd(ctx, conn, data)
39 }
40 }
41
42 func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
43 if tt.traceBatchStart != nil {
44 return tt.traceBatchStart(ctx, conn, data)
45 }
46 return ctx
47 }
48
49 func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
50 if tt.traceBatchQuery != nil {
51 tt.traceBatchQuery(ctx, conn, data)
52 }
53 }
54
55 func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
56 if tt.traceBatchEnd != nil {
57 tt.traceBatchEnd(ctx, conn, data)
58 }
59 }
60
61 func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
62 if tt.traceCopyFromStart != nil {
63 return tt.traceCopyFromStart(ctx, conn, data)
64 }
65 return ctx
66 }
67
68 func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
69 if tt.traceCopyFromEnd != nil {
70 tt.traceCopyFromEnd(ctx, conn, data)
71 }
72 }
73
74 func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
75 if tt.tracePrepareStart != nil {
76 return tt.tracePrepareStart(ctx, conn, data)
77 }
78 return ctx
79 }
80
81 func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
82 if tt.tracePrepareEnd != nil {
83 tt.tracePrepareEnd(ctx, conn, data)
84 }
85 }
86
87 func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
88 if tt.traceConnectStart != nil {
89 return tt.traceConnectStart(ctx, data)
90 }
91 return ctx
92 }
93
94 func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
95 if tt.traceConnectEnd != nil {
96 tt.traceConnectEnd(ctx, data)
97 }
98 }
99
100 func TestTraceExec(t *testing.T) {
101 t.Parallel()
102
103 tracer := &testTracer{}
104
105 ctr := defaultConnTestRunner
106 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
107 config := defaultConnTestRunner.CreateConfig(ctx, t)
108 config.Tracer = tracer
109 return config
110 }
111
112 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
113 defer cancel()
114
115 pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
116 traceQueryStartCalled := false
117 tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
118 traceQueryStartCalled = true
119 require.Equal(t, `select $1::text`, data.SQL)
120 require.Len(t, data.Args, 1)
121 require.Equal(t, `testing`, data.Args[0])
122 return context.WithValue(ctx, ctxKey(ctxKey("fromTraceQueryStart")), "foo")
123 }
124
125 traceQueryEndCalled := false
126 tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
127 traceQueryEndCalled = true
128 require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceQueryStart"))))
129 require.Equal(t, `SELECT 1`, data.CommandTag.String())
130 require.NoError(t, data.Err)
131 }
132
133 _, err := conn.Exec(ctx, `select $1::text`, "testing")
134 require.NoError(t, err)
135 require.True(t, traceQueryStartCalled)
136 require.True(t, traceQueryEndCalled)
137 })
138 }
139
140 func TestTraceQuery(t *testing.T) {
141 t.Parallel()
142
143 tracer := &testTracer{}
144
145 ctr := defaultConnTestRunner
146 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
147 config := defaultConnTestRunner.CreateConfig(ctx, t)
148 config.Tracer = tracer
149 return config
150 }
151
152 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
153 defer cancel()
154
155 pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
156 traceQueryStartCalled := false
157 tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
158 traceQueryStartCalled = true
159 require.Equal(t, `select $1::text`, data.SQL)
160 require.Len(t, data.Args, 1)
161 require.Equal(t, `testing`, data.Args[0])
162 return context.WithValue(ctx, ctxKey("fromTraceQueryStart"), "foo")
163 }
164
165 traceQueryEndCalled := false
166 tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
167 traceQueryEndCalled = true
168 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceQueryStart")))
169 require.Equal(t, `SELECT 1`, data.CommandTag.String())
170 require.NoError(t, data.Err)
171 }
172
173 var s string
174 err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s)
175 require.NoError(t, err)
176 require.Equal(t, "testing", s)
177 require.True(t, traceQueryStartCalled)
178 require.True(t, traceQueryEndCalled)
179 })
180 }
181
182 func TestTraceBatchNormal(t *testing.T) {
183 t.Parallel()
184
185 tracer := &testTracer{}
186
187 ctr := defaultConnTestRunner
188 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
189 config := defaultConnTestRunner.CreateConfig(ctx, t)
190 config.Tracer = tracer
191 return config
192 }
193
194 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
195 defer cancel()
196
197 pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
198 traceBatchStartCalled := false
199 tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
200 traceBatchStartCalled = true
201 require.NotNil(t, data.Batch)
202 require.Equal(t, 2, data.Batch.Len())
203 return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
204 }
205
206 traceBatchQueryCalledCount := 0
207 tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
208 traceBatchQueryCalledCount++
209 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
210 require.NoError(t, data.Err)
211 }
212
213 traceBatchEndCalled := false
214 tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
215 traceBatchEndCalled = true
216 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
217 require.NoError(t, data.Err)
218 }
219
220 batch := &pgx.Batch{}
221 batch.Queue(`select 1`)
222 batch.Queue(`select 2`)
223
224 br := conn.SendBatch(context.Background(), batch)
225 require.True(t, traceBatchStartCalled)
226
227 var n int32
228 err := br.QueryRow().Scan(&n)
229 require.NoError(t, err)
230 require.EqualValues(t, 1, n)
231 require.EqualValues(t, 1, traceBatchQueryCalledCount)
232
233 err = br.QueryRow().Scan(&n)
234 require.NoError(t, err)
235 require.EqualValues(t, 2, n)
236 require.EqualValues(t, 2, traceBatchQueryCalledCount)
237
238 err = br.Close()
239 require.NoError(t, err)
240
241 require.True(t, traceBatchEndCalled)
242 })
243 }
244
245 func TestTraceBatchClose(t *testing.T) {
246 t.Parallel()
247
248 tracer := &testTracer{}
249
250 ctr := defaultConnTestRunner
251 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
252 config := defaultConnTestRunner.CreateConfig(ctx, t)
253 config.Tracer = tracer
254 return config
255 }
256
257 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
258 defer cancel()
259
260 pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
261 traceBatchStartCalled := false
262 tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
263 traceBatchStartCalled = true
264 require.NotNil(t, data.Batch)
265 require.Equal(t, 2, data.Batch.Len())
266 return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
267 }
268
269 traceBatchQueryCalledCount := 0
270 tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
271 traceBatchQueryCalledCount++
272 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
273 require.NoError(t, data.Err)
274 }
275
276 traceBatchEndCalled := false
277 tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
278 traceBatchEndCalled = true
279 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
280 require.NoError(t, data.Err)
281 }
282
283 batch := &pgx.Batch{}
284 batch.Queue(`select 1`)
285 batch.Queue(`select 2`)
286
287 br := conn.SendBatch(context.Background(), batch)
288 require.True(t, traceBatchStartCalled)
289 err := br.Close()
290 require.NoError(t, err)
291 require.EqualValues(t, 2, traceBatchQueryCalledCount)
292 require.True(t, traceBatchEndCalled)
293 })
294 }
295
296 func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
297 t.Parallel()
298
299 tracer := &testTracer{}
300
301 ctr := defaultConnTestRunner
302 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
303 config := defaultConnTestRunner.CreateConfig(ctx, t)
304 config.Tracer = tracer
305 return config
306 }
307
308 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
309 defer cancel()
310
311 pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
312 traceBatchStartCalled := false
313 tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
314 traceBatchStartCalled = true
315 require.NotNil(t, data.Batch)
316 require.Equal(t, 3, data.Batch.Len())
317 return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
318 }
319
320 traceBatchQueryCalledCount := 0
321 tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
322 traceBatchQueryCalledCount++
323 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
324 if traceBatchQueryCalledCount == 2 {
325 require.Error(t, data.Err)
326 } else {
327 require.NoError(t, data.Err)
328 }
329 }
330
331 traceBatchEndCalled := false
332 tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
333 traceBatchEndCalled = true
334 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
335 require.Error(t, data.Err)
336 }
337
338 batch := &pgx.Batch{}
339 batch.Queue(`select 1`)
340 batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
341 batch.Queue(`select 3`)
342
343 br := conn.SendBatch(context.Background(), batch)
344 require.True(t, traceBatchStartCalled)
345
346 commandTag, err := br.Exec()
347 require.NoError(t, err)
348 require.Equal(t, "SELECT 1", commandTag.String())
349
350 commandTag, err = br.Exec()
351 require.Error(t, err)
352 require.Equal(t, "", commandTag.String())
353
354 commandTag, err = br.Exec()
355 require.Error(t, err)
356 require.Equal(t, "", commandTag.String())
357
358 err = br.Close()
359 require.Error(t, err)
360 require.EqualValues(t, 2, traceBatchQueryCalledCount)
361 require.True(t, traceBatchEndCalled)
362 })
363 }
364
365 func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
366 t.Parallel()
367
368 tracer := &testTracer{}
369
370 ctr := defaultConnTestRunner
371 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
372 config := defaultConnTestRunner.CreateConfig(ctx, t)
373 config.Tracer = tracer
374 return config
375 }
376
377 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
378 defer cancel()
379
380 pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
381 traceBatchStartCalled := false
382 tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
383 traceBatchStartCalled = true
384 require.NotNil(t, data.Batch)
385 require.Equal(t, 3, data.Batch.Len())
386 return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
387 }
388
389 traceBatchQueryCalledCount := 0
390 tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
391 traceBatchQueryCalledCount++
392 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
393 if traceBatchQueryCalledCount == 2 {
394 require.Error(t, data.Err)
395 } else {
396 require.NoError(t, data.Err)
397 }
398 }
399
400 traceBatchEndCalled := false
401 tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
402 traceBatchEndCalled = true
403 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
404 require.Error(t, data.Err)
405 }
406
407 batch := &pgx.Batch{}
408 batch.Queue(`select 1`)
409 batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
410 batch.Queue(`select 3`)
411
412 br := conn.SendBatch(context.Background(), batch)
413 require.True(t, traceBatchStartCalled)
414 err := br.Close()
415 require.Error(t, err)
416 require.EqualValues(t, 2, traceBatchQueryCalledCount)
417 require.True(t, traceBatchEndCalled)
418 })
419 }
420
421 func TestTraceCopyFrom(t *testing.T) {
422 t.Parallel()
423
424 tracer := &testTracer{}
425
426 ctr := defaultConnTestRunner
427 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
428 config := defaultConnTestRunner.CreateConfig(ctx, t)
429 config.Tracer = tracer
430 return config
431 }
432
433 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
434 defer cancel()
435
436 pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
437 ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
438 defer cancel()
439
440 traceCopyFromStartCalled := false
441 tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
442 traceCopyFromStartCalled = true
443 require.Equal(t, pgx.Identifier{"foo"}, data.TableName)
444 require.Equal(t, []string{"a"}, data.ColumnNames)
445 return context.WithValue(ctx, ctxKey("fromTraceCopyFromStart"), "foo")
446 }
447
448 traceCopyFromEndCalled := false
449 tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
450 traceCopyFromEndCalled = true
451 require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceCopyFromStart")))
452 require.Equal(t, `COPY 2`, data.CommandTag.String())
453 require.NoError(t, data.Err)
454 }
455
456 _, err := conn.Exec(ctx, `create temporary table foo(a int4)`)
457 require.NoError(t, err)
458
459 inputRows := [][]any{
460 {int32(1)},
461 {nil},
462 }
463
464 copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
465 require.NoError(t, err)
466 require.EqualValues(t, len(inputRows), copyCount)
467 require.True(t, traceCopyFromStartCalled)
468 require.True(t, traceCopyFromEndCalled)
469 })
470 }
471
472 func TestTracePrepare(t *testing.T) {
473 t.Parallel()
474
475 tracer := &testTracer{}
476
477 ctr := defaultConnTestRunner
478 ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
479 config := defaultConnTestRunner.CreateConfig(ctx, t)
480 config.Tracer = tracer
481 return config
482 }
483
484 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
485 defer cancel()
486
487 pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
488 tracePrepareStartCalled := false
489 tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
490 tracePrepareStartCalled = true
491 require.Equal(t, `ps`, data.Name)
492 require.Equal(t, `select $1::text`, data.SQL)
493 return context.WithValue(ctx, ctxKey("fromTracePrepareStart"), "foo")
494 }
495
496 tracePrepareEndCalled := false
497 tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
498 tracePrepareEndCalled = true
499 require.False(t, data.AlreadyPrepared)
500 require.NoError(t, data.Err)
501 }
502
503 _, err := conn.Prepare(ctx, "ps", `select $1::text`)
504 require.NoError(t, err)
505 require.True(t, tracePrepareStartCalled)
506 require.True(t, tracePrepareEndCalled)
507
508 tracePrepareStartCalled = false
509 tracePrepareEndCalled = false
510 tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
511 tracePrepareEndCalled = true
512 require.True(t, data.AlreadyPrepared)
513 require.NoError(t, data.Err)
514 }
515
516 _, err = conn.Prepare(ctx, "ps", `select $1::text`)
517 require.NoError(t, err)
518 require.True(t, tracePrepareStartCalled)
519 require.True(t, tracePrepareEndCalled)
520 })
521 }
522
523 func TestTraceConnect(t *testing.T) {
524 t.Parallel()
525
526 tracer := &testTracer{}
527
528 config := defaultConnTestRunner.CreateConfig(context.Background(), t)
529 config.Tracer = tracer
530
531 traceConnectStartCalled := false
532 tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
533 traceConnectStartCalled = true
534 require.NotNil(t, data.ConnConfig)
535 return context.WithValue(ctx, ctxKey("fromTraceConnectStart"), "foo")
536 }
537
538 traceConnectEndCalled := false
539 tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
540 traceConnectEndCalled = true
541 require.NotNil(t, data.Conn)
542 require.NoError(t, data.Err)
543 }
544
545 conn1, err := pgx.ConnectConfig(context.Background(), config)
546 require.NoError(t, err)
547 defer conn1.Close(context.Background())
548 require.True(t, traceConnectStartCalled)
549 require.True(t, traceConnectEndCalled)
550
551 config, err = pgx.ParseConfig("host=/invalid")
552 require.NoError(t, err)
553 config.Tracer = tracer
554
555 traceConnectStartCalled = false
556 traceConnectEndCalled = false
557 tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
558 traceConnectEndCalled = true
559 require.Nil(t, data.Conn)
560 require.Error(t, data.Err)
561 }
562
563 conn2, err := pgx.ConnectConfig(context.Background(), config)
564 require.Nil(t, conn2)
565 require.Error(t, err)
566 require.True(t, traceConnectStartCalled)
567 require.True(t, traceConnectEndCalled)
568 }
569
View as plain text