...

Source file src/github.com/ory/x/popx/migrator.go

Documentation: github.com/ory/x/popx

     1  package popx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"path/filepath"
    10  	"regexp"
    11  	"sort"
    12  	"strings"
    13  	"text/tabwriter"
    14  	"time"
    15  
    16  	"github.com/ory/x/cmdx"
    17  
    18  	"github.com/ory/x/tracing"
    19  
    20  	"github.com/opentracing/opentracing-go"
    21  	"github.com/opentracing/opentracing-go/log"
    22  
    23  	"github.com/gobuffalo/pop/v5"
    24  
    25  	"github.com/ory/x/logrusx"
    26  
    27  	"github.com/pkg/errors"
    28  )
    29  
    30  const (
    31  	Pending = "Pending"
    32  	Applied = "Applied"
    33  )
    34  
    35  var mrx = regexp.MustCompile(`^(\d+)_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.(sql|fizz)$`)
    36  
    37  // NewMigrator returns a new "blank" migrator. It is recommended
    38  // to use something like MigrationBox or FileMigrator. A "blank"
    39  // Migrator should only be used as the basis for a new type of
    40  // migration system.
    41  func NewMigrator(c *pop.Connection, l *logrusx.Logger, tracer *tracing.Tracer, perMigrationTimeout time.Duration) *Migrator {
    42  	return &Migrator{
    43  		Connection: c,
    44  		l:          l,
    45  		Migrations: map[string]Migrations{
    46  			"up":   {},
    47  			"down": {},
    48  		},
    49  		tracer:              tracer,
    50  		PerMigrationTimeout: perMigrationTimeout,
    51  	}
    52  }
    53  
    54  // Migrator forms the basis of all migrations systems.
    55  // It does the actual heavy lifting of running migrations.
    56  // When building a new migration system, you should embed this
    57  // type into your migrator.
    58  type Migrator struct {
    59  	Connection          *pop.Connection
    60  	SchemaPath          string
    61  	Migrations          map[string]Migrations
    62  	l                   *logrusx.Logger
    63  	PerMigrationTimeout time.Duration
    64  	tracer              *tracing.Tracer
    65  }
    66  
    67  func (m *Migrator) MigrationIsCompatible(dialect string, mi Migration) bool {
    68  	if mi.DBType == "all" || mi.DBType == dialect {
    69  		return true
    70  	}
    71  	return false
    72  }
    73  
    74  // Up runs pending "up" migrations and applies them to the database.
    75  func (m *Migrator) Up(ctx context.Context) error {
    76  	_, err := m.UpTo(ctx, 0)
    77  	return err
    78  }
    79  
    80  // UpTo runs up to step "up" migrations and applies them to the database.
    81  // If step <= 0 all pending migrations are run.
    82  func (m *Migrator) UpTo(ctx context.Context, step int) (applied int, err error) {
    83  	span, ctx := m.startSpan(ctx, MigrationUpOpName)
    84  	defer span.Finish()
    85  	span.LogFields(log.Int("up_to_step", step))
    86  
    87  	c := m.Connection.WithContext(ctx)
    88  	err = m.exec(ctx, func() error {
    89  		mtn := m.migrationTableName(ctx, c)
    90  		mfs := m.Migrations["up"].SortAndFilter(c.Dialect.Name())
    91  		for _, mi := range mfs {
    92  			exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
    93  			if err != nil {
    94  				return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
    95  			}
    96  
    97  			if exists {
    98  				m.l.WithField("version", mi.Version).Debug("Migration has already been applied, skipping.")
    99  				continue
   100  			}
   101  
   102  			if len(mi.Version) > 14 {
   103  				m.l.WithField("version", mi.Version).Debug("Migration has not been applied but it might be a legacy migration, investigating.")
   104  
   105  				legacyVersion := mi.Version[:14]
   106  				exists, err = c.Where("version = ?", legacyVersion).Exists(mtn)
   107  				if err != nil {
   108  					return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
   109  				}
   110  
   111  				if exists {
   112  					m.l.WithField("version", mi.Version).WithField("legacy_version", legacyVersion).WithField("migration_table", mtn).Debug("Migration has already been applied in a legacy migration run. Updating version in migration table.")
   113  					if err := m.isolatedTransaction(ctx, "init-migrate", func(tx *pop.Tx) error {
   114  						// We do not want to remove the legacy migration version or subsequent migrations might be applied twice.
   115  						//
   116  						// Do not activate the following - it is just for reference.
   117  						//
   118  						// if _, err := tx.Store.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", mtn), legacyVersion); err != nil {
   119  						//	return errors.Wrapf(err, "problem removing legacy version %s", mi.Version)
   120  						// }
   121  
   122  						// #nosec G201 - mtn is a system-wide const
   123  						_, err := tx.Exec(tx.Rebind(fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", mtn)), mi.Version)
   124  						return errors.Wrapf(err, "problem inserting migration version %s", mi.Version)
   125  					}); err != nil {
   126  						return err
   127  					}
   128  					continue
   129  				}
   130  			}
   131  
   132  			m.l.WithField("version", mi.Version).Debug("Migration has not yet been applied, running migration.")
   133  
   134  			if err = m.isolatedTransaction(ctx, "up", func(tx *pop.Tx) error {
   135  				if err := mi.Run(c, tx); err != nil {
   136  					return err
   137  				}
   138  
   139  				// #nosec G201 - mtn is a system-wide const
   140  				if _, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version) VALUES ('%s')", mtn, mi.Version)); err != nil {
   141  					return errors.Wrapf(err, "problem inserting migration version %s", mi.Version)
   142  				}
   143  				return nil
   144  			}); err != nil {
   145  				return err
   146  			}
   147  
   148  			m.l.Debugf("> %s", mi.Name)
   149  			applied++
   150  			if step > 0 && applied >= step {
   151  				break
   152  			}
   153  		}
   154  		if applied == 0 {
   155  			m.l.Debugf("Migrations already up to date, nothing to apply")
   156  		} else {
   157  			m.l.Debugf("Successfully applied %d migrations.", applied)
   158  		}
   159  		return nil
   160  	})
   161  	return
   162  }
   163  
   164  // Down runs pending "down" migrations and rolls back the
   165  // database by the specified number of steps.
   166  func (m *Migrator) Down(ctx context.Context, step int) error {
   167  	span, ctx := m.startSpan(ctx, MigrationDownOpName)
   168  	defer span.Finish()
   169  
   170  	c := m.Connection.WithContext(ctx)
   171  	return m.exec(ctx, func() error {
   172  		mtn := m.migrationTableName(ctx, c)
   173  		count, err := c.Count(mtn)
   174  		if err != nil {
   175  			return errors.Wrap(err, "migration down: unable count existing migration")
   176  		}
   177  		mfs := m.Migrations["down"].SortAndFilter(c.Dialect.Name(), sort.Reverse)
   178  		// skip all ran migration
   179  		if len(mfs) > count {
   180  			mfs = mfs[len(mfs)-count:]
   181  		}
   182  		// run only required steps
   183  		if step > 0 && len(mfs) >= step {
   184  			mfs = mfs[:step]
   185  		}
   186  		for _, mi := range mfs {
   187  			exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
   188  			if err != nil {
   189  				return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
   190  			}
   191  
   192  			if !exists && len(mi.Version) > 14 {
   193  				legacyVersion := mi.Version[:14]
   194  				legacyVersionExists, err := c.Where("version = ?", legacyVersion).Exists(mtn)
   195  				if err != nil {
   196  					return errors.Wrapf(err, "problem checking for migration version %s", mi.Version)
   197  				}
   198  
   199  				if !legacyVersionExists {
   200  					return errors.Wrapf(err, "problem checking for migration version %s", legacyVersion)
   201  				}
   202  			} else if !exists {
   203  				return errors.Errorf("migration version %s does not exist", mi.Version)
   204  			}
   205  
   206  			err = m.isolatedTransaction(ctx, "down", func(tx *pop.Tx) error {
   207  				err := mi.Run(c, tx)
   208  				if err != nil {
   209  					return err
   210  				}
   211  
   212  				// #nosec G201 - mtn is a system-wide const
   213  				if _, err = tx.Exec(tx.Rebind(fmt.Sprintf("DELETE FROM %s WHERE version = ?", mtn)), mi.Version); err != nil {
   214  					return errors.Wrapf(err, "problem deleting migration version %s", mi.Version)
   215  				}
   216  
   217  				return nil
   218  			})
   219  			if err != nil {
   220  				return err
   221  			}
   222  
   223  			m.l.Debugf("< %s", mi.Name)
   224  		}
   225  		return nil
   226  	})
   227  }
   228  
   229  // Reset the database by running the down migrations followed by the up migrations.
   230  func (m *Migrator) Reset(ctx context.Context) error {
   231  	err := m.Down(ctx, -1)
   232  	if err != nil {
   233  		return err
   234  	}
   235  	return m.Up(ctx)
   236  }
   237  
   238  func (m *Migrator) createTransactionalMigrationTable(ctx context.Context, c *pop.Connection, l *logrusx.Logger) error {
   239  	mtn := m.migrationTableName(ctx, c)
   240  	unprefixedMtn := m.migrationTableName(ctx, c)
   241  
   242  	if err := m.execMigrationTransaction(ctx, c, []string{
   243  		fmt.Sprintf(`CREATE TABLE %s (version VARCHAR (48) NOT NULL, version_self INT NOT NULL DEFAULT 0)`, mtn),
   244  		fmt.Sprintf(`CREATE UNIQUE INDEX %s_version_idx ON %s (version)`, unprefixedMtn, mtn),
   245  		fmt.Sprintf(`CREATE INDEX %s_version_self_idx ON %s (version_self)`, unprefixedMtn, mtn),
   246  	}); err != nil {
   247  		return err
   248  	}
   249  
   250  	l.WithField("migration_table", mtn).Debug("Transactional migration table created successfully.")
   251  
   252  	return nil
   253  }
   254  
   255  func (m *Migrator) migrateToTransactionalMigrationTable(ctx context.Context, c *pop.Connection, l *logrusx.Logger) error {
   256  	// This means the new pop migrator has also not yet been applied, do that now.
   257  	mtn := m.migrationTableName(ctx, c)
   258  	unprefixedMtn := m.migrationTableName(ctx, c)
   259  
   260  	withOn := fmt.Sprintf(" ON %s", mtn)
   261  	if c.Dialect.Name() != "mysql" {
   262  		withOn = ""
   263  	}
   264  
   265  	interimTable := fmt.Sprintf("%s_transactional", mtn)
   266  	workload := [][]string{
   267  		{
   268  			fmt.Sprintf(`DROP INDEX %s_version_idx%s`, unprefixedMtn, withOn),
   269  			fmt.Sprintf(`CREATE TABLE %s (version VARCHAR (48) NOT NULL, version_self INT NOT NULL DEFAULT 0)`, interimTable),
   270  			fmt.Sprintf(`CREATE UNIQUE INDEX %s_version_idx ON %s (version)`, unprefixedMtn, interimTable),
   271  			fmt.Sprintf(`CREATE INDEX %s_version_self_idx ON %s (version_self)`, unprefixedMtn, interimTable),
   272  			// #nosec G201 - mtn is a system-wide const
   273  			fmt.Sprintf(`INSERT INTO %s (version) SELECT version FROM %s`, interimTable, mtn),
   274  			fmt.Sprintf(`ALTER TABLE %s RENAME TO %s_pop_legacy`, mtn, mtn),
   275  		},
   276  		{
   277  			fmt.Sprintf(`ALTER TABLE %s RENAME TO %s`, interimTable, mtn),
   278  		},
   279  	}
   280  
   281  	if err := m.execMigrationTransaction(ctx, c, workload...); err != nil {
   282  		return err
   283  	}
   284  
   285  	l.WithField("migration_table", mtn).Debug("Successfully migrated legacy schema_migration to new transactional schema_migration table.")
   286  
   287  	return nil
   288  }
   289  
   290  func (m *Migrator) isolatedTransaction(ctx context.Context, direction string, fn func(tx *pop.Tx) error) error {
   291  	span, ctx := m.startSpan(ctx, MigrationRunTransactionOpName)
   292  	defer span.Finish()
   293  	span.SetTag("migration_direction", direction)
   294  
   295  	if m.PerMigrationTimeout > 0 {
   296  		var cancel context.CancelFunc
   297  		ctx, cancel = context.WithTimeout(ctx, m.PerMigrationTimeout)
   298  		defer cancel()
   299  	}
   300  
   301  	c := m.Connection.WithContext(ctx)
   302  	tx, dberr := c.Store.TransactionContextOptions(ctx, &sql.TxOptions{
   303  		Isolation: sql.LevelSerializable,
   304  		ReadOnly:  false,
   305  	})
   306  	if dberr != nil {
   307  		return dberr
   308  	}
   309  
   310  	err := fn(tx)
   311  	if err != nil {
   312  		dberr = tx.Rollback()
   313  	} else {
   314  		dberr = tx.Commit()
   315  	}
   316  
   317  	if dberr != nil {
   318  		return errors.Wrap(dberr, "error committing or rolling back transaction")
   319  	}
   320  
   321  	return err
   322  }
   323  
   324  func (m *Migrator) execMigrationTransaction(ctx context.Context, c *pop.Connection, transactions ...[]string) error {
   325  	for _, statements := range transactions {
   326  		if err := m.isolatedTransaction(ctx, "init", func(tx *pop.Tx) error {
   327  			for _, statement := range statements {
   328  				if _, err := tx.ExecContext(ctx, statement); err != nil {
   329  					return errors.Wrapf(err, "unable to execute statement: %s", statement)
   330  				}
   331  			}
   332  			return nil
   333  		}); err != nil {
   334  			return err
   335  		}
   336  	}
   337  
   338  	return nil
   339  }
   340  
   341  // CreateSchemaMigrations sets up a table to track migrations. This is an idempotent
   342  // operation.
   343  func (m *Migrator) CreateSchemaMigrations(ctx context.Context) error {
   344  	span, ctx := m.startSpan(ctx, MigrationInitOpName)
   345  	defer span.Finish()
   346  
   347  	c := m.Connection.WithContext(ctx)
   348  
   349  	mtn := m.migrationTableName(ctx, c)
   350  	m.l.WithField("migration_table", mtn).Debug("Checking if legacy migration table exists.")
   351  	_, err := c.Store.Exec(fmt.Sprintf("select version from %s", mtn))
   352  	if err != nil {
   353  		m.l.WithError(err).WithField("migration_table", mtn).Debug("An error occurred while checking for the legacy migration table, maybe it does not exist yet? Trying to create.")
   354  		// This means that the legacy pop migrator has not yet been applied
   355  		return m.createTransactionalMigrationTable(ctx, c, m.l)
   356  	}
   357  
   358  	m.l.WithField("migration_table", mtn).Debug("A migration table exists, checking if it is a transactional migration table.")
   359  	_, err = c.Store.Exec(fmt.Sprintf("select version, version_self from %s", mtn))
   360  	if err != nil {
   361  		m.l.WithError(err).WithField("migration_table", mtn).Debug("An error occurred while checking for the transactional migration table, maybe it does not exist yet? Trying to create.")
   362  		return m.migrateToTransactionalMigrationTable(ctx, c, m.l)
   363  	}
   364  
   365  	m.l.WithField("migration_table", mtn).Debug("Migration tables exist and are up to date.")
   366  	return nil
   367  }
   368  
   369  type MigrationStatus struct {
   370  	State   string `json:"state"`
   371  	Version string `json:"version"`
   372  	Name    string `json:"name"`
   373  }
   374  
   375  type MigrationStatuses []MigrationStatus
   376  
   377  var _ cmdx.Table = (MigrationStatuses)(nil)
   378  
   379  func (m MigrationStatuses) Header() []string {
   380  	return []string{"Version", "Name", "Status"}
   381  }
   382  
   383  func (m MigrationStatuses) Table() [][]string {
   384  	t := make([][]string, len(m))
   385  	for i, s := range m {
   386  		t[i] = []string{s.Version, s.Name, s.State}
   387  	}
   388  	return t
   389  }
   390  
   391  func (m MigrationStatuses) Interface() interface{} {
   392  	return m
   393  }
   394  
   395  func (m MigrationStatuses) Len() int {
   396  	return len(m)
   397  }
   398  
   399  func (m MigrationStatuses) IDs() []string {
   400  	ids := make([]string, len(m))
   401  	for i, s := range m {
   402  		ids[i] = s.Version
   403  	}
   404  	return ids
   405  }
   406  
   407  // In the context of a cobra.Command, use cmdx.PrintTable instead.
   408  func (m MigrationStatuses) Write(out io.Writer) error {
   409  	w := tabwriter.NewWriter(out, 0, 0, 3, ' ', tabwriter.TabIndent)
   410  	_, _ = fmt.Fprintln(w, "Version\tName\tStatus\t")
   411  
   412  	for _, mm := range m {
   413  		_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t\n", mm.Version, mm.Name, mm.State)
   414  	}
   415  
   416  	return w.Flush()
   417  }
   418  
   419  func (m MigrationStatuses) HasPending() bool {
   420  	for _, mm := range m {
   421  		if mm.State == Pending {
   422  			return true
   423  		}
   424  	}
   425  	return false
   426  }
   427  
   428  func (m *Migrator) migrationTableName(ctx context.Context, con *pop.Connection) string {
   429  	return con.MigrationTableName()
   430  }
   431  
   432  func errIsTableNotFound(err error) bool {
   433  	return strings.HasPrefix(err.Error(), "no such table:") || // sqlite
   434  		strings.HasPrefix(err.Error(), "Error 1146:") || // MySQL
   435  		strings.Contains(err.Error(), "SQLSTATE 42P01") // PostgreSQL / CockroachDB
   436  }
   437  
   438  // Status prints out the status of applied/pending migrations.
   439  func (m *Migrator) Status(ctx context.Context) (MigrationStatuses, error) {
   440  	span, ctx := m.startSpan(ctx, MigrationStatusOpName)
   441  	defer span.Finish()
   442  
   443  	con := m.Connection.WithContext(ctx)
   444  
   445  	migrations := m.Migrations["up"].SortAndFilter(con.Dialect.Name())
   446  
   447  	if len(migrations) == 0 {
   448  		return nil, errors.Errorf("unable to find any migrations for dialect: %s", con.Dialect.Name())
   449  	}
   450  
   451  	statuses := make(MigrationStatuses, len(migrations))
   452  	for k, mf := range migrations {
   453  		statuses[k] = MigrationStatus{
   454  			State:   Pending,
   455  			Version: mf.Version,
   456  			Name:    mf.Name,
   457  		}
   458  
   459  		exists, err := con.Where("version = ?", mf.Version).Exists(con.MigrationTableName())
   460  		if err != nil {
   461  			if errIsTableNotFound(err) {
   462  				continue
   463  			} else {
   464  				return nil, errors.Wrapf(err, "problem with migration")
   465  			}
   466  		}
   467  
   468  		if exists {
   469  			statuses[k].State = Applied
   470  		} else if len(mf.Version) > 14 {
   471  			mtn := m.migrationTableName(ctx, con)
   472  			legacyVersion := mf.Version[:14]
   473  			exists, err = con.Where("version = ?", legacyVersion).Exists(mtn)
   474  			if err != nil {
   475  				return nil, errors.Wrapf(err, "problem checking for migration version %s", legacyVersion)
   476  			}
   477  
   478  			if exists {
   479  				statuses[k].State = Applied
   480  			}
   481  		}
   482  	}
   483  
   484  	return statuses, nil
   485  }
   486  
   487  // DumpMigrationSchema will generate a file of the current database schema
   488  // based on the value of Migrator.SchemaPath
   489  func (m *Migrator) DumpMigrationSchema(ctx context.Context) error {
   490  	if m.SchemaPath == "" {
   491  		return nil
   492  	}
   493  	c := m.Connection.WithContext(ctx)
   494  	schema := filepath.Join(m.SchemaPath, "schema.sql")
   495  	f, err := os.Create(schema)
   496  	if err != nil {
   497  		return err
   498  	}
   499  	err = c.Dialect.DumpSchema(f)
   500  	if err != nil {
   501  		os.RemoveAll(schema)
   502  		return err
   503  	}
   504  	return nil
   505  }
   506  
   507  func (m *Migrator) wrapSpan(ctx context.Context, opName string, f func(ctx context.Context, span opentracing.Span) error) error {
   508  	span, ctx := m.startSpan(ctx, opName)
   509  	defer span.Finish()
   510  
   511  	return f(ctx, span)
   512  }
   513  
   514  func (m *Migrator) startSpan(ctx context.Context, opName string) (opentracing.Span, context.Context) {
   515  	tracer := opentracing.GlobalTracer()
   516  	if m.tracer.IsLoaded() {
   517  		tracer = m.tracer.Tracer()
   518  
   519  	}
   520  
   521  	span, ctx := opentracing.StartSpanFromContextWithTracer(ctx, tracer, opName)
   522  	span.SetTag("component", "github.com/ory/x/popx")
   523  
   524  	span.LogFields()
   525  	return span, ctx
   526  }
   527  
   528  func (m *Migrator) exec(ctx context.Context, fn func() error) error {
   529  	now := time.Now()
   530  	defer func() {
   531  		err := m.DumpMigrationSchema(ctx)
   532  		if err != nil {
   533  			m.l.WithError(err).Warn("Migrator: unable to dump schema")
   534  		}
   535  	}()
   536  	defer m.printTimer(now)
   537  
   538  	err := m.CreateSchemaMigrations(ctx)
   539  	if err != nil {
   540  		return errors.Wrap(err, "migrator: problem creating schema migrations")
   541  	}
   542  
   543  	if m.Connection.Dialect.Name() == "sqlite3" {
   544  		if err := m.Connection.RawQuery("PRAGMA foreign_keys=OFF").Exec(); err != nil {
   545  			return err
   546  		}
   547  	}
   548  
   549  	if err := fn(); err != nil {
   550  		return err
   551  	}
   552  
   553  	if m.Connection.Dialect.Name() == "sqlite3" {
   554  		if err := m.Connection.RawQuery("PRAGMA foreign_keys=ON").Exec(); err != nil {
   555  			return err
   556  		}
   557  	}
   558  
   559  	return nil
   560  }
   561  
   562  func (m *Migrator) printTimer(timerStart time.Time) {
   563  	diff := time.Since(timerStart).Seconds()
   564  	if diff > 60 {
   565  		m.l.Debugf("%.4f minutes", diff/60)
   566  	} else {
   567  		m.l.Debugf("%.4f seconds", diff)
   568  	}
   569  }
   570  

View as plain text