1 package cli
2
3 import (
4 "errors"
5 "fmt"
6 "os"
7 "path/filepath"
8 "strconv"
9 "strings"
10 "time"
11
12 "github.com/golang-migrate/migrate/v4"
13 _ "github.com/golang-migrate/migrate/v4/database/stub"
14 _ "github.com/golang-migrate/migrate/v4/source/file"
15 )
16
17 var (
18 errInvalidSequenceWidth = errors.New("Digits must be positive")
19 errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive")
20 errInvalidTimeFormat = errors.New("Time format may not be empty")
21 )
22
23 func nextSeqVersion(matches []string, seqDigits int) (string, error) {
24 if seqDigits <= 0 {
25 return "", errInvalidSequenceWidth
26 }
27
28 nextSeq := uint64(1)
29
30 if len(matches) > 0 {
31 filename := matches[len(matches)-1]
32 matchSeqStr := filepath.Base(filename)
33 idx := strings.Index(matchSeqStr, "_")
34
35 if idx < 1 {
36 return "", fmt.Errorf("Malformed migration filename: %s", filename)
37 }
38
39 var err error
40 matchSeqStr = matchSeqStr[0:idx]
41 nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64)
42
43 if err != nil {
44 return "", err
45 }
46
47 nextSeq++
48 }
49
50 version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits)
51
52 if len(version) > seqDigits {
53 return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits)
54 }
55
56 return version, nil
57 }
58
59 func timeVersion(startTime time.Time, format string) (version string, err error) {
60 switch format {
61 case "":
62 err = errInvalidTimeFormat
63 case "unix":
64 version = strconv.FormatInt(startTime.Unix(), 10)
65 case "unixNano":
66 version = strconv.FormatInt(startTime.UnixNano(), 10)
67 default:
68 version = startTime.Format(format)
69 }
70
71 return
72 }
73
74
75 func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error {
76 if seq && format != defaultTimeFormat {
77 return errIncompatibleSeqAndFormat
78 }
79
80 var version string
81 var err error
82
83 dir = filepath.Clean(dir)
84 ext = "." + strings.TrimPrefix(ext, ".")
85
86 if seq {
87 matches, err := filepath.Glob(filepath.Join(dir, "*"+ext))
88
89 if err != nil {
90 return err
91 }
92
93 version, err = nextSeqVersion(matches, seqDigits)
94
95 if err != nil {
96 return err
97 }
98 } else {
99 version, err = timeVersion(startTime, format)
100
101 if err != nil {
102 return err
103 }
104 }
105
106 versionGlob := filepath.Join(dir, version+"_*"+ext)
107 matches, err := filepath.Glob(versionGlob)
108
109 if err != nil {
110 return err
111 }
112
113 if len(matches) > 0 {
114 return fmt.Errorf("duplicate migration version: %s", version)
115 }
116
117 if err = os.MkdirAll(dir, os.ModePerm); err != nil {
118 return err
119 }
120
121 for _, direction := range []string{"up", "down"} {
122 basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext)
123 filename := filepath.Join(dir, basename)
124
125 if err = createFile(filename); err != nil {
126 return err
127 }
128
129 if print {
130 absPath, _ := filepath.Abs(filename)
131 log.Println(absPath)
132 }
133 }
134
135 return nil
136 }
137
138 func createFile(filename string) error {
139
140
141 f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
142
143 if err != nil {
144 return err
145 }
146
147 return f.Close()
148 }
149
150 func gotoCmd(m *migrate.Migrate, v uint) error {
151 if err := m.Migrate(v); err != nil {
152 if err != migrate.ErrNoChange {
153 return err
154 }
155 log.Println(err)
156 }
157 return nil
158 }
159
160 func upCmd(m *migrate.Migrate, limit int) error {
161 if limit >= 0 {
162 if err := m.Steps(limit); err != nil {
163 if err != migrate.ErrNoChange {
164 return err
165 }
166 log.Println(err)
167 }
168 } else {
169 if err := m.Up(); err != nil {
170 if err != migrate.ErrNoChange {
171 return err
172 }
173 log.Println(err)
174 }
175 }
176 return nil
177 }
178
179 func downCmd(m *migrate.Migrate, limit int) error {
180 if limit >= 0 {
181 if err := m.Steps(-limit); err != nil {
182 if err != migrate.ErrNoChange {
183 return err
184 }
185 log.Println(err)
186 }
187 } else {
188 if err := m.Down(); err != nil {
189 if err != migrate.ErrNoChange {
190 return err
191 }
192 log.Println(err)
193 }
194 }
195 return nil
196 }
197
198 func dropCmd(m *migrate.Migrate) error {
199 if err := m.Drop(); err != nil {
200 return err
201 }
202 return nil
203 }
204
205 func forceCmd(m *migrate.Migrate, v int) error {
206 if err := m.Force(v); err != nil {
207 return err
208 }
209 return nil
210 }
211
212 func versionCmd(m *migrate.Migrate) error {
213 v, dirty, err := m.Version()
214 if err != nil {
215 return err
216 }
217 if dirty {
218 log.Printf("%v (dirty)\n", v)
219 } else {
220 log.Println(v)
221 }
222 return nil
223 }
224
225
226
227 func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) {
228 if applyAll {
229 if len(args) > 0 {
230 return 0, false, errors.New("-all cannot be used with other arguments")
231 }
232 return -1, false, nil
233 }
234
235 switch len(args) {
236 case 0:
237 return -1, true, nil
238 case 1:
239 downValue := args[0]
240 n, err := strconv.ParseUint(downValue, 10, 64)
241 if err != nil {
242 return 0, false, errors.New("can't read limit argument N")
243 }
244 return int(n), false, nil
245 default:
246 return 0, false, errors.New("too many arguments")
247 }
248 }
249
View as plain text