...

Source file src/github.com/jackc/pgconn/pgconn_test.go

Documentation: github.com/jackc/pgconn

     1  package pgconn_test
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"log"
    13  	"math"
    14  	"net"
    15  	"os"
    16  	"strconv"
    17  	"strings"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/jackc/pgconn"
    22  	"github.com/jackc/pgmock"
    23  	"github.com/jackc/pgproto3/v2"
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  func TestConnect(t *testing.T) {
    29  	tests := []struct {
    30  		name string
    31  		env  string
    32  	}{
    33  		{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
    34  		{"TCP", "PGX_TEST_TCP_CONN_STRING"},
    35  		{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
    36  		{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
    37  		{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
    38  	}
    39  
    40  	for _, tt := range tests {
    41  		tt := tt
    42  		t.Run(tt.name, func(t *testing.T) {
    43  			connString := os.Getenv(tt.env)
    44  			if connString == "" {
    45  				t.Skipf("Skipping due to missing environment variable %v", tt.env)
    46  			}
    47  
    48  			conn, err := pgconn.Connect(context.Background(), connString)
    49  			require.NoError(t, err)
    50  
    51  			closeConn(t, conn)
    52  		})
    53  	}
    54  }
    55  
    56  func TestConnectWithOptions(t *testing.T) {
    57  	tests := []struct {
    58  		name string
    59  		env  string
    60  	}{
    61  		{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
    62  		{"TCP", "PGX_TEST_TCP_CONN_STRING"},
    63  		{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
    64  		{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
    65  		{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
    66  	}
    67  
    68  	for _, tt := range tests {
    69  		tt := tt
    70  		t.Run(tt.name, func(t *testing.T) {
    71  			connString := os.Getenv(tt.env)
    72  			if connString == "" {
    73  				t.Skipf("Skipping due to missing environment variable %v", tt.env)
    74  			}
    75  			var sslOptions pgconn.ParseConfigOptions
    76  			sslOptions.GetSSLPassword = GetSSLPassword
    77  			conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions)
    78  			require.NoError(t, err)
    79  
    80  			closeConn(t, conn)
    81  		})
    82  	}
    83  }
    84  
    85  // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
    86  // connection.
    87  func TestConnectTLS(t *testing.T) {
    88  	t.Parallel()
    89  
    90  	connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
    91  	if connString == "" {
    92  		t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
    93  	}
    94  
    95  	var conn *pgconn.PgConn
    96  	var err error
    97  
    98  	var sslOptions pgconn.ParseConfigOptions
    99  	sslOptions.GetSSLPassword = GetSSLPassword
   100  	config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
   101  	require.Nil(t, err)
   102  
   103  	conn, err = pgconn.ConnectConfig(context.Background(), config)
   104  	require.NoError(t, err)
   105  
   106  	if _, ok := conn.Conn().(*tls.Conn); !ok {
   107  		t.Error("not a TLS connection")
   108  	}
   109  
   110  	closeConn(t, conn)
   111  }
   112  
   113  type pgmockWaitStep time.Duration
   114  
   115  func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
   116  	time.Sleep(time.Duration(s))
   117  	return nil
   118  }
   119  
   120  func TestConnectTimeout(t *testing.T) {
   121  	t.Parallel()
   122  	tests := []struct {
   123  		name    string
   124  		connect func(connStr string) error
   125  	}{
   126  		{
   127  			name: "via context that times out",
   128  			connect: func(connStr string) error {
   129  				ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
   130  				defer cancel()
   131  				_, err := pgconn.Connect(ctx, connStr)
   132  				return err
   133  			},
   134  		},
   135  		{
   136  			name: "via config ConnectTimeout",
   137  			connect: func(connStr string) error {
   138  				conf, err := pgconn.ParseConfig(connStr)
   139  				require.NoError(t, err)
   140  				conf.ConnectTimeout = time.Microsecond * 50
   141  				_, err = pgconn.ConnectConfig(context.Background(), conf)
   142  				return err
   143  			},
   144  		},
   145  	}
   146  	for _, tt := range tests {
   147  		tt := tt
   148  		t.Run(tt.name, func(t *testing.T) {
   149  			t.Parallel()
   150  			script := &pgmock.Script{
   151  				Steps: []pgmock.Step{
   152  					pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
   153  					pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
   154  					pgmockWaitStep(time.Millisecond * 500),
   155  					pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
   156  					pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
   157  				},
   158  			}
   159  
   160  			ln, err := net.Listen("tcp", "127.0.0.1:")
   161  			require.NoError(t, err)
   162  			defer ln.Close()
   163  
   164  			serverErrChan := make(chan error, 1)
   165  			go func() {
   166  				defer close(serverErrChan)
   167  
   168  				conn, err := ln.Accept()
   169  				if err != nil {
   170  					serverErrChan <- err
   171  					return
   172  				}
   173  				defer conn.Close()
   174  
   175  				err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
   176  				if err != nil {
   177  					serverErrChan <- err
   178  					return
   179  				}
   180  
   181  				err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
   182  				if err != nil {
   183  					serverErrChan <- err
   184  					return
   185  				}
   186  			}()
   187  
   188  			parts := strings.Split(ln.Addr().String(), ":")
   189  			host := parts[0]
   190  			port := parts[1]
   191  			connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
   192  			tooLate := time.Now().Add(time.Millisecond * 500)
   193  
   194  			err = tt.connect(connStr)
   195  			require.True(t, pgconn.Timeout(err), err)
   196  			require.True(t, time.Now().Before(tooLate))
   197  		})
   198  	}
   199  }
   200  
   201  func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
   202  	t.Parallel()
   203  	tests := []struct {
   204  		name    string
   205  		connect func(connStr string) error
   206  	}{
   207  		{
   208  			name: "via context that times out",
   209  			connect: func(connStr string) error {
   210  				ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
   211  				defer cancel()
   212  				_, err := pgconn.Connect(ctx, connStr)
   213  				return err
   214  			},
   215  		},
   216  		{
   217  			name: "via config ConnectTimeout",
   218  			connect: func(connStr string) error {
   219  				conf, err := pgconn.ParseConfig(connStr)
   220  				require.NoError(t, err)
   221  				conf.ConnectTimeout = time.Millisecond * 10
   222  				_, err = pgconn.ConnectConfig(context.Background(), conf)
   223  				return err
   224  			},
   225  		},
   226  	}
   227  	for _, tt := range tests {
   228  		tt := tt
   229  		t.Run(tt.name, func(t *testing.T) {
   230  			t.Parallel()
   231  			ln, err := net.Listen("tcp", "127.0.0.1:")
   232  			require.NoError(t, err)
   233  			defer ln.Close()
   234  
   235  			serverErrChan := make(chan error)
   236  			defer close(serverErrChan)
   237  			go func() {
   238  				conn, err := ln.Accept()
   239  				if err != nil {
   240  					serverErrChan <- err
   241  					return
   242  				}
   243  				defer conn.Close()
   244  
   245  				var buf []byte
   246  				_, err = conn.Read(buf)
   247  				if err != nil {
   248  					serverErrChan <- err
   249  					return
   250  				}
   251  
   252  				// Sleeping to hang the TLS handshake.
   253  				time.Sleep(time.Minute)
   254  			}()
   255  
   256  			parts := strings.Split(ln.Addr().String(), ":")
   257  			host := parts[0]
   258  			port := parts[1]
   259  			connStr := fmt.Sprintf("host=%s port=%s", host, port)
   260  
   261  			errChan := make(chan error)
   262  			go func() {
   263  				err := tt.connect(connStr)
   264  				errChan <- err
   265  			}()
   266  
   267  			select {
   268  			case err = <-errChan:
   269  				require.True(t, pgconn.Timeout(err), err)
   270  			case err = <-serverErrChan:
   271  				t.Fatalf("server failed with error: %s", err)
   272  			case <-time.After(time.Millisecond * 100):
   273  				t.Fatal("exceeded connection timeout without erroring out")
   274  			}
   275  		})
   276  	}
   277  }
   278  
   279  func TestConnectInvalidUser(t *testing.T) {
   280  	t.Parallel()
   281  
   282  	connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
   283  	if connString == "" {
   284  		t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
   285  	}
   286  
   287  	config, err := pgconn.ParseConfig(connString)
   288  	require.NoError(t, err)
   289  
   290  	config.User = "pgxinvalidusertest"
   291  
   292  	_, err = pgconn.ConnectConfig(context.Background(), config)
   293  	require.Error(t, err)
   294  	pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
   295  	if !ok {
   296  		t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
   297  	}
   298  	if pgErr.Code != "28000" && pgErr.Code != "28P01" {
   299  		t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
   300  	}
   301  }
   302  
   303  func TestConnectWithConnectionRefused(t *testing.T) {
   304  	t.Parallel()
   305  
   306  	// Presumably nothing is listening on 127.0.0.1:1
   307  	conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1")
   308  	if err == nil {
   309  		conn.Close(context.Background())
   310  		t.Fatal("Expected error establishing connection to bad port")
   311  	}
   312  }
   313  
   314  func TestConnectCustomDialer(t *testing.T) {
   315  	t.Parallel()
   316  
   317  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
   318  	require.NoError(t, err)
   319  
   320  	dialed := false
   321  	config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
   322  		dialed = true
   323  		return net.Dial(network, address)
   324  	}
   325  
   326  	conn, err := pgconn.ConnectConfig(context.Background(), config)
   327  	require.NoError(t, err)
   328  	require.True(t, dialed)
   329  	closeConn(t, conn)
   330  }
   331  
   332  func TestConnectCustomLookup(t *testing.T) {
   333  	t.Parallel()
   334  
   335  	connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
   336  	if connString == "" {
   337  		t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
   338  	}
   339  
   340  	config, err := pgconn.ParseConfig(connString)
   341  	require.NoError(t, err)
   342  
   343  	looked := false
   344  	config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
   345  		looked = true
   346  		return net.LookupHost(host)
   347  	}
   348  
   349  	conn, err := pgconn.ConnectConfig(context.Background(), config)
   350  	require.NoError(t, err)
   351  	require.True(t, looked)
   352  	closeConn(t, conn)
   353  }
   354  
   355  func TestConnectCustomLookupWithPort(t *testing.T) {
   356  	t.Parallel()
   357  
   358  	connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
   359  	if connString == "" {
   360  		t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
   361  	}
   362  
   363  	config, err := pgconn.ParseConfig(connString)
   364  	require.NoError(t, err)
   365  
   366  	origPort := config.Port
   367  	// Chnage the config an invalid port so it will fail if used
   368  	config.Port = 0
   369  
   370  	looked := false
   371  	config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
   372  		looked = true
   373  		addrs, err := net.LookupHost(host)
   374  		if err != nil {
   375  			return nil, err
   376  		}
   377  		for i := range addrs {
   378  			addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
   379  		}
   380  		return addrs, nil
   381  	}
   382  
   383  	conn, err := pgconn.ConnectConfig(context.Background(), config)
   384  	require.NoError(t, err)
   385  	require.True(t, looked)
   386  	closeConn(t, conn)
   387  }
   388  
   389  func TestConnectWithRuntimeParams(t *testing.T) {
   390  	t.Parallel()
   391  
   392  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
   393  	require.NoError(t, err)
   394  
   395  	config.RuntimeParams = map[string]string{
   396  		"application_name": "pgxtest",
   397  		"search_path":      "myschema",
   398  	}
   399  
   400  	conn, err := pgconn.ConnectConfig(context.Background(), config)
   401  	require.NoError(t, err)
   402  	defer closeConn(t, conn)
   403  
   404  	result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read()
   405  	require.Nil(t, result.Err)
   406  	assert.Equal(t, 1, len(result.Rows))
   407  	assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
   408  
   409  	result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read()
   410  	require.Nil(t, result.Err)
   411  	assert.Equal(t, 1, len(result.Rows))
   412  	assert.Equal(t, "myschema", string(result.Rows[0][0]))
   413  }
   414  
   415  func TestConnectWithFallback(t *testing.T) {
   416  	t.Parallel()
   417  
   418  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
   419  	require.NoError(t, err)
   420  
   421  	// Prepend current primary config to fallbacks
   422  	config.Fallbacks = append([]*pgconn.FallbackConfig{
   423  		&pgconn.FallbackConfig{
   424  			Host:      config.Host,
   425  			Port:      config.Port,
   426  			TLSConfig: config.TLSConfig,
   427  		},
   428  	}, config.Fallbacks...)
   429  
   430  	// Make primary config bad
   431  	config.Host = "localhost"
   432  	config.Port = 1 // presumably nothing listening here
   433  
   434  	// Prepend bad first fallback
   435  	config.Fallbacks = append([]*pgconn.FallbackConfig{
   436  		&pgconn.FallbackConfig{
   437  			Host:      "localhost",
   438  			Port:      1,
   439  			TLSConfig: config.TLSConfig,
   440  		},
   441  	}, config.Fallbacks...)
   442  
   443  	conn, err := pgconn.ConnectConfig(context.Background(), config)
   444  	require.NoError(t, err)
   445  	closeConn(t, conn)
   446  }
   447  
   448  func TestConnectWithValidateConnect(t *testing.T) {
   449  	t.Parallel()
   450  
   451  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
   452  	require.NoError(t, err)
   453  
   454  	dialCount := 0
   455  	config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
   456  		dialCount++
   457  		return net.Dial(network, address)
   458  	}
   459  
   460  	acceptConnCount := 0
   461  	config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
   462  		acceptConnCount++
   463  		if acceptConnCount < 2 {
   464  			return errors.New("reject first conn")
   465  		}
   466  		return nil
   467  	}
   468  
   469  	// Append current primary config to fallbacks
   470  	config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
   471  		Host:      config.Host,
   472  		Port:      config.Port,
   473  		TLSConfig: config.TLSConfig,
   474  	})
   475  
   476  	// Repeat fallbacks
   477  	config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
   478  
   479  	conn, err := pgconn.ConnectConfig(context.Background(), config)
   480  	require.NoError(t, err)
   481  	closeConn(t, conn)
   482  
   483  	assert.True(t, dialCount > 1)
   484  	assert.True(t, acceptConnCount > 1)
   485  }
   486  
   487  func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
   488  	t.Parallel()
   489  
   490  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
   491  	require.NoError(t, err)
   492  
   493  	config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
   494  	config.RuntimeParams["default_transaction_read_only"] = "on"
   495  
   496  	ctx, cancel := context.WithCancel(context.Background())
   497  	defer cancel()
   498  
   499  	conn, err := pgconn.ConnectConfig(ctx, config)
   500  	if !assert.NotNil(t, err) {
   501  		conn.Close(ctx)
   502  	}
   503  }
   504  
   505  func TestConnectWithAfterConnect(t *testing.T) {
   506  	t.Parallel()
   507  
   508  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
   509  	require.NoError(t, err)
   510  
   511  	config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
   512  		_, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
   513  		return err
   514  	}
   515  
   516  	conn, err := pgconn.ConnectConfig(context.Background(), config)
   517  	require.NoError(t, err)
   518  
   519  	results, err := conn.Exec(context.Background(), "show search_path;").ReadAll()
   520  	require.NoError(t, err)
   521  	defer closeConn(t, conn)
   522  
   523  	assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
   524  }
   525  
   526  func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
   527  	t.Parallel()
   528  
   529  	config := &pgconn.Config{}
   530  
   531  	require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) })
   532  }
   533  
   534  func TestConnPrepareSyntaxError(t *testing.T) {
   535  	t.Parallel()
   536  
   537  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   538  	require.NoError(t, err)
   539  	defer closeConn(t, pgConn)
   540  
   541  	psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil)
   542  	require.Nil(t, psd)
   543  	require.NotNil(t, err)
   544  
   545  	ensureConnValid(t, pgConn)
   546  }
   547  
   548  func TestConnPrepareContextPrecanceled(t *testing.T) {
   549  	t.Parallel()
   550  
   551  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   552  	require.NoError(t, err)
   553  	defer closeConn(t, pgConn)
   554  
   555  	ctx, cancel := context.WithCancel(context.Background())
   556  	cancel()
   557  	psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
   558  	assert.Nil(t, psd)
   559  	assert.Error(t, err)
   560  	assert.True(t, errors.Is(err, context.Canceled))
   561  	assert.True(t, pgconn.SafeToRetry(err))
   562  
   563  	ensureConnValid(t, pgConn)
   564  }
   565  
   566  func TestConnExec(t *testing.T) {
   567  	t.Parallel()
   568  
   569  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   570  	require.NoError(t, err)
   571  	defer closeConn(t, pgConn)
   572  
   573  	results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
   574  	assert.NoError(t, err)
   575  
   576  	assert.Len(t, results, 1)
   577  	assert.Nil(t, results[0].Err)
   578  	assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
   579  	assert.Len(t, results[0].Rows, 1)
   580  	assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
   581  
   582  	ensureConnValid(t, pgConn)
   583  }
   584  
   585  func TestConnExecEmpty(t *testing.T) {
   586  	t.Parallel()
   587  
   588  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   589  	require.NoError(t, err)
   590  	defer closeConn(t, pgConn)
   591  
   592  	multiResult := pgConn.Exec(context.Background(), ";")
   593  
   594  	resultCount := 0
   595  	for multiResult.NextResult() {
   596  		resultCount++
   597  		multiResult.ResultReader().Close()
   598  	}
   599  	assert.Equal(t, 0, resultCount)
   600  	err = multiResult.Close()
   601  	assert.NoError(t, err)
   602  
   603  	ensureConnValid(t, pgConn)
   604  }
   605  
   606  func TestConnExecMultipleQueries(t *testing.T) {
   607  	t.Parallel()
   608  
   609  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   610  	require.NoError(t, err)
   611  	defer closeConn(t, pgConn)
   612  
   613  	results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll()
   614  	assert.NoError(t, err)
   615  
   616  	assert.Len(t, results, 2)
   617  
   618  	assert.Nil(t, results[0].Err)
   619  	assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
   620  	assert.Len(t, results[0].Rows, 1)
   621  	assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
   622  
   623  	assert.Nil(t, results[1].Err)
   624  	assert.Equal(t, "SELECT 1", string(results[1].CommandTag))
   625  	assert.Len(t, results[1].Rows, 1)
   626  	assert.Equal(t, "1", string(results[1].Rows[0][0]))
   627  
   628  	ensureConnValid(t, pgConn)
   629  }
   630  
   631  func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
   632  	t.Parallel()
   633  
   634  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   635  	require.NoError(t, err)
   636  	defer closeConn(t, pgConn)
   637  
   638  	mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num")
   639  
   640  	require.True(t, mrr.NextResult())
   641  	require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
   642  	assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name)
   643  	_, err = mrr.ResultReader().Close()
   644  	require.NoError(t, err)
   645  
   646  	require.True(t, mrr.NextResult())
   647  	require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
   648  	assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name)
   649  	_, err = mrr.ResultReader().Close()
   650  	require.NoError(t, err)
   651  
   652  	require.False(t, mrr.NextResult())
   653  
   654  	require.NoError(t, mrr.Close())
   655  
   656  	ensureConnValid(t, pgConn)
   657  }
   658  
   659  func TestConnExecMultipleQueriesError(t *testing.T) {
   660  	t.Parallel()
   661  
   662  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   663  	require.NoError(t, err)
   664  	defer closeConn(t, pgConn)
   665  
   666  	results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll()
   667  	require.NotNil(t, err)
   668  	if pgErr, ok := err.(*pgconn.PgError); ok {
   669  		assert.Equal(t, "22012", pgErr.Code)
   670  	} else {
   671  		t.Errorf("unexpected error: %v", err)
   672  	}
   673  
   674  	if pgConn.ParameterStatus("crdb_version") != "" {
   675  		// CockroachDB starts the second query result set and then sends the divide by zero error.
   676  		require.Len(t, results, 2)
   677  		assert.Len(t, results[0].Rows, 1)
   678  		assert.Equal(t, "1", string(results[0].Rows[0][0]))
   679  		assert.Len(t, results[1].Rows, 0)
   680  	} else {
   681  		// PostgreSQL sends the divide by zero and never sends the second query result set.
   682  		require.Len(t, results, 1)
   683  		assert.Len(t, results[0].Rows, 1)
   684  		assert.Equal(t, "1", string(results[0].Rows[0][0]))
   685  	}
   686  
   687  	ensureConnValid(t, pgConn)
   688  }
   689  
   690  func TestConnExecDeferredError(t *testing.T) {
   691  	t.Parallel()
   692  
   693  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   694  	require.NoError(t, err)
   695  	defer closeConn(t, pgConn)
   696  
   697  	if pgConn.ParameterStatus("crdb_version") != "" {
   698  		t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
   699  	}
   700  
   701  	setupSQL := `create temporary table t (
   702  		id text primary key,
   703  		n int not null,
   704  		unique (n) deferrable initially deferred
   705  	);
   706  
   707  	insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
   708  
   709  	_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
   710  	assert.NoError(t, err)
   711  
   712  	_, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll()
   713  	require.NotNil(t, err)
   714  
   715  	var pgErr *pgconn.PgError
   716  	require.True(t, errors.As(err, &pgErr))
   717  	require.Equal(t, "23505", pgErr.Code)
   718  
   719  	ensureConnValid(t, pgConn)
   720  }
   721  
   722  func TestConnExecContextCanceled(t *testing.T) {
   723  	t.Parallel()
   724  
   725  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   726  	require.NoError(t, err)
   727  	defer closeConn(t, pgConn)
   728  
   729  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   730  	defer cancel()
   731  	multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)")
   732  
   733  	for multiResult.NextResult() {
   734  	}
   735  	err = multiResult.Close()
   736  	assert.True(t, pgconn.Timeout(err))
   737  	assert.ErrorIs(t, err, context.DeadlineExceeded)
   738  	assert.True(t, pgConn.IsClosed())
   739  	select {
   740  	case <-pgConn.CleanupDone():
   741  	case <-time.After(5 * time.Second):
   742  		t.Fatal("Connection cleanup exceeded maximum time")
   743  	}
   744  }
   745  
   746  func TestConnExecContextPrecanceled(t *testing.T) {
   747  	t.Parallel()
   748  
   749  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   750  	require.NoError(t, err)
   751  	defer closeConn(t, pgConn)
   752  
   753  	ctx, cancel := context.WithCancel(context.Background())
   754  	cancel()
   755  	_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
   756  	assert.Error(t, err)
   757  	assert.True(t, errors.Is(err, context.Canceled))
   758  	assert.True(t, pgconn.SafeToRetry(err))
   759  
   760  	ensureConnValid(t, pgConn)
   761  }
   762  
   763  func TestConnExecParams(t *testing.T) {
   764  	t.Parallel()
   765  
   766  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   767  	require.NoError(t, err)
   768  	defer closeConn(t, pgConn)
   769  
   770  	result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
   771  	require.Len(t, result.FieldDescriptions(), 1)
   772  	assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
   773  
   774  	rowCount := 0
   775  	for result.NextRow() {
   776  		rowCount += 1
   777  		assert.Equal(t, "Hello, world", string(result.Values()[0]))
   778  	}
   779  	assert.Equal(t, 1, rowCount)
   780  	commandTag, err := result.Close()
   781  	assert.Equal(t, "SELECT 1", string(commandTag))
   782  	assert.NoError(t, err)
   783  
   784  	ensureConnValid(t, pgConn)
   785  }
   786  
   787  func TestConnExecParamsDeferredError(t *testing.T) {
   788  	t.Parallel()
   789  
   790  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   791  	require.NoError(t, err)
   792  	defer closeConn(t, pgConn)
   793  
   794  	if pgConn.ParameterStatus("crdb_version") != "" {
   795  		t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
   796  	}
   797  
   798  	setupSQL := `create temporary table t (
   799  		id text primary key,
   800  		n int not null,
   801  		unique (n) deferrable initially deferred
   802  	);
   803  
   804  	insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
   805  
   806  	_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
   807  	assert.NoError(t, err)
   808  
   809  	result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
   810  	require.NotNil(t, result.Err)
   811  	var pgErr *pgconn.PgError
   812  	require.True(t, errors.As(result.Err, &pgErr))
   813  	require.Equal(t, "23505", pgErr.Code)
   814  
   815  	ensureConnValid(t, pgConn)
   816  }
   817  
   818  func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
   819  	t.Parallel()
   820  
   821  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   822  	require.NoError(t, err)
   823  	defer closeConn(t, pgConn)
   824  
   825  	paramCount := math.MaxUint16
   826  	params := make([]string, 0, paramCount)
   827  	args := make([][]byte, 0, paramCount)
   828  	for i := 0; i < paramCount; i++ {
   829  		params = append(params, fmt.Sprintf("($%d::text)", i+1))
   830  		args = append(args, []byte(strconv.Itoa(i)))
   831  	}
   832  	sql := "values" + strings.Join(params, ", ")
   833  
   834  	result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
   835  	require.NoError(t, result.Err)
   836  	require.Len(t, result.Rows, paramCount)
   837  
   838  	ensureConnValid(t, pgConn)
   839  }
   840  
   841  func TestConnExecParamsTooManyParams(t *testing.T) {
   842  	t.Parallel()
   843  
   844  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   845  	require.NoError(t, err)
   846  	defer closeConn(t, pgConn)
   847  
   848  	paramCount := math.MaxUint16 + 1
   849  	params := make([]string, 0, paramCount)
   850  	args := make([][]byte, 0, paramCount)
   851  	for i := 0; i < paramCount; i++ {
   852  		params = append(params, fmt.Sprintf("($%d::text)", i+1))
   853  		args = append(args, []byte(strconv.Itoa(i)))
   854  	}
   855  	sql := "values" + strings.Join(params, ", ")
   856  
   857  	result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
   858  	require.Error(t, result.Err)
   859  	require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
   860  
   861  	ensureConnValid(t, pgConn)
   862  }
   863  
   864  func TestConnExecParamsCanceled(t *testing.T) {
   865  	t.Parallel()
   866  
   867  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   868  	require.NoError(t, err)
   869  	defer closeConn(t, pgConn)
   870  
   871  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   872  	defer cancel()
   873  	result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
   874  	rowCount := 0
   875  	for result.NextRow() {
   876  		rowCount += 1
   877  	}
   878  	assert.Equal(t, 0, rowCount)
   879  	commandTag, err := result.Close()
   880  	assert.Equal(t, pgconn.CommandTag(nil), commandTag)
   881  	assert.True(t, pgconn.Timeout(err))
   882  	assert.ErrorIs(t, err, context.DeadlineExceeded)
   883  
   884  	assert.True(t, pgConn.IsClosed())
   885  	select {
   886  	case <-pgConn.CleanupDone():
   887  	case <-time.After(5 * time.Second):
   888  		t.Fatal("Connection cleanup exceeded maximum time")
   889  	}
   890  }
   891  
   892  func TestConnExecParamsPrecanceled(t *testing.T) {
   893  	t.Parallel()
   894  
   895  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   896  	require.NoError(t, err)
   897  	defer closeConn(t, pgConn)
   898  
   899  	ctx, cancel := context.WithCancel(context.Background())
   900  	cancel()
   901  	result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
   902  	require.Error(t, result.Err)
   903  	assert.True(t, errors.Is(result.Err, context.Canceled))
   904  	assert.True(t, pgconn.SafeToRetry(result.Err))
   905  
   906  	ensureConnValid(t, pgConn)
   907  }
   908  
   909  func TestConnExecParamsEmptySQL(t *testing.T) {
   910  	t.Parallel()
   911  
   912  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   913  	defer cancel()
   914  
   915  	pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
   916  	require.NoError(t, err)
   917  	defer closeConn(t, pgConn)
   918  
   919  	result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
   920  	assert.Nil(t, result.CommandTag)
   921  	assert.Len(t, result.Rows, 0)
   922  	assert.NoError(t, result.Err)
   923  
   924  	ensureConnValid(t, pgConn)
   925  }
   926  
   927  // https://github.com/jackc/pgx/issues/859
   928  func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
   929  	t.Parallel()
   930  
   931  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   932  	require.NoError(t, err)
   933  	defer closeConn(t, pgConn)
   934  
   935  	result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
   936  	require.Len(t, result.FieldDescriptions(), 1)
   937  	assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
   938  
   939  	rowCount := 0
   940  	for result.NextRow() {
   941  		rowCount += 1
   942  		assert.Equal(t, "Hello, world", string(result.Values()[0]))
   943  		assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0]))
   944  	}
   945  	assert.Equal(t, 1, rowCount)
   946  	commandTag, err := result.Close()
   947  	assert.Equal(t, "SELECT 1", string(commandTag))
   948  	assert.NoError(t, err)
   949  
   950  	ensureConnValid(t, pgConn)
   951  }
   952  
   953  func TestConnExecPrepared(t *testing.T) {
   954  	t.Parallel()
   955  
   956  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   957  	require.NoError(t, err)
   958  	defer closeConn(t, pgConn)
   959  
   960  	psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil)
   961  	require.NoError(t, err)
   962  	require.NotNil(t, psd)
   963  	assert.Len(t, psd.ParamOIDs, 1)
   964  	assert.Len(t, psd.Fields, 1)
   965  
   966  	result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
   967  	require.Len(t, result.FieldDescriptions(), 1)
   968  	assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name)
   969  
   970  	rowCount := 0
   971  	for result.NextRow() {
   972  		rowCount += 1
   973  		assert.Equal(t, "Hello, world", string(result.Values()[0]))
   974  	}
   975  	assert.Equal(t, 1, rowCount)
   976  	commandTag, err := result.Close()
   977  	assert.Equal(t, "SELECT 1", string(commandTag))
   978  	assert.NoError(t, err)
   979  
   980  	ensureConnValid(t, pgConn)
   981  }
   982  
   983  func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
   984  	t.Parallel()
   985  
   986  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
   987  	require.NoError(t, err)
   988  	defer closeConn(t, pgConn)
   989  
   990  	paramCount := math.MaxUint16
   991  	params := make([]string, 0, paramCount)
   992  	args := make([][]byte, 0, paramCount)
   993  	for i := 0; i < paramCount; i++ {
   994  		params = append(params, fmt.Sprintf("($%d::text)", i+1))
   995  		args = append(args, []byte(strconv.Itoa(i)))
   996  	}
   997  	sql := "values" + strings.Join(params, ", ")
   998  
   999  	psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
  1000  	require.NoError(t, err)
  1001  	require.NotNil(t, psd)
  1002  	assert.Len(t, psd.ParamOIDs, paramCount)
  1003  	assert.Len(t, psd.Fields, 1)
  1004  
  1005  	result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
  1006  	require.NoError(t, result.Err)
  1007  	require.Len(t, result.Rows, paramCount)
  1008  
  1009  	ensureConnValid(t, pgConn)
  1010  }
  1011  
  1012  func TestConnExecPreparedTooManyParams(t *testing.T) {
  1013  	t.Parallel()
  1014  
  1015  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1016  	require.NoError(t, err)
  1017  	defer closeConn(t, pgConn)
  1018  
  1019  	paramCount := math.MaxUint16 + 1
  1020  	params := make([]string, 0, paramCount)
  1021  	args := make([][]byte, 0, paramCount)
  1022  	for i := 0; i < paramCount; i++ {
  1023  		params = append(params, fmt.Sprintf("($%d::text)", i+1))
  1024  		args = append(args, []byte(strconv.Itoa(i)))
  1025  	}
  1026  	sql := "values" + strings.Join(params, ", ")
  1027  
  1028  	psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
  1029  	if pgConn.ParameterStatus("crdb_version") != "" {
  1030  		// CockroachDB rejects preparing a statement with more than 65535 parameters.
  1031  		require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)")
  1032  	} else {
  1033  		// PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol.
  1034  		require.NoError(t, err)
  1035  		require.NotNil(t, psd)
  1036  		assert.Len(t, psd.ParamOIDs, paramCount)
  1037  		assert.Len(t, psd.Fields, 1)
  1038  
  1039  		result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
  1040  		require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters")
  1041  	}
  1042  
  1043  	ensureConnValid(t, pgConn)
  1044  }
  1045  
  1046  func TestConnExecPreparedCanceled(t *testing.T) {
  1047  	t.Parallel()
  1048  
  1049  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1050  	require.NoError(t, err)
  1051  	defer closeConn(t, pgConn)
  1052  
  1053  	_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
  1054  	require.NoError(t, err)
  1055  
  1056  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
  1057  	defer cancel()
  1058  	result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
  1059  	rowCount := 0
  1060  	for result.NextRow() {
  1061  		rowCount += 1
  1062  	}
  1063  	assert.Equal(t, 0, rowCount)
  1064  	commandTag, err := result.Close()
  1065  	assert.Equal(t, pgconn.CommandTag(nil), commandTag)
  1066  	assert.True(t, pgconn.Timeout(err))
  1067  	assert.True(t, pgConn.IsClosed())
  1068  	select {
  1069  	case <-pgConn.CleanupDone():
  1070  	case <-time.After(5 * time.Second):
  1071  		t.Fatal("Connection cleanup exceeded maximum time")
  1072  	}
  1073  }
  1074  
  1075  func TestConnExecPreparedPrecanceled(t *testing.T) {
  1076  	t.Parallel()
  1077  
  1078  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1079  	require.NoError(t, err)
  1080  	defer closeConn(t, pgConn)
  1081  
  1082  	_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
  1083  	require.NoError(t, err)
  1084  
  1085  	ctx, cancel := context.WithCancel(context.Background())
  1086  	cancel()
  1087  	result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
  1088  	require.Error(t, result.Err)
  1089  	assert.True(t, errors.Is(result.Err, context.Canceled))
  1090  	assert.True(t, pgconn.SafeToRetry(result.Err))
  1091  
  1092  	ensureConnValid(t, pgConn)
  1093  }
  1094  
  1095  func TestConnExecPreparedEmptySQL(t *testing.T) {
  1096  	t.Parallel()
  1097  
  1098  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
  1099  	defer cancel()
  1100  
  1101  	pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
  1102  	require.NoError(t, err)
  1103  	defer closeConn(t, pgConn)
  1104  
  1105  	_, err = pgConn.Prepare(ctx, "ps1", "", nil)
  1106  	require.NoError(t, err)
  1107  
  1108  	result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
  1109  	assert.Nil(t, result.CommandTag)
  1110  	assert.Len(t, result.Rows, 0)
  1111  	assert.NoError(t, result.Err)
  1112  
  1113  	ensureConnValid(t, pgConn)
  1114  }
  1115  
  1116  func TestConnExecBatch(t *testing.T) {
  1117  	t.Parallel()
  1118  
  1119  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1120  	require.NoError(t, err)
  1121  	defer closeConn(t, pgConn)
  1122  
  1123  	_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
  1124  	require.NoError(t, err)
  1125  
  1126  	batch := &pgconn.Batch{}
  1127  
  1128  	batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
  1129  	batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
  1130  	batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
  1131  	results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
  1132  	require.NoError(t, err)
  1133  	require.Len(t, results, 3)
  1134  
  1135  	require.Len(t, results[0].Rows, 1)
  1136  	require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
  1137  	assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
  1138  
  1139  	require.Len(t, results[1].Rows, 1)
  1140  	require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
  1141  	assert.Equal(t, "SELECT 1", string(results[1].CommandTag))
  1142  
  1143  	require.Len(t, results[2].Rows, 1)
  1144  	require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
  1145  	assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
  1146  }
  1147  
  1148  func TestConnExecBatchDeferredError(t *testing.T) {
  1149  	t.Parallel()
  1150  
  1151  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1152  	require.NoError(t, err)
  1153  	defer closeConn(t, pgConn)
  1154  
  1155  	if pgConn.ParameterStatus("crdb_version") != "" {
  1156  		t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
  1157  	}
  1158  
  1159  	setupSQL := `create temporary table t (
  1160  		id text primary key,
  1161  		n int not null,
  1162  		unique (n) deferrable initially deferred
  1163  	);
  1164  
  1165  	insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
  1166  
  1167  	_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
  1168  	require.NoError(t, err)
  1169  
  1170  	batch := &pgconn.Batch{}
  1171  
  1172  	batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
  1173  	_, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
  1174  	require.NotNil(t, err)
  1175  	var pgErr *pgconn.PgError
  1176  	require.True(t, errors.As(err, &pgErr))
  1177  	require.Equal(t, "23505", pgErr.Code)
  1178  
  1179  	ensureConnValid(t, pgConn)
  1180  }
  1181  
  1182  func TestConnExecBatchPrecanceled(t *testing.T) {
  1183  	t.Parallel()
  1184  
  1185  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1186  	require.NoError(t, err)
  1187  	defer closeConn(t, pgConn)
  1188  
  1189  	_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
  1190  	require.NoError(t, err)
  1191  
  1192  	batch := &pgconn.Batch{}
  1193  
  1194  	batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
  1195  	batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
  1196  	batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
  1197  
  1198  	ctx, cancel := context.WithCancel(context.Background())
  1199  	cancel()
  1200  	_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
  1201  	require.Error(t, err)
  1202  	assert.True(t, errors.Is(err, context.Canceled))
  1203  	assert.True(t, pgconn.SafeToRetry(err))
  1204  
  1205  	ensureConnValid(t, pgConn)
  1206  }
  1207  
  1208  // Without concurrent reading and writing large batches can deadlock.
  1209  //
  1210  // See https://github.com/jackc/pgx/issues/374.
  1211  func TestConnExecBatchHuge(t *testing.T) {
  1212  	if testing.Short() {
  1213  		t.Skip("skipping test in short mode.")
  1214  	}
  1215  
  1216  	t.Parallel()
  1217  
  1218  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1219  	require.NoError(t, err)
  1220  	defer closeConn(t, pgConn)
  1221  
  1222  	batch := &pgconn.Batch{}
  1223  
  1224  	queryCount := 100000
  1225  	args := make([]string, queryCount)
  1226  
  1227  	for i := range args {
  1228  		args[i] = strconv.Itoa(i)
  1229  		batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil)
  1230  	}
  1231  
  1232  	results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
  1233  	require.NoError(t, err)
  1234  	require.Len(t, results, queryCount)
  1235  
  1236  	for i := range args {
  1237  		require.Len(t, results[i].Rows, 1)
  1238  		require.Equal(t, args[i], string(results[i].Rows[0][0]))
  1239  		assert.Equal(t, "SELECT 1", string(results[i].CommandTag))
  1240  	}
  1241  }
  1242  
  1243  func TestConnExecBatchImplicitTransaction(t *testing.T) {
  1244  	t.Parallel()
  1245  
  1246  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1247  	require.NoError(t, err)
  1248  	defer closeConn(t, pgConn)
  1249  
  1250  	if pgConn.ParameterStatus("crdb_version") != "" {
  1251  		t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)")
  1252  	}
  1253  
  1254  	_, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll()
  1255  	require.NoError(t, err)
  1256  
  1257  	batch := &pgconn.Batch{}
  1258  
  1259  	batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil)
  1260  	batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil)
  1261  	batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil)
  1262  	batch.ExecParams("select 1/0", nil, nil, nil, nil)
  1263  	_, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
  1264  	require.Error(t, err)
  1265  
  1266  	result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read()
  1267  	require.Equal(t, "0", string(result.Rows[0][0]))
  1268  }
  1269  
  1270  func TestConnLocking(t *testing.T) {
  1271  	t.Parallel()
  1272  
  1273  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1274  	require.NoError(t, err)
  1275  	defer closeConn(t, pgConn)
  1276  
  1277  	mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
  1278  	_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
  1279  	assert.Error(t, err)
  1280  	assert.Equal(t, "conn busy", err.Error())
  1281  	assert.True(t, pgconn.SafeToRetry(err))
  1282  
  1283  	results, err := mrr.ReadAll()
  1284  	assert.NoError(t, err)
  1285  	assert.Len(t, results, 1)
  1286  	assert.Nil(t, results[0].Err)
  1287  	assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
  1288  	assert.Len(t, results[0].Rows, 1)
  1289  	assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
  1290  
  1291  	ensureConnValid(t, pgConn)
  1292  }
  1293  
  1294  func TestCommandTag(t *testing.T) {
  1295  	t.Parallel()
  1296  
  1297  	var tests = []struct {
  1298  		commandTag   pgconn.CommandTag
  1299  		rowsAffected int64
  1300  		isInsert     bool
  1301  		isUpdate     bool
  1302  		isDelete     bool
  1303  		isSelect     bool
  1304  	}{
  1305  		{commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true},
  1306  		{commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true},
  1307  		{commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true},
  1308  		{commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true},
  1309  		{commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true},
  1310  		{commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true},
  1311  		{commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true},
  1312  		{commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true},
  1313  		{commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0},
  1314  		{commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0},
  1315  		{commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0},
  1316  	}
  1317  
  1318  	for i, tt := range tests {
  1319  		ct := tt.commandTag
  1320  		assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag)
  1321  		assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag)
  1322  		assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag)
  1323  		assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag)
  1324  		assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
  1325  	}
  1326  }
  1327  
  1328  func TestConnOnNotice(t *testing.T) {
  1329  	t.Parallel()
  1330  
  1331  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
  1332  	require.NoError(t, err)
  1333  
  1334  	var msg string
  1335  	config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) {
  1336  		msg = notice.Message
  1337  	}
  1338  	config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect.
  1339  
  1340  	pgConn, err := pgconn.ConnectConfig(context.Background(), config)
  1341  	require.NoError(t, err)
  1342  	defer closeConn(t, pgConn)
  1343  
  1344  	if pgConn.ParameterStatus("crdb_version") != "" {
  1345  		t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)")
  1346  	}
  1347  
  1348  	multiResult := pgConn.Exec(context.Background(), `do $$
  1349  begin
  1350    raise notice 'hello, world';
  1351  end$$;`)
  1352  	err = multiResult.Close()
  1353  	require.NoError(t, err)
  1354  	assert.Equal(t, "hello, world", msg)
  1355  
  1356  	ensureConnValid(t, pgConn)
  1357  }
  1358  
  1359  func TestConnOnNotification(t *testing.T) {
  1360  	t.Parallel()
  1361  
  1362  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
  1363  	require.NoError(t, err)
  1364  
  1365  	var msg string
  1366  	config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
  1367  		msg = n.Payload
  1368  	}
  1369  
  1370  	pgConn, err := pgconn.ConnectConfig(context.Background(), config)
  1371  	require.NoError(t, err)
  1372  	defer closeConn(t, pgConn)
  1373  
  1374  	if pgConn.ParameterStatus("crdb_version") != "" {
  1375  		t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
  1376  	}
  1377  
  1378  	_, err = pgConn.Exec(context.Background(), "listen foo").ReadAll()
  1379  	require.NoError(t, err)
  1380  
  1381  	notifier, err := pgconn.ConnectConfig(context.Background(), config)
  1382  	require.NoError(t, err)
  1383  	defer closeConn(t, notifier)
  1384  	_, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll()
  1385  	require.NoError(t, err)
  1386  
  1387  	_, err = pgConn.Exec(context.Background(), "select 1").ReadAll()
  1388  	require.NoError(t, err)
  1389  
  1390  	assert.Equal(t, "bar", msg)
  1391  
  1392  	ensureConnValid(t, pgConn)
  1393  }
  1394  
  1395  func TestConnWaitForNotification(t *testing.T) {
  1396  	t.Parallel()
  1397  
  1398  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
  1399  	require.NoError(t, err)
  1400  
  1401  	var msg string
  1402  	config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
  1403  		msg = n.Payload
  1404  	}
  1405  
  1406  	pgConn, err := pgconn.ConnectConfig(context.Background(), config)
  1407  	require.NoError(t, err)
  1408  	defer closeConn(t, pgConn)
  1409  
  1410  	if pgConn.ParameterStatus("crdb_version") != "" {
  1411  		t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
  1412  	}
  1413  
  1414  	_, err = pgConn.Exec(context.Background(), "listen foo").ReadAll()
  1415  	require.NoError(t, err)
  1416  
  1417  	notifier, err := pgconn.ConnectConfig(context.Background(), config)
  1418  	require.NoError(t, err)
  1419  	defer closeConn(t, notifier)
  1420  	_, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll()
  1421  	require.NoError(t, err)
  1422  
  1423  	err = pgConn.WaitForNotification(context.Background())
  1424  	require.NoError(t, err)
  1425  
  1426  	assert.Equal(t, "bar", msg)
  1427  
  1428  	ensureConnValid(t, pgConn)
  1429  }
  1430  
  1431  func TestConnWaitForNotificationPrecanceled(t *testing.T) {
  1432  	t.Parallel()
  1433  
  1434  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
  1435  	require.NoError(t, err)
  1436  
  1437  	pgConn, err := pgconn.ConnectConfig(context.Background(), config)
  1438  	require.NoError(t, err)
  1439  	defer closeConn(t, pgConn)
  1440  
  1441  	ctx, cancel := context.WithCancel(context.Background())
  1442  	cancel()
  1443  	err = pgConn.WaitForNotification(ctx)
  1444  	require.ErrorIs(t, err, context.Canceled)
  1445  
  1446  	ensureConnValid(t, pgConn)
  1447  }
  1448  
  1449  func TestConnWaitForNotificationTimeout(t *testing.T) {
  1450  	t.Parallel()
  1451  
  1452  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
  1453  	require.NoError(t, err)
  1454  
  1455  	pgConn, err := pgconn.ConnectConfig(context.Background(), config)
  1456  	require.NoError(t, err)
  1457  	defer closeConn(t, pgConn)
  1458  
  1459  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
  1460  	err = pgConn.WaitForNotification(ctx)
  1461  	cancel()
  1462  	assert.True(t, pgconn.Timeout(err))
  1463  	assert.ErrorIs(t, err, context.DeadlineExceeded)
  1464  
  1465  	ensureConnValid(t, pgConn)
  1466  }
  1467  
  1468  func TestConnCopyToSmall(t *testing.T) {
  1469  	t.Parallel()
  1470  
  1471  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1472  	require.NoError(t, err)
  1473  	defer closeConn(t, pgConn)
  1474  
  1475  	if pgConn.ParameterStatus("crdb_version") != "" {
  1476  		t.Skip("Server does support COPY TO")
  1477  	}
  1478  
  1479  	_, err = pgConn.Exec(context.Background(), `create temporary table foo(
  1480  		a int2,
  1481  		b int4,
  1482  		c int8,
  1483  		d varchar,
  1484  		e text,
  1485  		f date,
  1486  		g json
  1487  	)`).ReadAll()
  1488  	require.NoError(t, err)
  1489  
  1490  	_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
  1491  	require.NoError(t, err)
  1492  
  1493  	_, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
  1494  	require.NoError(t, err)
  1495  
  1496  	inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
  1497  		"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
  1498  
  1499  	outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
  1500  
  1501  	res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
  1502  	require.NoError(t, err)
  1503  
  1504  	assert.Equal(t, int64(2), res.RowsAffected())
  1505  	assert.Equal(t, inputBytes, outputWriter.Bytes())
  1506  
  1507  	ensureConnValid(t, pgConn)
  1508  }
  1509  
  1510  func TestConnCopyToLarge(t *testing.T) {
  1511  	t.Parallel()
  1512  
  1513  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1514  	require.NoError(t, err)
  1515  	defer closeConn(t, pgConn)
  1516  
  1517  	if pgConn.ParameterStatus("crdb_version") != "" {
  1518  		t.Skip("Server does support COPY TO")
  1519  	}
  1520  
  1521  	_, err = pgConn.Exec(context.Background(), `create temporary table foo(
  1522  		a int2,
  1523  		b int4,
  1524  		c int8,
  1525  		d varchar,
  1526  		e text,
  1527  		f date,
  1528  		g json,
  1529  		h bytea
  1530  	)`).ReadAll()
  1531  	require.NoError(t, err)
  1532  
  1533  	inputBytes := make([]byte, 0)
  1534  
  1535  	for i := 0; i < 1000; i++ {
  1536  		_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
  1537  		require.NoError(t, err)
  1538  		inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
  1539  	}
  1540  
  1541  	outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
  1542  
  1543  	res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
  1544  	require.NoError(t, err)
  1545  
  1546  	assert.Equal(t, int64(1000), res.RowsAffected())
  1547  	assert.Equal(t, inputBytes, outputWriter.Bytes())
  1548  
  1549  	ensureConnValid(t, pgConn)
  1550  }
  1551  
  1552  func TestConnCopyToQueryError(t *testing.T) {
  1553  	t.Parallel()
  1554  
  1555  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1556  	require.NoError(t, err)
  1557  	defer closeConn(t, pgConn)
  1558  
  1559  	outputWriter := bytes.NewBuffer(make([]byte, 0))
  1560  
  1561  	res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout")
  1562  	require.Error(t, err)
  1563  	assert.IsType(t, &pgconn.PgError{}, err)
  1564  	assert.Equal(t, int64(0), res.RowsAffected())
  1565  
  1566  	ensureConnValid(t, pgConn)
  1567  }
  1568  
  1569  func TestConnCopyToCanceled(t *testing.T) {
  1570  	t.Parallel()
  1571  
  1572  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1573  	require.NoError(t, err)
  1574  	defer closeConn(t, pgConn)
  1575  
  1576  	if pgConn.ParameterStatus("crdb_version") != "" {
  1577  		t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
  1578  	}
  1579  
  1580  	outputWriter := &bytes.Buffer{}
  1581  
  1582  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
  1583  	defer cancel()
  1584  	res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
  1585  	assert.Error(t, err)
  1586  	assert.Equal(t, pgconn.CommandTag(nil), res)
  1587  
  1588  	assert.True(t, pgConn.IsClosed())
  1589  	select {
  1590  	case <-pgConn.CleanupDone():
  1591  	case <-time.After(5 * time.Second):
  1592  		t.Fatal("Connection cleanup exceeded maximum time")
  1593  	}
  1594  }
  1595  
  1596  func TestConnCopyToPrecanceled(t *testing.T) {
  1597  	t.Parallel()
  1598  
  1599  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1600  	require.NoError(t, err)
  1601  	defer closeConn(t, pgConn)
  1602  
  1603  	outputWriter := &bytes.Buffer{}
  1604  
  1605  	ctx, cancel := context.WithCancel(context.Background())
  1606  	cancel()
  1607  	res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
  1608  	require.Error(t, err)
  1609  	assert.True(t, errors.Is(err, context.Canceled))
  1610  	assert.True(t, pgconn.SafeToRetry(err))
  1611  	assert.Equal(t, pgconn.CommandTag(nil), res)
  1612  
  1613  	ensureConnValid(t, pgConn)
  1614  }
  1615  
  1616  func TestConnCopyFrom(t *testing.T) {
  1617  	t.Parallel()
  1618  
  1619  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1620  	require.NoError(t, err)
  1621  	defer closeConn(t, pgConn)
  1622  
  1623  	if pgConn.ParameterStatus("crdb_version") != "" {
  1624  		t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
  1625  	}
  1626  
  1627  	_, err = pgConn.Exec(context.Background(), `create temporary table foo(
  1628  		a int4,
  1629  		b varchar
  1630  	)`).ReadAll()
  1631  	require.NoError(t, err)
  1632  
  1633  	srcBuf := &bytes.Buffer{}
  1634  
  1635  	inputRows := [][][]byte{}
  1636  	for i := 0; i < 1000; i++ {
  1637  		a := strconv.Itoa(i)
  1638  		b := "foo " + a + " bar"
  1639  		inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
  1640  		_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
  1641  		require.NoError(t, err)
  1642  	}
  1643  
  1644  	ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)")
  1645  	require.NoError(t, err)
  1646  	assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
  1647  
  1648  	result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
  1649  	require.NoError(t, result.Err)
  1650  
  1651  	assert.Equal(t, inputRows, result.Rows)
  1652  
  1653  	ensureConnValid(t, pgConn)
  1654  }
  1655  
  1656  func TestConnCopyFromCanceled(t *testing.T) {
  1657  	t.Parallel()
  1658  
  1659  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1660  	require.NoError(t, err)
  1661  	defer closeConn(t, pgConn)
  1662  
  1663  	if pgConn.ParameterStatus("crdb_version") != "" {
  1664  		t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
  1665  	}
  1666  
  1667  	_, err = pgConn.Exec(context.Background(), `create temporary table foo(
  1668  		a int4,
  1669  		b varchar
  1670  	)`).ReadAll()
  1671  	require.NoError(t, err)
  1672  
  1673  	r, w := io.Pipe()
  1674  	go func() {
  1675  		for i := 0; i < 1000000; i++ {
  1676  			a := strconv.Itoa(i)
  1677  			b := "foo " + a + " bar"
  1678  			_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
  1679  			if err != nil {
  1680  				return
  1681  			}
  1682  			time.Sleep(time.Microsecond)
  1683  		}
  1684  	}()
  1685  
  1686  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
  1687  	ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
  1688  	cancel()
  1689  	assert.Equal(t, int64(0), ct.RowsAffected())
  1690  	assert.Error(t, err)
  1691  
  1692  	assert.True(t, pgConn.IsClosed())
  1693  	select {
  1694  	case <-pgConn.CleanupDone():
  1695  	case <-time.After(5 * time.Second):
  1696  		t.Fatal("Connection cleanup exceeded maximum time")
  1697  	}
  1698  }
  1699  
  1700  func TestConnCopyFromPrecanceled(t *testing.T) {
  1701  	t.Parallel()
  1702  
  1703  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1704  	require.NoError(t, err)
  1705  	defer closeConn(t, pgConn)
  1706  
  1707  	_, err = pgConn.Exec(context.Background(), `create temporary table foo(
  1708  		a int4,
  1709  		b varchar
  1710  	)`).ReadAll()
  1711  	require.NoError(t, err)
  1712  
  1713  	r, w := io.Pipe()
  1714  	go func() {
  1715  		for i := 0; i < 1000000; i++ {
  1716  			a := strconv.Itoa(i)
  1717  			b := "foo " + a + " bar"
  1718  			_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
  1719  			if err != nil {
  1720  				return
  1721  			}
  1722  			time.Sleep(time.Microsecond)
  1723  		}
  1724  	}()
  1725  
  1726  	ctx, cancel := context.WithCancel(context.Background())
  1727  	cancel()
  1728  	ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
  1729  	require.Error(t, err)
  1730  	assert.True(t, errors.Is(err, context.Canceled))
  1731  	assert.True(t, pgconn.SafeToRetry(err))
  1732  	assert.Equal(t, pgconn.CommandTag(nil), ct)
  1733  
  1734  	ensureConnValid(t, pgConn)
  1735  }
  1736  
  1737  func TestConnCopyFromGzipReader(t *testing.T) {
  1738  	t.Parallel()
  1739  
  1740  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1741  	require.NoError(t, err)
  1742  	defer closeConn(t, pgConn)
  1743  
  1744  	if pgConn.ParameterStatus("crdb_version") != "" {
  1745  		t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
  1746  	}
  1747  
  1748  	_, err = pgConn.Exec(context.Background(), `create temporary table foo(
  1749  		a int4,
  1750  		b varchar
  1751  	)`).ReadAll()
  1752  	require.NoError(t, err)
  1753  
  1754  	f, err := ioutil.TempFile("", "*")
  1755  	require.NoError(t, err)
  1756  
  1757  	gw := gzip.NewWriter(f)
  1758  
  1759  	inputRows := [][][]byte{}
  1760  	for i := 0; i < 1000; i++ {
  1761  		a := strconv.Itoa(i)
  1762  		b := "foo " + a + " bar"
  1763  		inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
  1764  		_, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
  1765  		require.NoError(t, err)
  1766  	}
  1767  
  1768  	err = gw.Close()
  1769  	require.NoError(t, err)
  1770  
  1771  	_, err = f.Seek(0, 0)
  1772  	require.NoError(t, err)
  1773  
  1774  	gr, err := gzip.NewReader(f)
  1775  	require.NoError(t, err)
  1776  
  1777  	ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)")
  1778  	require.NoError(t, err)
  1779  	assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
  1780  
  1781  	err = gr.Close()
  1782  	require.NoError(t, err)
  1783  
  1784  	err = f.Close()
  1785  	require.NoError(t, err)
  1786  
  1787  	err = os.Remove(f.Name())
  1788  	require.NoError(t, err)
  1789  
  1790  	result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
  1791  	require.NoError(t, result.Err)
  1792  
  1793  	assert.Equal(t, inputRows, result.Rows)
  1794  
  1795  	ensureConnValid(t, pgConn)
  1796  }
  1797  
  1798  func TestConnCopyFromQuerySyntaxError(t *testing.T) {
  1799  	t.Parallel()
  1800  
  1801  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1802  	require.NoError(t, err)
  1803  	defer closeConn(t, pgConn)
  1804  
  1805  	_, err = pgConn.Exec(context.Background(), `create temporary table foo(
  1806  		a int4,
  1807  		b varchar
  1808  	)`).ReadAll()
  1809  	require.NoError(t, err)
  1810  
  1811  	srcBuf := &bytes.Buffer{}
  1812  
  1813  	res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout")
  1814  	require.Error(t, err)
  1815  	assert.IsType(t, &pgconn.PgError{}, err)
  1816  	assert.Equal(t, int64(0), res.RowsAffected())
  1817  
  1818  	ensureConnValid(t, pgConn)
  1819  }
  1820  
  1821  func TestConnCopyFromQueryNoTableError(t *testing.T) {
  1822  	t.Parallel()
  1823  
  1824  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1825  	require.NoError(t, err)
  1826  	defer closeConn(t, pgConn)
  1827  
  1828  	srcBuf := &bytes.Buffer{}
  1829  
  1830  	res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout")
  1831  	require.Error(t, err)
  1832  	assert.IsType(t, &pgconn.PgError{}, err)
  1833  	assert.Equal(t, int64(0), res.RowsAffected())
  1834  
  1835  	ensureConnValid(t, pgConn)
  1836  }
  1837  
  1838  // https://github.com/jackc/pgconn/issues/21
  1839  func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
  1840  	t.Parallel()
  1841  
  1842  	ctx := context.Background()
  1843  	pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
  1844  	require.NoError(t, err)
  1845  	defer closeConn(t, pgConn)
  1846  
  1847  	if pgConn.ParameterStatus("crdb_version") != "" {
  1848  		t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)")
  1849  	}
  1850  
  1851  	_, err = pgConn.Exec(ctx, `create temporary table sentences(
  1852  		t text,
  1853  		ts tsvector
  1854  	)`).ReadAll()
  1855  	require.NoError(t, err)
  1856  
  1857  	_, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
  1858  	begin
  1859  	  new.ts := to_tsvector(new.t);
  1860  		return new;
  1861  	end
  1862  	$$ language plpgsql;`).ReadAll()
  1863  	require.NoError(t, err)
  1864  
  1865  	_, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
  1866  	require.NoError(t, err)
  1867  
  1868  	longString := make([]byte, 10001)
  1869  	for i := range longString {
  1870  		longString[i] = 'x'
  1871  	}
  1872  
  1873  	buf := &bytes.Buffer{}
  1874  	for i := 0; i < 1000; i++ {
  1875  		buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
  1876  	}
  1877  
  1878  	_, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
  1879  	require.NoError(t, err)
  1880  }
  1881  
  1882  func TestConnEscapeString(t *testing.T) {
  1883  	t.Parallel()
  1884  
  1885  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1886  	require.NoError(t, err)
  1887  	defer closeConn(t, pgConn)
  1888  
  1889  	tests := []struct {
  1890  		in  string
  1891  		out string
  1892  	}{
  1893  		{in: "", out: ""},
  1894  		{in: "42", out: "42"},
  1895  		{in: "'", out: "''"},
  1896  		{in: "hi'there", out: "hi''there"},
  1897  		{in: "'hi there'", out: "''hi there''"},
  1898  	}
  1899  
  1900  	for i, tt := range tests {
  1901  		value, err := pgConn.EscapeString(tt.in)
  1902  		if assert.NoErrorf(t, err, "%d.", i) {
  1903  			assert.Equalf(t, tt.out, value, "%d.", i)
  1904  		}
  1905  	}
  1906  
  1907  	ensureConnValid(t, pgConn)
  1908  }
  1909  
  1910  func TestConnCancelRequest(t *testing.T) {
  1911  	t.Parallel()
  1912  
  1913  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1914  	require.NoError(t, err)
  1915  	defer closeConn(t, pgConn)
  1916  
  1917  	if pgConn.ParameterStatus("crdb_version") != "" {
  1918  		t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
  1919  	}
  1920  
  1921  	multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)")
  1922  
  1923  	// This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a
  1924  	// response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a
  1925  	// few milliseconds.
  1926  	time.Sleep(50 * time.Millisecond)
  1927  
  1928  	err = pgConn.CancelRequest(context.Background())
  1929  	require.NoError(t, err)
  1930  
  1931  	for multiResult.NextResult() {
  1932  	}
  1933  	err = multiResult.Close()
  1934  
  1935  	require.IsType(t, &pgconn.PgError{}, err)
  1936  	require.Equal(t, "57014", err.(*pgconn.PgError).Code)
  1937  
  1938  	ensureConnValid(t, pgConn)
  1939  }
  1940  
  1941  // https://github.com/jackc/pgx/issues/659
  1942  func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
  1943  	t.Parallel()
  1944  
  1945  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1946  	require.NoError(t, err)
  1947  	defer closeConn(t, pgConn)
  1948  
  1949  	pid := pgConn.PID()
  1950  
  1951  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
  1952  	defer cancel()
  1953  	multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)")
  1954  
  1955  	for multiResult.NextResult() {
  1956  	}
  1957  	err = multiResult.Close()
  1958  	assert.True(t, pgconn.Timeout(err))
  1959  	assert.True(t, pgConn.IsClosed())
  1960  	select {
  1961  	case <-pgConn.CleanupDone():
  1962  	case <-time.After(5 * time.Second):
  1963  		t.Fatal("Connection cleanup exceeded maximum time")
  1964  	}
  1965  
  1966  	otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  1967  	require.NoError(t, err)
  1968  	defer closeConn(t, otherConn)
  1969  
  1970  	ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
  1971  	defer cancel()
  1972  
  1973  	for {
  1974  		result := otherConn.ExecParams(ctx,
  1975  			`select 1 from pg_stat_activity where pid=$1`,
  1976  			[][]byte{[]byte(strconv.FormatInt(int64(pid), 10))},
  1977  			nil,
  1978  			nil,
  1979  			nil,
  1980  		).Read()
  1981  		require.NoError(t, result.Err)
  1982  
  1983  		if len(result.Rows) == 0 {
  1984  			break
  1985  		}
  1986  	}
  1987  }
  1988  
  1989  func TestConnSendBytesAndReceiveMessage(t *testing.T) {
  1990  	t.Parallel()
  1991  
  1992  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
  1993  	defer cancel()
  1994  
  1995  	config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
  1996  	require.NoError(t, err)
  1997  	config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect.
  1998  
  1999  	pgConn, err := pgconn.ConnectConfig(context.Background(), config)
  2000  	require.NoError(t, err)
  2001  	defer closeConn(t, pgConn)
  2002  
  2003  	queryMsg := pgproto3.Query{String: "select 42"}
  2004  	buf, err := queryMsg.Encode(nil)
  2005  	require.NoError(t, err)
  2006  
  2007  	err = pgConn.SendBytes(ctx, buf)
  2008  	require.NoError(t, err)
  2009  
  2010  	msg, err := pgConn.ReceiveMessage(ctx)
  2011  	require.NoError(t, err)
  2012  	_, ok := msg.(*pgproto3.RowDescription)
  2013  	require.True(t, ok)
  2014  
  2015  	msg, err = pgConn.ReceiveMessage(ctx)
  2016  	require.NoError(t, err)
  2017  	_, ok = msg.(*pgproto3.DataRow)
  2018  	require.True(t, ok)
  2019  
  2020  	msg, err = pgConn.ReceiveMessage(ctx)
  2021  	require.NoError(t, err)
  2022  	_, ok = msg.(*pgproto3.CommandComplete)
  2023  	require.True(t, ok)
  2024  
  2025  	msg, err = pgConn.ReceiveMessage(ctx)
  2026  	require.NoError(t, err)
  2027  	_, ok = msg.(*pgproto3.ReadyForQuery)
  2028  	require.True(t, ok)
  2029  
  2030  	ensureConnValid(t, pgConn)
  2031  }
  2032  
  2033  func TestHijackAndConstruct(t *testing.T) {
  2034  	t.Parallel()
  2035  
  2036  	origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  2037  	require.NoError(t, err)
  2038  
  2039  	hc, err := origConn.Hijack()
  2040  	require.NoError(t, err)
  2041  
  2042  	_, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
  2043  	require.Error(t, err)
  2044  
  2045  	newConn, err := pgconn.Construct(hc)
  2046  	require.NoError(t, err)
  2047  
  2048  	defer closeConn(t, newConn)
  2049  
  2050  	results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
  2051  	assert.NoError(t, err)
  2052  
  2053  	assert.Len(t, results, 1)
  2054  	assert.Nil(t, results[0].Err)
  2055  	assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
  2056  	assert.Len(t, results[0].Rows, 1)
  2057  	assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
  2058  
  2059  	ensureConnValid(t, newConn)
  2060  }
  2061  
  2062  func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) {
  2063  	t.Parallel()
  2064  
  2065  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  2066  	require.NoError(t, err)
  2067  
  2068  	ctx, _ := context.WithCancel(context.Background())
  2069  	pgConn.Exec(ctx, "select n from generate_series(1,10) n")
  2070  
  2071  	closeCtx, _ := context.WithCancel(context.Background())
  2072  	pgConn.Close(closeCtx)
  2073  	select {
  2074  	case <-pgConn.CleanupDone():
  2075  	case <-time.After(5 * time.Second):
  2076  		t.Fatal("Connection cleanup exceeded maximum time")
  2077  	}
  2078  }
  2079  
  2080  // https://github.com/jackc/pgx/issues/800
  2081  func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) {
  2082  	t.Parallel()
  2083  
  2084  	steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
  2085  	steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
  2086  	steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{}))
  2087  	steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
  2088  	steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{}))
  2089  	steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{}))
  2090  	steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
  2091  		{Name: []byte("mock")},
  2092  	}}))
  2093  	steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}))
  2094  	steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
  2095  
  2096  	script := &pgmock.Script{Steps: steps}
  2097  
  2098  	ln, err := net.Listen("tcp", "127.0.0.1:")
  2099  	require.NoError(t, err)
  2100  	defer ln.Close()
  2101  
  2102  	serverErrChan := make(chan error, 1)
  2103  	go func() {
  2104  		defer close(serverErrChan)
  2105  
  2106  		conn, err := ln.Accept()
  2107  		if err != nil {
  2108  			serverErrChan <- err
  2109  			return
  2110  		}
  2111  		defer conn.Close()
  2112  
  2113  		err = conn.SetDeadline(time.Now().Add(5 * time.Second))
  2114  		if err != nil {
  2115  			serverErrChan <- err
  2116  			return
  2117  		}
  2118  
  2119  		err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
  2120  		if err != nil {
  2121  			serverErrChan <- err
  2122  			return
  2123  		}
  2124  	}()
  2125  
  2126  	parts := strings.Split(ln.Addr().String(), ":")
  2127  	host := parts[0]
  2128  	port := parts[1]
  2129  	connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
  2130  
  2131  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  2132  	defer cancel()
  2133  	conn, err := pgconn.Connect(ctx, connStr)
  2134  	require.NoError(t, err)
  2135  
  2136  	rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil)
  2137  
  2138  	for rr.NextRow() {
  2139  	}
  2140  
  2141  	_, err = rr.Close()
  2142  	require.Error(t, err)
  2143  }
  2144  
  2145  func Example() {
  2146  	pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
  2147  	if err != nil {
  2148  		log.Fatalln(err)
  2149  	}
  2150  	defer pgConn.Close(context.Background())
  2151  
  2152  	result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read()
  2153  	if result.Err != nil {
  2154  		log.Fatalln(result.Err)
  2155  	}
  2156  
  2157  	for _, row := range result.Rows {
  2158  		fmt.Println(string(row[0]))
  2159  	}
  2160  
  2161  	fmt.Println(result.CommandTag)
  2162  	// Output:
  2163  	// 1
  2164  	// 2
  2165  	// 3
  2166  	// SELECT 3
  2167  }
  2168  
  2169  func GetSSLPassword(ctx context.Context) string {
  2170  	connString := os.Getenv("PGX_SSL_PASSWORD")
  2171  	return connString
  2172  }
  2173  
  2174  var rsaCertPEM = `-----BEGIN CERTIFICATE-----
  2175  MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
  2176  BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
  2177  NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
  2178  AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
  2179  Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
  2180  tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
  2181  9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
  2182  0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
  2183  MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
  2184  FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
  2185  6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
  2186  gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
  2187  81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
  2188  Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
  2189  hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
  2190  VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
  2191  MlascjupnaptKX/wMA==
  2192  -----END CERTIFICATE-----
  2193  `
  2194  
  2195  var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
  2196  MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
  2197  ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
  2198  Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
  2199  bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
  2200  qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
  2201  Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
  2202  o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
  2203  WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
  2204  ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
  2205  Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
  2206  QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
  2207  QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
  2208  CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
  2209  bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
  2210  1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
  2211  SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
  2212  MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
  2213  McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
  2214  I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
  2215  QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
  2216  k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
  2217  lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
  2218  TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
  2219  5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
  2220  UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
  2221  z3w+CgS20UrbLIR1YXfqUXge1g==
  2222  -----END TESTING KEY-----
  2223  `)
  2224  
  2225  func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
  2226  
  2227  func TestSNISupport(t *testing.T) {
  2228  	t.Parallel()
  2229  	tests := []struct {
  2230  		name      string
  2231  		sni_param string
  2232  		sni_set   bool
  2233  	}{
  2234  		{
  2235  			name:      "SNI is passed by default",
  2236  			sni_param: "",
  2237  			sni_set:   true,
  2238  		},
  2239  		{
  2240  			name:      "SNI is passed when asked for",
  2241  			sni_param: "sslsni=1",
  2242  			sni_set:   true,
  2243  		},
  2244  		{
  2245  			name:      "SNI is not passed when disabled",
  2246  			sni_param: "sslsni=0",
  2247  			sni_set:   false,
  2248  		},
  2249  	}
  2250  	for _, tt := range tests {
  2251  		tt := tt
  2252  		t.Run(tt.name, func(t *testing.T) {
  2253  			t.Parallel()
  2254  
  2255  			ln, err := net.Listen("tcp", "127.0.0.1:")
  2256  			require.NoError(t, err)
  2257  			defer ln.Close()
  2258  
  2259  			serverErrChan := make(chan error, 1)
  2260  			serverSNINameChan := make(chan string, 1)
  2261  			defer close(serverErrChan)
  2262  			defer close(serverSNINameChan)
  2263  
  2264  			go func() {
  2265  				var sniHost string
  2266  
  2267  				conn, err := ln.Accept()
  2268  				if err != nil {
  2269  					serverErrChan <- err
  2270  					return
  2271  				}
  2272  				defer conn.Close()
  2273  
  2274  				err = conn.SetDeadline(time.Now().Add(5 * time.Second))
  2275  				if err != nil {
  2276  					serverErrChan <- err
  2277  					return
  2278  				}
  2279  
  2280  				backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)
  2281  				startupMessage, err := backend.ReceiveStartupMessage()
  2282  				if err != nil {
  2283  					serverErrChan <- err
  2284  					return
  2285  				}
  2286  
  2287  				switch startupMessage.(type) {
  2288  				case *pgproto3.SSLRequest:
  2289  					_, err = conn.Write([]byte("S"))
  2290  					if err != nil {
  2291  						serverErrChan <- err
  2292  						return
  2293  					}
  2294  				default:
  2295  					serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
  2296  					return
  2297  				}
  2298  
  2299  				cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
  2300  				if err != nil {
  2301  					serverErrChan <- err
  2302  					return
  2303  				}
  2304  
  2305  				srv := tls.Server(conn, &tls.Config{
  2306  					Certificates: []tls.Certificate{cert},
  2307  					GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
  2308  						sniHost = argHello.ServerName
  2309  						return nil, nil
  2310  					},
  2311  				})
  2312  				defer srv.Close()
  2313  
  2314  				if err := srv.Handshake(); err != nil {
  2315  					serverErrChan <- fmt.Errorf("handshake: %v", err)
  2316  					return
  2317  				}
  2318  
  2319  				srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
  2320  				srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
  2321  				srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
  2322  
  2323  				serverSNINameChan <- sniHost
  2324  			}()
  2325  
  2326  			port := strings.Split(ln.Addr().String(), ":")[1]
  2327  			connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
  2328  			_, err = pgconn.Connect(context.Background(), connStr)
  2329  
  2330  			select {
  2331  			case sniHost := <-serverSNINameChan:
  2332  				if tt.sni_set {
  2333  					require.Equal(t, sniHost, "localhost")
  2334  				} else {
  2335  					require.Equal(t, sniHost, "")
  2336  				}
  2337  			case err = <-serverErrChan:
  2338  				t.Fatalf("server failed with error: %+v", err)
  2339  			case <-time.After(time.Millisecond * 100):
  2340  				t.Fatal("exceeded connection timeout without erroring out")
  2341  			}
  2342  		})
  2343  	}
  2344  }
  2345  
  2346  type delayedReader struct {
  2347  	r io.Reader
  2348  }
  2349  
  2350  func (d delayedReader) Read(p []byte) (int, error) {
  2351  	// W/o sleep test passes, with sleep it fails.
  2352  	time.Sleep(time.Millisecond)
  2353  	return d.r.Read(p)
  2354  }
  2355  
  2356  func TestCopyFrom(t *testing.T) {
  2357  	connString := os.Getenv("PGX_TEST_CONN_STRING")
  2358  	if connString == "" {
  2359  		t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_CONN_STRING")
  2360  	}
  2361  
  2362  	config, err := pgconn.ParseConfig(connString)
  2363  	require.NoError(t, err)
  2364  
  2365  	pgConn, err := pgconn.ConnectConfig(context.Background(), config)
  2366  	require.NoError(t, err)
  2367  
  2368  	if pgConn.ParameterStatus("crdb_version") != "" {
  2369  		t.Skip("Server does support COPY FROM")
  2370  	}
  2371  
  2372  	setupSQL := `create temporary table t (
  2373  		id text primary key,
  2374  		n int not null
  2375  	);`
  2376  
  2377  	_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
  2378  	assert.NoError(t, err)
  2379  
  2380  	r1 := delayedReader{r: strings.NewReader(`id	0\n`)}
  2381  	// Generate an error with a bogus COPY command
  2382  	_, err = pgConn.CopyFrom(context.Background(), r1, "COPY nosuchtable FROM STDIN ")
  2383  	assert.Error(t, err)
  2384  
  2385  	r2 := delayedReader{r: strings.NewReader(`id	0\n`)}
  2386  	_, err = pgConn.CopyFrom(context.Background(), r2, "COPY t FROM STDIN")
  2387  	assert.NoError(t, err)
  2388  }
  2389  
  2390  func mustEncode(buf []byte, err error) []byte {
  2391  	if err != nil {
  2392  		panic(err)
  2393  	}
  2394  	return buf
  2395  }
  2396  

View as plain text