...

Source file src/github.com/jackc/pgx/v4/values_test.go

Documentation: github.com/jackc/pgx/v4

     1  package pgx_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net"
     7  	"os"
     8  	"reflect"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/jackc/pgx/v4"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func TestDateTranscode(t *testing.T) {
    19  	t.Parallel()
    20  
    21  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
    22  		dates := []time.Time{
    23  			time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC),
    24  			time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC),
    25  			time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC),
    26  			time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC),
    27  			time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC),
    28  			time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC),
    29  			time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC),
    30  			time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC),
    31  			time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC),
    32  			time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC),
    33  			time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC),
    34  			time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC),
    35  			time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC),
    36  			time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC),
    37  			time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC),
    38  			time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC),
    39  			time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC),
    40  			time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC),
    41  		}
    42  
    43  		for _, actualDate := range dates {
    44  			var d time.Time
    45  
    46  			err := conn.QueryRow(context.Background(), "select $1::date", actualDate).Scan(&d)
    47  			if err != nil {
    48  				t.Fatalf("Unexpected failure on QueryRow Scan: %v", err)
    49  			}
    50  			if !actualDate.Equal(d) {
    51  				t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate)
    52  			}
    53  		}
    54  	})
    55  }
    56  
    57  func TestTimestampTzTranscode(t *testing.T) {
    58  	t.Parallel()
    59  
    60  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
    61  		inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local)
    62  
    63  		var outputTime time.Time
    64  
    65  		err := conn.QueryRow(context.Background(), "select $1::timestamptz", inputTime).Scan(&outputTime)
    66  		if err != nil {
    67  			t.Fatalf("QueryRow Scan failed: %v", err)
    68  		}
    69  		if !inputTime.Equal(outputTime) {
    70  			t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime)
    71  		}
    72  	})
    73  }
    74  
    75  // TODO - move these tests to pgtype
    76  
    77  func TestJSONAndJSONBTranscode(t *testing.T) {
    78  	t.Parallel()
    79  
    80  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
    81  		for _, typename := range []string{"json", "jsonb"} {
    82  			if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok {
    83  				continue // No JSON/JSONB type -- must be running against old PostgreSQL
    84  			}
    85  
    86  			testJSONString(t, conn, typename)
    87  			testJSONStringPointer(t, conn, typename)
    88  		}
    89  	})
    90  }
    91  
    92  func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) {
    93  	t.Parallel()
    94  
    95  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
    96  	defer closeConn(t, conn)
    97  
    98  	for _, typename := range []string{"json", "jsonb"} {
    99  		if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok {
   100  			continue // No JSON/JSONB type -- must be running against old PostgreSQL
   101  		}
   102  		testJSONSingleLevelStringMap(t, conn, typename)
   103  		testJSONNestedMap(t, conn, typename)
   104  		testJSONStringArray(t, conn, typename)
   105  		testJSONInt64Array(t, conn, typename)
   106  		testJSONInt16ArrayFailureDueToOverflow(t, conn, typename)
   107  		testJSONStruct(t, conn, typename)
   108  	}
   109  
   110  }
   111  
   112  func testJSONString(t *testing.T, conn *pgx.Conn, typename string) {
   113  	input := `{"key": "value"}`
   114  	expectedOutput := map[string]string{"key": "value"}
   115  	var output map[string]string
   116  	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
   117  	if err != nil {
   118  		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
   119  		return
   120  	}
   121  
   122  	if !reflect.DeepEqual(expectedOutput, output) {
   123  		t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
   124  		return
   125  	}
   126  }
   127  
   128  func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) {
   129  	input := `{"key": "value"}`
   130  	expectedOutput := map[string]string{"key": "value"}
   131  	var output map[string]string
   132  	err := conn.QueryRow(context.Background(), "select $1::"+typename, &input).Scan(&output)
   133  	if err != nil {
   134  		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
   135  		return
   136  	}
   137  
   138  	if !reflect.DeepEqual(expectedOutput, output) {
   139  		t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output)
   140  		return
   141  	}
   142  }
   143  
   144  func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) {
   145  	input := map[string]string{"key": "value"}
   146  	var output map[string]string
   147  	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
   148  	if err != nil {
   149  		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
   150  		return
   151  	}
   152  
   153  	if !reflect.DeepEqual(input, output) {
   154  		t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output)
   155  		return
   156  	}
   157  }
   158  
   159  func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) {
   160  	input := map[string]interface{}{
   161  		"name":      "Uncanny",
   162  		"stats":     map[string]interface{}{"hp": float64(107), "maxhp": float64(150)},
   163  		"inventory": []interface{}{"phone", "key"},
   164  	}
   165  	var output map[string]interface{}
   166  	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
   167  	if err != nil {
   168  		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
   169  		return
   170  	}
   171  
   172  	if !reflect.DeepEqual(input, output) {
   173  		t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output)
   174  		return
   175  	}
   176  }
   177  
   178  func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) {
   179  	input := []string{"foo", "bar", "baz"}
   180  	var output []string
   181  	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
   182  	if err != nil {
   183  		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
   184  	}
   185  
   186  	if !reflect.DeepEqual(input, output) {
   187  		t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output)
   188  	}
   189  }
   190  
   191  func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) {
   192  	input := []int64{1, 2, 234432}
   193  	var output []int64
   194  	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
   195  	if err != nil {
   196  		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
   197  	}
   198  
   199  	if !reflect.DeepEqual(input, output) {
   200  		t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output)
   201  	}
   202  }
   203  
   204  func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) {
   205  	input := []int{1, 2, 234432}
   206  	var output []int16
   207  	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
   208  	if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
   209  		t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err)
   210  	}
   211  }
   212  
   213  func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) {
   214  	type person struct {
   215  		Name string `json:"name"`
   216  		Age  int    `json:"age"`
   217  	}
   218  
   219  	input := person{
   220  		Name: "John",
   221  		Age:  42,
   222  	}
   223  
   224  	var output person
   225  
   226  	err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
   227  	if err != nil {
   228  		t.Errorf("%s: QueryRow Scan failed: %v", typename, err)
   229  	}
   230  
   231  	if !reflect.DeepEqual(input, output) {
   232  		t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output)
   233  	}
   234  }
   235  
   236  func mustParseCIDR(t *testing.T, s string) *net.IPNet {
   237  	_, ipnet, err := net.ParseCIDR(s)
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  
   242  	return ipnet
   243  }
   244  
   245  func TestStringToNotTextTypeTranscode(t *testing.T) {
   246  	t.Parallel()
   247  
   248  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   249  		input := "01086ee0-4963-4e35-9116-30c173a8d0bd"
   250  
   251  		var output string
   252  		err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output)
   253  		if err != nil {
   254  			t.Fatal(err)
   255  		}
   256  		if input != output {
   257  			t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output)
   258  		}
   259  
   260  		err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output)
   261  		if err != nil {
   262  			t.Fatal(err)
   263  		}
   264  		if input != output {
   265  			t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output)
   266  		}
   267  	})
   268  }
   269  
   270  func TestInetCIDRTranscodeIPNet(t *testing.T) {
   271  	t.Parallel()
   272  
   273  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   274  		tests := []struct {
   275  			sql   string
   276  			value *net.IPNet
   277  		}{
   278  			{"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")},
   279  			{"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")},
   280  			{"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")},
   281  			{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
   282  			{"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")},
   283  			{"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")},
   284  			{"select $1::inet", mustParseCIDR(t, "::/128")},
   285  			{"select $1::inet", mustParseCIDR(t, "::/0")},
   286  			{"select $1::inet", mustParseCIDR(t, "::1/128")},
   287  			{"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
   288  			{"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")},
   289  			{"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")},
   290  			{"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")},
   291  			{"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
   292  			{"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")},
   293  			{"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")},
   294  			{"select $1::cidr", mustParseCIDR(t, "::/128")},
   295  			{"select $1::cidr", mustParseCIDR(t, "::/0")},
   296  			{"select $1::cidr", mustParseCIDR(t, "::1/128")},
   297  			{"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
   298  		}
   299  
   300  		for i, tt := range tests {
   301  			if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") {
   302  				t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)")
   303  				continue
   304  			}
   305  
   306  			var actual net.IPNet
   307  
   308  			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
   309  			if err != nil {
   310  				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
   311  				continue
   312  			}
   313  
   314  			if actual.String() != tt.value.String() {
   315  				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
   316  			}
   317  		}
   318  	})
   319  }
   320  
   321  func TestInetCIDRTranscodeIP(t *testing.T) {
   322  	t.Parallel()
   323  
   324  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   325  		tests := []struct {
   326  			sql   string
   327  			value net.IP
   328  		}{
   329  			{"select $1::inet", net.ParseIP("0.0.0.0")},
   330  			{"select $1::inet", net.ParseIP("127.0.0.1")},
   331  			{"select $1::inet", net.ParseIP("12.34.56.0")},
   332  			{"select $1::inet", net.ParseIP("255.255.255.255")},
   333  			{"select $1::inet", net.ParseIP("::1")},
   334  			{"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")},
   335  			{"select $1::cidr", net.ParseIP("0.0.0.0")},
   336  			{"select $1::cidr", net.ParseIP("127.0.0.1")},
   337  			{"select $1::cidr", net.ParseIP("12.34.56.0")},
   338  			{"select $1::cidr", net.ParseIP("255.255.255.255")},
   339  			{"select $1::cidr", net.ParseIP("::1")},
   340  			{"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")},
   341  		}
   342  
   343  		for i, tt := range tests {
   344  			if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") {
   345  				t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)")
   346  				continue
   347  			}
   348  
   349  			var actual net.IP
   350  
   351  			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
   352  			if err != nil {
   353  				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
   354  				continue
   355  			}
   356  
   357  			if !actual.Equal(tt.value) {
   358  				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
   359  			}
   360  
   361  			ensureConnValid(t, conn)
   362  		}
   363  
   364  		failTests := []struct {
   365  			sql   string
   366  			value *net.IPNet
   367  		}{
   368  			{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
   369  			{"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
   370  		}
   371  		for i, tt := range failTests {
   372  			var actual net.IP
   373  
   374  			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
   375  			if err == nil {
   376  				t.Errorf("%d. Expected failure but got none", i)
   377  				continue
   378  			}
   379  
   380  			ensureConnValid(t, conn)
   381  		}
   382  	})
   383  }
   384  
   385  func TestInetCIDRArrayTranscodeIPNet(t *testing.T) {
   386  	t.Parallel()
   387  
   388  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   389  		tests := []struct {
   390  			sql   string
   391  			value []*net.IPNet
   392  		}{
   393  			{
   394  				"select $1::inet[]",
   395  				[]*net.IPNet{
   396  					mustParseCIDR(t, "0.0.0.0/32"),
   397  					mustParseCIDR(t, "127.0.0.1/32"),
   398  					mustParseCIDR(t, "12.34.56.0/32"),
   399  					mustParseCIDR(t, "192.168.1.0/24"),
   400  					mustParseCIDR(t, "255.0.0.0/8"),
   401  					mustParseCIDR(t, "255.255.255.255/32"),
   402  					mustParseCIDR(t, "::/128"),
   403  					mustParseCIDR(t, "::/0"),
   404  					mustParseCIDR(t, "::1/128"),
   405  					mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
   406  				},
   407  			},
   408  			{
   409  				"select $1::cidr[]",
   410  				[]*net.IPNet{
   411  					mustParseCIDR(t, "0.0.0.0/32"),
   412  					mustParseCIDR(t, "127.0.0.1/32"),
   413  					mustParseCIDR(t, "12.34.56.0/32"),
   414  					mustParseCIDR(t, "192.168.1.0/24"),
   415  					mustParseCIDR(t, "255.0.0.0/8"),
   416  					mustParseCIDR(t, "255.255.255.255/32"),
   417  					mustParseCIDR(t, "::/128"),
   418  					mustParseCIDR(t, "::/0"),
   419  					mustParseCIDR(t, "::1/128"),
   420  					mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"),
   421  				},
   422  			},
   423  		}
   424  
   425  		for i, tt := range tests {
   426  			if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") {
   427  				t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)")
   428  				continue
   429  			}
   430  
   431  			var actual []*net.IPNet
   432  
   433  			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
   434  			if err != nil {
   435  				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
   436  				continue
   437  			}
   438  
   439  			if !reflect.DeepEqual(actual, tt.value) {
   440  				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
   441  			}
   442  
   443  			ensureConnValid(t, conn)
   444  		}
   445  	})
   446  }
   447  
   448  func TestInetCIDRArrayTranscodeIP(t *testing.T) {
   449  	t.Parallel()
   450  
   451  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   452  		tests := []struct {
   453  			sql   string
   454  			value []net.IP
   455  		}{
   456  			{
   457  				"select $1::inet[]",
   458  				[]net.IP{
   459  					net.ParseIP("0.0.0.0"),
   460  					net.ParseIP("127.0.0.1"),
   461  					net.ParseIP("12.34.56.0"),
   462  					net.ParseIP("255.255.255.255"),
   463  					net.ParseIP("2607:f8b0:4009:80b::200e"),
   464  				},
   465  			},
   466  			{
   467  				"select $1::cidr[]",
   468  				[]net.IP{
   469  					net.ParseIP("0.0.0.0"),
   470  					net.ParseIP("127.0.0.1"),
   471  					net.ParseIP("12.34.56.0"),
   472  					net.ParseIP("255.255.255.255"),
   473  					net.ParseIP("2607:f8b0:4009:80b::200e"),
   474  				},
   475  			},
   476  		}
   477  
   478  		for i, tt := range tests {
   479  			if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") {
   480  				t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)")
   481  				continue
   482  			}
   483  
   484  			var actual []net.IP
   485  
   486  			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
   487  			if err != nil {
   488  				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
   489  				continue
   490  			}
   491  
   492  			assert.Equal(t, len(tt.value), len(actual), "%d", i)
   493  			for j := range actual {
   494  				assert.True(t, actual[j].Equal(tt.value[j]), "%d", i)
   495  			}
   496  
   497  			ensureConnValid(t, conn)
   498  		}
   499  
   500  		failTests := []struct {
   501  			sql   string
   502  			value []*net.IPNet
   503  		}{
   504  			{
   505  				"select $1::inet[]",
   506  				[]*net.IPNet{
   507  					mustParseCIDR(t, "12.34.56.0/32"),
   508  					mustParseCIDR(t, "192.168.1.0/24"),
   509  				},
   510  			},
   511  			{
   512  				"select $1::cidr[]",
   513  				[]*net.IPNet{
   514  					mustParseCIDR(t, "12.34.56.0/32"),
   515  					mustParseCIDR(t, "192.168.1.0/24"),
   516  				},
   517  			},
   518  		}
   519  
   520  		for i, tt := range failTests {
   521  			var actual []net.IP
   522  
   523  			err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual)
   524  			if err == nil {
   525  				t.Errorf("%d. Expected failure but got none", i)
   526  				continue
   527  			}
   528  
   529  			ensureConnValid(t, conn)
   530  		}
   531  	})
   532  }
   533  
   534  func TestInetCIDRTranscodeWithJustIP(t *testing.T) {
   535  	t.Parallel()
   536  
   537  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   538  		tests := []struct {
   539  			sql   string
   540  			value string
   541  		}{
   542  			{"select $1::inet", "0.0.0.0/32"},
   543  			{"select $1::inet", "127.0.0.1/32"},
   544  			{"select $1::inet", "12.34.56.0/32"},
   545  			{"select $1::inet", "255.255.255.255/32"},
   546  			{"select $1::inet", "::/128"},
   547  			{"select $1::inet", "2607:f8b0:4009:80b::200e/128"},
   548  			{"select $1::cidr", "0.0.0.0/32"},
   549  			{"select $1::cidr", "127.0.0.1/32"},
   550  			{"select $1::cidr", "12.34.56.0/32"},
   551  			{"select $1::cidr", "255.255.255.255/32"},
   552  			{"select $1::cidr", "::/128"},
   553  			{"select $1::cidr", "2607:f8b0:4009:80b::200e/128"},
   554  		}
   555  
   556  		for i, tt := range tests {
   557  			if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") {
   558  				t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)")
   559  				continue
   560  			}
   561  
   562  			expected := mustParseCIDR(t, tt.value)
   563  			var actual net.IPNet
   564  
   565  			err := conn.QueryRow(context.Background(), tt.sql, expected.IP).Scan(&actual)
   566  			if err != nil {
   567  				t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
   568  				continue
   569  			}
   570  
   571  			if actual.String() != expected.String() {
   572  				t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
   573  			}
   574  
   575  			ensureConnValid(t, conn)
   576  		}
   577  	})
   578  }
   579  
   580  func TestArrayDecoding(t *testing.T) {
   581  	t.Parallel()
   582  
   583  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   584  		tests := []struct {
   585  			sql    string
   586  			query  interface{}
   587  			scan   interface{}
   588  			assert func(*testing.T, interface{}, interface{})
   589  		}{
   590  			{
   591  				"select $1::bool[]", []bool{true, false, true}, &[]bool{},
   592  				func(t *testing.T, query, scan interface{}) {
   593  					if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
   594  						t.Errorf("failed to encode bool[]")
   595  					}
   596  				},
   597  			},
   598  			{
   599  				"select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
   600  				func(t *testing.T, query, scan interface{}) {
   601  					if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
   602  						t.Errorf("failed to encode smallint[]")
   603  					}
   604  				},
   605  			},
   606  			{
   607  				"select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
   608  				func(t *testing.T, query, scan interface{}) {
   609  					if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
   610  						t.Errorf("failed to encode smallint[]")
   611  					}
   612  				},
   613  			},
   614  			{
   615  				"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
   616  				func(t *testing.T, query, scan interface{}) {
   617  					if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
   618  						t.Errorf("failed to encode int[]")
   619  					}
   620  				},
   621  			},
   622  			{
   623  				"select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
   624  				func(t *testing.T, query, scan interface{}) {
   625  					if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
   626  						t.Errorf("failed to encode int[]")
   627  					}
   628  				},
   629  			},
   630  			{
   631  				"select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
   632  				func(t *testing.T, query, scan interface{}) {
   633  					if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
   634  						t.Errorf("failed to encode bigint[]")
   635  					}
   636  				},
   637  			},
   638  			{
   639  				"select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
   640  				func(t *testing.T, query, scan interface{}) {
   641  					if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
   642  						t.Errorf("failed to encode bigint[]")
   643  					}
   644  				},
   645  			},
   646  			{
   647  				"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
   648  				func(t *testing.T, query, scan interface{}) {
   649  					if !reflect.DeepEqual(query, *(scan.(*[]string))) {
   650  						t.Errorf("failed to encode text[]")
   651  					}
   652  				},
   653  			},
   654  			{
   655  				"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
   656  				func(t *testing.T, query, scan interface{}) {
   657  					queryTimeSlice := query.([]time.Time)
   658  					scanTimeSlice := *(scan.(*[]time.Time))
   659  					require.Equal(t, len(queryTimeSlice), len(scanTimeSlice))
   660  					for i := range queryTimeSlice {
   661  						assert.Truef(t, queryTimeSlice[i].Equal(scanTimeSlice[i]), "%d", i)
   662  					}
   663  				},
   664  			},
   665  			{
   666  				"select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{},
   667  				func(t *testing.T, query, scan interface{}) {
   668  					queryBytesSliceSlice := query.([][]byte)
   669  					scanBytesSliceSlice := *(scan.(*[][]byte))
   670  					if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) {
   671  						t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice))
   672  					}
   673  					for i := range queryBytesSliceSlice {
   674  						qb := queryBytesSliceSlice[i]
   675  						sb := scanBytesSliceSlice[i]
   676  						if !bytes.Equal(qb, sb) {
   677  							t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb)
   678  						}
   679  					}
   680  				},
   681  			},
   682  		}
   683  
   684  		for i, tt := range tests {
   685  			err := conn.QueryRow(context.Background(), tt.sql, tt.query).Scan(tt.scan)
   686  			if err != nil {
   687  				t.Errorf(`%d. error reading array: %v`, i, err)
   688  				continue
   689  			}
   690  			tt.assert(t, tt.query, tt.scan)
   691  			ensureConnValid(t, conn)
   692  		}
   693  	})
   694  }
   695  
   696  func TestEmptyArrayDecoding(t *testing.T) {
   697  	t.Parallel()
   698  
   699  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   700  		var val []string
   701  
   702  		err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val)
   703  		if err != nil {
   704  			t.Errorf(`error reading array: %v`, err)
   705  		}
   706  		if len(val) != 0 {
   707  			t.Errorf("Expected 0 values, got %d", len(val))
   708  		}
   709  
   710  		var n, m int32
   711  
   712  		err = conn.QueryRow(context.Background(), "select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m)
   713  		if err != nil {
   714  			t.Errorf(`error reading array: %v`, err)
   715  		}
   716  		if len(val) != 0 {
   717  			t.Errorf("Expected 0 values, got %d", len(val))
   718  		}
   719  		if n != 1 {
   720  			t.Errorf("Expected n to be 1, but it was %d", n)
   721  		}
   722  		if m != 42 {
   723  			t.Errorf("Expected n to be 42, but it was %d", n)
   724  		}
   725  
   726  		rows, err := conn.Query(context.Background(), "select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]")
   727  		if err != nil {
   728  			t.Errorf(`error retrieving rows with array: %v`, err)
   729  		}
   730  		defer rows.Close()
   731  
   732  		for rows.Next() {
   733  			err = rows.Scan(&n, &val)
   734  			if err != nil {
   735  				t.Errorf(`error reading array: %v`, err)
   736  			}
   737  		}
   738  	})
   739  }
   740  
   741  func TestPointerPointer(t *testing.T) {
   742  	t.Parallel()
   743  
   744  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   745  		skipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types")
   746  
   747  		type allTypes struct {
   748  			s   *string
   749  			i16 *int16
   750  			i32 *int32
   751  			i64 *int64
   752  			f32 *float32
   753  			f64 *float64
   754  			b   *bool
   755  			t   *time.Time
   756  		}
   757  
   758  		var actual, zero, expected allTypes
   759  
   760  		{
   761  			s := "foo"
   762  			expected.s = &s
   763  			i16 := int16(1)
   764  			expected.i16 = &i16
   765  			i32 := int32(1)
   766  			expected.i32 = &i32
   767  			i64 := int64(1)
   768  			expected.i64 = &i64
   769  			f32 := float32(1.23)
   770  			expected.f32 = &f32
   771  			f64 := float64(1.23)
   772  			expected.f64 = &f64
   773  			b := true
   774  			expected.b = &b
   775  			t := time.Unix(123, 5000)
   776  			expected.t = &t
   777  		}
   778  
   779  		tests := []struct {
   780  			sql       string
   781  			queryArgs []interface{}
   782  			scanArgs  []interface{}
   783  			expected  allTypes
   784  		}{
   785  			{"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}},
   786  			{"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}},
   787  			{"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}},
   788  			{"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}},
   789  			{"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}},
   790  			{"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}},
   791  			{"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}},
   792  			{"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}},
   793  			{"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}},
   794  			{"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}},
   795  			{"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}},
   796  			{"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}},
   797  			{"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}},
   798  			{"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}},
   799  			{"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}},
   800  			{"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}},
   801  		}
   802  
   803  		for i, tt := range tests {
   804  			actual = zero
   805  
   806  			err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...)
   807  			if err != nil {
   808  				t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs)
   809  			}
   810  
   811  			assert.Equal(t, tt.expected.s, actual.s)
   812  			assert.Equal(t, tt.expected.i16, actual.i16)
   813  			assert.Equal(t, tt.expected.i32, actual.i32)
   814  			assert.Equal(t, tt.expected.i64, actual.i64)
   815  			assert.Equal(t, tt.expected.f32, actual.f32)
   816  			assert.Equal(t, tt.expected.f64, actual.f64)
   817  			assert.Equal(t, tt.expected.b, actual.b)
   818  			if tt.expected.t != nil || actual.t != nil {
   819  				assert.True(t, tt.expected.t.Equal(*actual.t))
   820  			}
   821  
   822  			ensureConnValid(t, conn)
   823  		}
   824  	})
   825  }
   826  
   827  func TestPointerPointerNonZero(t *testing.T) {
   828  	t.Parallel()
   829  
   830  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   831  		f := "foo"
   832  		dest := &f
   833  
   834  		err := conn.QueryRow(context.Background(), "select $1::text", nil).Scan(&dest)
   835  		if err != nil {
   836  			t.Errorf("Unexpected failure scanning: %v", err)
   837  		}
   838  		if dest != nil {
   839  			t.Errorf("Expected dest to be nil, got %#v", dest)
   840  		}
   841  	})
   842  }
   843  
   844  func TestEncodeTypeRename(t *testing.T) {
   845  	t.Parallel()
   846  
   847  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   848  		type _int int
   849  		inInt := _int(1)
   850  		var outInt _int
   851  
   852  		type _int8 int8
   853  		inInt8 := _int8(2)
   854  		var outInt8 _int8
   855  
   856  		type _int16 int16
   857  		inInt16 := _int16(3)
   858  		var outInt16 _int16
   859  
   860  		type _int32 int32
   861  		inInt32 := _int32(4)
   862  		var outInt32 _int32
   863  
   864  		type _int64 int64
   865  		inInt64 := _int64(5)
   866  		var outInt64 _int64
   867  
   868  		type _uint uint
   869  		inUint := _uint(6)
   870  		var outUint _uint
   871  
   872  		type _uint8 uint8
   873  		inUint8 := _uint8(7)
   874  		var outUint8 _uint8
   875  
   876  		type _uint16 uint16
   877  		inUint16 := _uint16(8)
   878  		var outUint16 _uint16
   879  
   880  		type _uint32 uint32
   881  		inUint32 := _uint32(9)
   882  		var outUint32 _uint32
   883  
   884  		type _uint64 uint64
   885  		inUint64 := _uint64(10)
   886  		var outUint64 _uint64
   887  
   888  		type _string string
   889  		inString := _string("foo")
   890  		var outString _string
   891  
   892  		err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text",
   893  			inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString,
   894  		).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)
   895  		if err != nil {
   896  			t.Fatalf("Failed with type rename: %v", err)
   897  		}
   898  
   899  		if inInt != outInt {
   900  			t.Errorf("int rename: expected %v, got %v", inInt, outInt)
   901  		}
   902  
   903  		if inInt8 != outInt8 {
   904  			t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8)
   905  		}
   906  
   907  		if inInt16 != outInt16 {
   908  			t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16)
   909  		}
   910  
   911  		if inInt32 != outInt32 {
   912  			t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32)
   913  		}
   914  
   915  		if inInt64 != outInt64 {
   916  			t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64)
   917  		}
   918  
   919  		if inUint != outUint {
   920  			t.Errorf("uint rename: expected %v, got %v", inUint, outUint)
   921  		}
   922  
   923  		if inUint8 != outUint8 {
   924  			t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8)
   925  		}
   926  
   927  		if inUint16 != outUint16 {
   928  			t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16)
   929  		}
   930  
   931  		if inUint32 != outUint32 {
   932  			t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32)
   933  		}
   934  
   935  		if inUint64 != outUint64 {
   936  			t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64)
   937  		}
   938  
   939  		if inString != outString {
   940  			t.Errorf("string rename: expected %v, got %v", inString, outString)
   941  		}
   942  	})
   943  }
   944  
   945  func TestRowDecodeBinary(t *testing.T) {
   946  	t.Parallel()
   947  
   948  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   949  	defer closeConn(t, conn)
   950  
   951  	tests := []struct {
   952  		sql      string
   953  		expected []interface{}
   954  	}{
   955  		{
   956  			"select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)",
   957  			[]interface{}{
   958  				int32(1),
   959  				"cat",
   960  				time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(),
   961  			},
   962  		},
   963  		{
   964  			"select row(100.0::float, 1.09::float)",
   965  			[]interface{}{
   966  				float64(100),
   967  				float64(1.09),
   968  			},
   969  		},
   970  	}
   971  
   972  	for i, tt := range tests {
   973  		var actual []interface{}
   974  
   975  		err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual)
   976  		if err != nil {
   977  			t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
   978  			continue
   979  		}
   980  
   981  		for j := range tt.expected {
   982  			assert.EqualValuesf(t, tt.expected[j], actual[j], "%d. [%d]", i, j)
   983  
   984  		}
   985  
   986  		ensureConnValid(t, conn)
   987  	}
   988  }
   989  
   990  // https://github.com/jackc/pgx/issues/810
   991  func TestRowsScanNilThenScanValue(t *testing.T) {
   992  	t.Parallel()
   993  
   994  	testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) {
   995  		sql := `select null as a, null as b
   996  union
   997  select 1, 2
   998  order by a nulls first
   999  `
  1000  		rows, err := conn.Query(context.Background(), sql)
  1001  		require.NoError(t, err)
  1002  
  1003  		require.True(t, rows.Next())
  1004  
  1005  		err = rows.Scan(nil, nil)
  1006  		require.NoError(t, err)
  1007  
  1008  		require.True(t, rows.Next())
  1009  
  1010  		var a int
  1011  		var b int
  1012  		err = rows.Scan(&a, &b)
  1013  		require.NoError(t, err)
  1014  
  1015  		require.EqualValues(t, 1, a)
  1016  		require.EqualValues(t, 2, b)
  1017  
  1018  		rows.Close()
  1019  		require.NoError(t, rows.Err())
  1020  	})
  1021  }
  1022  
  1023  func TestScanIntoByteSlice(t *testing.T) {
  1024  	t.Parallel()
  1025  
  1026  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
  1027  	defer closeConn(t, conn)
  1028  	// Success cases
  1029  	for _, tt := range []struct {
  1030  		name             string
  1031  		sql              string
  1032  		resultFormatCode int16
  1033  		output           []byte
  1034  	}{
  1035  		{"int - text", "select 42", pgx.TextFormatCode, []byte("42")},
  1036  		{"text - text", "select 'hi'", pgx.TextFormatCode, []byte("hi")},
  1037  		{"text - binary", "select 'hi'", pgx.BinaryFormatCode, []byte("hi")},
  1038  		{"json - text", "select '{}'::json", pgx.TextFormatCode, []byte("{}")},
  1039  		{"json - binary", "select '{}'::json", pgx.BinaryFormatCode, []byte("{}")},
  1040  		{"jsonb - text", "select '{}'::jsonb", pgx.TextFormatCode, []byte("{}")},
  1041  		{"jsonb - binary", "select '{}'::jsonb", pgx.BinaryFormatCode, []byte("{}")},
  1042  	} {
  1043  		t.Run(tt.name, func(t *testing.T) {
  1044  			var buf []byte
  1045  			err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{tt.resultFormatCode}).Scan(&buf)
  1046  			require.NoError(t, err)
  1047  			require.Equal(t, tt.output, buf)
  1048  		})
  1049  	}
  1050  
  1051  	// Failure cases
  1052  	for _, tt := range []struct {
  1053  		name string
  1054  		sql  string
  1055  		err  string
  1056  	}{
  1057  		{"int binary", "select 42", "can't scan into dest[0]: cannot assign 42 into *[]uint8"},
  1058  	} {
  1059  		t.Run(tt.name, func(t *testing.T) {
  1060  			var buf []byte
  1061  			err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&buf)
  1062  			require.EqualError(t, err, tt.err)
  1063  		})
  1064  	}
  1065  }
  1066  

View as plain text