...

Source file src/github.com/jackc/pgx/v4/pgxpool/common_test.go

Documentation: github.com/jackc/pgx/v4/pgxpool

     1  package pgxpool_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/jackc/pgx/v4/pgxpool"
     9  
    10  	"github.com/jackc/pgconn"
    11  	"github.com/jackc/pgx/v4"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  // Conn.Release is an asynchronous process that returns immediately. There is no signal when the actual work is
    17  // completed. To test something that relies on the actual work for Conn.Release being completed we must simply wait.
    18  // This function wraps the sleep so there is more meaning for the callers.
    19  func waitForReleaseToComplete() {
    20  	time.Sleep(500 * time.Millisecond)
    21  }
    22  
    23  type execer interface {
    24  	Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
    25  }
    26  
    27  func testExec(t *testing.T, db execer) {
    28  	results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'")
    29  	require.NoError(t, err)
    30  	assert.EqualValues(t, "SET", results)
    31  }
    32  
    33  type queryer interface {
    34  	Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
    35  }
    36  
    37  func testQuery(t *testing.T, db queryer) {
    38  	var sum, rowCount int32
    39  
    40  	rows, err := db.Query(context.Background(), "select generate_series(1,$1)", 10)
    41  	require.NoError(t, err)
    42  
    43  	for rows.Next() {
    44  		var n int32
    45  		rows.Scan(&n)
    46  		sum += n
    47  		rowCount++
    48  	}
    49  
    50  	assert.NoError(t, rows.Err())
    51  	assert.Equal(t, int32(10), rowCount)
    52  	assert.Equal(t, int32(55), sum)
    53  }
    54  
    55  type queryRower interface {
    56  	QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
    57  }
    58  
    59  func testQueryRow(t *testing.T, db queryRower) {
    60  	var what, who string
    61  	err := db.QueryRow(context.Background(), "select 'hello', $1::text", "world").Scan(&what, &who)
    62  	assert.NoError(t, err)
    63  	assert.Equal(t, "hello", what)
    64  	assert.Equal(t, "world", who)
    65  }
    66  
    67  type sendBatcher interface {
    68  	SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
    69  }
    70  
    71  func testSendBatch(t *testing.T, db sendBatcher) {
    72  	batch := &pgx.Batch{}
    73  	batch.Queue("select 1")
    74  	batch.Queue("select 2")
    75  
    76  	br := db.SendBatch(context.Background(), batch)
    77  
    78  	var err error
    79  	var n int32
    80  	err = br.QueryRow().Scan(&n)
    81  	assert.NoError(t, err)
    82  	assert.EqualValues(t, 1, n)
    83  
    84  	err = br.QueryRow().Scan(&n)
    85  	assert.NoError(t, err)
    86  	assert.EqualValues(t, 2, n)
    87  
    88  	err = br.Close()
    89  	assert.NoError(t, err)
    90  }
    91  
    92  type copyFromer interface {
    93  	CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
    94  }
    95  
    96  func testCopyFrom(t *testing.T, db interface {
    97  	execer
    98  	queryer
    99  	copyFromer
   100  }) {
   101  	_, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
   102  	require.NoError(t, err)
   103  
   104  	tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
   105  
   106  	inputRows := [][]interface{}{
   107  		{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
   108  		{nil, nil, nil, nil, nil, nil, nil},
   109  	}
   110  
   111  	copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
   112  	assert.NoError(t, err)
   113  	assert.EqualValues(t, len(inputRows), copyCount)
   114  
   115  	rows, err := db.Query(context.Background(), "select * from foo")
   116  	assert.NoError(t, err)
   117  
   118  	var outputRows [][]interface{}
   119  	for rows.Next() {
   120  		row, err := rows.Values()
   121  		if err != nil {
   122  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   123  		}
   124  		outputRows = append(outputRows, row)
   125  	}
   126  
   127  	assert.NoError(t, rows.Err())
   128  	assert.Equal(t, inputRows, outputRows)
   129  }
   130  
   131  func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName string) {
   132  	if !assert.NotNil(t, expected) {
   133  		return
   134  	}
   135  	if !assert.NotNil(t, actual) {
   136  		return
   137  	}
   138  
   139  	assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
   140  
   141  	// Can't test function equality, so just test that they are set or not.
   142  	assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
   143  	assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName)
   144  	assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName)
   145  
   146  	assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName)
   147  	assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName)
   148  	assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName)
   149  	assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName)
   150  	assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName)
   151  	assert.Equalf(t, expected.LazyConnect, actual.LazyConnect, "%s - LazyConnect", testName)
   152  
   153  	assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName)
   154  }
   155  
   156  func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
   157  	if !assert.NotNil(t, expected) {
   158  		return
   159  	}
   160  	if !assert.NotNil(t, actual) {
   161  		return
   162  	}
   163  
   164  	assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
   165  	assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
   166  	assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
   167  
   168  	// Can't test function equality, so just test that they are set or not.
   169  	assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName)
   170  
   171  	assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName)
   172  
   173  	assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
   174  	assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
   175  	assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
   176  	assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
   177  	assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
   178  	assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
   179  	assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
   180  
   181  	// Can't test function equality, so just test that they are set or not.
   182  	assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
   183  	assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
   184  
   185  	if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
   186  		if expected.TLSConfig != nil {
   187  			assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
   188  			assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
   189  		}
   190  	}
   191  
   192  	if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
   193  		for i := range expected.Fallbacks {
   194  			assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
   195  			assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
   196  
   197  			if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
   198  				if expected.Fallbacks[i].TLSConfig != nil {
   199  					assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
   200  					assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
   201  				}
   202  			}
   203  		}
   204  	}
   205  }
   206  

View as plain text