1 package cli
2
3 import (
4 "flag"
5 "fmt"
6 "os"
7 "os/signal"
8 "strconv"
9 "strings"
10 "syscall"
11 "time"
12
13 "github.com/golang-migrate/migrate/v4"
14 "github.com/golang-migrate/migrate/v4/database"
15 "github.com/golang-migrate/migrate/v4/source"
16 )
17
18 const (
19 defaultTimeFormat = "20060102150405"
20 defaultTimezone = "UTC"
21 createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME
22 Create a set of timestamped up/down migrations titled NAME, in directory D with extension E.
23 Use -seq option to generate sequential up/down migrations with N digits.
24 Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error.
25 Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC).
26 `
27 gotoUsage = `goto V Migrate to version V`
28 upUsage = `up [N] Apply all or N up migrations`
29 downUsage = `down [N] [-all] Apply all or N down migrations
30 Use -all to apply all down migrations`
31 dropUsage = `drop [-f] Drop everything inside database
32 Use -f to bypass confirmation`
33 forceUsage = `force V Set version V but don't run migration (ignores dirty state)`
34 )
35
36 func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) {
37 if help {
38 fmt.Fprintln(os.Stderr, usage)
39 flagSet.PrintDefaults()
40 os.Exit(0)
41 }
42 }
43
44 func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) {
45 flagSet := flag.NewFlagSet(name, flag.ExitOnError)
46 helpPtr := flagSet.Bool("help", false, "Print help information")
47 return flagSet, helpPtr
48 }
49
50
51 var log = &Log{}
52
53 func printUsageAndExit() {
54 flag.Usage()
55
56
57
58 os.Exit(2)
59 }
60
61
62 func Main(version string) {
63 helpPtr := flag.Bool("help", false, "")
64 versionPtr := flag.Bool("version", false, "")
65 verbosePtr := flag.Bool("verbose", false, "")
66 prefetchPtr := flag.Uint("prefetch", 10, "")
67 lockTimeoutPtr := flag.Uint("lock-timeout", 15, "")
68 pathPtr := flag.String("path", "", "")
69 databasePtr := flag.String("database", "", "")
70 sourcePtr := flag.String("source", "", "")
71
72 flag.Usage = func() {
73 fmt.Fprintf(os.Stderr,
74 `Usage: migrate OPTIONS COMMAND [arg...]
75 migrate [ -version | -help ]
76
77 Options:
78 -source Location of the migrations (driver://url)
79 -path Shorthand for -source=file://path
80 -database Run migrations against this database (driver://url)
81 -prefetch N Number of migrations to load in advance before executing (default 10)
82 -lock-timeout N Allow N seconds to acquire database lock (default 15)
83 -verbose Print verbose logging
84 -version Print version
85 -help Print usage
86
87 Commands:
88 %s
89 %s
90 %s
91 %s
92 %s
93 %s
94 version Print current migration version
95
96 Source drivers: `+strings.Join(source.List(), ", ")+`
97 Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage)
98 }
99
100 flag.Parse()
101
102
103 log.verbose = *verbosePtr
104
105
106 if *versionPtr {
107 fmt.Fprintln(os.Stderr, version)
108 os.Exit(0)
109 }
110
111
112 if *helpPtr {
113 flag.Usage()
114 os.Exit(0)
115 }
116
117
118 if *sourcePtr == "" && *pathPtr != "" {
119 *sourcePtr = fmt.Sprintf("file://%v", *pathPtr)
120 }
121
122
123
124
125 migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr)
126 defer func() {
127 if migraterErr == nil {
128 if _, err := migrater.Close(); err != nil {
129 log.Println(err)
130 }
131 }
132 }()
133 if migraterErr == nil {
134 migrater.Log = log
135 migrater.PrefetchMigrations = *prefetchPtr
136 migrater.LockTimeout = time.Duration(int64(*lockTimeoutPtr)) * time.Second
137
138
139 signals := make(chan os.Signal, 1)
140 signal.Notify(signals, syscall.SIGINT)
141 go func() {
142 for range signals {
143 log.Println("Stopping after this running migration ...")
144 migrater.GracefulStop <- true
145 return
146 }
147 }()
148 }
149
150 startTime := time.Now()
151
152 if len(flag.Args()) < 1 {
153 printUsageAndExit()
154 }
155 args := flag.Args()[1:]
156
157 switch flag.Arg(0) {
158 case "create":
159
160 seq := false
161 seqDigits := 6
162
163 createFlagSet, help := newFlagSetWithHelp("create")
164 extPtr := createFlagSet.String("ext", "", "File extension")
165 dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)")
166 formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`)
167 timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`)
168 createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)")
169 createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)")
170
171 if err := createFlagSet.Parse(args); err != nil {
172 log.fatalErr(err)
173 }
174
175 handleSubCmdHelp(*help, createUsage, createFlagSet)
176
177 if createFlagSet.NArg() == 0 {
178 log.fatal("error: please specify name")
179 }
180 name := createFlagSet.Arg(0)
181
182 if *extPtr == "" {
183 log.fatal("error: -ext flag must be specified")
184 }
185
186 timezone, err := time.LoadLocation(*timezoneName)
187 if err != nil {
188 log.fatal(err)
189 }
190
191 if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil {
192 log.fatalErr(err)
193 }
194
195 case "goto":
196
197 gotoSet, helpPtr := newFlagSetWithHelp("goto")
198
199 if err := gotoSet.Parse(args); err != nil {
200 log.fatalErr(err)
201 }
202
203 handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet)
204
205 if migraterErr != nil {
206 log.fatalErr(migraterErr)
207 }
208
209 if gotoSet.NArg() == 0 {
210 log.fatal("error: please specify version argument V")
211 }
212
213 v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64)
214 if err != nil {
215 log.fatal("error: can't read version argument V")
216 }
217
218 if err := gotoCmd(migrater, uint(v)); err != nil {
219 log.fatalErr(err)
220 }
221
222 if log.verbose {
223 log.Println("Finished after", time.Since(startTime))
224 }
225
226 case "up":
227 upSet, helpPtr := newFlagSetWithHelp("up")
228
229 if err := upSet.Parse(args); err != nil {
230 log.fatalErr(err)
231 }
232
233 handleSubCmdHelp(*helpPtr, upUsage, upSet)
234
235 if migraterErr != nil {
236 log.fatalErr(migraterErr)
237 }
238
239 limit := -1
240 if upSet.NArg() > 0 {
241 n, err := strconv.ParseUint(upSet.Arg(0), 10, 64)
242 if err != nil {
243 log.fatal("error: can't read limit argument N")
244 }
245 limit = int(n)
246 }
247
248 if err := upCmd(migrater, limit); err != nil {
249 log.fatalErr(err)
250 }
251
252 if log.verbose {
253 log.Println("Finished after", time.Since(startTime))
254 }
255
256 case "down":
257 downFlagSet, helpPtr := newFlagSetWithHelp("down")
258 applyAll := downFlagSet.Bool("all", false, "Apply all down migrations")
259
260 if err := downFlagSet.Parse(args); err != nil {
261 log.fatalErr(err)
262 }
263
264 handleSubCmdHelp(*helpPtr, downUsage, downFlagSet)
265
266 if migraterErr != nil {
267 log.fatalErr(migraterErr)
268 }
269
270 downArgs := downFlagSet.Args()
271 num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs)
272 if err != nil {
273 log.fatalErr(err)
274 }
275 if needsConfirm {
276 log.Println("Are you sure you want to apply all down migrations? [y/N]")
277 var response string
278 fmt.Scanln(&response)
279 response = strings.ToLower(strings.TrimSpace(response))
280
281 if response == "y" {
282 log.Println("Applying all down migrations")
283 } else {
284 log.fatal("Not applying all down migrations")
285 }
286 }
287
288 if err := downCmd(migrater, num); err != nil {
289 log.fatalErr(err)
290 }
291
292 if log.verbose {
293 log.Println("Finished after", time.Since(startTime))
294 }
295
296 case "drop":
297 dropFlagSet, help := newFlagSetWithHelp("drop")
298 forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt")
299
300 if err := dropFlagSet.Parse(args); err != nil {
301 log.fatalErr(err)
302 }
303
304 handleSubCmdHelp(*help, dropUsage, dropFlagSet)
305
306 if !*forceDrop {
307 log.Println("Are you sure you want to drop the entire database schema? [y/N]")
308 var response string
309 fmt.Scanln(&response)
310 response = strings.ToLower(strings.TrimSpace(response))
311
312 if response == "y" {
313 log.Println("Dropping the entire database schema")
314 } else {
315 log.fatal("Aborted dropping the entire database schema")
316 }
317 }
318
319 if migraterErr != nil {
320 log.fatalErr(migraterErr)
321 }
322
323 if err := dropCmd(migrater); err != nil {
324 log.fatalErr(err)
325 }
326
327 if log.verbose {
328 log.Println("Finished after", time.Since(startTime))
329 }
330
331 case "force":
332 forceSet, helpPtr := newFlagSetWithHelp("force")
333
334 if err := forceSet.Parse(args); err != nil {
335 log.fatalErr(err)
336 }
337
338 handleSubCmdHelp(*helpPtr, forceUsage, forceSet)
339
340 if migraterErr != nil {
341 log.fatalErr(migraterErr)
342 }
343
344 if forceSet.NArg() == 0 {
345 log.fatal("error: please specify version argument V")
346 }
347
348 v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64)
349 if err != nil {
350 log.fatal("error: can't read version argument V")
351 }
352
353 if v < -1 {
354 log.fatal("error: argument V must be >= -1")
355 }
356
357 if err := forceCmd(migrater, int(v)); err != nil {
358 log.fatalErr(err)
359 }
360
361 if log.verbose {
362 log.Println("Finished after", time.Since(startTime))
363 }
364
365 case "version":
366 if migraterErr != nil {
367 log.fatalErr(migraterErr)
368 }
369
370 if err := versionCmd(migrater); err != nil {
371 log.fatalErr(err)
372 }
373
374 default:
375 printUsageAndExit()
376 }
377 }
378
View as plain text