1
2
3
4
5 package migrate
6
7 import (
8 "errors"
9 "fmt"
10 "os"
11 "sync"
12 "time"
13
14 "github.com/hashicorp/go-multierror"
15
16 "github.com/golang-migrate/migrate/v4/database"
17 iurl "github.com/golang-migrate/migrate/v4/internal/url"
18 "github.com/golang-migrate/migrate/v4/source"
19 )
20
21
22
23
24
25
26 var DefaultPrefetchMigrations = uint(10)
27
28
29 var DefaultLockTimeout = 15 * time.Second
30
31 var (
32 ErrNoChange = errors.New("no change")
33 ErrNilVersion = errors.New("no migration")
34 ErrInvalidVersion = errors.New("version must be >= -1")
35 ErrLocked = errors.New("database locked")
36 ErrLockTimeout = errors.New("timeout: can't acquire database lock")
37 )
38
39
40
41 type ErrShortLimit struct {
42 Short uint
43 }
44
45
46 func (e ErrShortLimit) Error() string {
47 return fmt.Sprintf("limit %v short", e.Short)
48 }
49
50 type ErrDirty struct {
51 Version int
52 }
53
54 func (e ErrDirty) Error() string {
55 return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
56 }
57
58 type Migrate struct {
59 sourceName string
60 sourceDrv source.Driver
61 databaseName string
62 databaseDrv database.Driver
63
64
65 Log Logger
66
67
68
69
70 GracefulStop chan bool
71 isLockedMu *sync.Mutex
72
73 isGracefulStop bool
74 isLocked bool
75
76
77
78 PrefetchMigrations uint
79
80
81
82 LockTimeout time.Duration
83 }
84
85
86
87 func New(sourceURL, databaseURL string) (*Migrate, error) {
88 m := newCommon()
89
90 sourceName, err := iurl.SchemeFromURL(sourceURL)
91 if err != nil {
92 return nil, err
93 }
94 m.sourceName = sourceName
95
96 databaseName, err := iurl.SchemeFromURL(databaseURL)
97 if err != nil {
98 return nil, err
99 }
100 m.databaseName = databaseName
101
102 sourceDrv, err := source.Open(sourceURL)
103 if err != nil {
104 return nil, err
105 }
106 m.sourceDrv = sourceDrv
107
108 databaseDrv, err := database.Open(databaseURL)
109 if err != nil {
110 return nil, err
111 }
112 m.databaseDrv = databaseDrv
113
114 return m, nil
115 }
116
117
118
119
120
121 func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
122 m := newCommon()
123
124 sourceName, err := iurl.SchemeFromURL(sourceURL)
125 if err != nil {
126 return nil, err
127 }
128 m.sourceName = sourceName
129
130 m.databaseName = databaseName
131
132 sourceDrv, err := source.Open(sourceURL)
133 if err != nil {
134 return nil, err
135 }
136 m.sourceDrv = sourceDrv
137
138 m.databaseDrv = databaseInstance
139
140 return m, nil
141 }
142
143
144
145
146
147 func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
148 m := newCommon()
149
150 databaseName, err := iurl.SchemeFromURL(databaseURL)
151 if err != nil {
152 return nil, err
153 }
154 m.databaseName = databaseName
155
156 m.sourceName = sourceName
157
158 databaseDrv, err := database.Open(databaseURL)
159 if err != nil {
160 return nil, err
161 }
162 m.databaseDrv = databaseDrv
163
164 m.sourceDrv = sourceInstance
165
166 return m, nil
167 }
168
169
170
171
172
173 func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
174 m := newCommon()
175
176 m.sourceName = sourceName
177 m.databaseName = databaseName
178
179 m.sourceDrv = sourceInstance
180 m.databaseDrv = databaseInstance
181
182 return m, nil
183 }
184
185 func newCommon() *Migrate {
186 return &Migrate{
187 GracefulStop: make(chan bool, 1),
188 PrefetchMigrations: DefaultPrefetchMigrations,
189 LockTimeout: DefaultLockTimeout,
190 isLockedMu: &sync.Mutex{},
191 }
192 }
193
194
195 func (m *Migrate) Close() (source error, database error) {
196 databaseSrvClose := make(chan error)
197 sourceSrvClose := make(chan error)
198
199 m.logVerbosePrintf("Closing source and database\n")
200
201 go func() {
202 databaseSrvClose <- m.databaseDrv.Close()
203 }()
204
205 go func() {
206 sourceSrvClose <- m.sourceDrv.Close()
207 }()
208
209 return <-sourceSrvClose, <-databaseSrvClose
210 }
211
212
213
214 func (m *Migrate) Migrate(version uint) error {
215 if err := m.lock(); err != nil {
216 return err
217 }
218
219 curVersion, dirty, err := m.databaseDrv.Version()
220 if err != nil {
221 return m.unlockErr(err)
222 }
223
224 if dirty {
225 return m.unlockErr(ErrDirty{curVersion})
226 }
227
228 ret := make(chan interface{}, m.PrefetchMigrations)
229 go m.read(curVersion, int(version), ret)
230
231 return m.unlockErr(m.runMigrations(ret))
232 }
233
234
235
236 func (m *Migrate) Steps(n int) error {
237 if n == 0 {
238 return ErrNoChange
239 }
240
241 if err := m.lock(); err != nil {
242 return err
243 }
244
245 curVersion, dirty, err := m.databaseDrv.Version()
246 if err != nil {
247 return m.unlockErr(err)
248 }
249
250 if dirty {
251 return m.unlockErr(ErrDirty{curVersion})
252 }
253
254 ret := make(chan interface{}, m.PrefetchMigrations)
255
256 if n > 0 {
257 go m.readUp(curVersion, n, ret)
258 } else {
259 go m.readDown(curVersion, -n, ret)
260 }
261
262 return m.unlockErr(m.runMigrations(ret))
263 }
264
265
266
267 func (m *Migrate) Up() error {
268 if err := m.lock(); err != nil {
269 return err
270 }
271
272 curVersion, dirty, err := m.databaseDrv.Version()
273 if err != nil {
274 return m.unlockErr(err)
275 }
276
277 if dirty {
278 return m.unlockErr(ErrDirty{curVersion})
279 }
280
281 ret := make(chan interface{}, m.PrefetchMigrations)
282
283 go m.readUp(curVersion, -1, ret)
284 return m.unlockErr(m.runMigrations(ret))
285 }
286
287
288
289 func (m *Migrate) Down() error {
290 if err := m.lock(); err != nil {
291 return err
292 }
293
294 curVersion, dirty, err := m.databaseDrv.Version()
295 if err != nil {
296 return m.unlockErr(err)
297 }
298
299 if dirty {
300 return m.unlockErr(ErrDirty{curVersion})
301 }
302
303 ret := make(chan interface{}, m.PrefetchMigrations)
304 go m.readDown(curVersion, -1, ret)
305 return m.unlockErr(m.runMigrations(ret))
306 }
307
308
309 func (m *Migrate) Drop() error {
310 if err := m.lock(); err != nil {
311 return err
312 }
313 if err := m.databaseDrv.Drop(); err != nil {
314 return m.unlockErr(err)
315 }
316 return m.unlock()
317 }
318
319
320
321
322
323 func (m *Migrate) Run(migration ...*Migration) error {
324 if len(migration) == 0 {
325 return ErrNoChange
326 }
327
328 if err := m.lock(); err != nil {
329 return err
330 }
331
332 curVersion, dirty, err := m.databaseDrv.Version()
333 if err != nil {
334 return m.unlockErr(err)
335 }
336
337 if dirty {
338 return m.unlockErr(ErrDirty{curVersion})
339 }
340
341 ret := make(chan interface{}, m.PrefetchMigrations)
342
343 go func() {
344 defer close(ret)
345 for _, migr := range migration {
346 if m.PrefetchMigrations > 0 && migr.Body != nil {
347 m.logVerbosePrintf("Start buffering %v\n", migr.LogString())
348 } else {
349 m.logVerbosePrintf("Scheduled %v\n", migr.LogString())
350 }
351
352 ret <- migr
353 go func(migr *Migration) {
354 if err := migr.Buffer(); err != nil {
355 m.logErr(err)
356 }
357 }(migr)
358 }
359 }()
360
361 return m.unlockErr(m.runMigrations(ret))
362 }
363
364
365
366
367 func (m *Migrate) Force(version int) error {
368 if version < -1 {
369 return ErrInvalidVersion
370 }
371
372 if err := m.lock(); err != nil {
373 return err
374 }
375
376 if err := m.databaseDrv.SetVersion(version, false); err != nil {
377 return m.unlockErr(err)
378 }
379
380 return m.unlock()
381 }
382
383
384
385 func (m *Migrate) Version() (version uint, dirty bool, err error) {
386 v, d, err := m.databaseDrv.Version()
387 if err != nil {
388 return 0, false, err
389 }
390
391 if v == database.NilVersion {
392 return 0, false, ErrNilVersion
393 }
394
395 return suint(v), d, nil
396 }
397
398
399
400
401
402 func (m *Migrate) read(from int, to int, ret chan<- interface{}) {
403 defer close(ret)
404
405
406 if from >= 0 {
407 if err := m.versionExists(suint(from)); err != nil {
408 ret <- err
409 return
410 }
411 }
412
413
414 if to >= 0 {
415 if err := m.versionExists(suint(to)); err != nil {
416 ret <- err
417 return
418 }
419 }
420
421
422 if from == to {
423 ret <- ErrNoChange
424 return
425 }
426
427 if from < to {
428
429
430 if from == -1 {
431 firstVersion, err := m.sourceDrv.First()
432 if err != nil {
433 ret <- err
434 return
435 }
436
437 migr, err := m.newMigration(firstVersion, int(firstVersion))
438 if err != nil {
439 ret <- err
440 return
441 }
442
443 ret <- migr
444 go func() {
445 if err := migr.Buffer(); err != nil {
446 m.logErr(err)
447 }
448 }()
449
450 from = int(firstVersion)
451 }
452
453
454 for from < to {
455 if m.stop() {
456 return
457 }
458
459 next, err := m.sourceDrv.Next(suint(from))
460 if err != nil {
461 ret <- err
462 return
463 }
464
465 migr, err := m.newMigration(next, int(next))
466 if err != nil {
467 ret <- err
468 return
469 }
470
471 ret <- migr
472 go func() {
473 if err := migr.Buffer(); err != nil {
474 m.logErr(err)
475 }
476 }()
477
478 from = int(next)
479 }
480
481 } else {
482
483
484 for from > to && from >= 0 {
485 if m.stop() {
486 return
487 }
488
489 prev, err := m.sourceDrv.Prev(suint(from))
490 if errors.Is(err, os.ErrNotExist) && to == -1 {
491
492 migr, err := m.newMigration(suint(from), -1)
493 if err != nil {
494 ret <- err
495 return
496 }
497 ret <- migr
498 go func() {
499 if err := migr.Buffer(); err != nil {
500 m.logErr(err)
501 }
502 }()
503
504 return
505
506 } else if err != nil {
507 ret <- err
508 return
509 }
510
511 migr, err := m.newMigration(suint(from), int(prev))
512 if err != nil {
513 ret <- err
514 return
515 }
516
517 ret <- migr
518 go func() {
519 if err := migr.Buffer(); err != nil {
520 m.logErr(err)
521 }
522 }()
523
524 from = int(prev)
525 }
526 }
527 }
528
529
530
531
532
533
534 func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) {
535 defer close(ret)
536
537
538 if from >= 0 {
539 if err := m.versionExists(suint(from)); err != nil {
540 ret <- err
541 return
542 }
543 }
544
545 if limit == 0 {
546 ret <- ErrNoChange
547 return
548 }
549
550 count := 0
551 for count < limit || limit == -1 {
552 if m.stop() {
553 return
554 }
555
556
557 if from == -1 {
558 firstVersion, err := m.sourceDrv.First()
559 if err != nil {
560 ret <- err
561 return
562 }
563
564 migr, err := m.newMigration(firstVersion, int(firstVersion))
565 if err != nil {
566 ret <- err
567 return
568 }
569
570 ret <- migr
571 go func() {
572 if err := migr.Buffer(); err != nil {
573 m.logErr(err)
574 }
575 }()
576 from = int(firstVersion)
577 count++
578 continue
579 }
580
581
582 next, err := m.sourceDrv.Next(suint(from))
583 if errors.Is(err, os.ErrNotExist) {
584
585 if limit == -1 && count == 0 {
586 ret <- ErrNoChange
587 return
588 }
589
590
591 if limit == -1 {
592 return
593 }
594
595
596 if limit > 0 && count == 0 {
597 ret <- os.ErrNotExist
598 return
599 }
600
601
602 if count < limit {
603 ret <- ErrShortLimit{suint(limit - count)}
604 return
605 }
606 }
607 if err != nil {
608 ret <- err
609 return
610 }
611
612 migr, err := m.newMigration(next, int(next))
613 if err != nil {
614 ret <- err
615 return
616 }
617
618 ret <- migr
619 go func() {
620 if err := migr.Buffer(); err != nil {
621 m.logErr(err)
622 }
623 }()
624 from = int(next)
625 count++
626 }
627 }
628
629
630
631
632
633
634 func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
635 defer close(ret)
636
637
638 if from >= 0 {
639 if err := m.versionExists(suint(from)); err != nil {
640 ret <- err
641 return
642 }
643 }
644
645 if limit == 0 {
646 ret <- ErrNoChange
647 return
648 }
649
650
651 if from == -1 && limit == -1 {
652 ret <- ErrNoChange
653 return
654 }
655
656
657 if from == -1 && limit > 0 {
658 ret <- os.ErrNotExist
659 return
660 }
661
662 count := 0
663 for count < limit || limit == -1 {
664 if m.stop() {
665 return
666 }
667
668 prev, err := m.sourceDrv.Prev(suint(from))
669 if errors.Is(err, os.ErrNotExist) {
670
671 if limit == -1 || limit-count > 0 {
672 firstVersion, err := m.sourceDrv.First()
673 if err != nil {
674 ret <- err
675 return
676 }
677
678 migr, err := m.newMigration(firstVersion, -1)
679 if err != nil {
680 ret <- err
681 return
682 }
683 ret <- migr
684 go func() {
685 if err := migr.Buffer(); err != nil {
686 m.logErr(err)
687 }
688 }()
689 count++
690 }
691
692 if count < limit {
693 ret <- ErrShortLimit{suint(limit - count)}
694 }
695 return
696 }
697 if err != nil {
698 ret <- err
699 return
700 }
701
702 migr, err := m.newMigration(suint(from), int(prev))
703 if err != nil {
704 ret <- err
705 return
706 }
707
708 ret <- migr
709 go func() {
710 if err := migr.Buffer(); err != nil {
711 m.logErr(err)
712 }
713 }()
714 from = int(prev)
715 count++
716 }
717 }
718
719
720
721
722
723
724
725 func (m *Migrate) runMigrations(ret <-chan interface{}) error {
726 for r := range ret {
727
728 if m.stop() {
729 return nil
730 }
731
732 switch r := r.(type) {
733 case error:
734 return r
735
736 case *Migration:
737 migr := r
738
739
740 if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
741 return err
742 }
743
744 if migr.Body != nil {
745 m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
746 if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
747 return err
748 }
749 }
750
751
752 if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
753 return err
754 }
755
756 endTime := time.Now()
757 readTime := migr.FinishedReading.Sub(migr.StartedBuffering)
758 runTime := endTime.Sub(migr.FinishedReading)
759
760
761 if m.Log != nil {
762 if m.Log.Verbose() {
763 m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime)
764 } else {
765 m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime)
766 }
767 }
768
769 default:
770 return fmt.Errorf("unknown type: %T with value: %+v", r, r)
771 }
772 }
773 return nil
774 }
775
776
777
778 func (m *Migrate) versionExists(version uint) (result error) {
779
780 up, _, err := m.sourceDrv.ReadUp(version)
781 if err == nil {
782 defer func() {
783 if errClose := up.Close(); errClose != nil {
784 result = multierror.Append(result, errClose)
785 }
786 }()
787 }
788 if errors.Is(err, os.ErrExist) {
789 return nil
790 } else if !errors.Is(err, os.ErrNotExist) {
791 return err
792 }
793
794
795 down, _, err := m.sourceDrv.ReadDown(version)
796 if err == nil {
797 defer func() {
798 if errClose := down.Close(); errClose != nil {
799 result = multierror.Append(result, errClose)
800 }
801 }()
802 }
803 if errors.Is(err, os.ErrExist) {
804 return nil
805 } else if !errors.Is(err, os.ErrNotExist) {
806 return err
807 }
808
809 err = fmt.Errorf("no migration found for version %d: %w", version, err)
810 m.logErr(err)
811 return err
812 }
813
814
815
816
817 func (m *Migrate) stop() bool {
818 if m.isGracefulStop {
819 return true
820 }
821
822 select {
823 case <-m.GracefulStop:
824 m.isGracefulStop = true
825 return true
826
827 default:
828 return false
829 }
830 }
831
832
833
834 func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, error) {
835 var migr *Migration
836
837 if targetVersion >= int(version) {
838 r, identifier, err := m.sourceDrv.ReadUp(version)
839 if errors.Is(err, os.ErrNotExist) {
840
841 migr, err = NewMigration(nil, "", version, targetVersion)
842 if err != nil {
843 return nil, err
844 }
845
846 } else if err != nil {
847 return nil, err
848
849 } else {
850
851 migr, err = NewMigration(r, identifier, version, targetVersion)
852 if err != nil {
853 return nil, err
854 }
855 }
856
857 } else {
858 r, identifier, err := m.sourceDrv.ReadDown(version)
859 if errors.Is(err, os.ErrNotExist) {
860
861 migr, err = NewMigration(nil, "", version, targetVersion)
862 if err != nil {
863 return nil, err
864 }
865
866 } else if err != nil {
867 return nil, err
868
869 } else {
870
871 migr, err = NewMigration(r, identifier, version, targetVersion)
872 if err != nil {
873 return nil, err
874 }
875 }
876 }
877
878 if m.PrefetchMigrations > 0 && migr.Body != nil {
879 m.logVerbosePrintf("Start buffering %v\n", migr.LogString())
880 } else {
881 m.logVerbosePrintf("Scheduled %v\n", migr.LogString())
882 }
883
884 return migr, nil
885 }
886
887
888
889 func (m *Migrate) lock() error {
890 m.isLockedMu.Lock()
891 defer m.isLockedMu.Unlock()
892
893 if m.isLocked {
894 return ErrLocked
895 }
896
897
898 done := make(chan bool, 1)
899 defer func() {
900 done <- true
901 }()
902
903
904 errchan := make(chan error, 2)
905
906
907 timeout := time.After(m.LockTimeout)
908 go func() {
909 for {
910 select {
911 case <-done:
912 return
913 case <-timeout:
914 errchan <- ErrLockTimeout
915 return
916 }
917 }
918 }()
919
920
921 go func() {
922 if err := m.databaseDrv.Lock(); err != nil {
923 errchan <- err
924 } else {
925 errchan <- nil
926 }
927 }()
928
929
930 err := <-errchan
931 if err == nil {
932 m.isLocked = true
933 }
934 return err
935 }
936
937
938
939
940 func (m *Migrate) unlock() error {
941 m.isLockedMu.Lock()
942 defer m.isLockedMu.Unlock()
943
944 if err := m.databaseDrv.Unlock(); err != nil {
945
946 return err
947 }
948
949 m.isLocked = false
950 return nil
951 }
952
953
954
955 func (m *Migrate) unlockErr(prevErr error) error {
956 if err := m.unlock(); err != nil {
957 return multierror.Append(prevErr, err)
958 }
959 return prevErr
960 }
961
962
963 func (m *Migrate) logPrintf(format string, v ...interface{}) {
964 if m.Log != nil {
965 m.Log.Printf(format, v...)
966 }
967 }
968
969
970 func (m *Migrate) logVerbosePrintf(format string, v ...interface{}) {
971 if m.Log != nil && m.Log.Verbose() {
972 m.Log.Printf(format, v...)
973 }
974 }
975
976
977 func (m *Migrate) logErr(err error) {
978 if m.Log != nil {
979 m.Log.Printf("error: %v", err)
980 }
981 }
982
View as plain text