1 package pq
2
3
4
5 import (
6 "bytes"
7 _ "crypto/sha256"
8 "crypto/tls"
9 "crypto/x509"
10 "database/sql"
11 "fmt"
12 "io"
13 "net"
14 "os"
15 "path/filepath"
16 "strings"
17 "testing"
18 "time"
19 )
20
21 func maybeSkipSSLTests(t *testing.T) {
22
23 if os.Getenv("PQSSLCERTTEST_PATH") == "" {
24 t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests")
25 }
26
27 value := os.Getenv("PQGOSSLTESTS")
28 if value == "" || value == "0" {
29 t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests")
30 } else if value != "1" {
31 t.Fatalf("unexpected value %q for PQGOSSLTESTS", value)
32 }
33 }
34
35 func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) {
36 db, err := openTestConnConninfo(conninfo)
37 if err != nil {
38
39 t.Fatal(err)
40 }
41
42 tx, err := db.Begin()
43 if err == nil {
44 return db, tx.Rollback()
45 }
46 _ = db.Close()
47 return nil, err
48 }
49
50 func checkSSLSetup(t *testing.T, conninfo string) {
51 _, err := openSSLConn(t, conninfo)
52 if pge, ok := err.(*Error); ok {
53 if pge.Code.Name() != "invalid_authorization_specification" {
54 t.Fatalf("unexpected error code '%s'", pge.Code.Name())
55 }
56 } else {
57 t.Fatalf("expected %T, got %v", (*Error)(nil), err)
58 }
59 }
60
61
62 func TestSSLConnection(t *testing.T) {
63 maybeSkipSSLTests(t)
64
65 checkSSLSetup(t, "sslmode=disable user=pqgossltest")
66
67 db, err := openSSLConn(t, "sslmode=require user=pqgossltest")
68 if err != nil {
69 t.Fatal(err)
70 }
71 rows, err := db.Query("SELECT 1")
72 if err != nil {
73 t.Fatal(err)
74 }
75 rows.Close()
76 }
77
78
79 func TestSSLVerifyFull(t *testing.T) {
80 maybeSkipSSLTests(t)
81
82 checkSSLSetup(t, "sslmode=disable user=pqgossltest")
83
84
85 _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest")
86 if err == nil {
87 t.Fatal("expected error")
88 }
89 _, ok := err.(x509.UnknownAuthorityError)
90 if !ok {
91 _, ok := err.(x509.HostnameError)
92 if !ok {
93 t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err)
94 }
95 }
96
97 rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
98 rootCert := "sslrootcert=" + rootCertPath + " "
99
100 _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest")
101 if err == nil {
102 t.Fatal("expected error")
103 }
104 _, ok = err.(x509.HostnameError)
105 if !ok {
106 t.Fatalf("expected x509.HostnameError, got %#+v", err)
107 }
108
109 _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest")
110 if err != nil {
111 t.Fatal(err)
112 }
113 }
114
115
116 func TestSSLRequireWithRootCert(t *testing.T) {
117 maybeSkipSSLTests(t)
118
119 checkSSLSetup(t, "sslmode=disable user=pqgossltest")
120
121 bogusRootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "bogus_root.crt")
122 bogusRootCert := "sslrootcert=" + bogusRootCertPath + " "
123
124
125 _, err := openSSLConn(t, bogusRootCert+"host=postgres sslmode=require user=pqgossltest")
126 if err == nil {
127 t.Fatal("expected error")
128 }
129 _, ok := err.(x509.UnknownAuthorityError)
130 if !ok {
131 t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err)
132 }
133
134 nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt")
135 nonExistentCert := "sslrootcert=" + nonExistentCertPath + " "
136
137
138 _, err = openSSLConn(t, nonExistentCert+"host=127.0.0.1 sslmode=require user=pqgossltest")
139 if err != nil {
140 t.Fatal(err)
141 }
142
143 rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
144 rootCert := "sslrootcert=" + rootCertPath + " "
145
146
147 _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=require user=pqgossltest")
148 if err != nil {
149 t.Fatal(err)
150 }
151
152 _, err = openSSLConn(t, rootCert+"host=postgres sslmode=require user=pqgossltest")
153 if err != nil {
154 t.Fatal(err)
155 }
156 }
157
158
159 func TestSSLVerifyCA(t *testing.T) {
160 maybeSkipSSLTests(t)
161
162 checkSSLSetup(t, "sslmode=disable user=pqgossltest")
163
164
165 {
166 _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest")
167 if _, ok := err.(x509.UnknownAuthorityError); !ok {
168 t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err)
169 }
170 }
171
172
173 {
174 _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''")
175 if _, ok := err.(x509.UnknownAuthorityError); !ok {
176 t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err)
177 }
178 }
179
180 rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
181 rootCert := "sslrootcert=" + rootCertPath + " "
182
183 if _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil {
184 t.Fatal(err)
185 }
186
187 if _, err := openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest"); err != nil {
188 t.Fatal(err)
189 }
190 }
191
192
193 func TestSSLClientCertificates(t *testing.T) {
194 maybeSkipSSLTests(t)
195
196 checkSSLSetup(t, "sslmode=disable user=pqgossltest")
197
198 const baseinfo = "sslmode=require user=pqgosslcert"
199
200
201 {
202 _, err := openSSLConn(t, baseinfo)
203 if pge, ok := err.(*Error); ok {
204 if pge.Code.Name() != "invalid_authorization_specification" {
205 t.Fatalf("unexpected error code '%s'", pge.Code.Name())
206 }
207 } else {
208 t.Fatalf("expected %T, got %v", (*Error)(nil), err)
209 }
210 }
211
212
213 {
214 _, err := openSSLConn(t, baseinfo+" sslcert=''")
215 if pge, ok := err.(*Error); ok {
216 if pge.Code.Name() != "invalid_authorization_specification" {
217 t.Fatalf("unexpected error code '%s'", pge.Code.Name())
218 }
219 } else {
220 t.Fatalf("expected %T, got %v", (*Error)(nil), err)
221 }
222 }
223
224
225 {
226 _, err := openSSLConn(t, baseinfo+" sslcert=/tmp/filedoesnotexist")
227 if pge, ok := err.(*Error); ok {
228 if pge.Code.Name() != "invalid_authorization_specification" {
229 t.Fatalf("unexpected error code '%s'", pge.Code.Name())
230 }
231 } else {
232 t.Fatalf("expected %T, got %v", (*Error)(nil), err)
233 }
234 }
235
236 certpath, ok := os.LookupEnv("PQSSLCERTTEST_PATH")
237 if !ok {
238 t.Fatalf("PQSSLCERTTEST_PATH not present in environment")
239 }
240
241 sslcert := filepath.Join(certpath, "postgresql.crt")
242
243
244 {
245 _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert)
246 if _, ok := err.(*os.PathError); !ok {
247 t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
248 }
249 }
250
251
252 {
253 _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''")
254 if _, ok := err.(*os.PathError); !ok {
255 t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
256 }
257 }
258
259
260 {
261 _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist")
262 if _, ok := err.(*os.PathError); !ok {
263 t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
264 }
265 }
266
267
268 if _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslcert); err != ErrSSLKeyHasWorldPermissions {
269 t.Fatalf("expected %s, got %#+v", ErrSSLKeyHasWorldPermissions, err)
270 }
271
272 sslkey := filepath.Join(certpath, "postgresql.key")
273
274
275 if db, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslkey); err != nil {
276 t.Fatal(err)
277 } else {
278 rows, err := db.Query("SELECT 1")
279 if err != nil {
280 t.Fatal(err)
281 }
282 if err := rows.Close(); err != nil {
283 t.Fatal(err)
284 }
285 if err := db.Close(); err != nil {
286 t.Fatal(err)
287 }
288 }
289 }
290
291
292 func TestSNISupport(t *testing.T) {
293 t.Parallel()
294 tests := []struct {
295 name string
296 conn_param string
297 hostname string
298 expected_sni string
299 }{
300 {
301 name: "SNI is set by default",
302 conn_param: "",
303 hostname: "localhost",
304 expected_sni: "localhost",
305 },
306 {
307 name: "SNI is passed when asked for",
308 conn_param: "sslsni=1",
309 hostname: "localhost",
310 expected_sni: "localhost",
311 },
312 {
313 name: "SNI is not passed when disabled",
314 conn_param: "sslsni=0",
315 hostname: "localhost",
316 expected_sni: "",
317 },
318 {
319 name: "SNI is not set for IPv4",
320 conn_param: "",
321 hostname: "127.0.0.1",
322 expected_sni: "",
323 },
324 }
325 for _, tt := range tests {
326 tt := tt
327 t.Run(tt.name, func(t *testing.T) {
328 t.Parallel()
329
330
331 listener, err := net.Listen("tcp", "127.0.0.1:")
332 if err != nil {
333 t.Fatal(err)
334 }
335 serverErrChan := make(chan error, 1)
336 serverSNINameChan := make(chan string, 1)
337 go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)
338
339 defer listener.Close()
340 defer close(serverErrChan)
341 defer close(serverSNINameChan)
342
343
344
345 port := strings.Split(listener.Addr().String(), ":")[1]
346 connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param)
347
348
349
350 db, _ := sql.Open("postgres", connStr)
351 _, _ = db.Exec("SELECT 1")
352
353
354 select {
355 case sniHost := <-serverSNINameChan:
356 if sniHost != tt.expected_sni {
357 t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost)
358 }
359 case err = <-serverErrChan:
360 t.Fatalf("mock server failed with error: %+v", err)
361 case <-time.After(time.Second):
362 t.Fatal("exceeded connection timeout without erroring out")
363 }
364 })
365 }
366 }
367
368
369
370
371
372 func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
373 var sniHost string
374
375 conn, err := listener.Accept()
376 if err != nil {
377 errChan <- err
378 return
379 }
380 defer conn.Close()
381
382 err = conn.SetDeadline(time.Now().Add(time.Second))
383 if err != nil {
384 errChan <- err
385 return
386 }
387
388
389 startupMessage := make([]byte, 8)
390 if _, err := io.ReadFull(conn, startupMessage); err != nil {
391 errChan <- err
392 return
393 }
394
395 if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
396 errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
397 return
398 }
399
400
401 _, err = conn.Write([]byte("S"))
402 if err != nil {
403 errChan <- err
404 return
405 }
406
407
408
409 srv := tls.Server(conn, &tls.Config{
410 GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
411 sniHost = argHello.ServerName
412 return nil, nil
413 },
414 })
415 defer srv.Close()
416
417
418 _ = srv.Handshake()
419
420 nameChan <- sniHost
421 }
422
View as plain text