...

Source file src/github.com/golang-migrate/migrate/v4/database/mongodb/mongodb.go

Documentation: github.com/golang-migrate/migrate/v4/database/mongodb

     1  package mongodb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/cenkalti/backoff/v4"
     7  	"github.com/golang-migrate/migrate/v4/database"
     8  	"github.com/hashicorp/go-multierror"
     9  	"go.mongodb.org/mongo-driver/bson"
    10  	"go.mongodb.org/mongo-driver/mongo"
    11  	"go.mongodb.org/mongo-driver/mongo/options"
    12  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    13  	"go.uber.org/atomic"
    14  	"io"
    15  	"io/ioutil"
    16  	"net/url"
    17  	os "os"
    18  	"strconv"
    19  	"time"
    20  )
    21  
    22  func init() {
    23  	db := Mongo{}
    24  	database.Register("mongodb", &db)
    25  	database.Register("mongodb+srv", &db)
    26  }
    27  
    28  var DefaultMigrationsCollection = "schema_migrations"
    29  
    30  const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
    31  const lockKeyUniqueValue = 0                             // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
    32  const DefaultLockTimeout = 15                            // the default maximum time to wait for a lock to be released.
    33  const DefaultLockTimeoutInterval = 10                    // the default maximum intervals time for the locking timout.
    34  const DefaultAdvisoryLockingFlag = true                  // the default value for the advisory locking feature flag. Default is true.
    35  const LockIndexName = "lock_unique_key"                  // the name of the index which adds unique constraint to the locking_key field.
    36  const contextWaitTimeout = 5 * time.Second               // how long to wait for the request to mongo to block/wait for.
    37  
    38  var (
    39  	ErrNoDatabaseName = fmt.Errorf("no database name")
    40  	ErrNilConfig      = fmt.Errorf("no config")
    41  )
    42  
    43  type Mongo struct {
    44  	client   *mongo.Client
    45  	db       *mongo.Database
    46  	config   *Config
    47  	isLocked atomic.Bool
    48  }
    49  
    50  type Locking struct {
    51  	CollectionName string
    52  	Timeout        int
    53  	Enabled        bool
    54  	Interval       int
    55  }
    56  type Config struct {
    57  	DatabaseName         string
    58  	MigrationsCollection string
    59  	TransactionMode      bool
    60  	Locking              Locking
    61  }
    62  type versionInfo struct {
    63  	Version int  `bson:"version"`
    64  	Dirty   bool `bson:"dirty"`
    65  }
    66  
    67  type lockObj struct {
    68  	Key       int       `bson:"locking_key"`
    69  	Pid       int       `bson:"pid"`
    70  	Hostname  string    `bson:"hostname"`
    71  	CreatedAt time.Time `bson:"created_at"`
    72  }
    73  type findFilter struct {
    74  	Key int `bson:"locking_key"`
    75  }
    76  
    77  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    78  	if config == nil {
    79  		return nil, ErrNilConfig
    80  	}
    81  	if len(config.DatabaseName) == 0 {
    82  		return nil, ErrNoDatabaseName
    83  	}
    84  	if len(config.MigrationsCollection) == 0 {
    85  		config.MigrationsCollection = DefaultMigrationsCollection
    86  	}
    87  	if len(config.Locking.CollectionName) == 0 {
    88  		config.Locking.CollectionName = DefaultLockingCollection
    89  	}
    90  	if config.Locking.Timeout <= 0 {
    91  		config.Locking.Timeout = DefaultLockTimeout
    92  	}
    93  	if config.Locking.Interval <= 0 {
    94  		config.Locking.Interval = DefaultLockTimeoutInterval
    95  	}
    96  
    97  	mc := &Mongo{
    98  		client: instance,
    99  		db:     instance.Database(config.DatabaseName),
   100  		config: config,
   101  	}
   102  
   103  	if mc.config.Locking.Enabled {
   104  		if err := mc.ensureLockTable(); err != nil {
   105  			return nil, err
   106  		}
   107  	}
   108  	if err := mc.ensureVersionTable(); err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	return mc, nil
   113  }
   114  
   115  func (m *Mongo) Open(dsn string) (database.Driver, error) {
   116  	//connstring is experimental package, but it used for parse connection string in mongo.Connect function
   117  	uri, err := connstring.Parse(dsn)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	if len(uri.Database) == 0 {
   122  		return nil, ErrNoDatabaseName
   123  	}
   124  	unknown := url.Values(uri.UnknownOptions)
   125  
   126  	migrationsCollection := unknown.Get("x-migrations-collection")
   127  	lockCollection := unknown.Get("x-advisory-lock-collection")
   128  	transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	maxLockingIntervals, err := parseInt(unknown.Get("x-advisory-lock-timout-interval"), DefaultLockTimeoutInterval)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	if err = client.Ping(context.TODO(), nil); err != nil {
   150  		return nil, err
   151  	}
   152  	mc, err := WithInstance(client, &Config{
   153  		DatabaseName:         uri.Database,
   154  		MigrationsCollection: migrationsCollection,
   155  		TransactionMode:      transactionMode,
   156  		Locking: Locking{
   157  			CollectionName: lockCollection,
   158  			Timeout:        lockingTimout,
   159  			Enabled:        advisoryLockingFlag,
   160  			Interval:       maxLockingIntervals,
   161  		},
   162  	})
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	return mc, nil
   167  }
   168  
   169  //Parse the url param, convert it to boolean
   170  // returns error if param invalid. returns defaultValue if param not present
   171  func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
   172  
   173  	// if parameter passed, parse it (otherwise return default value)
   174  	if urlParam != "" {
   175  		result, err := strconv.ParseBool(urlParam)
   176  		if err != nil {
   177  			return false, err
   178  		}
   179  		return result, nil
   180  	}
   181  
   182  	// if no url Param passed, return default value
   183  	return defaultValue, nil
   184  }
   185  
   186  //Parse the url param, convert it to int
   187  // returns error if param invalid. returns defaultValue if param not present
   188  func parseInt(urlParam string, defaultValue int) (int, error) {
   189  
   190  	// if parameter passed, parse it (otherwise return default value)
   191  	if urlParam != "" {
   192  		result, err := strconv.Atoi(urlParam)
   193  		if err != nil {
   194  			return -1, err
   195  		}
   196  		return result, nil
   197  	}
   198  
   199  	// if no url Param passed, return default value
   200  	return defaultValue, nil
   201  }
   202  func (m *Mongo) SetVersion(version int, dirty bool) error {
   203  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   204  	if err := migrationsCollection.Drop(context.TODO()); err != nil {
   205  		return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   206  	}
   207  	_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   208  	if err != nil {
   209  		return &database.Error{OrigErr: err, Err: "save version failed"}
   210  	}
   211  	return nil
   212  }
   213  
   214  func (m *Mongo) Version() (version int, dirty bool, err error) {
   215  	var versionInfo versionInfo
   216  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   217  	switch {
   218  	case err == mongo.ErrNoDocuments:
   219  		return database.NilVersion, false, nil
   220  	case err != nil:
   221  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   222  	default:
   223  		return versionInfo.Version, versionInfo.Dirty, nil
   224  	}
   225  }
   226  
   227  func (m *Mongo) Run(migration io.Reader) error {
   228  	migr, err := ioutil.ReadAll(migration)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	var cmds []bson.D
   233  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   234  	if err != nil {
   235  		return fmt.Errorf("unmarshaling json error: %s", err)
   236  	}
   237  	if m.config.TransactionMode {
   238  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   239  			return err
   240  		}
   241  	} else {
   242  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   243  			return err
   244  		}
   245  	}
   246  	return nil
   247  }
   248  
   249  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   250  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   251  		if err := sessionContext.StartTransaction(); err != nil {
   252  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   253  		}
   254  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   255  			//When command execution is failed, it's aborting transaction
   256  			//If you tried to call abortTransaction, it`s return error that transaction already aborted
   257  			return err
   258  		}
   259  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   260  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   261  		}
   262  		return nil
   263  	})
   264  	if err != nil {
   265  		return err
   266  	}
   267  	return nil
   268  }
   269  
   270  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   271  	for _, cmd := range cmds {
   272  		err := m.db.RunCommand(ctx, cmd).Err()
   273  		if err != nil {
   274  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   275  		}
   276  	}
   277  	return nil
   278  }
   279  
   280  func (m *Mongo) Close() error {
   281  	return m.client.Disconnect(context.TODO())
   282  }
   283  
   284  func (m *Mongo) Drop() error {
   285  	return m.db.Drop(context.TODO())
   286  }
   287  
   288  func (m *Mongo) ensureLockTable() error {
   289  	indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
   290  
   291  	indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
   292  	_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
   293  		Options: indexOptions,
   294  		Keys:    findFilter{Key: -1},
   295  	})
   296  	if err != nil {
   297  		return err
   298  	}
   299  	return nil
   300  }
   301  
   302  // ensureVersionTable checks if versions table exists and, if not, creates it.
   303  // Note that this function locks the database, which deviates from the usual
   304  // convention of "caller locks" in the MongoDb type.
   305  func (m *Mongo) ensureVersionTable() (err error) {
   306  	if err = m.Lock(); err != nil {
   307  		return err
   308  	}
   309  
   310  	defer func() {
   311  		if e := m.Unlock(); e != nil {
   312  			if err == nil {
   313  				err = e
   314  			} else {
   315  				err = multierror.Append(err, e)
   316  			}
   317  		}
   318  	}()
   319  
   320  	if err != nil {
   321  		return err
   322  	}
   323  	if _, _, err = m.Version(); err != nil {
   324  		return err
   325  	}
   326  	return nil
   327  }
   328  
   329  // Utilizes advisory locking on the config.LockingCollection collection
   330  // This uses a unique index on the `locking_key` field.
   331  func (m *Mongo) Lock() error {
   332  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   333  		if !m.config.Locking.Enabled {
   334  			return nil
   335  		}
   336  
   337  		pid := os.Getpid()
   338  		hostname, err := os.Hostname()
   339  		if err != nil {
   340  			hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
   341  		}
   342  
   343  		newLockObj := lockObj{
   344  			Key:       lockKeyUniqueValue,
   345  			Pid:       pid,
   346  			Hostname:  hostname,
   347  			CreatedAt: time.Now(),
   348  		}
   349  		operation := func() error {
   350  			timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
   351  			_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
   352  			defer cancelFunc()
   353  			return err
   354  		}
   355  		exponentialBackOff := backoff.NewExponentialBackOff()
   356  		duration := time.Duration(m.config.Locking.Timeout) * time.Second
   357  		exponentialBackOff.MaxElapsedTime = duration
   358  		exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
   359  
   360  		err = backoff.Retry(operation, exponentialBackOff)
   361  		if err != nil {
   362  			return database.ErrLocked
   363  		}
   364  
   365  		return nil
   366  	})
   367  }
   368  
   369  func (m *Mongo) Unlock() error {
   370  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   371  		if !m.config.Locking.Enabled {
   372  			return nil
   373  		}
   374  
   375  		filter := findFilter{
   376  			Key: lockKeyUniqueValue,
   377  		}
   378  
   379  		ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
   380  		_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
   381  		defer cancel()
   382  
   383  		if err != nil {
   384  			return err
   385  		}
   386  		return nil
   387  	})
   388  }
   389  

View as plain text