...

Source file src/github.com/jackc/pgx/v5/pgxtest/pgxtest.go

Documentation: github.com/jackc/pgx/v5/pgxtest

     1  // Package pgxtest provides utilities for testing pgx and packages that integrate with pgx.
     2  package pgxtest
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/jackc/pgx/v5"
    13  )
    14  
    15  var AllQueryExecModes = []pgx.QueryExecMode{
    16  	pgx.QueryExecModeCacheStatement,
    17  	pgx.QueryExecModeCacheDescribe,
    18  	pgx.QueryExecModeDescribeExec,
    19  	pgx.QueryExecModeExec,
    20  	pgx.QueryExecModeSimpleProtocol,
    21  }
    22  
    23  // KnownOIDQueryExecModes is a slice of all query exec modes where the param and result OIDs are known before sending the query.
    24  var KnownOIDQueryExecModes = []pgx.QueryExecMode{
    25  	pgx.QueryExecModeCacheStatement,
    26  	pgx.QueryExecModeCacheDescribe,
    27  	pgx.QueryExecModeDescribeExec,
    28  }
    29  
    30  // ConnTestRunner controls how a *pgx.Conn is created and closed by tests. All fields are required. Use DefaultConnTestRunner to get a
    31  // ConnTestRunner with reasonable default values.
    32  type ConnTestRunner struct {
    33  	// CreateConfig returns a *pgx.ConnConfig suitable for use with pgx.ConnectConfig.
    34  	CreateConfig func(ctx context.Context, t testing.TB) *pgx.ConnConfig
    35  
    36  	// AfterConnect is called after conn is established. It allows for arbitrary connection setup before a test begins.
    37  	AfterConnect func(ctx context.Context, t testing.TB, conn *pgx.Conn)
    38  
    39  	// AfterTest is called after the test is run. It allows for validating the state of the connection before it is closed.
    40  	AfterTest func(ctx context.Context, t testing.TB, conn *pgx.Conn)
    41  
    42  	// CloseConn closes conn.
    43  	CloseConn func(ctx context.Context, t testing.TB, conn *pgx.Conn)
    44  }
    45  
    46  // DefaultConnTestRunner returns a new ConnTestRunner with all fields set to reasonable default values.
    47  func DefaultConnTestRunner() ConnTestRunner {
    48  	return ConnTestRunner{
    49  		CreateConfig: func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
    50  			config, err := pgx.ParseConfig("")
    51  			if err != nil {
    52  				t.Fatalf("ParseConfig failed: %v", err)
    53  			}
    54  			return config
    55  		},
    56  		AfterConnect: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {},
    57  		AfterTest:    func(ctx context.Context, t testing.TB, conn *pgx.Conn) {},
    58  		CloseConn: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
    59  			err := conn.Close(ctx)
    60  			if err != nil {
    61  				t.Errorf("Close failed: %v", err)
    62  			}
    63  		},
    64  	}
    65  }
    66  
    67  func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) {
    68  	t.Helper()
    69  
    70  	config := ctr.CreateConfig(ctx, t)
    71  	conn, err := pgx.ConnectConfig(ctx, config)
    72  	if err != nil {
    73  		t.Fatalf("ConnectConfig failed: %v", err)
    74  	}
    75  	defer ctr.CloseConn(ctx, t, conn)
    76  
    77  	ctr.AfterConnect(ctx, t, conn)
    78  	f(ctx, t, conn)
    79  	ctr.AfterTest(ctx, t, conn)
    80  }
    81  
    82  // RunWithQueryExecModes runs a f in a new test for each element of modes with a new connection created using connector.
    83  // If modes is nil all pgx.QueryExecModes are tested.
    84  func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner, modes []pgx.QueryExecMode, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) {
    85  	if modes == nil {
    86  		modes = AllQueryExecModes
    87  	}
    88  
    89  	for _, mode := range modes {
    90  		ctrWithMode := ctr
    91  		ctrWithMode.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
    92  			config := ctr.CreateConfig(ctx, t)
    93  			config.DefaultQueryExecMode = mode
    94  			return config
    95  		}
    96  
    97  		t.Run(mode.String(),
    98  			func(t *testing.T) {
    99  				ctrWithMode.RunTest(ctx, t, f)
   100  			},
   101  		)
   102  	}
   103  }
   104  
   105  type ValueRoundTripTest struct {
   106  	Param  any
   107  	Result any
   108  	Test   func(any) bool
   109  }
   110  
   111  func RunValueRoundTripTests(
   112  	ctx context.Context,
   113  	t testing.TB,
   114  	ctr ConnTestRunner,
   115  	modes []pgx.QueryExecMode,
   116  	pgTypeName string,
   117  	tests []ValueRoundTripTest,
   118  ) {
   119  	t.Helper()
   120  
   121  	if modes == nil {
   122  		modes = AllQueryExecModes
   123  	}
   124  
   125  	ctr.RunTest(ctx, t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   126  		t.Helper()
   127  
   128  		sql := fmt.Sprintf("select $1::%s", pgTypeName)
   129  
   130  		for i, tt := range tests {
   131  			for _, mode := range modes {
   132  				err := conn.QueryRow(ctx, sql, mode, tt.Param).Scan(tt.Result)
   133  				if err != nil {
   134  					t.Errorf("%d. %v: %v", i, mode, err)
   135  				}
   136  
   137  				result := reflect.ValueOf(tt.Result)
   138  				if result.Kind() == reflect.Ptr {
   139  					result = result.Elem()
   140  				}
   141  
   142  				if !tt.Test(result.Interface()) {
   143  					t.Errorf("%d. %v: unexpected result for %v: %v", i, mode, tt.Param, result.Interface())
   144  				}
   145  			}
   146  		}
   147  	})
   148  }
   149  
   150  // SkipCockroachDB calls Skip on t with msg if the connection is to a CockroachDB server.
   151  func SkipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) {
   152  	if conn.PgConn().ParameterStatus("crdb_version") != "" {
   153  		t.Skip(msg)
   154  	}
   155  }
   156  
   157  func SkipPostgreSQLVersionLessThan(t testing.TB, conn *pgx.Conn, minVersion int64) {
   158  	serverVersionStr := conn.PgConn().ParameterStatus("server_version")
   159  	serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
   160  	// if not PostgreSQL do nothing
   161  	if serverVersionStr == "" {
   162  		return
   163  	}
   164  
   165  	serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
   166  	if err != nil {
   167  		t.Fatalf("postgres version parsed failed: %s", err)
   168  	}
   169  
   170  	if serverVersion < minVersion {
   171  		t.Skipf("Test requires PostgreSQL v%d+", minVersion)
   172  	}
   173  }
   174  

View as plain text