1 package postgres
2
3
4
5 import (
6 "context"
7 "database/sql"
8 sqldriver "database/sql/driver"
9 "errors"
10 "fmt"
11 "io"
12 "log"
13 "strconv"
14 "strings"
15 "sync"
16 "testing"
17
18 "github.com/golang-migrate/migrate/v4"
19
20 "github.com/dhui/dktest"
21
22 "github.com/golang-migrate/migrate/v4/database"
23 dt "github.com/golang-migrate/migrate/v4/database/testing"
24 "github.com/golang-migrate/migrate/v4/dktesting"
25 _ "github.com/golang-migrate/migrate/v4/source/file"
26 )
27
28 const (
29 pgPassword = "postgres"
30 )
31
32 var (
33 opts = dktest.Options{
34 Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
35 PortRequired: true, ReadyFunc: isReady}
36
37 specs = []dktesting.ContainerSpec{
38 {ImageName: "postgres:9.5", Options: opts},
39 {ImageName: "postgres:9.6", Options: opts},
40 {ImageName: "postgres:10", Options: opts},
41 {ImageName: "postgres:11", Options: opts},
42 {ImageName: "postgres:12", Options: opts},
43 }
44 )
45
46 func pgConnectionString(host, port string, options ...string) string {
47 options = append(options, "sslmode=disable")
48 return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
49 }
50
51 func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
52 ip, port, err := c.FirstPort()
53 if err != nil {
54 return false
55 }
56
57 db, err := sql.Open("postgres", pgConnectionString(ip, port))
58 if err != nil {
59 return false
60 }
61 defer func() {
62 if err := db.Close(); err != nil {
63 log.Println("close error:", err)
64 }
65 }()
66 if err = db.PingContext(ctx); err != nil {
67 switch err {
68 case sqldriver.ErrBadConn, io.EOF:
69 return false
70 default:
71 log.Println(err)
72 }
73 return false
74 }
75
76 return true
77 }
78
79 func mustRun(t *testing.T, d database.Driver, statements []string) {
80 for _, statement := range statements {
81 if err := d.Run(strings.NewReader(statement)); err != nil {
82 t.Fatal(err)
83 }
84 }
85 }
86
87 func Test(t *testing.T) {
88 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
89 ip, port, err := c.FirstPort()
90 if err != nil {
91 t.Fatal(err)
92 }
93
94 addr := pgConnectionString(ip, port)
95 p := &Postgres{}
96 d, err := p.Open(addr)
97 if err != nil {
98 t.Fatal(err)
99 }
100 defer func() {
101 if err := d.Close(); err != nil {
102 t.Error(err)
103 }
104 }()
105 dt.Test(t, d, []byte("SELECT 1"))
106 })
107 }
108
109 func TestMigrate(t *testing.T) {
110 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
111 ip, port, err := c.FirstPort()
112 if err != nil {
113 t.Fatal(err)
114 }
115
116 addr := pgConnectionString(ip, port)
117 p := &Postgres{}
118 d, err := p.Open(addr)
119 if err != nil {
120 t.Fatal(err)
121 }
122 defer func() {
123 if err := d.Close(); err != nil {
124 t.Error(err)
125 }
126 }()
127 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d)
128 if err != nil {
129 t.Fatal(err)
130 }
131 dt.TestMigrate(t, m)
132 })
133 }
134
135 func TestMultipleStatements(t *testing.T) {
136 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
137 ip, port, err := c.FirstPort()
138 if err != nil {
139 t.Fatal(err)
140 }
141
142 addr := pgConnectionString(ip, port)
143 p := &Postgres{}
144 d, err := p.Open(addr)
145 if err != nil {
146 t.Fatal(err)
147 }
148 defer func() {
149 if err := d.Close(); err != nil {
150 t.Error(err)
151 }
152 }()
153 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
154 t.Fatalf("expected err to be nil, got %v", err)
155 }
156
157
158 var exists bool
159 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
160 t.Fatal(err)
161 }
162 if !exists {
163 t.Fatalf("expected table bar to exist")
164 }
165 })
166 }
167
168 func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
169 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
170 ip, port, err := c.FirstPort()
171 if err != nil {
172 t.Fatal(err)
173 }
174
175 addr := pgConnectionString(ip, port, "x-multi-statement=true")
176 p := &Postgres{}
177 d, err := p.Open(addr)
178 if err != nil {
179 t.Fatal(err)
180 }
181 defer func() {
182 if err := d.Close(); err != nil {
183 t.Error(err)
184 }
185 }()
186 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
187 t.Fatalf("expected err to be nil, got %v", err)
188 }
189
190
191 var exists bool
192 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
193 t.Fatal(err)
194 }
195 if !exists {
196 t.Fatalf("expected table bar to exist")
197 }
198 })
199 }
200
201 func TestErrorParsing(t *testing.T) {
202 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
203 ip, port, err := c.FirstPort()
204 if err != nil {
205 t.Fatal(err)
206 }
207
208 addr := pgConnectionString(ip, port)
209 p := &Postgres{}
210 d, err := p.Open(addr)
211 if err != nil {
212 t.Fatal(err)
213 }
214 defer func() {
215 if err := d.Close(); err != nil {
216 t.Error(err)
217 }
218 }()
219
220 wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
221 `(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")`
222 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
223 t.Fatal("expected err but got nil")
224 } else if err.Error() != wantErr {
225 t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
226 }
227 })
228 }
229
230 func TestFilterCustomQuery(t *testing.T) {
231 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
232 ip, port, err := c.FirstPort()
233 if err != nil {
234 t.Fatal(err)
235 }
236
237 addr := fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-custom=foobar",
238 pgPassword, ip, port)
239 p := &Postgres{}
240 d, err := p.Open(addr)
241 if err != nil {
242 t.Fatal(err)
243 }
244 defer func() {
245 if err := d.Close(); err != nil {
246 t.Error(err)
247 }
248 }()
249 })
250 }
251
252 func TestWithSchema(t *testing.T) {
253 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
254 ip, port, err := c.FirstPort()
255 if err != nil {
256 t.Fatal(err)
257 }
258
259 addr := pgConnectionString(ip, port)
260 p := &Postgres{}
261 d, err := p.Open(addr)
262 if err != nil {
263 t.Fatal(err)
264 }
265 defer func() {
266 if err := d.Close(); err != nil {
267 t.Fatal(err)
268 }
269 }()
270
271
272 if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
273 t.Fatal(err)
274 }
275 if err := d.SetVersion(1, false); err != nil {
276 t.Fatal(err)
277 }
278
279
280 d2, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar",
281 pgPassword, ip, port))
282 if err != nil {
283 t.Fatal(err)
284 }
285 defer func() {
286 if err := d2.Close(); err != nil {
287 t.Fatal(err)
288 }
289 }()
290
291 version, _, err := d2.Version()
292 if err != nil {
293 t.Fatal(err)
294 }
295 if version != database.NilVersion {
296 t.Fatal("expected NilVersion")
297 }
298
299
300 if err := d2.SetVersion(2, false); err != nil {
301 t.Fatal(err)
302 }
303 version, _, err = d2.Version()
304 if err != nil {
305 t.Fatal(err)
306 }
307 if version != 2 {
308 t.Fatal("expected version 2")
309 }
310
311
312 version, _, err = d.Version()
313 if err != nil {
314 t.Fatal(err)
315 }
316 if version != 1 {
317 t.Fatal("expected version 2")
318 }
319 })
320 }
321
322 func TestMigrationTableOption(t *testing.T) {
323 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
324 ip, port, err := c.FirstPort()
325 if err != nil {
326 t.Fatal(err)
327 }
328
329 addr := pgConnectionString(ip, port)
330 p := &Postgres{}
331 d, _ := p.Open(addr)
332 defer func() {
333 if err := d.Close(); err != nil {
334 t.Fatal(err)
335 }
336 }()
337
338
339 if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
340 t.Fatal(err)
341 }
342
343
344 wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
345 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
346 pgPassword, ip, port))
347 if (err != nil) && (err.Error() != wantErr) {
348 t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
349 }
350
351
352 wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
353 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
354 pgPassword, ip, port))
355 if (err != nil) && (err.Error() != wantErr) {
356 t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
357 }
358
359
360 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
361 pgPassword, ip, port))
362 if err != nil {
363 t.Fatal(err)
364 }
365
366
367 var exists bool
368 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
369 t.Fatal(err)
370 }
371 if !exists {
372 t.Fatalf("expected table migrate.schema_migrations to exist")
373 }
374
375 d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
376 pgPassword, ip, port))
377 if err != nil {
378 t.Fatal(err)
379 }
380 if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
381 t.Fatal(err)
382 }
383 if !exists {
384 t.Fatalf("expected table 'migrate.schema_migrations' to exist")
385 }
386
387 })
388 }
389
390 func TestFailToCreateTableWithoutPermissions(t *testing.T) {
391 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
392 ip, port, err := c.FirstPort()
393 if err != nil {
394 t.Fatal(err)
395 }
396
397 addr := pgConnectionString(ip, port)
398
399
400 p := &Postgres{}
401
402 d, err := p.Open(addr)
403
404 if err != nil {
405 t.Fatal(err)
406 }
407
408 defer func() {
409 if err := d.Close(); err != nil {
410 t.Error(err)
411 }
412 }()
413
414
415
416 mustRun(t, d, []string{
417 "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
418 "CREATE SCHEMA barfoo AUTHORIZATION postgres",
419 "GRANT USAGE ON SCHEMA barfoo TO not_owner",
420 "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
421 "REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
422 })
423
424
425 d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
426 pgPassword, ip, port))
427
428 defer func() {
429 if d2 == nil {
430 return
431 }
432 if err := d2.Close(); err != nil {
433 t.Fatal(err)
434 }
435 }()
436
437 var e *database.Error
438 if !errors.As(err, &e) || err == nil {
439 t.Fatal("Unexpected error, want permission denied error. Got: ", err)
440 }
441
442 if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
443 t.Fatal(e)
444 }
445
446
447 d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
448 pgPassword, ip, port))
449
450 if !errors.As(err, &e) || err == nil {
451 t.Fatal("Unexpected error, want permission denied error. Got: ", err)
452 }
453
454 if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
455 t.Fatal(e)
456 }
457 })
458 }
459
460 func TestCheckBeforeCreateTable(t *testing.T) {
461 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
462 ip, port, err := c.FirstPort()
463 if err != nil {
464 t.Fatal(err)
465 }
466
467 addr := pgConnectionString(ip, port)
468
469
470 p := &Postgres{}
471
472 d, err := p.Open(addr)
473
474 if err != nil {
475 t.Fatal(err)
476 }
477
478 defer func() {
479 if err := d.Close(); err != nil {
480 t.Error(err)
481 }
482 }()
483
484
485
486 mustRun(t, d, []string{
487 "CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
488 "CREATE SCHEMA barfoo AUTHORIZATION postgres",
489 "GRANT USAGE ON SCHEMA barfoo TO not_owner",
490 "GRANT CREATE ON SCHEMA barfoo TO not_owner",
491 })
492
493
494 d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
495 pgPassword, ip, port))
496
497 if err != nil {
498 t.Fatal(err)
499 }
500
501 if err := d2.Close(); err != nil {
502 t.Fatal(err)
503 }
504
505
506 mustRun(t, d, []string{
507 "REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
508 "REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
509 })
510
511
512 d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
513 pgPassword, ip, port))
514
515 if err != nil {
516 t.Fatal(err)
517 }
518
519 version, _, err := d3.Version()
520
521 if err != nil {
522 t.Fatal(err)
523 }
524
525 if version != database.NilVersion {
526 t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
527 }
528
529 defer func() {
530 if err := d3.Close(); err != nil {
531 t.Fatal(err)
532 }
533 }()
534 })
535 }
536
537 func TestParallelSchema(t *testing.T) {
538 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
539 ip, port, err := c.FirstPort()
540 if err != nil {
541 t.Fatal(err)
542 }
543
544 addr := pgConnectionString(ip, port)
545 p := &Postgres{}
546 d, err := p.Open(addr)
547 if err != nil {
548 t.Fatal(err)
549 }
550 defer func() {
551 if err := d.Close(); err != nil {
552 t.Error(err)
553 }
554 }()
555
556
557 if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
558 t.Fatal(err)
559 }
560 if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
561 t.Fatal(err)
562 }
563
564
565 dfoo, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foo",
566 pgPassword, ip, port))
567 if err != nil {
568 t.Fatal(err)
569 }
570 defer func() {
571 if err := dfoo.Close(); err != nil {
572 t.Error(err)
573 }
574 }()
575
576 dbar, err := p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=bar",
577 pgPassword, ip, port))
578 if err != nil {
579 t.Fatal(err)
580 }
581 defer func() {
582 if err := dbar.Close(); err != nil {
583 t.Error(err)
584 }
585 }()
586
587 if err := dfoo.Lock(); err != nil {
588 t.Fatal(err)
589 }
590
591 if err := dbar.Lock(); err != nil {
592 t.Fatal(err)
593 }
594
595 if err := dbar.Unlock(); err != nil {
596 t.Fatal(err)
597 }
598
599 if err := dfoo.Unlock(); err != nil {
600 t.Fatal(err)
601 }
602 })
603 }
604
605 func TestPostgres_Lock(t *testing.T) {
606 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
607 ip, port, err := c.FirstPort()
608 if err != nil {
609 t.Fatal(err)
610 }
611
612 addr := pgConnectionString(ip, port)
613 p := &Postgres{}
614 d, err := p.Open(addr)
615 if err != nil {
616 t.Fatal(err)
617 }
618
619 dt.Test(t, d, []byte("SELECT 1"))
620
621 ps := d.(*Postgres)
622
623 err = ps.Lock()
624 if err != nil {
625 t.Fatal(err)
626 }
627
628 err = ps.Unlock()
629 if err != nil {
630 t.Fatal(err)
631 }
632
633 err = ps.Lock()
634 if err != nil {
635 t.Fatal(err)
636 }
637
638 err = ps.Unlock()
639 if err != nil {
640 t.Fatal(err)
641 }
642 })
643 }
644
645 func TestWithInstance_Concurrent(t *testing.T) {
646 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
647 ip, port, err := c.FirstPort()
648 if err != nil {
649 t.Fatal(err)
650 }
651
652
653 const concurrency = 30
654
655
656
657
658
659 db, err := sql.Open("postgres", pgConnectionString(ip, port))
660 if err != nil {
661 t.Fatal(err)
662 }
663 defer func() {
664 if err := db.Close(); err != nil {
665 t.Error(err)
666 }
667 }()
668
669 db.SetMaxIdleConns(concurrency)
670 db.SetMaxOpenConns(concurrency)
671
672 var wg sync.WaitGroup
673 defer wg.Wait()
674
675 wg.Add(concurrency)
676 for i := 0; i < concurrency; i++ {
677 go func(i int) {
678 defer wg.Done()
679 _, err := WithInstance(db, &Config{})
680 if err != nil {
681 t.Errorf("process %d error: %s", i, err)
682 }
683 }(i)
684 }
685 })
686 }
687 func Test_computeLineFromPos(t *testing.T) {
688 testcases := []struct {
689 pos int
690 wantLine uint
691 wantCol uint
692 input string
693 wantOk bool
694 }{
695 {
696 15, 2, 6, "SELECT *\nFROM foo", true,
697 },
698 {
699 16, 3, 6, "SELECT *\n\nFROM foo", true,
700 },
701 {
702 25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true,
703 },
704 {
705 27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true,
706 },
707 {
708 10, 2, 1, "SELECT *\nFROMM foo", true,
709 },
710 {
711 11, 3, 1, "SELECT *\n\nFROMM foo", true,
712 },
713 {
714 17, 2, 8, "SELECT *\nFROM foo", true,
715 },
716 {
717 18, 0, 0, "SELECT *\nFROM foo", false,
718 },
719 }
720 for i, tc := range testcases {
721 t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
722 run := func(crlf bool, nonASCII bool) {
723 var name string
724 if crlf {
725 name = "crlf"
726 } else {
727 name = "lf"
728 }
729 if nonASCII {
730 name += "-nonascii"
731 } else {
732 name += "-ascii"
733 }
734 t.Run(name, func(t *testing.T) {
735 input := tc.input
736 if crlf {
737 input = strings.Replace(input, "\n", "\r\n", -1)
738 }
739 if nonASCII {
740 input = strings.Replace(input, "FROM", "FRÖM", -1)
741 }
742 gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
743
744 if tc.wantOk {
745 t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
746 }
747
748 if gotOK != tc.wantOk {
749 t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
750 }
751 if gotLine != tc.wantLine {
752 t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
753 }
754 if gotCol != tc.wantCol {
755 t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
756 }
757 })
758 }
759 run(false, false)
760 run(true, false)
761 run(false, true)
762 run(true, true)
763 })
764 }
765 }
766
View as plain text