1 package pgx_test
2
3 import (
4 "context"
5 "os"
6 "testing"
7
8 "github.com/stretchr/testify/assert"
9
10 "github.com/jackc/pgconn"
11 "github.com/jackc/pgx/v4"
12 "github.com/stretchr/testify/require"
13 )
14
15 func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
16 t.Run("SimpleProto",
17 func(t *testing.T) {
18 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
19 require.NoError(t, err)
20
21 config.PreferSimpleProtocol = true
22 conn, err := pgx.ConnectConfig(context.Background(), config)
23 require.NoError(t, err)
24 defer func() {
25 err := conn.Close(context.Background())
26 require.NoError(t, err)
27 }()
28
29 f(t, conn)
30
31 ensureConnValid(t, conn)
32 },
33 )
34
35 t.Run("DefaultProto",
36 func(t *testing.T) {
37 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
38 require.NoError(t, err)
39
40 conn, err := pgx.ConnectConfig(context.Background(), config)
41 require.NoError(t, err)
42 defer func() {
43 err := conn.Close(context.Background())
44 require.NoError(t, err)
45 }()
46
47 f(t, conn)
48
49 ensureConnValid(t, conn)
50 },
51 )
52 }
53
54 func mustConnectString(t testing.TB, connString string) *pgx.Conn {
55 conn, err := pgx.Connect(context.Background(), connString)
56 if err != nil {
57 t.Fatalf("Unable to establish connection: %v", err)
58 }
59 return conn
60 }
61
62 func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig {
63 config, err := pgx.ParseConfig(connString)
64 require.Nil(t, err)
65 return config
66 }
67
68 func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
69 conn, err := pgx.ConnectConfig(context.Background(), config)
70 if err != nil {
71 t.Fatalf("Unable to establish connection: %v", err)
72 }
73 return conn
74 }
75
76 func closeConn(t testing.TB, conn *pgx.Conn) {
77 err := conn.Close(context.Background())
78 if err != nil {
79 t.Fatalf("conn.Close unexpectedly failed: %v", err)
80 }
81 }
82
83 func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
84 var err error
85 if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
86 t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
87 }
88 return
89 }
90
91
92 func ensureConnValid(t *testing.T, conn *pgx.Conn) {
93 var sum, rowCount int32
94
95 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
96 if err != nil {
97 t.Fatalf("conn.Query failed: %v", err)
98 }
99 defer rows.Close()
100
101 for rows.Next() {
102 var n int32
103 rows.Scan(&n)
104 sum += n
105 rowCount++
106 }
107
108 if rows.Err() != nil {
109 t.Fatalf("conn.Query failed: %v", err)
110 }
111
112 if rowCount != 10 {
113 t.Error("Select called onDataRow wrong number of times")
114 }
115 if sum != 55 {
116 t.Error("Wrong values returned")
117 }
118 }
119
120 func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
121 if !assert.NotNil(t, expected) {
122 return
123 }
124 if !assert.NotNil(t, actual) {
125 return
126 }
127
128 assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
129 assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
130 assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
131
132 assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName)
133 assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName)
134
135 assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
136 assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
137 assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
138 assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
139 assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
140 assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
141 assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
142
143
144 assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
145 assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
146
147 if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
148 if expected.TLSConfig != nil {
149 assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
150 assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
151 }
152 }
153
154 if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
155 for i := range expected.Fallbacks {
156 assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
157 assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
158
159 if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
160 if expected.Fallbacks[i].TLSConfig != nil {
161 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
162 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
163 }
164 }
165 }
166 }
167 }
168
169 func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) {
170 if conn.PgConn().ParameterStatus("crdb_version") != "" {
171 t.Skip(msg)
172 }
173 }
174
View as plain text