1 package pgconn_test
2
3 import (
4 "bytes"
5 "compress/gzip"
6 "context"
7 "crypto/tls"
8 "errors"
9 "fmt"
10 "io"
11 "log"
12 "math"
13 "net"
14 "os"
15 "strconv"
16 "strings"
17 "testing"
18 "time"
19
20 "github.com/stretchr/testify/assert"
21 "github.com/stretchr/testify/require"
22
23 "github.com/jackc/pgx/v5"
24 "github.com/jackc/pgx/v5/internal/pgio"
25 "github.com/jackc/pgx/v5/internal/pgmock"
26 "github.com/jackc/pgx/v5/pgconn"
27 "github.com/jackc/pgx/v5/pgproto3"
28 "github.com/jackc/pgx/v5/pgtype"
29 )
30
31 const pgbouncerConnStringEnvVar = "PGX_TEST_PGBOUNCER_CONN_STRING"
32
33 func TestConnect(t *testing.T) {
34 tests := []struct {
35 name string
36 env string
37 }{
38 {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
39 {"TCP", "PGX_TEST_TCP_CONN_STRING"},
40 {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
41 {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
42 {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
43 }
44
45 for _, tt := range tests {
46 tt := tt
47 t.Run(tt.name, func(t *testing.T) {
48 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
49 defer cancel()
50
51 connString := os.Getenv(tt.env)
52 if connString == "" {
53 t.Skipf("Skipping due to missing environment variable %v", tt.env)
54 }
55
56 conn, err := pgconn.Connect(ctx, connString)
57 require.NoError(t, err)
58
59 closeConn(t, conn)
60 })
61 }
62 }
63
64 func TestConnectWithOptions(t *testing.T) {
65 tests := []struct {
66 name string
67 env string
68 }{
69 {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
70 {"TCP", "PGX_TEST_TCP_CONN_STRING"},
71 {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
72 {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
73 {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
74 }
75
76 for _, tt := range tests {
77 tt := tt
78 t.Run(tt.name, func(t *testing.T) {
79 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
80 defer cancel()
81
82 connString := os.Getenv(tt.env)
83 if connString == "" {
84 t.Skipf("Skipping due to missing environment variable %v", tt.env)
85 }
86 var sslOptions pgconn.ParseConfigOptions
87 sslOptions.GetSSLPassword = GetSSLPassword
88 conn, err := pgconn.ConnectWithOptions(ctx, connString, sslOptions)
89 require.NoError(t, err)
90
91 closeConn(t, conn)
92 })
93 }
94 }
95
96
97
98 func TestConnectTLS(t *testing.T) {
99 t.Parallel()
100
101 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
102 defer cancel()
103
104 connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
105 if connString == "" {
106 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
107 }
108
109 conn, err := pgconn.Connect(ctx, connString)
110 require.NoError(t, err)
111
112 result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
113 require.NoError(t, result.Err)
114 require.Len(t, result.Rows, 1)
115 require.Len(t, result.Rows[0], 1)
116 require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
117
118 closeConn(t, conn)
119 }
120
121 func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) {
122 t.Parallel()
123
124 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
125 defer cancel()
126
127 connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
128 if connString == "" {
129 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
130 }
131 if os.Getenv("PGX_SSL_PASSWORD") == "" {
132 t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
133 }
134
135 connString += " sslpassword=" + os.Getenv("PGX_SSL_PASSWORD")
136
137 conn, err := pgconn.Connect(ctx, connString)
138 require.NoError(t, err)
139
140 result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
141 require.NoError(t, result.Err)
142 require.Len(t, result.Rows, 1)
143 require.Len(t, result.Rows[0], 1)
144 require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
145
146 closeConn(t, conn)
147 }
148
149 func TestConnectTLSPasswordProtectedClientCertWithGetSSLPasswordConfigOption(t *testing.T) {
150 t.Parallel()
151
152 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
153 defer cancel()
154
155 connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
156 if connString == "" {
157 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
158 }
159 if os.Getenv("PGX_SSL_PASSWORD") == "" {
160 t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
161 }
162
163 var sslOptions pgconn.ParseConfigOptions
164 sslOptions.GetSSLPassword = GetSSLPassword
165 config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
166 require.Nil(t, err)
167
168 conn, err := pgconn.ConnectConfig(ctx, config)
169 require.NoError(t, err)
170
171 result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
172 require.NoError(t, result.Err)
173 require.Len(t, result.Rows, 1)
174 require.Len(t, result.Rows[0], 1)
175 require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
176
177 closeConn(t, conn)
178 }
179
180 type pgmockWaitStep time.Duration
181
182 func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
183 time.Sleep(time.Duration(s))
184 return nil
185 }
186
187 func TestConnectTimeout(t *testing.T) {
188 t.Parallel()
189 tests := []struct {
190 name string
191 connect func(connStr string) error
192 }{
193 {
194 name: "via context that times out",
195 connect: func(connStr string) error {
196 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
197 defer cancel()
198 _, err := pgconn.Connect(ctx, connStr)
199 return err
200 },
201 },
202 {
203 name: "via config ConnectTimeout",
204 connect: func(connStr string) error {
205 conf, err := pgconn.ParseConfig(connStr)
206 require.NoError(t, err)
207 conf.ConnectTimeout = time.Microsecond * 50
208 _, err = pgconn.ConnectConfig(context.Background(), conf)
209 return err
210 },
211 },
212 }
213 for _, tt := range tests {
214 tt := tt
215 t.Run(tt.name, func(t *testing.T) {
216 t.Parallel()
217 script := &pgmock.Script{
218 Steps: []pgmock.Step{
219 pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
220 pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
221 pgmockWaitStep(time.Millisecond * 500),
222 pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
223 pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
224 },
225 }
226
227 ln, err := net.Listen("tcp", "127.0.0.1:")
228 require.NoError(t, err)
229 defer ln.Close()
230
231 serverErrChan := make(chan error, 1)
232 go func() {
233 defer close(serverErrChan)
234
235 conn, err := ln.Accept()
236 if err != nil {
237 serverErrChan <- err
238 return
239 }
240 defer conn.Close()
241
242 err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
243 if err != nil {
244 serverErrChan <- err
245 return
246 }
247
248 err = script.Run(pgproto3.NewBackend(conn, conn))
249 if err != nil {
250 serverErrChan <- err
251 return
252 }
253 }()
254
255 host, port, _ := strings.Cut(ln.Addr().String(), ":")
256 connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
257 tooLate := time.Now().Add(time.Millisecond * 500)
258
259 err = tt.connect(connStr)
260 require.True(t, pgconn.Timeout(err), err)
261 require.True(t, time.Now().Before(tooLate))
262 })
263 }
264 }
265
266 func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
267 t.Parallel()
268 tests := []struct {
269 name string
270 connect func(connStr string) error
271 }{
272 {
273 name: "via context that times out",
274 connect: func(connStr string) error {
275 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
276 defer cancel()
277 _, err := pgconn.Connect(ctx, connStr)
278 return err
279 },
280 },
281 {
282 name: "via config ConnectTimeout",
283 connect: func(connStr string) error {
284 conf, err := pgconn.ParseConfig(connStr)
285 require.NoError(t, err)
286 conf.ConnectTimeout = time.Millisecond * 10
287 _, err = pgconn.ConnectConfig(context.Background(), conf)
288 return err
289 },
290 },
291 }
292 for _, tt := range tests {
293 tt := tt
294 t.Run(tt.name, func(t *testing.T) {
295 t.Parallel()
296 ln, err := net.Listen("tcp", "127.0.0.1:")
297 require.NoError(t, err)
298 defer ln.Close()
299
300 serverErrChan := make(chan error, 1)
301 go func() {
302 conn, err := ln.Accept()
303 if err != nil {
304 serverErrChan <- err
305 return
306 }
307 defer conn.Close()
308
309 var buf []byte
310 _, err = conn.Read(buf)
311 if err != nil {
312 serverErrChan <- err
313 return
314 }
315
316
317 time.Sleep(time.Minute)
318 }()
319
320 host, port, _ := strings.Cut(ln.Addr().String(), ":")
321 connStr := fmt.Sprintf("host=%s port=%s", host, port)
322
323 errChan := make(chan error)
324 go func() {
325 err := tt.connect(connStr)
326 errChan <- err
327 }()
328
329 select {
330 case err = <-errChan:
331 require.True(t, pgconn.Timeout(err), err)
332 case err = <-serverErrChan:
333 t.Fatalf("server failed with error: %s", err)
334 case <-time.After(time.Millisecond * 500):
335 t.Fatal("exceeded connection timeout without erroring out")
336 }
337 })
338 }
339 }
340
341 func TestConnectInvalidUser(t *testing.T) {
342 t.Parallel()
343
344 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
345 defer cancel()
346
347 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
348 if connString == "" {
349 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
350 }
351
352 config, err := pgconn.ParseConfig(connString)
353 require.NoError(t, err)
354
355 config.User = "pgxinvalidusertest"
356
357 _, err = pgconn.ConnectConfig(ctx, config)
358 require.Error(t, err)
359 pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
360 if !ok {
361 t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
362 }
363 if pgErr.Code != "28000" && pgErr.Code != "28P01" {
364 t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
365 }
366 }
367
368 func TestConnectWithConnectionRefused(t *testing.T) {
369 t.Parallel()
370
371 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
372 defer cancel()
373
374
375 conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1")
376 if err == nil {
377 conn.Close(ctx)
378 t.Fatal("Expected error establishing connection to bad port")
379 }
380 }
381
382 func TestConnectCustomDialer(t *testing.T) {
383 t.Parallel()
384
385 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
386 defer cancel()
387
388 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
389 require.NoError(t, err)
390
391 dialed := false
392 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
393 dialed = true
394 return net.Dial(network, address)
395 }
396
397 conn, err := pgconn.ConnectConfig(ctx, config)
398 require.NoError(t, err)
399 require.True(t, dialed)
400 closeConn(t, conn)
401 }
402
403 func TestConnectCustomLookup(t *testing.T) {
404 t.Parallel()
405
406 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
407 defer cancel()
408
409 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
410 if connString == "" {
411 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
412 }
413
414 config, err := pgconn.ParseConfig(connString)
415 require.NoError(t, err)
416
417 looked := false
418 config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
419 looked = true
420 return net.LookupHost(host)
421 }
422
423 conn, err := pgconn.ConnectConfig(ctx, config)
424 require.NoError(t, err)
425 require.True(t, looked)
426 closeConn(t, conn)
427 }
428
429 func TestConnectCustomLookupWithPort(t *testing.T) {
430 t.Parallel()
431
432 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
433 defer cancel()
434
435 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
436 if connString == "" {
437 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
438 }
439
440 config, err := pgconn.ParseConfig(connString)
441 require.NoError(t, err)
442
443 origPort := config.Port
444
445 config.Port = 0
446
447 looked := false
448 config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
449 looked = true
450 addrs, err := net.LookupHost(host)
451 if err != nil {
452 return nil, err
453 }
454 for i := range addrs {
455 addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
456 }
457 return addrs, nil
458 }
459
460 conn, err := pgconn.ConnectConfig(ctx, config)
461 require.NoError(t, err)
462 require.True(t, looked)
463 closeConn(t, conn)
464 }
465
466 func TestConnectWithRuntimeParams(t *testing.T) {
467 t.Parallel()
468
469 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
470 defer cancel()
471
472 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
473 require.NoError(t, err)
474
475 config.RuntimeParams = map[string]string{
476 "application_name": "pgxtest",
477 "search_path": "myschema",
478 }
479
480 conn, err := pgconn.ConnectConfig(ctx, config)
481 require.NoError(t, err)
482 defer closeConn(t, conn)
483
484 result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read()
485 require.Nil(t, result.Err)
486 assert.Equal(t, 1, len(result.Rows))
487 assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
488
489 result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read()
490 require.Nil(t, result.Err)
491 assert.Equal(t, 1, len(result.Rows))
492 assert.Equal(t, "myschema", string(result.Rows[0][0]))
493 }
494
495 func TestConnectWithFallback(t *testing.T) {
496 t.Parallel()
497
498 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
499 defer cancel()
500
501 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
502 require.NoError(t, err)
503
504
505 config.Fallbacks = append([]*pgconn.FallbackConfig{
506 {
507 Host: config.Host,
508 Port: config.Port,
509 TLSConfig: config.TLSConfig,
510 },
511 }, config.Fallbacks...)
512
513
514 config.Host = "localhost"
515 config.Port = 1
516
517
518 config.Fallbacks = append([]*pgconn.FallbackConfig{
519 {
520 Host: "localhost",
521 Port: 1,
522 TLSConfig: config.TLSConfig,
523 },
524 }, config.Fallbacks...)
525
526 conn, err := pgconn.ConnectConfig(ctx, config)
527 require.NoError(t, err)
528 closeConn(t, conn)
529 }
530
531 func TestConnectWithValidateConnect(t *testing.T) {
532 t.Parallel()
533
534 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
535 defer cancel()
536
537 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
538 require.NoError(t, err)
539
540 dialCount := 0
541 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
542 dialCount++
543 return net.Dial(network, address)
544 }
545
546 acceptConnCount := 0
547 config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
548 acceptConnCount++
549 if acceptConnCount < 2 {
550 return errors.New("reject first conn")
551 }
552 return nil
553 }
554
555
556 config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
557 Host: config.Host,
558 Port: config.Port,
559 TLSConfig: config.TLSConfig,
560 })
561
562
563 config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
564
565 conn, err := pgconn.ConnectConfig(ctx, config)
566 require.NoError(t, err)
567 closeConn(t, conn)
568
569 assert.True(t, dialCount > 1)
570 assert.True(t, acceptConnCount > 1)
571 }
572
573 func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
574 t.Parallel()
575
576 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
577 defer cancel()
578
579 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
580 require.NoError(t, err)
581
582 config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
583 config.RuntimeParams["default_transaction_read_only"] = "on"
584
585 conn, err := pgconn.ConnectConfig(ctx, config)
586 if !assert.NotNil(t, err) {
587 conn.Close(ctx)
588 }
589 }
590
591 func TestConnectWithAfterConnect(t *testing.T) {
592 t.Parallel()
593
594 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
595 defer cancel()
596
597 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
598 require.NoError(t, err)
599
600 config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
601 _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
602 return err
603 }
604
605 conn, err := pgconn.ConnectConfig(ctx, config)
606 require.NoError(t, err)
607
608 results, err := conn.Exec(ctx, "show search_path;").ReadAll()
609 require.NoError(t, err)
610 defer closeConn(t, conn)
611
612 assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
613 }
614
615 func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
616 t.Parallel()
617
618 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
619 defer cancel()
620
621 config := &pgconn.Config{}
622
623 require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(ctx, config) })
624 }
625
626 func TestConnPrepareSyntaxError(t *testing.T) {
627 t.Parallel()
628
629 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
630 defer cancel()
631
632 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
633 require.NoError(t, err)
634 defer closeConn(t, pgConn)
635
636 psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil)
637 require.Nil(t, psd)
638 require.NotNil(t, err)
639
640 ensureConnValid(t, pgConn)
641 }
642
643 func TestConnPrepareContextPrecanceled(t *testing.T) {
644 t.Parallel()
645
646 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
647 defer cancel()
648
649 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
650 require.NoError(t, err)
651 defer closeConn(t, pgConn)
652
653 cancel()
654
655 psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
656 assert.Nil(t, psd)
657 assert.Error(t, err)
658 assert.True(t, errors.Is(err, context.Canceled))
659 assert.True(t, pgconn.SafeToRetry(err))
660
661 ensureConnValid(t, pgConn)
662 }
663
664 func TestConnDeallocate(t *testing.T) {
665 t.Parallel()
666
667 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
668 defer cancel()
669
670 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
671 require.NoError(t, err)
672 defer closeConn(t, pgConn)
673
674 _, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
675 require.NoError(t, err)
676
677 _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
678 require.NoError(t, err)
679
680 err = pgConn.Deallocate(ctx, "ps1")
681 require.NoError(t, err)
682
683 _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
684 require.Error(t, err)
685 var pgErr *pgconn.PgError
686 require.ErrorAs(t, err, &pgErr)
687 require.Equal(t, "26000", pgErr.Code)
688
689 ensureConnValid(t, pgConn)
690 }
691
692 func TestConnDeallocateSucceedsInAbortedTransaction(t *testing.T) {
693 t.Parallel()
694
695 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
696 defer cancel()
697
698 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
699 require.NoError(t, err)
700 defer closeConn(t, pgConn)
701
702 err = pgConn.Exec(ctx, "begin").Close()
703 require.NoError(t, err)
704
705 _, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
706 require.NoError(t, err)
707
708 _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
709 require.NoError(t, err)
710
711 err = pgConn.Exec(ctx, "select 1/0").Close()
712 require.Error(t, err)
713 var pgErr *pgconn.PgError
714 require.ErrorAs(t, err, &pgErr)
715 require.Equal(t, "22012", pgErr.Code)
716
717 err = pgConn.Deallocate(ctx, "ps1")
718 require.NoError(t, err)
719
720 err = pgConn.Exec(ctx, "rollback").Close()
721 require.NoError(t, err)
722
723 _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
724 require.Error(t, err)
725 require.ErrorAs(t, err, &pgErr)
726 require.Equal(t, "26000", pgErr.Code)
727
728 ensureConnValid(t, pgConn)
729 }
730
731 func TestConnDeallocateNonExistantStatementSucceeds(t *testing.T) {
732 t.Parallel()
733
734 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
735 defer cancel()
736
737 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
738 require.NoError(t, err)
739 defer closeConn(t, pgConn)
740
741 err = pgConn.Deallocate(ctx, "ps1")
742 require.NoError(t, err)
743
744 ensureConnValid(t, pgConn)
745 }
746
747 func TestConnExec(t *testing.T) {
748 t.Parallel()
749
750 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
751 defer cancel()
752
753 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
754 require.NoError(t, err)
755 defer closeConn(t, pgConn)
756
757 results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
758 assert.NoError(t, err)
759
760 assert.Len(t, results, 1)
761 assert.Nil(t, results[0].Err)
762 assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
763 assert.Len(t, results[0].Rows, 1)
764 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
765
766 ensureConnValid(t, pgConn)
767 }
768
769 func TestConnExecEmpty(t *testing.T) {
770 t.Parallel()
771
772 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
773 defer cancel()
774
775 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
776 require.NoError(t, err)
777 defer closeConn(t, pgConn)
778
779 multiResult := pgConn.Exec(ctx, ";")
780
781 resultCount := 0
782 for multiResult.NextResult() {
783 resultCount++
784 multiResult.ResultReader().Close()
785 }
786 assert.Equal(t, 0, resultCount)
787 err = multiResult.Close()
788 assert.NoError(t, err)
789
790 ensureConnValid(t, pgConn)
791 }
792
793 func TestConnExecMultipleQueries(t *testing.T) {
794 t.Parallel()
795
796 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
797 defer cancel()
798
799 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
800 require.NoError(t, err)
801 defer closeConn(t, pgConn)
802
803 results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll()
804 assert.NoError(t, err)
805
806 assert.Len(t, results, 2)
807
808 assert.Nil(t, results[0].Err)
809 assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
810 assert.Len(t, results[0].Rows, 1)
811 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
812
813 assert.Nil(t, results[1].Err)
814 assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
815 assert.Len(t, results[1].Rows, 1)
816 assert.Equal(t, "1", string(results[1].Rows[0][0]))
817
818 ensureConnValid(t, pgConn)
819 }
820
821 func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
822 t.Parallel()
823
824 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
825 defer cancel()
826
827 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
828 require.NoError(t, err)
829 defer closeConn(t, pgConn)
830
831 mrr := pgConn.Exec(ctx, "select 'Hello, world' as msg; select 1 as num")
832
833 require.True(t, mrr.NextResult())
834 require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
835 assert.Equal(t, "msg", mrr.ResultReader().FieldDescriptions()[0].Name)
836 _, err = mrr.ResultReader().Close()
837 require.NoError(t, err)
838
839 require.True(t, mrr.NextResult())
840 require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
841 assert.Equal(t, "num", mrr.ResultReader().FieldDescriptions()[0].Name)
842 _, err = mrr.ResultReader().Close()
843 require.NoError(t, err)
844
845 require.False(t, mrr.NextResult())
846
847 require.NoError(t, mrr.Close())
848
849 ensureConnValid(t, pgConn)
850 }
851
852 func TestConnExecMultipleQueriesError(t *testing.T) {
853 t.Parallel()
854
855 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
856 defer cancel()
857
858 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
859 require.NoError(t, err)
860 defer closeConn(t, pgConn)
861
862 results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll()
863 require.NotNil(t, err)
864 if pgErr, ok := err.(*pgconn.PgError); ok {
865 assert.Equal(t, "22012", pgErr.Code)
866 } else {
867 t.Errorf("unexpected error: %v", err)
868 }
869
870 if pgConn.ParameterStatus("crdb_version") != "" {
871
872 require.Len(t, results, 2)
873 assert.Len(t, results[0].Rows, 1)
874 assert.Equal(t, "1", string(results[0].Rows[0][0]))
875 assert.Len(t, results[1].Rows, 0)
876 } else {
877
878 require.Len(t, results, 1)
879 assert.Len(t, results[0].Rows, 1)
880 assert.Equal(t, "1", string(results[0].Rows[0][0]))
881 }
882
883 ensureConnValid(t, pgConn)
884 }
885
886 func TestConnExecDeferredError(t *testing.T) {
887 t.Parallel()
888
889 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
890 defer cancel()
891
892 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
893 require.NoError(t, err)
894 defer closeConn(t, pgConn)
895
896 if pgConn.ParameterStatus("crdb_version") != "" {
897 t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
898 }
899
900 setupSQL := `create temporary table t (
901 id text primary key,
902 n int not null,
903 unique (n) deferrable initially deferred
904 );
905
906 insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
907
908 _, err = pgConn.Exec(ctx, setupSQL).ReadAll()
909 assert.NoError(t, err)
910
911 _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll()
912 require.NotNil(t, err)
913
914 var pgErr *pgconn.PgError
915 require.True(t, errors.As(err, &pgErr))
916 require.Equal(t, "23505", pgErr.Code)
917
918 ensureConnValid(t, pgConn)
919 }
920
921 func TestConnExecContextCanceled(t *testing.T) {
922 t.Parallel()
923
924 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
925 defer cancel()
926
927 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
928 require.NoError(t, err)
929 defer closeConn(t, pgConn)
930 cancel()
931
932 ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
933 defer cancel()
934 multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)")
935
936 for multiResult.NextResult() {
937 }
938 err = multiResult.Close()
939 assert.True(t, pgconn.Timeout(err))
940 assert.ErrorIs(t, err, context.DeadlineExceeded)
941 assert.True(t, pgConn.IsClosed())
942 select {
943 case <-pgConn.CleanupDone():
944 case <-time.After(5 * time.Second):
945 t.Fatal("Connection cleanup exceeded maximum time")
946 }
947 }
948
949 func TestConnExecContextPrecanceled(t *testing.T) {
950 t.Parallel()
951
952 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
953 defer cancel()
954
955 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
956 require.NoError(t, err)
957 defer closeConn(t, pgConn)
958
959 cancel()
960 _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
961 assert.Error(t, err)
962 assert.True(t, errors.Is(err, context.Canceled))
963 assert.True(t, pgconn.SafeToRetry(err))
964
965 ensureConnValid(t, pgConn)
966 }
967
968 func TestConnExecParams(t *testing.T) {
969 t.Parallel()
970
971 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
972 defer cancel()
973
974 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
975 require.NoError(t, err)
976 defer closeConn(t, pgConn)
977
978 result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
979 require.Len(t, result.FieldDescriptions(), 1)
980 assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
981
982 rowCount := 0
983 for result.NextRow() {
984 rowCount += 1
985 assert.Equal(t, "Hello, world", string(result.Values()[0]))
986 }
987 assert.Equal(t, 1, rowCount)
988 commandTag, err := result.Close()
989 assert.Equal(t, "SELECT 1", commandTag.String())
990 assert.NoError(t, err)
991
992 ensureConnValid(t, pgConn)
993 }
994
995 func TestConnExecParamsDeferredError(t *testing.T) {
996 t.Parallel()
997
998 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
999 defer cancel()
1000
1001 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1002 require.NoError(t, err)
1003 defer closeConn(t, pgConn)
1004
1005 if pgConn.ParameterStatus("crdb_version") != "" {
1006 t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
1007 }
1008
1009 setupSQL := `create temporary table t (
1010 id text primary key,
1011 n int not null,
1012 unique (n) deferrable initially deferred
1013 );
1014
1015 insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
1016
1017 _, err = pgConn.Exec(ctx, setupSQL).ReadAll()
1018 assert.NoError(t, err)
1019
1020 result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
1021 require.NotNil(t, result.Err)
1022 var pgErr *pgconn.PgError
1023 require.True(t, errors.As(result.Err, &pgErr))
1024 require.Equal(t, "23505", pgErr.Code)
1025
1026 ensureConnValid(t, pgConn)
1027 }
1028
1029 func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
1030 t.Parallel()
1031
1032 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1033 defer cancel()
1034
1035 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1036 require.NoError(t, err)
1037 defer closeConn(t, pgConn)
1038
1039 paramCount := math.MaxUint16
1040 params := make([]string, 0, paramCount)
1041 args := make([][]byte, 0, paramCount)
1042 for i := 0; i < paramCount; i++ {
1043 params = append(params, fmt.Sprintf("($%d::text)", i+1))
1044 args = append(args, []byte(strconv.Itoa(i)))
1045 }
1046 sql := "values" + strings.Join(params, ", ")
1047
1048 result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
1049 require.NoError(t, result.Err)
1050 require.Len(t, result.Rows, paramCount)
1051
1052 ensureConnValid(t, pgConn)
1053 }
1054
1055 func TestConnExecParamsTooManyParams(t *testing.T) {
1056 t.Parallel()
1057
1058 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1059 defer cancel()
1060
1061 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1062 require.NoError(t, err)
1063 defer closeConn(t, pgConn)
1064
1065 paramCount := math.MaxUint16 + 1
1066 params := make([]string, 0, paramCount)
1067 args := make([][]byte, 0, paramCount)
1068 for i := 0; i < paramCount; i++ {
1069 params = append(params, fmt.Sprintf("($%d::text)", i+1))
1070 args = append(args, []byte(strconv.Itoa(i)))
1071 }
1072 sql := "values" + strings.Join(params, ", ")
1073
1074 result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
1075 require.Error(t, result.Err)
1076 require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
1077
1078 ensureConnValid(t, pgConn)
1079 }
1080
1081 func TestConnExecParamsCanceled(t *testing.T) {
1082 t.Parallel()
1083
1084 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1085 defer cancel()
1086
1087 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1088 require.NoError(t, err)
1089 defer closeConn(t, pgConn)
1090
1091 ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
1092 defer cancel()
1093 result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
1094 rowCount := 0
1095 for result.NextRow() {
1096 rowCount += 1
1097 }
1098 assert.Equal(t, 0, rowCount)
1099 commandTag, err := result.Close()
1100 assert.Equal(t, pgconn.CommandTag{}, commandTag)
1101 assert.True(t, pgconn.Timeout(err))
1102 assert.ErrorIs(t, err, context.DeadlineExceeded)
1103
1104 assert.True(t, pgConn.IsClosed())
1105 select {
1106 case <-pgConn.CleanupDone():
1107 case <-time.After(5 * time.Second):
1108 t.Fatal("Connection cleanup exceeded maximum time")
1109 }
1110 }
1111
1112 func TestConnExecParamsPrecanceled(t *testing.T) {
1113 t.Parallel()
1114
1115 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1116 defer cancel()
1117
1118 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1119 require.NoError(t, err)
1120 defer closeConn(t, pgConn)
1121
1122 cancel()
1123 result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
1124 require.Error(t, result.Err)
1125 assert.True(t, errors.Is(result.Err, context.Canceled))
1126 assert.True(t, pgconn.SafeToRetry(result.Err))
1127
1128 ensureConnValid(t, pgConn)
1129 }
1130
1131 func TestConnExecParamsEmptySQL(t *testing.T) {
1132 t.Parallel()
1133
1134 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1135 defer cancel()
1136
1137 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1138 require.NoError(t, err)
1139 defer closeConn(t, pgConn)
1140
1141 result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
1142 assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
1143 assert.Len(t, result.Rows, 0)
1144 assert.NoError(t, result.Err)
1145
1146 ensureConnValid(t, pgConn)
1147 }
1148
1149
1150 func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
1151 t.Parallel()
1152
1153 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1154 defer cancel()
1155
1156 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1157 require.NoError(t, err)
1158 defer closeConn(t, pgConn)
1159
1160 result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
1161 require.Len(t, result.FieldDescriptions(), 1)
1162 assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
1163
1164 rowCount := 0
1165 for result.NextRow() {
1166 rowCount += 1
1167 assert.Equal(t, "Hello, world", string(result.Values()[0]))
1168 assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0]))
1169 }
1170 assert.Equal(t, 1, rowCount)
1171 commandTag, err := result.Close()
1172 assert.Equal(t, "SELECT 1", commandTag.String())
1173 assert.NoError(t, err)
1174
1175 ensureConnValid(t, pgConn)
1176 }
1177
1178 func TestConnExecPrepared(t *testing.T) {
1179 t.Parallel()
1180
1181 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1182 defer cancel()
1183
1184 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1185 require.NoError(t, err)
1186 defer closeConn(t, pgConn)
1187
1188 psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text as msg", nil)
1189 require.NoError(t, err)
1190 require.NotNil(t, psd)
1191 assert.Len(t, psd.ParamOIDs, 1)
1192 assert.Len(t, psd.Fields, 1)
1193
1194 result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
1195 require.Len(t, result.FieldDescriptions(), 1)
1196 assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
1197
1198 rowCount := 0
1199 for result.NextRow() {
1200 rowCount += 1
1201 assert.Equal(t, "Hello, world", string(result.Values()[0]))
1202 }
1203 assert.Equal(t, 1, rowCount)
1204 commandTag, err := result.Close()
1205 assert.Equal(t, "SELECT 1", commandTag.String())
1206 assert.NoError(t, err)
1207
1208 ensureConnValid(t, pgConn)
1209 }
1210
1211 func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
1212 t.Parallel()
1213
1214 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1215 defer cancel()
1216
1217 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1218 require.NoError(t, err)
1219 defer closeConn(t, pgConn)
1220
1221 paramCount := math.MaxUint16
1222 params := make([]string, 0, paramCount)
1223 args := make([][]byte, 0, paramCount)
1224 for i := 0; i < paramCount; i++ {
1225 params = append(params, fmt.Sprintf("($%d::text)", i+1))
1226 args = append(args, []byte(strconv.Itoa(i)))
1227 }
1228 sql := "values" + strings.Join(params, ", ")
1229
1230 psd, err := pgConn.Prepare(ctx, "ps1", sql, nil)
1231 require.NoError(t, err)
1232 require.NotNil(t, psd)
1233 assert.Len(t, psd.ParamOIDs, paramCount)
1234 assert.Len(t, psd.Fields, 1)
1235
1236 result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read()
1237 require.NoError(t, result.Err)
1238 require.Len(t, result.Rows, paramCount)
1239
1240 ensureConnValid(t, pgConn)
1241 }
1242
1243 func TestConnExecPreparedTooManyParams(t *testing.T) {
1244 t.Parallel()
1245
1246 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1247 defer cancel()
1248
1249 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1250 require.NoError(t, err)
1251 defer closeConn(t, pgConn)
1252
1253 paramCount := math.MaxUint16 + 1
1254 params := make([]string, 0, paramCount)
1255 args := make([][]byte, 0, paramCount)
1256 for i := 0; i < paramCount; i++ {
1257 params = append(params, fmt.Sprintf("($%d::text)", i+1))
1258 args = append(args, []byte(strconv.Itoa(i)))
1259 }
1260 sql := "values" + strings.Join(params, ", ")
1261
1262 psd, err := pgConn.Prepare(ctx, "ps1", sql, nil)
1263 if pgConn.ParameterStatus("crdb_version") != "" {
1264
1265 require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)")
1266 } else {
1267
1268 require.NoError(t, err)
1269 require.NotNil(t, psd)
1270 assert.Len(t, psd.ParamOIDs, paramCount)
1271 assert.Len(t, psd.Fields, 1)
1272
1273 result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read()
1274 require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters")
1275 }
1276
1277 ensureConnValid(t, pgConn)
1278 }
1279
1280 func TestConnExecPreparedCanceled(t *testing.T) {
1281 t.Parallel()
1282
1283 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1284 defer cancel()
1285
1286 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1287 require.NoError(t, err)
1288 defer closeConn(t, pgConn)
1289
1290 _, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil)
1291 require.NoError(t, err)
1292
1293 ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
1294 defer cancel()
1295 result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
1296 rowCount := 0
1297 for result.NextRow() {
1298 rowCount += 1
1299 }
1300 assert.Equal(t, 0, rowCount)
1301 commandTag, err := result.Close()
1302 assert.Equal(t, pgconn.CommandTag{}, commandTag)
1303 assert.True(t, pgconn.Timeout(err))
1304 assert.True(t, pgConn.IsClosed())
1305 select {
1306 case <-pgConn.CleanupDone():
1307 case <-time.After(5 * time.Second):
1308 t.Fatal("Connection cleanup exceeded maximum time")
1309 }
1310 }
1311
1312 func TestConnExecPreparedPrecanceled(t *testing.T) {
1313 t.Parallel()
1314
1315 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1316 defer cancel()
1317
1318 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1319 require.NoError(t, err)
1320 defer closeConn(t, pgConn)
1321
1322 _, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil)
1323 require.NoError(t, err)
1324
1325 cancel()
1326 result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
1327 require.Error(t, result.Err)
1328 assert.True(t, errors.Is(result.Err, context.Canceled))
1329 assert.True(t, pgconn.SafeToRetry(result.Err))
1330
1331 ensureConnValid(t, pgConn)
1332 }
1333
1334 func TestConnExecPreparedEmptySQL(t *testing.T) {
1335 t.Parallel()
1336
1337 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1338 defer cancel()
1339
1340 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1341 require.NoError(t, err)
1342 defer closeConn(t, pgConn)
1343
1344 _, err = pgConn.Prepare(ctx, "ps1", "", nil)
1345 require.NoError(t, err)
1346
1347 result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
1348 assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
1349 assert.Len(t, result.Rows, 0)
1350 assert.NoError(t, result.Err)
1351
1352 ensureConnValid(t, pgConn)
1353 }
1354
1355 func TestConnExecBatch(t *testing.T) {
1356 t.Parallel()
1357
1358 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1359 defer cancel()
1360
1361 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1362 require.NoError(t, err)
1363 defer closeConn(t, pgConn)
1364
1365 _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil)
1366 require.NoError(t, err)
1367
1368 batch := &pgconn.Batch{}
1369
1370 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
1371 batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
1372 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
1373 results, err := pgConn.ExecBatch(ctx, batch).ReadAll()
1374 require.NoError(t, err)
1375 require.Len(t, results, 3)
1376
1377 require.Len(t, results[0].Rows, 1)
1378 require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
1379 assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
1380
1381 require.Len(t, results[1].Rows, 1)
1382 require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
1383 assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
1384
1385 require.Len(t, results[2].Rows, 1)
1386 require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
1387 assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
1388 }
1389
1390 func TestConnExecBatchDeferredError(t *testing.T) {
1391 t.Parallel()
1392
1393 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1394 defer cancel()
1395
1396 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1397 require.NoError(t, err)
1398 defer closeConn(t, pgConn)
1399
1400 if pgConn.ParameterStatus("crdb_version") != "" {
1401 t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
1402 }
1403
1404 setupSQL := `create temporary table t (
1405 id text primary key,
1406 n int not null,
1407 unique (n) deferrable initially deferred
1408 );
1409
1410 insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
1411
1412 _, err = pgConn.Exec(ctx, setupSQL).ReadAll()
1413 require.NoError(t, err)
1414
1415 batch := &pgconn.Batch{}
1416
1417 batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
1418 _, err = pgConn.ExecBatch(ctx, batch).ReadAll()
1419 require.NotNil(t, err)
1420 var pgErr *pgconn.PgError
1421 require.True(t, errors.As(err, &pgErr))
1422 require.Equal(t, "23505", pgErr.Code)
1423
1424 ensureConnValid(t, pgConn)
1425 }
1426
1427 func TestConnExecBatchPrecanceled(t *testing.T) {
1428 t.Parallel()
1429
1430 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1431 defer cancel()
1432
1433 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1434 require.NoError(t, err)
1435 defer closeConn(t, pgConn)
1436
1437 _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil)
1438 require.NoError(t, err)
1439
1440 batch := &pgconn.Batch{}
1441
1442 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
1443 batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
1444 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
1445
1446 cancel()
1447 _, err = pgConn.ExecBatch(ctx, batch).ReadAll()
1448 require.Error(t, err)
1449 assert.True(t, errors.Is(err, context.Canceled))
1450 assert.True(t, pgconn.SafeToRetry(err))
1451
1452 ensureConnValid(t, pgConn)
1453 }
1454
1455
1456
1457
1458 func TestConnExecBatchHuge(t *testing.T) {
1459 if testing.Short() {
1460 t.Skip("skipping test in short mode.")
1461 }
1462
1463 t.Parallel()
1464
1465 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1466 defer cancel()
1467
1468 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1469 require.NoError(t, err)
1470 defer closeConn(t, pgConn)
1471
1472 batch := &pgconn.Batch{}
1473
1474 queryCount := 100000
1475 args := make([]string, queryCount)
1476
1477 for i := range args {
1478 args[i] = strconv.Itoa(i)
1479 batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil)
1480 }
1481
1482 results, err := pgConn.ExecBatch(ctx, batch).ReadAll()
1483 require.NoError(t, err)
1484 require.Len(t, results, queryCount)
1485
1486 for i := range args {
1487 require.Len(t, results[i].Rows, 1)
1488 require.Equal(t, args[i], string(results[i].Rows[0][0]))
1489 assert.Equal(t, "SELECT 1", results[i].CommandTag.String())
1490 }
1491 }
1492
1493 func TestConnExecBatchImplicitTransaction(t *testing.T) {
1494 t.Parallel()
1495
1496 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1497 defer cancel()
1498
1499 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1500 require.NoError(t, err)
1501 defer closeConn(t, pgConn)
1502
1503 if pgConn.ParameterStatus("crdb_version") != "" {
1504 t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)")
1505 }
1506
1507 _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll()
1508 require.NoError(t, err)
1509
1510 batch := &pgconn.Batch{}
1511
1512 batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil)
1513 batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil)
1514 batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil)
1515 batch.ExecParams("select 1/0", nil, nil, nil, nil)
1516 _, err = pgConn.ExecBatch(ctx, batch).ReadAll()
1517 require.Error(t, err)
1518
1519 result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read()
1520 require.Equal(t, "0", string(result.Rows[0][0]))
1521 }
1522
1523 func TestConnLocking(t *testing.T) {
1524 t.Parallel()
1525
1526 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1527 defer cancel()
1528
1529 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1530 require.NoError(t, err)
1531 defer closeConn(t, pgConn)
1532
1533 mrr := pgConn.Exec(ctx, "select 'Hello, world'")
1534 _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
1535 assert.Error(t, err)
1536 assert.Equal(t, "conn busy", err.Error())
1537 assert.True(t, pgconn.SafeToRetry(err))
1538
1539 results, err := mrr.ReadAll()
1540 assert.NoError(t, err)
1541 assert.Len(t, results, 1)
1542 assert.Nil(t, results[0].Err)
1543 assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
1544 assert.Len(t, results[0].Rows, 1)
1545 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
1546
1547 ensureConnValid(t, pgConn)
1548 }
1549
1550 func TestConnOnNotice(t *testing.T) {
1551 t.Parallel()
1552
1553 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1554 defer cancel()
1555
1556 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1557 require.NoError(t, err)
1558
1559 var msg string
1560 config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) {
1561 msg = notice.Message
1562 }
1563 config.RuntimeParams["client_min_messages"] = "notice"
1564
1565 pgConn, err := pgconn.ConnectConfig(ctx, config)
1566 require.NoError(t, err)
1567 defer closeConn(t, pgConn)
1568
1569 if pgConn.ParameterStatus("crdb_version") != "" {
1570 t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)")
1571 }
1572
1573 multiResult := pgConn.Exec(ctx, `do $$
1574 begin
1575 raise notice 'hello, world';
1576 end$$;`)
1577 err = multiResult.Close()
1578 require.NoError(t, err)
1579 assert.Equal(t, "hello, world", msg)
1580
1581 ensureConnValid(t, pgConn)
1582 }
1583
1584 func TestConnOnNotification(t *testing.T) {
1585 t.Parallel()
1586
1587 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1588 defer cancel()
1589
1590 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1591 require.NoError(t, err)
1592
1593 var msg string
1594 config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
1595 msg = n.Payload
1596 }
1597
1598 pgConn, err := pgconn.ConnectConfig(ctx, config)
1599 require.NoError(t, err)
1600 defer closeConn(t, pgConn)
1601
1602 if pgConn.ParameterStatus("crdb_version") != "" {
1603 t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
1604 }
1605
1606 _, err = pgConn.Exec(ctx, "listen foo").ReadAll()
1607 require.NoError(t, err)
1608
1609 notifier, err := pgconn.ConnectConfig(ctx, config)
1610 require.NoError(t, err)
1611 defer closeConn(t, notifier)
1612 _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll()
1613 require.NoError(t, err)
1614
1615 _, err = pgConn.Exec(ctx, "select 1").ReadAll()
1616 require.NoError(t, err)
1617
1618 assert.Equal(t, "bar", msg)
1619
1620 ensureConnValid(t, pgConn)
1621 }
1622
1623 func TestConnWaitForNotification(t *testing.T) {
1624 t.Parallel()
1625
1626 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1627 defer cancel()
1628
1629 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1630 require.NoError(t, err)
1631
1632 var msg string
1633 config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
1634 msg = n.Payload
1635 }
1636
1637 pgConn, err := pgconn.ConnectConfig(ctx, config)
1638 require.NoError(t, err)
1639 defer closeConn(t, pgConn)
1640
1641 if pgConn.ParameterStatus("crdb_version") != "" {
1642 t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
1643 }
1644
1645 _, err = pgConn.Exec(ctx, "listen foo").ReadAll()
1646 require.NoError(t, err)
1647
1648 notifier, err := pgconn.ConnectConfig(ctx, config)
1649 require.NoError(t, err)
1650 defer closeConn(t, notifier)
1651 _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll()
1652 require.NoError(t, err)
1653
1654 err = pgConn.WaitForNotification(ctx)
1655 require.NoError(t, err)
1656
1657 assert.Equal(t, "bar", msg)
1658
1659 ensureConnValid(t, pgConn)
1660 }
1661
1662 func TestConnWaitForNotificationPrecanceled(t *testing.T) {
1663 t.Parallel()
1664
1665 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1666 defer cancel()
1667
1668 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1669 require.NoError(t, err)
1670
1671 pgConn, err := pgconn.ConnectConfig(ctx, config)
1672 require.NoError(t, err)
1673 defer closeConn(t, pgConn)
1674
1675 cancel()
1676 err = pgConn.WaitForNotification(ctx)
1677 require.ErrorIs(t, err, context.Canceled)
1678
1679 ensureConnValid(t, pgConn)
1680 }
1681
1682 func TestConnWaitForNotificationTimeout(t *testing.T) {
1683 t.Parallel()
1684
1685 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1686 defer cancel()
1687
1688 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
1689 require.NoError(t, err)
1690
1691 pgConn, err := pgconn.ConnectConfig(ctx, config)
1692 require.NoError(t, err)
1693 defer closeConn(t, pgConn)
1694
1695 ctx, cancel = context.WithTimeout(ctx, 5*time.Millisecond)
1696 err = pgConn.WaitForNotification(ctx)
1697 cancel()
1698 assert.True(t, pgconn.Timeout(err))
1699 assert.ErrorIs(t, err, context.DeadlineExceeded)
1700
1701 ensureConnValid(t, pgConn)
1702 }
1703
1704 func TestConnCopyToSmall(t *testing.T) {
1705 t.Parallel()
1706
1707 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1708 defer cancel()
1709
1710 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1711 require.NoError(t, err)
1712 defer closeConn(t, pgConn)
1713
1714 if pgConn.ParameterStatus("crdb_version") != "" {
1715 t.Skip("Server does support COPY TO")
1716 }
1717
1718 _, err = pgConn.Exec(ctx, `create temporary table foo(
1719 a int2,
1720 b int4,
1721 c int8,
1722 d varchar,
1723 e text,
1724 f date,
1725 g json
1726 )`).ReadAll()
1727 require.NoError(t, err)
1728
1729 _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
1730 require.NoError(t, err)
1731
1732 _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
1733 require.NoError(t, err)
1734
1735 inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
1736 "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
1737
1738 outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
1739
1740 res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout")
1741 require.NoError(t, err)
1742
1743 assert.Equal(t, int64(2), res.RowsAffected())
1744 assert.Equal(t, inputBytes, outputWriter.Bytes())
1745
1746 ensureConnValid(t, pgConn)
1747 }
1748
1749 func TestConnCopyToLarge(t *testing.T) {
1750 t.Parallel()
1751
1752 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1753 defer cancel()
1754
1755 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1756 require.NoError(t, err)
1757 defer closeConn(t, pgConn)
1758
1759 if pgConn.ParameterStatus("crdb_version") != "" {
1760 t.Skip("Server does support COPY TO")
1761 }
1762
1763 _, err = pgConn.Exec(ctx, `create temporary table foo(
1764 a int2,
1765 b int4,
1766 c int8,
1767 d varchar,
1768 e text,
1769 f date,
1770 g json,
1771 h bytea
1772 )`).ReadAll()
1773 require.NoError(t, err)
1774
1775 inputBytes := make([]byte, 0)
1776
1777 for i := 0; i < 1000; i++ {
1778 _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
1779 require.NoError(t, err)
1780 inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
1781 }
1782
1783 outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
1784
1785 res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout")
1786 require.NoError(t, err)
1787
1788 assert.Equal(t, int64(1000), res.RowsAffected())
1789 assert.Equal(t, inputBytes, outputWriter.Bytes())
1790
1791 ensureConnValid(t, pgConn)
1792 }
1793
1794 func TestConnCopyToQueryError(t *testing.T) {
1795 t.Parallel()
1796
1797 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1798 defer cancel()
1799
1800 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1801 require.NoError(t, err)
1802 defer closeConn(t, pgConn)
1803
1804 outputWriter := bytes.NewBuffer(make([]byte, 0))
1805
1806 res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout")
1807 require.Error(t, err)
1808 assert.IsType(t, &pgconn.PgError{}, err)
1809 assert.Equal(t, int64(0), res.RowsAffected())
1810
1811 ensureConnValid(t, pgConn)
1812 }
1813
1814 func TestConnCopyToCanceled(t *testing.T) {
1815 t.Parallel()
1816
1817 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1818 defer cancel()
1819
1820 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1821 require.NoError(t, err)
1822 defer closeConn(t, pgConn)
1823
1824 if pgConn.ParameterStatus("crdb_version") != "" {
1825 t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
1826 }
1827
1828 outputWriter := &bytes.Buffer{}
1829
1830 ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
1831 defer cancel()
1832 res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
1833 assert.Error(t, err)
1834 assert.Equal(t, pgconn.CommandTag{}, res)
1835
1836 assert.True(t, pgConn.IsClosed())
1837 select {
1838 case <-pgConn.CleanupDone():
1839 case <-time.After(5 * time.Second):
1840 t.Fatal("Connection cleanup exceeded maximum time")
1841 }
1842 }
1843
1844 func TestConnCopyToPrecanceled(t *testing.T) {
1845 t.Parallel()
1846
1847 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1848 defer cancel()
1849
1850 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1851 require.NoError(t, err)
1852 defer closeConn(t, pgConn)
1853
1854 outputWriter := &bytes.Buffer{}
1855
1856 cancel()
1857 res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
1858 require.Error(t, err)
1859 assert.True(t, errors.Is(err, context.Canceled))
1860 assert.True(t, pgconn.SafeToRetry(err))
1861 assert.Equal(t, pgconn.CommandTag{}, res)
1862
1863 ensureConnValid(t, pgConn)
1864 }
1865
1866 func TestConnCopyFrom(t *testing.T) {
1867 t.Parallel()
1868
1869 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1870 defer cancel()
1871
1872 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1873 require.NoError(t, err)
1874 defer closeConn(t, pgConn)
1875
1876 _, err = pgConn.Exec(ctx, `create temporary table foo(
1877 a int4,
1878 b varchar
1879 )`).ReadAll()
1880 require.NoError(t, err)
1881
1882 srcBuf := &bytes.Buffer{}
1883
1884 inputRows := [][][]byte{}
1885 for i := 0; i < 1000; i++ {
1886 a := strconv.Itoa(i)
1887 b := "foo " + a + " bar"
1888 inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
1889 _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
1890 require.NoError(t, err)
1891 }
1892
1893 copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
1894 if pgConn.ParameterStatus("crdb_version") != "" {
1895 copySql = "COPY foo FROM STDIN WITH CSV"
1896 }
1897 ct, err := pgConn.CopyFrom(ctx, srcBuf, copySql)
1898 require.NoError(t, err)
1899 assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
1900
1901 result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
1902 require.NoError(t, result.Err)
1903
1904 assert.Equal(t, inputRows, result.Rows)
1905
1906 ensureConnValid(t, pgConn)
1907 }
1908
1909 func TestConnCopyFromBinary(t *testing.T) {
1910 t.Parallel()
1911
1912 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1913 defer cancel()
1914
1915 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1916 require.NoError(t, err)
1917 defer closeConn(t, pgConn)
1918
1919 _, err = pgConn.Exec(ctx, `create temporary table foo(
1920 a int4,
1921 b varchar
1922 )`).ReadAll()
1923 require.NoError(t, err)
1924
1925 buf := []byte{}
1926 buf = append(buf, "PGCOPY\n\377\r\n\000"...)
1927 buf = pgio.AppendInt32(buf, 0)
1928 buf = pgio.AppendInt32(buf, 0)
1929
1930 inputRows := [][][]byte{}
1931 for i := 0; i < 1000; i++ {
1932
1933 buf = pgio.AppendInt16(buf, int16(2))
1934 a := i
1935
1936
1937 buf = pgio.AppendInt32(buf, 4)
1938 buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf)
1939 require.NoError(t, err)
1940
1941 b := "foo " + strconv.Itoa(a) + " bar"
1942 lenB := int32(len([]byte(b)))
1943
1944 buf = pgio.AppendInt32(buf, lenB)
1945 buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf)
1946 require.NoError(t, err)
1947
1948 inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)})
1949 }
1950
1951 srcBuf := &bytes.Buffer{}
1952 srcBuf.Write(buf)
1953 ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo (a, b) FROM STDIN BINARY;")
1954 require.NoError(t, err)
1955 assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
1956
1957 result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
1958 require.NoError(t, result.Err)
1959
1960 assert.Equal(t, inputRows, result.Rows)
1961
1962 ensureConnValid(t, pgConn)
1963 }
1964
1965 func TestConnCopyFromCanceled(t *testing.T) {
1966 t.Parallel()
1967
1968 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1969 defer cancel()
1970
1971 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
1972 require.NoError(t, err)
1973 defer closeConn(t, pgConn)
1974
1975 _, err = pgConn.Exec(ctx, `create temporary table foo(
1976 a int4,
1977 b varchar
1978 )`).ReadAll()
1979 require.NoError(t, err)
1980
1981 r, w := io.Pipe()
1982 go func() {
1983 for i := 0; i < 1000000; i++ {
1984 a := strconv.Itoa(i)
1985 b := "foo " + a + " bar"
1986 _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
1987 if err != nil {
1988 return
1989 }
1990 time.Sleep(time.Microsecond)
1991 }
1992 }()
1993
1994 ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
1995 copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
1996 if pgConn.ParameterStatus("crdb_version") != "" {
1997 copySql = "COPY foo FROM STDIN WITH CSV"
1998 }
1999 ct, err := pgConn.CopyFrom(ctx, r, copySql)
2000 cancel()
2001 assert.Equal(t, int64(0), ct.RowsAffected())
2002 assert.Error(t, err)
2003
2004 assert.True(t, pgConn.IsClosed())
2005 select {
2006 case <-pgConn.CleanupDone():
2007 case <-time.After(5 * time.Second):
2008 t.Fatal("Connection cleanup exceeded maximum time")
2009 }
2010 }
2011
2012 func TestConnCopyFromPrecanceled(t *testing.T) {
2013 t.Parallel()
2014
2015 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2016 defer cancel()
2017
2018 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2019 require.NoError(t, err)
2020 defer closeConn(t, pgConn)
2021
2022 _, err = pgConn.Exec(ctx, `create temporary table foo(
2023 a int4,
2024 b varchar
2025 )`).ReadAll()
2026 require.NoError(t, err)
2027
2028 r, w := io.Pipe()
2029 go func() {
2030 for i := 0; i < 1000000; i++ {
2031 a := strconv.Itoa(i)
2032 b := "foo " + a + " bar"
2033 _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
2034 if err != nil {
2035 return
2036 }
2037 time.Sleep(time.Microsecond)
2038 }
2039 }()
2040
2041 ctx, cancel = context.WithCancel(ctx)
2042 cancel()
2043 ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
2044 require.Error(t, err)
2045 assert.True(t, errors.Is(err, context.Canceled))
2046 assert.True(t, pgconn.SafeToRetry(err))
2047 assert.Equal(t, pgconn.CommandTag{}, ct)
2048
2049 ensureConnValid(t, pgConn)
2050 }
2051
2052 func TestConnCopyFromGzipReader(t *testing.T) {
2053 t.Parallel()
2054
2055 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2056 defer cancel()
2057
2058 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2059 require.NoError(t, err)
2060 defer closeConn(t, pgConn)
2061
2062 if pgConn.ParameterStatus("crdb_version") != "" {
2063 t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
2064 }
2065
2066 _, err = pgConn.Exec(ctx, `create temporary table foo(
2067 a int4,
2068 b varchar
2069 )`).ReadAll()
2070 require.NoError(t, err)
2071
2072 f, err := os.CreateTemp("", "*")
2073 require.NoError(t, err)
2074
2075 gw := gzip.NewWriter(f)
2076
2077 inputRows := [][][]byte{}
2078 for i := 0; i < 1000; i++ {
2079 a := strconv.Itoa(i)
2080 b := "foo " + a + " bar"
2081 inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
2082 _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
2083 require.NoError(t, err)
2084 }
2085
2086 err = gw.Close()
2087 require.NoError(t, err)
2088
2089 _, err = f.Seek(0, 0)
2090 require.NoError(t, err)
2091
2092 gr, err := gzip.NewReader(f)
2093 require.NoError(t, err)
2094
2095 copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
2096 if pgConn.ParameterStatus("crdb_version") != "" {
2097 copySql = "COPY foo FROM STDIN WITH CSV"
2098 }
2099 ct, err := pgConn.CopyFrom(ctx, gr, copySql)
2100 require.NoError(t, err)
2101 assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
2102
2103 err = gr.Close()
2104 require.NoError(t, err)
2105
2106 err = f.Close()
2107 require.NoError(t, err)
2108
2109 err = os.Remove(f.Name())
2110 require.NoError(t, err)
2111
2112 result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
2113 require.NoError(t, result.Err)
2114
2115 assert.Equal(t, inputRows, result.Rows)
2116
2117 ensureConnValid(t, pgConn)
2118 }
2119
2120 func TestConnCopyFromQuerySyntaxError(t *testing.T) {
2121 t.Parallel()
2122
2123 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2124 defer cancel()
2125
2126 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2127 require.NoError(t, err)
2128 defer closeConn(t, pgConn)
2129
2130 _, err = pgConn.Exec(ctx, `create temporary table foo(
2131 a int4,
2132 b varchar
2133 )`).ReadAll()
2134 require.NoError(t, err)
2135
2136 srcBuf := &bytes.Buffer{}
2137
2138
2139
2140 inputRows := [][][]byte{}
2141 for i := 0; i < 1000; i++ {
2142 a := strconv.Itoa(i)
2143 b := "foo " + a + " bar"
2144 inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
2145 _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
2146 require.NoError(t, err)
2147 }
2148
2149 res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo FROM STDIN WITH (FORMAT csv)")
2150 require.Error(t, err)
2151 assert.IsType(t, &pgconn.PgError{}, err)
2152 assert.Equal(t, int64(0), res.RowsAffected())
2153
2154 ensureConnValid(t, pgConn)
2155 }
2156
2157 func TestConnCopyFromQueryNoTableError(t *testing.T) {
2158 t.Parallel()
2159
2160 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2161 defer cancel()
2162
2163 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2164 require.NoError(t, err)
2165 defer closeConn(t, pgConn)
2166
2167 srcBuf := &bytes.Buffer{}
2168
2169 res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout")
2170 require.Error(t, err)
2171 assert.IsType(t, &pgconn.PgError{}, err)
2172 assert.Equal(t, int64(0), res.RowsAffected())
2173
2174 ensureConnValid(t, pgConn)
2175 }
2176
2177
2178 func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
2179 t.Parallel()
2180
2181 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2182 defer cancel()
2183
2184 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2185 require.NoError(t, err)
2186 defer closeConn(t, pgConn)
2187
2188 if pgConn.ParameterStatus("crdb_version") != "" {
2189 t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)")
2190 }
2191
2192 _, err = pgConn.Exec(ctx, `create temporary table sentences(
2193 t text,
2194 ts tsvector
2195 )`).ReadAll()
2196 require.NoError(t, err)
2197
2198 _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
2199 begin
2200 new.ts := to_tsvector(new.t);
2201 return new;
2202 end
2203 $$ language plpgsql;`).ReadAll()
2204 require.NoError(t, err)
2205
2206 _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
2207 require.NoError(t, err)
2208
2209 longString := make([]byte, 10001)
2210 for i := range longString {
2211 longString[i] = 'x'
2212 }
2213
2214 buf := &bytes.Buffer{}
2215 for i := 0; i < 1000; i++ {
2216 buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
2217 }
2218
2219 _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
2220 require.NoError(t, err)
2221 }
2222
2223 type delayedReader struct {
2224 r io.Reader
2225 }
2226
2227 func (d delayedReader) Read(p []byte) (int, error) {
2228
2229 time.Sleep(time.Millisecond)
2230 return d.r.Read(p)
2231 }
2232
2233
2234 func TestConnCopyFromDataWriteAfterErrorAndReturn(t *testing.T) {
2235 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2236 defer cancel()
2237
2238 connString := os.Getenv("PGX_TEST_DATABASE")
2239 if connString == "" {
2240 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_DATABASE")
2241 }
2242
2243 config, err := pgconn.ParseConfig(connString)
2244 require.NoError(t, err)
2245
2246 pgConn, err := pgconn.ConnectConfig(ctx, config)
2247 require.NoError(t, err)
2248
2249 if pgConn.ParameterStatus("crdb_version") != "" {
2250 t.Skip("Server does not fully support COPY FROM")
2251 }
2252
2253 setupSQL := `create temporary table t (
2254 id text primary key,
2255 n int not null
2256 );`
2257
2258 _, err = pgConn.Exec(ctx, setupSQL).ReadAll()
2259 assert.NoError(t, err)
2260
2261 r1 := delayedReader{r: strings.NewReader(`id 0\n`)}
2262
2263 _, err = pgConn.CopyFrom(ctx, r1, "COPY nosuchtable FROM STDIN ")
2264 assert.Error(t, err)
2265
2266 r2 := delayedReader{r: strings.NewReader(`id 0\n`)}
2267 _, err = pgConn.CopyFrom(ctx, r2, "COPY t FROM STDIN")
2268 assert.NoError(t, err)
2269 }
2270
2271 func TestConnEscapeString(t *testing.T) {
2272 t.Parallel()
2273
2274 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2275 defer cancel()
2276
2277 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2278 require.NoError(t, err)
2279 defer closeConn(t, pgConn)
2280
2281 tests := []struct {
2282 in string
2283 out string
2284 }{
2285 {in: "", out: ""},
2286 {in: "42", out: "42"},
2287 {in: "'", out: "''"},
2288 {in: "hi'there", out: "hi''there"},
2289 {in: "'hi there'", out: "''hi there''"},
2290 }
2291
2292 for i, tt := range tests {
2293 value, err := pgConn.EscapeString(tt.in)
2294 if assert.NoErrorf(t, err, "%d.", i) {
2295 assert.Equalf(t, tt.out, value, "%d.", i)
2296 }
2297 }
2298
2299 ensureConnValid(t, pgConn)
2300 }
2301
2302 func TestConnCancelRequest(t *testing.T) {
2303 t.Parallel()
2304
2305 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2306 defer cancel()
2307
2308 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2309 require.NoError(t, err)
2310 defer closeConn(t, pgConn)
2311
2312 if pgConn.ParameterStatus("crdb_version") != "" {
2313 t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
2314 }
2315
2316 multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(25)")
2317
2318 errChan := make(chan error)
2319 go func() {
2320
2321
2322 time.Sleep(1 * time.Second)
2323
2324 err := pgConn.CancelRequest(ctx)
2325 errChan <- err
2326 }()
2327
2328 for multiResult.NextResult() {
2329 }
2330 err = multiResult.Close()
2331
2332 require.IsType(t, &pgconn.PgError{}, err)
2333 require.Equal(t, "57014", err.(*pgconn.PgError).Code)
2334
2335 err = <-errChan
2336 require.NoError(t, err)
2337
2338 ensureConnValid(t, pgConn)
2339 }
2340
2341
2342 func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
2343 t.Parallel()
2344
2345 t.Run("postgres", func(t *testing.T) {
2346 t.Parallel()
2347
2348 testConnContextCanceledCancelsRunningQueryOnServer(t, os.Getenv("PGX_TEST_DATABASE"), "postgres")
2349 })
2350
2351 t.Run("pgbouncer", func(t *testing.T) {
2352 t.Parallel()
2353
2354 connString := os.Getenv(pgbouncerConnStringEnvVar)
2355 if connString == "" {
2356 t.Skipf("Skipping due to missing environment variable %v", pgbouncerConnStringEnvVar)
2357 }
2358
2359 testConnContextCanceledCancelsRunningQueryOnServer(t, connString, "pgbouncer")
2360 })
2361 }
2362
2363 func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString, dbType string) {
2364 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2365 defer cancel()
2366
2367 pgConn, err := pgconn.Connect(ctx, connString)
2368 require.NoError(t, err)
2369 defer closeConn(t, pgConn)
2370
2371 ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
2372 defer cancel()
2373
2374
2375
2376
2377 queryID := fmt.Sprintf("%s testConnContextCanceled %d", dbType, time.Now().UnixNano())
2378
2379 multiResult := pgConn.Exec(ctx, fmt.Sprintf(`
2380 -- %v
2381 select 'Hello, world', pg_sleep(30)
2382 `, queryID))
2383
2384 for multiResult.NextResult() {
2385 }
2386 err = multiResult.Close()
2387 assert.True(t, pgconn.Timeout(err))
2388 assert.True(t, pgConn.IsClosed())
2389 select {
2390 case <-pgConn.CleanupDone():
2391 case <-time.After(5 * time.Second):
2392 t.Fatal("Connection cleanup exceeded maximum time")
2393 }
2394
2395 ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
2396 defer cancel()
2397
2398 otherConn, err := pgconn.Connect(ctx, connString)
2399 require.NoError(t, err)
2400 defer closeConn(t, otherConn)
2401
2402 ctx, cancel = context.WithTimeout(ctx, time.Second*5)
2403 defer cancel()
2404
2405 for {
2406 result := otherConn.ExecParams(ctx,
2407 `select 1 from pg_stat_activity where query like $1`,
2408 [][]byte{[]byte("%" + queryID + "%")},
2409 nil,
2410 nil,
2411 nil,
2412 ).Read()
2413 require.NoError(t, result.Err)
2414
2415 if len(result.Rows) == 0 {
2416 break
2417 }
2418 }
2419 }
2420
2421 func TestHijackAndConstruct(t *testing.T) {
2422 t.Parallel()
2423
2424 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2425 defer cancel()
2426
2427 origConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2428 require.NoError(t, err)
2429
2430 err = origConn.SyncConn(ctx)
2431 require.NoError(t, err)
2432
2433 hc, err := origConn.Hijack()
2434 require.NoError(t, err)
2435
2436 _, err = origConn.Exec(ctx, "select 'Hello, world'").ReadAll()
2437 require.Error(t, err)
2438
2439 newConn, err := pgconn.Construct(hc)
2440 require.NoError(t, err)
2441
2442 defer closeConn(t, newConn)
2443
2444 results, err := newConn.Exec(ctx, "select 'Hello, world'").ReadAll()
2445 assert.NoError(t, err)
2446
2447 assert.Len(t, results, 1)
2448 assert.Nil(t, results[0].Err)
2449 assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
2450 assert.Len(t, results[0].Rows, 1)
2451 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
2452
2453 ensureConnValid(t, newConn)
2454 }
2455
2456 func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) {
2457 t.Parallel()
2458
2459 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2460 defer cancel()
2461
2462 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2463 require.NoError(t, err)
2464
2465 pgConn.Exec(ctx, "select n from generate_series(1,10) n")
2466
2467 closeCtx, _ := context.WithCancel(ctx)
2468 pgConn.Close(closeCtx)
2469 select {
2470 case <-pgConn.CleanupDone():
2471 case <-time.After(5 * time.Second):
2472 t.Fatal("Connection cleanup exceeded maximum time")
2473 }
2474 }
2475
2476
2477 func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) {
2478 t.Parallel()
2479
2480 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2481 defer cancel()
2482
2483 steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
2484 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
2485 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{}))
2486 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
2487 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{}))
2488 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{}))
2489 steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
2490 {Name: []byte("mock")},
2491 }}))
2492 steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}))
2493 steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
2494
2495 script := &pgmock.Script{Steps: steps}
2496
2497 ln, err := net.Listen("tcp", "127.0.0.1:")
2498 require.NoError(t, err)
2499 defer ln.Close()
2500
2501 serverErrChan := make(chan error, 1)
2502 go func() {
2503 defer close(serverErrChan)
2504
2505 conn, err := ln.Accept()
2506 if err != nil {
2507 serverErrChan <- err
2508 return
2509 }
2510 defer conn.Close()
2511
2512 err = conn.SetDeadline(time.Now().Add(5 * time.Second))
2513 if err != nil {
2514 serverErrChan <- err
2515 return
2516 }
2517
2518 err = script.Run(pgproto3.NewBackend(conn, conn))
2519 if err != nil {
2520 serverErrChan <- err
2521 return
2522 }
2523 }()
2524
2525 host, port, _ := strings.Cut(ln.Addr().String(), ":")
2526 connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
2527
2528 ctx, cancel = context.WithTimeout(ctx, 5*time.Second)
2529 defer cancel()
2530 conn, err := pgconn.Connect(ctx, connStr)
2531 require.NoError(t, err)
2532
2533 rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil)
2534
2535 for rr.NextRow() {
2536 }
2537
2538 _, err = rr.Close()
2539 require.Error(t, err)
2540 }
2541
2542
2543 func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) {
2544 t.Parallel()
2545
2546 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2547 defer cancel()
2548
2549 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2550 require.NoError(t, err)
2551 defer closeConn(t, pgConn)
2552
2553 _, err = pgConn.Exec(ctx, "set client_min_messages = debug5").ReadAll()
2554 require.NoError(t, err)
2555
2556
2557
2558
2559 paramCount := math.MaxUint16
2560 params := make([]string, 0, paramCount)
2561 args := make([][]byte, 0, paramCount)
2562 for i := 0; i < paramCount; i++ {
2563 params = append(params, fmt.Sprintf("($%d::text)", i+1))
2564 args = append(args, []byte(strconv.Itoa(i)))
2565 }
2566 sql := "values" + strings.Join(params, ", ")
2567
2568 result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
2569 require.NoError(t, result.Err)
2570 require.Len(t, result.Rows, paramCount)
2571
2572 ensureConnValid(t, pgConn)
2573 }
2574
2575 func TestConnCheckConn(t *testing.T) {
2576 t.Parallel()
2577
2578 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2579 defer cancel()
2580
2581
2582
2583 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
2584 if connString == "" {
2585 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
2586 }
2587
2588 c1, err := pgconn.Connect(ctx, connString)
2589 require.NoError(t, err)
2590 defer c1.Close(ctx)
2591
2592 if c1.ParameterStatus("crdb_version") != "" {
2593 t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
2594 }
2595
2596 err = c1.CheckConn()
2597 require.NoError(t, err)
2598
2599 c2, err := pgconn.Connect(ctx, connString)
2600 require.NoError(t, err)
2601 defer c2.Close(ctx)
2602
2603 _, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
2604 require.NoError(t, err)
2605
2606
2607
2608 for err == nil && ctx.Err() == nil {
2609 time.Sleep(50 * time.Millisecond)
2610 err = c1.CheckConn()
2611 }
2612 require.Error(t, err)
2613 }
2614
2615 func TestConnPing(t *testing.T) {
2616 t.Parallel()
2617
2618 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2619 defer cancel()
2620
2621
2622
2623 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
2624 if connString == "" {
2625 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
2626 }
2627
2628 c1, err := pgconn.Connect(ctx, connString)
2629 require.NoError(t, err)
2630 defer c1.Close(ctx)
2631
2632 if c1.ParameterStatus("crdb_version") != "" {
2633 t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
2634 }
2635
2636 err = c1.Exec(ctx, "set log_statement = 'all'").Close()
2637 require.NoError(t, err)
2638
2639 err = c1.Ping(ctx)
2640 require.NoError(t, err)
2641
2642 c2, err := pgconn.Connect(ctx, connString)
2643 require.NoError(t, err)
2644 defer c2.Close(ctx)
2645
2646 _, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
2647 require.NoError(t, err)
2648
2649
2650 time.Sleep(500 * time.Millisecond)
2651
2652 err = c1.Ping(ctx)
2653 require.Error(t, err)
2654 }
2655
2656 func TestPipelinePrepare(t *testing.T) {
2657 t.Parallel()
2658
2659 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2660 defer cancel()
2661
2662 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2663 require.NoError(t, err)
2664 defer closeConn(t, pgConn)
2665
2666 result := pgConn.ExecParams(ctx, `create temporary table t (id text primary key)`, nil, nil, nil, nil).Read()
2667 require.NoError(t, result.Err)
2668
2669 pipeline := pgConn.StartPipeline(ctx)
2670 pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
2671 pipeline.SendPrepare("selectText", "select $1::text as b", nil)
2672 pipeline.SendPrepare("selectNoParams", "select 42 as c", nil)
2673 pipeline.SendPrepare("insertNoResults", "insert into t (id) values ($1)", nil)
2674 pipeline.SendPrepare("insertNoParamsOrResults", "insert into t (id) values ('foo')", nil)
2675 err = pipeline.Sync()
2676 require.NoError(t, err)
2677
2678 results, err := pipeline.GetResults()
2679 require.NoError(t, err)
2680 sd, ok := results.(*pgconn.StatementDescription)
2681 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2682 require.Len(t, sd.Fields, 1)
2683 require.Equal(t, string(sd.Fields[0].Name), "a")
2684 require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
2685
2686 results, err = pipeline.GetResults()
2687 require.NoError(t, err)
2688 sd, ok = results.(*pgconn.StatementDescription)
2689 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2690 require.Len(t, sd.Fields, 1)
2691 require.Equal(t, string(sd.Fields[0].Name), "b")
2692 require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
2693
2694 results, err = pipeline.GetResults()
2695 require.NoError(t, err)
2696 sd, ok = results.(*pgconn.StatementDescription)
2697 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2698 require.Len(t, sd.Fields, 1)
2699 require.Equal(t, string(sd.Fields[0].Name), "c")
2700 require.Equal(t, []uint32{}, sd.ParamOIDs)
2701
2702 results, err = pipeline.GetResults()
2703 require.NoError(t, err)
2704 sd, ok = results.(*pgconn.StatementDescription)
2705 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2706 require.Len(t, sd.Fields, 0)
2707 require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
2708
2709 results, err = pipeline.GetResults()
2710 require.NoError(t, err)
2711 sd, ok = results.(*pgconn.StatementDescription)
2712 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2713 require.Len(t, sd.Fields, 0)
2714 require.Len(t, sd.ParamOIDs, 0)
2715
2716 results, err = pipeline.GetResults()
2717 require.NoError(t, err)
2718 _, ok = results.(*pgconn.PipelineSync)
2719 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
2720
2721 results, err = pipeline.GetResults()
2722 require.NoError(t, err)
2723 require.Nil(t, results)
2724
2725 err = pipeline.Close()
2726 require.NoError(t, err)
2727
2728 ensureConnValid(t, pgConn)
2729 }
2730
2731 func TestPipelinePrepareError(t *testing.T) {
2732 t.Parallel()
2733
2734 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2735 defer cancel()
2736
2737 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2738 require.NoError(t, err)
2739 defer closeConn(t, pgConn)
2740
2741 pipeline := pgConn.StartPipeline(ctx)
2742 pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
2743 pipeline.SendPrepare("selectError", "bad", nil)
2744 pipeline.SendPrepare("selectText", "select $1::text as b", nil)
2745 err = pipeline.Sync()
2746 require.NoError(t, err)
2747
2748 results, err := pipeline.GetResults()
2749 require.NoError(t, err)
2750 sd, ok := results.(*pgconn.StatementDescription)
2751 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2752 require.Len(t, sd.Fields, 1)
2753 require.Equal(t, string(sd.Fields[0].Name), "a")
2754 require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
2755
2756 results, err = pipeline.GetResults()
2757 var pgErr *pgconn.PgError
2758 require.ErrorAs(t, err, &pgErr)
2759 require.Nil(t, results)
2760
2761 results, err = pipeline.GetResults()
2762 require.NoError(t, err)
2763 _, ok = results.(*pgconn.PipelineSync)
2764 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
2765
2766 results, err = pipeline.GetResults()
2767 require.NoError(t, err)
2768 require.Nil(t, results)
2769
2770 err = pipeline.Close()
2771 require.NoError(t, err)
2772
2773 ensureConnValid(t, pgConn)
2774 }
2775
2776 func TestPipelinePrepareAndDeallocate(t *testing.T) {
2777 t.Parallel()
2778
2779 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2780 defer cancel()
2781
2782 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2783 require.NoError(t, err)
2784 defer closeConn(t, pgConn)
2785
2786 pipeline := pgConn.StartPipeline(ctx)
2787 pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
2788 pipeline.SendDeallocate("selectInt")
2789 err = pipeline.Sync()
2790 require.NoError(t, err)
2791
2792 results, err := pipeline.GetResults()
2793 require.NoError(t, err)
2794 sd, ok := results.(*pgconn.StatementDescription)
2795 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2796 require.Len(t, sd.Fields, 1)
2797 require.Equal(t, string(sd.Fields[0].Name), "a")
2798 require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
2799
2800 results, err = pipeline.GetResults()
2801 require.NoError(t, err)
2802 _, ok = results.(*pgconn.CloseComplete)
2803 require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
2804
2805 results, err = pipeline.GetResults()
2806 require.NoError(t, err)
2807 _, ok = results.(*pgconn.PipelineSync)
2808 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
2809
2810 results, err = pipeline.GetResults()
2811 require.NoError(t, err)
2812 require.Nil(t, results)
2813
2814 err = pipeline.Close()
2815 require.NoError(t, err)
2816
2817 ensureConnValid(t, pgConn)
2818 }
2819
2820 func TestPipelineQuery(t *testing.T) {
2821 t.Parallel()
2822
2823 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2824 defer cancel()
2825
2826 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2827 require.NoError(t, err)
2828 defer closeConn(t, pgConn)
2829
2830 pipeline := pgConn.StartPipeline(ctx)
2831 pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
2832 pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
2833 pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
2834 err = pipeline.Sync()
2835 require.NoError(t, err)
2836
2837 pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
2838 pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
2839 err = pipeline.Sync()
2840 require.NoError(t, err)
2841
2842 results, err := pipeline.GetResults()
2843 require.NoError(t, err)
2844 rr, ok := results.(*pgconn.ResultReader)
2845 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
2846 readResult := rr.Read()
2847 require.NoError(t, readResult.Err)
2848 require.Len(t, readResult.Rows, 1)
2849 require.Len(t, readResult.Rows[0], 1)
2850 require.Equal(t, "1", string(readResult.Rows[0][0]))
2851
2852 results, err = pipeline.GetResults()
2853 require.NoError(t, err)
2854 rr, ok = results.(*pgconn.ResultReader)
2855 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
2856 readResult = rr.Read()
2857 require.NoError(t, readResult.Err)
2858 require.Len(t, readResult.Rows, 1)
2859 require.Len(t, readResult.Rows[0], 1)
2860 require.Equal(t, "2", string(readResult.Rows[0][0]))
2861
2862 results, err = pipeline.GetResults()
2863 require.NoError(t, err)
2864 rr, ok = results.(*pgconn.ResultReader)
2865 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
2866 readResult = rr.Read()
2867 require.NoError(t, readResult.Err)
2868 require.Len(t, readResult.Rows, 1)
2869 require.Len(t, readResult.Rows[0], 1)
2870 require.Equal(t, "3", string(readResult.Rows[0][0]))
2871
2872 results, err = pipeline.GetResults()
2873 require.NoError(t, err)
2874 _, ok = results.(*pgconn.PipelineSync)
2875 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
2876
2877 results, err = pipeline.GetResults()
2878 require.NoError(t, err)
2879 rr, ok = results.(*pgconn.ResultReader)
2880 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
2881 readResult = rr.Read()
2882 require.NoError(t, readResult.Err)
2883 require.Len(t, readResult.Rows, 1)
2884 require.Len(t, readResult.Rows[0], 1)
2885 require.Equal(t, "4", string(readResult.Rows[0][0]))
2886
2887 results, err = pipeline.GetResults()
2888 require.NoError(t, err)
2889 rr, ok = results.(*pgconn.ResultReader)
2890 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
2891 readResult = rr.Read()
2892 require.NoError(t, readResult.Err)
2893 require.Len(t, readResult.Rows, 1)
2894 require.Len(t, readResult.Rows[0], 1)
2895 require.Equal(t, "5", string(readResult.Rows[0][0]))
2896
2897 results, err = pipeline.GetResults()
2898 require.NoError(t, err)
2899 _, ok = results.(*pgconn.PipelineSync)
2900 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
2901
2902 results, err = pipeline.GetResults()
2903 require.NoError(t, err)
2904 require.Nil(t, results)
2905
2906 err = pipeline.Close()
2907 require.NoError(t, err)
2908
2909 ensureConnValid(t, pgConn)
2910 }
2911
2912 func TestPipelinePrepareQuery(t *testing.T) {
2913 t.Parallel()
2914
2915 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2916 defer cancel()
2917
2918 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2919 require.NoError(t, err)
2920 defer closeConn(t, pgConn)
2921
2922 pipeline := pgConn.StartPipeline(ctx)
2923 pipeline.SendPrepare("ps", "select $1::text as msg", nil)
2924 pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
2925 pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
2926 err = pipeline.Sync()
2927 require.NoError(t, err)
2928
2929 results, err := pipeline.GetResults()
2930 require.NoError(t, err)
2931 sd, ok := results.(*pgconn.StatementDescription)
2932 require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
2933 require.Len(t, sd.Fields, 1)
2934 require.Equal(t, string(sd.Fields[0].Name), "msg")
2935 require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
2936
2937 results, err = pipeline.GetResults()
2938 require.NoError(t, err)
2939 rr, ok := results.(*pgconn.ResultReader)
2940 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
2941 readResult := rr.Read()
2942 require.NoError(t, readResult.Err)
2943 require.Len(t, readResult.Rows, 1)
2944 require.Len(t, readResult.Rows[0], 1)
2945 require.Equal(t, "hello", string(readResult.Rows[0][0]))
2946
2947 results, err = pipeline.GetResults()
2948 require.NoError(t, err)
2949 rr, ok = results.(*pgconn.ResultReader)
2950 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
2951 readResult = rr.Read()
2952 require.NoError(t, readResult.Err)
2953 require.Len(t, readResult.Rows, 1)
2954 require.Len(t, readResult.Rows[0], 1)
2955 require.Equal(t, "goodbye", string(readResult.Rows[0][0]))
2956
2957 results, err = pipeline.GetResults()
2958 require.NoError(t, err)
2959 _, ok = results.(*pgconn.PipelineSync)
2960 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
2961
2962 results, err = pipeline.GetResults()
2963 require.NoError(t, err)
2964 require.Nil(t, results)
2965
2966 err = pipeline.Close()
2967 require.NoError(t, err)
2968
2969 ensureConnValid(t, pgConn)
2970 }
2971
2972 func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
2973 t.Parallel()
2974
2975 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
2976 defer cancel()
2977
2978 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
2979 require.NoError(t, err)
2980 defer closeConn(t, pgConn)
2981
2982 pipeline := pgConn.StartPipeline(ctx)
2983 pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
2984 pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
2985 err = pipeline.Sync()
2986 require.NoError(t, err)
2987
2988 pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
2989 pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
2990 pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
2991 err = pipeline.Sync()
2992 require.NoError(t, err)
2993
2994 pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
2995 pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil)
2996 err = pipeline.Sync()
2997 require.NoError(t, err)
2998
2999 results, err := pipeline.GetResults()
3000 require.NoError(t, err)
3001 rr, ok := results.(*pgconn.ResultReader)
3002 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3003 readResult := rr.Read()
3004 require.NoError(t, readResult.Err)
3005 require.Len(t, readResult.Rows, 1)
3006 require.Len(t, readResult.Rows[0], 1)
3007 require.Equal(t, "1", string(readResult.Rows[0][0]))
3008
3009 results, err = pipeline.GetResults()
3010 require.NoError(t, err)
3011 rr, ok = results.(*pgconn.ResultReader)
3012 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3013 readResult = rr.Read()
3014 require.NoError(t, readResult.Err)
3015 require.Len(t, readResult.Rows, 1)
3016 require.Len(t, readResult.Rows[0], 1)
3017 require.Equal(t, "2", string(readResult.Rows[0][0]))
3018
3019 results, err = pipeline.GetResults()
3020 require.NoError(t, err)
3021 _, ok = results.(*pgconn.PipelineSync)
3022 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
3023
3024 results, err = pipeline.GetResults()
3025 require.NoError(t, err)
3026 rr, ok = results.(*pgconn.ResultReader)
3027 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3028 readResult = rr.Read()
3029 require.NoError(t, readResult.Err)
3030 require.Len(t, readResult.Rows, 1)
3031 require.Len(t, readResult.Rows[0], 1)
3032 require.Equal(t, "3", string(readResult.Rows[0][0]))
3033
3034 results, err = pipeline.GetResults()
3035 require.NoError(t, err)
3036 rr, ok = results.(*pgconn.ResultReader)
3037 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3038 readResult = rr.Read()
3039 var pgErr *pgconn.PgError
3040 require.ErrorAs(t, readResult.Err, &pgErr)
3041 require.Equal(t, "22012", pgErr.Code)
3042
3043 results, err = pipeline.GetResults()
3044 require.NoError(t, err)
3045 _, ok = results.(*pgconn.PipelineSync)
3046 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
3047
3048 results, err = pipeline.GetResults()
3049 require.NoError(t, err)
3050 rr, ok = results.(*pgconn.ResultReader)
3051 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3052 readResult = rr.Read()
3053 require.NoError(t, readResult.Err)
3054 require.Len(t, readResult.Rows, 1)
3055 require.Len(t, readResult.Rows[0], 1)
3056 require.Equal(t, "5", string(readResult.Rows[0][0]))
3057
3058 results, err = pipeline.GetResults()
3059 require.NoError(t, err)
3060 rr, ok = results.(*pgconn.ResultReader)
3061 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3062 readResult = rr.Read()
3063 require.NoError(t, readResult.Err)
3064 require.Len(t, readResult.Rows, 1)
3065 require.Len(t, readResult.Rows[0], 1)
3066 require.Equal(t, "6", string(readResult.Rows[0][0]))
3067
3068 results, err = pipeline.GetResults()
3069 require.NoError(t, err)
3070 _, ok = results.(*pgconn.PipelineSync)
3071 require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
3072
3073 err = pipeline.Close()
3074 require.NoError(t, err)
3075
3076 ensureConnValid(t, pgConn)
3077 }
3078
3079 func TestPipelineCloseReadsUnreadResults(t *testing.T) {
3080 t.Parallel()
3081
3082 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
3083 defer cancel()
3084
3085 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
3086 require.NoError(t, err)
3087 defer closeConn(t, pgConn)
3088
3089 pipeline := pgConn.StartPipeline(ctx)
3090 pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
3091 pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
3092 pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
3093 err = pipeline.Sync()
3094 require.NoError(t, err)
3095
3096 pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
3097 pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
3098 err = pipeline.Sync()
3099 require.NoError(t, err)
3100
3101 results, err := pipeline.GetResults()
3102 require.NoError(t, err)
3103 rr, ok := results.(*pgconn.ResultReader)
3104 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3105 readResult := rr.Read()
3106 require.NoError(t, readResult.Err)
3107 require.Len(t, readResult.Rows, 1)
3108 require.Len(t, readResult.Rows[0], 1)
3109 require.Equal(t, "1", string(readResult.Rows[0][0]))
3110
3111 err = pipeline.Close()
3112 require.NoError(t, err)
3113
3114 ensureConnValid(t, pgConn)
3115 }
3116
3117 func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) {
3118 t.Parallel()
3119
3120 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
3121 defer cancel()
3122
3123 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
3124 require.NoError(t, err)
3125 defer closeConn(t, pgConn)
3126
3127 pipeline := pgConn.StartPipeline(ctx)
3128 pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
3129 pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
3130 pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
3131 err = pipeline.Sync()
3132 require.NoError(t, err)
3133
3134 pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
3135 pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
3136
3137 results, err := pipeline.GetResults()
3138 require.NoError(t, err)
3139 rr, ok := results.(*pgconn.ResultReader)
3140 require.Truef(t, ok, "expected ResultReader, got: %#v", results)
3141 readResult := rr.Read()
3142 require.NoError(t, readResult.Err)
3143 require.Len(t, readResult.Rows, 1)
3144 require.Len(t, readResult.Rows[0], 1)
3145 require.Equal(t, "1", string(readResult.Rows[0][0]))
3146
3147 err = pipeline.Close()
3148 require.EqualError(t, err, "pipeline has unsynced requests")
3149 }
3150
3151 func TestConnOnPgError(t *testing.T) {
3152 t.Parallel()
3153
3154 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
3155 defer cancel()
3156
3157 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
3158 require.NoError(t, err)
3159 config.OnPgError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool {
3160 require.NotNil(t, c)
3161 require.NotNil(t, pgErr)
3162
3163 if pgErr.Code == "42P01" {
3164 return false
3165 }
3166 return true
3167 }
3168
3169 pgConn, err := pgconn.ConnectConfig(ctx, config)
3170 require.NoError(t, err)
3171 defer closeConn(t, pgConn)
3172
3173 _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
3174 assert.NoError(t, err)
3175 assert.False(t, pgConn.IsClosed())
3176
3177 _, err = pgConn.Exec(ctx, "select 1/0").ReadAll()
3178 assert.Error(t, err)
3179 assert.False(t, pgConn.IsClosed())
3180
3181 _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll()
3182 assert.Error(t, err)
3183 assert.True(t, pgConn.IsClosed())
3184 }
3185
3186 func Example() {
3187 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
3188 defer cancel()
3189
3190 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
3191 if err != nil {
3192 log.Fatalln(err)
3193 }
3194 defer pgConn.Close(ctx)
3195
3196 result := pgConn.ExecParams(ctx, "select generate_series(1,3)", nil, nil, nil, nil).Read()
3197 if result.Err != nil {
3198 log.Fatalln(result.Err)
3199 }
3200
3201 for _, row := range result.Rows {
3202 fmt.Println(string(row[0]))
3203 }
3204
3205 fmt.Println(result.CommandTag)
3206
3207
3208
3209
3210
3211 }
3212
3213 func GetSSLPassword(ctx context.Context) string {
3214 connString := os.Getenv("PGX_SSL_PASSWORD")
3215 return connString
3216 }
3217
3218 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
3219 MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
3220 BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
3221 NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
3222 AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
3223 Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
3224 tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
3225 9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
3226 0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
3227 MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
3228 FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
3229 6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
3230 gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
3231 81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
3232 Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
3233 hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
3234 VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
3235 MlascjupnaptKX/wMA==
3236 -----END CERTIFICATE-----
3237 `
3238
3239 var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
3240 MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
3241 ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
3242 Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
3243 bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
3244 qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
3245 Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
3246 o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
3247 WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
3248 ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
3249 Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
3250 QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
3251 QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
3252 CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
3253 bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
3254 1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
3255 SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
3256 MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
3257 McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
3258 I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
3259 QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
3260 k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
3261 lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
3262 TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
3263 5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
3264 UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
3265 z3w+CgS20UrbLIR1YXfqUXge1g==
3266 -----END TESTING KEY-----
3267 `)
3268
3269 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
3270
3271 func TestSNISupport(t *testing.T) {
3272 t.Parallel()
3273 tests := []struct {
3274 name string
3275 sni_param string
3276 sni_set bool
3277 }{
3278 {
3279 name: "SNI is passed by default",
3280 sni_param: "",
3281 sni_set: true,
3282 },
3283 {
3284 name: "SNI is passed when asked for",
3285 sni_param: "sslsni=1",
3286 sni_set: true,
3287 },
3288 {
3289 name: "SNI is not passed when disabled",
3290 sni_param: "sslsni=0",
3291 sni_set: false,
3292 },
3293 }
3294 for _, tt := range tests {
3295 tt := tt
3296 t.Run(tt.name, func(t *testing.T) {
3297 t.Parallel()
3298
3299 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
3300 defer cancel()
3301
3302 ln, err := net.Listen("tcp", "127.0.0.1:")
3303 require.NoError(t, err)
3304 defer ln.Close()
3305
3306 serverErrChan := make(chan error, 1)
3307 serverSNINameChan := make(chan string, 1)
3308 defer close(serverErrChan)
3309 defer close(serverSNINameChan)
3310
3311 go func() {
3312 var sniHost string
3313
3314 conn, err := ln.Accept()
3315 if err != nil {
3316 serverErrChan <- err
3317 return
3318 }
3319 defer conn.Close()
3320
3321 err = conn.SetDeadline(time.Now().Add(5 * time.Second))
3322 if err != nil {
3323 serverErrChan <- err
3324 return
3325 }
3326
3327 backend := pgproto3.NewBackend(conn, conn)
3328 startupMessage, err := backend.ReceiveStartupMessage()
3329 if err != nil {
3330 serverErrChan <- err
3331 return
3332 }
3333
3334 switch startupMessage.(type) {
3335 case *pgproto3.SSLRequest:
3336 _, err = conn.Write([]byte("S"))
3337 if err != nil {
3338 serverErrChan <- err
3339 return
3340 }
3341 default:
3342 serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
3343 return
3344 }
3345
3346 cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
3347 if err != nil {
3348 serverErrChan <- err
3349 return
3350 }
3351
3352 srv := tls.Server(conn, &tls.Config{
3353 Certificates: []tls.Certificate{cert},
3354 GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
3355 sniHost = argHello.ServerName
3356 return nil, nil
3357 },
3358 })
3359 defer srv.Close()
3360
3361 if err := srv.Handshake(); err != nil {
3362 serverErrChan <- fmt.Errorf("handshake: %w", err)
3363 return
3364 }
3365
3366 srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
3367 srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
3368 srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
3369
3370 serverSNINameChan <- sniHost
3371 }()
3372
3373 _, port, _ := strings.Cut(ln.Addr().String(), ":")
3374 connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
3375 _, err = pgconn.Connect(ctx, connStr)
3376
3377 select {
3378 case sniHost := <-serverSNINameChan:
3379 if tt.sni_set {
3380 require.Equal(t, sniHost, "localhost")
3381 } else {
3382 require.Equal(t, sniHost, "")
3383 }
3384 case err = <-serverErrChan:
3385 t.Fatalf("server failed with error: %+v", err)
3386 case <-time.After(time.Millisecond * 100):
3387 t.Fatal("exceeded connection timeout without erroring out")
3388 }
3389 })
3390 }
3391 }
3392
3393
3394 func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
3395 t.Parallel()
3396
3397 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
3398 defer cancel()
3399
3400 steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
3401 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
3402 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
3403 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
3404 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
3405 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
3406 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
3407 steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
3408 {Name: []byte("mock")},
3409 }}))
3410 steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
3411
3412
3413 steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
3414
3415 script := &pgmock.Script{Steps: steps}
3416
3417 ln, err := net.Listen("tcp", "127.0.0.1:")
3418 require.NoError(t, err)
3419 defer ln.Close()
3420
3421 serverKeepAlive := make(chan struct{})
3422 defer close(serverKeepAlive)
3423
3424 serverErrChan := make(chan error, 1)
3425 go func() {
3426 defer close(serverErrChan)
3427
3428 conn, err := ln.Accept()
3429 if err != nil {
3430 serverErrChan <- err
3431 return
3432 }
3433 defer conn.Close()
3434
3435 err = conn.SetDeadline(time.Now().Add(59 * time.Second))
3436 if err != nil {
3437 serverErrChan <- err
3438 return
3439 }
3440
3441 err = script.Run(pgproto3.NewBackend(conn, conn))
3442 if err != nil {
3443 serverErrChan <- err
3444 return
3445 }
3446
3447 <-serverKeepAlive
3448 }()
3449
3450 parts := strings.Split(ln.Addr().String(), ":")
3451 host := parts[0]
3452 port := parts[1]
3453 connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
3454
3455 ctx, cancel = context.WithTimeout(ctx, 59*time.Second)
3456 defer cancel()
3457 conn, err := pgconn.Connect(ctx, connStr)
3458 require.NoError(t, err)
3459
3460 pipeline := conn.StartPipeline(ctx)
3461 pipeline.SendPrepare("s1", "select 1", nil)
3462 pipeline.SendPrepare("s2", "select 2", nil)
3463 pipeline.SendPrepare("s3", "select 3", nil)
3464 err = pipeline.Sync()
3465 require.NoError(t, err)
3466
3467 _, err = pipeline.GetResults()
3468 require.NoError(t, err)
3469 _, err = pipeline.GetResults()
3470 require.Error(t, err)
3471
3472 err = pipeline.Close()
3473 require.Error(t, err)
3474 }
3475
3476 func mustEncode(buf []byte, err error) []byte {
3477 if err != nil {
3478 panic(err)
3479 }
3480 return buf
3481 }
3482
View as plain text