1
2
3
4 package mysql
5
6 import (
7 "context"
8 "crypto/tls"
9 "crypto/x509"
10 "database/sql"
11 "fmt"
12 "go.uber.org/atomic"
13 "io"
14 "io/ioutil"
15 nurl "net/url"
16 "strconv"
17 "strings"
18
19 "github.com/go-sql-driver/mysql"
20 "github.com/golang-migrate/migrate/v4/database"
21 "github.com/hashicorp/go-multierror"
22 )
23
24 var _ database.Driver = (*Mysql)(nil)
25
26 func init() {
27 database.Register("mysql", &Mysql{})
28 }
29
30 var DefaultMigrationsTable = "schema_migrations"
31
32 var (
33 ErrDatabaseDirty = fmt.Errorf("database is dirty")
34 ErrNilConfig = fmt.Errorf("no config")
35 ErrNoDatabaseName = fmt.Errorf("no database name")
36 ErrAppendPEM = fmt.Errorf("failed to append PEM")
37 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
38 )
39
40 type Config struct {
41 MigrationsTable string
42 DatabaseName string
43 NoLock bool
44 }
45
46 type Mysql struct {
47
48
49 conn *sql.Conn
50 db *sql.DB
51 isLocked atomic.Bool
52
53 config *Config
54 }
55
56
57 func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
58 if config == nil {
59 return nil, ErrNilConfig
60 }
61
62 if err := conn.PingContext(ctx); err != nil {
63 return nil, err
64 }
65
66 mx := &Mysql{
67 conn: conn,
68 db: nil,
69 config: config,
70 }
71
72 if config.DatabaseName == "" {
73 query := `SELECT DATABASE()`
74 var databaseName sql.NullString
75 if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
76 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
77 }
78
79 if len(databaseName.String) == 0 {
80 return nil, ErrNoDatabaseName
81 }
82
83 config.DatabaseName = databaseName.String
84 }
85
86 if len(config.MigrationsTable) == 0 {
87 config.MigrationsTable = DefaultMigrationsTable
88 }
89
90 if err := mx.ensureVersionTable(); err != nil {
91 return nil, err
92 }
93
94 return mx, nil
95 }
96
97
98 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
99 ctx := context.Background()
100
101 if err := instance.Ping(); err != nil {
102 return nil, err
103 }
104
105 conn, err := instance.Conn(ctx)
106 if err != nil {
107 return nil, err
108 }
109
110 mx, err := WithConnection(ctx, conn, config)
111 if err != nil {
112 return nil, err
113 }
114
115 mx.db = instance
116
117 return mx, nil
118 }
119
120
121
122 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
123 if c == nil {
124 return nil, ErrNilConfig
125 }
126 customQueryParams := map[string]string{}
127
128 for k, v := range c.Params {
129 if strings.HasPrefix(k, "x-") {
130 customQueryParams[k] = v
131 delete(c.Params, k)
132 }
133 }
134 return customQueryParams, nil
135 }
136
137 func urlToMySQLConfig(url string) (*mysql.Config, error) {
138
139
140
141
142
143
144
145
146 if idx := strings.LastIndex(url, "?"); idx > 0 {
147 rawParams := url[idx+1:]
148 parsedParams, err := nurl.ParseQuery(rawParams)
149 if err != nil {
150 return nil, err
151 }
152
153 ctls := parsedParams.Get("tls")
154 if len(ctls) > 0 {
155 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
156 rootCertPool := x509.NewCertPool()
157 pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca"))
158 if err != nil {
159 return nil, err
160 }
161
162 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
163 return nil, ErrAppendPEM
164 }
165
166 clientCert := make([]tls.Certificate, 0, 1)
167 if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
168 if ccert == "" || ckey == "" {
169 return nil, ErrTLSCertKeyConfig
170 }
171 certs, err := tls.LoadX509KeyPair(ccert, ckey)
172 if err != nil {
173 return nil, err
174 }
175 clientCert = append(clientCert, certs)
176 }
177
178 insecureSkipVerify := false
179 insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
180 if len(insecureSkipVerifyStr) > 0 {
181 x, err := strconv.ParseBool(insecureSkipVerifyStr)
182 if err != nil {
183 return nil, err
184 }
185 insecureSkipVerify = x
186 }
187
188 err = mysql.RegisterTLSConfig(ctls, &tls.Config{
189 RootCAs: rootCertPool,
190 Certificates: clientCert,
191 InsecureSkipVerify: insecureSkipVerify,
192 })
193 if err != nil {
194 return nil, err
195 }
196 }
197 }
198 }
199
200 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
201 if err != nil {
202 return nil, err
203 }
204
205 config.MultiStatements = true
206
207
208
209
210 user, err := nurl.QueryUnescape(config.User)
211 if err != nil {
212 return nil, err
213 }
214 config.User = user
215
216 password, err := nurl.QueryUnescape(config.Passwd)
217 if err != nil {
218 return nil, err
219 }
220 config.Passwd = password
221
222 return config, nil
223 }
224
225 func (m *Mysql) Open(url string) (database.Driver, error) {
226 config, err := urlToMySQLConfig(url)
227 if err != nil {
228 return nil, err
229 }
230
231 customParams, err := extractCustomQueryParams(config)
232 if err != nil {
233 return nil, err
234 }
235
236 noLockParam, noLock := customParams["x-no-lock"], false
237 if noLockParam != "" {
238 noLock, err = strconv.ParseBool(noLockParam)
239 if err != nil {
240 return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
241 }
242 }
243
244 db, err := sql.Open("mysql", config.FormatDSN())
245 if err != nil {
246 return nil, err
247 }
248
249 mx, err := WithInstance(db, &Config{
250 DatabaseName: config.DBName,
251 MigrationsTable: customParams["x-migrations-table"],
252 NoLock: noLock,
253 })
254 if err != nil {
255 return nil, err
256 }
257
258 return mx, nil
259 }
260
261 func (m *Mysql) Close() error {
262 connErr := m.conn.Close()
263 var dbErr error
264 if m.db != nil {
265 dbErr = m.db.Close()
266 }
267
268 if connErr != nil || dbErr != nil {
269 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
270 }
271 return nil
272 }
273
274 func (m *Mysql) Lock() error {
275 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
276 if m.config.NoLock {
277 return nil
278 }
279 aid, err := database.GenerateAdvisoryLockId(
280 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
281 if err != nil {
282 return err
283 }
284
285 query := "SELECT GET_LOCK(?, 10)"
286 var success bool
287 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
288 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
289 }
290
291 if !success {
292 return database.ErrLocked
293 }
294
295 return nil
296 })
297 }
298
299 func (m *Mysql) Unlock() error {
300 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
301 if m.config.NoLock {
302 return nil
303 }
304
305 aid, err := database.GenerateAdvisoryLockId(
306 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
307 if err != nil {
308 return err
309 }
310
311 query := `SELECT RELEASE_LOCK(?)`
312 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
313 return &database.Error{OrigErr: err, Query: []byte(query)}
314 }
315
316
317
318
319
320 return nil
321 })
322 }
323
324 func (m *Mysql) Run(migration io.Reader) error {
325 migr, err := ioutil.ReadAll(migration)
326 if err != nil {
327 return err
328 }
329
330 query := string(migr[:])
331 if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
332 return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
333 }
334
335 return nil
336 }
337
338 func (m *Mysql) SetVersion(version int, dirty bool) error {
339 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
340 if err != nil {
341 return &database.Error{OrigErr: err, Err: "transaction start failed"}
342 }
343
344 query := "TRUNCATE `" + m.config.MigrationsTable + "`"
345 if _, err := tx.ExecContext(context.Background(), query); err != nil {
346 if errRollback := tx.Rollback(); errRollback != nil {
347 err = multierror.Append(err, errRollback)
348 }
349 return &database.Error{OrigErr: err, Query: []byte(query)}
350 }
351
352
353
354
355 if version >= 0 || (version == database.NilVersion && dirty) {
356 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
357 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
358 if errRollback := tx.Rollback(); errRollback != nil {
359 err = multierror.Append(err, errRollback)
360 }
361 return &database.Error{OrigErr: err, Query: []byte(query)}
362 }
363 }
364
365 if err := tx.Commit(); err != nil {
366 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
367 }
368
369 return nil
370 }
371
372 func (m *Mysql) Version() (version int, dirty bool, err error) {
373 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
374 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
375 switch {
376 case err == sql.ErrNoRows:
377 return database.NilVersion, false, nil
378
379 case err != nil:
380 if e, ok := err.(*mysql.MySQLError); ok {
381 if e.Number == 0 {
382 return database.NilVersion, false, nil
383 }
384 }
385 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
386
387 default:
388 return version, dirty, nil
389 }
390 }
391
392 func (m *Mysql) Drop() (err error) {
393
394 query := `SHOW TABLES LIKE '%'`
395 tables, err := m.conn.QueryContext(context.Background(), query)
396 if err != nil {
397 return &database.Error{OrigErr: err, Query: []byte(query)}
398 }
399 defer func() {
400 if errClose := tables.Close(); errClose != nil {
401 err = multierror.Append(err, errClose)
402 }
403 }()
404
405
406 tableNames := make([]string, 0)
407 for tables.Next() {
408 var tableName string
409 if err := tables.Scan(&tableName); err != nil {
410 return err
411 }
412 if len(tableName) > 0 {
413 tableNames = append(tableNames, tableName)
414 }
415 }
416 if err := tables.Err(); err != nil {
417 return &database.Error{OrigErr: err, Query: []byte(query)}
418 }
419
420 if len(tableNames) > 0 {
421
422 query = `SET foreign_key_checks = 0`
423 if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
424 return &database.Error{OrigErr: err, Query: []byte(query)}
425 }
426
427 defer func() {
428
429 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
430 }()
431
432
433 for _, t := range tableNames {
434 query = "DROP TABLE IF EXISTS `" + t + "`"
435 if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
436 return &database.Error{OrigErr: err, Query: []byte(query)}
437 }
438 }
439 }
440
441 return nil
442 }
443
444
445
446
447 func (m *Mysql) ensureVersionTable() (err error) {
448 if err = m.Lock(); err != nil {
449 return err
450 }
451
452 defer func() {
453 if e := m.Unlock(); e != nil {
454 if err == nil {
455 err = e
456 } else {
457 err = multierror.Append(err, e)
458 }
459 }
460 }()
461
462
463 var result string
464 query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
465 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
466 if err != sql.ErrNoRows {
467 return &database.Error{OrigErr: err, Query: []byte(query)}
468 }
469 } else {
470 return nil
471 }
472
473
474 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
475 if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
476 return &database.Error{OrigErr: err, Query: []byte(query)}
477 }
478 return nil
479 }
480
481
482
483
484 func readBool(input string) (value bool, valid bool) {
485 switch input {
486 case "1", "true", "TRUE", "True":
487 return true, true
488 case "0", "false", "FALSE", "False":
489 return false, true
490 }
491
492
493 return
494 }
495
View as plain text