1 package pgx
2
3
4
5 import (
6 "context"
7 "database/sql"
8 sqldriver "database/sql/driver"
9 "errors"
10 "fmt"
11 "log"
12
13 "io"
14 "strconv"
15 "strings"
16 "sync"
17 "testing"
18
19 "github.com/golang-migrate/migrate/v4"
20
21 "github.com/dhui/dktest"
22
23 "github.com/golang-migrate/migrate/v4/database"
24 dt "github.com/golang-migrate/migrate/v4/database/testing"
25 "github.com/golang-migrate/migrate/v4/dktesting"
26 _ "github.com/golang-migrate/migrate/v4/source/file"
27 )
28
29 const (
30 pgPassword = "postgres"
31 )
32
33 var (
34 opts = dktest.Options{
35 Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
36 PortRequired: true, ReadyFunc: isReady}
37
38 specs = []dktesting.ContainerSpec{
39 {ImageName: "postgres:9.5", Options: opts},
40 {ImageName: "postgres:9.6", Options: opts},
41 {ImageName: "postgres:10", Options: opts},
42 {ImageName: "postgres:11", Options: opts},
43 {ImageName: "postgres:12", Options: opts},
44 }
45 )
46
47 func pgConnectionString(host, port string, options ...string) string {
48 options = append(options, "sslmode=disable")
49 return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
50 }
51
52 func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
53 ip, port, err := c.FirstPort()
54 if err != nil {
55 return false
56 }
57
58 db, err := sql.Open("pgx", pgConnectionString(ip, port))
59 if err != nil {
60 return false
61 }
62 defer func() {
63 if err := db.Close(); err != nil {
64 log.Println("close error:", err)
65 }
66 }()
67 if err = db.PingContext(ctx); err != nil {
68 switch err {
69 case sqldriver.ErrBadConn, io.EOF:
70 return false
71 default:
72 log.Println(err)
73 }
74 return false
75 }
76
77 return true
78 }
79
80 func mustRun(t *testing.T, d database.Driver, statements []string) {
81 for _, statement := range statements {
82 if err := d.Run(strings.NewReader(statement)); err != nil {
83 t.Fatal(err)
84 }
85 }
86 }
87
88 func Test(t *testing.T) {
89 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
90 ip, port, err := c.FirstPort()
91 if err != nil {
92 t.Fatal(err)
93 }
94
95 addr := pgConnectionString(ip, port)
96 p := &Postgres{}
97 d, err := p.Open(addr)
98 if err != nil {
99 t.Fatal(err)
100 }
101 defer func() {
102 if err := d.Close(); err != nil {
103 t.Error(err)
104 }
105 }()
106 dt.Test(t, d, []byte("SELECT 1"))
107 })
108 }
109
110 func TestMigrate(t *testing.T) {
111 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
112 ip, port, err := c.FirstPort()
113 if err != nil {
114 t.Fatal(err)
115 }
116
117 addr := pgConnectionString(ip, port)
118 p := &Postgres{}
119 d, err := p.Open(addr)
120 if err != nil {
121 t.Fatal(err)
122 }
123 defer func() {
124 if err := d.Close(); err != nil {
125 t.Error(err)
126 }
127 }()
128 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
129 if err != nil {
130 t.Fatal(err)
131 }
132 dt.TestMigrate(t, m)
133 })
134 }
135
136 func TestMultipleStatements(t *testing.T) {
137 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
138 ip, port, err := c.FirstPort()
139 if err != nil {
140 t.Fatal(err)
141 }
142
143 addr := pgConnectionString(ip, port)
144 p := &Postgres{}
145 d, err := p.Open(addr)
146 if err != nil {
147 t.Fatal(err)
148 }
149 defer func() {
150 if err := d.Close(); err != nil {
151 t.Error(err)
152 }
153 }()
154 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
155 t.Fatalf("expected err to be nil, got %v", err)
156 }
157
158
159 var exists bool
160 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
161 t.Fatal(err)
162 }
163 if !exists {
164 t.Fatalf("expected table bar to exist")
165 }
166 })
167 }
168
169 func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
170 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
171 ip, port, err := c.FirstPort()
172 if err != nil {
173 t.Fatal(err)
174 }
175
176 addr := pgConnectionString(ip, port, "x-multi-statement=true")
177 p := &Postgres{}
178 d, err := p.Open(addr)
179 if err != nil {
180 t.Fatal(err)
181 }
182 defer func() {
183 if err := d.Close(); err != nil {
184 t.Error(err)
185 }
186 }()
187 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
188 t.Fatalf("expected err to be nil, got %v", err)
189 }
190
191
192 var exists bool
193 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
194 t.Fatal(err)
195 }
196 if !exists {
197 t.Fatalf("expected table bar to exist")
198 }
199 })
200 }
201
202 func TestErrorParsing(t *testing.T) {
203 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
204 ip, port, err := c.FirstPort()
205 if err != nil {
206 t.Fatal(err)
207 }
208
209 addr := pgConnectionString(ip, port)
210 p := &Postgres{}
211 d, err := p.Open(addr)
212 if err != nil {
213 t.Fatal(err)
214 }
215 defer func() {
216 if err := d.Close(); err != nil {
217 t.Error(err)
218 }
219 }()
220
221 wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
222 `(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
223 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
224 t.Fatal("expected err but got nil")
225 } else if err.Error() != wantErr {
226 t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
227 }
228 })
229 }
230
231 func TestFilterCustomQuery(t *testing.T) {
232 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
233 ip, port, err := c.FirstPort()
234 if err != nil {
235 t.Fatal(err)
236 }
237
238 addr := pgConnectionString(ip, port, "x-custom=foobar")
239 p := &Postgres{}
240 d, err := p.Open(addr)
241 if err != nil {
242 t.Fatal(err)
243 }
244 defer func() {
245 if err := d.Close(); err != nil {
246 t.Error(err)
247 }
248 }()
249 })
250 }
251
252 func TestWithSchema(t *testing.T) {
253 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
254 ip, port, err := c.FirstPort()
255 if err != nil {
256 t.Fatal(err)
257 }
258
259 addr := pgConnectionString(ip, port)
260 p := &Postgres{}
261 d, err := p.Open(addr)
262 if err != nil {
263 t.Fatal(err)
264 }
265 defer func() {
266 if err := d.Close(); err != nil {
267 t.Fatal(err)
268 }
269 }()
270
271
272 if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
273 t.Fatal(err)
274 }
275 if err := d.SetVersion(1, false); err != nil {
276 t.Fatal(err)
277 }
278
279
280 d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
281 if err != nil {
282 t.Fatal(err)
283 }
284 defer func() {
285 if err := d2.Close(); err != nil {
286 t.Fatal(err)
287 }
288 }()
289
290 version, _, err := d2.Version()
291 if err != nil {
292 t.Fatal(err)
293 }
294 if version != database.NilVersion {
295 t.Fatal("expected NilVersion")
296 }
297
298
299 if err := d2.SetVersion(2, false); err != nil {
300 t.Fatal(err)
301 }
302 version, _, err = d2.Version()
303 if err != nil {
304 t.Fatal(err)
305 }
306 if version != 2 {
307 t.Fatal("expected version 2")
308 }
309
310
311 version, _, err = d.Version()
312 if err != nil {
313 t.Fatal(err)
314 }
315 if version != 1 {
316 t.Fatal("expected version 2")
317 }
318 })
319 }
320
321 func TestMigrationTableOption(t *testing.T) {
322 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
323 ip, port, err := c.FirstPort()
324 if err != nil {
325 t.Fatal(err)
326 }
327
328 addr := pgConnectionString(ip, port)
329 p := &Postgres{}
330 d, _ := p.Open(addr)
331 defer func() {
332 if err := d.Close(); err != nil {
333 t.Fatal(err)
334 }
335 }()
336
337
338 if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
339 t.Fatal(err)
340 }
341
342
343 wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
344 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
345 pgPassword, ip, port))
346 if (err != nil) && (err.Error() != wantErr) {
347 t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
348 }
349
350
351 wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
352 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
353 pgPassword, ip, port))
354 if (err != nil) && (err.Error() != wantErr) {
355 t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
356 }
357
358
359 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
360 pgPassword, ip, port))
361 if err != nil {
362 t.Fatal(err)
363 }
364
365
366 var exists bool
367 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
368 t.Fatal(err)
369 }
370 if !exists {
371 t.Fatalf("expected table migrate.schema_migrations to exist")
372 }
373
374 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
375 pgPassword, ip, port))
376 if err != nil {
377 t.Fatal(err)
378 }
379 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
380 t.Fatal(err)
381 }
382 if !exists {
383 t.Fatalf("expected table 'migrate.schema_migrations' to exist")
384 }
385
386 })
387 }
388
389 func TestFailToCreateTableWithoutPermissions(t *testing.T) {
390 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
391 ip, port, err := c.FirstPort()
392 if err != nil {
393 t.Fatal(err)
394 }
395
396 addr := pgConnectionString(ip, port)
397
398
399 p := &Postgres{}
400
401 d, err := p.Open(addr)
402
403 if err != nil {
404 t.Fatal(err)
405 }
406
407 defer func() {
408 if err := d.Close(); err != nil {
409 t.Error(err)
410 }
411 }()
412
413
414
415 mustRun(t, d, []string{
416 "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
417 "CREATE SCHEMA barfoo AUTHORIZATION postgres",
418 "GRANT USAGE ON SCHEMA barfoo TO not_owner",
419 "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
420 "REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
421 })
422
423
424 d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
425 pgPassword, ip, port))
426
427 defer func() {
428 if d2 == nil {
429 return
430 }
431 if err := d2.Close(); err != nil {
432 t.Fatal(err)
433 }
434 }()
435
436 var e *database.Error
437 if !errors.As(err, &e) || err == nil {
438 t.Fatal("Unexpected error, want permission denied error. Got: ", err)
439 }
440
441 if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
442 t.Fatal(e)
443 }
444
445
446 d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
447 pgPassword, ip, port))
448
449 if !errors.As(err, &e) || err == nil {
450 t.Fatal("Unexpected error, want permission denied error. Got: ", err)
451 }
452
453 if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
454 t.Fatal(e)
455 }
456 })
457 }
458
459 func TestCheckBeforeCreateTable(t *testing.T) {
460 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
461 ip, port, err := c.FirstPort()
462 if err != nil {
463 t.Fatal(err)
464 }
465
466 addr := pgConnectionString(ip, port)
467
468
469 p := &Postgres{}
470
471 d, err := p.Open(addr)
472
473 if err != nil {
474 t.Fatal(err)
475 }
476
477 defer func() {
478 if err := d.Close(); err != nil {
479 t.Error(err)
480 }
481 }()
482
483
484
485 mustRun(t, d, []string{
486 "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
487 "CREATE SCHEMA barfoo AUTHORIZATION postgres",
488 "GRANT USAGE ON SCHEMA barfoo TO not_owner",
489 "GRANT CREATE ON SCHEMA barfoo TO not_owner",
490 })
491
492
493 d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
494 pgPassword, ip, port))
495
496 if err != nil {
497 t.Fatal(err)
498 }
499
500 if err := d2.Close(); err != nil {
501 t.Fatal(err)
502 }
503
504
505 mustRun(t, d, []string{
506 "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
507 "REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
508 })
509
510
511 d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
512 pgPassword, ip, port))
513
514 if err != nil {
515 t.Fatal(err)
516 }
517
518 version, _, err := d3.Version()
519
520 if err != nil {
521 t.Fatal(err)
522 }
523
524 if version != database.NilVersion {
525 t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
526 }
527
528 defer func() {
529 if err := d3.Close(); err != nil {
530 t.Fatal(err)
531 }
532 }()
533 })
534 }
535
536 func TestParallelSchema(t *testing.T) {
537 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
538 ip, port, err := c.FirstPort()
539 if err != nil {
540 t.Fatal(err)
541 }
542
543 addr := pgConnectionString(ip, port)
544 p := &Postgres{}
545 d, err := p.Open(addr)
546 if err != nil {
547 t.Fatal(err)
548 }
549 defer func() {
550 if err := d.Close(); err != nil {
551 t.Error(err)
552 }
553 }()
554
555
556 if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
557 t.Fatal(err)
558 }
559 if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
560 t.Fatal(err)
561 }
562
563
564 dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
565 if err != nil {
566 t.Fatal(err)
567 }
568 defer func() {
569 if err := dfoo.Close(); err != nil {
570 t.Error(err)
571 }
572 }()
573
574 dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
575 if err != nil {
576 t.Fatal(err)
577 }
578 defer func() {
579 if err := dbar.Close(); err != nil {
580 t.Error(err)
581 }
582 }()
583
584 if err := dfoo.Lock(); err != nil {
585 t.Fatal(err)
586 }
587
588 if err := dbar.Lock(); err != nil {
589 t.Fatal(err)
590 }
591
592 if err := dbar.Unlock(); err != nil {
593 t.Fatal(err)
594 }
595
596 if err := dfoo.Unlock(); err != nil {
597 t.Fatal(err)
598 }
599 })
600 }
601
602 func TestPostgres_Lock(t *testing.T) {
603 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
604 ip, port, err := c.FirstPort()
605 if err != nil {
606 t.Fatal(err)
607 }
608
609 addr := pgConnectionString(ip, port)
610 p := &Postgres{}
611 d, err := p.Open(addr)
612 if err != nil {
613 t.Fatal(err)
614 }
615
616 dt.Test(t, d, []byte("SELECT 1"))
617
618 ps := d.(*Postgres)
619
620 err = ps.Lock()
621 if err != nil {
622 t.Fatal(err)
623 }
624
625 err = ps.Unlock()
626 if err != nil {
627 t.Fatal(err)
628 }
629
630 err = ps.Lock()
631 if err != nil {
632 t.Fatal(err)
633 }
634
635 err = ps.Unlock()
636 if err != nil {
637 t.Fatal(err)
638 }
639 })
640 }
641
642 func TestWithInstance_Concurrent(t *testing.T) {
643 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
644 ip, port, err := c.FirstPort()
645 if err != nil {
646 t.Fatal(err)
647 }
648
649
650 const concurrency = 30
651
652
653
654
655
656 db, err := sql.Open("pgx", pgConnectionString(ip, port))
657 if err != nil {
658 t.Fatal(err)
659 }
660 defer func() {
661 if err := db.Close(); err != nil {
662 t.Error(err)
663 }
664 }()
665
666 db.SetMaxIdleConns(concurrency)
667 db.SetMaxOpenConns(concurrency)
668
669 var wg sync.WaitGroup
670 defer wg.Wait()
671
672 wg.Add(concurrency)
673 for i := 0; i < concurrency; i++ {
674 go func(i int) {
675 defer wg.Done()
676 _, err := WithInstance(db, &Config{})
677 if err != nil {
678 t.Errorf("process %d error: %s", i, err)
679 }
680 }(i)
681 }
682 })
683 }
684 func Test_computeLineFromPos(t *testing.T) {
685 testcases := []struct {
686 pos int
687 wantLine uint
688 wantCol uint
689 input string
690 wantOk bool
691 }{
692 {
693 15, 2, 6, "SELECT *\nFROM foo", true,
694 },
695 {
696 16, 3, 6, "SELECT *\n\nFROM foo", true,
697 },
698 {
699 25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true,
700 },
701 {
702 27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true,
703 },
704 {
705 10, 2, 1, "SELECT *\nFROMM foo", true,
706 },
707 {
708 11, 3, 1, "SELECT *\n\nFROMM foo", true,
709 },
710 {
711 17, 2, 8, "SELECT *\nFROM foo", true,
712 },
713 {
714 18, 0, 0, "SELECT *\nFROM foo", false,
715 },
716 }
717 for i, tc := range testcases {
718 t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
719 run := func(crlf bool, nonASCII bool) {
720 var name string
721 if crlf {
722 name = "crlf"
723 } else {
724 name = "lf"
725 }
726 if nonASCII {
727 name += "-nonascii"
728 } else {
729 name += "-ascii"
730 }
731 t.Run(name, func(t *testing.T) {
732 input := tc.input
733 if crlf {
734 input = strings.Replace(input, "\n", "\r\n", -1)
735 }
736 if nonASCII {
737 input = strings.Replace(input, "FROM", "FRÖM", -1)
738 }
739 gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
740
741 if tc.wantOk {
742 t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
743 }
744
745 if gotOK != tc.wantOk {
746 t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
747 }
748 if gotLine != tc.wantLine {
749 t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
750 }
751 if gotCol != tc.wantCol {
752 t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
753 }
754 })
755 }
756 run(false, false)
757 run(true, false)
758 run(false, true)
759 run(true, true)
760 })
761 }
762 }
763
View as plain text