1
2 package pgxtest
3
4 import (
5 "context"
6 "fmt"
7 "reflect"
8 "regexp"
9 "strconv"
10 "testing"
11
12 "github.com/jackc/pgx/v5"
13 )
14
15 var AllQueryExecModes = []pgx.QueryExecMode{
16 pgx.QueryExecModeCacheStatement,
17 pgx.QueryExecModeCacheDescribe,
18 pgx.QueryExecModeDescribeExec,
19 pgx.QueryExecModeExec,
20 pgx.QueryExecModeSimpleProtocol,
21 }
22
23
24 var KnownOIDQueryExecModes = []pgx.QueryExecMode{
25 pgx.QueryExecModeCacheStatement,
26 pgx.QueryExecModeCacheDescribe,
27 pgx.QueryExecModeDescribeExec,
28 }
29
30
31
32 type ConnTestRunner struct {
33
34 CreateConfig func(ctx context.Context, t testing.TB) *pgx.ConnConfig
35
36
37 AfterConnect func(ctx context.Context, t testing.TB, conn *pgx.Conn)
38
39
40 AfterTest func(ctx context.Context, t testing.TB, conn *pgx.Conn)
41
42
43 CloseConn func(ctx context.Context, t testing.TB, conn *pgx.Conn)
44 }
45
46
47 func DefaultConnTestRunner() ConnTestRunner {
48 return ConnTestRunner{
49 CreateConfig: func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
50 config, err := pgx.ParseConfig("")
51 if err != nil {
52 t.Fatalf("ParseConfig failed: %v", err)
53 }
54 return config
55 },
56 AfterConnect: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {},
57 AfterTest: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {},
58 CloseConn: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
59 err := conn.Close(ctx)
60 if err != nil {
61 t.Errorf("Close failed: %v", err)
62 }
63 },
64 }
65 }
66
67 func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) {
68 t.Helper()
69
70 config := ctr.CreateConfig(ctx, t)
71 conn, err := pgx.ConnectConfig(ctx, config)
72 if err != nil {
73 t.Fatalf("ConnectConfig failed: %v", err)
74 }
75 defer ctr.CloseConn(ctx, t, conn)
76
77 ctr.AfterConnect(ctx, t, conn)
78 f(ctx, t, conn)
79 ctr.AfterTest(ctx, t, conn)
80 }
81
82
83
84 func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner, modes []pgx.QueryExecMode, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) {
85 if modes == nil {
86 modes = AllQueryExecModes
87 }
88
89 for _, mode := range modes {
90 ctrWithMode := ctr
91 ctrWithMode.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
92 config := ctr.CreateConfig(ctx, t)
93 config.DefaultQueryExecMode = mode
94 return config
95 }
96
97 t.Run(mode.String(),
98 func(t *testing.T) {
99 ctrWithMode.RunTest(ctx, t, f)
100 },
101 )
102 }
103 }
104
105 type ValueRoundTripTest struct {
106 Param any
107 Result any
108 Test func(any) bool
109 }
110
111 func RunValueRoundTripTests(
112 ctx context.Context,
113 t testing.TB,
114 ctr ConnTestRunner,
115 modes []pgx.QueryExecMode,
116 pgTypeName string,
117 tests []ValueRoundTripTest,
118 ) {
119 t.Helper()
120
121 if modes == nil {
122 modes = AllQueryExecModes
123 }
124
125 ctr.RunTest(ctx, t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
126 t.Helper()
127
128 sql := fmt.Sprintf("select $1::%s", pgTypeName)
129
130 for i, tt := range tests {
131 for _, mode := range modes {
132 err := conn.QueryRow(ctx, sql, mode, tt.Param).Scan(tt.Result)
133 if err != nil {
134 t.Errorf("%d. %v: %v", i, mode, err)
135 }
136
137 result := reflect.ValueOf(tt.Result)
138 if result.Kind() == reflect.Ptr {
139 result = result.Elem()
140 }
141
142 if !tt.Test(result.Interface()) {
143 t.Errorf("%d. %v: unexpected result for %v: %v", i, mode, tt.Param, result.Interface())
144 }
145 }
146 }
147 })
148 }
149
150
151 func SkipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) {
152 if conn.PgConn().ParameterStatus("crdb_version") != "" {
153 t.Skip(msg)
154 }
155 }
156
157 func SkipPostgreSQLVersionLessThan(t testing.TB, conn *pgx.Conn, minVersion int64) {
158 serverVersionStr := conn.PgConn().ParameterStatus("server_version")
159 serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
160
161 if serverVersionStr == "" {
162 return
163 }
164
165 serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
166 if err != nil {
167 t.Fatalf("postgres version parsed failed: %s", err)
168 }
169
170 if serverVersion < minVersion {
171 t.Skipf("Test requires PostgreSQL v%d+", minVersion)
172 }
173 }
174
View as plain text