...

Source file src/github.com/golang-migrate/migrate/v4/database/mysql/mysql.go

Documentation: github.com/golang-migrate/migrate/v4/database/mysql

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package mysql
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"database/sql"
    11  	"fmt"
    12  	"go.uber.org/atomic"
    13  	"io"
    14  	"io/ioutil"
    15  	nurl "net/url"
    16  	"strconv"
    17  	"strings"
    18  
    19  	"github.com/go-sql-driver/mysql"
    20  	"github.com/golang-migrate/migrate/v4/database"
    21  	"github.com/hashicorp/go-multierror"
    22  )
    23  
    24  var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
    25  
    26  func init() {
    27  	database.Register("mysql", &Mysql{})
    28  }
    29  
    30  var DefaultMigrationsTable = "schema_migrations"
    31  
    32  var (
    33  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    34  	ErrNilConfig        = fmt.Errorf("no config")
    35  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    36  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    37  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    38  )
    39  
    40  type Config struct {
    41  	MigrationsTable string
    42  	DatabaseName    string
    43  	NoLock          bool
    44  }
    45  
    46  type Mysql struct {
    47  	// mysql RELEASE_LOCK must be called from the same conn, so
    48  	// just do everything over a single conn anyway.
    49  	conn     *sql.Conn
    50  	db       *sql.DB
    51  	isLocked atomic.Bool
    52  
    53  	config *Config
    54  }
    55  
    56  // connection instance must have `multiStatements` set to true
    57  func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
    58  	if config == nil {
    59  		return nil, ErrNilConfig
    60  	}
    61  
    62  	if err := conn.PingContext(ctx); err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	mx := &Mysql{
    67  		conn:   conn,
    68  		db:     nil,
    69  		config: config,
    70  	}
    71  
    72  	if config.DatabaseName == "" {
    73  		query := `SELECT DATABASE()`
    74  		var databaseName sql.NullString
    75  		if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
    76  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    77  		}
    78  
    79  		if len(databaseName.String) == 0 {
    80  			return nil, ErrNoDatabaseName
    81  		}
    82  
    83  		config.DatabaseName = databaseName.String
    84  	}
    85  
    86  	if len(config.MigrationsTable) == 0 {
    87  		config.MigrationsTable = DefaultMigrationsTable
    88  	}
    89  
    90  	if err := mx.ensureVersionTable(); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	return mx, nil
    95  }
    96  
    97  // instance must have `multiStatements` set to true
    98  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    99  	ctx := context.Background()
   100  
   101  	if err := instance.Ping(); err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	conn, err := instance.Conn(ctx)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	mx, err := WithConnection(ctx, conn, config)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	mx.db = instance
   116  
   117  	return mx, nil
   118  }
   119  
   120  // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
   121  // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
   122  func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
   123  	if c == nil {
   124  		return nil, ErrNilConfig
   125  	}
   126  	customQueryParams := map[string]string{}
   127  
   128  	for k, v := range c.Params {
   129  		if strings.HasPrefix(k, "x-") {
   130  			customQueryParams[k] = v
   131  			delete(c.Params, k)
   132  		}
   133  	}
   134  	return customQueryParams, nil
   135  }
   136  
   137  func urlToMySQLConfig(url string) (*mysql.Config, error) {
   138  	// Need to parse out custom TLS parameters and call
   139  	// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
   140  	// which consumes the registered tls.Config
   141  	// Fixes: https://github.com/golang-migrate/migrate/issues/411
   142  	//
   143  	// Can't use url.Parse() since it fails to parse MySQL DSNs
   144  	// mysql.ParseDSN() also searches for "?" to find query parameters:
   145  	// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
   146  	if idx := strings.LastIndex(url, "?"); idx > 0 {
   147  		rawParams := url[idx+1:]
   148  		parsedParams, err := nurl.ParseQuery(rawParams)
   149  		if err != nil {
   150  			return nil, err
   151  		}
   152  
   153  		ctls := parsedParams.Get("tls")
   154  		if len(ctls) > 0 {
   155  			if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   156  				rootCertPool := x509.NewCertPool()
   157  				pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca"))
   158  				if err != nil {
   159  					return nil, err
   160  				}
   161  
   162  				if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   163  					return nil, ErrAppendPEM
   164  				}
   165  
   166  				clientCert := make([]tls.Certificate, 0, 1)
   167  				if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
   168  					if ccert == "" || ckey == "" {
   169  						return nil, ErrTLSCertKeyConfig
   170  					}
   171  					certs, err := tls.LoadX509KeyPair(ccert, ckey)
   172  					if err != nil {
   173  						return nil, err
   174  					}
   175  					clientCert = append(clientCert, certs)
   176  				}
   177  
   178  				insecureSkipVerify := false
   179  				insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
   180  				if len(insecureSkipVerifyStr) > 0 {
   181  					x, err := strconv.ParseBool(insecureSkipVerifyStr)
   182  					if err != nil {
   183  						return nil, err
   184  					}
   185  					insecureSkipVerify = x
   186  				}
   187  
   188  				err = mysql.RegisterTLSConfig(ctls, &tls.Config{
   189  					RootCAs:            rootCertPool,
   190  					Certificates:       clientCert,
   191  					InsecureSkipVerify: insecureSkipVerify,
   192  				})
   193  				if err != nil {
   194  					return nil, err
   195  				}
   196  			}
   197  		}
   198  	}
   199  
   200  	config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	config.MultiStatements = true
   206  
   207  	// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
   208  	// net/url.Parse() would automatically unescape it for us.
   209  	// See: https://play.golang.org/p/q9j1io-YICQ
   210  	user, err := nurl.QueryUnescape(config.User)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	config.User = user
   215  
   216  	password, err := nurl.QueryUnescape(config.Passwd)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  	config.Passwd = password
   221  
   222  	return config, nil
   223  }
   224  
   225  func (m *Mysql) Open(url string) (database.Driver, error) {
   226  	config, err := urlToMySQLConfig(url)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	customParams, err := extractCustomQueryParams(config)
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	noLockParam, noLock := customParams["x-no-lock"], false
   237  	if noLockParam != "" {
   238  		noLock, err = strconv.ParseBool(noLockParam)
   239  		if err != nil {
   240  			return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
   241  		}
   242  	}
   243  
   244  	db, err := sql.Open("mysql", config.FormatDSN())
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	mx, err := WithInstance(db, &Config{
   250  		DatabaseName:    config.DBName,
   251  		MigrationsTable: customParams["x-migrations-table"],
   252  		NoLock:          noLock,
   253  	})
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  
   258  	return mx, nil
   259  }
   260  
   261  func (m *Mysql) Close() error {
   262  	connErr := m.conn.Close()
   263  	var dbErr error
   264  	if m.db != nil {
   265  		dbErr = m.db.Close()
   266  	}
   267  
   268  	if connErr != nil || dbErr != nil {
   269  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   270  	}
   271  	return nil
   272  }
   273  
   274  func (m *Mysql) Lock() error {
   275  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   276  		if m.config.NoLock {
   277  			return nil
   278  		}
   279  		aid, err := database.GenerateAdvisoryLockId(
   280  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   281  		if err != nil {
   282  			return err
   283  		}
   284  
   285  		query := "SELECT GET_LOCK(?, 10)"
   286  		var success bool
   287  		if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   288  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   289  		}
   290  
   291  		if !success {
   292  			return database.ErrLocked
   293  		}
   294  
   295  		return nil
   296  	})
   297  }
   298  
   299  func (m *Mysql) Unlock() error {
   300  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   301  		if m.config.NoLock {
   302  			return nil
   303  		}
   304  
   305  		aid, err := database.GenerateAdvisoryLockId(
   306  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   307  		if err != nil {
   308  			return err
   309  		}
   310  
   311  		query := `SELECT RELEASE_LOCK(?)`
   312  		if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   313  			return &database.Error{OrigErr: err, Query: []byte(query)}
   314  		}
   315  
   316  		// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   317  		// in which case isLocked should be true until the timeout expires -- synchronizing
   318  		// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   319  
   320  		return nil
   321  	})
   322  }
   323  
   324  func (m *Mysql) Run(migration io.Reader) error {
   325  	migr, err := ioutil.ReadAll(migration)
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	query := string(migr[:])
   331  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   332  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   333  	}
   334  
   335  	return nil
   336  }
   337  
   338  func (m *Mysql) SetVersion(version int, dirty bool) error {
   339  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
   340  	if err != nil {
   341  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   342  	}
   343  
   344  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   345  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   346  		if errRollback := tx.Rollback(); errRollback != nil {
   347  			err = multierror.Append(err, errRollback)
   348  		}
   349  		return &database.Error{OrigErr: err, Query: []byte(query)}
   350  	}
   351  
   352  	// Also re-write the schema version for nil dirty versions to prevent
   353  	// empty schema version for failed down migration on the first migration
   354  	// See: https://github.com/golang-migrate/migrate/issues/330
   355  	if version >= 0 || (version == database.NilVersion && dirty) {
   356  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   357  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   358  			if errRollback := tx.Rollback(); errRollback != nil {
   359  				err = multierror.Append(err, errRollback)
   360  			}
   361  			return &database.Error{OrigErr: err, Query: []byte(query)}
   362  		}
   363  	}
   364  
   365  	if err := tx.Commit(); err != nil {
   366  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   367  	}
   368  
   369  	return nil
   370  }
   371  
   372  func (m *Mysql) Version() (version int, dirty bool, err error) {
   373  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   374  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   375  	switch {
   376  	case err == sql.ErrNoRows:
   377  		return database.NilVersion, false, nil
   378  
   379  	case err != nil:
   380  		if e, ok := err.(*mysql.MySQLError); ok {
   381  			if e.Number == 0 {
   382  				return database.NilVersion, false, nil
   383  			}
   384  		}
   385  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   386  
   387  	default:
   388  		return version, dirty, nil
   389  	}
   390  }
   391  
   392  func (m *Mysql) Drop() (err error) {
   393  	// select all tables
   394  	query := `SHOW TABLES LIKE '%'`
   395  	tables, err := m.conn.QueryContext(context.Background(), query)
   396  	if err != nil {
   397  		return &database.Error{OrigErr: err, Query: []byte(query)}
   398  	}
   399  	defer func() {
   400  		if errClose := tables.Close(); errClose != nil {
   401  			err = multierror.Append(err, errClose)
   402  		}
   403  	}()
   404  
   405  	// delete one table after another
   406  	tableNames := make([]string, 0)
   407  	for tables.Next() {
   408  		var tableName string
   409  		if err := tables.Scan(&tableName); err != nil {
   410  			return err
   411  		}
   412  		if len(tableName) > 0 {
   413  			tableNames = append(tableNames, tableName)
   414  		}
   415  	}
   416  	if err := tables.Err(); err != nil {
   417  		return &database.Error{OrigErr: err, Query: []byte(query)}
   418  	}
   419  
   420  	if len(tableNames) > 0 {
   421  		// disable checking foreign key constraints until finished
   422  		query = `SET foreign_key_checks = 0`
   423  		if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   424  			return &database.Error{OrigErr: err, Query: []byte(query)}
   425  		}
   426  
   427  		defer func() {
   428  			// enable foreign key checks
   429  			_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
   430  		}()
   431  
   432  		// delete one by one ...
   433  		for _, t := range tableNames {
   434  			query = "DROP TABLE IF EXISTS `" + t + "`"
   435  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   436  				return &database.Error{OrigErr: err, Query: []byte(query)}
   437  			}
   438  		}
   439  	}
   440  
   441  	return nil
   442  }
   443  
   444  // ensureVersionTable checks if versions table exists and, if not, creates it.
   445  // Note that this function locks the database, which deviates from the usual
   446  // convention of "caller locks" in the Mysql type.
   447  func (m *Mysql) ensureVersionTable() (err error) {
   448  	if err = m.Lock(); err != nil {
   449  		return err
   450  	}
   451  
   452  	defer func() {
   453  		if e := m.Unlock(); e != nil {
   454  			if err == nil {
   455  				err = e
   456  			} else {
   457  				err = multierror.Append(err, e)
   458  			}
   459  		}
   460  	}()
   461  
   462  	// check if migration table exists
   463  	var result string
   464  	query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
   465  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   466  		if err != sql.ErrNoRows {
   467  			return &database.Error{OrigErr: err, Query: []byte(query)}
   468  		}
   469  	} else {
   470  		return nil
   471  	}
   472  
   473  	// if not, create the empty migration table
   474  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   475  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   476  		return &database.Error{OrigErr: err, Query: []byte(query)}
   477  	}
   478  	return nil
   479  }
   480  
   481  // Returns the bool value of the input.
   482  // The 2nd return value indicates if the input was a valid bool value
   483  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   484  func readBool(input string) (value bool, valid bool) {
   485  	switch input {
   486  	case "1", "true", "TRUE", "True":
   487  		return true, true
   488  	case "0", "false", "FALSE", "False":
   489  		return false, true
   490  	}
   491  
   492  	// Not a valid bool value
   493  	return
   494  }
   495  

View as plain text