...

Source file src/github.com/jackc/pgx/v5/helper_test.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"os"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  
    10  	"github.com/jackc/pgx/v5"
    11  	"github.com/jackc/pgx/v5/pgconn"
    12  	"github.com/jackc/pgx/v5/pgxtest"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  var defaultConnTestRunner pgxtest.ConnTestRunner
    17  
    18  func init() {
    19  	defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
    20  	defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
    21  		config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
    22  		require.NoError(t, err)
    23  		return config
    24  	}
    25  }
    26  
    27  func mustConnectString(t testing.TB, connString string) *pgx.Conn {
    28  	conn, err := pgx.Connect(context.Background(), connString)
    29  	if err != nil {
    30  		t.Fatalf("Unable to establish connection: %v", err)
    31  	}
    32  	return conn
    33  }
    34  
    35  func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig {
    36  	config, err := pgx.ParseConfig(connString)
    37  	require.Nil(t, err)
    38  	return config
    39  }
    40  
    41  func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
    42  	conn, err := pgx.ConnectConfig(context.Background(), config)
    43  	if err != nil {
    44  		t.Fatalf("Unable to establish connection: %v", err)
    45  	}
    46  	return conn
    47  }
    48  
    49  func closeConn(t testing.TB, conn *pgx.Conn) {
    50  	err := conn.Close(context.Background())
    51  	if err != nil {
    52  		t.Fatalf("conn.Close unexpectedly failed: %v", err)
    53  	}
    54  }
    55  
    56  func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) {
    57  	var err error
    58  	if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
    59  		t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
    60  	}
    61  	return
    62  }
    63  
    64  // Do a simple query to ensure the connection is still usable
    65  func ensureConnValid(t testing.TB, conn *pgx.Conn) {
    66  	var sum, rowCount int32
    67  
    68  	rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
    69  	if err != nil {
    70  		t.Fatalf("conn.Query failed: %v", err)
    71  	}
    72  	defer rows.Close()
    73  
    74  	for rows.Next() {
    75  		var n int32
    76  		rows.Scan(&n)
    77  		sum += n
    78  		rowCount++
    79  	}
    80  
    81  	if rows.Err() != nil {
    82  		t.Fatalf("conn.Query failed: %v", rows.Err())
    83  	}
    84  
    85  	if rowCount != 10 {
    86  		t.Error("Select called onDataRow wrong number of times")
    87  	}
    88  	if sum != 55 {
    89  		t.Error("Wrong values returned")
    90  	}
    91  }
    92  
    93  func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
    94  	if !assert.NotNil(t, expected) {
    95  		return
    96  	}
    97  	if !assert.NotNil(t, actual) {
    98  		return
    99  	}
   100  
   101  	assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
   102  	assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
   103  	assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
   104  	assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
   105  	assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
   106  	assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
   107  	assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
   108  	assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
   109  	assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
   110  	assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
   111  	assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
   112  	assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
   113  
   114  	// Can't test function equality, so just test that they are set or not.
   115  	assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
   116  	assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
   117  
   118  	if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
   119  		if expected.TLSConfig != nil {
   120  			assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
   121  			assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
   122  		}
   123  	}
   124  
   125  	if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
   126  		for i := range expected.Fallbacks {
   127  			assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
   128  			assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
   129  
   130  			if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
   131  				if expected.Fallbacks[i].TLSConfig != nil {
   132  					assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
   133  					assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
   134  				}
   135  			}
   136  		}
   137  	}
   138  }
   139  

View as plain text