1 package embeddedpostgres
2
3 import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "path/filepath"
12
13 "github.com/lib/pq"
14 )
15
16 const (
17 fmtCloseDBConn = "unable to close database connection: %w"
18 fmtAfterError = "%v happened after error: %w"
19 )
20
21 type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error
22 type createDatabase func(port uint32, username, password, database string) error
23
24 func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error {
25 passwordFile, err := createPasswordFile(runtimePath, password)
26 if err != nil {
27 return err
28 }
29
30 args := []string{
31 "-A", "password",
32 "-U", username,
33 "-D", pgDataDir,
34 fmt.Sprintf("--pwfile=%s", passwordFile),
35 }
36
37 if locale != "" {
38 args = append(args, fmt.Sprintf("--locale=%s", locale))
39 }
40
41 postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb")
42 postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...)
43 postgresInitDBProcess.Stderr = logger
44 postgresInitDBProcess.Stdout = logger
45
46 if err = postgresInitDBProcess.Run(); err != nil {
47 logContent, readLogsErr := readLogsOrTimeout(logger)
48 if readLogsErr != nil {
49 logContent = []byte(string(logContent) + " - " + readLogsErr.Error())
50 }
51 return fmt.Errorf("unable to init database using '%s': %w\n%s", postgresInitDBProcess.String(), err, string(logContent))
52 }
53
54 if err = os.Remove(passwordFile); err != nil {
55 return fmt.Errorf("unable to remove password file '%v': %w", passwordFile, err)
56 }
57
58 return nil
59 }
60
61 func createPasswordFile(runtimePath, password string) (string, error) {
62 passwordFileLocation := filepath.Join(runtimePath, "pwfile")
63 if err := os.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil {
64 return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation)
65 }
66
67 return passwordFileLocation, nil
68 }
69
70 func defaultCreateDatabase(port uint32, username, password, database string) (err error) {
71 if database == "postgres" {
72 return nil
73 }
74
75 conn, err := openDatabaseConnection(port, username, password, "postgres")
76 if err != nil {
77 return errorCustomDatabase(database, err)
78 }
79
80 db := sql.OpenDB(conn)
81 defer func() {
82 err = connectionClose(db, err)
83 }()
84
85 if _, err := db.Exec(fmt.Sprintf("CREATE DATABASE \"%s\"", database)); err != nil {
86 return errorCustomDatabase(database, err)
87 }
88
89 return nil
90 }
91
92
93 func connectionClose(db io.Closer, err error) error {
94 closeErr := db.Close()
95 if closeErr != nil {
96 closeErr = fmt.Errorf(fmtCloseDBConn, closeErr)
97
98 if err != nil {
99 err = fmt.Errorf(fmtAfterError, closeErr, err)
100 } else {
101 err = closeErr
102 }
103 }
104
105 return err
106 }
107
108 func healthCheckDatabaseOrTimeout(config Config) error {
109 healthCheckSignal := make(chan bool)
110
111 defer close(healthCheckSignal)
112
113 timeout, cancelFunc := context.WithTimeout(context.Background(), config.startTimeout)
114
115 defer cancelFunc()
116
117 go func() {
118 for timeout.Err() == nil {
119 if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil {
120 continue
121 }
122 healthCheckSignal <- true
123
124 break
125 }
126 }()
127
128 select {
129 case <-healthCheckSignal:
130 return nil
131 case <-timeout.Done():
132 return errors.New("timed out waiting for database to become available")
133 }
134 }
135
136 func healthCheckDatabase(port uint32, database, username, password string) (err error) {
137 conn, err := openDatabaseConnection(port, username, password, database)
138 if err != nil {
139 return err
140 }
141
142 db := sql.OpenDB(conn)
143 defer func() {
144 err = connectionClose(db, err)
145 }()
146
147 if _, err := db.Query("SELECT 1"); err != nil {
148 return err
149 }
150
151 return nil
152 }
153
154 func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) {
155 conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable",
156 port,
157 username,
158 password,
159 database))
160 if err != nil {
161 return nil, err
162 }
163
164 return conn, nil
165 }
166
167 func errorCustomDatabase(database string, err error) error {
168 return fmt.Errorf("unable to connect to create database with custom name %s with the following error: %s", database, err)
169 }
170
View as plain text