...

Source file src/github.com/jackc/pgtype/testutil/testutil.go

Documentation: github.com/jackc/pgtype/testutil

     1  package testutil
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"os"
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/jackc/pgtype"
    12  	"github.com/jackc/pgx/v4"
    13  	_ "github.com/jackc/pgx/v4/stdlib"
    14  	_ "github.com/lib/pq"
    15  )
    16  
    17  func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB {
    18  	var sqlDriverName string
    19  	switch driverName {
    20  	case "github.com/lib/pq":
    21  		sqlDriverName = "postgres"
    22  	case "github.com/jackc/pgx/stdlib":
    23  		sqlDriverName = "pgx"
    24  	default:
    25  		t.Fatalf("Unknown driver %v", driverName)
    26  	}
    27  
    28  	db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE"))
    29  	if err != nil {
    30  		t.Fatal(err)
    31  	}
    32  
    33  	return db
    34  }
    35  
    36  func MustConnectPgx(t testing.TB) *pgx.Conn {
    37  	conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
    38  	if err != nil {
    39  		t.Fatal(err)
    40  	}
    41  
    42  	return conn
    43  }
    44  
    45  func MustClose(t testing.TB, conn interface {
    46  	Close() error
    47  }) {
    48  	err := conn.Close()
    49  	if err != nil {
    50  		t.Fatal(err)
    51  	}
    52  }
    53  
    54  func MustCloseContext(t testing.TB, conn interface {
    55  	Close(context.Context) error
    56  }) {
    57  	err := conn.Close(context.Background())
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  }
    62  
    63  type forceTextEncoder struct {
    64  	e pgtype.TextEncoder
    65  }
    66  
    67  func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
    68  	return f.e.EncodeText(ci, buf)
    69  }
    70  
    71  type forceBinaryEncoder struct {
    72  	e pgtype.BinaryEncoder
    73  }
    74  
    75  func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
    76  	return f.e.EncodeBinary(ci, buf)
    77  }
    78  
    79  func ForceEncoder(e interface{}, formatCode int16) interface{} {
    80  	switch formatCode {
    81  	case pgx.TextFormatCode:
    82  		if e, ok := e.(pgtype.TextEncoder); ok {
    83  			return forceTextEncoder{e: e}
    84  		}
    85  	case pgx.BinaryFormatCode:
    86  		if e, ok := e.(pgtype.BinaryEncoder); ok {
    87  			return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)}
    88  		}
    89  	}
    90  	return nil
    91  }
    92  
    93  func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) {
    94  	TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool {
    95  		return reflect.DeepEqual(a, b)
    96  	})
    97  }
    98  
    99  func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
   100  	TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
   101  	for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
   102  		TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
   103  	}
   104  }
   105  
   106  func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
   107  	conn := MustConnectPgx(t)
   108  	defer MustCloseContext(t, conn)
   109  
   110  	_, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName))
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  
   115  	formats := []struct {
   116  		name       string
   117  		formatCode int16
   118  	}{
   119  		{name: "TextFormat", formatCode: pgx.TextFormatCode},
   120  		{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
   121  	}
   122  
   123  	for i, v := range values {
   124  		for _, paramFormat := range formats {
   125  			for _, resultFormat := range formats {
   126  				vEncoder := ForceEncoder(v, paramFormat.formatCode)
   127  				if vEncoder == nil {
   128  					t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name)
   129  					continue
   130  				}
   131  				switch resultFormat.formatCode {
   132  				case pgx.TextFormatCode:
   133  					if _, ok := v.(pgtype.TextEncoder); !ok {
   134  						t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
   135  						continue
   136  					}
   137  				case pgx.BinaryFormatCode:
   138  					if _, ok := v.(pgtype.BinaryEncoder); !ok {
   139  						t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name)
   140  						continue
   141  					}
   142  				}
   143  
   144  				// Dereference value if it is a pointer
   145  				derefV := v
   146  				refVal := reflect.ValueOf(v)
   147  				if refVal.Kind() == reflect.Ptr {
   148  					derefV = refVal.Elem().Interface()
   149  				}
   150  
   151  				result := reflect.New(reflect.TypeOf(derefV))
   152  
   153  				err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface())
   154  				if err != nil {
   155  					t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err)
   156  				}
   157  
   158  				if !eqFunc(result.Elem().Interface(), derefV) {
   159  					t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface())
   160  				}
   161  			}
   162  		}
   163  	}
   164  }
   165  
   166  func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
   167  	conn := MustConnectDatabaseSQL(t, driverName)
   168  	defer MustClose(t, conn)
   169  
   170  	ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName))
   171  	if err != nil {
   172  		t.Fatal(err)
   173  	}
   174  
   175  	for i, v := range values {
   176  		// Dereference value if it is a pointer
   177  		derefV := v
   178  		refVal := reflect.ValueOf(v)
   179  		if refVal.Kind() == reflect.Ptr {
   180  			derefV = refVal.Elem().Interface()
   181  		}
   182  
   183  		result := reflect.New(reflect.TypeOf(derefV))
   184  		err := ps.QueryRow(v).Scan(result.Interface())
   185  		if err != nil {
   186  			t.Errorf("%v %d: %v", driverName, i, err)
   187  		}
   188  
   189  		if !eqFunc(result.Elem().Interface(), derefV) {
   190  			t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
   191  		}
   192  	}
   193  }
   194  
   195  type NormalizeTest struct {
   196  	SQL   string
   197  	Value interface{}
   198  }
   199  
   200  func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) {
   201  	TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool {
   202  		return reflect.DeepEqual(a, b)
   203  	})
   204  }
   205  
   206  func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
   207  	TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc)
   208  	for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
   209  		TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc)
   210  	}
   211  }
   212  
   213  func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
   214  	conn := MustConnectPgx(t)
   215  	defer MustCloseContext(t, conn)
   216  
   217  	formats := []struct {
   218  		name       string
   219  		formatCode int16
   220  	}{
   221  		{name: "TextFormat", formatCode: pgx.TextFormatCode},
   222  		{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
   223  	}
   224  
   225  	for i, tt := range tests {
   226  		for _, fc := range formats {
   227  			psName := fmt.Sprintf("test%d", i)
   228  			_, err := conn.Prepare(context.Background(), psName, tt.SQL)
   229  			if err != nil {
   230  				t.Fatal(err)
   231  			}
   232  
   233  			queryResultFormats := pgx.QueryResultFormats{fc.formatCode}
   234  			if ForceEncoder(tt.Value, fc.formatCode) == nil {
   235  				t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name)
   236  				continue
   237  			}
   238  			// Dereference value if it is a pointer
   239  			derefV := tt.Value
   240  			refVal := reflect.ValueOf(tt.Value)
   241  			if refVal.Kind() == reflect.Ptr {
   242  				derefV = refVal.Elem().Interface()
   243  			}
   244  
   245  			result := reflect.New(reflect.TypeOf(derefV))
   246  			err = conn.QueryRow(context.Background(), psName, queryResultFormats).Scan(result.Interface())
   247  			if err != nil {
   248  				t.Errorf("%v %d: %v", fc.name, i, err)
   249  			}
   250  
   251  			if !eqFunc(result.Elem().Interface(), derefV) {
   252  				t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface())
   253  			}
   254  		}
   255  	}
   256  }
   257  
   258  func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) {
   259  	conn := MustConnectDatabaseSQL(t, driverName)
   260  	defer MustClose(t, conn)
   261  
   262  	for i, tt := range tests {
   263  		ps, err := conn.Prepare(tt.SQL)
   264  		if err != nil {
   265  			t.Errorf("%d. %v", i, err)
   266  			continue
   267  		}
   268  
   269  		// Dereference value if it is a pointer
   270  		derefV := tt.Value
   271  		refVal := reflect.ValueOf(tt.Value)
   272  		if refVal.Kind() == reflect.Ptr {
   273  			derefV = refVal.Elem().Interface()
   274  		}
   275  
   276  		result := reflect.New(reflect.TypeOf(derefV))
   277  		err = ps.QueryRow().Scan(result.Interface())
   278  		if err != nil {
   279  			t.Errorf("%v %d: %v", driverName, i, err)
   280  		}
   281  
   282  		if !eqFunc(result.Elem().Interface(), derefV) {
   283  			t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface())
   284  		}
   285  	}
   286  }
   287  
   288  func TestGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) {
   289  	TestPgxGoZeroToNullConversion(t, pgTypeName, zero)
   290  	for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
   291  		TestDatabaseSQLGoZeroToNullConversion(t, driverName, pgTypeName, zero)
   292  	}
   293  }
   294  
   295  func TestNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) {
   296  	TestPgxNullToGoZeroConversion(t, pgTypeName, zero)
   297  	for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
   298  		TestDatabaseSQLNullToGoZeroConversion(t, driverName, pgTypeName, zero)
   299  	}
   300  }
   301  
   302  func TestPgxGoZeroToNullConversion(t testing.TB, pgTypeName string, zero interface{}) {
   303  	conn := MustConnectPgx(t)
   304  	defer MustCloseContext(t, conn)
   305  
   306  	_, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s is null", pgTypeName))
   307  	if err != nil {
   308  		t.Fatal(err)
   309  	}
   310  
   311  	formats := []struct {
   312  		name       string
   313  		formatCode int16
   314  	}{
   315  		{name: "TextFormat", formatCode: pgx.TextFormatCode},
   316  		{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
   317  	}
   318  
   319  	for _, paramFormat := range formats {
   320  		vEncoder := ForceEncoder(zero, paramFormat.formatCode)
   321  		if vEncoder == nil {
   322  			t.Logf("Skipping Param %s: %#v does not implement %v for encoding", paramFormat.name, zero, paramFormat.name)
   323  			continue
   324  		}
   325  
   326  		var result bool
   327  		err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(&result)
   328  		if err != nil {
   329  			t.Errorf("Param %s: %v", paramFormat.name, err)
   330  		}
   331  
   332  		if !result {
   333  			t.Errorf("Param %s: did not convert zero to null", paramFormat.name)
   334  		}
   335  	}
   336  }
   337  
   338  func TestPgxNullToGoZeroConversion(t testing.TB, pgTypeName string, zero interface{}) {
   339  	conn := MustConnectPgx(t)
   340  	defer MustCloseContext(t, conn)
   341  
   342  	_, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select null::%s", pgTypeName))
   343  	if err != nil {
   344  		t.Fatal(err)
   345  	}
   346  
   347  	formats := []struct {
   348  		name       string
   349  		formatCode int16
   350  	}{
   351  		{name: "TextFormat", formatCode: pgx.TextFormatCode},
   352  		{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
   353  	}
   354  
   355  	for _, resultFormat := range formats {
   356  
   357  		switch resultFormat.formatCode {
   358  		case pgx.TextFormatCode:
   359  			if _, ok := zero.(pgtype.TextEncoder); !ok {
   360  				t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name)
   361  				continue
   362  			}
   363  		case pgx.BinaryFormatCode:
   364  			if _, ok := zero.(pgtype.BinaryEncoder); !ok {
   365  				t.Logf("Skipping Result %s: %#v does not implement %v for decoding", resultFormat.name, zero, resultFormat.name)
   366  				continue
   367  			}
   368  		}
   369  
   370  		// Dereference value if it is a pointer
   371  		derefZero := zero
   372  		refVal := reflect.ValueOf(zero)
   373  		if refVal.Kind() == reflect.Ptr {
   374  			derefZero = refVal.Elem().Interface()
   375  		}
   376  
   377  		result := reflect.New(reflect.TypeOf(derefZero))
   378  
   379  		err := conn.QueryRow(context.Background(), "test").Scan(result.Interface())
   380  		if err != nil {
   381  			t.Errorf("Result %s: %v", resultFormat.name, err)
   382  		}
   383  
   384  		if !reflect.DeepEqual(result.Elem().Interface(), derefZero) {
   385  			t.Errorf("Result %s: did not convert null to zero", resultFormat.name)
   386  		}
   387  	}
   388  }
   389  
   390  func TestDatabaseSQLGoZeroToNullConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) {
   391  	conn := MustConnectDatabaseSQL(t, driverName)
   392  	defer MustClose(t, conn)
   393  
   394  	ps, err := conn.Prepare(fmt.Sprintf("select $1::%s is null", pgTypeName))
   395  	if err != nil {
   396  		t.Fatal(err)
   397  	}
   398  
   399  	var result bool
   400  	err = ps.QueryRow(zero).Scan(&result)
   401  	if err != nil {
   402  		t.Errorf("%v %v", driverName, err)
   403  	}
   404  
   405  	if !result {
   406  		t.Errorf("%v: did not convert zero to null", driverName)
   407  	}
   408  }
   409  
   410  func TestDatabaseSQLNullToGoZeroConversion(t testing.TB, driverName, pgTypeName string, zero interface{}) {
   411  	conn := MustConnectDatabaseSQL(t, driverName)
   412  	defer MustClose(t, conn)
   413  
   414  	ps, err := conn.Prepare(fmt.Sprintf("select null::%s", pgTypeName))
   415  	if err != nil {
   416  		t.Fatal(err)
   417  	}
   418  
   419  	// Dereference value if it is a pointer
   420  	derefZero := zero
   421  	refVal := reflect.ValueOf(zero)
   422  	if refVal.Kind() == reflect.Ptr {
   423  		derefZero = refVal.Elem().Interface()
   424  	}
   425  
   426  	result := reflect.New(reflect.TypeOf(derefZero))
   427  
   428  	err = ps.QueryRow().Scan(result.Interface())
   429  	if err != nil {
   430  		t.Errorf("%v %v", driverName, err)
   431  	}
   432  
   433  	if !reflect.DeepEqual(result.Elem().Interface(), derefZero) {
   434  		t.Errorf("%s: did not convert null to zero", driverName)
   435  	}
   436  }
   437  

View as plain text