1 package mongodb
2
3 import (
4 "context"
5 "fmt"
6 "github.com/cenkalti/backoff/v4"
7 "github.com/golang-migrate/migrate/v4/database"
8 "github.com/hashicorp/go-multierror"
9 "go.mongodb.org/mongo-driver/bson"
10 "go.mongodb.org/mongo-driver/mongo"
11 "go.mongodb.org/mongo-driver/mongo/options"
12 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
13 "go.uber.org/atomic"
14 "io"
15 "io/ioutil"
16 "net/url"
17 os "os"
18 "strconv"
19 "time"
20 )
21
22 func init() {
23 db := Mongo{}
24 database.Register("mongodb", &db)
25 database.Register("mongodb+srv", &db)
26 }
27
28 var DefaultMigrationsCollection = "schema_migrations"
29
30 const DefaultLockingCollection = "migrate_advisory_lock"
31 const lockKeyUniqueValue = 0
32 const DefaultLockTimeout = 15
33 const DefaultLockTimeoutInterval = 10
34 const DefaultAdvisoryLockingFlag = true
35 const LockIndexName = "lock_unique_key"
36 const contextWaitTimeout = 5 * time.Second
37
38 var (
39 ErrNoDatabaseName = fmt.Errorf("no database name")
40 ErrNilConfig = fmt.Errorf("no config")
41 )
42
43 type Mongo struct {
44 client *mongo.Client
45 db *mongo.Database
46 config *Config
47 isLocked atomic.Bool
48 }
49
50 type Locking struct {
51 CollectionName string
52 Timeout int
53 Enabled bool
54 Interval int
55 }
56 type Config struct {
57 DatabaseName string
58 MigrationsCollection string
59 TransactionMode bool
60 Locking Locking
61 }
62 type versionInfo struct {
63 Version int `bson:"version"`
64 Dirty bool `bson:"dirty"`
65 }
66
67 type lockObj struct {
68 Key int `bson:"locking_key"`
69 Pid int `bson:"pid"`
70 Hostname string `bson:"hostname"`
71 CreatedAt time.Time `bson:"created_at"`
72 }
73 type findFilter struct {
74 Key int `bson:"locking_key"`
75 }
76
77 func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
78 if config == nil {
79 return nil, ErrNilConfig
80 }
81 if len(config.DatabaseName) == 0 {
82 return nil, ErrNoDatabaseName
83 }
84 if len(config.MigrationsCollection) == 0 {
85 config.MigrationsCollection = DefaultMigrationsCollection
86 }
87 if len(config.Locking.CollectionName) == 0 {
88 config.Locking.CollectionName = DefaultLockingCollection
89 }
90 if config.Locking.Timeout <= 0 {
91 config.Locking.Timeout = DefaultLockTimeout
92 }
93 if config.Locking.Interval <= 0 {
94 config.Locking.Interval = DefaultLockTimeoutInterval
95 }
96
97 mc := &Mongo{
98 client: instance,
99 db: instance.Database(config.DatabaseName),
100 config: config,
101 }
102
103 if mc.config.Locking.Enabled {
104 if err := mc.ensureLockTable(); err != nil {
105 return nil, err
106 }
107 }
108 if err := mc.ensureVersionTable(); err != nil {
109 return nil, err
110 }
111
112 return mc, nil
113 }
114
115 func (m *Mongo) Open(dsn string) (database.Driver, error) {
116
117 uri, err := connstring.Parse(dsn)
118 if err != nil {
119 return nil, err
120 }
121 if len(uri.Database) == 0 {
122 return nil, ErrNoDatabaseName
123 }
124 unknown := url.Values(uri.UnknownOptions)
125
126 migrationsCollection := unknown.Get("x-migrations-collection")
127 lockCollection := unknown.Get("x-advisory-lock-collection")
128 transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
129 if err != nil {
130 return nil, err
131 }
132 advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
133 if err != nil {
134 return nil, err
135 }
136 lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
137 if err != nil {
138 return nil, err
139 }
140 maxLockingIntervals, err := parseInt(unknown.Get("x-advisory-lock-timout-interval"), DefaultLockTimeoutInterval)
141 if err != nil {
142 return nil, err
143 }
144 client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
145 if err != nil {
146 return nil, err
147 }
148
149 if err = client.Ping(context.TODO(), nil); err != nil {
150 return nil, err
151 }
152 mc, err := WithInstance(client, &Config{
153 DatabaseName: uri.Database,
154 MigrationsCollection: migrationsCollection,
155 TransactionMode: transactionMode,
156 Locking: Locking{
157 CollectionName: lockCollection,
158 Timeout: lockingTimout,
159 Enabled: advisoryLockingFlag,
160 Interval: maxLockingIntervals,
161 },
162 })
163 if err != nil {
164 return nil, err
165 }
166 return mc, nil
167 }
168
169
170
171 func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
172
173
174 if urlParam != "" {
175 result, err := strconv.ParseBool(urlParam)
176 if err != nil {
177 return false, err
178 }
179 return result, nil
180 }
181
182
183 return defaultValue, nil
184 }
185
186
187
188 func parseInt(urlParam string, defaultValue int) (int, error) {
189
190
191 if urlParam != "" {
192 result, err := strconv.Atoi(urlParam)
193 if err != nil {
194 return -1, err
195 }
196 return result, nil
197 }
198
199
200 return defaultValue, nil
201 }
202 func (m *Mongo) SetVersion(version int, dirty bool) error {
203 migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
204 if err := migrationsCollection.Drop(context.TODO()); err != nil {
205 return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
206 }
207 _, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
208 if err != nil {
209 return &database.Error{OrigErr: err, Err: "save version failed"}
210 }
211 return nil
212 }
213
214 func (m *Mongo) Version() (version int, dirty bool, err error) {
215 var versionInfo versionInfo
216 err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
217 switch {
218 case err == mongo.ErrNoDocuments:
219 return database.NilVersion, false, nil
220 case err != nil:
221 return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
222 default:
223 return versionInfo.Version, versionInfo.Dirty, nil
224 }
225 }
226
227 func (m *Mongo) Run(migration io.Reader) error {
228 migr, err := ioutil.ReadAll(migration)
229 if err != nil {
230 return err
231 }
232 var cmds []bson.D
233 err = bson.UnmarshalExtJSON(migr, true, &cmds)
234 if err != nil {
235 return fmt.Errorf("unmarshaling json error: %s", err)
236 }
237 if m.config.TransactionMode {
238 if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
239 return err
240 }
241 } else {
242 if err := m.executeCommands(context.TODO(), cmds); err != nil {
243 return err
244 }
245 }
246 return nil
247 }
248
249 func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
250 err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
251 if err := sessionContext.StartTransaction(); err != nil {
252 return &database.Error{OrigErr: err, Err: "failed to start transaction"}
253 }
254 if err := m.executeCommands(sessionContext, cmds); err != nil {
255
256
257 return err
258 }
259 if err := sessionContext.CommitTransaction(sessionContext); err != nil {
260 return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
261 }
262 return nil
263 })
264 if err != nil {
265 return err
266 }
267 return nil
268 }
269
270 func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
271 for _, cmd := range cmds {
272 err := m.db.RunCommand(ctx, cmd).Err()
273 if err != nil {
274 return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
275 }
276 }
277 return nil
278 }
279
280 func (m *Mongo) Close() error {
281 return m.client.Disconnect(context.TODO())
282 }
283
284 func (m *Mongo) Drop() error {
285 return m.db.Drop(context.TODO())
286 }
287
288 func (m *Mongo) ensureLockTable() error {
289 indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
290
291 indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
292 _, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
293 Options: indexOptions,
294 Keys: findFilter{Key: -1},
295 })
296 if err != nil {
297 return err
298 }
299 return nil
300 }
301
302
303
304
305 func (m *Mongo) ensureVersionTable() (err error) {
306 if err = m.Lock(); err != nil {
307 return err
308 }
309
310 defer func() {
311 if e := m.Unlock(); e != nil {
312 if err == nil {
313 err = e
314 } else {
315 err = multierror.Append(err, e)
316 }
317 }
318 }()
319
320 if err != nil {
321 return err
322 }
323 if _, _, err = m.Version(); err != nil {
324 return err
325 }
326 return nil
327 }
328
329
330
331 func (m *Mongo) Lock() error {
332 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
333 if !m.config.Locking.Enabled {
334 return nil
335 }
336
337 pid := os.Getpid()
338 hostname, err := os.Hostname()
339 if err != nil {
340 hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
341 }
342
343 newLockObj := lockObj{
344 Key: lockKeyUniqueValue,
345 Pid: pid,
346 Hostname: hostname,
347 CreatedAt: time.Now(),
348 }
349 operation := func() error {
350 timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
351 _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
352 defer cancelFunc()
353 return err
354 }
355 exponentialBackOff := backoff.NewExponentialBackOff()
356 duration := time.Duration(m.config.Locking.Timeout) * time.Second
357 exponentialBackOff.MaxElapsedTime = duration
358 exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
359
360 err = backoff.Retry(operation, exponentialBackOff)
361 if err != nil {
362 return database.ErrLocked
363 }
364
365 return nil
366 })
367 }
368
369 func (m *Mongo) Unlock() error {
370 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
371 if !m.config.Locking.Enabled {
372 return nil
373 }
374
375 filter := findFilter{
376 Key: lockKeyUniqueValue,
377 }
378
379 ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
380 _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
381 defer cancel()
382
383 if err != nil {
384 return err
385 }
386 return nil
387 })
388 }
389
View as plain text