1 package embeddedpostgres
2
3 import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "net"
8 "os"
9 "os/user"
10 "path"
11 "path/filepath"
12 "strings"
13 "sync"
14 "testing"
15 "time"
16
17 "github.com/stretchr/testify/assert"
18 "github.com/stretchr/testify/require"
19 )
20
21 func Test_DefaultConfig(t *testing.T) {
22 defer verifyLeak(t)
23
24 database := NewDatabase()
25 if err := database.Start(); err != nil {
26 shutdownDBAndFail(t, err, database)
27 }
28
29 db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
30 if err != nil {
31 shutdownDBAndFail(t, err, database)
32 }
33
34 if err = db.Ping(); err != nil {
35 shutdownDBAndFail(t, err, database)
36 }
37
38 if err := db.Close(); err != nil {
39 shutdownDBAndFail(t, err, database)
40 }
41
42 if err := database.Stop(); err != nil {
43 shutdownDBAndFail(t, err, database)
44 }
45 }
46
47 func Test_ErrorWhenPortAlreadyTaken(t *testing.T) {
48 listener, err := net.Listen("tcp", "localhost:9887")
49 if err != nil {
50 panic(err)
51 }
52
53 defer func() {
54 if err := listener.Close(); err != nil {
55 panic(err)
56 }
57 }()
58
59 database := NewDatabase(DefaultConfig().
60 Port(9887))
61
62 err = database.Start()
63
64 assert.EqualError(t, err, "process already listening on port 9887")
65 }
66
67 func Test_ErrorWhenRemoteFetchError(t *testing.T) {
68 database := NewDatabase()
69 database.cacheLocator = func() (string, bool) {
70 return "", false
71 }
72 database.remoteFetchStrategy = func() error {
73 return errors.New("did not work")
74 }
75
76 err := database.Start()
77
78 assert.EqualError(t, err, "did not work")
79 }
80
81 func Test_ErrorWhenUnableToUnArchiveFile_WrongFormat(t *testing.T) {
82 jarFile, cleanUp := createTempZipArchive()
83 defer cleanUp()
84
85 database := NewDatabase(DefaultConfig().
86 Username("gin").
87 Password("wine").
88 Database("beer").
89 StartTimeout(10 * time.Second))
90
91 database.cacheLocator = func() (string, bool) {
92 return jarFile, true
93 }
94
95 err := database.Start()
96
97 if err == nil {
98 if err := database.Stop(); err != nil {
99 panic(err)
100 }
101 }
102
103 assert.EqualError(t, err, fmt.Sprintf(`unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, xz: file format not recognized`, jarFile, filepath.Join(filepath.Dir(jarFile), "extracted")))
104 }
105
106 func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
107 jarFile, cleanUp := createTempXzArchive()
108 defer cleanUp()
109
110 extractPath, err := os.MkdirTemp(filepath.Dir(jarFile), "extract")
111 if err != nil {
112 panic(err)
113 }
114
115 database := NewDatabase(DefaultConfig().
116 Username("gin").
117 Password("wine").
118 Database("beer").
119 RuntimePath(extractPath).
120 StartTimeout(10 * time.Second))
121
122 database.cacheLocator = func() (string, bool) {
123 return jarFile, true
124 }
125
126 database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger *os.File) error {
127 return errors.New("ah it did not work")
128 }
129
130 err = database.Start()
131
132 if err == nil {
133 if err := database.Stop(); err != nil {
134 panic(err)
135 }
136 }
137
138 assert.EqualError(t, err, "ah it did not work")
139 }
140
141 func Test_ErrorWhenUnableToCreateDatabase(t *testing.T) {
142 jarFile, cleanUp := createTempXzArchive()
143
144 defer cleanUp()
145
146 extractPath, err := os.MkdirTemp(filepath.Dir(jarFile), "extract")
147
148 if err != nil {
149 panic(err)
150 }
151
152 database := NewDatabase(DefaultConfig().
153 Username("gin").
154 Password("wine").
155 Database("beer").
156 RuntimePath(extractPath).
157 StartTimeout(10 * time.Second))
158
159 database.createDatabase = func(port uint32, username, password, database string) error {
160 return errors.New("ah noes")
161 }
162
163 err = database.Start()
164
165 if err == nil {
166 if err := database.Stop(); err != nil {
167 panic(err)
168 }
169 }
170
171 assert.EqualError(t, err, "ah noes")
172 }
173
174 func Test_TimesOutWhenCannotStart(t *testing.T) {
175 database := NewDatabase(DefaultConfig().
176 Database("something-fancy").
177 StartTimeout(500 * time.Millisecond))
178
179 database.createDatabase = func(port uint32, username, password, database string) error {
180 return nil
181 }
182
183 err := database.Start()
184
185 assert.EqualError(t, err, "timed out waiting for database to become available")
186 }
187
188 func Test_ErrorWhenStopCalledBeforeStart(t *testing.T) {
189 database := NewDatabase()
190
191 err := database.Stop()
192
193 assert.EqualError(t, err, "server has not been started")
194 }
195
196 func Test_ErrorWhenStartCalledWhenAlreadyStarted(t *testing.T) {
197 database := NewDatabase()
198
199 defer func() {
200 if err := database.Stop(); err != nil {
201 t.Fatal(err)
202 }
203 }()
204
205 err := database.Start()
206 assert.NoError(t, err)
207
208 err = database.Start()
209 assert.EqualError(t, err, "server is already started")
210 }
211
212 func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) {
213 jarFile, cleanUp := createTempXzArchive()
214
215 defer cleanUp()
216
217 extractPath, err := os.MkdirTemp(filepath.Dir(jarFile), "extract")
218 if err != nil {
219 panic(err)
220 }
221
222 database := NewDatabase(DefaultConfig().
223 RuntimePath(extractPath))
224
225 database.cacheLocator = func() (string, bool) {
226 return jarFile, true
227 }
228
229 database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger *os.File) error {
230 _, _ = logger.Write([]byte("ah it did not work"))
231 return nil
232 }
233
234 err = database.Start()
235
236 assert.EqualError(t, err, fmt.Sprintf("could not start postgres using %s/bin/pg_ctl start -w -D %s/data -o -p 5432:\nah it did not work", extractPath, extractPath))
237 }
238
239 func Test_CustomConfig(t *testing.T) {
240 tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
241 if err != nil {
242 panic(err)
243 }
244
245 defer func() {
246 if err := os.RemoveAll(tempDir); err != nil {
247 panic(err)
248 }
249 }()
250
251 database := NewDatabase(DefaultConfig().
252 Username("gin").
253 Password("wine").
254 Database("beer").
255 Version(V15).
256 RuntimePath(tempDir).
257 Port(9876).
258 StartTimeout(10 * time.Second).
259 Locale("C").
260 Logger(nil))
261
262 if err := database.Start(); err != nil {
263 shutdownDBAndFail(t, err, database)
264 }
265
266 db, err := sql.Open("postgres", "host=localhost port=9876 user=gin password=wine dbname=beer sslmode=disable")
267 if err != nil {
268 shutdownDBAndFail(t, err, database)
269 }
270
271 if err = db.Ping(); err != nil {
272 shutdownDBAndFail(t, err, database)
273 }
274
275 if err := db.Close(); err != nil {
276 shutdownDBAndFail(t, err, database)
277 }
278
279 if err := database.Stop(); err != nil {
280 shutdownDBAndFail(t, err, database)
281 }
282 }
283
284 func Test_CustomLog(t *testing.T) {
285 tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
286 if err != nil {
287 panic(err)
288 }
289
290 defer func() {
291 if err := os.RemoveAll(tempDir); err != nil {
292 panic(err)
293 }
294 }()
295
296 logger := customLogger{}
297
298 database := NewDatabase(DefaultConfig().
299 Logger(&logger))
300
301 if err := database.Start(); err != nil {
302 shutdownDBAndFail(t, err, database)
303 }
304
305 db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
306 if err != nil {
307 shutdownDBAndFail(t, err, database)
308 }
309
310 if err = db.Ping(); err != nil {
311 shutdownDBAndFail(t, err, database)
312 }
313
314 if err := db.Close(); err != nil {
315 shutdownDBAndFail(t, err, database)
316 }
317
318 if err := database.Stop(); err != nil {
319 shutdownDBAndFail(t, err, database)
320 }
321
322 current, err := user.Current()
323
324 lines := strings.Split(string(logger.logLines), "\n")
325
326 assert.NoError(t, err)
327 assert.Contains(t, lines, fmt.Sprintf("The files belonging to this database system will be owned by user \"%s\".", current.Username))
328 assert.Contains(t, lines, "syncing data to disk ... ok")
329 assert.Contains(t, lines, "server stopped")
330 assert.Less(t, len(lines), 55)
331 assert.Greater(t, len(lines), 40)
332 }
333
334 func Test_CustomLocaleConfig(t *testing.T) {
335
336 database := NewDatabase(DefaultConfig().Locale("C"))
337 if err := database.Start(); err != nil {
338 shutdownDBAndFail(t, err, database)
339 }
340
341 db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
342 if err != nil {
343 shutdownDBAndFail(t, err, database)
344 }
345
346 if err = db.Ping(); err != nil {
347 shutdownDBAndFail(t, err, database)
348 }
349
350 if err := db.Close(); err != nil {
351 shutdownDBAndFail(t, err, database)
352 }
353
354 if err := database.Stop(); err != nil {
355 shutdownDBAndFail(t, err, database)
356 }
357 }
358
359 func Test_ConcurrentStart(t *testing.T) {
360 var wg sync.WaitGroup
361
362 database := NewDatabase()
363 cacheLocation, _ := database.cacheLocator()
364 err := os.RemoveAll(cacheLocation)
365 require.NoError(t, err)
366
367 port := 5432
368 for i := 1; i <= 3; i++ {
369 port = port + 1
370 wg.Add(1)
371
372 go func(p int) {
373 defer wg.Done()
374 tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
375 if err != nil {
376 panic(err)
377 }
378
379 defer func() {
380 if err := os.RemoveAll(tempDir); err != nil {
381 panic(err)
382 }
383 }()
384
385 database := NewDatabase(DefaultConfig().
386 RuntimePath(tempDir).
387 Port(uint32(p)))
388
389 if err := database.Start(); err != nil {
390 shutdownDBAndFail(t, err, database)
391 }
392
393 db, err := sql.Open(
394 "postgres",
395 fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", p),
396 )
397 if err != nil {
398 shutdownDBAndFail(t, err, database)
399 }
400
401 if err = db.Ping(); err != nil {
402 shutdownDBAndFail(t, err, database)
403 }
404
405 if err := db.Close(); err != nil {
406 shutdownDBAndFail(t, err, database)
407 }
408
409 if err := database.Stop(); err != nil {
410 shutdownDBAndFail(t, err, database)
411 }
412
413 }(port)
414 }
415
416 wg.Wait()
417 }
418
419 func Test_CustomStartParameters(t *testing.T) {
420 database := NewDatabase(DefaultConfig().StartParameters(map[string]string{
421 "max_connections": "101",
422 "shared_buffers": "16 MB",
423 }))
424 if err := database.Start(); err != nil {
425 shutdownDBAndFail(t, err, database)
426 }
427
428 db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
429 if err != nil {
430 shutdownDBAndFail(t, err, database)
431 }
432
433 if err := db.Ping(); err != nil {
434 shutdownDBAndFail(t, err, database)
435 }
436
437 row := db.QueryRow("SHOW max_connections")
438 var res string
439 if err := row.Scan(&res); err != nil {
440 shutdownDBAndFail(t, err, database)
441 }
442 assert.Equal(t, "101", res)
443
444 if err := db.Close(); err != nil {
445 shutdownDBAndFail(t, err, database)
446 }
447
448 if err := database.Stop(); err != nil {
449 shutdownDBAndFail(t, err, database)
450 }
451 }
452
453 func Test_CanStartAndStopTwice(t *testing.T) {
454 database := NewDatabase()
455
456 if err := database.Start(); err != nil {
457 shutdownDBAndFail(t, err, database)
458 }
459
460 db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
461 if err != nil {
462 shutdownDBAndFail(t, err, database)
463 }
464
465 if err = db.Ping(); err != nil {
466 shutdownDBAndFail(t, err, database)
467 }
468
469 if err := db.Close(); err != nil {
470 shutdownDBAndFail(t, err, database)
471 }
472
473 if err := database.Stop(); err != nil {
474 shutdownDBAndFail(t, err, database)
475 }
476
477 if err := database.Start(); err != nil {
478 shutdownDBAndFail(t, err, database)
479 }
480
481 db, err = sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
482 if err != nil {
483 shutdownDBAndFail(t, err, database)
484 }
485
486 if err = db.Ping(); err != nil {
487 shutdownDBAndFail(t, err, database)
488 }
489
490 if err := db.Close(); err != nil {
491 shutdownDBAndFail(t, err, database)
492 }
493
494 if err := database.Stop(); err != nil {
495 shutdownDBAndFail(t, err, database)
496 }
497 }
498
499 func Test_ReuseData(t *testing.T) {
500 tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
501 if err != nil {
502 panic(err)
503 }
504
505 defer func() {
506 if err := os.RemoveAll(tempDir); err != nil {
507 panic(err)
508 }
509 }()
510
511 database := NewDatabase(DefaultConfig().DataPath(tempDir))
512
513 if err := database.Start(); err != nil {
514 shutdownDBAndFail(t, err, database)
515 }
516
517 db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
518 if err != nil {
519 shutdownDBAndFail(t, err, database)
520 }
521
522 if _, err = db.Exec("CREATE TABLE test(id serial, value text, PRIMARY KEY(id))"); err != nil {
523 shutdownDBAndFail(t, err, database)
524 }
525
526 if _, err = db.Exec("INSERT INTO test (value) VALUES ('foobar')"); err != nil {
527 shutdownDBAndFail(t, err, database)
528 }
529
530 if err := db.Close(); err != nil {
531 shutdownDBAndFail(t, err, database)
532 }
533
534 if err := database.Stop(); err != nil {
535 shutdownDBAndFail(t, err, database)
536 }
537
538 database = NewDatabase(DefaultConfig().DataPath(tempDir))
539
540 if err := database.Start(); err != nil {
541 shutdownDBAndFail(t, err, database)
542 }
543
544 db, err = sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
545 if err != nil {
546 shutdownDBAndFail(t, err, database)
547 }
548
549 if rows, err := db.Query("SELECT * FROM test"); err != nil {
550 shutdownDBAndFail(t, err, database)
551 } else {
552 if !rows.Next() {
553 shutdownDBAndFail(t, errors.New("no row from db"), database)
554 }
555
556 var (
557 id int64
558 value string
559 )
560 if err := rows.Scan(&id, &value); err != nil {
561 shutdownDBAndFail(t, err, database)
562 }
563 if value != "foobar" {
564 shutdownDBAndFail(t, errors.New("wrong value from db"), database)
565 }
566 }
567
568 if err := db.Close(); err != nil {
569 shutdownDBAndFail(t, err, database)
570 }
571
572 if err := database.Stop(); err != nil {
573 shutdownDBAndFail(t, err, database)
574 }
575 }
576
577 func Test_CustomBinariesRepo(t *testing.T) {
578 tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
579 if err != nil {
580 panic(err)
581 }
582
583 defer func() {
584 if err := os.RemoveAll(tempDir); err != nil {
585 panic(err)
586 }
587 }()
588
589 database := NewDatabase(DefaultConfig().
590 Username("gin").
591 Password("wine").
592 Database("beer").
593 Version(V15).
594 RuntimePath(tempDir).
595 BinaryRepositoryURL("https://repo.maven.apache.org/maven2").
596 Port(9876).
597 StartTimeout(10 * time.Second).
598 Locale("C").
599 Logger(nil))
600
601 if err := database.Start(); err != nil {
602 shutdownDBAndFail(t, err, database)
603 }
604
605 db, err := sql.Open("postgres", "host=localhost port=9876 user=gin password=wine dbname=beer sslmode=disable")
606 if err != nil {
607 shutdownDBAndFail(t, err, database)
608 }
609
610 if err = db.Ping(); err != nil {
611 shutdownDBAndFail(t, err, database)
612 }
613
614 if err := db.Close(); err != nil {
615 shutdownDBAndFail(t, err, database)
616 }
617
618 if err := database.Stop(); err != nil {
619 shutdownDBAndFail(t, err, database)
620 }
621 }
622
623 func Test_CachePath(t *testing.T) {
624 cacheTempDir, err := os.MkdirTemp("", "prepare_database_test_cache")
625 if err != nil {
626 panic(err)
627 }
628
629 defer func() {
630 if err := os.RemoveAll(cacheTempDir); err != nil {
631 panic(err)
632 }
633 }()
634
635 database := NewDatabase(DefaultConfig().
636 CachePath(cacheTempDir))
637
638 if err := database.Start(); err != nil {
639 shutdownDBAndFail(t, err, database)
640 }
641
642 if err := database.Stop(); err != nil {
643 shutdownDBAndFail(t, err, database)
644 }
645 }
646
647 func Test_CustomBinariesLocation(t *testing.T) {
648 tempDir, err := os.MkdirTemp("", "prepare_database_test")
649 if err != nil {
650 panic(err)
651 }
652
653 defer func() {
654 if err := os.RemoveAll(tempDir); err != nil {
655 panic(err)
656 }
657 }()
658
659 database := NewDatabase(DefaultConfig().
660 BinariesPath(tempDir))
661
662 if err := database.Start(); err != nil {
663 shutdownDBAndFail(t, err, database)
664 }
665
666 if err := database.Stop(); err != nil {
667 shutdownDBAndFail(t, err, database)
668 }
669
670
671 cacheLocation, _ := database.cacheLocator()
672 if err := os.RemoveAll(cacheLocation); err != nil {
673 panic(err)
674 }
675
676 if err := database.Start(); err != nil {
677 shutdownDBAndFail(t, err, database)
678 }
679
680 if err := database.Stop(); err != nil {
681 shutdownDBAndFail(t, err, database)
682 }
683 }
684
685 func Test_PrefetchedBinaries(t *testing.T) {
686 binTempDir, err := os.MkdirTemp("", "prepare_database_test_bin")
687 if err != nil {
688 panic(err)
689 }
690
691 runtimeTempDir, err := os.MkdirTemp("", "prepare_database_test_runtime")
692 if err != nil {
693 panic(err)
694 }
695
696 defer func() {
697 if err := os.RemoveAll(binTempDir); err != nil {
698 panic(err)
699 }
700
701 if err := os.RemoveAll(runtimeTempDir); err != nil {
702 panic(err)
703 }
704 }()
705
706 database := NewDatabase(DefaultConfig().
707 BinariesPath(binTempDir).
708 RuntimePath(runtimeTempDir))
709
710
711 if err := database.remoteFetchStrategy(); err != nil {
712 panic(err)
713 }
714
715 cacheLocation, _ := database.cacheLocator()
716 if err := decompressTarXz(defaultTarReader, cacheLocation, binTempDir); err != nil {
717 panic(err)
718 }
719
720
721 database.cacheLocator = func() (string, bool) {
722 return "", false
723 }
724 database.remoteFetchStrategy = func() error {
725 return errors.New("did not work")
726 }
727
728 if err := database.Start(); err != nil {
729 shutdownDBAndFail(t, err, database)
730 }
731
732 if err := database.Stop(); err != nil {
733 shutdownDBAndFail(t, err, database)
734 }
735 }
736
737 func Test_RunningInParallel(t *testing.T) {
738 tempPath, err := os.MkdirTemp("", "parallel_tests_path")
739 if err != nil {
740 panic(err)
741 }
742
743 waitGroup := sync.WaitGroup{}
744 waitGroup.Add(2)
745
746 runTestWithPortAndPath := func(port uint32, path string) {
747 defer waitGroup.Done()
748
749 database := NewDatabase(DefaultConfig().Port(port).RuntimePath(path))
750 if err := database.Start(); err != nil {
751 shutdownDBAndFail(t, err, database)
752 }
753
754 db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", port))
755 if err != nil {
756 shutdownDBAndFail(t, err, database)
757 }
758
759 if err = db.Ping(); err != nil {
760 shutdownDBAndFail(t, err, database)
761 }
762
763 if err := db.Close(); err != nil {
764 shutdownDBAndFail(t, err, database)
765 }
766
767 if err := database.Stop(); err != nil {
768 shutdownDBAndFail(t, err, database)
769 }
770 }
771
772 go runTestWithPortAndPath(8765, path.Join(tempPath, "1"))
773 go runTestWithPortAndPath(8766, path.Join(tempPath, "2"))
774
775 waitGroup.Wait()
776 }
777
View as plain text