...

Source file src/github.com/jackc/pgx/v5/tracer_test.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/jackc/pgx/v5"
     9  	"github.com/jackc/pgx/v5/pgxtest"
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  type testTracer struct {
    14  	traceQueryStart    func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context
    15  	traceQueryEnd      func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData)
    16  	traceBatchStart    func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context
    17  	traceBatchQuery    func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData)
    18  	traceBatchEnd      func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData)
    19  	traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context
    20  	traceCopyFromEnd   func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData)
    21  	tracePrepareStart  func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context
    22  	tracePrepareEnd    func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData)
    23  	traceConnectStart  func(ctx context.Context, data pgx.TraceConnectStartData) context.Context
    24  	traceConnectEnd    func(ctx context.Context, data pgx.TraceConnectEndData)
    25  }
    26  
    27  type ctxKey string
    28  
    29  func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
    30  	if tt.traceQueryStart != nil {
    31  		return tt.traceQueryStart(ctx, conn, data)
    32  	}
    33  	return ctx
    34  }
    35  
    36  func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
    37  	if tt.traceQueryEnd != nil {
    38  		tt.traceQueryEnd(ctx, conn, data)
    39  	}
    40  }
    41  
    42  func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
    43  	if tt.traceBatchStart != nil {
    44  		return tt.traceBatchStart(ctx, conn, data)
    45  	}
    46  	return ctx
    47  }
    48  
    49  func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
    50  	if tt.traceBatchQuery != nil {
    51  		tt.traceBatchQuery(ctx, conn, data)
    52  	}
    53  }
    54  
    55  func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
    56  	if tt.traceBatchEnd != nil {
    57  		tt.traceBatchEnd(ctx, conn, data)
    58  	}
    59  }
    60  
    61  func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
    62  	if tt.traceCopyFromStart != nil {
    63  		return tt.traceCopyFromStart(ctx, conn, data)
    64  	}
    65  	return ctx
    66  }
    67  
    68  func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
    69  	if tt.traceCopyFromEnd != nil {
    70  		tt.traceCopyFromEnd(ctx, conn, data)
    71  	}
    72  }
    73  
    74  func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
    75  	if tt.tracePrepareStart != nil {
    76  		return tt.tracePrepareStart(ctx, conn, data)
    77  	}
    78  	return ctx
    79  }
    80  
    81  func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
    82  	if tt.tracePrepareEnd != nil {
    83  		tt.tracePrepareEnd(ctx, conn, data)
    84  	}
    85  }
    86  
    87  func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
    88  	if tt.traceConnectStart != nil {
    89  		return tt.traceConnectStart(ctx, data)
    90  	}
    91  	return ctx
    92  }
    93  
    94  func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
    95  	if tt.traceConnectEnd != nil {
    96  		tt.traceConnectEnd(ctx, data)
    97  	}
    98  }
    99  
   100  func TestTraceExec(t *testing.T) {
   101  	t.Parallel()
   102  
   103  	tracer := &testTracer{}
   104  
   105  	ctr := defaultConnTestRunner
   106  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   107  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   108  		config.Tracer = tracer
   109  		return config
   110  	}
   111  
   112  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   113  	defer cancel()
   114  
   115  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   116  		traceQueryStartCalled := false
   117  		tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
   118  			traceQueryStartCalled = true
   119  			require.Equal(t, `select $1::text`, data.SQL)
   120  			require.Len(t, data.Args, 1)
   121  			require.Equal(t, `testing`, data.Args[0])
   122  			return context.WithValue(ctx, ctxKey(ctxKey("fromTraceQueryStart")), "foo")
   123  		}
   124  
   125  		traceQueryEndCalled := false
   126  		tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
   127  			traceQueryEndCalled = true
   128  			require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceQueryStart"))))
   129  			require.Equal(t, `SELECT 1`, data.CommandTag.String())
   130  			require.NoError(t, data.Err)
   131  		}
   132  
   133  		_, err := conn.Exec(ctx, `select $1::text`, "testing")
   134  		require.NoError(t, err)
   135  		require.True(t, traceQueryStartCalled)
   136  		require.True(t, traceQueryEndCalled)
   137  	})
   138  }
   139  
   140  func TestTraceQuery(t *testing.T) {
   141  	t.Parallel()
   142  
   143  	tracer := &testTracer{}
   144  
   145  	ctr := defaultConnTestRunner
   146  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   147  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   148  		config.Tracer = tracer
   149  		return config
   150  	}
   151  
   152  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   153  	defer cancel()
   154  
   155  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   156  		traceQueryStartCalled := false
   157  		tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
   158  			traceQueryStartCalled = true
   159  			require.Equal(t, `select $1::text`, data.SQL)
   160  			require.Len(t, data.Args, 1)
   161  			require.Equal(t, `testing`, data.Args[0])
   162  			return context.WithValue(ctx, ctxKey("fromTraceQueryStart"), "foo")
   163  		}
   164  
   165  		traceQueryEndCalled := false
   166  		tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
   167  			traceQueryEndCalled = true
   168  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceQueryStart")))
   169  			require.Equal(t, `SELECT 1`, data.CommandTag.String())
   170  			require.NoError(t, data.Err)
   171  		}
   172  
   173  		var s string
   174  		err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s)
   175  		require.NoError(t, err)
   176  		require.Equal(t, "testing", s)
   177  		require.True(t, traceQueryStartCalled)
   178  		require.True(t, traceQueryEndCalled)
   179  	})
   180  }
   181  
   182  func TestTraceBatchNormal(t *testing.T) {
   183  	t.Parallel()
   184  
   185  	tracer := &testTracer{}
   186  
   187  	ctr := defaultConnTestRunner
   188  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   189  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   190  		config.Tracer = tracer
   191  		return config
   192  	}
   193  
   194  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   195  	defer cancel()
   196  
   197  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   198  		traceBatchStartCalled := false
   199  		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
   200  			traceBatchStartCalled = true
   201  			require.NotNil(t, data.Batch)
   202  			require.Equal(t, 2, data.Batch.Len())
   203  			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
   204  		}
   205  
   206  		traceBatchQueryCalledCount := 0
   207  		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
   208  			traceBatchQueryCalledCount++
   209  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   210  			require.NoError(t, data.Err)
   211  		}
   212  
   213  		traceBatchEndCalled := false
   214  		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
   215  			traceBatchEndCalled = true
   216  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   217  			require.NoError(t, data.Err)
   218  		}
   219  
   220  		batch := &pgx.Batch{}
   221  		batch.Queue(`select 1`)
   222  		batch.Queue(`select 2`)
   223  
   224  		br := conn.SendBatch(context.Background(), batch)
   225  		require.True(t, traceBatchStartCalled)
   226  
   227  		var n int32
   228  		err := br.QueryRow().Scan(&n)
   229  		require.NoError(t, err)
   230  		require.EqualValues(t, 1, n)
   231  		require.EqualValues(t, 1, traceBatchQueryCalledCount)
   232  
   233  		err = br.QueryRow().Scan(&n)
   234  		require.NoError(t, err)
   235  		require.EqualValues(t, 2, n)
   236  		require.EqualValues(t, 2, traceBatchQueryCalledCount)
   237  
   238  		err = br.Close()
   239  		require.NoError(t, err)
   240  
   241  		require.True(t, traceBatchEndCalled)
   242  	})
   243  }
   244  
   245  func TestTraceBatchClose(t *testing.T) {
   246  	t.Parallel()
   247  
   248  	tracer := &testTracer{}
   249  
   250  	ctr := defaultConnTestRunner
   251  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   252  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   253  		config.Tracer = tracer
   254  		return config
   255  	}
   256  
   257  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   258  	defer cancel()
   259  
   260  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   261  		traceBatchStartCalled := false
   262  		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
   263  			traceBatchStartCalled = true
   264  			require.NotNil(t, data.Batch)
   265  			require.Equal(t, 2, data.Batch.Len())
   266  			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
   267  		}
   268  
   269  		traceBatchQueryCalledCount := 0
   270  		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
   271  			traceBatchQueryCalledCount++
   272  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   273  			require.NoError(t, data.Err)
   274  		}
   275  
   276  		traceBatchEndCalled := false
   277  		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
   278  			traceBatchEndCalled = true
   279  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   280  			require.NoError(t, data.Err)
   281  		}
   282  
   283  		batch := &pgx.Batch{}
   284  		batch.Queue(`select 1`)
   285  		batch.Queue(`select 2`)
   286  
   287  		br := conn.SendBatch(context.Background(), batch)
   288  		require.True(t, traceBatchStartCalled)
   289  		err := br.Close()
   290  		require.NoError(t, err)
   291  		require.EqualValues(t, 2, traceBatchQueryCalledCount)
   292  		require.True(t, traceBatchEndCalled)
   293  	})
   294  }
   295  
   296  func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
   297  	t.Parallel()
   298  
   299  	tracer := &testTracer{}
   300  
   301  	ctr := defaultConnTestRunner
   302  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   303  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   304  		config.Tracer = tracer
   305  		return config
   306  	}
   307  
   308  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   309  	defer cancel()
   310  
   311  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   312  		traceBatchStartCalled := false
   313  		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
   314  			traceBatchStartCalled = true
   315  			require.NotNil(t, data.Batch)
   316  			require.Equal(t, 3, data.Batch.Len())
   317  			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
   318  		}
   319  
   320  		traceBatchQueryCalledCount := 0
   321  		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
   322  			traceBatchQueryCalledCount++
   323  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   324  			if traceBatchQueryCalledCount == 2 {
   325  				require.Error(t, data.Err)
   326  			} else {
   327  				require.NoError(t, data.Err)
   328  			}
   329  		}
   330  
   331  		traceBatchEndCalled := false
   332  		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
   333  			traceBatchEndCalled = true
   334  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   335  			require.Error(t, data.Err)
   336  		}
   337  
   338  		batch := &pgx.Batch{}
   339  		batch.Queue(`select 1`)
   340  		batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
   341  		batch.Queue(`select 3`)
   342  
   343  		br := conn.SendBatch(context.Background(), batch)
   344  		require.True(t, traceBatchStartCalled)
   345  
   346  		commandTag, err := br.Exec()
   347  		require.NoError(t, err)
   348  		require.Equal(t, "SELECT 1", commandTag.String())
   349  
   350  		commandTag, err = br.Exec()
   351  		require.Error(t, err)
   352  		require.Equal(t, "", commandTag.String())
   353  
   354  		commandTag, err = br.Exec()
   355  		require.Error(t, err)
   356  		require.Equal(t, "", commandTag.String())
   357  
   358  		err = br.Close()
   359  		require.Error(t, err)
   360  		require.EqualValues(t, 2, traceBatchQueryCalledCount)
   361  		require.True(t, traceBatchEndCalled)
   362  	})
   363  }
   364  
   365  func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
   366  	t.Parallel()
   367  
   368  	tracer := &testTracer{}
   369  
   370  	ctr := defaultConnTestRunner
   371  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   372  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   373  		config.Tracer = tracer
   374  		return config
   375  	}
   376  
   377  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   378  	defer cancel()
   379  
   380  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   381  		traceBatchStartCalled := false
   382  		tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
   383  			traceBatchStartCalled = true
   384  			require.NotNil(t, data.Batch)
   385  			require.Equal(t, 3, data.Batch.Len())
   386  			return context.WithValue(ctx, ctxKey("fromTraceBatchStart"), "foo")
   387  		}
   388  
   389  		traceBatchQueryCalledCount := 0
   390  		tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
   391  			traceBatchQueryCalledCount++
   392  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   393  			if traceBatchQueryCalledCount == 2 {
   394  				require.Error(t, data.Err)
   395  			} else {
   396  				require.NoError(t, data.Err)
   397  			}
   398  		}
   399  
   400  		traceBatchEndCalled := false
   401  		tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
   402  			traceBatchEndCalled = true
   403  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceBatchStart")))
   404  			require.Error(t, data.Err)
   405  		}
   406  
   407  		batch := &pgx.Batch{}
   408  		batch.Queue(`select 1`)
   409  		batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
   410  		batch.Queue(`select 3`)
   411  
   412  		br := conn.SendBatch(context.Background(), batch)
   413  		require.True(t, traceBatchStartCalled)
   414  		err := br.Close()
   415  		require.Error(t, err)
   416  		require.EqualValues(t, 2, traceBatchQueryCalledCount)
   417  		require.True(t, traceBatchEndCalled)
   418  	})
   419  }
   420  
   421  func TestTraceCopyFrom(t *testing.T) {
   422  	t.Parallel()
   423  
   424  	tracer := &testTracer{}
   425  
   426  	ctr := defaultConnTestRunner
   427  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   428  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   429  		config.Tracer = tracer
   430  		return config
   431  	}
   432  
   433  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   434  	defer cancel()
   435  
   436  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   437  		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   438  		defer cancel()
   439  
   440  		traceCopyFromStartCalled := false
   441  		tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
   442  			traceCopyFromStartCalled = true
   443  			require.Equal(t, pgx.Identifier{"foo"}, data.TableName)
   444  			require.Equal(t, []string{"a"}, data.ColumnNames)
   445  			return context.WithValue(ctx, ctxKey("fromTraceCopyFromStart"), "foo")
   446  		}
   447  
   448  		traceCopyFromEndCalled := false
   449  		tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
   450  			traceCopyFromEndCalled = true
   451  			require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceCopyFromStart")))
   452  			require.Equal(t, `COPY 2`, data.CommandTag.String())
   453  			require.NoError(t, data.Err)
   454  		}
   455  
   456  		_, err := conn.Exec(ctx, `create temporary table foo(a int4)`)
   457  		require.NoError(t, err)
   458  
   459  		inputRows := [][]any{
   460  			{int32(1)},
   461  			{nil},
   462  		}
   463  
   464  		copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
   465  		require.NoError(t, err)
   466  		require.EqualValues(t, len(inputRows), copyCount)
   467  		require.True(t, traceCopyFromStartCalled)
   468  		require.True(t, traceCopyFromEndCalled)
   469  	})
   470  }
   471  
   472  func TestTracePrepare(t *testing.T) {
   473  	t.Parallel()
   474  
   475  	tracer := &testTracer{}
   476  
   477  	ctr := defaultConnTestRunner
   478  	ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
   479  		config := defaultConnTestRunner.CreateConfig(ctx, t)
   480  		config.Tracer = tracer
   481  		return config
   482  	}
   483  
   484  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   485  	defer cancel()
   486  
   487  	pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   488  		tracePrepareStartCalled := false
   489  		tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
   490  			tracePrepareStartCalled = true
   491  			require.Equal(t, `ps`, data.Name)
   492  			require.Equal(t, `select $1::text`, data.SQL)
   493  			return context.WithValue(ctx, ctxKey("fromTracePrepareStart"), "foo")
   494  		}
   495  
   496  		tracePrepareEndCalled := false
   497  		tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
   498  			tracePrepareEndCalled = true
   499  			require.False(t, data.AlreadyPrepared)
   500  			require.NoError(t, data.Err)
   501  		}
   502  
   503  		_, err := conn.Prepare(ctx, "ps", `select $1::text`)
   504  		require.NoError(t, err)
   505  		require.True(t, tracePrepareStartCalled)
   506  		require.True(t, tracePrepareEndCalled)
   507  
   508  		tracePrepareStartCalled = false
   509  		tracePrepareEndCalled = false
   510  		tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
   511  			tracePrepareEndCalled = true
   512  			require.True(t, data.AlreadyPrepared)
   513  			require.NoError(t, data.Err)
   514  		}
   515  
   516  		_, err = conn.Prepare(ctx, "ps", `select $1::text`)
   517  		require.NoError(t, err)
   518  		require.True(t, tracePrepareStartCalled)
   519  		require.True(t, tracePrepareEndCalled)
   520  	})
   521  }
   522  
   523  func TestTraceConnect(t *testing.T) {
   524  	t.Parallel()
   525  
   526  	tracer := &testTracer{}
   527  
   528  	config := defaultConnTestRunner.CreateConfig(context.Background(), t)
   529  	config.Tracer = tracer
   530  
   531  	traceConnectStartCalled := false
   532  	tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
   533  		traceConnectStartCalled = true
   534  		require.NotNil(t, data.ConnConfig)
   535  		return context.WithValue(ctx, ctxKey("fromTraceConnectStart"), "foo")
   536  	}
   537  
   538  	traceConnectEndCalled := false
   539  	tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
   540  		traceConnectEndCalled = true
   541  		require.NotNil(t, data.Conn)
   542  		require.NoError(t, data.Err)
   543  	}
   544  
   545  	conn1, err := pgx.ConnectConfig(context.Background(), config)
   546  	require.NoError(t, err)
   547  	defer conn1.Close(context.Background())
   548  	require.True(t, traceConnectStartCalled)
   549  	require.True(t, traceConnectEndCalled)
   550  
   551  	config, err = pgx.ParseConfig("host=/invalid")
   552  	require.NoError(t, err)
   553  	config.Tracer = tracer
   554  
   555  	traceConnectStartCalled = false
   556  	traceConnectEndCalled = false
   557  	tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
   558  		traceConnectEndCalled = true
   559  		require.Nil(t, data.Conn)
   560  		require.Error(t, data.Err)
   561  	}
   562  
   563  	conn2, err := pgx.ConnectConfig(context.Background(), config)
   564  	require.Nil(t, conn2)
   565  	require.Error(t, err)
   566  	require.True(t, traceConnectStartCalled)
   567  	require.True(t, traceConnectEndCalled)
   568  }
   569  

View as plain text