1 package seededpostgres
2
3 import (
4 "database/sql"
5 "fmt"
6 "net"
7 "os"
8 "path"
9
10 edgesql "edge-infra.dev/pkg/edge/api/sql"
11 "edge-infra.dev/pkg/edge/api/sql/plugin"
12 "edge-infra.dev/pkg/lib/build/bazel"
13 "edge-infra.dev/pkg/lib/compression"
14 "edge-infra.dev/pkg/lib/gcp/cloudsql"
15 "edge-infra.dev/pkg/lib/logging"
16
17 "github.com/bazelbuild/rules_go/go/runfiles"
18 embeddedpostgres "github.com/fergusstrange/embedded-postgres"
19 "github.com/golang-migrate/migrate/v4/database/postgres"
20 )
21
22 var PostgresVersion = embeddedpostgres.V14
23
24
25 type SeededPostgres struct {
26 dbname string
27 username string
28 password string
29 port int
30 tempDir string
31 ep *embeddedpostgres.EmbeddedPostgres
32 }
33
34
35 func New() (*SeededPostgres, error) {
36 return NewWithUser("postgres", "postgres", "postgres")
37 }
38
39 func NewWithUser(dbname, username, password string) (*SeededPostgres, error) {
40 if dbname == "" || username == "" || password == "" {
41 return nil, fmt.Errorf("NewWithUser arguments must not be empty: dbname=%q username=%q password=%q", dbname, username, password)
42 }
43
44 var cfg = embeddedpostgres.DefaultConfig()
45 cfg = cfg.Version(PostgresVersion)
46 cfg = cfg.Database(dbname)
47 cfg = cfg.Username(username)
48 cfg = cfg.Password(password)
49
50 var port, err = findUnusedPort()
51 if err != nil {
52 return nil, err
53 }
54 cfg = cfg.Port(uint32(port))
55
56 var tempDir string
57 if bazel.IsBazelTest() || bazel.IsBazelRun() {
58 embeddedTxzFile, err := runfiles.Rlocation(path.Join("edge_infra", "hack", "tools", "postgres.txz"))
59 if err != nil {
60 return nil, err
61 }
62 tempDir, err = bazel.NewTestTmpDir("edge-infra-api-test-*")
63 if err != nil {
64 return nil, err
65 }
66 err = compression.DecompressTarXz(embeddedTxzFile, tempDir)
67 if err != nil {
68 return nil, err
69 }
70
71 cfg = cfg.RuntimePath(path.Join(tempDir, "runtime"))
72 cfg = cfg.BinariesPath(tempDir)
73 }
74
75 var sp = &SeededPostgres{
76 dbname: dbname,
77 username: username,
78 password: password,
79 port: port,
80 tempDir: tempDir,
81 ep: embeddedpostgres.NewDatabase(cfg),
82 }
83
84 err = sp.ep.Start()
85 if err != nil {
86 _ = sp.Close()
87 return nil, err
88 }
89
90 db, err := sp.DB()
91 if err != nil {
92 _ = sp.Close()
93 return nil, err
94 }
95 defer db.Close()
96
97 driver, err := postgres.WithInstance(db, &postgres.Config{})
98 if err != nil {
99 _ = sp.Close()
100 return nil, err
101 }
102 defer driver.Close()
103
104 _, err = db.Exec("CREATE EXTENSION IF NOT EXISTS \"pgcrypto\"")
105 if err != nil {
106 _ = sp.Close()
107 return nil, err
108 }
109
110 var logger = logging.NewLogger().WithName("seededpostgres")
111 var pluginConfig = &plugin.Config{
112 MigrationAction: "up",
113 Ordered: true,
114 TestMode: true,
115 Data: Seed,
116 }
117
118 err = edgesql.SetupEdgeTables(pluginConfig, driver, logger, db)
119 if err != nil {
120 _ = sp.Close()
121 return nil, err
122 }
123 return sp, nil
124 }
125
126
127 func (sp *SeededPostgres) Close() error {
128 var errStop error
129 if sp.ep != nil {
130
131 errStop = sp.ep.Stop()
132 }
133
134 errDeleteTempDir := os.RemoveAll(sp.tempDir)
135 if errDeleteTempDir != nil {
136 return errDeleteTempDir
137 }
138
139 return errStop
140 }
141
142
143 func (sp *SeededPostgres) DB() (*sql.DB, error) {
144 return sp.EdgePostgres().NewConnection()
145 }
146
147 func (sp *SeededPostgres) EdgePostgres() *cloudsql.EdgePostgres {
148 return cloudsql.PostgresConnection("localhost", fmt.Sprint(sp.port)).DBName(sp.dbname).Username(sp.username).Password(sp.password)
149 }
150
151 func (sp *SeededPostgres) Port() int {
152 return sp.port
153 }
154
155 func findUnusedPort() (int, error) {
156 addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
157 if err != nil {
158 return 0, err
159 }
160 l, err := net.ListenTCP("tcp", addr)
161 if err != nil {
162 return 0, err
163 }
164 return l.Addr().(*net.TCPAddr).Port, l.Close()
165 }
166
View as plain text