1 package popx
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "io"
8 "os"
9 "path/filepath"
10 "regexp"
11 "sort"
12 "strings"
13 "text/tabwriter"
14 "time"
15
16 "github.com/ory/x/cmdx"
17
18 "github.com/ory/x/tracing"
19
20 "github.com/opentracing/opentracing-go"
21 "github.com/opentracing/opentracing-go/log"
22
23 "github.com/gobuffalo/pop/v5"
24
25 "github.com/ory/x/logrusx"
26
27 "github.com/pkg/errors"
28 )
29
30 const (
31 Pending = "Pending"
32 Applied = "Applied"
33 )
34
35 var mrx = regexp.MustCompile(`^(\d+)_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.(sql|fizz)$`)
36
37
38
39
40
41 func NewMigrator(c *pop.Connection, l *logrusx.Logger, tracer *tracing.Tracer, perMigrationTimeout time.Duration) *Migrator {
42 return &Migrator{
43 Connection: c,
44 l: l,
45 Migrations: map[string]Migrations{
46 "up": {},
47 "down": {},
48 },
49 tracer: tracer,
50 PerMigrationTimeout: perMigrationTimeout,
51 }
52 }
53
54
55
56
57
58 type Migrator struct {
59 Connection *pop.Connection
60 SchemaPath string
61 Migrations map[string]Migrations
62 l *logrusx.Logger
63 PerMigrationTimeout time.Duration
64 tracer *tracing.Tracer
65 }
66
67 func (m *Migrator) MigrationIsCompatible(dialect string, mi Migration) bool {
68 if mi.DBType == "all" || mi.DBType == dialect {
69 return true
70 }
71 return false
72 }
73
74
75 func (m *Migrator) Up(ctx context.Context) error {
76 _, err := m.UpTo(ctx, 0)
77 return err
78 }
79
80
81
82 func (m *Migrator) UpTo(ctx context.Context, step int) (applied int, err error) {
83 span, ctx := m.startSpan(ctx, MigrationUpOpName)
84 defer span.Finish()
85 span.LogFields(log.Int("up_to_step", step))
86
87 c := m.Connection.WithContext(ctx)
88 err = m.exec(ctx, func() error {
89 mtn := m.migrationTableName(ctx, c)
90 mfs := m.Migrations["up"].SortAndFilter(c.Dialect.Name())
91 for _, mi := range mfs {
92 exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
93 if err != nil {
94 return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
95 }
96
97 if exists {
98 m.l.WithField("version", mi.Version).Debug("Migration has already been applied, skipping.")
99 continue
100 }
101
102 if len(mi.Version) > 14 {
103 m.l.WithField("version", mi.Version).Debug("Migration has not been applied but it might be a legacy migration, investigating.")
104
105 legacyVersion := mi.Version[:14]
106 exists, err = c.Where("version = ?", legacyVersion).Exists(mtn)
107 if err != nil {
108 return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
109 }
110
111 if exists {
112 m.l.WithField("version", mi.Version).WithField("legacy_version", legacyVersion).WithField("migration_table", mtn).Debug("Migration has already been applied in a legacy migration run. Updating version in migration table.")
113 if err := m.isolatedTransaction(ctx, "init-migrate", func(tx *pop.Tx) error {
114
115
116
117
118
119
120
121
122
123 _, err := tx.Exec(tx.Rebind(fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", mtn)), mi.Version)
124 return errors.Wrapf(err, "problem inserting migration version %s", mi.Version)
125 }); err != nil {
126 return err
127 }
128 continue
129 }
130 }
131
132 m.l.WithField("version", mi.Version).Debug("Migration has not yet been applied, running migration.")
133
134 if err = m.isolatedTransaction(ctx, "up", func(tx *pop.Tx) error {
135 if err := mi.Run(c, tx); err != nil {
136 return err
137 }
138
139
140 if _, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version) VALUES ('%s')", mtn, mi.Version)); err != nil {
141 return errors.Wrapf(err, "problem inserting migration version %s", mi.Version)
142 }
143 return nil
144 }); err != nil {
145 return err
146 }
147
148 m.l.Debugf("> %s", mi.Name)
149 applied++
150 if step > 0 && applied >= step {
151 break
152 }
153 }
154 if applied == 0 {
155 m.l.Debugf("Migrations already up to date, nothing to apply")
156 } else {
157 m.l.Debugf("Successfully applied %d migrations.", applied)
158 }
159 return nil
160 })
161 return
162 }
163
164
165
166 func (m *Migrator) Down(ctx context.Context, step int) error {
167 span, ctx := m.startSpan(ctx, MigrationDownOpName)
168 defer span.Finish()
169
170 c := m.Connection.WithContext(ctx)
171 return m.exec(ctx, func() error {
172 mtn := m.migrationTableName(ctx, c)
173 count, err := c.Count(mtn)
174 if err != nil {
175 return errors.Wrap(err, "migration down: unable count existing migration")
176 }
177 mfs := m.Migrations["down"].SortAndFilter(c.Dialect.Name(), sort.Reverse)
178
179 if len(mfs) > count {
180 mfs = mfs[len(mfs)-count:]
181 }
182
183 if step > 0 && len(mfs) >= step {
184 mfs = mfs[:step]
185 }
186 for _, mi := range mfs {
187 exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
188 if err != nil {
189 return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
190 }
191
192 if !exists && len(mi.Version) > 14 {
193 legacyVersion := mi.Version[:14]
194 legacyVersionExists, err := c.Where("version = ?", legacyVersion).Exists(mtn)
195 if err != nil {
196 return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
197 }
198
199 if !legacyVersionExists {
200 return errors.Wrapf(err, "problem checking for migration version %s", legacyVersion)
201 }
202 } else if !exists {
203 return errors.Errorf("migration version %s does not exist", mi.Version)
204 }
205
206 err = m.isolatedTransaction(ctx, "down", func(tx *pop.Tx) error {
207 err := mi.Run(c, tx)
208 if err != nil {
209 return err
210 }
211
212
213 if _, err = tx.Exec(tx.Rebind(fmt.Sprintf("DELETE FROM %s WHERE version = ?", mtn)), mi.Version); err != nil {
214 return errors.Wrapf(err, "problem deleting migration version %s", mi.Version)
215 }
216
217 return nil
218 })
219 if err != nil {
220 return err
221 }
222
223 m.l.Debugf("< %s", mi.Name)
224 }
225 return nil
226 })
227 }
228
229
230 func (m *Migrator) Reset(ctx context.Context) error {
231 err := m.Down(ctx, -1)
232 if err != nil {
233 return err
234 }
235 return m.Up(ctx)
236 }
237
238 func (m *Migrator) createTransactionalMigrationTable(ctx context.Context, c *pop.Connection, l *logrusx.Logger) error {
239 mtn := m.migrationTableName(ctx, c)
240 unprefixedMtn := m.migrationTableName(ctx, c)
241
242 if err := m.execMigrationTransaction(ctx, c, []string{
243 fmt.Sprintf(`CREATE TABLE %s (version VARCHAR (48) NOT NULL, version_self INT NOT NULL DEFAULT 0)`, mtn),
244 fmt.Sprintf(`CREATE UNIQUE INDEX %s_version_idx ON %s (version)`, unprefixedMtn, mtn),
245 fmt.Sprintf(`CREATE INDEX %s_version_self_idx ON %s (version_self)`, unprefixedMtn, mtn),
246 }); err != nil {
247 return err
248 }
249
250 l.WithField("migration_table", mtn).Debug("Transactional migration table created successfully.")
251
252 return nil
253 }
254
255 func (m *Migrator) migrateToTransactionalMigrationTable(ctx context.Context, c *pop.Connection, l *logrusx.Logger) error {
256
257 mtn := m.migrationTableName(ctx, c)
258 unprefixedMtn := m.migrationTableName(ctx, c)
259
260 withOn := fmt.Sprintf(" ON %s", mtn)
261 if c.Dialect.Name() != "mysql" {
262 withOn = ""
263 }
264
265 interimTable := fmt.Sprintf("%s_transactional", mtn)
266 workload := [][]string{
267 {
268 fmt.Sprintf(`DROP INDEX %s_version_idx%s`, unprefixedMtn, withOn),
269 fmt.Sprintf(`CREATE TABLE %s (version VARCHAR (48) NOT NULL, version_self INT NOT NULL DEFAULT 0)`, interimTable),
270 fmt.Sprintf(`CREATE UNIQUE INDEX %s_version_idx ON %s (version)`, unprefixedMtn, interimTable),
271 fmt.Sprintf(`CREATE INDEX %s_version_self_idx ON %s (version_self)`, unprefixedMtn, interimTable),
272
273 fmt.Sprintf(`INSERT INTO %s (version) SELECT version FROM %s`, interimTable, mtn),
274 fmt.Sprintf(`ALTER TABLE %s RENAME TO %s_pop_legacy`, mtn, mtn),
275 },
276 {
277 fmt.Sprintf(`ALTER TABLE %s RENAME TO %s`, interimTable, mtn),
278 },
279 }
280
281 if err := m.execMigrationTransaction(ctx, c, workload...); err != nil {
282 return err
283 }
284
285 l.WithField("migration_table", mtn).Debug("Successfully migrated legacy schema_migration to new transactional schema_migration table.")
286
287 return nil
288 }
289
290 func (m *Migrator) isolatedTransaction(ctx context.Context, direction string, fn func(tx *pop.Tx) error) error {
291 span, ctx := m.startSpan(ctx, MigrationRunTransactionOpName)
292 defer span.Finish()
293 span.SetTag("migration_direction", direction)
294
295 if m.PerMigrationTimeout > 0 {
296 var cancel context.CancelFunc
297 ctx, cancel = context.WithTimeout(ctx, m.PerMigrationTimeout)
298 defer cancel()
299 }
300
301 c := m.Connection.WithContext(ctx)
302 tx, dberr := c.Store.TransactionContextOptions(ctx, &sql.TxOptions{
303 Isolation: sql.LevelSerializable,
304 ReadOnly: false,
305 })
306 if dberr != nil {
307 return dberr
308 }
309
310 err := fn(tx)
311 if err != nil {
312 dberr = tx.Rollback()
313 } else {
314 dberr = tx.Commit()
315 }
316
317 if dberr != nil {
318 return errors.Wrap(dberr, "error committing or rolling back transaction")
319 }
320
321 return err
322 }
323
324 func (m *Migrator) execMigrationTransaction(ctx context.Context, c *pop.Connection, transactions ...[]string) error {
325 for _, statements := range transactions {
326 if err := m.isolatedTransaction(ctx, "init", func(tx *pop.Tx) error {
327 for _, statement := range statements {
328 if _, err := tx.ExecContext(ctx, statement); err != nil {
329 return errors.Wrapf(err, "unable to execute statement: %s", statement)
330 }
331 }
332 return nil
333 }); err != nil {
334 return err
335 }
336 }
337
338 return nil
339 }
340
341
342
343 func (m *Migrator) CreateSchemaMigrations(ctx context.Context) error {
344 span, ctx := m.startSpan(ctx, MigrationInitOpName)
345 defer span.Finish()
346
347 c := m.Connection.WithContext(ctx)
348
349 mtn := m.migrationTableName(ctx, c)
350 m.l.WithField("migration_table", mtn).Debug("Checking if legacy migration table exists.")
351 _, err := c.Store.Exec(fmt.Sprintf("select version from %s", mtn))
352 if err != nil {
353 m.l.WithError(err).WithField("migration_table", mtn).Debug("An error occurred while checking for the legacy migration table, maybe it does not exist yet? Trying to create.")
354
355 return m.createTransactionalMigrationTable(ctx, c, m.l)
356 }
357
358 m.l.WithField("migration_table", mtn).Debug("A migration table exists, checking if it is a transactional migration table.")
359 _, err = c.Store.Exec(fmt.Sprintf("select version, version_self from %s", mtn))
360 if err != nil {
361 m.l.WithError(err).WithField("migration_table", mtn).Debug("An error occurred while checking for the transactional migration table, maybe it does not exist yet? Trying to create.")
362 return m.migrateToTransactionalMigrationTable(ctx, c, m.l)
363 }
364
365 m.l.WithField("migration_table", mtn).Debug("Migration tables exist and are up to date.")
366 return nil
367 }
368
369 type MigrationStatus struct {
370 State string `json:"state"`
371 Version string `json:"version"`
372 Name string `json:"name"`
373 }
374
375 type MigrationStatuses []MigrationStatus
376
377 var _ cmdx.Table = (MigrationStatuses)(nil)
378
379 func (m MigrationStatuses) Header() []string {
380 return []string{"Version", "Name", "Status"}
381 }
382
383 func (m MigrationStatuses) Table() [][]string {
384 t := make([][]string, len(m))
385 for i, s := range m {
386 t[i] = []string{s.Version, s.Name, s.State}
387 }
388 return t
389 }
390
391 func (m MigrationStatuses) Interface() interface{} {
392 return m
393 }
394
395 func (m MigrationStatuses) Len() int {
396 return len(m)
397 }
398
399 func (m MigrationStatuses) IDs() []string {
400 ids := make([]string, len(m))
401 for i, s := range m {
402 ids[i] = s.Version
403 }
404 return ids
405 }
406
407
408 func (m MigrationStatuses) Write(out io.Writer) error {
409 w := tabwriter.NewWriter(out, 0, 0, 3, ' ', tabwriter.TabIndent)
410 _, _ = fmt.Fprintln(w, "Version\tName\tStatus\t")
411
412 for _, mm := range m {
413 _, _ = fmt.Fprintf(w, "%s\t%s\t%s\t\n", mm.Version, mm.Name, mm.State)
414 }
415
416 return w.Flush()
417 }
418
419 func (m MigrationStatuses) HasPending() bool {
420 for _, mm := range m {
421 if mm.State == Pending {
422 return true
423 }
424 }
425 return false
426 }
427
428 func (m *Migrator) migrationTableName(ctx context.Context, con *pop.Connection) string {
429 return con.MigrationTableName()
430 }
431
432 func errIsTableNotFound(err error) bool {
433 return strings.HasPrefix(err.Error(), "no such table:") ||
434 strings.HasPrefix(err.Error(), "Error 1146:") ||
435 strings.Contains(err.Error(), "SQLSTATE 42P01")
436 }
437
438
439 func (m *Migrator) Status(ctx context.Context) (MigrationStatuses, error) {
440 span, ctx := m.startSpan(ctx, MigrationStatusOpName)
441 defer span.Finish()
442
443 con := m.Connection.WithContext(ctx)
444
445 migrations := m.Migrations["up"].SortAndFilter(con.Dialect.Name())
446
447 if len(migrations) == 0 {
448 return nil, errors.Errorf("unable to find any migrations for dialect: %s", con.Dialect.Name())
449 }
450
451 statuses := make(MigrationStatuses, len(migrations))
452 for k, mf := range migrations {
453 statuses[k] = MigrationStatus{
454 State: Pending,
455 Version: mf.Version,
456 Name: mf.Name,
457 }
458
459 exists, err := con.Where("version = ?", mf.Version).Exists(con.MigrationTableName())
460 if err != nil {
461 if errIsTableNotFound(err) {
462 continue
463 } else {
464 return nil, errors.Wrapf(err, "problem with migration")
465 }
466 }
467
468 if exists {
469 statuses[k].State = Applied
470 } else if len(mf.Version) > 14 {
471 mtn := m.migrationTableName(ctx, con)
472 legacyVersion := mf.Version[:14]
473 exists, err = con.Where("version = ?", legacyVersion).Exists(mtn)
474 if err != nil {
475 return nil, errors.Wrapf(err, "problem checking for migration version %s", legacyVersion)
476 }
477
478 if exists {
479 statuses[k].State = Applied
480 }
481 }
482 }
483
484 return statuses, nil
485 }
486
487
488
489 func (m *Migrator) DumpMigrationSchema(ctx context.Context) error {
490 if m.SchemaPath == "" {
491 return nil
492 }
493 c := m.Connection.WithContext(ctx)
494 schema := filepath.Join(m.SchemaPath, "schema.sql")
495 f, err := os.Create(schema)
496 if err != nil {
497 return err
498 }
499 err = c.Dialect.DumpSchema(f)
500 if err != nil {
501 os.RemoveAll(schema)
502 return err
503 }
504 return nil
505 }
506
507 func (m *Migrator) wrapSpan(ctx context.Context, opName string, f func(ctx context.Context, span opentracing.Span) error) error {
508 span, ctx := m.startSpan(ctx, opName)
509 defer span.Finish()
510
511 return f(ctx, span)
512 }
513
514 func (m *Migrator) startSpan(ctx context.Context, opName string) (opentracing.Span, context.Context) {
515 tracer := opentracing.GlobalTracer()
516 if m.tracer.IsLoaded() {
517 tracer = m.tracer.Tracer()
518
519 }
520
521 span, ctx := opentracing.StartSpanFromContextWithTracer(ctx, tracer, opName)
522 span.SetTag("component", "github.com/ory/x/popx")
523
524 span.LogFields()
525 return span, ctx
526 }
527
528 func (m *Migrator) exec(ctx context.Context, fn func() error) error {
529 now := time.Now()
530 defer func() {
531 err := m.DumpMigrationSchema(ctx)
532 if err != nil {
533 m.l.WithError(err).Warn("Migrator: unable to dump schema")
534 }
535 }()
536 defer m.printTimer(now)
537
538 err := m.CreateSchemaMigrations(ctx)
539 if err != nil {
540 return errors.Wrap(err, "migrator: problem creating schema migrations")
541 }
542
543 if m.Connection.Dialect.Name() == "sqlite3" {
544 if err := m.Connection.RawQuery("PRAGMA foreign_keys=OFF").Exec(); err != nil {
545 return err
546 }
547 }
548
549 if err := fn(); err != nil {
550 return err
551 }
552
553 if m.Connection.Dialect.Name() == "sqlite3" {
554 if err := m.Connection.RawQuery("PRAGMA foreign_keys=ON").Exec(); err != nil {
555 return err
556 }
557 }
558
559 return nil
560 }
561
562 func (m *Migrator) printTimer(timerStart time.Time) {
563 diff := time.Since(timerStart).Seconds()
564 if diff > 60 {
565 m.l.Debugf("%.4f minutes", diff/60)
566 } else {
567 m.l.Debugf("%.4f seconds", diff)
568 }
569 }
570
View as plain text