1 package mysql
2
3 import (
4 "context"
5 "crypto/ed25519"
6 "crypto/x509"
7 "database/sql"
8 sqldriver "database/sql/driver"
9 "encoding/pem"
10 "errors"
11 "fmt"
12 "io/ioutil"
13 "log"
14 "math/big"
15 "math/rand"
16 "net/url"
17 "os"
18 "strconv"
19 "testing"
20 )
21
22 import (
23 "github.com/dhui/dktest"
24 "github.com/go-sql-driver/mysql"
25 "github.com/stretchr/testify/assert"
26 )
27
28 import (
29 "github.com/golang-migrate/migrate/v4"
30 dt "github.com/golang-migrate/migrate/v4/database/testing"
31 "github.com/golang-migrate/migrate/v4/dktesting"
32 _ "github.com/golang-migrate/migrate/v4/source/file"
33 )
34
35 const defaultPort = 3306
36
37 var (
38 opts = dktest.Options{
39 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
40 PortRequired: true, ReadyFunc: isReady,
41 }
42 optsAnsiQuotes = dktest.Options{
43 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
44 PortRequired: true, ReadyFunc: isReady,
45 Cmd: []string{"--sql-mode=ANSI_QUOTES"},
46 }
47
48 specs = []dktesting.ContainerSpec{
49 {ImageName: "mysql:5.5", Options: opts},
50 {ImageName: "mysql:5.6", Options: opts},
51 {ImageName: "mysql:5.7", Options: opts},
52 {ImageName: "mysql:8", Options: opts},
53 }
54 specsAnsiQuotes = []dktesting.ContainerSpec{
55 {ImageName: "mysql:5.5", Options: optsAnsiQuotes},
56 {ImageName: "mysql:5.6", Options: optsAnsiQuotes},
57 {ImageName: "mysql:5.7", Options: optsAnsiQuotes},
58 {ImageName: "mysql:8", Options: optsAnsiQuotes},
59 }
60 )
61
62 func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
63 ip, port, err := c.Port(defaultPort)
64 if err != nil {
65 return false
66 }
67
68 db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
69 if err != nil {
70 return false
71 }
72 defer func() {
73 if err := db.Close(); err != nil {
74 log.Println("close error:", err)
75 }
76 }()
77 if err = db.PingContext(ctx); err != nil {
78 switch err {
79 case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
80 return false
81 default:
82 fmt.Println(err)
83 }
84 return false
85 }
86
87 return true
88 }
89
90 func Test(t *testing.T) {
91
92
93 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
94 ip, port, err := c.Port(defaultPort)
95 if err != nil {
96 t.Fatal(err)
97 }
98
99 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
100 p := &Mysql{}
101 d, err := p.Open(addr)
102 if err != nil {
103 t.Fatal(err)
104 }
105 defer func() {
106 if err := d.Close(); err != nil {
107 t.Error(err)
108 }
109 }()
110 dt.Test(t, d, []byte("SELECT 1"))
111
112
113 if err := d.(*Mysql).ensureVersionTable(); err != nil {
114 t.Fatal(err)
115 }
116
117 if err := d.(*Mysql).ensureVersionTable(); err != nil {
118 t.Fatal(err)
119 }
120 })
121 }
122
123 func TestMigrate(t *testing.T) {
124
125
126 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
127 ip, port, err := c.Port(defaultPort)
128 if err != nil {
129 t.Fatal(err)
130 }
131
132 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
133 p := &Mysql{}
134 d, err := p.Open(addr)
135 if err != nil {
136 t.Fatal(err)
137 }
138 defer func() {
139 if err := d.Close(); err != nil {
140 t.Error(err)
141 }
142 }()
143
144 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
145 if err != nil {
146 t.Fatal(err)
147 }
148 dt.TestMigrate(t, m)
149
150
151 if err := d.(*Mysql).ensureVersionTable(); err != nil {
152 t.Fatal(err)
153 }
154
155 if err := d.(*Mysql).ensureVersionTable(); err != nil {
156 t.Fatal(err)
157 }
158 })
159 }
160
161 func TestMigrateAnsiQuotes(t *testing.T) {
162
163
164 dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
165 ip, port, err := c.Port(defaultPort)
166 if err != nil {
167 t.Fatal(err)
168 }
169
170 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
171 p := &Mysql{}
172 d, err := p.Open(addr)
173 if err != nil {
174 t.Fatal(err)
175 }
176 defer func() {
177 if err := d.Close(); err != nil {
178 t.Error(err)
179 }
180 }()
181
182 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
183 if err != nil {
184 t.Fatal(err)
185 }
186 dt.TestMigrate(t, m)
187
188
189 if err := d.(*Mysql).ensureVersionTable(); err != nil {
190 t.Fatal(err)
191 }
192
193 if err := d.(*Mysql).ensureVersionTable(); err != nil {
194 t.Fatal(err)
195 }
196 })
197 }
198
199 func TestLockWorks(t *testing.T) {
200 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
201 ip, port, err := c.Port(defaultPort)
202 if err != nil {
203 t.Fatal(err)
204 }
205
206 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
207 p := &Mysql{}
208 d, err := p.Open(addr)
209 if err != nil {
210 t.Fatal(err)
211 }
212 dt.Test(t, d, []byte("SELECT 1"))
213
214 ms := d.(*Mysql)
215
216 err = ms.Lock()
217 if err != nil {
218 t.Fatal(err)
219 }
220 err = ms.Unlock()
221 if err != nil {
222 t.Fatal(err)
223 }
224
225
226 err = ms.Lock()
227 if err != nil {
228 t.Fatal(err)
229 }
230 err = ms.Unlock()
231 if err != nil {
232 t.Fatal(err)
233 }
234 })
235 }
236
237 func TestNoLockParamValidation(t *testing.T) {
238 ip := "127.0.0.1"
239 port := 3306
240 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
241 p := &Mysql{}
242 _, err := p.Open(addr + "?x-no-lock=not-a-bool")
243 if !errors.Is(err, strconv.ErrSyntax) {
244 t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
245 }
246 }
247
248 func TestNoLockWorks(t *testing.T) {
249 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
250 ip, port, err := c.Port(defaultPort)
251 if err != nil {
252 t.Fatal(err)
253 }
254
255 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
256 p := &Mysql{}
257 d, err := p.Open(addr)
258 if err != nil {
259 t.Fatal(err)
260 }
261
262 lock := d.(*Mysql)
263
264 p = &Mysql{}
265 d, err = p.Open(addr + "?x-no-lock=true")
266 if err != nil {
267 t.Fatal(err)
268 }
269
270 noLock := d.(*Mysql)
271
272
273 if err = lock.Lock(); err != nil {
274 t.Fatal(err)
275 }
276 if err = noLock.Lock(); err != nil {
277 t.Fatal(err)
278 }
279 if err = lock.Unlock(); err != nil {
280 t.Fatal(err)
281 }
282 if err = noLock.Unlock(); err != nil {
283 t.Fatal(err)
284 }
285 })
286 }
287
288 func TestExtractCustomQueryParams(t *testing.T) {
289 testcases := []struct {
290 name string
291 config *mysql.Config
292 expectedParams map[string]string
293 expectedCustomParams map[string]string
294 expectedErr error
295 }{
296 {name: "nil config", expectedErr: ErrNilConfig},
297 {
298 name: "no params",
299 config: mysql.NewConfig(),
300 expectedCustomParams: map[string]string{},
301 },
302 {
303 name: "no custom params",
304 config: &mysql.Config{Params: map[string]string{"hello": "world"}},
305 expectedParams: map[string]string{"hello": "world"},
306 expectedCustomParams: map[string]string{},
307 },
308 {
309 name: "one param, one custom param",
310 config: &mysql.Config{
311 Params: map[string]string{"hello": "world", "x-foo": "bar"},
312 },
313 expectedParams: map[string]string{"hello": "world"},
314 expectedCustomParams: map[string]string{"x-foo": "bar"},
315 },
316 {
317 name: "multiple params, multiple custom params",
318 config: &mysql.Config{
319 Params: map[string]string{
320 "hello": "world",
321 "x-foo": "bar",
322 "dead": "beef",
323 "x-cat": "hat",
324 },
325 },
326 expectedParams: map[string]string{"hello": "world", "dead": "beef"},
327 expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
328 },
329 }
330 for _, tc := range testcases {
331 t.Run(tc.name, func(t *testing.T) {
332 customParams, err := extractCustomQueryParams(tc.config)
333 if tc.config != nil {
334 assert.Equal(t, tc.expectedParams, tc.config.Params,
335 "Expected config params have custom params properly removed")
336 }
337 assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
338 assert.Equal(t, tc.expectedCustomParams, customParams,
339 "Expected custom params to be properly extracted")
340 })
341 }
342 }
343
344 func createTmpCert(t *testing.T) string {
345 tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert")
346 if err != nil {
347 t.Fatal("Failed to create temp cert file:", err)
348 }
349 t.Cleanup(func() {
350 if err := os.Remove(tmpCertFile.Name()); err != nil {
351 t.Log("Failed to cleanup temp cert file:", err)
352 }
353 })
354
355 r := rand.New(rand.NewSource(0))
356 pub, priv, err := ed25519.GenerateKey(r)
357 if err != nil {
358 t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
359 }
360 tmpl := x509.Certificate{
361 SerialNumber: big.NewInt(0),
362 }
363 derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
364 if err != nil {
365 t.Fatal("Failed to generate temp cert file:", err)
366 }
367 if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
368 t.Fatal("Failed to encode ")
369 }
370 if err := tmpCertFile.Close(); err != nil {
371 t.Fatal("Failed to close temp cert file:", err)
372 }
373 return tmpCertFile.Name()
374 }
375
376 func TestURLToMySQLConfig(t *testing.T) {
377 tmpCertFilename := createTmpCert(t)
378 tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
379
380 testcases := []struct {
381 name string
382 urlStr string
383 expectedDSN string
384 }{
385 {name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
386 expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
387 {name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
388 expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
389 {name: "only user - with encoded :",
390 urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
391 expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
392 {name: "only user - with encoded @",
393 urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
394 expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
395 {name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
396 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
397
398
399
400
401 {name: "user/password - user with encoded @",
402 urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
403 expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
404 {name: "user/password - password with encoded :",
405 urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
406 expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
407 {name: "user/password - password with encoded @",
408 urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
409 expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
410 {name: "custom tls",
411 urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
412 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
413 }
414 for _, tc := range testcases {
415 t.Run(tc.name, func(t *testing.T) {
416 config, err := urlToMySQLConfig(tc.urlStr)
417 if err != nil {
418 t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
419 }
420 dsn := config.FormatDSN()
421 if dsn != tc.expectedDSN {
422 t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
423 }
424 })
425 }
426 }
427
View as plain text