...

Source file src/github.com/lib/pq/conn_test.go

Documentation: github.com/lib/pq

     1  package pq
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"os"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  type Fatalistic interface {
    19  	Fatal(args ...interface{})
    20  }
    21  
    22  func forceBinaryParameters() bool {
    23  	bp := os.Getenv("PQTEST_BINARY_PARAMETERS")
    24  	if bp == "yes" {
    25  		return true
    26  	} else if bp == "" || bp == "no" {
    27  		return false
    28  	} else {
    29  		panic("unexpected value for PQTEST_BINARY_PARAMETERS")
    30  	}
    31  }
    32  
    33  func testConninfo(conninfo string) string {
    34  	defaultTo := func(envvar string, value string) {
    35  		if os.Getenv(envvar) == "" {
    36  			os.Setenv(envvar, value)
    37  		}
    38  	}
    39  	defaultTo("PGDATABASE", "pqgotest")
    40  	defaultTo("PGSSLMODE", "disable")
    41  	defaultTo("PGCONNECT_TIMEOUT", "20")
    42  
    43  	if forceBinaryParameters() &&
    44  		!strings.HasPrefix(conninfo, "postgres://") &&
    45  		!strings.HasPrefix(conninfo, "postgresql://") {
    46  		conninfo += " binary_parameters=yes"
    47  	}
    48  	return conninfo
    49  }
    50  
    51  func openTestConnConninfo(conninfo string) (*sql.DB, error) {
    52  	return sql.Open("postgres", testConninfo(conninfo))
    53  }
    54  
    55  func openTestConn(t Fatalistic) *sql.DB {
    56  	conn, err := openTestConnConninfo("")
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  
    61  	return conn
    62  }
    63  
    64  func getServerVersion(t *testing.T, db *sql.DB) int {
    65  	var version int
    66  	err := db.QueryRow("SHOW server_version_num").Scan(&version)
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	return version
    71  }
    72  
    73  func TestReconnect(t *testing.T) {
    74  	db1 := openTestConn(t)
    75  	defer db1.Close()
    76  	tx, err := db1.Begin()
    77  	if err != nil {
    78  		t.Fatal(err)
    79  	}
    80  	var pid1 int
    81  	err = tx.QueryRow("SELECT pg_backend_pid()").Scan(&pid1)
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	db2 := openTestConn(t)
    86  	defer db2.Close()
    87  	_, err = db2.Exec("SELECT pg_terminate_backend($1)", pid1)
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  	// The rollback will probably "fail" because we just killed
    92  	// its connection above
    93  	_ = tx.Rollback()
    94  
    95  	const expected int = 42
    96  	var result int
    97  	err = db1.QueryRow(fmt.Sprintf("SELECT %d", expected)).Scan(&result)
    98  	if err != nil {
    99  		t.Fatal(err)
   100  	}
   101  	if result != expected {
   102  		t.Errorf("got %v; expected %v", result, expected)
   103  	}
   104  }
   105  
   106  func TestCommitInFailedTransaction(t *testing.T) {
   107  	db := openTestConn(t)
   108  	defer db.Close()
   109  
   110  	txn, err := db.Begin()
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  	rows, err := txn.Query("SELECT error")
   115  	if err == nil {
   116  		rows.Close()
   117  		t.Fatal("expected failure")
   118  	}
   119  	err = txn.Commit()
   120  	if err != ErrInFailedTransaction {
   121  		t.Fatalf("expected ErrInFailedTransaction; got %#v", err)
   122  	}
   123  }
   124  
   125  func TestOpenURL(t *testing.T) {
   126  	testURL := func(url string) {
   127  		db, err := openTestConnConninfo(url)
   128  		if err != nil {
   129  			t.Fatal(err)
   130  		}
   131  		defer db.Close()
   132  		// database/sql might not call our Open at all unless we do something with
   133  		// the connection
   134  		txn, err := db.Begin()
   135  		if err != nil {
   136  			t.Fatal(err)
   137  		}
   138  		txn.Rollback()
   139  	}
   140  	testURL("postgres://")
   141  	testURL("postgresql://")
   142  }
   143  
   144  const pgpassFile = "/tmp/pqgotest_pgpass"
   145  
   146  func TestPgpass(t *testing.T) {
   147  	testAssert := func(conninfo string, expected string, reason string) {
   148  		conn, err := openTestConnConninfo(conninfo)
   149  		if err != nil {
   150  			t.Fatal(err)
   151  		}
   152  		defer conn.Close()
   153  
   154  		txn, err := conn.Begin()
   155  		if err != nil {
   156  			if expected != "fail" {
   157  				t.Fatalf(reason, err)
   158  			}
   159  			return
   160  		}
   161  		rows, err := txn.Query("SELECT USER")
   162  		if err != nil {
   163  			txn.Rollback()
   164  			if expected != "fail" {
   165  				t.Fatalf(reason, err)
   166  			}
   167  		} else {
   168  			rows.Close()
   169  			if expected != "ok" {
   170  				t.Fatalf(reason, err)
   171  			}
   172  		}
   173  		txn.Rollback()
   174  	}
   175  	testAssert("", "ok", "missing .pgpass, unexpected error %#v")
   176  	os.Setenv("PGPASSFILE", pgpassFile)
   177  	testAssert("host=/tmp", "fail", ", unexpected error %#v")
   178  	os.Remove(pgpassFile)
   179  	pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644)
   180  	if err != nil {
   181  		t.Fatalf("Unexpected error writing pgpass file %#v", err)
   182  	}
   183  	_, err = pgpass.WriteString(`# comment
   184  server:5432:some_db:some_user:pass_A
   185  *:5432:some_db:some_user:pass_B
   186  localhost:*:*:*:pass_C
   187  *:*:*:*:pass_fallback
   188  `)
   189  	if err != nil {
   190  		t.Fatalf("Unexpected error writing pgpass file %#v", err)
   191  	}
   192  	pgpass.Close()
   193  
   194  	assertPassword := func(extra values, expected string) {
   195  		o := values{
   196  			"host":               "localhost",
   197  			"sslmode":            "disable",
   198  			"connect_timeout":    "20",
   199  			"user":               "majid",
   200  			"port":               "5432",
   201  			"extra_float_digits": "2",
   202  			"dbname":             "pqgotest",
   203  			"client_encoding":    "UTF8",
   204  			"datestyle":          "ISO, MDY",
   205  		}
   206  		for k, v := range extra {
   207  			o[k] = v
   208  		}
   209  		(&conn{}).handlePgpass(o)
   210  		if pw := o["password"]; pw != expected {
   211  			t.Fatalf("For %v expected %s got %s", extra, expected, pw)
   212  		}
   213  	}
   214  	// wrong permissions for the pgpass file means it should be ignored
   215  	assertPassword(values{"host": "example.com", "user": "foo"}, "")
   216  	// fix the permissions and check if it has taken effect
   217  	os.Chmod(pgpassFile, 0600)
   218  	assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A")
   219  	assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback")
   220  	assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B")
   221  	// localhost also matches the default "" and UNIX sockets
   222  	assertPassword(values{"host": "", "user": "some_user"}, "pass_C")
   223  	assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C")
   224  	// cleanup
   225  	os.Remove(pgpassFile)
   226  	os.Setenv("PGPASSFILE", "")
   227  }
   228  
   229  func TestExec(t *testing.T) {
   230  	db := openTestConn(t)
   231  	defer db.Close()
   232  
   233  	_, err := db.Exec("CREATE TEMP TABLE temp (a int)")
   234  	if err != nil {
   235  		t.Fatal(err)
   236  	}
   237  
   238  	r, err := db.Exec("INSERT INTO temp VALUES (1)")
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  
   243  	if n, _ := r.RowsAffected(); n != 1 {
   244  		t.Fatalf("expected 1 row affected, not %d", n)
   245  	}
   246  
   247  	r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3)
   248  	if err != nil {
   249  		t.Fatal(err)
   250  	}
   251  
   252  	if n, _ := r.RowsAffected(); n != 3 {
   253  		t.Fatalf("expected 3 rows affected, not %d", n)
   254  	}
   255  
   256  	// SELECT doesn't send the number of returned rows in the command tag
   257  	// before 9.0
   258  	if getServerVersion(t, db) >= 90000 {
   259  		r, err = db.Exec("SELECT g FROM generate_series(1, 2) g")
   260  		if err != nil {
   261  			t.Fatal(err)
   262  		}
   263  		if n, _ := r.RowsAffected(); n != 2 {
   264  			t.Fatalf("expected 2 rows affected, not %d", n)
   265  		}
   266  
   267  		r, err = db.Exec("SELECT g FROM generate_series(1, $1) g", 3)
   268  		if err != nil {
   269  			t.Fatal(err)
   270  		}
   271  		if n, _ := r.RowsAffected(); n != 3 {
   272  			t.Fatalf("expected 3 rows affected, not %d", n)
   273  		}
   274  	}
   275  }
   276  
   277  func TestStatment(t *testing.T) {
   278  	db := openTestConn(t)
   279  	defer db.Close()
   280  
   281  	st, err := db.Prepare("SELECT 1")
   282  	if err != nil {
   283  		t.Fatal(err)
   284  	}
   285  
   286  	st1, err := db.Prepare("SELECT 2")
   287  	if err != nil {
   288  		t.Fatal(err)
   289  	}
   290  
   291  	r, err := st.Query()
   292  	if err != nil {
   293  		t.Fatal(err)
   294  	}
   295  	defer r.Close()
   296  
   297  	if !r.Next() {
   298  		t.Fatal("expected row")
   299  	}
   300  
   301  	var i int
   302  	err = r.Scan(&i)
   303  	if err != nil {
   304  		t.Fatal(err)
   305  	}
   306  
   307  	if i != 1 {
   308  		t.Fatalf("expected 1, got %d", i)
   309  	}
   310  
   311  	// st1
   312  
   313  	r1, err := st1.Query()
   314  	if err != nil {
   315  		t.Fatal(err)
   316  	}
   317  	defer r1.Close()
   318  
   319  	if !r1.Next() {
   320  		if r.Err() != nil {
   321  			t.Fatal(r1.Err())
   322  		}
   323  		t.Fatal("expected row")
   324  	}
   325  
   326  	err = r1.Scan(&i)
   327  	if err != nil {
   328  		t.Fatal(err)
   329  	}
   330  
   331  	if i != 2 {
   332  		t.Fatalf("expected 2, got %d", i)
   333  	}
   334  }
   335  
   336  func TestRowsCloseBeforeDone(t *testing.T) {
   337  	db := openTestConn(t)
   338  	defer db.Close()
   339  
   340  	r, err := db.Query("SELECT 1")
   341  	if err != nil {
   342  		t.Fatal(err)
   343  	}
   344  
   345  	err = r.Close()
   346  	if err != nil {
   347  		t.Fatal(err)
   348  	}
   349  
   350  	if r.Next() {
   351  		t.Fatal("unexpected row")
   352  	}
   353  
   354  	if r.Err() != nil {
   355  		t.Fatal(r.Err())
   356  	}
   357  }
   358  
   359  func TestParameterCountMismatch(t *testing.T) {
   360  	db := openTestConn(t)
   361  	defer db.Close()
   362  
   363  	var notused int
   364  	err := db.QueryRow("SELECT false", 1).Scan(&notused)
   365  	if err == nil {
   366  		t.Fatal("expected err")
   367  	}
   368  	// make sure we clean up correctly
   369  	err = db.QueryRow("SELECT 1").Scan(&notused)
   370  	if err != nil {
   371  		t.Fatal(err)
   372  	}
   373  
   374  	err = db.QueryRow("SELECT $1").Scan(&notused)
   375  	if err == nil {
   376  		t.Fatal("expected err")
   377  	}
   378  	// make sure we clean up correctly
   379  	err = db.QueryRow("SELECT 1").Scan(&notused)
   380  	if err != nil {
   381  		t.Fatal(err)
   382  	}
   383  }
   384  
   385  // Test that EmptyQueryResponses are handled correctly.
   386  func TestEmptyQuery(t *testing.T) {
   387  	db := openTestConn(t)
   388  	defer db.Close()
   389  
   390  	res, err := db.Exec("")
   391  	if err != nil {
   392  		t.Fatal(err)
   393  	}
   394  	if _, err := res.RowsAffected(); err != errNoRowsAffected {
   395  		t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
   396  	}
   397  	if _, err := res.LastInsertId(); err != errNoLastInsertID {
   398  		t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
   399  	}
   400  	rows, err := db.Query("")
   401  	if err != nil {
   402  		t.Fatal(err)
   403  	}
   404  	cols, err := rows.Columns()
   405  	if err != nil {
   406  		t.Fatal(err)
   407  	}
   408  	if len(cols) != 0 {
   409  		t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
   410  	}
   411  	if rows.Next() {
   412  		t.Fatal("unexpected row")
   413  	}
   414  	if rows.Err() != nil {
   415  		t.Fatal(rows.Err())
   416  	}
   417  
   418  	stmt, err := db.Prepare("")
   419  	if err != nil {
   420  		t.Fatal(err)
   421  	}
   422  	res, err = stmt.Exec()
   423  	if err != nil {
   424  		t.Fatal(err)
   425  	}
   426  	if _, err := res.RowsAffected(); err != errNoRowsAffected {
   427  		t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
   428  	}
   429  	if _, err := res.LastInsertId(); err != errNoLastInsertID {
   430  		t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
   431  	}
   432  	rows, err = stmt.Query()
   433  	if err != nil {
   434  		t.Fatal(err)
   435  	}
   436  	cols, err = rows.Columns()
   437  	if err != nil {
   438  		t.Fatal(err)
   439  	}
   440  	if len(cols) != 0 {
   441  		t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
   442  	}
   443  	if rows.Next() {
   444  		t.Fatal("unexpected row")
   445  	}
   446  	if rows.Err() != nil {
   447  		t.Fatal(rows.Err())
   448  	}
   449  }
   450  
   451  // Test that rows.Columns() is correct even if there are no result rows.
   452  func TestEmptyResultSetColumns(t *testing.T) {
   453  	db := openTestConn(t)
   454  	defer db.Close()
   455  
   456  	rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar WHERE FALSE")
   457  	if err != nil {
   458  		t.Fatal(err)
   459  	}
   460  	cols, err := rows.Columns()
   461  	if err != nil {
   462  		t.Fatal(err)
   463  	}
   464  	if len(cols) != 2 {
   465  		t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
   466  	}
   467  	if rows.Next() {
   468  		t.Fatal("unexpected row")
   469  	}
   470  	if rows.Err() != nil {
   471  		t.Fatal(rows.Err())
   472  	}
   473  	if cols[0] != "a" || cols[1] != "bar" {
   474  		t.Fatalf("unexpected Columns result %v", cols)
   475  	}
   476  
   477  	stmt, err := db.Prepare("SELECT $1::int AS a, text 'bar' AS bar WHERE FALSE")
   478  	if err != nil {
   479  		t.Fatal(err)
   480  	}
   481  	rows, err = stmt.Query(1)
   482  	if err != nil {
   483  		t.Fatal(err)
   484  	}
   485  	cols, err = rows.Columns()
   486  	if err != nil {
   487  		t.Fatal(err)
   488  	}
   489  	if len(cols) != 2 {
   490  		t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
   491  	}
   492  	if rows.Next() {
   493  		t.Fatal("unexpected row")
   494  	}
   495  	if rows.Err() != nil {
   496  		t.Fatal(rows.Err())
   497  	}
   498  	if cols[0] != "a" || cols[1] != "bar" {
   499  		t.Fatalf("unexpected Columns result %v", cols)
   500  	}
   501  
   502  }
   503  
   504  func TestEncodeDecode(t *testing.T) {
   505  	db := openTestConn(t)
   506  	defer db.Close()
   507  
   508  	q := `
   509  		SELECT
   510  			E'\\000\\001\\002'::bytea,
   511  			'foobar'::text,
   512  			NULL::integer,
   513  			'2000-1-1 01:02:03.04-7'::timestamptz,
   514  			0::boolean,
   515  			123,
   516  			-321,
   517  			3.14::float8
   518  		WHERE
   519  			    E'\\000\\001\\002'::bytea = $1
   520  			AND 'foobar'::text = $2
   521  			AND $3::integer is NULL
   522  	`
   523  	// AND '2000-1-1 12:00:00.000000-7'::timestamp = $3
   524  
   525  	exp1 := []byte{0, 1, 2}
   526  	exp2 := "foobar"
   527  
   528  	r, err := db.Query(q, exp1, exp2, nil)
   529  	if err != nil {
   530  		t.Fatal(err)
   531  	}
   532  	defer r.Close()
   533  
   534  	if !r.Next() {
   535  		if r.Err() != nil {
   536  			t.Fatal(r.Err())
   537  		}
   538  		t.Fatal("expected row")
   539  	}
   540  
   541  	var got1 []byte
   542  	var got2 string
   543  	var got3 = sql.NullInt64{Valid: true}
   544  	var got4 time.Time
   545  	var got5, got6, got7, got8 interface{}
   546  
   547  	err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7, &got8)
   548  	if err != nil {
   549  		t.Fatal(err)
   550  	}
   551  
   552  	if !reflect.DeepEqual(exp1, got1) {
   553  		t.Errorf("expected %q byte: %q", exp1, got1)
   554  	}
   555  
   556  	if !reflect.DeepEqual(exp2, got2) {
   557  		t.Errorf("expected %q byte: %q", exp2, got2)
   558  	}
   559  
   560  	if got3.Valid {
   561  		t.Fatal("expected invalid")
   562  	}
   563  
   564  	if got4.Year() != 2000 {
   565  		t.Fatal("wrong year")
   566  	}
   567  
   568  	if got5 != false {
   569  		t.Fatalf("expected false, got %q", got5)
   570  	}
   571  
   572  	if got6 != int64(123) {
   573  		t.Fatalf("expected 123, got %d", got6)
   574  	}
   575  
   576  	if got7 != int64(-321) {
   577  		t.Fatalf("expected -321, got %d", got7)
   578  	}
   579  
   580  	if got8 != float64(3.14) {
   581  		t.Fatalf("expected 3.14, got %f", got8)
   582  	}
   583  }
   584  
   585  func TestNoData(t *testing.T) {
   586  	db := openTestConn(t)
   587  	defer db.Close()
   588  
   589  	st, err := db.Prepare("SELECT 1 WHERE true = false")
   590  	if err != nil {
   591  		t.Fatal(err)
   592  	}
   593  	defer st.Close()
   594  
   595  	r, err := st.Query()
   596  	if err != nil {
   597  		t.Fatal(err)
   598  	}
   599  	defer r.Close()
   600  
   601  	if r.Next() {
   602  		if r.Err() != nil {
   603  			t.Fatal(r.Err())
   604  		}
   605  		t.Fatal("unexpected row")
   606  	}
   607  
   608  	_, err = db.Query("SELECT * FROM nonexistenttable WHERE age=$1", 20)
   609  	if err == nil {
   610  		t.Fatal("Should have raised an error on non existent table")
   611  	}
   612  
   613  	_, err = db.Query("SELECT * FROM nonexistenttable")
   614  	if err == nil {
   615  		t.Fatal("Should have raised an error on non existent table")
   616  	}
   617  }
   618  
   619  func TestErrorDuringStartup(t *testing.T) {
   620  	// Don't use the normal connection setup, this is intended to
   621  	// blow up in the startup packet from a non-existent user.
   622  	db, err := openTestConnConninfo("user=thisuserreallydoesntexist")
   623  	if err != nil {
   624  		t.Fatal(err)
   625  	}
   626  	defer db.Close()
   627  
   628  	_, err = db.Begin()
   629  	if err == nil {
   630  		t.Fatal("expected error")
   631  	}
   632  
   633  	e, ok := err.(*Error)
   634  	if !ok {
   635  		t.Fatalf("expected Error, got %#v", err)
   636  	} else if e.Code.Name() != "invalid_authorization_specification" && e.Code.Name() != "invalid_password" {
   637  		t.Fatalf("expected invalid_authorization_specification or invalid_password, got %s (%+v)", e.Code.Name(), err)
   638  	}
   639  }
   640  
   641  type testConn struct {
   642  	closed bool
   643  	net.Conn
   644  }
   645  
   646  func (c *testConn) Close() error {
   647  	c.closed = true
   648  	return c.Conn.Close()
   649  }
   650  
   651  type testDialer struct {
   652  	conns []*testConn
   653  }
   654  
   655  func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) {
   656  	c, err := net.Dial(ntw, addr)
   657  	if err != nil {
   658  		return nil, err
   659  	}
   660  	tc := &testConn{Conn: c}
   661  	d.conns = append(d.conns, tc)
   662  	return tc, nil
   663  }
   664  
   665  func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
   666  	c, err := net.DialTimeout(ntw, addr, timeout)
   667  	if err != nil {
   668  		return nil, err
   669  	}
   670  	tc := &testConn{Conn: c}
   671  	d.conns = append(d.conns, tc)
   672  	return tc, nil
   673  }
   674  
   675  func TestErrorDuringStartupClosesConn(t *testing.T) {
   676  	// Don't use the normal connection setup, this is intended to
   677  	// blow up in the startup packet from a non-existent user.
   678  	var d testDialer
   679  	c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist"))
   680  	if err == nil {
   681  		c.Close()
   682  		t.Fatal("expected dial error")
   683  	}
   684  	if len(d.conns) != 1 {
   685  		t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1)
   686  	}
   687  	if !d.conns[0].closed {
   688  		t.Error("connection leaked")
   689  	}
   690  }
   691  
   692  func TestBadConn(t *testing.T) {
   693  	var err error
   694  
   695  	cn := conn{}
   696  	func() {
   697  		defer cn.errRecover(&err)
   698  		panic(io.EOF)
   699  	}()
   700  	if err != driver.ErrBadConn {
   701  		t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
   702  	}
   703  	if err := cn.err.get(); err != driver.ErrBadConn {
   704  		t.Fatalf("expected driver.ErrBadConn, got %#v", err)
   705  	}
   706  
   707  	cn = conn{}
   708  	func() {
   709  		defer cn.errRecover(&err)
   710  		e := &Error{Severity: Efatal}
   711  		panic(e)
   712  	}()
   713  	if err != driver.ErrBadConn {
   714  		t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
   715  	}
   716  	if err := cn.err.get(); err != driver.ErrBadConn {
   717  		t.Fatalf("expected driver.ErrBadConn, got %#v", err)
   718  	}
   719  }
   720  
   721  // TestCloseBadConn tests that the underlying connection can be closed with
   722  // Close after an error.
   723  func TestCloseBadConn(t *testing.T) {
   724  	host := os.Getenv("PGHOST")
   725  	if host == "" {
   726  		host = "localhost"
   727  	}
   728  	port := os.Getenv("PGPORT")
   729  	if port == "" {
   730  		port = "5432"
   731  	}
   732  	nc, err := net.Dial("tcp", host+":"+port)
   733  	if err != nil {
   734  		t.Fatal(err)
   735  	}
   736  	cn := conn{c: nc}
   737  	func() {
   738  		defer cn.errRecover(&err)
   739  		panic(io.EOF)
   740  	}()
   741  	// Verify we can write before closing.
   742  	if _, err := nc.Write(nil); err != nil {
   743  		t.Fatal(err)
   744  	}
   745  	// First close should close the connection.
   746  	if err := cn.Close(); err != nil {
   747  		t.Fatal(err)
   748  	}
   749  
   750  	// During the Go 1.9 cycle, https://github.com/golang/go/commit/3792db5
   751  	// changed this error from
   752  	//
   753  	// net.errClosing = errors.New("use of closed network connection")
   754  	//
   755  	// to
   756  	//
   757  	// internal/poll.ErrClosing = errors.New("use of closed file or network connection")
   758  	const errClosing = "use of closed"
   759  
   760  	// Verify write after closing fails.
   761  	if _, err := nc.Write(nil); err == nil {
   762  		t.Fatal("expected error")
   763  	} else if !strings.Contains(err.Error(), errClosing) {
   764  		t.Fatalf("expected %s error, got %s", errClosing, err)
   765  	}
   766  	// Verify second close fails.
   767  	if err := cn.Close(); err == nil {
   768  		t.Fatal("expected error")
   769  	} else if !strings.Contains(err.Error(), errClosing) {
   770  		t.Fatalf("expected %s error, got %s", errClosing, err)
   771  	}
   772  }
   773  
   774  func TestErrorOnExec(t *testing.T) {
   775  	db := openTestConn(t)
   776  	defer db.Close()
   777  
   778  	txn, err := db.Begin()
   779  	if err != nil {
   780  		t.Fatal(err)
   781  	}
   782  	defer txn.Rollback()
   783  
   784  	_, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)")
   785  	if err != nil {
   786  		t.Fatal(err)
   787  	}
   788  
   789  	_, err = txn.Exec("INSERT INTO foo VALUES (0), (0)")
   790  	if err == nil {
   791  		t.Fatal("Should have raised error")
   792  	}
   793  
   794  	e, ok := err.(*Error)
   795  	if !ok {
   796  		t.Fatalf("expected Error, got %#v", err)
   797  	} else if e.Code.Name() != "unique_violation" {
   798  		t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err)
   799  	}
   800  }
   801  
   802  func TestErrorOnQuery(t *testing.T) {
   803  	db := openTestConn(t)
   804  	defer db.Close()
   805  
   806  	txn, err := db.Begin()
   807  	if err != nil {
   808  		t.Fatal(err)
   809  	}
   810  	defer txn.Rollback()
   811  
   812  	_, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)")
   813  	if err != nil {
   814  		t.Fatal(err)
   815  	}
   816  
   817  	_, err = txn.Query("INSERT INTO foo VALUES (0), (0)")
   818  	if err == nil {
   819  		t.Fatal("Should have raised error")
   820  	}
   821  
   822  	e, ok := err.(*Error)
   823  	if !ok {
   824  		t.Fatalf("expected Error, got %#v", err)
   825  	} else if e.Code.Name() != "unique_violation" {
   826  		t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err)
   827  	}
   828  }
   829  
   830  func TestErrorOnQueryRowSimpleQuery(t *testing.T) {
   831  	db := openTestConn(t)
   832  	defer db.Close()
   833  
   834  	txn, err := db.Begin()
   835  	if err != nil {
   836  		t.Fatal(err)
   837  	}
   838  	defer txn.Rollback()
   839  
   840  	_, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)")
   841  	if err != nil {
   842  		t.Fatal(err)
   843  	}
   844  
   845  	var v int
   846  	err = txn.QueryRow("INSERT INTO foo VALUES (0), (0)").Scan(&v)
   847  	if err == nil {
   848  		t.Fatal("Should have raised error")
   849  	}
   850  
   851  	e, ok := err.(*Error)
   852  	if !ok {
   853  		t.Fatalf("expected Error, got %#v", err)
   854  	} else if e.Code.Name() != "unique_violation" {
   855  		t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err)
   856  	}
   857  }
   858  
   859  // Test the QueryRow bug workarounds in stmt.exec() and simpleQuery()
   860  func TestQueryRowBugWorkaround(t *testing.T) {
   861  	db := openTestConn(t)
   862  	defer db.Close()
   863  
   864  	// stmt.exec()
   865  	_, err := db.Exec("CREATE TEMP TABLE notnulltemp (a varchar(10) not null)")
   866  	if err != nil {
   867  		t.Fatal(err)
   868  	}
   869  
   870  	var a string
   871  	err = db.QueryRow("INSERT INTO notnulltemp(a) values($1) RETURNING a", nil).Scan(&a)
   872  	if err == sql.ErrNoRows {
   873  		t.Fatalf("expected constraint violation error; got: %v", err)
   874  	}
   875  	pge, ok := err.(*Error)
   876  	if !ok {
   877  		t.Fatalf("expected *Error; got: %#v", err)
   878  	}
   879  	if pge.Code.Name() != "not_null_violation" {
   880  		t.Fatalf("expected not_null_violation; got: %s (%+v)", pge.Code.Name(), err)
   881  	}
   882  
   883  	// Test workaround in simpleQuery()
   884  	tx, err := db.Begin()
   885  	if err != nil {
   886  		t.Fatalf("unexpected error %s in Begin", err)
   887  	}
   888  	defer tx.Rollback()
   889  
   890  	_, err = tx.Exec("SET LOCAL check_function_bodies TO FALSE")
   891  	if err != nil {
   892  		t.Fatalf("could not disable check_function_bodies: %s", err)
   893  	}
   894  	_, err = tx.Exec(`
   895  CREATE OR REPLACE FUNCTION bad_function()
   896  RETURNS integer
   897  -- hack to prevent the function from being inlined
   898  SET check_function_bodies TO TRUE
   899  AS $$
   900  	SELECT text 'bad'
   901  $$ LANGUAGE sql`)
   902  	if err != nil {
   903  		t.Fatalf("could not create function: %s", err)
   904  	}
   905  
   906  	err = tx.QueryRow("SELECT * FROM bad_function()").Scan(&a)
   907  	if err == nil {
   908  		t.Fatalf("expected error")
   909  	}
   910  	pge, ok = err.(*Error)
   911  	if !ok {
   912  		t.Fatalf("expected *Error; got: %#v", err)
   913  	}
   914  	if pge.Code.Name() != "invalid_function_definition" {
   915  		t.Fatalf("expected invalid_function_definition; got: %s (%+v)", pge.Code.Name(), err)
   916  	}
   917  
   918  	err = tx.Rollback()
   919  	if err != nil {
   920  		t.Fatalf("unexpected error %s in Rollback", err)
   921  	}
   922  
   923  	// Also test that simpleQuery()'s workaround works when the query fails
   924  	// after a row has been received.
   925  	rows, err := db.Query(`
   926  select
   927  	(select generate_series(1, ss.i))
   928  from (select gs.i
   929        from generate_series(1, 2) gs(i)
   930        order by gs.i limit 2) ss`)
   931  	if err != nil {
   932  		t.Fatalf("query failed: %s", err)
   933  	}
   934  	if !rows.Next() {
   935  		t.Fatalf("expected at least one result row; got %s", rows.Err())
   936  	}
   937  	var i int
   938  	err = rows.Scan(&i)
   939  	if err != nil {
   940  		t.Fatalf("rows.Scan() failed: %s", err)
   941  	}
   942  	if i != 1 {
   943  		t.Fatalf("unexpected value for i: %d", i)
   944  	}
   945  	if rows.Next() {
   946  		t.Fatalf("unexpected row")
   947  	}
   948  	pge, ok = rows.Err().(*Error)
   949  	if !ok {
   950  		t.Fatalf("expected *Error; got: %#v", err)
   951  	}
   952  	if pge.Code.Name() != "cardinality_violation" {
   953  		t.Fatalf("expected cardinality_violation; got: %s (%+v)", pge.Code.Name(), rows.Err())
   954  	}
   955  }
   956  
   957  func TestSimpleQuery(t *testing.T) {
   958  	db := openTestConn(t)
   959  	defer db.Close()
   960  
   961  	r, err := db.Query("select 1")
   962  	if err != nil {
   963  		t.Fatal(err)
   964  	}
   965  	defer r.Close()
   966  
   967  	if !r.Next() {
   968  		t.Fatal("expected row")
   969  	}
   970  }
   971  
   972  func TestBindError(t *testing.T) {
   973  	db := openTestConn(t)
   974  	defer db.Close()
   975  
   976  	_, err := db.Exec("create temp table test (i integer)")
   977  	if err != nil {
   978  		t.Fatal(err)
   979  	}
   980  
   981  	_, err = db.Query("select * from test where i=$1", "hhh")
   982  	if err == nil {
   983  		t.Fatal("expected an error")
   984  	}
   985  
   986  	// Should not get error here
   987  	r, err := db.Query("select * from test where i=$1", 1)
   988  	if err != nil {
   989  		t.Fatal(err)
   990  	}
   991  	defer r.Close()
   992  }
   993  
   994  func TestParseErrorInExtendedQuery(t *testing.T) {
   995  	db := openTestConn(t)
   996  	defer db.Close()
   997  
   998  	_, err := db.Query("PARSE_ERROR $1", 1)
   999  	pqErr, _ := err.(*Error)
  1000  	// Expecting a syntax error.
  1001  	if err == nil || pqErr == nil || pqErr.Code != "42601" {
  1002  		t.Fatalf("expected syntax error, got %s", err)
  1003  	}
  1004  
  1005  	rows, err := db.Query("SELECT 1")
  1006  	if err != nil {
  1007  		t.Fatal(err)
  1008  	}
  1009  	rows.Close()
  1010  }
  1011  
  1012  // TestReturning tests that an INSERT query using the RETURNING clause returns a row.
  1013  func TestReturning(t *testing.T) {
  1014  	db := openTestConn(t)
  1015  	defer db.Close()
  1016  
  1017  	_, err := db.Exec("CREATE TEMP TABLE distributors (did integer default 0, dname text)")
  1018  	if err != nil {
  1019  		t.Fatal(err)
  1020  	}
  1021  
  1022  	rows, err := db.Query("INSERT INTO distributors (did, dname) VALUES (DEFAULT, 'XYZ Widgets') " +
  1023  		"RETURNING did;")
  1024  	if err != nil {
  1025  		t.Fatal(err)
  1026  	}
  1027  	if !rows.Next() {
  1028  		t.Fatal("no rows")
  1029  	}
  1030  	var did int
  1031  	err = rows.Scan(&did)
  1032  	if err != nil {
  1033  		t.Fatal(err)
  1034  	}
  1035  	if did != 0 {
  1036  		t.Fatalf("bad value for did: got %d, want %d", did, 0)
  1037  	}
  1038  
  1039  	if rows.Next() {
  1040  		t.Fatal("unexpected next row")
  1041  	}
  1042  	err = rows.Err()
  1043  	if err != nil {
  1044  		t.Fatal(err)
  1045  	}
  1046  }
  1047  
  1048  func TestIssue186(t *testing.T) {
  1049  	db := openTestConn(t)
  1050  	defer db.Close()
  1051  
  1052  	// Exec() a query which returns results
  1053  	_, err := db.Exec("VALUES (1), (2), (3)")
  1054  	if err != nil {
  1055  		t.Fatal(err)
  1056  	}
  1057  
  1058  	_, err = db.Exec("VALUES ($1), ($2), ($3)", 1, 2, 3)
  1059  	if err != nil {
  1060  		t.Fatal(err)
  1061  	}
  1062  
  1063  	// Query() a query which doesn't return any results
  1064  	txn, err := db.Begin()
  1065  	if err != nil {
  1066  		t.Fatal(err)
  1067  	}
  1068  	defer txn.Rollback()
  1069  
  1070  	rows, err := txn.Query("CREATE TEMP TABLE foo(f1 int)")
  1071  	if err != nil {
  1072  		t.Fatal(err)
  1073  	}
  1074  	if err = rows.Close(); err != nil {
  1075  		t.Fatal(err)
  1076  	}
  1077  
  1078  	// small trick to get NoData from a parameterized query
  1079  	_, err = txn.Exec("CREATE RULE nodata AS ON INSERT TO foo DO INSTEAD NOTHING")
  1080  	if err != nil {
  1081  		t.Fatal(err)
  1082  	}
  1083  	rows, err = txn.Query("INSERT INTO foo VALUES ($1)", 1)
  1084  	if err != nil {
  1085  		t.Fatal(err)
  1086  	}
  1087  	if err = rows.Close(); err != nil {
  1088  		t.Fatal(err)
  1089  	}
  1090  }
  1091  
  1092  func TestIssue196(t *testing.T) {
  1093  	db := openTestConn(t)
  1094  	defer db.Close()
  1095  
  1096  	row := db.QueryRow("SELECT float4 '0.10000122' = $1, float8 '35.03554004971999' = $2",
  1097  		float32(0.10000122), float64(35.03554004971999))
  1098  
  1099  	var float4match, float8match bool
  1100  	err := row.Scan(&float4match, &float8match)
  1101  	if err != nil {
  1102  		t.Fatal(err)
  1103  	}
  1104  	if !float4match {
  1105  		t.Errorf("Expected float4 fidelity to be maintained; got no match")
  1106  	}
  1107  	if !float8match {
  1108  		t.Errorf("Expected float8 fidelity to be maintained; got no match")
  1109  	}
  1110  }
  1111  
  1112  // Test that any CommandComplete messages sent before the query results are
  1113  // ignored.
  1114  func TestIssue282(t *testing.T) {
  1115  	db := openTestConn(t)
  1116  	defer db.Close()
  1117  
  1118  	var searchPath string
  1119  	err := db.QueryRow(`
  1120  		SET LOCAL search_path TO pg_catalog;
  1121  		SET LOCAL search_path TO pg_catalog;
  1122  		SHOW search_path`).Scan(&searchPath)
  1123  	if err != nil {
  1124  		t.Fatal(err)
  1125  	}
  1126  	if searchPath != "pg_catalog" {
  1127  		t.Fatalf("unexpected search_path %s", searchPath)
  1128  	}
  1129  }
  1130  
  1131  func TestReadFloatPrecision(t *testing.T) {
  1132  	db := openTestConn(t)
  1133  	defer db.Close()
  1134  
  1135  	row := db.QueryRow("SELECT float4 '0.10000122', float8 '35.03554004971999', float4 '1.2'")
  1136  	var float4val float32
  1137  	var float8val float64
  1138  	var float4val2 float64
  1139  	err := row.Scan(&float4val, &float8val, &float4val2)
  1140  	if err != nil {
  1141  		t.Fatal(err)
  1142  	}
  1143  	if float4val != float32(0.10000122) {
  1144  		t.Errorf("Expected float4 fidelity to be maintained; got no match")
  1145  	}
  1146  	if float8val != float64(35.03554004971999) {
  1147  		t.Errorf("Expected float8 fidelity to be maintained; got no match")
  1148  	}
  1149  	if float4val2 != float64(1.2) {
  1150  		t.Errorf("Expected float4 fidelity into a float64 to be maintained; got no match")
  1151  	}
  1152  }
  1153  
  1154  func TestXactMultiStmt(t *testing.T) {
  1155  	// minified test case based on bug reports from
  1156  	// pico303@gmail.com and rangelspam@gmail.com
  1157  	t.Skip("Skipping failing test")
  1158  	db := openTestConn(t)
  1159  	defer db.Close()
  1160  
  1161  	tx, err := db.Begin()
  1162  	if err != nil {
  1163  		t.Fatal(err)
  1164  	}
  1165  	defer tx.Commit()
  1166  
  1167  	rows, err := tx.Query("select 1")
  1168  	if err != nil {
  1169  		t.Fatal(err)
  1170  	}
  1171  
  1172  	if rows.Next() {
  1173  		var val int32
  1174  		if err = rows.Scan(&val); err != nil {
  1175  			t.Fatal(err)
  1176  		}
  1177  	} else {
  1178  		t.Fatal("Expected at least one row in first query in xact")
  1179  	}
  1180  
  1181  	rows2, err := tx.Query("select 2")
  1182  	if err != nil {
  1183  		t.Fatal(err)
  1184  	}
  1185  
  1186  	if rows2.Next() {
  1187  		var val2 int32
  1188  		if err := rows2.Scan(&val2); err != nil {
  1189  			t.Fatal(err)
  1190  		}
  1191  	} else {
  1192  		t.Fatal("Expected at least one row in second query in xact")
  1193  	}
  1194  
  1195  	if err = rows.Err(); err != nil {
  1196  		t.Fatal(err)
  1197  	}
  1198  
  1199  	if err = rows2.Err(); err != nil {
  1200  		t.Fatal(err)
  1201  	}
  1202  
  1203  	if err = tx.Commit(); err != nil {
  1204  		t.Fatal(err)
  1205  	}
  1206  }
  1207  
  1208  var envParseTests = []struct {
  1209  	Expected map[string]string
  1210  	Env      []string
  1211  }{
  1212  	{
  1213  		Env:      []string{"PGDATABASE=hello", "PGUSER=goodbye"},
  1214  		Expected: map[string]string{"dbname": "hello", "user": "goodbye"},
  1215  	},
  1216  	{
  1217  		Env:      []string{"PGDATESTYLE=ISO, MDY"},
  1218  		Expected: map[string]string{"datestyle": "ISO, MDY"},
  1219  	},
  1220  	{
  1221  		Env:      []string{"PGCONNECT_TIMEOUT=30"},
  1222  		Expected: map[string]string{"connect_timeout": "30"},
  1223  	},
  1224  }
  1225  
  1226  func TestParseEnviron(t *testing.T) {
  1227  	for i, tt := range envParseTests {
  1228  		results := parseEnviron(tt.Env)
  1229  		if !reflect.DeepEqual(tt.Expected, results) {
  1230  			t.Errorf("%d: Expected: %#v Got: %#v", i, tt.Expected, results)
  1231  		}
  1232  	}
  1233  }
  1234  
  1235  func TestParseComplete(t *testing.T) {
  1236  	tpc := func(commandTag string, command string, affectedRows int64, shouldFail bool) {
  1237  		defer func() {
  1238  			if p := recover(); p != nil {
  1239  				if !shouldFail {
  1240  					t.Error(p)
  1241  				}
  1242  			}
  1243  		}()
  1244  		cn := &conn{}
  1245  		res, c := cn.parseComplete(commandTag)
  1246  		if c != command {
  1247  			t.Errorf("Expected %v, got %v", command, c)
  1248  		}
  1249  		n, err := res.RowsAffected()
  1250  		if err != nil {
  1251  			t.Fatal(err)
  1252  		}
  1253  		if n != affectedRows {
  1254  			t.Errorf("Expected %d, got %d", affectedRows, n)
  1255  		}
  1256  	}
  1257  
  1258  	tpc("ALTER TABLE", "ALTER TABLE", 0, false)
  1259  	tpc("INSERT 0 1", "INSERT", 1, false)
  1260  	tpc("UPDATE 100", "UPDATE", 100, false)
  1261  	tpc("SELECT 100", "SELECT", 100, false)
  1262  	tpc("FETCH 100", "FETCH", 100, false)
  1263  	// allow COPY (and others) without row count
  1264  	tpc("COPY", "COPY", 0, false)
  1265  	// don't fail on command tags we don't recognize
  1266  	tpc("UNKNOWNCOMMANDTAG", "UNKNOWNCOMMANDTAG", 0, false)
  1267  
  1268  	// failure cases
  1269  	tpc("INSERT 1", "", 0, true)   // missing oid
  1270  	tpc("UPDATE 0 1", "", 0, true) // too many numbers
  1271  	tpc("SELECT foo", "", 0, true) // invalid row count
  1272  }
  1273  
  1274  // Test interface conformance.
  1275  var (
  1276  	_ driver.ExecerContext  = (*conn)(nil)
  1277  	_ driver.QueryerContext = (*conn)(nil)
  1278  )
  1279  
  1280  func TestNullAfterNonNull(t *testing.T) {
  1281  	db := openTestConn(t)
  1282  	defer db.Close()
  1283  
  1284  	r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer")
  1285  	if err != nil {
  1286  		t.Fatal(err)
  1287  	}
  1288  
  1289  	var n sql.NullInt64
  1290  
  1291  	if !r.Next() {
  1292  		if r.Err() != nil {
  1293  			t.Fatal(err)
  1294  		}
  1295  		t.Fatal("expected row")
  1296  	}
  1297  
  1298  	if err := r.Scan(&n); err != nil {
  1299  		t.Fatal(err)
  1300  	}
  1301  
  1302  	if n.Int64 != 9 {
  1303  		t.Fatalf("expected 2, not %d", n.Int64)
  1304  	}
  1305  
  1306  	if !r.Next() {
  1307  		if r.Err() != nil {
  1308  			t.Fatal(err)
  1309  		}
  1310  		t.Fatal("expected row")
  1311  	}
  1312  
  1313  	if err := r.Scan(&n); err != nil {
  1314  		t.Fatal(err)
  1315  	}
  1316  
  1317  	if n.Valid {
  1318  		t.Fatal("expected n to be invalid")
  1319  	}
  1320  
  1321  	if n.Int64 != 0 {
  1322  		t.Fatalf("expected n to 2, not %d", n.Int64)
  1323  	}
  1324  }
  1325  
  1326  func Test64BitErrorChecking(t *testing.T) {
  1327  	defer func() {
  1328  		if err := recover(); err != nil {
  1329  			t.Fatal("panic due to 0xFFFFFFFF != -1 " +
  1330  				"when int is 64 bits")
  1331  		}
  1332  	}()
  1333  
  1334  	db := openTestConn(t)
  1335  	defer db.Close()
  1336  
  1337  	r, err := db.Query(`SELECT *
  1338  FROM (VALUES (0::integer, NULL::text), (1, 'test string')) AS t;`)
  1339  
  1340  	if err != nil {
  1341  		t.Fatal(err)
  1342  	}
  1343  
  1344  	defer r.Close()
  1345  
  1346  	for r.Next() {
  1347  	}
  1348  }
  1349  
  1350  func TestCommit(t *testing.T) {
  1351  	db := openTestConn(t)
  1352  	defer db.Close()
  1353  
  1354  	_, err := db.Exec("CREATE TEMP TABLE temp (a int)")
  1355  	if err != nil {
  1356  		t.Fatal(err)
  1357  	}
  1358  	sqlInsert := "INSERT INTO temp VALUES (1)"
  1359  	sqlSelect := "SELECT * FROM temp"
  1360  	tx, err := db.Begin()
  1361  	if err != nil {
  1362  		t.Fatal(err)
  1363  	}
  1364  	_, err = tx.Exec(sqlInsert)
  1365  	if err != nil {
  1366  		t.Fatal(err)
  1367  	}
  1368  	err = tx.Commit()
  1369  	if err != nil {
  1370  		t.Fatal(err)
  1371  	}
  1372  	var i int
  1373  	err = db.QueryRow(sqlSelect).Scan(&i)
  1374  	if err != nil {
  1375  		t.Fatal(err)
  1376  	}
  1377  	if i != 1 {
  1378  		t.Fatalf("expected 1, got %d", i)
  1379  	}
  1380  }
  1381  
  1382  func TestErrorClass(t *testing.T) {
  1383  	db := openTestConn(t)
  1384  	defer db.Close()
  1385  
  1386  	_, err := db.Query("SELECT int 'notint'")
  1387  	if err == nil {
  1388  		t.Fatal("expected error")
  1389  	}
  1390  	pge, ok := err.(*Error)
  1391  	if !ok {
  1392  		t.Fatalf("expected *pq.Error, got %#+v", err)
  1393  	}
  1394  	if pge.Code.Class() != "22" {
  1395  		t.Fatalf("expected class 28, got %v", pge.Code.Class())
  1396  	}
  1397  	if pge.Code.Class().Name() != "data_exception" {
  1398  		t.Fatalf("expected data_exception, got %v", pge.Code.Class().Name())
  1399  	}
  1400  }
  1401  
  1402  func TestParseOpts(t *testing.T) {
  1403  	tests := []struct {
  1404  		in       string
  1405  		expected values
  1406  		valid    bool
  1407  	}{
  1408  		{"dbname=hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
  1409  		{"dbname=hello user=goodbye  ", values{"dbname": "hello", "user": "goodbye"}, true},
  1410  		{"dbname = hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
  1411  		{"dbname=hello user =goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
  1412  		{"dbname=hello user= goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
  1413  		{"host=localhost password='correct horse battery staple'", values{"host": "localhost", "password": "correct horse battery staple"}, true},
  1414  		{"dbname=データベース password=パスワード", values{"dbname": "データベース", "password": "パスワード"}, true},
  1415  		{"dbname=hello user=''", values{"dbname": "hello", "user": ""}, true},
  1416  		{"user='' dbname=hello", values{"dbname": "hello", "user": ""}, true},
  1417  		// The last option value is an empty string if there's no non-whitespace after its =
  1418  		{"dbname=hello user=   ", values{"dbname": "hello", "user": ""}, true},
  1419  
  1420  		// The parser ignores spaces after = and interprets the next set of non-whitespace characters as the value.
  1421  		{"user= password=foo", values{"user": "password=foo"}, true},
  1422  
  1423  		// Backslash escapes next char
  1424  		{`user=a\ \'\\b`, values{"user": `a '\b`}, true},
  1425  		{`user='a \'b'`, values{"user": `a 'b`}, true},
  1426  
  1427  		// Incomplete escape
  1428  		{`user=x\`, values{}, false},
  1429  
  1430  		// No '=' after the key
  1431  		{"postgre://marko@internet", values{}, false},
  1432  		{"dbname user=goodbye", values{}, false},
  1433  		{"user=foo blah", values{}, false},
  1434  		{"user=foo blah   ", values{}, false},
  1435  
  1436  		// Unterminated quoted value
  1437  		{"dbname=hello user='unterminated", values{}, false},
  1438  	}
  1439  
  1440  	for _, test := range tests {
  1441  		o := make(values)
  1442  		err := parseOpts(test.in, o)
  1443  
  1444  		switch {
  1445  		case err != nil && test.valid:
  1446  			t.Errorf("%q got unexpected error: %s", test.in, err)
  1447  		case err == nil && test.valid && !reflect.DeepEqual(test.expected, o):
  1448  			t.Errorf("%q got: %#v want: %#v", test.in, o, test.expected)
  1449  		case err == nil && !test.valid:
  1450  			t.Errorf("%q expected an error", test.in)
  1451  		}
  1452  	}
  1453  }
  1454  
  1455  func TestRuntimeParameters(t *testing.T) {
  1456  	tests := []struct {
  1457  		conninfo string
  1458  		param    string
  1459  		expected string
  1460  		success  bool
  1461  	}{
  1462  		// invalid parameter
  1463  		{"DOESNOTEXIST=foo", "", "", false},
  1464  		// we can only work with a specific value for these two
  1465  		{"client_encoding=SQL_ASCII", "", "", false},
  1466  		{"datestyle='ISO, YDM'", "", "", false},
  1467  		// "options" should work exactly as it does in libpq
  1468  		{"options='-c search_path=pqgotest'", "search_path", "pqgotest", true},
  1469  		// pq should override client_encoding in this case
  1470  		{"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", true},
  1471  		// allow client_encoding to be set explicitly
  1472  		{"client_encoding=UTF8", "client_encoding", "UTF8", true},
  1473  		// test a runtime parameter not supported by libpq
  1474  		{"work_mem='139kB'", "work_mem", "139kB", true},
  1475  		// test fallback_application_name
  1476  		{"application_name=foo fallback_application_name=bar", "application_name", "foo", true},
  1477  		{"application_name='' fallback_application_name=bar", "application_name", "", true},
  1478  		{"fallback_application_name=bar", "application_name", "bar", true},
  1479  	}
  1480  
  1481  	for _, test := range tests {
  1482  		db, err := openTestConnConninfo(test.conninfo)
  1483  		if err != nil {
  1484  			t.Fatal(err)
  1485  		}
  1486  
  1487  		// application_name didn't exist before 9.0
  1488  		if test.param == "application_name" && getServerVersion(t, db) < 90000 {
  1489  			db.Close()
  1490  			continue
  1491  		}
  1492  
  1493  		tryGetParameterValue := func() (value string, success bool) {
  1494  			defer db.Close()
  1495  			row := db.QueryRow("SELECT current_setting($1)", test.param)
  1496  			err = row.Scan(&value)
  1497  			if err != nil {
  1498  				return "", false
  1499  			}
  1500  			return value, true
  1501  		}
  1502  
  1503  		value, success := tryGetParameterValue()
  1504  		if success != test.success && !test.success {
  1505  			t.Fatalf("%v: unexpected error: %v", test.conninfo, err)
  1506  		}
  1507  		if success != test.success {
  1508  			t.Fatalf("unexpected outcome %v (was expecting %v) for conninfo \"%s\"",
  1509  				success, test.success, test.conninfo)
  1510  		}
  1511  		if value != test.expected {
  1512  			t.Fatalf("bad value for %s: got %s, want %s with conninfo \"%s\"",
  1513  				test.param, value, test.expected, test.conninfo)
  1514  		}
  1515  	}
  1516  }
  1517  
  1518  func TestIsUTF8(t *testing.T) {
  1519  	var cases = []struct {
  1520  		name string
  1521  		want bool
  1522  	}{
  1523  		{"unicode", true},
  1524  		{"utf-8", true},
  1525  		{"utf_8", true},
  1526  		{"UTF-8", true},
  1527  		{"UTF8", true},
  1528  		{"utf8", true},
  1529  		{"u n ic_ode", true},
  1530  		{"ut_f%8", true},
  1531  		{"ubf8", false},
  1532  		{"punycode", false},
  1533  	}
  1534  
  1535  	for _, test := range cases {
  1536  		if g := isUTF8(test.name); g != test.want {
  1537  			t.Errorf("isUTF8(%q) = %v want %v", test.name, g, test.want)
  1538  		}
  1539  	}
  1540  }
  1541  
  1542  func TestQuoteIdentifier(t *testing.T) {
  1543  	var cases = []struct {
  1544  		input string
  1545  		want  string
  1546  	}{
  1547  		{`foo`, `"foo"`},
  1548  		{`foo bar baz`, `"foo bar baz"`},
  1549  		{`foo"bar`, `"foo""bar"`},
  1550  		{"foo\x00bar", `"foo"`},
  1551  		{"\x00foo", `""`},
  1552  	}
  1553  
  1554  	for _, test := range cases {
  1555  		got := QuoteIdentifier(test.input)
  1556  		if got != test.want {
  1557  			t.Errorf("QuoteIdentifier(%q) = %v want %v", test.input, got, test.want)
  1558  		}
  1559  	}
  1560  }
  1561  
  1562  func TestQuoteLiteral(t *testing.T) {
  1563  	var cases = []struct {
  1564  		input string
  1565  		want  string
  1566  	}{
  1567  		{`foo`, `'foo'`},
  1568  		{`foo bar baz`, `'foo bar baz'`},
  1569  		{`foo'bar`, `'foo''bar'`},
  1570  		{`foo\bar`, ` E'foo\\bar'`},
  1571  		{`foo\ba'r`, ` E'foo\\ba''r'`},
  1572  		{`foo"bar`, `'foo"bar'`},
  1573  		{`foo\x00bar`, ` E'foo\\x00bar'`},
  1574  		{`\x00foo`, ` E'\\x00foo'`},
  1575  		{`'`, `''''`},
  1576  		{`''`, `''''''`},
  1577  		{`\`, ` E'\\'`},
  1578  		{`'abc'; DROP TABLE users;`, `'''abc''; DROP TABLE users;'`},
  1579  		{`\'`, ` E'\\'''`},
  1580  		{`E'\''`, ` E'E''\\'''''`},
  1581  		{`e'\''`, ` E'e''\\'''''`},
  1582  		{`E'\'abc\'; DROP TABLE users;'`, ` E'E''\\''abc\\''; DROP TABLE users;'''`},
  1583  		{`e'\'abc\'; DROP TABLE users;'`, ` E'e''\\''abc\\''; DROP TABLE users;'''`},
  1584  	}
  1585  
  1586  	for _, test := range cases {
  1587  		got := QuoteLiteral(test.input)
  1588  		if got != test.want {
  1589  			t.Errorf("QuoteLiteral(%q) = %v want %v", test.input, got, test.want)
  1590  		}
  1591  	}
  1592  }
  1593  
  1594  func TestRowsResultTag(t *testing.T) {
  1595  	type ResultTag interface {
  1596  		Result() driver.Result
  1597  		Tag() string
  1598  	}
  1599  
  1600  	tests := []struct {
  1601  		query string
  1602  		tag   string
  1603  		ra    int64
  1604  	}{
  1605  		{
  1606  			query: "CREATE TEMP TABLE temp (a int)",
  1607  			tag:   "CREATE TABLE",
  1608  		},
  1609  		{
  1610  			query: "INSERT INTO temp VALUES (1), (2)",
  1611  			tag:   "INSERT",
  1612  			ra:    2,
  1613  		},
  1614  		{
  1615  			query: "SELECT 1",
  1616  		},
  1617  		// A SELECT anywhere should take precedent.
  1618  		{
  1619  			query: "SELECT 1; INSERT INTO temp VALUES (1), (2)",
  1620  		},
  1621  		{
  1622  			query: "INSERT INTO temp VALUES (1), (2); SELECT 1",
  1623  		},
  1624  		// Multiple statements that don't return rows should return the last tag.
  1625  		{
  1626  			query: "CREATE TEMP TABLE t (a int); DROP TABLE t",
  1627  			tag:   "DROP TABLE",
  1628  		},
  1629  		// Ensure a rows-returning query in any position among various tags-returing
  1630  		// statements will prefer the rows.
  1631  		{
  1632  			query: "SELECT 1; CREATE TEMP TABLE t (a int); DROP TABLE t",
  1633  		},
  1634  		{
  1635  			query: "CREATE TEMP TABLE t (a int); SELECT 1; DROP TABLE t",
  1636  		},
  1637  		{
  1638  			query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1",
  1639  		},
  1640  	}
  1641  
  1642  	// If this is the only test run, this will correct the connection string.
  1643  	openTestConn(t).Close()
  1644  
  1645  	conn, err := Open("")
  1646  	if err != nil {
  1647  		t.Fatal(err)
  1648  	}
  1649  	defer conn.Close()
  1650  	q := conn.(driver.QueryerContext)
  1651  
  1652  	for _, test := range tests {
  1653  		if rows, err := q.QueryContext(context.Background(), test.query, nil); err != nil {
  1654  			t.Fatalf("%s: %s", test.query, err)
  1655  		} else {
  1656  			r := rows.(ResultTag)
  1657  			if tag := r.Tag(); tag != test.tag {
  1658  				t.Fatalf("%s: unexpected tag %q", test.query, tag)
  1659  			}
  1660  			res := r.Result()
  1661  			if ra, _ := res.RowsAffected(); ra != test.ra {
  1662  				t.Fatalf("%s: unexpected rows affected: %d", test.query, ra)
  1663  			}
  1664  			rows.Close()
  1665  		}
  1666  	}
  1667  }
  1668  
  1669  // TestQuickClose tests that closing a query early allows a subsequent query to work.
  1670  func TestQuickClose(t *testing.T) {
  1671  	db := openTestConn(t)
  1672  	defer db.Close()
  1673  
  1674  	tx, err := db.Begin()
  1675  	if err != nil {
  1676  		t.Fatal(err)
  1677  	}
  1678  	rows, err := tx.Query("SELECT 1; SELECT 2;")
  1679  	if err != nil {
  1680  		t.Fatal(err)
  1681  	}
  1682  	if err := rows.Close(); err != nil {
  1683  		t.Fatal(err)
  1684  	}
  1685  
  1686  	var id int
  1687  	if err := tx.QueryRow("SELECT 3").Scan(&id); err != nil {
  1688  		t.Fatal(err)
  1689  	}
  1690  	if id != 3 {
  1691  		t.Fatalf("unexpected %d", id)
  1692  	}
  1693  	if err := tx.Commit(); err != nil {
  1694  		t.Fatal(err)
  1695  	}
  1696  }
  1697  
  1698  func TestMultipleResult(t *testing.T) {
  1699  	db := openTestConn(t)
  1700  	defer db.Close()
  1701  
  1702  	rows, err := db.Query(`
  1703  		begin;
  1704  			select * from information_schema.tables limit 1;
  1705  			select * from information_schema.columns limit 2;
  1706  		commit;
  1707  	`)
  1708  	if err != nil {
  1709  		t.Fatal(err)
  1710  	}
  1711  	type set struct {
  1712  		cols     []string
  1713  		rowCount int
  1714  	}
  1715  	buf := []*set{}
  1716  	for {
  1717  		cols, err := rows.Columns()
  1718  		if err != nil {
  1719  			t.Fatal(err)
  1720  		}
  1721  		s := &set{
  1722  			cols: cols,
  1723  		}
  1724  		buf = append(buf, s)
  1725  
  1726  		for rows.Next() {
  1727  			s.rowCount++
  1728  		}
  1729  		if !rows.NextResultSet() {
  1730  			break
  1731  		}
  1732  	}
  1733  	if len(buf) != 2 {
  1734  		t.Fatalf("got %d sets, expected 2", len(buf))
  1735  	}
  1736  	if len(buf[0].cols) == len(buf[1].cols) || len(buf[1].cols) == 0 {
  1737  		t.Fatal("invalid cols size, expected different column count and greater then zero")
  1738  	}
  1739  	if buf[0].rowCount != 1 || buf[1].rowCount != 2 {
  1740  		t.Fatal("incorrect number of rows returned")
  1741  	}
  1742  }
  1743  
  1744  func TestMultipleEmptyResult(t *testing.T) {
  1745  	db := openTestConn(t)
  1746  	defer db.Close()
  1747  
  1748  	rows, err := db.Query("select 1 where false; select 2")
  1749  	if err != nil {
  1750  		t.Fatal(err)
  1751  	}
  1752  	defer rows.Close()
  1753  
  1754  	for rows.Next() {
  1755  		t.Fatal("unexpected row")
  1756  	}
  1757  	if !rows.NextResultSet() {
  1758  		t.Fatal("expected more result sets", rows.Err())
  1759  	}
  1760  	for rows.Next() {
  1761  		var i int
  1762  		if err := rows.Scan(&i); err != nil {
  1763  			t.Fatal(err)
  1764  		}
  1765  		if i != 2 {
  1766  			t.Fatalf("expected 2, got %d", i)
  1767  		}
  1768  	}
  1769  	if rows.NextResultSet() {
  1770  		t.Fatal("unexpected result set")
  1771  	}
  1772  }
  1773  
  1774  func TestCopyInStmtAffectedRows(t *testing.T) {
  1775  	db := openTestConn(t)
  1776  	defer db.Close()
  1777  
  1778  	_, err := db.Exec("CREATE TEMP TABLE temp (a int)")
  1779  	if err != nil {
  1780  		t.Fatal(err)
  1781  	}
  1782  
  1783  	txn, err := db.BeginTx(context.TODO(), nil)
  1784  	if err != nil {
  1785  		t.Fatal(err)
  1786  	}
  1787  
  1788  	copyStmt, err := txn.Prepare(CopyIn("temp", "a"))
  1789  	if err != nil {
  1790  		t.Fatal(err)
  1791  	}
  1792  
  1793  	res, err := copyStmt.Exec()
  1794  	if err != nil {
  1795  		t.Fatal(err)
  1796  	}
  1797  
  1798  	res.RowsAffected()
  1799  	res.LastInsertId()
  1800  }
  1801  
  1802  func TestConnPrepareContext(t *testing.T) {
  1803  	db := openTestConn(t)
  1804  	defer db.Close()
  1805  
  1806  	tests := []struct {
  1807  		name string
  1808  		ctx  func() (context.Context, context.CancelFunc)
  1809  		sql  string
  1810  		err  error
  1811  	}{
  1812  		{
  1813  			name: "context.Background",
  1814  			ctx: func() (context.Context, context.CancelFunc) {
  1815  				return context.Background(), nil
  1816  			},
  1817  			sql: "SELECT 1",
  1818  			err: nil,
  1819  		},
  1820  		{
  1821  			name: "context.WithTimeout exceeded",
  1822  			ctx: func() (context.Context, context.CancelFunc) {
  1823  				return context.WithTimeout(context.Background(), -time.Minute)
  1824  			},
  1825  			sql: "SELECT 1",
  1826  			err: context.DeadlineExceeded,
  1827  		},
  1828  		{
  1829  			name: "context.WithTimeout",
  1830  			ctx: func() (context.Context, context.CancelFunc) {
  1831  				return context.WithTimeout(context.Background(), time.Minute)
  1832  			},
  1833  			sql: "SELECT 1",
  1834  			err: nil,
  1835  		},
  1836  	}
  1837  	for _, tt := range tests {
  1838  		t.Run(tt.name, func(t *testing.T) {
  1839  			ctx, cancel := tt.ctx()
  1840  			if cancel != nil {
  1841  				defer cancel()
  1842  			}
  1843  			_, err := db.PrepareContext(ctx, tt.sql)
  1844  			switch {
  1845  			case (err != nil) != (tt.err != nil):
  1846  				t.Fatalf("conn.PrepareContext() unexpected nil err got = %v, expected = %v", err, tt.err)
  1847  			case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
  1848  				t.Errorf("conn.PrepareContext() got = %v, expected = %v", err.Error(), tt.err.Error())
  1849  			}
  1850  		})
  1851  	}
  1852  }
  1853  
  1854  func TestStmtQueryContext(t *testing.T) {
  1855  	db := openTestConn(t)
  1856  	defer db.Close()
  1857  
  1858  	tests := []struct {
  1859  		name           string
  1860  		ctx            func() (context.Context, context.CancelFunc)
  1861  		sql            string
  1862  		cancelExpected bool
  1863  	}{
  1864  		{
  1865  			name: "context.Background",
  1866  			ctx: func() (context.Context, context.CancelFunc) {
  1867  				return context.Background(), nil
  1868  			},
  1869  			sql:            "SELECT pg_sleep(1);",
  1870  			cancelExpected: false,
  1871  		},
  1872  		{
  1873  			name: "context.WithTimeout exceeded",
  1874  			ctx: func() (context.Context, context.CancelFunc) {
  1875  				return context.WithTimeout(context.Background(), 1*time.Second)
  1876  			},
  1877  			sql:            "SELECT pg_sleep(10);",
  1878  			cancelExpected: true,
  1879  		},
  1880  		{
  1881  			name: "context.WithTimeout",
  1882  			ctx: func() (context.Context, context.CancelFunc) {
  1883  				return context.WithTimeout(context.Background(), time.Minute)
  1884  			},
  1885  			sql:            "SELECT pg_sleep(1);",
  1886  			cancelExpected: false,
  1887  		},
  1888  	}
  1889  	for _, tt := range tests {
  1890  		t.Run(tt.name, func(t *testing.T) {
  1891  			ctx, cancel := tt.ctx()
  1892  			if cancel != nil {
  1893  				defer cancel()
  1894  			}
  1895  			stmt, err := db.PrepareContext(ctx, tt.sql)
  1896  			if err != nil {
  1897  				t.Fatal(err)
  1898  			}
  1899  			_, err = stmt.QueryContext(ctx)
  1900  			pgErr := (*Error)(nil)
  1901  			switch {
  1902  			case (err != nil) != tt.cancelExpected:
  1903  				t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
  1904  			case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
  1905  				t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
  1906  			}
  1907  		})
  1908  	}
  1909  }
  1910  
  1911  func TestStmtExecContext(t *testing.T) {
  1912  	db := openTestConn(t)
  1913  	defer db.Close()
  1914  
  1915  	tests := []struct {
  1916  		name           string
  1917  		ctx            func() (context.Context, context.CancelFunc)
  1918  		sql            string
  1919  		cancelExpected bool
  1920  	}{
  1921  		{
  1922  			name: "context.Background",
  1923  			ctx: func() (context.Context, context.CancelFunc) {
  1924  				return context.Background(), nil
  1925  			},
  1926  			sql:            "SELECT pg_sleep(1);",
  1927  			cancelExpected: false,
  1928  		},
  1929  		{
  1930  			name: "context.WithTimeout exceeded",
  1931  			ctx: func() (context.Context, context.CancelFunc) {
  1932  				return context.WithTimeout(context.Background(), 1*time.Second)
  1933  			},
  1934  			sql:            "SELECT pg_sleep(10);",
  1935  			cancelExpected: true,
  1936  		},
  1937  		{
  1938  			name: "context.WithTimeout",
  1939  			ctx: func() (context.Context, context.CancelFunc) {
  1940  				return context.WithTimeout(context.Background(), time.Minute)
  1941  			},
  1942  			sql:            "SELECT pg_sleep(1);",
  1943  			cancelExpected: false,
  1944  		},
  1945  	}
  1946  	for _, tt := range tests {
  1947  		t.Run(tt.name, func(t *testing.T) {
  1948  			ctx, cancel := tt.ctx()
  1949  			if cancel != nil {
  1950  				defer cancel()
  1951  			}
  1952  			stmt, err := db.PrepareContext(ctx, tt.sql)
  1953  			if err != nil {
  1954  				t.Fatal(err)
  1955  			}
  1956  			_, err = stmt.ExecContext(ctx)
  1957  			pgErr := (*Error)(nil)
  1958  			switch {
  1959  			case (err != nil) != tt.cancelExpected:
  1960  				t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
  1961  			case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
  1962  				t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
  1963  			}
  1964  		})
  1965  	}
  1966  }
  1967  

View as plain text