1 package pgconn_test
2
3 import (
4 "bytes"
5 "compress/gzip"
6 "context"
7 "crypto/tls"
8 "errors"
9 "fmt"
10 "io"
11 "io/ioutil"
12 "log"
13 "math"
14 "net"
15 "os"
16 "strconv"
17 "strings"
18 "testing"
19 "time"
20
21 "github.com/jackc/pgconn"
22 "github.com/jackc/pgmock"
23 "github.com/jackc/pgproto3/v2"
24 "github.com/stretchr/testify/assert"
25 "github.com/stretchr/testify/require"
26 )
27
28 func TestConnect(t *testing.T) {
29 tests := []struct {
30 name string
31 env string
32 }{
33 {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
34 {"TCP", "PGX_TEST_TCP_CONN_STRING"},
35 {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
36 {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
37 {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
38 }
39
40 for _, tt := range tests {
41 tt := tt
42 t.Run(tt.name, func(t *testing.T) {
43 connString := os.Getenv(tt.env)
44 if connString == "" {
45 t.Skipf("Skipping due to missing environment variable %v", tt.env)
46 }
47
48 conn, err := pgconn.Connect(context.Background(), connString)
49 require.NoError(t, err)
50
51 closeConn(t, conn)
52 })
53 }
54 }
55
56 func TestConnectWithOptions(t *testing.T) {
57 tests := []struct {
58 name string
59 env string
60 }{
61 {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
62 {"TCP", "PGX_TEST_TCP_CONN_STRING"},
63 {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
64 {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
65 {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
66 }
67
68 for _, tt := range tests {
69 tt := tt
70 t.Run(tt.name, func(t *testing.T) {
71 connString := os.Getenv(tt.env)
72 if connString == "" {
73 t.Skipf("Skipping due to missing environment variable %v", tt.env)
74 }
75 var sslOptions pgconn.ParseConfigOptions
76 sslOptions.GetSSLPassword = GetSSLPassword
77 conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions)
78 require.NoError(t, err)
79
80 closeConn(t, conn)
81 })
82 }
83 }
84
85
86
87 func TestConnectTLS(t *testing.T) {
88 t.Parallel()
89
90 connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
91 if connString == "" {
92 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
93 }
94
95 var conn *pgconn.PgConn
96 var err error
97
98 var sslOptions pgconn.ParseConfigOptions
99 sslOptions.GetSSLPassword = GetSSLPassword
100 config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
101 require.Nil(t, err)
102
103 conn, err = pgconn.ConnectConfig(context.Background(), config)
104 require.NoError(t, err)
105
106 if _, ok := conn.Conn().(*tls.Conn); !ok {
107 t.Error("not a TLS connection")
108 }
109
110 closeConn(t, conn)
111 }
112
113 type pgmockWaitStep time.Duration
114
115 func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
116 time.Sleep(time.Duration(s))
117 return nil
118 }
119
120 func TestConnectTimeout(t *testing.T) {
121 t.Parallel()
122 tests := []struct {
123 name string
124 connect func(connStr string) error
125 }{
126 {
127 name: "via context that times out",
128 connect: func(connStr string) error {
129 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
130 defer cancel()
131 _, err := pgconn.Connect(ctx, connStr)
132 return err
133 },
134 },
135 {
136 name: "via config ConnectTimeout",
137 connect: func(connStr string) error {
138 conf, err := pgconn.ParseConfig(connStr)
139 require.NoError(t, err)
140 conf.ConnectTimeout = time.Microsecond * 50
141 _, err = pgconn.ConnectConfig(context.Background(), conf)
142 return err
143 },
144 },
145 }
146 for _, tt := range tests {
147 tt := tt
148 t.Run(tt.name, func(t *testing.T) {
149 t.Parallel()
150 script := &pgmock.Script{
151 Steps: []pgmock.Step{
152 pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
153 pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
154 pgmockWaitStep(time.Millisecond * 500),
155 pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
156 pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
157 },
158 }
159
160 ln, err := net.Listen("tcp", "127.0.0.1:")
161 require.NoError(t, err)
162 defer ln.Close()
163
164 serverErrChan := make(chan error, 1)
165 go func() {
166 defer close(serverErrChan)
167
168 conn, err := ln.Accept()
169 if err != nil {
170 serverErrChan <- err
171 return
172 }
173 defer conn.Close()
174
175 err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
176 if err != nil {
177 serverErrChan <- err
178 return
179 }
180
181 err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
182 if err != nil {
183 serverErrChan <- err
184 return
185 }
186 }()
187
188 parts := strings.Split(ln.Addr().String(), ":")
189 host := parts[0]
190 port := parts[1]
191 connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
192 tooLate := time.Now().Add(time.Millisecond * 500)
193
194 err = tt.connect(connStr)
195 require.True(t, pgconn.Timeout(err), err)
196 require.True(t, time.Now().Before(tooLate))
197 })
198 }
199 }
200
201 func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
202 t.Parallel()
203 tests := []struct {
204 name string
205 connect func(connStr string) error
206 }{
207 {
208 name: "via context that times out",
209 connect: func(connStr string) error {
210 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
211 defer cancel()
212 _, err := pgconn.Connect(ctx, connStr)
213 return err
214 },
215 },
216 {
217 name: "via config ConnectTimeout",
218 connect: func(connStr string) error {
219 conf, err := pgconn.ParseConfig(connStr)
220 require.NoError(t, err)
221 conf.ConnectTimeout = time.Millisecond * 10
222 _, err = pgconn.ConnectConfig(context.Background(), conf)
223 return err
224 },
225 },
226 }
227 for _, tt := range tests {
228 tt := tt
229 t.Run(tt.name, func(t *testing.T) {
230 t.Parallel()
231 ln, err := net.Listen("tcp", "127.0.0.1:")
232 require.NoError(t, err)
233 defer ln.Close()
234
235 serverErrChan := make(chan error)
236 defer close(serverErrChan)
237 go func() {
238 conn, err := ln.Accept()
239 if err != nil {
240 serverErrChan <- err
241 return
242 }
243 defer conn.Close()
244
245 var buf []byte
246 _, err = conn.Read(buf)
247 if err != nil {
248 serverErrChan <- err
249 return
250 }
251
252
253 time.Sleep(time.Minute)
254 }()
255
256 parts := strings.Split(ln.Addr().String(), ":")
257 host := parts[0]
258 port := parts[1]
259 connStr := fmt.Sprintf("host=%s port=%s", host, port)
260
261 errChan := make(chan error)
262 go func() {
263 err := tt.connect(connStr)
264 errChan <- err
265 }()
266
267 select {
268 case err = <-errChan:
269 require.True(t, pgconn.Timeout(err), err)
270 case err = <-serverErrChan:
271 t.Fatalf("server failed with error: %s", err)
272 case <-time.After(time.Millisecond * 100):
273 t.Fatal("exceeded connection timeout without erroring out")
274 }
275 })
276 }
277 }
278
279 func TestConnectInvalidUser(t *testing.T) {
280 t.Parallel()
281
282 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
283 if connString == "" {
284 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
285 }
286
287 config, err := pgconn.ParseConfig(connString)
288 require.NoError(t, err)
289
290 config.User = "pgxinvalidusertest"
291
292 _, err = pgconn.ConnectConfig(context.Background(), config)
293 require.Error(t, err)
294 pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
295 if !ok {
296 t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
297 }
298 if pgErr.Code != "28000" && pgErr.Code != "28P01" {
299 t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
300 }
301 }
302
303 func TestConnectWithConnectionRefused(t *testing.T) {
304 t.Parallel()
305
306
307 conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1")
308 if err == nil {
309 conn.Close(context.Background())
310 t.Fatal("Expected error establishing connection to bad port")
311 }
312 }
313
314 func TestConnectCustomDialer(t *testing.T) {
315 t.Parallel()
316
317 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
318 require.NoError(t, err)
319
320 dialed := false
321 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
322 dialed = true
323 return net.Dial(network, address)
324 }
325
326 conn, err := pgconn.ConnectConfig(context.Background(), config)
327 require.NoError(t, err)
328 require.True(t, dialed)
329 closeConn(t, conn)
330 }
331
332 func TestConnectCustomLookup(t *testing.T) {
333 t.Parallel()
334
335 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
336 if connString == "" {
337 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
338 }
339
340 config, err := pgconn.ParseConfig(connString)
341 require.NoError(t, err)
342
343 looked := false
344 config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
345 looked = true
346 return net.LookupHost(host)
347 }
348
349 conn, err := pgconn.ConnectConfig(context.Background(), config)
350 require.NoError(t, err)
351 require.True(t, looked)
352 closeConn(t, conn)
353 }
354
355 func TestConnectCustomLookupWithPort(t *testing.T) {
356 t.Parallel()
357
358 connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
359 if connString == "" {
360 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
361 }
362
363 config, err := pgconn.ParseConfig(connString)
364 require.NoError(t, err)
365
366 origPort := config.Port
367
368 config.Port = 0
369
370 looked := false
371 config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
372 looked = true
373 addrs, err := net.LookupHost(host)
374 if err != nil {
375 return nil, err
376 }
377 for i := range addrs {
378 addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
379 }
380 return addrs, nil
381 }
382
383 conn, err := pgconn.ConnectConfig(context.Background(), config)
384 require.NoError(t, err)
385 require.True(t, looked)
386 closeConn(t, conn)
387 }
388
389 func TestConnectWithRuntimeParams(t *testing.T) {
390 t.Parallel()
391
392 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
393 require.NoError(t, err)
394
395 config.RuntimeParams = map[string]string{
396 "application_name": "pgxtest",
397 "search_path": "myschema",
398 }
399
400 conn, err := pgconn.ConnectConfig(context.Background(), config)
401 require.NoError(t, err)
402 defer closeConn(t, conn)
403
404 result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read()
405 require.Nil(t, result.Err)
406 assert.Equal(t, 1, len(result.Rows))
407 assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
408
409 result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read()
410 require.Nil(t, result.Err)
411 assert.Equal(t, 1, len(result.Rows))
412 assert.Equal(t, "myschema", string(result.Rows[0][0]))
413 }
414
415 func TestConnectWithFallback(t *testing.T) {
416 t.Parallel()
417
418 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
419 require.NoError(t, err)
420
421
422 config.Fallbacks = append([]*pgconn.FallbackConfig{
423 &pgconn.FallbackConfig{
424 Host: config.Host,
425 Port: config.Port,
426 TLSConfig: config.TLSConfig,
427 },
428 }, config.Fallbacks...)
429
430
431 config.Host = "localhost"
432 config.Port = 1
433
434
435 config.Fallbacks = append([]*pgconn.FallbackConfig{
436 &pgconn.FallbackConfig{
437 Host: "localhost",
438 Port: 1,
439 TLSConfig: config.TLSConfig,
440 },
441 }, config.Fallbacks...)
442
443 conn, err := pgconn.ConnectConfig(context.Background(), config)
444 require.NoError(t, err)
445 closeConn(t, conn)
446 }
447
448 func TestConnectWithValidateConnect(t *testing.T) {
449 t.Parallel()
450
451 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
452 require.NoError(t, err)
453
454 dialCount := 0
455 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
456 dialCount++
457 return net.Dial(network, address)
458 }
459
460 acceptConnCount := 0
461 config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
462 acceptConnCount++
463 if acceptConnCount < 2 {
464 return errors.New("reject first conn")
465 }
466 return nil
467 }
468
469
470 config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
471 Host: config.Host,
472 Port: config.Port,
473 TLSConfig: config.TLSConfig,
474 })
475
476
477 config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
478
479 conn, err := pgconn.ConnectConfig(context.Background(), config)
480 require.NoError(t, err)
481 closeConn(t, conn)
482
483 assert.True(t, dialCount > 1)
484 assert.True(t, acceptConnCount > 1)
485 }
486
487 func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
488 t.Parallel()
489
490 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
491 require.NoError(t, err)
492
493 config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
494 config.RuntimeParams["default_transaction_read_only"] = "on"
495
496 ctx, cancel := context.WithCancel(context.Background())
497 defer cancel()
498
499 conn, err := pgconn.ConnectConfig(ctx, config)
500 if !assert.NotNil(t, err) {
501 conn.Close(ctx)
502 }
503 }
504
505 func TestConnectWithAfterConnect(t *testing.T) {
506 t.Parallel()
507
508 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
509 require.NoError(t, err)
510
511 config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
512 _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
513 return err
514 }
515
516 conn, err := pgconn.ConnectConfig(context.Background(), config)
517 require.NoError(t, err)
518
519 results, err := conn.Exec(context.Background(), "show search_path;").ReadAll()
520 require.NoError(t, err)
521 defer closeConn(t, conn)
522
523 assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
524 }
525
526 func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
527 t.Parallel()
528
529 config := &pgconn.Config{}
530
531 require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) })
532 }
533
534 func TestConnPrepareSyntaxError(t *testing.T) {
535 t.Parallel()
536
537 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
538 require.NoError(t, err)
539 defer closeConn(t, pgConn)
540
541 psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil)
542 require.Nil(t, psd)
543 require.NotNil(t, err)
544
545 ensureConnValid(t, pgConn)
546 }
547
548 func TestConnPrepareContextPrecanceled(t *testing.T) {
549 t.Parallel()
550
551 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
552 require.NoError(t, err)
553 defer closeConn(t, pgConn)
554
555 ctx, cancel := context.WithCancel(context.Background())
556 cancel()
557 psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
558 assert.Nil(t, psd)
559 assert.Error(t, err)
560 assert.True(t, errors.Is(err, context.Canceled))
561 assert.True(t, pgconn.SafeToRetry(err))
562
563 ensureConnValid(t, pgConn)
564 }
565
566 func TestConnExec(t *testing.T) {
567 t.Parallel()
568
569 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
570 require.NoError(t, err)
571 defer closeConn(t, pgConn)
572
573 results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
574 assert.NoError(t, err)
575
576 assert.Len(t, results, 1)
577 assert.Nil(t, results[0].Err)
578 assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
579 assert.Len(t, results[0].Rows, 1)
580 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
581
582 ensureConnValid(t, pgConn)
583 }
584
585 func TestConnExecEmpty(t *testing.T) {
586 t.Parallel()
587
588 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
589 require.NoError(t, err)
590 defer closeConn(t, pgConn)
591
592 multiResult := pgConn.Exec(context.Background(), ";")
593
594 resultCount := 0
595 for multiResult.NextResult() {
596 resultCount++
597 multiResult.ResultReader().Close()
598 }
599 assert.Equal(t, 0, resultCount)
600 err = multiResult.Close()
601 assert.NoError(t, err)
602
603 ensureConnValid(t, pgConn)
604 }
605
606 func TestConnExecMultipleQueries(t *testing.T) {
607 t.Parallel()
608
609 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
610 require.NoError(t, err)
611 defer closeConn(t, pgConn)
612
613 results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll()
614 assert.NoError(t, err)
615
616 assert.Len(t, results, 2)
617
618 assert.Nil(t, results[0].Err)
619 assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
620 assert.Len(t, results[0].Rows, 1)
621 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
622
623 assert.Nil(t, results[1].Err)
624 assert.Equal(t, "SELECT 1", string(results[1].CommandTag))
625 assert.Len(t, results[1].Rows, 1)
626 assert.Equal(t, "1", string(results[1].Rows[0][0]))
627
628 ensureConnValid(t, pgConn)
629 }
630
631 func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
632 t.Parallel()
633
634 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
635 require.NoError(t, err)
636 defer closeConn(t, pgConn)
637
638 mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num")
639
640 require.True(t, mrr.NextResult())
641 require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
642 assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name)
643 _, err = mrr.ResultReader().Close()
644 require.NoError(t, err)
645
646 require.True(t, mrr.NextResult())
647 require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
648 assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name)
649 _, err = mrr.ResultReader().Close()
650 require.NoError(t, err)
651
652 require.False(t, mrr.NextResult())
653
654 require.NoError(t, mrr.Close())
655
656 ensureConnValid(t, pgConn)
657 }
658
659 func TestConnExecMultipleQueriesError(t *testing.T) {
660 t.Parallel()
661
662 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
663 require.NoError(t, err)
664 defer closeConn(t, pgConn)
665
666 results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll()
667 require.NotNil(t, err)
668 if pgErr, ok := err.(*pgconn.PgError); ok {
669 assert.Equal(t, "22012", pgErr.Code)
670 } else {
671 t.Errorf("unexpected error: %v", err)
672 }
673
674 if pgConn.ParameterStatus("crdb_version") != "" {
675
676 require.Len(t, results, 2)
677 assert.Len(t, results[0].Rows, 1)
678 assert.Equal(t, "1", string(results[0].Rows[0][0]))
679 assert.Len(t, results[1].Rows, 0)
680 } else {
681
682 require.Len(t, results, 1)
683 assert.Len(t, results[0].Rows, 1)
684 assert.Equal(t, "1", string(results[0].Rows[0][0]))
685 }
686
687 ensureConnValid(t, pgConn)
688 }
689
690 func TestConnExecDeferredError(t *testing.T) {
691 t.Parallel()
692
693 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
694 require.NoError(t, err)
695 defer closeConn(t, pgConn)
696
697 if pgConn.ParameterStatus("crdb_version") != "" {
698 t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
699 }
700
701 setupSQL := `create temporary table t (
702 id text primary key,
703 n int not null,
704 unique (n) deferrable initially deferred
705 );
706
707 insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
708
709 _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
710 assert.NoError(t, err)
711
712 _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll()
713 require.NotNil(t, err)
714
715 var pgErr *pgconn.PgError
716 require.True(t, errors.As(err, &pgErr))
717 require.Equal(t, "23505", pgErr.Code)
718
719 ensureConnValid(t, pgConn)
720 }
721
722 func TestConnExecContextCanceled(t *testing.T) {
723 t.Parallel()
724
725 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
726 require.NoError(t, err)
727 defer closeConn(t, pgConn)
728
729 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
730 defer cancel()
731 multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)")
732
733 for multiResult.NextResult() {
734 }
735 err = multiResult.Close()
736 assert.True(t, pgconn.Timeout(err))
737 assert.ErrorIs(t, err, context.DeadlineExceeded)
738 assert.True(t, pgConn.IsClosed())
739 select {
740 case <-pgConn.CleanupDone():
741 case <-time.After(5 * time.Second):
742 t.Fatal("Connection cleanup exceeded maximum time")
743 }
744 }
745
746 func TestConnExecContextPrecanceled(t *testing.T) {
747 t.Parallel()
748
749 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
750 require.NoError(t, err)
751 defer closeConn(t, pgConn)
752
753 ctx, cancel := context.WithCancel(context.Background())
754 cancel()
755 _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
756 assert.Error(t, err)
757 assert.True(t, errors.Is(err, context.Canceled))
758 assert.True(t, pgconn.SafeToRetry(err))
759
760 ensureConnValid(t, pgConn)
761 }
762
763 func TestConnExecParams(t *testing.T) {
764 t.Parallel()
765
766 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
767 require.NoError(t, err)
768 defer closeConn(t, pgConn)
769
770 result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
771 require.Len(t, result.FieldDescriptions(), 1)
772 assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
773
774 rowCount := 0
775 for result.NextRow() {
776 rowCount += 1
777 assert.Equal(t, "Hello, world", string(result.Values()[0]))
778 }
779 assert.Equal(t, 1, rowCount)
780 commandTag, err := result.Close()
781 assert.Equal(t, "SELECT 1", string(commandTag))
782 assert.NoError(t, err)
783
784 ensureConnValid(t, pgConn)
785 }
786
787 func TestConnExecParamsDeferredError(t *testing.T) {
788 t.Parallel()
789
790 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
791 require.NoError(t, err)
792 defer closeConn(t, pgConn)
793
794 if pgConn.ParameterStatus("crdb_version") != "" {
795 t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
796 }
797
798 setupSQL := `create temporary table t (
799 id text primary key,
800 n int not null,
801 unique (n) deferrable initially deferred
802 );
803
804 insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
805
806 _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
807 assert.NoError(t, err)
808
809 result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
810 require.NotNil(t, result.Err)
811 var pgErr *pgconn.PgError
812 require.True(t, errors.As(result.Err, &pgErr))
813 require.Equal(t, "23505", pgErr.Code)
814
815 ensureConnValid(t, pgConn)
816 }
817
818 func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
819 t.Parallel()
820
821 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
822 require.NoError(t, err)
823 defer closeConn(t, pgConn)
824
825 paramCount := math.MaxUint16
826 params := make([]string, 0, paramCount)
827 args := make([][]byte, 0, paramCount)
828 for i := 0; i < paramCount; i++ {
829 params = append(params, fmt.Sprintf("($%d::text)", i+1))
830 args = append(args, []byte(strconv.Itoa(i)))
831 }
832 sql := "values" + strings.Join(params, ", ")
833
834 result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
835 require.NoError(t, result.Err)
836 require.Len(t, result.Rows, paramCount)
837
838 ensureConnValid(t, pgConn)
839 }
840
841 func TestConnExecParamsTooManyParams(t *testing.T) {
842 t.Parallel()
843
844 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
845 require.NoError(t, err)
846 defer closeConn(t, pgConn)
847
848 paramCount := math.MaxUint16 + 1
849 params := make([]string, 0, paramCount)
850 args := make([][]byte, 0, paramCount)
851 for i := 0; i < paramCount; i++ {
852 params = append(params, fmt.Sprintf("($%d::text)", i+1))
853 args = append(args, []byte(strconv.Itoa(i)))
854 }
855 sql := "values" + strings.Join(params, ", ")
856
857 result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
858 require.Error(t, result.Err)
859 require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
860
861 ensureConnValid(t, pgConn)
862 }
863
864 func TestConnExecParamsCanceled(t *testing.T) {
865 t.Parallel()
866
867 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
868 require.NoError(t, err)
869 defer closeConn(t, pgConn)
870
871 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
872 defer cancel()
873 result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
874 rowCount := 0
875 for result.NextRow() {
876 rowCount += 1
877 }
878 assert.Equal(t, 0, rowCount)
879 commandTag, err := result.Close()
880 assert.Equal(t, pgconn.CommandTag(nil), commandTag)
881 assert.True(t, pgconn.Timeout(err))
882 assert.ErrorIs(t, err, context.DeadlineExceeded)
883
884 assert.True(t, pgConn.IsClosed())
885 select {
886 case <-pgConn.CleanupDone():
887 case <-time.After(5 * time.Second):
888 t.Fatal("Connection cleanup exceeded maximum time")
889 }
890 }
891
892 func TestConnExecParamsPrecanceled(t *testing.T) {
893 t.Parallel()
894
895 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
896 require.NoError(t, err)
897 defer closeConn(t, pgConn)
898
899 ctx, cancel := context.WithCancel(context.Background())
900 cancel()
901 result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
902 require.Error(t, result.Err)
903 assert.True(t, errors.Is(result.Err, context.Canceled))
904 assert.True(t, pgconn.SafeToRetry(result.Err))
905
906 ensureConnValid(t, pgConn)
907 }
908
909 func TestConnExecParamsEmptySQL(t *testing.T) {
910 t.Parallel()
911
912 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
913 defer cancel()
914
915 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
916 require.NoError(t, err)
917 defer closeConn(t, pgConn)
918
919 result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
920 assert.Nil(t, result.CommandTag)
921 assert.Len(t, result.Rows, 0)
922 assert.NoError(t, result.Err)
923
924 ensureConnValid(t, pgConn)
925 }
926
927
928 func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
929 t.Parallel()
930
931 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
932 require.NoError(t, err)
933 defer closeConn(t, pgConn)
934
935 result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
936 require.Len(t, result.FieldDescriptions(), 1)
937 assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
938
939 rowCount := 0
940 for result.NextRow() {
941 rowCount += 1
942 assert.Equal(t, "Hello, world", string(result.Values()[0]))
943 assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0]))
944 }
945 assert.Equal(t, 1, rowCount)
946 commandTag, err := result.Close()
947 assert.Equal(t, "SELECT 1", string(commandTag))
948 assert.NoError(t, err)
949
950 ensureConnValid(t, pgConn)
951 }
952
953 func TestConnExecPrepared(t *testing.T) {
954 t.Parallel()
955
956 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
957 require.NoError(t, err)
958 defer closeConn(t, pgConn)
959
960 psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil)
961 require.NoError(t, err)
962 require.NotNil(t, psd)
963 assert.Len(t, psd.ParamOIDs, 1)
964 assert.Len(t, psd.Fields, 1)
965
966 result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
967 require.Len(t, result.FieldDescriptions(), 1)
968 assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
969
970 rowCount := 0
971 for result.NextRow() {
972 rowCount += 1
973 assert.Equal(t, "Hello, world", string(result.Values()[0]))
974 }
975 assert.Equal(t, 1, rowCount)
976 commandTag, err := result.Close()
977 assert.Equal(t, "SELECT 1", string(commandTag))
978 assert.NoError(t, err)
979
980 ensureConnValid(t, pgConn)
981 }
982
983 func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
984 t.Parallel()
985
986 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
987 require.NoError(t, err)
988 defer closeConn(t, pgConn)
989
990 paramCount := math.MaxUint16
991 params := make([]string, 0, paramCount)
992 args := make([][]byte, 0, paramCount)
993 for i := 0; i < paramCount; i++ {
994 params = append(params, fmt.Sprintf("($%d::text)", i+1))
995 args = append(args, []byte(strconv.Itoa(i)))
996 }
997 sql := "values" + strings.Join(params, ", ")
998
999 psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
1000 require.NoError(t, err)
1001 require.NotNil(t, psd)
1002 assert.Len(t, psd.ParamOIDs, paramCount)
1003 assert.Len(t, psd.Fields, 1)
1004
1005 result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
1006 require.NoError(t, result.Err)
1007 require.Len(t, result.Rows, paramCount)
1008
1009 ensureConnValid(t, pgConn)
1010 }
1011
1012 func TestConnExecPreparedTooManyParams(t *testing.T) {
1013 t.Parallel()
1014
1015 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1016 require.NoError(t, err)
1017 defer closeConn(t, pgConn)
1018
1019 paramCount := math.MaxUint16 + 1
1020 params := make([]string, 0, paramCount)
1021 args := make([][]byte, 0, paramCount)
1022 for i := 0; i < paramCount; i++ {
1023 params = append(params, fmt.Sprintf("($%d::text)", i+1))
1024 args = append(args, []byte(strconv.Itoa(i)))
1025 }
1026 sql := "values" + strings.Join(params, ", ")
1027
1028 psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
1029 if pgConn.ParameterStatus("crdb_version") != "" {
1030
1031 require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)")
1032 } else {
1033
1034 require.NoError(t, err)
1035 require.NotNil(t, psd)
1036 assert.Len(t, psd.ParamOIDs, paramCount)
1037 assert.Len(t, psd.Fields, 1)
1038
1039 result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
1040 require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters")
1041 }
1042
1043 ensureConnValid(t, pgConn)
1044 }
1045
1046 func TestConnExecPreparedCanceled(t *testing.T) {
1047 t.Parallel()
1048
1049 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1050 require.NoError(t, err)
1051 defer closeConn(t, pgConn)
1052
1053 _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
1054 require.NoError(t, err)
1055
1056 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
1057 defer cancel()
1058 result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
1059 rowCount := 0
1060 for result.NextRow() {
1061 rowCount += 1
1062 }
1063 assert.Equal(t, 0, rowCount)
1064 commandTag, err := result.Close()
1065 assert.Equal(t, pgconn.CommandTag(nil), commandTag)
1066 assert.True(t, pgconn.Timeout(err))
1067 assert.True(t, pgConn.IsClosed())
1068 select {
1069 case <-pgConn.CleanupDone():
1070 case <-time.After(5 * time.Second):
1071 t.Fatal("Connection cleanup exceeded maximum time")
1072 }
1073 }
1074
1075 func TestConnExecPreparedPrecanceled(t *testing.T) {
1076 t.Parallel()
1077
1078 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1079 require.NoError(t, err)
1080 defer closeConn(t, pgConn)
1081
1082 _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
1083 require.NoError(t, err)
1084
1085 ctx, cancel := context.WithCancel(context.Background())
1086 cancel()
1087 result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
1088 require.Error(t, result.Err)
1089 assert.True(t, errors.Is(result.Err, context.Canceled))
1090 assert.True(t, pgconn.SafeToRetry(result.Err))
1091
1092 ensureConnValid(t, pgConn)
1093 }
1094
1095 func TestConnExecPreparedEmptySQL(t *testing.T) {
1096 t.Parallel()
1097
1098 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
1099 defer cancel()
1100
1101 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
1102 require.NoError(t, err)
1103 defer closeConn(t, pgConn)
1104
1105 _, err = pgConn.Prepare(ctx, "ps1", "", nil)
1106 require.NoError(t, err)
1107
1108 result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
1109 assert.Nil(t, result.CommandTag)
1110 assert.Len(t, result.Rows, 0)
1111 assert.NoError(t, result.Err)
1112
1113 ensureConnValid(t, pgConn)
1114 }
1115
1116 func TestConnExecBatch(t *testing.T) {
1117 t.Parallel()
1118
1119 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1120 require.NoError(t, err)
1121 defer closeConn(t, pgConn)
1122
1123 _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
1124 require.NoError(t, err)
1125
1126 batch := &pgconn.Batch{}
1127
1128 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
1129 batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
1130 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
1131 results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
1132 require.NoError(t, err)
1133 require.Len(t, results, 3)
1134
1135 require.Len(t, results[0].Rows, 1)
1136 require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
1137 assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
1138
1139 require.Len(t, results[1].Rows, 1)
1140 require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
1141 assert.Equal(t, "SELECT 1", string(results[1].CommandTag))
1142
1143 require.Len(t, results[2].Rows, 1)
1144 require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
1145 assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
1146 }
1147
1148 func TestConnExecBatchDeferredError(t *testing.T) {
1149 t.Parallel()
1150
1151 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1152 require.NoError(t, err)
1153 defer closeConn(t, pgConn)
1154
1155 if pgConn.ParameterStatus("crdb_version") != "" {
1156 t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
1157 }
1158
1159 setupSQL := `create temporary table t (
1160 id text primary key,
1161 n int not null,
1162 unique (n) deferrable initially deferred
1163 );
1164
1165 insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
1166
1167 _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
1168 require.NoError(t, err)
1169
1170 batch := &pgconn.Batch{}
1171
1172 batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
1173 _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
1174 require.NotNil(t, err)
1175 var pgErr *pgconn.PgError
1176 require.True(t, errors.As(err, &pgErr))
1177 require.Equal(t, "23505", pgErr.Code)
1178
1179 ensureConnValid(t, pgConn)
1180 }
1181
1182 func TestConnExecBatchPrecanceled(t *testing.T) {
1183 t.Parallel()
1184
1185 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1186 require.NoError(t, err)
1187 defer closeConn(t, pgConn)
1188
1189 _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
1190 require.NoError(t, err)
1191
1192 batch := &pgconn.Batch{}
1193
1194 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
1195 batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
1196 batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
1197
1198 ctx, cancel := context.WithCancel(context.Background())
1199 cancel()
1200 _, err = pgConn.ExecBatch(ctx, batch).ReadAll()
1201 require.Error(t, err)
1202 assert.True(t, errors.Is(err, context.Canceled))
1203 assert.True(t, pgconn.SafeToRetry(err))
1204
1205 ensureConnValid(t, pgConn)
1206 }
1207
1208
1209
1210
1211 func TestConnExecBatchHuge(t *testing.T) {
1212 if testing.Short() {
1213 t.Skip("skipping test in short mode.")
1214 }
1215
1216 t.Parallel()
1217
1218 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1219 require.NoError(t, err)
1220 defer closeConn(t, pgConn)
1221
1222 batch := &pgconn.Batch{}
1223
1224 queryCount := 100000
1225 args := make([]string, queryCount)
1226
1227 for i := range args {
1228 args[i] = strconv.Itoa(i)
1229 batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil)
1230 }
1231
1232 results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
1233 require.NoError(t, err)
1234 require.Len(t, results, queryCount)
1235
1236 for i := range args {
1237 require.Len(t, results[i].Rows, 1)
1238 require.Equal(t, args[i], string(results[i].Rows[0][0]))
1239 assert.Equal(t, "SELECT 1", string(results[i].CommandTag))
1240 }
1241 }
1242
1243 func TestConnExecBatchImplicitTransaction(t *testing.T) {
1244 t.Parallel()
1245
1246 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1247 require.NoError(t, err)
1248 defer closeConn(t, pgConn)
1249
1250 if pgConn.ParameterStatus("crdb_version") != "" {
1251 t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)")
1252 }
1253
1254 _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll()
1255 require.NoError(t, err)
1256
1257 batch := &pgconn.Batch{}
1258
1259 batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil)
1260 batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil)
1261 batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil)
1262 batch.ExecParams("select 1/0", nil, nil, nil, nil)
1263 _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
1264 require.Error(t, err)
1265
1266 result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read()
1267 require.Equal(t, "0", string(result.Rows[0][0]))
1268 }
1269
1270 func TestConnLocking(t *testing.T) {
1271 t.Parallel()
1272
1273 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1274 require.NoError(t, err)
1275 defer closeConn(t, pgConn)
1276
1277 mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
1278 _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
1279 assert.Error(t, err)
1280 assert.Equal(t, "conn busy", err.Error())
1281 assert.True(t, pgconn.SafeToRetry(err))
1282
1283 results, err := mrr.ReadAll()
1284 assert.NoError(t, err)
1285 assert.Len(t, results, 1)
1286 assert.Nil(t, results[0].Err)
1287 assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
1288 assert.Len(t, results[0].Rows, 1)
1289 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
1290
1291 ensureConnValid(t, pgConn)
1292 }
1293
1294 func TestCommandTag(t *testing.T) {
1295 t.Parallel()
1296
1297 var tests = []struct {
1298 commandTag pgconn.CommandTag
1299 rowsAffected int64
1300 isInsert bool
1301 isUpdate bool
1302 isDelete bool
1303 isSelect bool
1304 }{
1305 {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true},
1306 {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true},
1307 {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true},
1308 {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true},
1309 {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true},
1310 {commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true},
1311 {commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true},
1312 {commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true},
1313 {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0},
1314 {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0},
1315 {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0},
1316 }
1317
1318 for i, tt := range tests {
1319 ct := tt.commandTag
1320 assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
1321 assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
1322 assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
1323 assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
1324 assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
1325 }
1326 }
1327
1328 func TestConnOnNotice(t *testing.T) {
1329 t.Parallel()
1330
1331 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
1332 require.NoError(t, err)
1333
1334 var msg string
1335 config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) {
1336 msg = notice.Message
1337 }
1338 config.RuntimeParams["client_min_messages"] = "notice"
1339
1340 pgConn, err := pgconn.ConnectConfig(context.Background(), config)
1341 require.NoError(t, err)
1342 defer closeConn(t, pgConn)
1343
1344 if pgConn.ParameterStatus("crdb_version") != "" {
1345 t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)")
1346 }
1347
1348 multiResult := pgConn.Exec(context.Background(), `do $$
1349 begin
1350 raise notice 'hello, world';
1351 end$$;`)
1352 err = multiResult.Close()
1353 require.NoError(t, err)
1354 assert.Equal(t, "hello, world", msg)
1355
1356 ensureConnValid(t, pgConn)
1357 }
1358
1359 func TestConnOnNotification(t *testing.T) {
1360 t.Parallel()
1361
1362 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
1363 require.NoError(t, err)
1364
1365 var msg string
1366 config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
1367 msg = n.Payload
1368 }
1369
1370 pgConn, err := pgconn.ConnectConfig(context.Background(), config)
1371 require.NoError(t, err)
1372 defer closeConn(t, pgConn)
1373
1374 if pgConn.ParameterStatus("crdb_version") != "" {
1375 t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
1376 }
1377
1378 _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll()
1379 require.NoError(t, err)
1380
1381 notifier, err := pgconn.ConnectConfig(context.Background(), config)
1382 require.NoError(t, err)
1383 defer closeConn(t, notifier)
1384 _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll()
1385 require.NoError(t, err)
1386
1387 _, err = pgConn.Exec(context.Background(), "select 1").ReadAll()
1388 require.NoError(t, err)
1389
1390 assert.Equal(t, "bar", msg)
1391
1392 ensureConnValid(t, pgConn)
1393 }
1394
1395 func TestConnWaitForNotification(t *testing.T) {
1396 t.Parallel()
1397
1398 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
1399 require.NoError(t, err)
1400
1401 var msg string
1402 config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
1403 msg = n.Payload
1404 }
1405
1406 pgConn, err := pgconn.ConnectConfig(context.Background(), config)
1407 require.NoError(t, err)
1408 defer closeConn(t, pgConn)
1409
1410 if pgConn.ParameterStatus("crdb_version") != "" {
1411 t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
1412 }
1413
1414 _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll()
1415 require.NoError(t, err)
1416
1417 notifier, err := pgconn.ConnectConfig(context.Background(), config)
1418 require.NoError(t, err)
1419 defer closeConn(t, notifier)
1420 _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll()
1421 require.NoError(t, err)
1422
1423 err = pgConn.WaitForNotification(context.Background())
1424 require.NoError(t, err)
1425
1426 assert.Equal(t, "bar", msg)
1427
1428 ensureConnValid(t, pgConn)
1429 }
1430
1431 func TestConnWaitForNotificationPrecanceled(t *testing.T) {
1432 t.Parallel()
1433
1434 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
1435 require.NoError(t, err)
1436
1437 pgConn, err := pgconn.ConnectConfig(context.Background(), config)
1438 require.NoError(t, err)
1439 defer closeConn(t, pgConn)
1440
1441 ctx, cancel := context.WithCancel(context.Background())
1442 cancel()
1443 err = pgConn.WaitForNotification(ctx)
1444 require.ErrorIs(t, err, context.Canceled)
1445
1446 ensureConnValid(t, pgConn)
1447 }
1448
1449 func TestConnWaitForNotificationTimeout(t *testing.T) {
1450 t.Parallel()
1451
1452 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
1453 require.NoError(t, err)
1454
1455 pgConn, err := pgconn.ConnectConfig(context.Background(), config)
1456 require.NoError(t, err)
1457 defer closeConn(t, pgConn)
1458
1459 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
1460 err = pgConn.WaitForNotification(ctx)
1461 cancel()
1462 assert.True(t, pgconn.Timeout(err))
1463 assert.ErrorIs(t, err, context.DeadlineExceeded)
1464
1465 ensureConnValid(t, pgConn)
1466 }
1467
1468 func TestConnCopyToSmall(t *testing.T) {
1469 t.Parallel()
1470
1471 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1472 require.NoError(t, err)
1473 defer closeConn(t, pgConn)
1474
1475 if pgConn.ParameterStatus("crdb_version") != "" {
1476 t.Skip("Server does support COPY TO")
1477 }
1478
1479 _, err = pgConn.Exec(context.Background(), `create temporary table foo(
1480 a int2,
1481 b int4,
1482 c int8,
1483 d varchar,
1484 e text,
1485 f date,
1486 g json
1487 )`).ReadAll()
1488 require.NoError(t, err)
1489
1490 _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
1491 require.NoError(t, err)
1492
1493 _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
1494 require.NoError(t, err)
1495
1496 inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
1497 "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
1498
1499 outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
1500
1501 res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
1502 require.NoError(t, err)
1503
1504 assert.Equal(t, int64(2), res.RowsAffected())
1505 assert.Equal(t, inputBytes, outputWriter.Bytes())
1506
1507 ensureConnValid(t, pgConn)
1508 }
1509
1510 func TestConnCopyToLarge(t *testing.T) {
1511 t.Parallel()
1512
1513 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1514 require.NoError(t, err)
1515 defer closeConn(t, pgConn)
1516
1517 if pgConn.ParameterStatus("crdb_version") != "" {
1518 t.Skip("Server does support COPY TO")
1519 }
1520
1521 _, err = pgConn.Exec(context.Background(), `create temporary table foo(
1522 a int2,
1523 b int4,
1524 c int8,
1525 d varchar,
1526 e text,
1527 f date,
1528 g json,
1529 h bytea
1530 )`).ReadAll()
1531 require.NoError(t, err)
1532
1533 inputBytes := make([]byte, 0)
1534
1535 for i := 0; i < 1000; i++ {
1536 _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
1537 require.NoError(t, err)
1538 inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
1539 }
1540
1541 outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
1542
1543 res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
1544 require.NoError(t, err)
1545
1546 assert.Equal(t, int64(1000), res.RowsAffected())
1547 assert.Equal(t, inputBytes, outputWriter.Bytes())
1548
1549 ensureConnValid(t, pgConn)
1550 }
1551
1552 func TestConnCopyToQueryError(t *testing.T) {
1553 t.Parallel()
1554
1555 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1556 require.NoError(t, err)
1557 defer closeConn(t, pgConn)
1558
1559 outputWriter := bytes.NewBuffer(make([]byte, 0))
1560
1561 res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout")
1562 require.Error(t, err)
1563 assert.IsType(t, &pgconn.PgError{}, err)
1564 assert.Equal(t, int64(0), res.RowsAffected())
1565
1566 ensureConnValid(t, pgConn)
1567 }
1568
1569 func TestConnCopyToCanceled(t *testing.T) {
1570 t.Parallel()
1571
1572 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1573 require.NoError(t, err)
1574 defer closeConn(t, pgConn)
1575
1576 if pgConn.ParameterStatus("crdb_version") != "" {
1577 t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
1578 }
1579
1580 outputWriter := &bytes.Buffer{}
1581
1582 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
1583 defer cancel()
1584 res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
1585 assert.Error(t, err)
1586 assert.Equal(t, pgconn.CommandTag(nil), res)
1587
1588 assert.True(t, pgConn.IsClosed())
1589 select {
1590 case <-pgConn.CleanupDone():
1591 case <-time.After(5 * time.Second):
1592 t.Fatal("Connection cleanup exceeded maximum time")
1593 }
1594 }
1595
1596 func TestConnCopyToPrecanceled(t *testing.T) {
1597 t.Parallel()
1598
1599 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1600 require.NoError(t, err)
1601 defer closeConn(t, pgConn)
1602
1603 outputWriter := &bytes.Buffer{}
1604
1605 ctx, cancel := context.WithCancel(context.Background())
1606 cancel()
1607 res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
1608 require.Error(t, err)
1609 assert.True(t, errors.Is(err, context.Canceled))
1610 assert.True(t, pgconn.SafeToRetry(err))
1611 assert.Equal(t, pgconn.CommandTag(nil), res)
1612
1613 ensureConnValid(t, pgConn)
1614 }
1615
1616 func TestConnCopyFrom(t *testing.T) {
1617 t.Parallel()
1618
1619 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1620 require.NoError(t, err)
1621 defer closeConn(t, pgConn)
1622
1623 if pgConn.ParameterStatus("crdb_version") != "" {
1624 t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
1625 }
1626
1627 _, err = pgConn.Exec(context.Background(), `create temporary table foo(
1628 a int4,
1629 b varchar
1630 )`).ReadAll()
1631 require.NoError(t, err)
1632
1633 srcBuf := &bytes.Buffer{}
1634
1635 inputRows := [][][]byte{}
1636 for i := 0; i < 1000; i++ {
1637 a := strconv.Itoa(i)
1638 b := "foo " + a + " bar"
1639 inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
1640 _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
1641 require.NoError(t, err)
1642 }
1643
1644 ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)")
1645 require.NoError(t, err)
1646 assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
1647
1648 result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
1649 require.NoError(t, result.Err)
1650
1651 assert.Equal(t, inputRows, result.Rows)
1652
1653 ensureConnValid(t, pgConn)
1654 }
1655
1656 func TestConnCopyFromCanceled(t *testing.T) {
1657 t.Parallel()
1658
1659 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1660 require.NoError(t, err)
1661 defer closeConn(t, pgConn)
1662
1663 if pgConn.ParameterStatus("crdb_version") != "" {
1664 t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
1665 }
1666
1667 _, err = pgConn.Exec(context.Background(), `create temporary table foo(
1668 a int4,
1669 b varchar
1670 )`).ReadAll()
1671 require.NoError(t, err)
1672
1673 r, w := io.Pipe()
1674 go func() {
1675 for i := 0; i < 1000000; i++ {
1676 a := strconv.Itoa(i)
1677 b := "foo " + a + " bar"
1678 _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
1679 if err != nil {
1680 return
1681 }
1682 time.Sleep(time.Microsecond)
1683 }
1684 }()
1685
1686 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
1687 ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
1688 cancel()
1689 assert.Equal(t, int64(0), ct.RowsAffected())
1690 assert.Error(t, err)
1691
1692 assert.True(t, pgConn.IsClosed())
1693 select {
1694 case <-pgConn.CleanupDone():
1695 case <-time.After(5 * time.Second):
1696 t.Fatal("Connection cleanup exceeded maximum time")
1697 }
1698 }
1699
1700 func TestConnCopyFromPrecanceled(t *testing.T) {
1701 t.Parallel()
1702
1703 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1704 require.NoError(t, err)
1705 defer closeConn(t, pgConn)
1706
1707 _, err = pgConn.Exec(context.Background(), `create temporary table foo(
1708 a int4,
1709 b varchar
1710 )`).ReadAll()
1711 require.NoError(t, err)
1712
1713 r, w := io.Pipe()
1714 go func() {
1715 for i := 0; i < 1000000; i++ {
1716 a := strconv.Itoa(i)
1717 b := "foo " + a + " bar"
1718 _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
1719 if err != nil {
1720 return
1721 }
1722 time.Sleep(time.Microsecond)
1723 }
1724 }()
1725
1726 ctx, cancel := context.WithCancel(context.Background())
1727 cancel()
1728 ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
1729 require.Error(t, err)
1730 assert.True(t, errors.Is(err, context.Canceled))
1731 assert.True(t, pgconn.SafeToRetry(err))
1732 assert.Equal(t, pgconn.CommandTag(nil), ct)
1733
1734 ensureConnValid(t, pgConn)
1735 }
1736
1737 func TestConnCopyFromGzipReader(t *testing.T) {
1738 t.Parallel()
1739
1740 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1741 require.NoError(t, err)
1742 defer closeConn(t, pgConn)
1743
1744 if pgConn.ParameterStatus("crdb_version") != "" {
1745 t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
1746 }
1747
1748 _, err = pgConn.Exec(context.Background(), `create temporary table foo(
1749 a int4,
1750 b varchar
1751 )`).ReadAll()
1752 require.NoError(t, err)
1753
1754 f, err := ioutil.TempFile("", "*")
1755 require.NoError(t, err)
1756
1757 gw := gzip.NewWriter(f)
1758
1759 inputRows := [][][]byte{}
1760 for i := 0; i < 1000; i++ {
1761 a := strconv.Itoa(i)
1762 b := "foo " + a + " bar"
1763 inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
1764 _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
1765 require.NoError(t, err)
1766 }
1767
1768 err = gw.Close()
1769 require.NoError(t, err)
1770
1771 _, err = f.Seek(0, 0)
1772 require.NoError(t, err)
1773
1774 gr, err := gzip.NewReader(f)
1775 require.NoError(t, err)
1776
1777 ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
1778 require.NoError(t, err)
1779 assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
1780
1781 err = gr.Close()
1782 require.NoError(t, err)
1783
1784 err = f.Close()
1785 require.NoError(t, err)
1786
1787 err = os.Remove(f.Name())
1788 require.NoError(t, err)
1789
1790 result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
1791 require.NoError(t, result.Err)
1792
1793 assert.Equal(t, inputRows, result.Rows)
1794
1795 ensureConnValid(t, pgConn)
1796 }
1797
1798 func TestConnCopyFromQuerySyntaxError(t *testing.T) {
1799 t.Parallel()
1800
1801 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1802 require.NoError(t, err)
1803 defer closeConn(t, pgConn)
1804
1805 _, err = pgConn.Exec(context.Background(), `create temporary table foo(
1806 a int4,
1807 b varchar
1808 )`).ReadAll()
1809 require.NoError(t, err)
1810
1811 srcBuf := &bytes.Buffer{}
1812
1813 res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout")
1814 require.Error(t, err)
1815 assert.IsType(t, &pgconn.PgError{}, err)
1816 assert.Equal(t, int64(0), res.RowsAffected())
1817
1818 ensureConnValid(t, pgConn)
1819 }
1820
1821 func TestConnCopyFromQueryNoTableError(t *testing.T) {
1822 t.Parallel()
1823
1824 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1825 require.NoError(t, err)
1826 defer closeConn(t, pgConn)
1827
1828 srcBuf := &bytes.Buffer{}
1829
1830 res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout")
1831 require.Error(t, err)
1832 assert.IsType(t, &pgconn.PgError{}, err)
1833 assert.Equal(t, int64(0), res.RowsAffected())
1834
1835 ensureConnValid(t, pgConn)
1836 }
1837
1838
1839 func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
1840 t.Parallel()
1841
1842 ctx := context.Background()
1843 pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
1844 require.NoError(t, err)
1845 defer closeConn(t, pgConn)
1846
1847 if pgConn.ParameterStatus("crdb_version") != "" {
1848 t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)")
1849 }
1850
1851 _, err = pgConn.Exec(ctx, `create temporary table sentences(
1852 t text,
1853 ts tsvector
1854 )`).ReadAll()
1855 require.NoError(t, err)
1856
1857 _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
1858 begin
1859 new.ts := to_tsvector(new.t);
1860 return new;
1861 end
1862 $$ language plpgsql;`).ReadAll()
1863 require.NoError(t, err)
1864
1865 _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
1866 require.NoError(t, err)
1867
1868 longString := make([]byte, 10001)
1869 for i := range longString {
1870 longString[i] = 'x'
1871 }
1872
1873 buf := &bytes.Buffer{}
1874 for i := 0; i < 1000; i++ {
1875 buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
1876 }
1877
1878 _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
1879 require.NoError(t, err)
1880 }
1881
1882 func TestConnEscapeString(t *testing.T) {
1883 t.Parallel()
1884
1885 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1886 require.NoError(t, err)
1887 defer closeConn(t, pgConn)
1888
1889 tests := []struct {
1890 in string
1891 out string
1892 }{
1893 {in: "", out: ""},
1894 {in: "42", out: "42"},
1895 {in: "'", out: "''"},
1896 {in: "hi'there", out: "hi''there"},
1897 {in: "'hi there'", out: "''hi there''"},
1898 }
1899
1900 for i, tt := range tests {
1901 value, err := pgConn.EscapeString(tt.in)
1902 if assert.NoErrorf(t, err, "%d.", i) {
1903 assert.Equalf(t, tt.out, value, "%d.", i)
1904 }
1905 }
1906
1907 ensureConnValid(t, pgConn)
1908 }
1909
1910 func TestConnCancelRequest(t *testing.T) {
1911 t.Parallel()
1912
1913 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1914 require.NoError(t, err)
1915 defer closeConn(t, pgConn)
1916
1917 if pgConn.ParameterStatus("crdb_version") != "" {
1918 t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
1919 }
1920
1921 multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)")
1922
1923
1924
1925
1926 time.Sleep(50 * time.Millisecond)
1927
1928 err = pgConn.CancelRequest(context.Background())
1929 require.NoError(t, err)
1930
1931 for multiResult.NextResult() {
1932 }
1933 err = multiResult.Close()
1934
1935 require.IsType(t, &pgconn.PgError{}, err)
1936 require.Equal(t, "57014", err.(*pgconn.PgError).Code)
1937
1938 ensureConnValid(t, pgConn)
1939 }
1940
1941
1942 func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
1943 t.Parallel()
1944
1945 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1946 require.NoError(t, err)
1947 defer closeConn(t, pgConn)
1948
1949 pid := pgConn.PID()
1950
1951 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
1952 defer cancel()
1953 multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)")
1954
1955 for multiResult.NextResult() {
1956 }
1957 err = multiResult.Close()
1958 assert.True(t, pgconn.Timeout(err))
1959 assert.True(t, pgConn.IsClosed())
1960 select {
1961 case <-pgConn.CleanupDone():
1962 case <-time.After(5 * time.Second):
1963 t.Fatal("Connection cleanup exceeded maximum time")
1964 }
1965
1966 otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
1967 require.NoError(t, err)
1968 defer closeConn(t, otherConn)
1969
1970 ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
1971 defer cancel()
1972
1973 for {
1974 result := otherConn.ExecParams(ctx,
1975 `select 1 from pg_stat_activity where pid=$1`,
1976 [][]byte{[]byte(strconv.FormatInt(int64(pid), 10))},
1977 nil,
1978 nil,
1979 nil,
1980 ).Read()
1981 require.NoError(t, result.Err)
1982
1983 if len(result.Rows) == 0 {
1984 break
1985 }
1986 }
1987 }
1988
1989 func TestConnSendBytesAndReceiveMessage(t *testing.T) {
1990 t.Parallel()
1991
1992 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
1993 defer cancel()
1994
1995 config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
1996 require.NoError(t, err)
1997 config.RuntimeParams["client_min_messages"] = "notice"
1998
1999 pgConn, err := pgconn.ConnectConfig(context.Background(), config)
2000 require.NoError(t, err)
2001 defer closeConn(t, pgConn)
2002
2003 queryMsg := pgproto3.Query{String: "select 42"}
2004 buf, err := queryMsg.Encode(nil)
2005 require.NoError(t, err)
2006
2007 err = pgConn.SendBytes(ctx, buf)
2008 require.NoError(t, err)
2009
2010 msg, err := pgConn.ReceiveMessage(ctx)
2011 require.NoError(t, err)
2012 _, ok := msg.(*pgproto3.RowDescription)
2013 require.True(t, ok)
2014
2015 msg, err = pgConn.ReceiveMessage(ctx)
2016 require.NoError(t, err)
2017 _, ok = msg.(*pgproto3.DataRow)
2018 require.True(t, ok)
2019
2020 msg, err = pgConn.ReceiveMessage(ctx)
2021 require.NoError(t, err)
2022 _, ok = msg.(*pgproto3.CommandComplete)
2023 require.True(t, ok)
2024
2025 msg, err = pgConn.ReceiveMessage(ctx)
2026 require.NoError(t, err)
2027 _, ok = msg.(*pgproto3.ReadyForQuery)
2028 require.True(t, ok)
2029
2030 ensureConnValid(t, pgConn)
2031 }
2032
2033 func TestHijackAndConstruct(t *testing.T) {
2034 t.Parallel()
2035
2036 origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
2037 require.NoError(t, err)
2038
2039 hc, err := origConn.Hijack()
2040 require.NoError(t, err)
2041
2042 _, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
2043 require.Error(t, err)
2044
2045 newConn, err := pgconn.Construct(hc)
2046 require.NoError(t, err)
2047
2048 defer closeConn(t, newConn)
2049
2050 results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
2051 assert.NoError(t, err)
2052
2053 assert.Len(t, results, 1)
2054 assert.Nil(t, results[0].Err)
2055 assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
2056 assert.Len(t, results[0].Rows, 1)
2057 assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
2058
2059 ensureConnValid(t, newConn)
2060 }
2061
2062 func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) {
2063 t.Parallel()
2064
2065 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
2066 require.NoError(t, err)
2067
2068 ctx, _ := context.WithCancel(context.Background())
2069 pgConn.Exec(ctx, "select n from generate_series(1,10) n")
2070
2071 closeCtx, _ := context.WithCancel(context.Background())
2072 pgConn.Close(closeCtx)
2073 select {
2074 case <-pgConn.CleanupDone():
2075 case <-time.After(5 * time.Second):
2076 t.Fatal("Connection cleanup exceeded maximum time")
2077 }
2078 }
2079
2080
2081 func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) {
2082 t.Parallel()
2083
2084 steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
2085 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
2086 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{}))
2087 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
2088 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{}))
2089 steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{}))
2090 steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
2091 {Name: []byte("mock")},
2092 }}))
2093 steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}))
2094 steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
2095
2096 script := &pgmock.Script{Steps: steps}
2097
2098 ln, err := net.Listen("tcp", "127.0.0.1:")
2099 require.NoError(t, err)
2100 defer ln.Close()
2101
2102 serverErrChan := make(chan error, 1)
2103 go func() {
2104 defer close(serverErrChan)
2105
2106 conn, err := ln.Accept()
2107 if err != nil {
2108 serverErrChan <- err
2109 return
2110 }
2111 defer conn.Close()
2112
2113 err = conn.SetDeadline(time.Now().Add(5 * time.Second))
2114 if err != nil {
2115 serverErrChan <- err
2116 return
2117 }
2118
2119 err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
2120 if err != nil {
2121 serverErrChan <- err
2122 return
2123 }
2124 }()
2125
2126 parts := strings.Split(ln.Addr().String(), ":")
2127 host := parts[0]
2128 port := parts[1]
2129 connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
2130
2131 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
2132 defer cancel()
2133 conn, err := pgconn.Connect(ctx, connStr)
2134 require.NoError(t, err)
2135
2136 rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil)
2137
2138 for rr.NextRow() {
2139 }
2140
2141 _, err = rr.Close()
2142 require.Error(t, err)
2143 }
2144
2145 func Example() {
2146 pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
2147 if err != nil {
2148 log.Fatalln(err)
2149 }
2150 defer pgConn.Close(context.Background())
2151
2152 result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read()
2153 if result.Err != nil {
2154 log.Fatalln(result.Err)
2155 }
2156
2157 for _, row := range result.Rows {
2158 fmt.Println(string(row[0]))
2159 }
2160
2161 fmt.Println(result.CommandTag)
2162
2163
2164
2165
2166
2167 }
2168
2169 func GetSSLPassword(ctx context.Context) string {
2170 connString := os.Getenv("PGX_SSL_PASSWORD")
2171 return connString
2172 }
2173
2174 var rsaCertPEM = `-----BEGIN CERTIFICATE-----
2175 MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
2176 BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
2177 NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
2178 AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
2179 Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
2180 tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
2181 9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
2182 0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
2183 MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
2184 FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
2185 6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
2186 gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
2187 81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
2188 Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
2189 hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
2190 VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
2191 MlascjupnaptKX/wMA==
2192 -----END CERTIFICATE-----
2193 `
2194
2195 var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
2196 MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
2197 ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
2198 Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
2199 bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
2200 qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
2201 Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
2202 o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
2203 WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
2204 ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
2205 Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
2206 QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
2207 QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
2208 CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
2209 bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
2210 1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
2211 SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
2212 MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
2213 McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
2214 I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
2215 QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
2216 k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
2217 lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
2218 TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
2219 5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
2220 UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
2221 z3w+CgS20UrbLIR1YXfqUXge1g==
2222 -----END TESTING KEY-----
2223 `)
2224
2225 func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
2226
2227 func TestSNISupport(t *testing.T) {
2228 t.Parallel()
2229 tests := []struct {
2230 name string
2231 sni_param string
2232 sni_set bool
2233 }{
2234 {
2235 name: "SNI is passed by default",
2236 sni_param: "",
2237 sni_set: true,
2238 },
2239 {
2240 name: "SNI is passed when asked for",
2241 sni_param: "sslsni=1",
2242 sni_set: true,
2243 },
2244 {
2245 name: "SNI is not passed when disabled",
2246 sni_param: "sslsni=0",
2247 sni_set: false,
2248 },
2249 }
2250 for _, tt := range tests {
2251 tt := tt
2252 t.Run(tt.name, func(t *testing.T) {
2253 t.Parallel()
2254
2255 ln, err := net.Listen("tcp", "127.0.0.1:")
2256 require.NoError(t, err)
2257 defer ln.Close()
2258
2259 serverErrChan := make(chan error, 1)
2260 serverSNINameChan := make(chan string, 1)
2261 defer close(serverErrChan)
2262 defer close(serverSNINameChan)
2263
2264 go func() {
2265 var sniHost string
2266
2267 conn, err := ln.Accept()
2268 if err != nil {
2269 serverErrChan <- err
2270 return
2271 }
2272 defer conn.Close()
2273
2274 err = conn.SetDeadline(time.Now().Add(5 * time.Second))
2275 if err != nil {
2276 serverErrChan <- err
2277 return
2278 }
2279
2280 backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
2281 startupMessage, err := backend.ReceiveStartupMessage()
2282 if err != nil {
2283 serverErrChan <- err
2284 return
2285 }
2286
2287 switch startupMessage.(type) {
2288 case *pgproto3.SSLRequest:
2289 _, err = conn.Write([]byte("S"))
2290 if err != nil {
2291 serverErrChan <- err
2292 return
2293 }
2294 default:
2295 serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
2296 return
2297 }
2298
2299 cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
2300 if err != nil {
2301 serverErrChan <- err
2302 return
2303 }
2304
2305 srv := tls.Server(conn, &tls.Config{
2306 Certificates: []tls.Certificate{cert},
2307 GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
2308 sniHost = argHello.ServerName
2309 return nil, nil
2310 },
2311 })
2312 defer srv.Close()
2313
2314 if err := srv.Handshake(); err != nil {
2315 serverErrChan <- fmt.Errorf("handshake: %v", err)
2316 return
2317 }
2318
2319 srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
2320 srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
2321 srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
2322
2323 serverSNINameChan <- sniHost
2324 }()
2325
2326 port := strings.Split(ln.Addr().String(), ":")[1]
2327 connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
2328 _, err = pgconn.Connect(context.Background(), connStr)
2329
2330 select {
2331 case sniHost := <-serverSNINameChan:
2332 if tt.sni_set {
2333 require.Equal(t, sniHost, "localhost")
2334 } else {
2335 require.Equal(t, sniHost, "")
2336 }
2337 case err = <-serverErrChan:
2338 t.Fatalf("server failed with error: %+v", err)
2339 case <-time.After(time.Millisecond * 100):
2340 t.Fatal("exceeded connection timeout without erroring out")
2341 }
2342 })
2343 }
2344 }
2345
2346 type delayedReader struct {
2347 r io.Reader
2348 }
2349
2350 func (d delayedReader) Read(p []byte) (int, error) {
2351
2352 time.Sleep(time.Millisecond)
2353 return d.r.Read(p)
2354 }
2355
2356 func TestCopyFrom(t *testing.T) {
2357 connString := os.Getenv("PGX_TEST_CONN_STRING")
2358 if connString == "" {
2359 t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_CONN_STRING")
2360 }
2361
2362 config, err := pgconn.ParseConfig(connString)
2363 require.NoError(t, err)
2364
2365 pgConn, err := pgconn.ConnectConfig(context.Background(), config)
2366 require.NoError(t, err)
2367
2368 if pgConn.ParameterStatus("crdb_version") != "" {
2369 t.Skip("Server does support COPY FROM")
2370 }
2371
2372 setupSQL := `create temporary table t (
2373 id text primary key,
2374 n int not null
2375 );`
2376
2377 _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
2378 assert.NoError(t, err)
2379
2380 r1 := delayedReader{r: strings.NewReader(`id 0\n`)}
2381
2382 _, err = pgConn.CopyFrom(context.Background(), r1, "COPY nosuchtable FROM STDIN ")
2383 assert.Error(t, err)
2384
2385 r2 := delayedReader{r: strings.NewReader(`id 0\n`)}
2386 _, err = pgConn.CopyFrom(context.Background(), r2, "COPY t FROM STDIN")
2387 assert.NoError(t, err)
2388 }
2389
2390 func mustEncode(buf []byte, err error) []byte {
2391 if err != nil {
2392 panic(err)
2393 }
2394 return buf
2395 }
2396
View as plain text