...

Source file src/github.com/golang-migrate/migrate/v4/database/sqlserver/sqlserver.go

Documentation: github.com/golang-migrate/migrate/v4/database/sqlserver

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	nurl "net/url"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"go.uber.org/atomic"
    14  
    15  	"github.com/Azure/go-autorest/autorest/adal"
    16  	mssql "github.com/denisenkom/go-mssqldb" // mssql support
    17  	"github.com/golang-migrate/migrate/v4"
    18  	"github.com/golang-migrate/migrate/v4/database"
    19  	"github.com/hashicorp/go-multierror"
    20  )
    21  
    22  func init() {
    23  	database.Register("sqlserver", &SQLServer{})
    24  }
    25  
    26  // DefaultMigrationsTable is the name of the migrations table in the database
    27  var DefaultMigrationsTable = "schema_migrations"
    28  
    29  var (
    30  	ErrNilConfig                 = fmt.Errorf("no config")
    31  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    32  	ErrNoSchema                  = fmt.Errorf("no schema")
    33  	ErrDatabaseDirty             = fmt.Errorf("database is dirty")
    34  	ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
    35  )
    36  
    37  var lockErrorMap = map[mssql.ReturnStatus]string{
    38  	-1:   "The lock request timed out.",
    39  	-2:   "The lock request was canceled.",
    40  	-3:   "The lock request was chosen as a deadlock victim.",
    41  	-999: "Parameter validation or other call error.",
    42  }
    43  
    44  // Config for database
    45  type Config struct {
    46  	MigrationsTable string
    47  	DatabaseName    string
    48  	SchemaName      string
    49  }
    50  
    51  // SQL Server connection
    52  type SQLServer struct {
    53  	// Locking and unlocking need to use the same connection
    54  	conn     *sql.Conn
    55  	db       *sql.DB
    56  	isLocked atomic.Bool
    57  
    58  	// Open and WithInstance need to garantuee that config is never nil
    59  	config *Config
    60  }
    61  
    62  // WithInstance returns a database instance from an already created database connection.
    63  //
    64  // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
    65  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    66  	if config == nil {
    67  		return nil, ErrNilConfig
    68  	}
    69  
    70  	if err := instance.Ping(); err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	if config.DatabaseName == "" {
    75  		query := `SELECT DB_NAME()`
    76  		var databaseName string
    77  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    78  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    79  		}
    80  
    81  		if len(databaseName) == 0 {
    82  			return nil, ErrNoDatabaseName
    83  		}
    84  
    85  		config.DatabaseName = databaseName
    86  	}
    87  
    88  	if config.SchemaName == "" {
    89  		query := `SELECT SCHEMA_NAME()`
    90  		var schemaName string
    91  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    92  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    93  		}
    94  
    95  		if len(schemaName) == 0 {
    96  			return nil, ErrNoSchema
    97  		}
    98  
    99  		config.SchemaName = schemaName
   100  	}
   101  
   102  	if len(config.MigrationsTable) == 0 {
   103  		config.MigrationsTable = DefaultMigrationsTable
   104  	}
   105  
   106  	conn, err := instance.Conn(context.Background())
   107  
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	ss := &SQLServer{
   113  		conn:   conn,
   114  		db:     instance,
   115  		config: config,
   116  	}
   117  
   118  	if err := ss.ensureVersionTable(); err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	return ss, nil
   123  }
   124  
   125  // Open a connection to the database.
   126  func (ss *SQLServer) Open(url string) (database.Driver, error) {
   127  	purl, err := nurl.Parse(url)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	useMsiParam := purl.Query().Get("useMsi")
   133  	useMsi := false
   134  	if len(useMsiParam) > 0 {
   135  		useMsi, err = strconv.ParseBool(useMsiParam)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  
   141  	if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
   142  		return nil, ErrMultipleAuthOptionsPassed
   143  	}
   144  
   145  	filteredURL := migrate.FilterCustomQuery(purl).String()
   146  
   147  	var db *sql.DB
   148  	if useMsi {
   149  		resource := getAADResourceFromServerUri(purl)
   150  		tokenProvider, err := getMSITokenProvider(resource)
   151  		if err != nil {
   152  			return nil, err
   153  		}
   154  
   155  		connector, err := mssql.NewAccessTokenConnector(
   156  			filteredURL, tokenProvider)
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  
   161  		db = sql.OpenDB(connector)
   162  
   163  	} else {
   164  		db, err = sql.Open("sqlserver", filteredURL)
   165  		if err != nil {
   166  			return nil, err
   167  		}
   168  	}
   169  
   170  	migrationsTable := purl.Query().Get("x-migrations-table")
   171  
   172  	px, err := WithInstance(db, &Config{
   173  		DatabaseName:    purl.Path,
   174  		MigrationsTable: migrationsTable,
   175  	})
   176  
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	return px, nil
   182  }
   183  
   184  // Close the database connection
   185  func (ss *SQLServer) Close() error {
   186  	connErr := ss.conn.Close()
   187  	dbErr := ss.db.Close()
   188  	if connErr != nil || dbErr != nil {
   189  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   190  	}
   191  	return nil
   192  }
   193  
   194  // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
   195  func (ss *SQLServer) Lock() error {
   196  	return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
   197  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   198  		if err != nil {
   199  			return err
   200  		}
   201  
   202  		// This will either obtain the lock immediately and return true,
   203  		// or return false if the lock cannot be acquired immediately.
   204  		// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
   205  		query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
   206  
   207  		var status mssql.ReturnStatus
   208  		if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
   209  			return nil
   210  		} else if err != nil {
   211  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   212  		} else {
   213  			return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
   214  		}
   215  	})
   216  }
   217  
   218  // Unlock froms the migration lock from the database
   219  func (ss *SQLServer) Unlock() error {
   220  	return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
   221  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   222  		if err != nil {
   223  			return err
   224  		}
   225  
   226  		// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
   227  		query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
   228  		if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
   229  			return &database.Error{OrigErr: err, Query: []byte(query)}
   230  		}
   231  
   232  		return nil
   233  	})
   234  }
   235  
   236  // Run the migrations for the database
   237  func (ss *SQLServer) Run(migration io.Reader) error {
   238  	migr, err := ioutil.ReadAll(migration)
   239  	if err != nil {
   240  		return err
   241  	}
   242  
   243  	// run migration
   244  	query := string(migr[:])
   245  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   246  		if msErr, ok := err.(mssql.Error); ok {
   247  			message := fmt.Sprintf("migration failed: %s", msErr.Message)
   248  			if msErr.ProcName != "" {
   249  				message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
   250  			}
   251  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
   252  		}
   253  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   254  	}
   255  
   256  	return nil
   257  }
   258  
   259  // SetVersion for the current database
   260  func (ss *SQLServer) SetVersion(version int, dirty bool) error {
   261  
   262  	tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
   263  	if err != nil {
   264  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   265  	}
   266  
   267  	query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
   268  	if _, err := tx.Exec(query); err != nil {
   269  		if errRollback := tx.Rollback(); errRollback != nil {
   270  			err = multierror.Append(err, errRollback)
   271  		}
   272  		return &database.Error{OrigErr: err, Query: []byte(query)}
   273  	}
   274  
   275  	// Also re-write the schema version for nil dirty versions to prevent
   276  	// empty schema version for failed down migration on the first migration
   277  	// See: https://github.com/golang-migrate/migrate/issues/330
   278  	if version >= 0 || (version == database.NilVersion && dirty) {
   279  		var dirtyBit int
   280  		if dirty {
   281  			dirtyBit = 1
   282  		}
   283  		query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
   284  		if _, err := tx.Exec(query, version, dirtyBit); err != nil {
   285  			if errRollback := tx.Rollback(); errRollback != nil {
   286  				err = multierror.Append(err, errRollback)
   287  			}
   288  			return &database.Error{OrigErr: err, Query: []byte(query)}
   289  		}
   290  	}
   291  
   292  	if err := tx.Commit(); err != nil {
   293  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   294  	}
   295  
   296  	return nil
   297  }
   298  
   299  // Version of the current database state
   300  func (ss *SQLServer) Version() (version int, dirty bool, err error) {
   301  	query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
   302  	err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   303  	switch {
   304  	case err == sql.ErrNoRows:
   305  		return database.NilVersion, false, nil
   306  
   307  	case err != nil:
   308  		// FIXME: convert to MSSQL error
   309  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   310  
   311  	default:
   312  		return version, dirty, nil
   313  	}
   314  }
   315  
   316  // Drop all tables from the database.
   317  func (ss *SQLServer) Drop() error {
   318  
   319  	// drop all referential integrity constraints
   320  	query := `
   321  	DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
   322  
   323  	SET @Cursor = CURSOR FAST_FORWARD FOR
   324  	SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
   325  	FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
   326  	LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
   327  
   328  	OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
   329  
   330  	WHILE (@@FETCH_STATUS = 0)
   331  	BEGIN
   332  	Exec sp_executesql @Sql
   333  	FETCH NEXT FROM @Cursor INTO @Sql
   334  	END
   335  
   336  	CLOSE @Cursor DEALLOCATE @Cursor`
   337  
   338  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   339  		return &database.Error{OrigErr: err, Query: []byte(query)}
   340  	}
   341  
   342  	// drop the tables
   343  	query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
   344  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   345  		return &database.Error{OrigErr: err, Query: []byte(query)}
   346  	}
   347  
   348  	return nil
   349  }
   350  
   351  func (ss *SQLServer) ensureVersionTable() (err error) {
   352  	if err = ss.Lock(); err != nil {
   353  		return err
   354  	}
   355  
   356  	defer func() {
   357  		if e := ss.Unlock(); e != nil {
   358  			if err == nil {
   359  				err = e
   360  			} else {
   361  				err = multierror.Append(err, e)
   362  			}
   363  		}
   364  	}()
   365  
   366  	query := `IF NOT EXISTS
   367  	(SELECT *
   368  		 FROM sysobjects
   369  		WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
   370  			AND OBJECTPROPERTY(id, N'IsUserTable') = 1
   371  	)
   372  	CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
   373  
   374  	if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
   375  		return &database.Error{OrigErr: err, Query: []byte(query)}
   376  	}
   377  
   378  	return nil
   379  }
   380  
   381  func getMSITokenProvider(resource string) (func() (string, error), error) {
   382  	msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  
   387  	return func() (string, error) {
   388  		err := msi.EnsureFresh()
   389  		if err != nil {
   390  			return "", err
   391  		}
   392  		token := msi.OAuthToken()
   393  		return token, nil
   394  	}, nil
   395  }
   396  
   397  // The sql server resource can change across clouds so get it
   398  // dynamically based on the server uri.
   399  // ex. <server name>.database.windows.net -> https://database.windows.net
   400  func getAADResourceFromServerUri(purl *nurl.URL) string {
   401  	return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
   402  }
   403  

View as plain text