...

Source file src/github.com/jackc/pgx/v4/helper_test.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"os"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  
    10  	"github.com/jackc/pgconn"
    11  	"github.com/jackc/pgx/v4"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
    16  	t.Run("SimpleProto",
    17  		func(t *testing.T) {
    18  			config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
    19  			require.NoError(t, err)
    20  
    21  			config.PreferSimpleProtocol = true
    22  			conn, err := pgx.ConnectConfig(context.Background(), config)
    23  			require.NoError(t, err)
    24  			defer func() {
    25  				err := conn.Close(context.Background())
    26  				require.NoError(t, err)
    27  			}()
    28  
    29  			f(t, conn)
    30  
    31  			ensureConnValid(t, conn)
    32  		},
    33  	)
    34  
    35  	t.Run("DefaultProto",
    36  		func(t *testing.T) {
    37  			config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
    38  			require.NoError(t, err)
    39  
    40  			conn, err := pgx.ConnectConfig(context.Background(), config)
    41  			require.NoError(t, err)
    42  			defer func() {
    43  				err := conn.Close(context.Background())
    44  				require.NoError(t, err)
    45  			}()
    46  
    47  			f(t, conn)
    48  
    49  			ensureConnValid(t, conn)
    50  		},
    51  	)
    52  }
    53  
    54  func mustConnectString(t testing.TB, connString string) *pgx.Conn {
    55  	conn, err := pgx.Connect(context.Background(), connString)
    56  	if err != nil {
    57  		t.Fatalf("Unable to establish connection: %v", err)
    58  	}
    59  	return conn
    60  }
    61  
    62  func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig {
    63  	config, err := pgx.ParseConfig(connString)
    64  	require.Nil(t, err)
    65  	return config
    66  }
    67  
    68  func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
    69  	conn, err := pgx.ConnectConfig(context.Background(), config)
    70  	if err != nil {
    71  		t.Fatalf("Unable to establish connection: %v", err)
    72  	}
    73  	return conn
    74  }
    75  
    76  func closeConn(t testing.TB, conn *pgx.Conn) {
    77  	err := conn.Close(context.Background())
    78  	if err != nil {
    79  		t.Fatalf("conn.Close unexpectedly failed: %v", err)
    80  	}
    81  }
    82  
    83  func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) {
    84  	var err error
    85  	if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
    86  		t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
    87  	}
    88  	return
    89  }
    90  
    91  // Do a simple query to ensure the connection is still usable
    92  func ensureConnValid(t *testing.T, conn *pgx.Conn) {
    93  	var sum, rowCount int32
    94  
    95  	rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
    96  	if err != nil {
    97  		t.Fatalf("conn.Query failed: %v", err)
    98  	}
    99  	defer rows.Close()
   100  
   101  	for rows.Next() {
   102  		var n int32
   103  		rows.Scan(&n)
   104  		sum += n
   105  		rowCount++
   106  	}
   107  
   108  	if rows.Err() != nil {
   109  		t.Fatalf("conn.Query failed: %v", err)
   110  	}
   111  
   112  	if rowCount != 10 {
   113  		t.Error("Select called onDataRow wrong number of times")
   114  	}
   115  	if sum != 55 {
   116  		t.Error("Wrong values returned")
   117  	}
   118  }
   119  
   120  func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
   121  	if !assert.NotNil(t, expected) {
   122  		return
   123  	}
   124  	if !assert.NotNil(t, actual) {
   125  		return
   126  	}
   127  
   128  	assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
   129  	assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
   130  	assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
   131  	// Can't test function equality, so just test that they are set or not.
   132  	assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName)
   133  	assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName)
   134  
   135  	assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
   136  	assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
   137  	assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
   138  	assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
   139  	assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
   140  	assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
   141  	assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
   142  
   143  	// Can't test function equality, so just test that they are set or not.
   144  	assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
   145  	assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
   146  
   147  	if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
   148  		if expected.TLSConfig != nil {
   149  			assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
   150  			assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
   151  		}
   152  	}
   153  
   154  	if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
   155  		for i := range expected.Fallbacks {
   156  			assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
   157  			assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
   158  
   159  			if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
   160  				if expected.Fallbacks[i].TLSConfig != nil {
   161  					assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
   162  					assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
   163  				}
   164  			}
   165  		}
   166  	}
   167  }
   168  
   169  func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) {
   170  	if conn.PgConn().ParameterStatus("crdb_version") != "" {
   171  		t.Skip(msg)
   172  	}
   173  }
   174  

View as plain text