1 package testutil
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "os"
8 "reflect"
9 "testing"
10
11 "github.com/jackc/pgtype"
12 "github.com/jackc/pgx/v4"
13 _ "github.com/jackc/pgx/v4/stdlib"
14 _ "github.com/lib/pq"
15 )
16
17 func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
18 var sqlDriverName string
19 switch driverName {
20 case "github.com/lib/pq":
21 sqlDriverName = "postgres"
22 case "github.com/jackc/pgx/stdlib":
23 sqlDriverName = "pgx"
24 default:
25 t.Fatalf("Unknown driver %v", driverName)
26 }
27
28 db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE"))
29 if err != nil {
30 t.Fatal(err)
31 }
32
33 return db
34 }
35
36 func MustConnectPgx(t testing.TB) *pgx.Conn {
37 conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
38 if err != nil {
39 t.Fatal(err)
40 }
41
42 return conn
43 }
44
45 func MustClose(t testing.TB, conn interface {
46 Close() error
47 }) {
48 err := conn.Close()
49 if err != nil {
50 t.Fatal(err)
51 }
52 }
53
54 func MustCloseContext(t testing.TB, conn interface {
55 Close(context.Context) error
56 }) {
57 err := conn.Close(context.Background())
58 if err != nil {
59 t.Fatal(err)
60 }
61 }
62
63 type forceTextEncoder struct {
64 e pgtype.TextEncoder
65 }
66
67 func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
68 return f.e.EncodeText(ci, buf)
69 }
70
71 type forceBinaryEncoder struct {
72 e pgtype.BinaryEncoder
73 }
74
75 func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
76 return f.e.EncodeBinary(ci, buf)
77 }
78
79 func ForceEncoder(e interface{}, formatCode int16) interface{} {
80 switch formatCode {
81 case pgx.TextFormatCode:
82 if e, ok := e.(pgtype.TextEncoder); ok {
83 return forceTextEncoder{e: e}
84 }
85 case pgx.BinaryFormatCode:
86 if e, ok := e.(pgtype.BinaryEncoder); ok {
87 return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
88 }
89 }
90 return nil
91 }
92
93 func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
94 TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
95 return reflect.DeepEqual(a, b)
96 })
97 }
98
99 func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
100 TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
101 for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
102 TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
103 }
104 }
105
106 func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
107 conn := MustConnectPgx(t)
108 defer MustCloseContext(t, conn)
109
110 _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName))
111 if err != nil {
112 t.Fatal(err)
113 }
114
115 formats := []struct {
116 name string
117 formatCode int16
118 }{
119 {name: "TextFormat", formatCode: pgx.TextFormatCode},
120 {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
121 }
122
123 for i, v := range values {
124 for _, paramFormat := range formats {
125 for _, resultFormat := range formats {
126 vEncoder := ForceEncoder(v, paramFormat.formatCode)
127 if vEncoder == nil {
128 t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name)
129 continue
130 }
131 switch resultFormat.formatCode {
132 case pgx.TextFormatCode:
133 if _, ok := v.(pgtype.TextEncoder); !ok {
134 t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
135 continue
136 }
137 case pgx.BinaryFormatCode:
138 if _, ok := v.(pgtype.BinaryEncoder); !ok {
139 t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
140 continue
141 }
142 }
143
144
145 derefV := v
146 refVal := reflect.ValueOf(v)
147 if refVal.Kind() == reflect.Ptr {
148 derefV = refVal.Elem().Interface()
149 }
150
151 result := reflect.New(reflect.TypeOf(derefV))
152
153 err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface())
154 if err != nil {
155 t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err)
156 }
157
158 if !eqFunc(result.Elem().Interface(), derefV) {
159 t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface())
160 }
161 }
162 }
163 }
164 }
165
166 func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
167 conn := MustConnectDatabaseSQL(t, driverName)
168 defer MustClose(t, conn)
169
170 ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
171 if err != nil {
172 t.Fatal(err)
173 }
174
175 for i, v := range values {
176
177 derefV := v
178 refVal := reflect.ValueOf(v)
179 if refVal.Kind() == reflect.Ptr {
180 derefV = refVal.Elem().Interface()
181 }
182
183 result := reflect.New(reflect.TypeOf(derefV))
184 err := ps.QueryRow(v).Scan(result.Interface())
185 if err != nil {
186 t.Errorf("%v %d: %v", driverName, i, err)
187 }
188
189 if !eqFunc(result.Elem().Interface(), derefV) {
190 t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
191 }
192 }
193 }
194
195 type NormalizeTest struct {
196 SQL string
197 Value interface{}
198 }
199
200 func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
201 TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
202 return reflect.DeepEqual(a, b)
203 })
204 }
205
206 func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
207 TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
208 for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
209 TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
210 }
211 }
212
213 func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
214 conn := MustConnectPgx(t)
215 defer MustCloseContext(t, conn)
216
217 formats := []struct {
218 name string
219 formatCode int16
220 }{
221 {name: "TextFormat", formatCode: pgx.TextFormatCode},
222 {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
223 }
224
225 for i, tt := range tests {
226 for _, fc := range formats {
227 psName := fmt.Sprintf("test%d", i)
228 _, err := conn.Prepare(context.Background(), psName, tt.SQL)
229 if err != nil {
230 t.Fatal(err)
231 }
232
233 queryResultFormats := pgx.QueryResultFormats{fc.formatCode}
234 if ForceEncoder(tt.Value, fc.formatCode) == nil {
235 t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
236 continue
237 }
238
239 derefV := tt.Value
240 refVal := reflect.ValueOf(tt.Value)
241 if refVal.Kind() == reflect.Ptr {
242 derefV = refVal.Elem().Interface()
243 }
244
245 result := reflect.New(reflect.TypeOf(derefV))
246 err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface())
247 if err != nil {
248 t.Errorf("%v %d: %v", fc.name, i, err)
249 }
250
251 if !eqFunc(result.Elem().Interface(), derefV) {
252 t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
253 }
254 }
255 }
256 }
257
258 func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
259 conn := MustConnectDatabaseSQL(t, driverName)
260 defer MustClose(t, conn)
261
262 for i, tt := range tests {
263 ps, err := conn.Prepare(tt.SQL)
264 if err != nil {
265 t.Errorf("%d. %v", i, err)
266 continue
267 }
268
269
270 derefV := tt.Value
271 refVal := reflect.ValueOf(tt.Value)
272 if refVal.Kind() == reflect.Ptr {
273 derefV = refVal.Elem().Interface()
274 }
275
276 result := reflect.New(reflect.TypeOf(derefV))
277 err = ps.QueryRow().Scan(result.Interface())
278 if err != nil {
279 t.Errorf("%v %d: %v", driverName, i, err)
280 }
281
282 if !eqFunc(result.Elem().Interface(), derefV) {
283 t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
284 }
285 }
286 }
287
288 func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) {
289 TestPgxGoZeroToNullConversion(t, pgTypeName, zero)
290 for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
291 TestDatabaseSQLGoZeroToNullConversion(t, driverName, pgTypeName, zero)
292 }
293 }
294
295 func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) {
296 TestPgxNullToGoZeroConversion(t, pgTypeName, zero)
297 for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
298 TestDatabaseSQLNullToGoZeroConversion(t, driverName, pgTypeName, zero)
299 }
300 }
301
302 func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) {
303 conn := MustConnectPgx(t)
304 defer MustCloseContext(t, conn)
305
306 _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName))
307 if err != nil {
308 t.Fatal(err)
309 }
310
311 formats := []struct {
312 name string
313 formatCode int16
314 }{
315 {name: "TextFormat", formatCode: pgx.TextFormatCode},
316 {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
317 }
318
319 for _, paramFormat := range formats {
320 vEncoder := ForceEncoder(zero, paramFormat.formatCode)
321 if vEncoder == nil {
322 t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name)
323 continue
324 }
325
326 var result bool
327 err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result)
328 if err != nil {
329 t.Errorf("Param %s: %v", paramFormat.name, err)
330 }
331
332 if !result {
333 t.Errorf("Param %s: did not convert zero to null", paramFormat.name)
334 }
335 }
336 }
337
338 func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) {
339 conn := MustConnectPgx(t)
340 defer MustCloseContext(t, conn)
341
342 _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName))
343 if err != nil {
344 t.Fatal(err)
345 }
346
347 formats := []struct {
348 name string
349 formatCode int16
350 }{
351 {name: "TextFormat", formatCode: pgx.TextFormatCode},
352 {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
353 }
354
355 for _, resultFormat := range formats {
356
357 switch resultFormat.formatCode {
358 case pgx.TextFormatCode:
359 if _, ok := zero.(pgtype.TextEncoder); !ok {
360 t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name)
361 continue
362 }
363 case pgx.BinaryFormatCode:
364 if _, ok := zero.(pgtype.BinaryEncoder); !ok {
365 t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name)
366 continue
367 }
368 }
369
370
371 derefZero := zero
372 refVal := reflect.ValueOf(zero)
373 if refVal.Kind() == reflect.Ptr {
374 derefZero = refVal.Elem().Interface()
375 }
376
377 result := reflect.New(reflect.TypeOf(derefZero))
378
379 err := conn.QueryRow(context.Background(), "test").Scan(result.Interface())
380 if err != nil {
381 t.Errorf("Result %s: %v", resultFormat.name, err)
382 }
383
384 if !reflect.DeepEqual(result.Elem().Interface(), derefZero) {
385 t.Errorf("Result %s: did not convert null to zero", resultFormat.name)
386 }
387 }
388 }
389
390 func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) {
391 conn := MustConnectDatabaseSQL(t, driverName)
392 defer MustClose(t, conn)
393
394 ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName))
395 if err != nil {
396 t.Fatal(err)
397 }
398
399 var result bool
400 err = ps.QueryRow(zero).Scan(&result)
401 if err != nil {
402 t.Errorf("%v %v", driverName, err)
403 }
404
405 if !result {
406 t.Errorf("%v: did not convert zero to null", driverName)
407 }
408 }
409
410 func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) {
411 conn := MustConnectDatabaseSQL(t, driverName)
412 defer MustClose(t, conn)
413
414 ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName))
415 if err != nil {
416 t.Fatal(err)
417 }
418
419
420 derefZero := zero
421 refVal := reflect.ValueOf(zero)
422 if refVal.Kind() == reflect.Ptr {
423 derefZero = refVal.Elem().Interface()
424 }
425
426 result := reflect.New(reflect.TypeOf(derefZero))
427
428 err = ps.QueryRow().Scan(result.Interface())
429 if err != nil {
430 t.Errorf("%v %v", driverName, err)
431 }
432
433 if !reflect.DeepEqual(result.Elem().Interface(), derefZero) {
434 t.Errorf("%s: did not convert null to zero", driverName)
435 }
436 }
437
View as plain text