1 package pgxpool_test
2
3 import (
4 "context"
5 "testing"
6 "time"
7
8 "github.com/jackc/pgx/v5/pgxpool"
9
10 "github.com/jackc/pgx/v5"
11 "github.com/jackc/pgx/v5/pgconn"
12 "github.com/stretchr/testify/assert"
13 "github.com/stretchr/testify/require"
14 )
15
16
17
18
19 func waitForReleaseToComplete() {
20 time.Sleep(500 * time.Millisecond)
21 }
22
23 type execer interface {
24 Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
25 }
26
27 func testExec(t *testing.T, ctx context.Context, db execer) {
28 results, err := db.Exec(ctx, "set time zone 'America/Chicago'")
29 require.NoError(t, err)
30 assert.EqualValues(t, "SET", results.String())
31 }
32
33 type queryer interface {
34 Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
35 }
36
37 func testQuery(t *testing.T, ctx context.Context, db queryer) {
38 var sum, rowCount int32
39
40 rows, err := db.Query(ctx, "select generate_series(1,$1)", 10)
41 require.NoError(t, err)
42
43 for rows.Next() {
44 var n int32
45 rows.Scan(&n)
46 sum += n
47 rowCount++
48 }
49
50 assert.NoError(t, rows.Err())
51 assert.Equal(t, int32(10), rowCount)
52 assert.Equal(t, int32(55), sum)
53 }
54
55 type queryRower interface {
56 QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
57 }
58
59 func testQueryRow(t *testing.T, ctx context.Context, db queryRower) {
60 var what, who string
61 err := db.QueryRow(ctx, "select 'hello', $1::text", "world").Scan(&what, &who)
62 assert.NoError(t, err)
63 assert.Equal(t, "hello", what)
64 assert.Equal(t, "world", who)
65 }
66
67 type sendBatcher interface {
68 SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
69 }
70
71 func testSendBatch(t *testing.T, ctx context.Context, db sendBatcher) {
72 batch := &pgx.Batch{}
73 batch.Queue("select 1")
74 batch.Queue("select 2")
75
76 br := db.SendBatch(ctx, batch)
77
78 var err error
79 var n int32
80 err = br.QueryRow().Scan(&n)
81 assert.NoError(t, err)
82 assert.EqualValues(t, 1, n)
83
84 err = br.QueryRow().Scan(&n)
85 assert.NoError(t, err)
86 assert.EqualValues(t, 2, n)
87
88 err = br.Close()
89 assert.NoError(t, err)
90 }
91
92 type copyFromer interface {
93 CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
94 }
95
96 func testCopyFrom(t *testing.T, ctx context.Context, db interface {
97 execer
98 queryer
99 copyFromer
100 }) {
101 _, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
102 require.NoError(t, err)
103
104 tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
105
106 inputRows := [][]any{
107 {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
108 {nil, nil, nil, nil, nil, nil, nil},
109 }
110
111 copyCount, err := db.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
112 assert.NoError(t, err)
113 assert.EqualValues(t, len(inputRows), copyCount)
114
115 rows, err := db.Query(ctx, "select * from foo")
116 assert.NoError(t, err)
117
118 var outputRows [][]any
119 for rows.Next() {
120 row, err := rows.Values()
121 if err != nil {
122 t.Errorf("Unexpected error for rows.Values(): %v", err)
123 }
124 outputRows = append(outputRows, row)
125 }
126
127 assert.NoError(t, rows.Err())
128 assert.Equal(t, inputRows, outputRows)
129 }
130
131 func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName string) {
132 if !assert.NotNil(t, expected) {
133 return
134 }
135 if !assert.NotNil(t, actual) {
136 return
137 }
138
139 assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
140
141
142 assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
143 assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName)
144 assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName)
145
146 assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName)
147 assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName)
148 assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName)
149 assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName)
150 assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName)
151
152 assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName)
153 }
154
155 func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
156 if !assert.NotNil(t, expected) {
157 return
158 }
159 if !assert.NotNil(t, actual) {
160 return
161 }
162
163 assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
164 assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
165 assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
166 assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
167 assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
168 assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
169 assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
170 assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
171 assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
172 assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
173 assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
174 assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
175 assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
176
177
178 assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
179 assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
180
181 if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
182 if expected.TLSConfig != nil {
183 assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
184 assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
185 }
186 }
187
188 if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
189 for i := range expected.Fallbacks {
190 assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
191 assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
192
193 if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
194 if expected.Fallbacks[i].TLSConfig != nil {
195 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
196 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
197 }
198 }
199 }
200 }
201 }
202
View as plain text