1 package cockroachdb
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "go.uber.org/atomic"
8 "io"
9 "io/ioutil"
10 nurl "net/url"
11 "regexp"
12 "strconv"
13 )
14
15 import (
16 "github.com/cockroachdb/cockroach-go/v2/crdb"
17 "github.com/hashicorp/go-multierror"
18 "github.com/lib/pq"
19 )
20
21 import (
22 "github.com/golang-migrate/migrate/v4"
23 "github.com/golang-migrate/migrate/v4/database"
24 )
25
26 func init() {
27 db := CockroachDb{}
28 database.Register("cockroach", &db)
29 database.Register("cockroachdb", &db)
30 database.Register("crdb-postgres", &db)
31 }
32
33 var DefaultMigrationsTable = "schema_migrations"
34 var DefaultLockTable = "schema_lock"
35
36 var (
37 ErrNilConfig = fmt.Errorf("no config")
38 ErrNoDatabaseName = fmt.Errorf("no database name")
39 )
40
41 type Config struct {
42 MigrationsTable string
43 LockTable string
44 ForceLock bool
45 DatabaseName string
46 }
47
48 type CockroachDb struct {
49 db *sql.DB
50 isLocked atomic.Bool
51
52
53 config *Config
54 }
55
56 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
57 if config == nil {
58 return nil, ErrNilConfig
59 }
60
61 if err := instance.Ping(); err != nil {
62 return nil, err
63 }
64
65 if config.DatabaseName == "" {
66 query := `SELECT current_database()`
67 var databaseName string
68 if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
69 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
70 }
71
72 if len(databaseName) == 0 {
73 return nil, ErrNoDatabaseName
74 }
75
76 config.DatabaseName = databaseName
77 }
78
79 if len(config.MigrationsTable) == 0 {
80 config.MigrationsTable = DefaultMigrationsTable
81 }
82
83 if len(config.LockTable) == 0 {
84 config.LockTable = DefaultLockTable
85 }
86
87 px := &CockroachDb{
88 db: instance,
89 config: config,
90 }
91
92
93 if err := px.ensureLockTable(); err != nil {
94 return nil, err
95 }
96
97 if err := px.ensureVersionTable(); err != nil {
98 return nil, err
99 }
100
101 return px, nil
102 }
103
104 func (c *CockroachDb) Open(url string) (database.Driver, error) {
105 purl, err := nurl.Parse(url)
106 if err != nil {
107 return nil, err
108 }
109
110
111
112 re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
113 connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
114
115 db, err := sql.Open("postgres", connectString)
116 if err != nil {
117 return nil, err
118 }
119
120 migrationsTable := purl.Query().Get("x-migrations-table")
121 if len(migrationsTable) == 0 {
122 migrationsTable = DefaultMigrationsTable
123 }
124
125 lockTable := purl.Query().Get("x-lock-table")
126 if len(lockTable) == 0 {
127 lockTable = DefaultLockTable
128 }
129
130 forceLockQuery := purl.Query().Get("x-force-lock")
131 forceLock, err := strconv.ParseBool(forceLockQuery)
132 if err != nil {
133 forceLock = false
134 }
135
136 px, err := WithInstance(db, &Config{
137 DatabaseName: purl.Path,
138 MigrationsTable: migrationsTable,
139 LockTable: lockTable,
140 ForceLock: forceLock,
141 })
142 if err != nil {
143 return nil, err
144 }
145
146 return px, nil
147 }
148
149 func (c *CockroachDb) Close() error {
150 return c.db.Close()
151 }
152
153
154
155 func (c *CockroachDb) Lock() error {
156 return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
157 return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
158 aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
159 if err != nil {
160 return err
161 }
162
163 query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
164 rows, err := tx.Query(query, aid)
165 if err != nil {
166 return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
167 }
168 defer func() {
169 if errClose := rows.Close(); errClose != nil {
170 err = multierror.Append(err, errClose)
171 }
172 }()
173
174
175 locked := rows.Next()
176 if locked && !c.config.ForceLock {
177 return database.ErrLocked
178 }
179
180 query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
181 if _, err := tx.Exec(query, aid); err != nil {
182 return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
183 }
184
185 return nil
186 })
187 })
188 }
189
190
191
192 func (c *CockroachDb) Unlock() error {
193 return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
194 aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
195 if err != nil {
196 return err
197 }
198
199
200
201 query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
202 if _, err := c.db.Exec(query, aid); err != nil {
203 if e, ok := err.(*pq.Error); ok {
204
205
206 if e.Code == "42P01" {
207
208 return nil
209 }
210 }
211
212 return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
213 }
214
215 return nil
216 })
217 }
218
219 func (c *CockroachDb) Run(migration io.Reader) error {
220 migr, err := ioutil.ReadAll(migration)
221 if err != nil {
222 return err
223 }
224
225
226 query := string(migr[:])
227 if _, err := c.db.Exec(query); err != nil {
228 return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
229 }
230
231 return nil
232 }
233
234 func (c *CockroachDb) SetVersion(version int, dirty bool) error {
235 return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
236 if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
237 return err
238 }
239
240
241
242
243 if version >= 0 || (version == database.NilVersion && dirty) {
244 if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
245 return err
246 }
247 }
248
249 return nil
250 })
251 }
252
253 func (c *CockroachDb) Version() (version int, dirty bool, err error) {
254 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
255 err = c.db.QueryRow(query).Scan(&version, &dirty)
256
257 switch {
258 case err == sql.ErrNoRows:
259 return database.NilVersion, false, nil
260
261 case err != nil:
262 if e, ok := err.(*pq.Error); ok {
263
264
265 if e.Code == "42P01" {
266 return database.NilVersion, false, nil
267 }
268 }
269 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
270
271 default:
272 return version, dirty, nil
273 }
274 }
275
276 func (c *CockroachDb) Drop() (err error) {
277
278 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
279 tables, err := c.db.Query(query)
280 if err != nil {
281 return &database.Error{OrigErr: err, Query: []byte(query)}
282 }
283 defer func() {
284 if errClose := tables.Close(); errClose != nil {
285 err = multierror.Append(err, errClose)
286 }
287 }()
288
289
290 tableNames := make([]string, 0)
291 for tables.Next() {
292 var tableName string
293 if err := tables.Scan(&tableName); err != nil {
294 return err
295 }
296 if len(tableName) > 0 {
297 tableNames = append(tableNames, tableName)
298 }
299 }
300 if err := tables.Err(); err != nil {
301 return &database.Error{OrigErr: err, Query: []byte(query)}
302 }
303
304 if len(tableNames) > 0 {
305
306 for _, t := range tableNames {
307 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
308 if _, err := c.db.Exec(query); err != nil {
309 return &database.Error{OrigErr: err, Query: []byte(query)}
310 }
311 }
312 }
313
314 return nil
315 }
316
317
318
319
320 func (c *CockroachDb) ensureVersionTable() (err error) {
321 if err = c.Lock(); err != nil {
322 return err
323 }
324
325 defer func() {
326 if e := c.Unlock(); e != nil {
327 if err == nil {
328 err = e
329 } else {
330 err = multierror.Append(err, e)
331 }
332 }
333 }()
334
335
336 var count int
337 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
338 if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
339 return &database.Error{OrigErr: err, Query: []byte(query)}
340 }
341 if count == 1 {
342 return nil
343 }
344
345
346 query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
347 if _, err := c.db.Exec(query); err != nil {
348 return &database.Error{OrigErr: err, Query: []byte(query)}
349 }
350 return nil
351 }
352
353 func (c *CockroachDb) ensureLockTable() error {
354
355 var count int
356 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
357 if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
358 return &database.Error{OrigErr: err, Query: []byte(query)}
359 }
360 if count == 1 {
361 return nil
362 }
363
364
365 query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
366 if _, err := c.db.Exec(query); err != nil {
367 return &database.Error{OrigErr: err, Query: []byte(query)}
368 }
369
370 return nil
371 }
372
View as plain text