1 package pgx_test
2
3 import (
4 "context"
5 "os"
6 "testing"
7
8 "github.com/stretchr/testify/assert"
9
10 "github.com/jackc/pgx/v5"
11 "github.com/jackc/pgx/v5/pgconn"
12 "github.com/jackc/pgx/v5/pgxtest"
13 "github.com/stretchr/testify/require"
14 )
15
16 var defaultConnTestRunner pgxtest.ConnTestRunner
17
18 func init() {
19 defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
20 defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
21 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
22 require.NoError(t, err)
23 return config
24 }
25 }
26
27 func mustConnectString(t testing.TB, connString string) *pgx.Conn {
28 conn, err := pgx.Connect(context.Background(), connString)
29 if err != nil {
30 t.Fatalf("Unable to establish connection: %v", err)
31 }
32 return conn
33 }
34
35 func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig {
36 config, err := pgx.ParseConfig(connString)
37 require.Nil(t, err)
38 return config
39 }
40
41 func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
42 conn, err := pgx.ConnectConfig(context.Background(), config)
43 if err != nil {
44 t.Fatalf("Unable to establish connection: %v", err)
45 }
46 return conn
47 }
48
49 func closeConn(t testing.TB, conn *pgx.Conn) {
50 err := conn.Close(context.Background())
51 if err != nil {
52 t.Fatalf("conn.Close unexpectedly failed: %v", err)
53 }
54 }
55
56 func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) {
57 var err error
58 if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
59 t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
60 }
61 return
62 }
63
64
65 func ensureConnValid(t testing.TB, conn *pgx.Conn) {
66 var sum, rowCount int32
67
68 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
69 if err != nil {
70 t.Fatalf("conn.Query failed: %v", err)
71 }
72 defer rows.Close()
73
74 for rows.Next() {
75 var n int32
76 rows.Scan(&n)
77 sum += n
78 rowCount++
79 }
80
81 if rows.Err() != nil {
82 t.Fatalf("conn.Query failed: %v", rows.Err())
83 }
84
85 if rowCount != 10 {
86 t.Error("Select called onDataRow wrong number of times")
87 }
88 if sum != 55 {
89 t.Error("Wrong values returned")
90 }
91 }
92
93 func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
94 if !assert.NotNil(t, expected) {
95 return
96 }
97 if !assert.NotNil(t, actual) {
98 return
99 }
100
101 assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
102 assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
103 assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
104 assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
105 assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
106 assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
107 assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
108 assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
109 assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
110 assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
111 assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
112 assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
113
114
115 assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
116 assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
117
118 if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
119 if expected.TLSConfig != nil {
120 assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
121 assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
122 }
123 }
124
125 if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
126 for i := range expected.Fallbacks {
127 assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
128 assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
129
130 if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
131 if expected.Fallbacks[i].TLSConfig != nil {
132 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
133 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
134 }
135 }
136 }
137 }
138 }
139
View as plain text