1 package pq
2
3 import (
4 "context"
5 "database/sql"
6 "database/sql/driver"
7 "errors"
8 "fmt"
9 "io"
10 "net"
11 "os"
12 "reflect"
13 "strings"
14 "testing"
15 "time"
16 )
17
18 type Fatalistic interface {
19 Fatal(args ...interface{})
20 }
21
22 func forceBinaryParameters() bool {
23 bp := os.Getenv("PQTEST_BINARY_PARAMETERS")
24 if bp == "yes" {
25 return true
26 } else if bp == "" || bp == "no" {
27 return false
28 } else {
29 panic("unexpected value for PQTEST_BINARY_PARAMETERS")
30 }
31 }
32
33 func testConninfo(conninfo string) string {
34 defaultTo := func(envvar string, value string) {
35 if os.Getenv(envvar) == "" {
36 os.Setenv(envvar, value)
37 }
38 }
39 defaultTo("PGDATABASE", "pqgotest")
40 defaultTo("PGSSLMODE", "disable")
41 defaultTo("PGCONNECT_TIMEOUT", "20")
42
43 if forceBinaryParameters() &&
44 !strings.HasPrefix(conninfo, "postgres://") &&
45 !strings.HasPrefix(conninfo, "postgresql://") {
46 conninfo += " binary_parameters=yes"
47 }
48 return conninfo
49 }
50
51 func openTestConnConninfo(conninfo string) (*sql.DB, error) {
52 return sql.Open("postgres", testConninfo(conninfo))
53 }
54
55 func openTestConn(t Fatalistic) *sql.DB {
56 conn, err := openTestConnConninfo("")
57 if err != nil {
58 t.Fatal(err)
59 }
60
61 return conn
62 }
63
64 func getServerVersion(t *testing.T, db *sql.DB) int {
65 var version int
66 err := db.QueryRow("SHOW server_version_num").Scan(&version)
67 if err != nil {
68 t.Fatal(err)
69 }
70 return version
71 }
72
73 func TestReconnect(t *testing.T) {
74 db1 := openTestConn(t)
75 defer db1.Close()
76 tx, err := db1.Begin()
77 if err != nil {
78 t.Fatal(err)
79 }
80 var pid1 int
81 err = tx.QueryRow("SELECT pg_backend_pid()").Scan(&pid1)
82 if err != nil {
83 t.Fatal(err)
84 }
85 db2 := openTestConn(t)
86 defer db2.Close()
87 _, err = db2.Exec("SELECT pg_terminate_backend($1)", pid1)
88 if err != nil {
89 t.Fatal(err)
90 }
91
92
93 _ = tx.Rollback()
94
95 const expected int = 42
96 var result int
97 err = db1.QueryRow(fmt.Sprintf("SELECT %d", expected)).Scan(&result)
98 if err != nil {
99 t.Fatal(err)
100 }
101 if result != expected {
102 t.Errorf("got %v; expected %v", result, expected)
103 }
104 }
105
106 func TestCommitInFailedTransaction(t *testing.T) {
107 db := openTestConn(t)
108 defer db.Close()
109
110 txn, err := db.Begin()
111 if err != nil {
112 t.Fatal(err)
113 }
114 rows, err := txn.Query("SELECT error")
115 if err == nil {
116 rows.Close()
117 t.Fatal("expected failure")
118 }
119 err = txn.Commit()
120 if err != ErrInFailedTransaction {
121 t.Fatalf("expected ErrInFailedTransaction; got %#v", err)
122 }
123 }
124
125 func TestOpenURL(t *testing.T) {
126 testURL := func(url string) {
127 db, err := openTestConnConninfo(url)
128 if err != nil {
129 t.Fatal(err)
130 }
131 defer db.Close()
132
133
134 txn, err := db.Begin()
135 if err != nil {
136 t.Fatal(err)
137 }
138 txn.Rollback()
139 }
140 testURL("postgres://")
141 testURL("postgresql://")
142 }
143
144 const pgpassFile = "/tmp/pqgotest_pgpass"
145
146 func TestPgpass(t *testing.T) {
147 testAssert := func(conninfo string, expected string, reason string) {
148 conn, err := openTestConnConninfo(conninfo)
149 if err != nil {
150 t.Fatal(err)
151 }
152 defer conn.Close()
153
154 txn, err := conn.Begin()
155 if err != nil {
156 if expected != "fail" {
157 t.Fatalf(reason, err)
158 }
159 return
160 }
161 rows, err := txn.Query("SELECT USER")
162 if err != nil {
163 txn.Rollback()
164 if expected != "fail" {
165 t.Fatalf(reason, err)
166 }
167 } else {
168 rows.Close()
169 if expected != "ok" {
170 t.Fatalf(reason, err)
171 }
172 }
173 txn.Rollback()
174 }
175 testAssert("", "ok", "missing .pgpass, unexpected error %#v")
176 os.Setenv("PGPASSFILE", pgpassFile)
177 testAssert("host=/tmp", "fail", ", unexpected error %#v")
178 os.Remove(pgpassFile)
179 pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644)
180 if err != nil {
181 t.Fatalf("Unexpected error writing pgpass file %#v", err)
182 }
183 _, err = pgpass.WriteString(`# comment
184 server:5432:some_db:some_user:pass_A
185 *:5432:some_db:some_user:pass_B
186 localhost:*:*:*:pass_C
187 *:*:*:*:pass_fallback
188 `)
189 if err != nil {
190 t.Fatalf("Unexpected error writing pgpass file %#v", err)
191 }
192 pgpass.Close()
193
194 assertPassword := func(extra values, expected string) {
195 o := values{
196 "host": "localhost",
197 "sslmode": "disable",
198 "connect_timeout": "20",
199 "user": "majid",
200 "port": "5432",
201 "extra_float_digits": "2",
202 "dbname": "pqgotest",
203 "client_encoding": "UTF8",
204 "datestyle": "ISO, MDY",
205 }
206 for k, v := range extra {
207 o[k] = v
208 }
209 (&conn{}).handlePgpass(o)
210 if pw := o["password"]; pw != expected {
211 t.Fatalf("For %v expected %s got %s", extra, expected, pw)
212 }
213 }
214
215 assertPassword(values{"host": "example.com", "user": "foo"}, "")
216
217 os.Chmod(pgpassFile, 0600)
218 assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A")
219 assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback")
220 assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B")
221
222 assertPassword(values{"host": "", "user": "some_user"}, "pass_C")
223 assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C")
224
225 os.Remove(pgpassFile)
226 os.Setenv("PGPASSFILE", "")
227 }
228
229 func TestExec(t *testing.T) {
230 db := openTestConn(t)
231 defer db.Close()
232
233 _, err := db.Exec("CREATE TEMP TABLE temp (a int)")
234 if err != nil {
235 t.Fatal(err)
236 }
237
238 r, err := db.Exec("INSERT INTO temp VALUES (1)")
239 if err != nil {
240 t.Fatal(err)
241 }
242
243 if n, _ := r.RowsAffected(); n != 1 {
244 t.Fatalf("expected 1 row affected, not %d", n)
245 }
246
247 r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3)
248 if err != nil {
249 t.Fatal(err)
250 }
251
252 if n, _ := r.RowsAffected(); n != 3 {
253 t.Fatalf("expected 3 rows affected, not %d", n)
254 }
255
256
257
258 if getServerVersion(t, db) >= 90000 {
259 r, err = db.Exec("SELECT g FROM generate_series(1, 2) g")
260 if err != nil {
261 t.Fatal(err)
262 }
263 if n, _ := r.RowsAffected(); n != 2 {
264 t.Fatalf("expected 2 rows affected, not %d", n)
265 }
266
267 r, err = db.Exec("SELECT g FROM generate_series(1, $1) g", 3)
268 if err != nil {
269 t.Fatal(err)
270 }
271 if n, _ := r.RowsAffected(); n != 3 {
272 t.Fatalf("expected 3 rows affected, not %d", n)
273 }
274 }
275 }
276
277 func TestStatment(t *testing.T) {
278 db := openTestConn(t)
279 defer db.Close()
280
281 st, err := db.Prepare("SELECT 1")
282 if err != nil {
283 t.Fatal(err)
284 }
285
286 st1, err := db.Prepare("SELECT 2")
287 if err != nil {
288 t.Fatal(err)
289 }
290
291 r, err := st.Query()
292 if err != nil {
293 t.Fatal(err)
294 }
295 defer r.Close()
296
297 if !r.Next() {
298 t.Fatal("expected row")
299 }
300
301 var i int
302 err = r.Scan(&i)
303 if err != nil {
304 t.Fatal(err)
305 }
306
307 if i != 1 {
308 t.Fatalf("expected 1, got %d", i)
309 }
310
311
312
313 r1, err := st1.Query()
314 if err != nil {
315 t.Fatal(err)
316 }
317 defer r1.Close()
318
319 if !r1.Next() {
320 if r.Err() != nil {
321 t.Fatal(r1.Err())
322 }
323 t.Fatal("expected row")
324 }
325
326 err = r1.Scan(&i)
327 if err != nil {
328 t.Fatal(err)
329 }
330
331 if i != 2 {
332 t.Fatalf("expected 2, got %d", i)
333 }
334 }
335
336 func TestRowsCloseBeforeDone(t *testing.T) {
337 db := openTestConn(t)
338 defer db.Close()
339
340 r, err := db.Query("SELECT 1")
341 if err != nil {
342 t.Fatal(err)
343 }
344
345 err = r.Close()
346 if err != nil {
347 t.Fatal(err)
348 }
349
350 if r.Next() {
351 t.Fatal("unexpected row")
352 }
353
354 if r.Err() != nil {
355 t.Fatal(r.Err())
356 }
357 }
358
359 func TestParameterCountMismatch(t *testing.T) {
360 db := openTestConn(t)
361 defer db.Close()
362
363 var notused int
364 err := db.QueryRow("SELECT false", 1).Scan(¬used)
365 if err == nil {
366 t.Fatal("expected err")
367 }
368
369 err = db.QueryRow("SELECT 1").Scan(¬used)
370 if err != nil {
371 t.Fatal(err)
372 }
373
374 err = db.QueryRow("SELECT $1").Scan(¬used)
375 if err == nil {
376 t.Fatal("expected err")
377 }
378
379 err = db.QueryRow("SELECT 1").Scan(¬used)
380 if err != nil {
381 t.Fatal(err)
382 }
383 }
384
385
386 func TestEmptyQuery(t *testing.T) {
387 db := openTestConn(t)
388 defer db.Close()
389
390 res, err := db.Exec("")
391 if err != nil {
392 t.Fatal(err)
393 }
394 if _, err := res.RowsAffected(); err != errNoRowsAffected {
395 t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
396 }
397 if _, err := res.LastInsertId(); err != errNoLastInsertID {
398 t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
399 }
400 rows, err := db.Query("")
401 if err != nil {
402 t.Fatal(err)
403 }
404 cols, err := rows.Columns()
405 if err != nil {
406 t.Fatal(err)
407 }
408 if len(cols) != 0 {
409 t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
410 }
411 if rows.Next() {
412 t.Fatal("unexpected row")
413 }
414 if rows.Err() != nil {
415 t.Fatal(rows.Err())
416 }
417
418 stmt, err := db.Prepare("")
419 if err != nil {
420 t.Fatal(err)
421 }
422 res, err = stmt.Exec()
423 if err != nil {
424 t.Fatal(err)
425 }
426 if _, err := res.RowsAffected(); err != errNoRowsAffected {
427 t.Fatalf("expected %s, got %v", errNoRowsAffected, err)
428 }
429 if _, err := res.LastInsertId(); err != errNoLastInsertID {
430 t.Fatalf("expected %s, got %v", errNoLastInsertID, err)
431 }
432 rows, err = stmt.Query()
433 if err != nil {
434 t.Fatal(err)
435 }
436 cols, err = rows.Columns()
437 if err != nil {
438 t.Fatal(err)
439 }
440 if len(cols) != 0 {
441 t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
442 }
443 if rows.Next() {
444 t.Fatal("unexpected row")
445 }
446 if rows.Err() != nil {
447 t.Fatal(rows.Err())
448 }
449 }
450
451
452 func TestEmptyResultSetColumns(t *testing.T) {
453 db := openTestConn(t)
454 defer db.Close()
455
456 rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar WHERE FALSE")
457 if err != nil {
458 t.Fatal(err)
459 }
460 cols, err := rows.Columns()
461 if err != nil {
462 t.Fatal(err)
463 }
464 if len(cols) != 2 {
465 t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
466 }
467 if rows.Next() {
468 t.Fatal("unexpected row")
469 }
470 if rows.Err() != nil {
471 t.Fatal(rows.Err())
472 }
473 if cols[0] != "a" || cols[1] != "bar" {
474 t.Fatalf("unexpected Columns result %v", cols)
475 }
476
477 stmt, err := db.Prepare("SELECT $1::int AS a, text 'bar' AS bar WHERE FALSE")
478 if err != nil {
479 t.Fatal(err)
480 }
481 rows, err = stmt.Query(1)
482 if err != nil {
483 t.Fatal(err)
484 }
485 cols, err = rows.Columns()
486 if err != nil {
487 t.Fatal(err)
488 }
489 if len(cols) != 2 {
490 t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols))
491 }
492 if rows.Next() {
493 t.Fatal("unexpected row")
494 }
495 if rows.Err() != nil {
496 t.Fatal(rows.Err())
497 }
498 if cols[0] != "a" || cols[1] != "bar" {
499 t.Fatalf("unexpected Columns result %v", cols)
500 }
501
502 }
503
504 func TestEncodeDecode(t *testing.T) {
505 db := openTestConn(t)
506 defer db.Close()
507
508 q := `
509 SELECT
510 E'\\000\\001\\002'::bytea,
511 'foobar'::text,
512 NULL::integer,
513 '2000-1-1 01:02:03.04-7'::timestamptz,
514 0::boolean,
515 123,
516 -321,
517 3.14::float8
518 WHERE
519 E'\\000\\001\\002'::bytea = $1
520 AND 'foobar'::text = $2
521 AND $3::integer is NULL
522 `
523
524
525 exp1 := []byte{0, 1, 2}
526 exp2 := "foobar"
527
528 r, err := db.Query(q, exp1, exp2, nil)
529 if err != nil {
530 t.Fatal(err)
531 }
532 defer r.Close()
533
534 if !r.Next() {
535 if r.Err() != nil {
536 t.Fatal(r.Err())
537 }
538 t.Fatal("expected row")
539 }
540
541 var got1 []byte
542 var got2 string
543 var got3 = sql.NullInt64{Valid: true}
544 var got4 time.Time
545 var got5, got6, got7, got8 interface{}
546
547 err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7, &got8)
548 if err != nil {
549 t.Fatal(err)
550 }
551
552 if !reflect.DeepEqual(exp1, got1) {
553 t.Errorf("expected %q byte: %q", exp1, got1)
554 }
555
556 if !reflect.DeepEqual(exp2, got2) {
557 t.Errorf("expected %q byte: %q", exp2, got2)
558 }
559
560 if got3.Valid {
561 t.Fatal("expected invalid")
562 }
563
564 if got4.Year() != 2000 {
565 t.Fatal("wrong year")
566 }
567
568 if got5 != false {
569 t.Fatalf("expected false, got %q", got5)
570 }
571
572 if got6 != int64(123) {
573 t.Fatalf("expected 123, got %d", got6)
574 }
575
576 if got7 != int64(-321) {
577 t.Fatalf("expected -321, got %d", got7)
578 }
579
580 if got8 != float64(3.14) {
581 t.Fatalf("expected 3.14, got %f", got8)
582 }
583 }
584
585 func TestNoData(t *testing.T) {
586 db := openTestConn(t)
587 defer db.Close()
588
589 st, err := db.Prepare("SELECT 1 WHERE true = false")
590 if err != nil {
591 t.Fatal(err)
592 }
593 defer st.Close()
594
595 r, err := st.Query()
596 if err != nil {
597 t.Fatal(err)
598 }
599 defer r.Close()
600
601 if r.Next() {
602 if r.Err() != nil {
603 t.Fatal(r.Err())
604 }
605 t.Fatal("unexpected row")
606 }
607
608 _, err = db.Query("SELECT * FROM nonexistenttable WHERE age=$1", 20)
609 if err == nil {
610 t.Fatal("Should have raised an error on non existent table")
611 }
612
613 _, err = db.Query("SELECT * FROM nonexistenttable")
614 if err == nil {
615 t.Fatal("Should have raised an error on non existent table")
616 }
617 }
618
619 func TestErrorDuringStartup(t *testing.T) {
620
621
622 db, err := openTestConnConninfo("user=thisuserreallydoesntexist")
623 if err != nil {
624 t.Fatal(err)
625 }
626 defer db.Close()
627
628 _, err = db.Begin()
629 if err == nil {
630 t.Fatal("expected error")
631 }
632
633 e, ok := err.(*Error)
634 if !ok {
635 t.Fatalf("expected Error, got %#v", err)
636 } else if e.Code.Name() != "invalid_authorization_specification" && e.Code.Name() != "invalid_password" {
637 t.Fatalf("expected invalid_authorization_specification or invalid_password, got %s (%+v)", e.Code.Name(), err)
638 }
639 }
640
641 type testConn struct {
642 closed bool
643 net.Conn
644 }
645
646 func (c *testConn) Close() error {
647 c.closed = true
648 return c.Conn.Close()
649 }
650
651 type testDialer struct {
652 conns []*testConn
653 }
654
655 func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) {
656 c, err := net.Dial(ntw, addr)
657 if err != nil {
658 return nil, err
659 }
660 tc := &testConn{Conn: c}
661 d.conns = append(d.conns, tc)
662 return tc, nil
663 }
664
665 func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
666 c, err := net.DialTimeout(ntw, addr, timeout)
667 if err != nil {
668 return nil, err
669 }
670 tc := &testConn{Conn: c}
671 d.conns = append(d.conns, tc)
672 return tc, nil
673 }
674
675 func TestErrorDuringStartupClosesConn(t *testing.T) {
676
677
678 var d testDialer
679 c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist"))
680 if err == nil {
681 c.Close()
682 t.Fatal("expected dial error")
683 }
684 if len(d.conns) != 1 {
685 t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1)
686 }
687 if !d.conns[0].closed {
688 t.Error("connection leaked")
689 }
690 }
691
692 func TestBadConn(t *testing.T) {
693 var err error
694
695 cn := conn{}
696 func() {
697 defer cn.errRecover(&err)
698 panic(io.EOF)
699 }()
700 if err != driver.ErrBadConn {
701 t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
702 }
703 if err := cn.err.get(); err != driver.ErrBadConn {
704 t.Fatalf("expected driver.ErrBadConn, got %#v", err)
705 }
706
707 cn = conn{}
708 func() {
709 defer cn.errRecover(&err)
710 e := &Error{Severity: Efatal}
711 panic(e)
712 }()
713 if err != driver.ErrBadConn {
714 t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
715 }
716 if err := cn.err.get(); err != driver.ErrBadConn {
717 t.Fatalf("expected driver.ErrBadConn, got %#v", err)
718 }
719 }
720
721
722
723 func TestCloseBadConn(t *testing.T) {
724 host := os.Getenv("PGHOST")
725 if host == "" {
726 host = "localhost"
727 }
728 port := os.Getenv("PGPORT")
729 if port == "" {
730 port = "5432"
731 }
732 nc, err := net.Dial("tcp", host+":"+port)
733 if err != nil {
734 t.Fatal(err)
735 }
736 cn := conn{c: nc}
737 func() {
738 defer cn.errRecover(&err)
739 panic(io.EOF)
740 }()
741
742 if _, err := nc.Write(nil); err != nil {
743 t.Fatal(err)
744 }
745
746 if err := cn.Close(); err != nil {
747 t.Fatal(err)
748 }
749
750
751
752
753
754
755
756
757
758 const errClosing = "use of closed"
759
760
761 if _, err := nc.Write(nil); err == nil {
762 t.Fatal("expected error")
763 } else if !strings.Contains(err.Error(), errClosing) {
764 t.Fatalf("expected %s error, got %s", errClosing, err)
765 }
766
767 if err := cn.Close(); err == nil {
768 t.Fatal("expected error")
769 } else if !strings.Contains(err.Error(), errClosing) {
770 t.Fatalf("expected %s error, got %s", errClosing, err)
771 }
772 }
773
774 func TestErrorOnExec(t *testing.T) {
775 db := openTestConn(t)
776 defer db.Close()
777
778 txn, err := db.Begin()
779 if err != nil {
780 t.Fatal(err)
781 }
782 defer txn.Rollback()
783
784 _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)")
785 if err != nil {
786 t.Fatal(err)
787 }
788
789 _, err = txn.Exec("INSERT INTO foo VALUES (0), (0)")
790 if err == nil {
791 t.Fatal("Should have raised error")
792 }
793
794 e, ok := err.(*Error)
795 if !ok {
796 t.Fatalf("expected Error, got %#v", err)
797 } else if e.Code.Name() != "unique_violation" {
798 t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err)
799 }
800 }
801
802 func TestErrorOnQuery(t *testing.T) {
803 db := openTestConn(t)
804 defer db.Close()
805
806 txn, err := db.Begin()
807 if err != nil {
808 t.Fatal(err)
809 }
810 defer txn.Rollback()
811
812 _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)")
813 if err != nil {
814 t.Fatal(err)
815 }
816
817 _, err = txn.Query("INSERT INTO foo VALUES (0), (0)")
818 if err == nil {
819 t.Fatal("Should have raised error")
820 }
821
822 e, ok := err.(*Error)
823 if !ok {
824 t.Fatalf("expected Error, got %#v", err)
825 } else if e.Code.Name() != "unique_violation" {
826 t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err)
827 }
828 }
829
830 func TestErrorOnQueryRowSimpleQuery(t *testing.T) {
831 db := openTestConn(t)
832 defer db.Close()
833
834 txn, err := db.Begin()
835 if err != nil {
836 t.Fatal(err)
837 }
838 defer txn.Rollback()
839
840 _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)")
841 if err != nil {
842 t.Fatal(err)
843 }
844
845 var v int
846 err = txn.QueryRow("INSERT INTO foo VALUES (0), (0)").Scan(&v)
847 if err == nil {
848 t.Fatal("Should have raised error")
849 }
850
851 e, ok := err.(*Error)
852 if !ok {
853 t.Fatalf("expected Error, got %#v", err)
854 } else if e.Code.Name() != "unique_violation" {
855 t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err)
856 }
857 }
858
859
860 func TestQueryRowBugWorkaround(t *testing.T) {
861 db := openTestConn(t)
862 defer db.Close()
863
864
865 _, err := db.Exec("CREATE TEMP TABLE notnulltemp (a varchar(10) not null)")
866 if err != nil {
867 t.Fatal(err)
868 }
869
870 var a string
871 err = db.QueryRow("INSERT INTO notnulltemp(a) values($1) RETURNING a", nil).Scan(&a)
872 if err == sql.ErrNoRows {
873 t.Fatalf("expected constraint violation error; got: %v", err)
874 }
875 pge, ok := err.(*Error)
876 if !ok {
877 t.Fatalf("expected *Error; got: %#v", err)
878 }
879 if pge.Code.Name() != "not_null_violation" {
880 t.Fatalf("expected not_null_violation; got: %s (%+v)", pge.Code.Name(), err)
881 }
882
883
884 tx, err := db.Begin()
885 if err != nil {
886 t.Fatalf("unexpected error %s in Begin", err)
887 }
888 defer tx.Rollback()
889
890 _, err = tx.Exec("SET LOCAL check_function_bodies TO FALSE")
891 if err != nil {
892 t.Fatalf("could not disable check_function_bodies: %s", err)
893 }
894 _, err = tx.Exec(`
895 CREATE OR REPLACE FUNCTION bad_function()
896 RETURNS integer
897 -- hack to prevent the function from being inlined
898 SET check_function_bodies TO TRUE
899 AS $$
900 SELECT text 'bad'
901 $$ LANGUAGE sql`)
902 if err != nil {
903 t.Fatalf("could not create function: %s", err)
904 }
905
906 err = tx.QueryRow("SELECT * FROM bad_function()").Scan(&a)
907 if err == nil {
908 t.Fatalf("expected error")
909 }
910 pge, ok = err.(*Error)
911 if !ok {
912 t.Fatalf("expected *Error; got: %#v", err)
913 }
914 if pge.Code.Name() != "invalid_function_definition" {
915 t.Fatalf("expected invalid_function_definition; got: %s (%+v)", pge.Code.Name(), err)
916 }
917
918 err = tx.Rollback()
919 if err != nil {
920 t.Fatalf("unexpected error %s in Rollback", err)
921 }
922
923
924
925 rows, err := db.Query(`
926 select
927 (select generate_series(1, ss.i))
928 from (select gs.i
929 from generate_series(1, 2) gs(i)
930 order by gs.i limit 2) ss`)
931 if err != nil {
932 t.Fatalf("query failed: %s", err)
933 }
934 if !rows.Next() {
935 t.Fatalf("expected at least one result row; got %s", rows.Err())
936 }
937 var i int
938 err = rows.Scan(&i)
939 if err != nil {
940 t.Fatalf("rows.Scan() failed: %s", err)
941 }
942 if i != 1 {
943 t.Fatalf("unexpected value for i: %d", i)
944 }
945 if rows.Next() {
946 t.Fatalf("unexpected row")
947 }
948 pge, ok = rows.Err().(*Error)
949 if !ok {
950 t.Fatalf("expected *Error; got: %#v", err)
951 }
952 if pge.Code.Name() != "cardinality_violation" {
953 t.Fatalf("expected cardinality_violation; got: %s (%+v)", pge.Code.Name(), rows.Err())
954 }
955 }
956
957 func TestSimpleQuery(t *testing.T) {
958 db := openTestConn(t)
959 defer db.Close()
960
961 r, err := db.Query("select 1")
962 if err != nil {
963 t.Fatal(err)
964 }
965 defer r.Close()
966
967 if !r.Next() {
968 t.Fatal("expected row")
969 }
970 }
971
972 func TestBindError(t *testing.T) {
973 db := openTestConn(t)
974 defer db.Close()
975
976 _, err := db.Exec("create temp table test (i integer)")
977 if err != nil {
978 t.Fatal(err)
979 }
980
981 _, err = db.Query("select * from test where i=$1", "hhh")
982 if err == nil {
983 t.Fatal("expected an error")
984 }
985
986
987 r, err := db.Query("select * from test where i=$1", 1)
988 if err != nil {
989 t.Fatal(err)
990 }
991 defer r.Close()
992 }
993
994 func TestParseErrorInExtendedQuery(t *testing.T) {
995 db := openTestConn(t)
996 defer db.Close()
997
998 _, err := db.Query("PARSE_ERROR $1", 1)
999 pqErr, _ := err.(*Error)
1000
1001 if err == nil || pqErr == nil || pqErr.Code != "42601" {
1002 t.Fatalf("expected syntax error, got %s", err)
1003 }
1004
1005 rows, err := db.Query("SELECT 1")
1006 if err != nil {
1007 t.Fatal(err)
1008 }
1009 rows.Close()
1010 }
1011
1012
1013 func TestReturning(t *testing.T) {
1014 db := openTestConn(t)
1015 defer db.Close()
1016
1017 _, err := db.Exec("CREATE TEMP TABLE distributors (did integer default 0, dname text)")
1018 if err != nil {
1019 t.Fatal(err)
1020 }
1021
1022 rows, err := db.Query("INSERT INTO distributors (did, dname) VALUES (DEFAULT, 'XYZ Widgets') " +
1023 "RETURNING did;")
1024 if err != nil {
1025 t.Fatal(err)
1026 }
1027 if !rows.Next() {
1028 t.Fatal("no rows")
1029 }
1030 var did int
1031 err = rows.Scan(&did)
1032 if err != nil {
1033 t.Fatal(err)
1034 }
1035 if did != 0 {
1036 t.Fatalf("bad value for did: got %d, want %d", did, 0)
1037 }
1038
1039 if rows.Next() {
1040 t.Fatal("unexpected next row")
1041 }
1042 err = rows.Err()
1043 if err != nil {
1044 t.Fatal(err)
1045 }
1046 }
1047
1048 func TestIssue186(t *testing.T) {
1049 db := openTestConn(t)
1050 defer db.Close()
1051
1052
1053 _, err := db.Exec("VALUES (1), (2), (3)")
1054 if err != nil {
1055 t.Fatal(err)
1056 }
1057
1058 _, err = db.Exec("VALUES ($1), ($2), ($3)", 1, 2, 3)
1059 if err != nil {
1060 t.Fatal(err)
1061 }
1062
1063
1064 txn, err := db.Begin()
1065 if err != nil {
1066 t.Fatal(err)
1067 }
1068 defer txn.Rollback()
1069
1070 rows, err := txn.Query("CREATE TEMP TABLE foo(f1 int)")
1071 if err != nil {
1072 t.Fatal(err)
1073 }
1074 if err = rows.Close(); err != nil {
1075 t.Fatal(err)
1076 }
1077
1078
1079 _, err = txn.Exec("CREATE RULE nodata AS ON INSERT TO foo DO INSTEAD NOTHING")
1080 if err != nil {
1081 t.Fatal(err)
1082 }
1083 rows, err = txn.Query("INSERT INTO foo VALUES ($1)", 1)
1084 if err != nil {
1085 t.Fatal(err)
1086 }
1087 if err = rows.Close(); err != nil {
1088 t.Fatal(err)
1089 }
1090 }
1091
1092 func TestIssue196(t *testing.T) {
1093 db := openTestConn(t)
1094 defer db.Close()
1095
1096 row := db.QueryRow("SELECT float4 '0.10000122' = $1, float8 '35.03554004971999' = $2",
1097 float32(0.10000122), float64(35.03554004971999))
1098
1099 var float4match, float8match bool
1100 err := row.Scan(&float4match, &float8match)
1101 if err != nil {
1102 t.Fatal(err)
1103 }
1104 if !float4match {
1105 t.Errorf("Expected float4 fidelity to be maintained; got no match")
1106 }
1107 if !float8match {
1108 t.Errorf("Expected float8 fidelity to be maintained; got no match")
1109 }
1110 }
1111
1112
1113
1114 func TestIssue282(t *testing.T) {
1115 db := openTestConn(t)
1116 defer db.Close()
1117
1118 var searchPath string
1119 err := db.QueryRow(`
1120 SET LOCAL search_path TO pg_catalog;
1121 SET LOCAL search_path TO pg_catalog;
1122 SHOW search_path`).Scan(&searchPath)
1123 if err != nil {
1124 t.Fatal(err)
1125 }
1126 if searchPath != "pg_catalog" {
1127 t.Fatalf("unexpected search_path %s", searchPath)
1128 }
1129 }
1130
1131 func TestReadFloatPrecision(t *testing.T) {
1132 db := openTestConn(t)
1133 defer db.Close()
1134
1135 row := db.QueryRow("SELECT float4 '0.10000122', float8 '35.03554004971999', float4 '1.2'")
1136 var float4val float32
1137 var float8val float64
1138 var float4val2 float64
1139 err := row.Scan(&float4val, &float8val, &float4val2)
1140 if err != nil {
1141 t.Fatal(err)
1142 }
1143 if float4val != float32(0.10000122) {
1144 t.Errorf("Expected float4 fidelity to be maintained; got no match")
1145 }
1146 if float8val != float64(35.03554004971999) {
1147 t.Errorf("Expected float8 fidelity to be maintained; got no match")
1148 }
1149 if float4val2 != float64(1.2) {
1150 t.Errorf("Expected float4 fidelity into a float64 to be maintained; got no match")
1151 }
1152 }
1153
1154 func TestXactMultiStmt(t *testing.T) {
1155
1156
1157 t.Skip("Skipping failing test")
1158 db := openTestConn(t)
1159 defer db.Close()
1160
1161 tx, err := db.Begin()
1162 if err != nil {
1163 t.Fatal(err)
1164 }
1165 defer tx.Commit()
1166
1167 rows, err := tx.Query("select 1")
1168 if err != nil {
1169 t.Fatal(err)
1170 }
1171
1172 if rows.Next() {
1173 var val int32
1174 if err = rows.Scan(&val); err != nil {
1175 t.Fatal(err)
1176 }
1177 } else {
1178 t.Fatal("Expected at least one row in first query in xact")
1179 }
1180
1181 rows2, err := tx.Query("select 2")
1182 if err != nil {
1183 t.Fatal(err)
1184 }
1185
1186 if rows2.Next() {
1187 var val2 int32
1188 if err := rows2.Scan(&val2); err != nil {
1189 t.Fatal(err)
1190 }
1191 } else {
1192 t.Fatal("Expected at least one row in second query in xact")
1193 }
1194
1195 if err = rows.Err(); err != nil {
1196 t.Fatal(err)
1197 }
1198
1199 if err = rows2.Err(); err != nil {
1200 t.Fatal(err)
1201 }
1202
1203 if err = tx.Commit(); err != nil {
1204 t.Fatal(err)
1205 }
1206 }
1207
1208 var envParseTests = []struct {
1209 Expected map[string]string
1210 Env []string
1211 }{
1212 {
1213 Env: []string{"PGDATABASE=hello", "PGUSER=goodbye"},
1214 Expected: map[string]string{"dbname": "hello", "user": "goodbye"},
1215 },
1216 {
1217 Env: []string{"PGDATESTYLE=ISO, MDY"},
1218 Expected: map[string]string{"datestyle": "ISO, MDY"},
1219 },
1220 {
1221 Env: []string{"PGCONNECT_TIMEOUT=30"},
1222 Expected: map[string]string{"connect_timeout": "30"},
1223 },
1224 }
1225
1226 func TestParseEnviron(t *testing.T) {
1227 for i, tt := range envParseTests {
1228 results := parseEnviron(tt.Env)
1229 if !reflect.DeepEqual(tt.Expected, results) {
1230 t.Errorf("%d: Expected: %#v Got: %#v", i, tt.Expected, results)
1231 }
1232 }
1233 }
1234
1235 func TestParseComplete(t *testing.T) {
1236 tpc := func(commandTag string, command string, affectedRows int64, shouldFail bool) {
1237 defer func() {
1238 if p := recover(); p != nil {
1239 if !shouldFail {
1240 t.Error(p)
1241 }
1242 }
1243 }()
1244 cn := &conn{}
1245 res, c := cn.parseComplete(commandTag)
1246 if c != command {
1247 t.Errorf("Expected %v, got %v", command, c)
1248 }
1249 n, err := res.RowsAffected()
1250 if err != nil {
1251 t.Fatal(err)
1252 }
1253 if n != affectedRows {
1254 t.Errorf("Expected %d, got %d", affectedRows, n)
1255 }
1256 }
1257
1258 tpc("ALTER TABLE", "ALTER TABLE", 0, false)
1259 tpc("INSERT 0 1", "INSERT", 1, false)
1260 tpc("UPDATE 100", "UPDATE", 100, false)
1261 tpc("SELECT 100", "SELECT", 100, false)
1262 tpc("FETCH 100", "FETCH", 100, false)
1263
1264 tpc("COPY", "COPY", 0, false)
1265
1266 tpc("UNKNOWNCOMMANDTAG", "UNKNOWNCOMMANDTAG", 0, false)
1267
1268
1269 tpc("INSERT 1", "", 0, true)
1270 tpc("UPDATE 0 1", "", 0, true)
1271 tpc("SELECT foo", "", 0, true)
1272 }
1273
1274
1275 var (
1276 _ driver.ExecerContext = (*conn)(nil)
1277 _ driver.QueryerContext = (*conn)(nil)
1278 )
1279
1280 func TestNullAfterNonNull(t *testing.T) {
1281 db := openTestConn(t)
1282 defer db.Close()
1283
1284 r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer")
1285 if err != nil {
1286 t.Fatal(err)
1287 }
1288
1289 var n sql.NullInt64
1290
1291 if !r.Next() {
1292 if r.Err() != nil {
1293 t.Fatal(err)
1294 }
1295 t.Fatal("expected row")
1296 }
1297
1298 if err := r.Scan(&n); err != nil {
1299 t.Fatal(err)
1300 }
1301
1302 if n.Int64 != 9 {
1303 t.Fatalf("expected 2, not %d", n.Int64)
1304 }
1305
1306 if !r.Next() {
1307 if r.Err() != nil {
1308 t.Fatal(err)
1309 }
1310 t.Fatal("expected row")
1311 }
1312
1313 if err := r.Scan(&n); err != nil {
1314 t.Fatal(err)
1315 }
1316
1317 if n.Valid {
1318 t.Fatal("expected n to be invalid")
1319 }
1320
1321 if n.Int64 != 0 {
1322 t.Fatalf("expected n to 2, not %d", n.Int64)
1323 }
1324 }
1325
1326 func Test64BitErrorChecking(t *testing.T) {
1327 defer func() {
1328 if err := recover(); err != nil {
1329 t.Fatal("panic due to 0xFFFFFFFF != -1 " +
1330 "when int is 64 bits")
1331 }
1332 }()
1333
1334 db := openTestConn(t)
1335 defer db.Close()
1336
1337 r, err := db.Query(`SELECT *
1338 FROM (VALUES (0::integer, NULL::text), (1, 'test string')) AS t;`)
1339
1340 if err != nil {
1341 t.Fatal(err)
1342 }
1343
1344 defer r.Close()
1345
1346 for r.Next() {
1347 }
1348 }
1349
1350 func TestCommit(t *testing.T) {
1351 db := openTestConn(t)
1352 defer db.Close()
1353
1354 _, err := db.Exec("CREATE TEMP TABLE temp (a int)")
1355 if err != nil {
1356 t.Fatal(err)
1357 }
1358 sqlInsert := "INSERT INTO temp VALUES (1)"
1359 sqlSelect := "SELECT * FROM temp"
1360 tx, err := db.Begin()
1361 if err != nil {
1362 t.Fatal(err)
1363 }
1364 _, err = tx.Exec(sqlInsert)
1365 if err != nil {
1366 t.Fatal(err)
1367 }
1368 err = tx.Commit()
1369 if err != nil {
1370 t.Fatal(err)
1371 }
1372 var i int
1373 err = db.QueryRow(sqlSelect).Scan(&i)
1374 if err != nil {
1375 t.Fatal(err)
1376 }
1377 if i != 1 {
1378 t.Fatalf("expected 1, got %d", i)
1379 }
1380 }
1381
1382 func TestErrorClass(t *testing.T) {
1383 db := openTestConn(t)
1384 defer db.Close()
1385
1386 _, err := db.Query("SELECT int 'notint'")
1387 if err == nil {
1388 t.Fatal("expected error")
1389 }
1390 pge, ok := err.(*Error)
1391 if !ok {
1392 t.Fatalf("expected *pq.Error, got %#+v", err)
1393 }
1394 if pge.Code.Class() != "22" {
1395 t.Fatalf("expected class 28, got %v", pge.Code.Class())
1396 }
1397 if pge.Code.Class().Name() != "data_exception" {
1398 t.Fatalf("expected data_exception, got %v", pge.Code.Class().Name())
1399 }
1400 }
1401
1402 func TestParseOpts(t *testing.T) {
1403 tests := []struct {
1404 in string
1405 expected values
1406 valid bool
1407 }{
1408 {"dbname=hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
1409 {"dbname=hello user=goodbye ", values{"dbname": "hello", "user": "goodbye"}, true},
1410 {"dbname = hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
1411 {"dbname=hello user =goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
1412 {"dbname=hello user= goodbye", values{"dbname": "hello", "user": "goodbye"}, true},
1413 {"host=localhost password='correct horse battery staple'", values{"host": "localhost", "password": "correct horse battery staple"}, true},
1414 {"dbname=データベース password=パスワード", values{"dbname": "データベース", "password": "パスワード"}, true},
1415 {"dbname=hello user=''", values{"dbname": "hello", "user": ""}, true},
1416 {"user='' dbname=hello", values{"dbname": "hello", "user": ""}, true},
1417
1418 {"dbname=hello user= ", values{"dbname": "hello", "user": ""}, true},
1419
1420
1421 {"user= password=foo", values{"user": "password=foo"}, true},
1422
1423
1424 {`user=a\ \'\\b`, values{"user": `a '\b`}, true},
1425 {`user='a \'b'`, values{"user": `a 'b`}, true},
1426
1427
1428 {`user=x\`, values{}, false},
1429
1430
1431 {"postgre://marko@internet", values{}, false},
1432 {"dbname user=goodbye", values{}, false},
1433 {"user=foo blah", values{}, false},
1434 {"user=foo blah ", values{}, false},
1435
1436
1437 {"dbname=hello user='unterminated", values{}, false},
1438 }
1439
1440 for _, test := range tests {
1441 o := make(values)
1442 err := parseOpts(test.in, o)
1443
1444 switch {
1445 case err != nil && test.valid:
1446 t.Errorf("%q got unexpected error: %s", test.in, err)
1447 case err == nil && test.valid && !reflect.DeepEqual(test.expected, o):
1448 t.Errorf("%q got: %#v want: %#v", test.in, o, test.expected)
1449 case err == nil && !test.valid:
1450 t.Errorf("%q expected an error", test.in)
1451 }
1452 }
1453 }
1454
1455 func TestRuntimeParameters(t *testing.T) {
1456 tests := []struct {
1457 conninfo string
1458 param string
1459 expected string
1460 success bool
1461 }{
1462
1463 {"DOESNOTEXIST=foo", "", "", false},
1464
1465 {"client_encoding=SQL_ASCII", "", "", false},
1466 {"datestyle='ISO, YDM'", "", "", false},
1467
1468 {"options='-c search_path=pqgotest'", "search_path", "pqgotest", true},
1469
1470 {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", true},
1471
1472 {"client_encoding=UTF8", "client_encoding", "UTF8", true},
1473
1474 {"work_mem='139kB'", "work_mem", "139kB", true},
1475
1476 {"application_name=foo fallback_application_name=bar", "application_name", "foo", true},
1477 {"application_name='' fallback_application_name=bar", "application_name", "", true},
1478 {"fallback_application_name=bar", "application_name", "bar", true},
1479 }
1480
1481 for _, test := range tests {
1482 db, err := openTestConnConninfo(test.conninfo)
1483 if err != nil {
1484 t.Fatal(err)
1485 }
1486
1487
1488 if test.param == "application_name" && getServerVersion(t, db) < 90000 {
1489 db.Close()
1490 continue
1491 }
1492
1493 tryGetParameterValue := func() (value string, success bool) {
1494 defer db.Close()
1495 row := db.QueryRow("SELECT current_setting($1)", test.param)
1496 err = row.Scan(&value)
1497 if err != nil {
1498 return "", false
1499 }
1500 return value, true
1501 }
1502
1503 value, success := tryGetParameterValue()
1504 if success != test.success && !test.success {
1505 t.Fatalf("%v: unexpected error: %v", test.conninfo, err)
1506 }
1507 if success != test.success {
1508 t.Fatalf("unexpected outcome %v (was expecting %v) for conninfo \"%s\"",
1509 success, test.success, test.conninfo)
1510 }
1511 if value != test.expected {
1512 t.Fatalf("bad value for %s: got %s, want %s with conninfo \"%s\"",
1513 test.param, value, test.expected, test.conninfo)
1514 }
1515 }
1516 }
1517
1518 func TestIsUTF8(t *testing.T) {
1519 var cases = []struct {
1520 name string
1521 want bool
1522 }{
1523 {"unicode", true},
1524 {"utf-8", true},
1525 {"utf_8", true},
1526 {"UTF-8", true},
1527 {"UTF8", true},
1528 {"utf8", true},
1529 {"u n ic_ode", true},
1530 {"ut_f%8", true},
1531 {"ubf8", false},
1532 {"punycode", false},
1533 }
1534
1535 for _, test := range cases {
1536 if g := isUTF8(test.name); g != test.want {
1537 t.Errorf("isUTF8(%q) = %v want %v", test.name, g, test.want)
1538 }
1539 }
1540 }
1541
1542 func TestQuoteIdentifier(t *testing.T) {
1543 var cases = []struct {
1544 input string
1545 want string
1546 }{
1547 {`foo`, `"foo"`},
1548 {`foo bar baz`, `"foo bar baz"`},
1549 {`foo"bar`, `"foo""bar"`},
1550 {"foo\x00bar", `"foo"`},
1551 {"\x00foo", `""`},
1552 }
1553
1554 for _, test := range cases {
1555 got := QuoteIdentifier(test.input)
1556 if got != test.want {
1557 t.Errorf("QuoteIdentifier(%q) = %v want %v", test.input, got, test.want)
1558 }
1559 }
1560 }
1561
1562 func TestQuoteLiteral(t *testing.T) {
1563 var cases = []struct {
1564 input string
1565 want string
1566 }{
1567 {`foo`, `'foo'`},
1568 {`foo bar baz`, `'foo bar baz'`},
1569 {`foo'bar`, `'foo''bar'`},
1570 {`foo\bar`, ` E'foo\\bar'`},
1571 {`foo\ba'r`, ` E'foo\\ba''r'`},
1572 {`foo"bar`, `'foo"bar'`},
1573 {`foo\x00bar`, ` E'foo\\x00bar'`},
1574 {`\x00foo`, ` E'\\x00foo'`},
1575 {`'`, `''''`},
1576 {`''`, `''''''`},
1577 {`\`, ` E'\\'`},
1578 {`'abc'; DROP TABLE users;`, `'''abc''; DROP TABLE users;'`},
1579 {`\'`, ` E'\\'''`},
1580 {`E'\''`, ` E'E''\\'''''`},
1581 {`e'\''`, ` E'e''\\'''''`},
1582 {`E'\'abc\'; DROP TABLE users;'`, ` E'E''\\''abc\\''; DROP TABLE users;'''`},
1583 {`e'\'abc\'; DROP TABLE users;'`, ` E'e''\\''abc\\''; DROP TABLE users;'''`},
1584 }
1585
1586 for _, test := range cases {
1587 got := QuoteLiteral(test.input)
1588 if got != test.want {
1589 t.Errorf("QuoteLiteral(%q) = %v want %v", test.input, got, test.want)
1590 }
1591 }
1592 }
1593
1594 func TestRowsResultTag(t *testing.T) {
1595 type ResultTag interface {
1596 Result() driver.Result
1597 Tag() string
1598 }
1599
1600 tests := []struct {
1601 query string
1602 tag string
1603 ra int64
1604 }{
1605 {
1606 query: "CREATE TEMP TABLE temp (a int)",
1607 tag: "CREATE TABLE",
1608 },
1609 {
1610 query: "INSERT INTO temp VALUES (1), (2)",
1611 tag: "INSERT",
1612 ra: 2,
1613 },
1614 {
1615 query: "SELECT 1",
1616 },
1617
1618 {
1619 query: "SELECT 1; INSERT INTO temp VALUES (1), (2)",
1620 },
1621 {
1622 query: "INSERT INTO temp VALUES (1), (2); SELECT 1",
1623 },
1624
1625 {
1626 query: "CREATE TEMP TABLE t (a int); DROP TABLE t",
1627 tag: "DROP TABLE",
1628 },
1629
1630
1631 {
1632 query: "SELECT 1; CREATE TEMP TABLE t (a int); DROP TABLE t",
1633 },
1634 {
1635 query: "CREATE TEMP TABLE t (a int); SELECT 1; DROP TABLE t",
1636 },
1637 {
1638 query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1",
1639 },
1640 }
1641
1642
1643 openTestConn(t).Close()
1644
1645 conn, err := Open("")
1646 if err != nil {
1647 t.Fatal(err)
1648 }
1649 defer conn.Close()
1650 q := conn.(driver.QueryerContext)
1651
1652 for _, test := range tests {
1653 if rows, err := q.QueryContext(context.Background(), test.query, nil); err != nil {
1654 t.Fatalf("%s: %s", test.query, err)
1655 } else {
1656 r := rows.(ResultTag)
1657 if tag := r.Tag(); tag != test.tag {
1658 t.Fatalf("%s: unexpected tag %q", test.query, tag)
1659 }
1660 res := r.Result()
1661 if ra, _ := res.RowsAffected(); ra != test.ra {
1662 t.Fatalf("%s: unexpected rows affected: %d", test.query, ra)
1663 }
1664 rows.Close()
1665 }
1666 }
1667 }
1668
1669
1670 func TestQuickClose(t *testing.T) {
1671 db := openTestConn(t)
1672 defer db.Close()
1673
1674 tx, err := db.Begin()
1675 if err != nil {
1676 t.Fatal(err)
1677 }
1678 rows, err := tx.Query("SELECT 1; SELECT 2;")
1679 if err != nil {
1680 t.Fatal(err)
1681 }
1682 if err := rows.Close(); err != nil {
1683 t.Fatal(err)
1684 }
1685
1686 var id int
1687 if err := tx.QueryRow("SELECT 3").Scan(&id); err != nil {
1688 t.Fatal(err)
1689 }
1690 if id != 3 {
1691 t.Fatalf("unexpected %d", id)
1692 }
1693 if err := tx.Commit(); err != nil {
1694 t.Fatal(err)
1695 }
1696 }
1697
1698 func TestMultipleResult(t *testing.T) {
1699 db := openTestConn(t)
1700 defer db.Close()
1701
1702 rows, err := db.Query(`
1703 begin;
1704 select * from information_schema.tables limit 1;
1705 select * from information_schema.columns limit 2;
1706 commit;
1707 `)
1708 if err != nil {
1709 t.Fatal(err)
1710 }
1711 type set struct {
1712 cols []string
1713 rowCount int
1714 }
1715 buf := []*set{}
1716 for {
1717 cols, err := rows.Columns()
1718 if err != nil {
1719 t.Fatal(err)
1720 }
1721 s := &set{
1722 cols: cols,
1723 }
1724 buf = append(buf, s)
1725
1726 for rows.Next() {
1727 s.rowCount++
1728 }
1729 if !rows.NextResultSet() {
1730 break
1731 }
1732 }
1733 if len(buf) != 2 {
1734 t.Fatalf("got %d sets, expected 2", len(buf))
1735 }
1736 if len(buf[0].cols) == len(buf[1].cols) || len(buf[1].cols) == 0 {
1737 t.Fatal("invalid cols size, expected different column count and greater then zero")
1738 }
1739 if buf[0].rowCount != 1 || buf[1].rowCount != 2 {
1740 t.Fatal("incorrect number of rows returned")
1741 }
1742 }
1743
1744 func TestMultipleEmptyResult(t *testing.T) {
1745 db := openTestConn(t)
1746 defer db.Close()
1747
1748 rows, err := db.Query("select 1 where false; select 2")
1749 if err != nil {
1750 t.Fatal(err)
1751 }
1752 defer rows.Close()
1753
1754 for rows.Next() {
1755 t.Fatal("unexpected row")
1756 }
1757 if !rows.NextResultSet() {
1758 t.Fatal("expected more result sets", rows.Err())
1759 }
1760 for rows.Next() {
1761 var i int
1762 if err := rows.Scan(&i); err != nil {
1763 t.Fatal(err)
1764 }
1765 if i != 2 {
1766 t.Fatalf("expected 2, got %d", i)
1767 }
1768 }
1769 if rows.NextResultSet() {
1770 t.Fatal("unexpected result set")
1771 }
1772 }
1773
1774 func TestCopyInStmtAffectedRows(t *testing.T) {
1775 db := openTestConn(t)
1776 defer db.Close()
1777
1778 _, err := db.Exec("CREATE TEMP TABLE temp (a int)")
1779 if err != nil {
1780 t.Fatal(err)
1781 }
1782
1783 txn, err := db.BeginTx(context.TODO(), nil)
1784 if err != nil {
1785 t.Fatal(err)
1786 }
1787
1788 copyStmt, err := txn.Prepare(CopyIn("temp", "a"))
1789 if err != nil {
1790 t.Fatal(err)
1791 }
1792
1793 res, err := copyStmt.Exec()
1794 if err != nil {
1795 t.Fatal(err)
1796 }
1797
1798 res.RowsAffected()
1799 res.LastInsertId()
1800 }
1801
1802 func TestConnPrepareContext(t *testing.T) {
1803 db := openTestConn(t)
1804 defer db.Close()
1805
1806 tests := []struct {
1807 name string
1808 ctx func() (context.Context, context.CancelFunc)
1809 sql string
1810 err error
1811 }{
1812 {
1813 name: "context.Background",
1814 ctx: func() (context.Context, context.CancelFunc) {
1815 return context.Background(), nil
1816 },
1817 sql: "SELECT 1",
1818 err: nil,
1819 },
1820 {
1821 name: "context.WithTimeout exceeded",
1822 ctx: func() (context.Context, context.CancelFunc) {
1823 return context.WithTimeout(context.Background(), -time.Minute)
1824 },
1825 sql: "SELECT 1",
1826 err: context.DeadlineExceeded,
1827 },
1828 {
1829 name: "context.WithTimeout",
1830 ctx: func() (context.Context, context.CancelFunc) {
1831 return context.WithTimeout(context.Background(), time.Minute)
1832 },
1833 sql: "SELECT 1",
1834 err: nil,
1835 },
1836 }
1837 for _, tt := range tests {
1838 t.Run(tt.name, func(t *testing.T) {
1839 ctx, cancel := tt.ctx()
1840 if cancel != nil {
1841 defer cancel()
1842 }
1843 _, err := db.PrepareContext(ctx, tt.sql)
1844 switch {
1845 case (err != nil) != (tt.err != nil):
1846 t.Fatalf("conn.PrepareContext() unexpected nil err got = %v, expected = %v", err, tt.err)
1847 case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
1848 t.Errorf("conn.PrepareContext() got = %v, expected = %v", err.Error(), tt.err.Error())
1849 }
1850 })
1851 }
1852 }
1853
1854 func TestStmtQueryContext(t *testing.T) {
1855 db := openTestConn(t)
1856 defer db.Close()
1857
1858 tests := []struct {
1859 name string
1860 ctx func() (context.Context, context.CancelFunc)
1861 sql string
1862 cancelExpected bool
1863 }{
1864 {
1865 name: "context.Background",
1866 ctx: func() (context.Context, context.CancelFunc) {
1867 return context.Background(), nil
1868 },
1869 sql: "SELECT pg_sleep(1);",
1870 cancelExpected: false,
1871 },
1872 {
1873 name: "context.WithTimeout exceeded",
1874 ctx: func() (context.Context, context.CancelFunc) {
1875 return context.WithTimeout(context.Background(), 1*time.Second)
1876 },
1877 sql: "SELECT pg_sleep(10);",
1878 cancelExpected: true,
1879 },
1880 {
1881 name: "context.WithTimeout",
1882 ctx: func() (context.Context, context.CancelFunc) {
1883 return context.WithTimeout(context.Background(), time.Minute)
1884 },
1885 sql: "SELECT pg_sleep(1);",
1886 cancelExpected: false,
1887 },
1888 }
1889 for _, tt := range tests {
1890 t.Run(tt.name, func(t *testing.T) {
1891 ctx, cancel := tt.ctx()
1892 if cancel != nil {
1893 defer cancel()
1894 }
1895 stmt, err := db.PrepareContext(ctx, tt.sql)
1896 if err != nil {
1897 t.Fatal(err)
1898 }
1899 _, err = stmt.QueryContext(ctx)
1900 pgErr := (*Error)(nil)
1901 switch {
1902 case (err != nil) != tt.cancelExpected:
1903 t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
1904 case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
1905 t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
1906 }
1907 })
1908 }
1909 }
1910
1911 func TestStmtExecContext(t *testing.T) {
1912 db := openTestConn(t)
1913 defer db.Close()
1914
1915 tests := []struct {
1916 name string
1917 ctx func() (context.Context, context.CancelFunc)
1918 sql string
1919 cancelExpected bool
1920 }{
1921 {
1922 name: "context.Background",
1923 ctx: func() (context.Context, context.CancelFunc) {
1924 return context.Background(), nil
1925 },
1926 sql: "SELECT pg_sleep(1);",
1927 cancelExpected: false,
1928 },
1929 {
1930 name: "context.WithTimeout exceeded",
1931 ctx: func() (context.Context, context.CancelFunc) {
1932 return context.WithTimeout(context.Background(), 1*time.Second)
1933 },
1934 sql: "SELECT pg_sleep(10);",
1935 cancelExpected: true,
1936 },
1937 {
1938 name: "context.WithTimeout",
1939 ctx: func() (context.Context, context.CancelFunc) {
1940 return context.WithTimeout(context.Background(), time.Minute)
1941 },
1942 sql: "SELECT pg_sleep(1);",
1943 cancelExpected: false,
1944 },
1945 }
1946 for _, tt := range tests {
1947 t.Run(tt.name, func(t *testing.T) {
1948 ctx, cancel := tt.ctx()
1949 if cancel != nil {
1950 defer cancel()
1951 }
1952 stmt, err := db.PrepareContext(ctx, tt.sql)
1953 if err != nil {
1954 t.Fatal(err)
1955 }
1956 _, err = stmt.ExecContext(ctx)
1957 pgErr := (*Error)(nil)
1958 switch {
1959 case (err != nil) != tt.cancelExpected:
1960 t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected)
1961 case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode):
1962 t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected)
1963 }
1964 })
1965 }
1966 }
1967
View as plain text