1 package dbal
2
3 import (
4 "fmt"
5 "path/filepath"
6 "sort"
7 "strings"
8
9 "github.com/gobuffalo/packr"
10 "github.com/pkg/errors"
11 migrate "github.com/rubenv/sql-migrate"
12
13 "github.com/ory/x/logrusx"
14 )
15
16 type migrationFile struct {
17 Filename string
18 Filepath string
19 Content []byte
20 }
21
22 const migrationBasePath = "/migrations/sql"
23
24 type migrationFiles []migrationFile
25
26 func (s migrationFiles) Len() int { return len(s) }
27 func (s migrationFiles) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
28 func (s migrationFiles) Less(i, j int) bool { return s[i].Filename < s[j].Filename }
29
30
31 type PackrMigrationSource struct {
32 *migrate.PackrMigrationSource
33 omitExtension bool
34 }
35
36
37 func (p PackrMigrationSource) FindMigrations() ([]*migrate.Migration, error) {
38 migrations, err := p.PackrMigrationSource.FindMigrations()
39 if err != nil {
40 return nil, err
41 }
42
43 if p.omitExtension {
44 for k, m := range migrations {
45 m.Id = strings.TrimSuffix(m.Id, filepath.Ext(m.Id))
46 migrations[k] = m
47 }
48 }
49
50 return migrations, err
51 }
52
53
54 func FindMatchingTestMigrations(folder string, migrations map[string]*PackrMigrationSource, assetNames []string, asset func(string) ([]byte, error)) map[string]*PackrMigrationSource {
55 var testMigrations = map[string]*PackrMigrationSource{}
56 for name, migration := range migrations {
57 var filter []string
58 for _, file := range migration.PackrMigrationSource.Box.List() {
59 f := folder + strings.Replace(filepath.Base(file), ".sql", "_test.sql", 1)
60 filter = append(filter, f)
61 }
62 testMigrations[name] = NewMustPackerMigrationSource(logrusx.New("", ""), assetNames, asset, filter, true)
63 }
64
65 return testMigrations
66 }
67
68
69 func NewMustPackerMigrationSource(l *logrusx.Logger, folder []string, loader func(string) ([]byte, error), filters []string, omitExtension bool) *PackrMigrationSource {
70 m, err := NewPackerMigrationSource(l, folder, loader, filters, omitExtension)
71 if err != nil {
72 l.WithError(err).WithField("stack", fmt.Sprintf("%+v", err)).Fatal("Unable to set up migration source")
73 }
74 return m
75 }
76
77
78 func NewPackerMigrationSource(l *logrusx.Logger, sources []string, loader func(string) ([]byte, error), filters []string, omitExtension bool) (*PackrMigrationSource, error) {
79 b := packr.NewBox(migrationBasePath)
80 var files migrationFiles
81
82 for _, source := range sources {
83 if filepath.Ext(source) != ".sql" {
84 continue
85 }
86
87 var found bool
88 for _, f := range filters {
89 if strings.Contains(source, f) {
90 found = true
91 }
92 }
93
94 if !found {
95 l.WithField("file", source).WithField("filters", fmt.Sprintf("%v", filters)).Debug("Ignoring file because path does not match filters")
96 continue
97 }
98
99 l.WithField("file", source).Debug("Processing sql migration file")
100
101 body, err := loader(source)
102 if err != nil {
103 return nil, errors.WithStack(err)
104 }
105
106 files = append(files, migrationFile{
107 Filename: filepath.Base(source),
108 Filepath: source,
109 Content: body,
110 })
111 }
112
113 sort.Sort(files)
114
115 for _, f := range files {
116
117 b.AddBytes(filepath.ToSlash(filepath.Join(migrationBasePath, f.Filename)), f.Content)
118 }
119
120 return &PackrMigrationSource{
121 PackrMigrationSource: &migrate.PackrMigrationSource{
122 Box: b,
123 Dir: migrationBasePath,
124 },
125 omitExtension: omitExtension,
126 }, nil
127 }
128
View as plain text