...

Source file src/github.com/jackc/pgtype/composite_fields_test.go

Documentation: github.com/jackc/pgtype

     1  package pgtype_test
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	"github.com/jackc/pgtype"
     8  	"github.com/jackc/pgtype/testutil"
     9  	"github.com/jackc/pgx/v4"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func TestCompositeFieldsDecode(t *testing.T) {
    15  	conn := testutil.MustConnectPgx(t)
    16  	defer testutil.MustCloseContext(t, conn)
    17  
    18  	formats := []int16{pgx.TextFormatCode, pgx.BinaryFormatCode}
    19  
    20  	// Assorted values
    21  	{
    22  		var a int32
    23  		var b string
    24  		var c float64
    25  
    26  		for _, format := range formats {
    27  			err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
    28  				pgtype.CompositeFields{&a, &b, &c},
    29  			)
    30  			if !assert.NoErrorf(t, err, "Format: %v", format) {
    31  				continue
    32  			}
    33  
    34  			assert.EqualValuesf(t, 1, a, "Format: %v", format)
    35  			assert.EqualValuesf(t, "hi", b, "Format: %v", format)
    36  			assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
    37  		}
    38  	}
    39  
    40  	// nulls, string "null", and empty string fields
    41  	{
    42  		var a pgtype.Text
    43  		var b string
    44  		var c pgtype.Text
    45  		var d string
    46  		var e pgtype.Text
    47  
    48  		for _, format := range formats {
    49  			err := conn.QueryRow(context.Background(), "select row(null,'null',null,'',null)", pgx.QueryResultFormats{format}).Scan(
    50  				pgtype.CompositeFields{&a, &b, &c, &d, &e},
    51  			)
    52  			if !assert.NoErrorf(t, err, "Format: %v", format) {
    53  				continue
    54  			}
    55  
    56  			assert.Nilf(t, a.Get(), "Format: %v", format)
    57  			assert.EqualValuesf(t, "null", b, "Format: %v", format)
    58  			assert.Nilf(t, c.Get(), "Format: %v", format)
    59  			assert.EqualValuesf(t, "", d, "Format: %v", format)
    60  			assert.Nilf(t, e.Get(), "Format: %v", format)
    61  		}
    62  	}
    63  
    64  	// null record
    65  	{
    66  		var a pgtype.Text
    67  		var b string
    68  		cf := pgtype.CompositeFields{&a, &b}
    69  
    70  		for _, format := range formats {
    71  			// Cannot scan nil into
    72  			err := conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
    73  				cf,
    74  			)
    75  			if assert.Errorf(t, err, "Format: %v", format) {
    76  				continue
    77  			}
    78  			assert.NotNilf(t, cf, "Format: %v", format)
    79  
    80  			// But can scan nil into *pgtype.CompositeFields
    81  			err = conn.QueryRow(context.Background(), "select null::record", pgx.QueryResultFormats{format}).Scan(
    82  				&cf,
    83  			)
    84  			if assert.Errorf(t, err, "Format: %v", format) {
    85  				continue
    86  			}
    87  			assert.Nilf(t, cf, "Format: %v", format)
    88  		}
    89  	}
    90  
    91  	// quotes and special characters
    92  	{
    93  		var a, b, c, d string
    94  
    95  		for _, format := range formats {
    96  			err := conn.QueryRow(context.Background(), `select row('"', 'foo bar', 'foo''bar', 'baz)bar')`, pgx.QueryResultFormats{format}).Scan(
    97  				pgtype.CompositeFields{&a, &b, &c, &d},
    98  			)
    99  			if !assert.NoErrorf(t, err, "Format: %v", format) {
   100  				continue
   101  			}
   102  
   103  			assert.Equalf(t, `"`, a, "Format: %v", format)
   104  			assert.Equalf(t, `foo bar`, b, "Format: %v", format)
   105  			assert.Equalf(t, `foo'bar`, c, "Format: %v", format)
   106  			assert.Equalf(t, `baz)bar`, d, "Format: %v", format)
   107  		}
   108  	}
   109  
   110  	// arrays
   111  	{
   112  		var a []string
   113  		var b []int64
   114  
   115  		for _, format := range formats {
   116  			err := conn.QueryRow(context.Background(), `select row(array['foo', 'bar', 'baz'], array[1,2,3])`, pgx.QueryResultFormats{format}).Scan(
   117  				pgtype.CompositeFields{&a, &b},
   118  			)
   119  			if !assert.NoErrorf(t, err, "Format: %v", format) {
   120  				continue
   121  			}
   122  
   123  			assert.EqualValuesf(t, []string{"foo", "bar", "baz"}, a, "Format: %v", format)
   124  			assert.EqualValuesf(t, []int64{1, 2, 3}, b, "Format: %v", format)
   125  		}
   126  	}
   127  
   128  	// Skip nil fields
   129  	{
   130  		var a int32
   131  		var c float64
   132  
   133  		for _, format := range formats {
   134  			err := conn.QueryRow(context.Background(), "select row(1,'hi',2.1)", pgx.QueryResultFormats{format}).Scan(
   135  				pgtype.CompositeFields{&a, nil, &c},
   136  			)
   137  			if !assert.NoErrorf(t, err, "Format: %v", format) {
   138  				continue
   139  			}
   140  
   141  			assert.EqualValuesf(t, 1, a, "Format: %v", format)
   142  			assert.EqualValuesf(t, 2.1, c, "Format: %v", format)
   143  		}
   144  	}
   145  }
   146  
   147  func TestCompositeFieldsEncode(t *testing.T) {
   148  	conn := testutil.MustConnectPgx(t)
   149  	defer testutil.MustCloseContext(t, conn)
   150  
   151  	_, err := conn.Exec(context.Background(), `drop type if exists cf_encode;
   152  
   153  create type cf_encode as (
   154  	a text,
   155    b int4,
   156  	c text,
   157  	d float8,
   158  	e text
   159  );`)
   160  	require.NoError(t, err)
   161  	defer conn.Exec(context.Background(), "drop type cf_encode")
   162  
   163  	// Use simple protocol to force text or binary encoding
   164  	simpleProtocols := []bool{true, false}
   165  
   166  	// Assorted values
   167  	{
   168  		var a string
   169  		var b int32
   170  		var c string
   171  		var d float64
   172  		var e string
   173  
   174  		for _, simpleProtocol := range simpleProtocols {
   175  			err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
   176  				pgtype.CompositeFields{"hi", int32(1), "ok", float64(2.1), "bye"},
   177  			).Scan(
   178  				pgtype.CompositeFields{&a, &b, &c, &d, &e},
   179  			)
   180  			if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
   181  				assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
   182  				assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
   183  				assert.EqualValuesf(t, "ok", c, "Simple Protocol: %v", simpleProtocol)
   184  				assert.EqualValuesf(t, 2.1, d, "Simple Protocol: %v", simpleProtocol)
   185  				assert.EqualValuesf(t, "bye", e, "Simple Protocol: %v", simpleProtocol)
   186  			}
   187  		}
   188  	}
   189  
   190  	// untyped nil
   191  	{
   192  		var a pgtype.Text
   193  		var b int32
   194  		var c string
   195  		var d pgtype.Float8
   196  		var e pgtype.Text
   197  
   198  		simpleProtocol := true
   199  		err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
   200  			pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
   201  		).Scan(
   202  			pgtype.CompositeFields{&a, &b, &c, &d, &e},
   203  		)
   204  		if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
   205  			assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
   206  			assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
   207  			assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
   208  			assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
   209  			assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
   210  		}
   211  
   212  		// untyped nil cannot be represented in binary format because CompositeFields does not know the PostgreSQL schema
   213  		// of the composite type.
   214  		simpleProtocol = false
   215  		err = conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
   216  			pgtype.CompositeFields{nil, int32(1), "null", nil, nil},
   217  		).Scan(
   218  			pgtype.CompositeFields{&a, &b, &c, &d, &e},
   219  		)
   220  		assert.Errorf(t, err, "Simple Protocol: %v", simpleProtocol)
   221  	}
   222  
   223  	// nulls, string "null", and empty string fields
   224  	{
   225  		var a pgtype.Text
   226  		var b int32
   227  		var c string
   228  		var d pgtype.Float8
   229  		var e pgtype.Text
   230  
   231  		for _, simpleProtocol := range simpleProtocols {
   232  			err := conn.QueryRow(context.Background(), "select $1::cf_encode", pgx.QuerySimpleProtocol(simpleProtocol),
   233  				pgtype.CompositeFields{&pgtype.Text{Status: pgtype.Null}, int32(1), "null", &pgtype.Float8{Status: pgtype.Null}, &pgtype.Text{Status: pgtype.Null}},
   234  			).Scan(
   235  				pgtype.CompositeFields{&a, &b, &c, &d, &e},
   236  			)
   237  			if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
   238  				assert.Nilf(t, a.Get(), "Simple Protocol: %v", simpleProtocol)
   239  				assert.EqualValuesf(t, 1, b, "Simple Protocol: %v", simpleProtocol)
   240  				assert.EqualValuesf(t, "null", c, "Simple Protocol: %v", simpleProtocol)
   241  				assert.Nilf(t, d.Get(), "Simple Protocol: %v", simpleProtocol)
   242  				assert.Nilf(t, e.Get(), "Simple Protocol: %v", simpleProtocol)
   243  			}
   244  		}
   245  	}
   246  
   247  	// quotes and special characters
   248  	{
   249  		var a string
   250  		var b int32
   251  		var c string
   252  		var d float64
   253  		var e string
   254  
   255  		for _, simpleProtocol := range simpleProtocols {
   256  			err := conn.QueryRow(
   257  				context.Background(),
   258  				`select $1::cf_encode`,
   259  				pgx.QuerySimpleProtocol(simpleProtocol),
   260  				pgtype.CompositeFields{`"`, int32(42), `foo'bar`, float64(1.2), `baz)bar`},
   261  			).Scan(
   262  				pgtype.CompositeFields{&a, &b, &c, &d, &e},
   263  			)
   264  			if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
   265  				assert.Equalf(t, `"`, a, "Simple Protocol: %v", simpleProtocol)
   266  				assert.Equalf(t, int32(42), b, "Simple Protocol: %v", simpleProtocol)
   267  				assert.Equalf(t, `foo'bar`, c, "Simple Protocol: %v", simpleProtocol)
   268  				assert.Equalf(t, float64(1.2), d, "Simple Protocol: %v", simpleProtocol)
   269  				assert.Equalf(t, `baz)bar`, e, "Simple Protocol: %v", simpleProtocol)
   270  			}
   271  		}
   272  	}
   273  }
   274  

View as plain text