...

Source file src/github.com/golang-migrate/migrate/v4/database/ql/ql.go

Documentation: github.com/golang-migrate/migrate/v4/database/ql

     1  package ql
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"github.com/hashicorp/go-multierror"
     7  	"go.uber.org/atomic"
     8  	"io"
     9  	"io/ioutil"
    10  	"strings"
    11  
    12  	nurl "net/url"
    13  
    14  	"github.com/golang-migrate/migrate/v4"
    15  	"github.com/golang-migrate/migrate/v4/database"
    16  	_ "modernc.org/ql/driver"
    17  )
    18  
    19  func init() {
    20  	database.Register("ql", &Ql{})
    21  }
    22  
    23  var DefaultMigrationsTable = "schema_migrations"
    24  var (
    25  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    26  	ErrNilConfig      = fmt.Errorf("no config")
    27  	ErrNoDatabaseName = fmt.Errorf("no database name")
    28  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    29  )
    30  
    31  type Config struct {
    32  	MigrationsTable string
    33  	DatabaseName    string
    34  }
    35  
    36  type Ql struct {
    37  	db       *sql.DB
    38  	isLocked atomic.Bool
    39  
    40  	config *Config
    41  }
    42  
    43  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    44  	if config == nil {
    45  		return nil, ErrNilConfig
    46  	}
    47  
    48  	if err := instance.Ping(); err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	if len(config.MigrationsTable) == 0 {
    53  		config.MigrationsTable = DefaultMigrationsTable
    54  	}
    55  
    56  	mx := &Ql{
    57  		db:     instance,
    58  		config: config,
    59  	}
    60  	if err := mx.ensureVersionTable(); err != nil {
    61  		return nil, err
    62  	}
    63  	return mx, nil
    64  }
    65  
    66  // ensureVersionTable checks if versions table exists and, if not, creates it.
    67  // Note that this function locks the database, which deviates from the usual
    68  // convention of "caller locks" in the Ql type.
    69  func (m *Ql) ensureVersionTable() (err error) {
    70  	if err = m.Lock(); err != nil {
    71  		return err
    72  	}
    73  
    74  	defer func() {
    75  		if e := m.Unlock(); e != nil {
    76  			if err == nil {
    77  				err = e
    78  			} else {
    79  				err = multierror.Append(err, e)
    80  			}
    81  		}
    82  	}()
    83  
    84  	tx, err := m.db.Begin()
    85  	if err != nil {
    86  		return err
    87  	}
    88  	if _, err := tx.Exec(fmt.Sprintf(`
    89  	CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool);
    90  	CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    91  `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil {
    92  		if err := tx.Rollback(); err != nil {
    93  			return err
    94  		}
    95  		return err
    96  	}
    97  	if err := tx.Commit(); err != nil {
    98  		return err
    99  	}
   100  	return nil
   101  }
   102  
   103  func (m *Ql) Open(url string) (database.Driver, error) {
   104  	purl, err := nurl.Parse(url)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1)
   109  	db, err := sql.Open("ql", dbfile)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	migrationsTable := purl.Query().Get("x-migrations-table")
   114  	if len(migrationsTable) == 0 {
   115  		migrationsTable = DefaultMigrationsTable
   116  	}
   117  	mx, err := WithInstance(db, &Config{
   118  		DatabaseName:    purl.Path,
   119  		MigrationsTable: migrationsTable,
   120  	})
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	return mx, nil
   125  }
   126  func (m *Ql) Close() error {
   127  	return m.db.Close()
   128  }
   129  func (m *Ql) Drop() (err error) {
   130  	query := `SELECT Name FROM __Table`
   131  	tables, err := m.db.Query(query)
   132  	if err != nil {
   133  		return &database.Error{OrigErr: err, Query: []byte(query)}
   134  	}
   135  	defer func() {
   136  		if errClose := tables.Close(); errClose != nil {
   137  			err = multierror.Append(err, errClose)
   138  		}
   139  	}()
   140  
   141  	tableNames := make([]string, 0)
   142  	for tables.Next() {
   143  		var tableName string
   144  		if err := tables.Scan(&tableName); err != nil {
   145  			return err
   146  		}
   147  		if len(tableName) > 0 {
   148  			if !strings.HasPrefix(tableName, "__") {
   149  				tableNames = append(tableNames, tableName)
   150  			}
   151  		}
   152  	}
   153  	if err := tables.Err(); err != nil {
   154  		return &database.Error{OrigErr: err, Query: []byte(query)}
   155  	}
   156  
   157  	if len(tableNames) > 0 {
   158  		for _, t := range tableNames {
   159  			query := "DROP TABLE " + t
   160  			err = m.executeQuery(query)
   161  			if err != nil {
   162  				return &database.Error{OrigErr: err, Query: []byte(query)}
   163  			}
   164  		}
   165  	}
   166  
   167  	return nil
   168  }
   169  func (m *Ql) Lock() error {
   170  	if !m.isLocked.CAS(false, true) {
   171  		return database.ErrLocked
   172  	}
   173  	return nil
   174  }
   175  func (m *Ql) Unlock() error {
   176  	if !m.isLocked.CAS(true, false) {
   177  		return database.ErrNotLocked
   178  	}
   179  	return nil
   180  }
   181  func (m *Ql) Run(migration io.Reader) error {
   182  	migr, err := ioutil.ReadAll(migration)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	query := string(migr[:])
   187  
   188  	return m.executeQuery(query)
   189  }
   190  func (m *Ql) executeQuery(query string) error {
   191  	tx, err := m.db.Begin()
   192  	if err != nil {
   193  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   194  	}
   195  	if _, err := tx.Exec(query); err != nil {
   196  		if errRollback := tx.Rollback(); errRollback != nil {
   197  			err = multierror.Append(err, errRollback)
   198  		}
   199  		return &database.Error{OrigErr: err, Query: []byte(query)}
   200  	}
   201  	if err := tx.Commit(); err != nil {
   202  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   203  	}
   204  	return nil
   205  }
   206  func (m *Ql) SetVersion(version int, dirty bool) error {
   207  	tx, err := m.db.Begin()
   208  	if err != nil {
   209  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   210  	}
   211  
   212  	query := "TRUNCATE TABLE " + m.config.MigrationsTable
   213  	if _, err := tx.Exec(query); err != nil {
   214  		return &database.Error{OrigErr: err, Query: []byte(query)}
   215  	}
   216  
   217  	// Also re-write the schema version for nil dirty versions to prevent
   218  	// empty schema version for failed down migration on the first migration
   219  	// See: https://github.com/golang-migrate/migrate/issues/330
   220  	if version >= 0 || (version == database.NilVersion && dirty) {
   221  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`,
   222  			m.config.MigrationsTable)
   223  		if _, err := tx.Exec(query, version, dirty); err != nil {
   224  			if errRollback := tx.Rollback(); errRollback != nil {
   225  				err = multierror.Append(err, errRollback)
   226  			}
   227  			return &database.Error{OrigErr: err, Query: []byte(query)}
   228  		}
   229  	}
   230  
   231  	if err := tx.Commit(); err != nil {
   232  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   233  	}
   234  
   235  	return nil
   236  }
   237  
   238  func (m *Ql) Version() (version int, dirty bool, err error) {
   239  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   240  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   241  	if err != nil {
   242  		return database.NilVersion, false, nil
   243  	}
   244  	return version, dirty, nil
   245  }
   246  

View as plain text