1 package mysql_test
2
3 import (
4 "database/sql"
5 "fmt"
6 "os"
7 "strconv"
8 "strings"
9 "testing"
10 "time"
11
12 "github.com/doug-martin/goqu/v9"
13 "github.com/doug-martin/goqu/v9/dialect/mysql"
14 _ "github.com/go-sql-driver/mysql"
15 "github.com/stretchr/testify/suite"
16 )
17
18 const (
19 dropTable = "DROP TABLE IF EXISTS `entry`;"
20 createTable = "CREATE TABLE `entry` (" +
21 "`id` INT NOT NULL AUTO_INCREMENT ," +
22 "`int` INT NOT NULL UNIQUE," +
23 "`float` FLOAT NOT NULL ," +
24 "`string` VARCHAR(255) NOT NULL ," +
25 "`time` DATETIME NOT NULL ," +
26 "`bool` TINYINT NOT NULL ," +
27 "`bytes` BLOB NOT NULL ," +
28 "PRIMARY KEY (`id`) );"
29 insertDefaultReords = "INSERT INTO `entry` (`int`, `float`, `string`, `time`, `bool`, `bytes`) VALUES" +
30 "(0, 0.000000, '0.000000', '2015-02-22 18:19:55', TRUE, '0.000000')," +
31 "(1, 0.100000, '0.100000', '2015-02-22 19:19:55', FALSE, '0.100000')," +
32 "(2, 0.200000, '0.200000', '2015-02-22 20:19:55', TRUE, '0.200000')," +
33 "(3, 0.300000, '0.300000', '2015-02-22 21:19:55', FALSE, '0.300000')," +
34 "(4, 0.400000, '0.400000', '2015-02-22 22:19:55', TRUE, '0.400000')," +
35 "(5, 0.500000, '0.500000', '2015-02-22 23:19:55', FALSE, '0.500000')," +
36 "(6, 0.600000, '0.600000', '2015-02-23 00:19:55', TRUE, '0.600000')," +
37 "(7, 0.700000, '0.700000', '2015-02-23 01:19:55', FALSE, '0.700000')," +
38 "(8, 0.800000, '0.800000', '2015-02-23 02:19:55', TRUE, '0.800000')," +
39 "(9, 0.900000, '0.900000', '2015-02-23 03:19:55', FALSE, '0.900000');"
40 )
41
42 const defaultDBURI = "root@/goqumysql?parseTime=true"
43
44 type (
45 mysqlTest struct {
46 suite.Suite
47 db *goqu.Database
48 }
49 entry struct {
50 ID uint32 `db:"id" goqu:"skipinsert,skipupdate"`
51 Int int `db:"int"`
52 Float float64 `db:"float"`
53 String string `db:"string"`
54 Time time.Time `db:"time"`
55 Bool bool `db:"bool"`
56 Bytes []byte `db:"bytes"`
57 }
58 entryTestCase struct {
59 ds *goqu.SelectDataset
60 len int
61 check func(entry entry, index int)
62 err string
63 }
64 )
65
66 func (mt *mysqlTest) SetupSuite() {
67 dbURI := os.Getenv("MYSQL_URI")
68 if dbURI == "" {
69 dbURI = defaultDBURI
70 }
71 db, err := sql.Open("mysql", dbURI)
72 if err != nil {
73 panic(err.Error())
74 }
75 mt.db = goqu.New("mysql", db)
76 }
77
78 func (mt *mysqlTest) assertSQL(cases ...sqlTestCase) {
79 for i, c := range cases {
80 actualSQL, actualArgs, err := c.ds.ToSQL()
81 if c.err == "" {
82 mt.NoError(err, "test case %d failed", i)
83 } else {
84 mt.EqualError(err, c.err, "test case %d failed", i)
85 }
86 mt.Equal(c.sql, actualSQL, "test case %d failed", i)
87 if c.isPrepared && c.args != nil || len(c.args) > 0 {
88 mt.Equal(c.args, actualArgs, "test case %d failed", i)
89 } else {
90 mt.Empty(actualArgs, "test case %d failed", i)
91 }
92 }
93 }
94
95 func (mt *mysqlTest) assertEntries(cases ...entryTestCase) {
96 for i, c := range cases {
97 var entries []entry
98 err := c.ds.ScanStructs(&entries)
99 if c.err == "" {
100 mt.NoError(err, "test case %d failed", i)
101 } else {
102 mt.EqualError(err, c.err, "test case %d failed", i)
103 }
104 mt.Len(entries, c.len)
105 for index, entry := range entries {
106 c.check(entry, index)
107 }
108 }
109 }
110
111 func (mt *mysqlTest) SetupTest() {
112 if _, err := mt.db.Exec(dropTable); err != nil {
113 panic(err)
114 }
115 if _, err := mt.db.Exec(createTable); err != nil {
116 panic(err)
117 }
118 if _, err := mt.db.Exec(insertDefaultReords); err != nil {
119 panic(err)
120 }
121 }
122
123 func (mt *mysqlTest) TestToSQL() {
124 ds := mt.db.From("entry")
125 mt.assertSQL(
126 sqlTestCase{ds: ds.Select("id", "float", "string", "time", "bool"), sql: "SELECT `id`, `float`, `string`, `time`, `bool` FROM `entry`"},
127 sqlTestCase{ds: ds.Where(goqu.C("int").Eq(10)), sql: "SELECT * FROM `entry` WHERE (`int` = 10)"},
128 sqlTestCase{
129 ds: ds.Prepared(true).Where(goqu.L("? = ?", goqu.C("int"), 10)),
130 sql: "SELECT * FROM `entry` WHERE `int` = ?", args: []interface{}{int64(10)},
131 },
132 )
133 }
134
135 func (mt *mysqlTest) TestQuery() {
136 ds := mt.db.From("entry")
137 floatVal := float64(0)
138 baseDate, err := time.Parse(
139 "2006-01-02 15:04:05",
140 "2015-02-22 18:19:55",
141 )
142 mt.NoError(err)
143 mt.assertEntries(
144 entryTestCase{ds: ds.Order(goqu.C("id").Asc()), len: 10, check: func(entry entry, index int) {
145 f := fmt.Sprintf("%f", floatVal)
146 mt.Equal(uint32(index+1), entry.ID)
147 mt.Equal(index, entry.Int)
148 mt.Equal(f, fmt.Sprintf("%f", entry.Float))
149 mt.Equal(f, entry.String)
150 mt.Equal([]byte(f), entry.Bytes)
151 mt.Equal(index%2 == 0, entry.Bool)
152 mt.Equal(baseDate.Add(time.Duration(index)*time.Hour).Unix(), entry.Time.Unix())
153 floatVal += float64(0.1)
154 }},
155 entryTestCase{ds: ds.Where(goqu.C("bool").IsTrue()).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
156 mt.True(entry.Bool)
157 }},
158 entryTestCase{ds: ds.Where(goqu.C("int").Gt(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
159 mt.True(entry.Int > 4)
160 }},
161 entryTestCase{ds: ds.Where(goqu.C("int").Gte(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
162 mt.True(entry.Int >= 5)
163 }},
164 entryTestCase{ds: ds.Where(goqu.C("int").Lt(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
165 mt.True(entry.Int < 5)
166 }},
167 entryTestCase{ds: ds.Where(goqu.C("int").Lte(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
168 mt.True(entry.Int <= 4)
169 }},
170 entryTestCase{ds: ds.Where(goqu.C("int").Between(goqu.Range(3, 6))).Order(goqu.C("id").Asc()), len: 4, check: func(entry entry, _ int) {
171 mt.True(entry.Int >= 3)
172 mt.True(entry.Int <= 6)
173 }},
174 entryTestCase{ds: ds.Where(goqu.C("string").Eq("0.100000")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
175 mt.Equal(entry.String, "0.100000")
176 }},
177 entryTestCase{ds: ds.Where(goqu.C("string").Like("0.1%")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
178 mt.Equal(entry.String, "0.100000")
179 }},
180 entryTestCase{ds: ds.Where(goqu.C("string").NotLike("0.1%")).Order(goqu.C("id").Asc()), len: 9, check: func(entry entry, _ int) {
181 mt.NotEqual(entry.String, "0.100000")
182 }},
183 entryTestCase{ds: ds.Where(goqu.C("string").IsNull()).Order(goqu.C("id").Asc()), len: 0, check: func(entry entry, _ int) {
184 mt.Fail("Should not have returned any records")
185 }},
186 )
187 }
188
189 func (mt *mysqlTest) TestQuery_Prepared() {
190 ds := mt.db.From("entry").Prepared(true)
191 floatVal := float64(0)
192 baseDate, err := time.Parse(
193 "2006-01-02 15:04:05",
194 "2015-02-22 18:19:55",
195 )
196 mt.NoError(err)
197 mt.assertEntries(
198 entryTestCase{ds: ds.Order(goqu.C("id").Asc()), len: 10, check: func(entry entry, index int) {
199 f := fmt.Sprintf("%f", floatVal)
200 mt.Equal(uint32(index+1), entry.ID)
201 mt.Equal(index, entry.Int)
202 mt.Equal(f, fmt.Sprintf("%f", entry.Float))
203 mt.Equal(f, entry.String)
204 mt.Equal([]byte(f), entry.Bytes)
205 mt.Equal(index%2 == 0, entry.Bool)
206 mt.Equal(baseDate.Add(time.Duration(index)*time.Hour).Unix(), entry.Time.Unix())
207 floatVal += float64(0.1)
208 }},
209 entryTestCase{ds: ds.Where(goqu.C("bool").IsTrue()).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
210 mt.True(entry.Bool)
211 }},
212 entryTestCase{ds: ds.Where(goqu.C("int").Gt(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
213 mt.True(entry.Int > 4)
214 }},
215 entryTestCase{ds: ds.Where(goqu.C("int").Gte(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
216 mt.True(entry.Int >= 5)
217 }},
218 entryTestCase{ds: ds.Where(goqu.C("int").Lt(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
219 mt.True(entry.Int < 5)
220 }},
221 entryTestCase{ds: ds.Where(goqu.C("int").Lte(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
222 mt.True(entry.Int <= 4)
223 }},
224 entryTestCase{ds: ds.Where(goqu.C("int").Between(goqu.Range(3, 6))).Order(goqu.C("id").Asc()), len: 4, check: func(entry entry, _ int) {
225 mt.True(entry.Int >= 3)
226 mt.True(entry.Int <= 6)
227 }},
228 entryTestCase{ds: ds.Where(goqu.C("string").Eq("0.100000")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
229 mt.Equal(entry.String, "0.100000")
230 }},
231 entryTestCase{ds: ds.Where(goqu.C("string").Like("0.1%")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
232 mt.Equal(entry.String, "0.100000")
233 }},
234 entryTestCase{ds: ds.Where(goqu.C("string").NotLike("0.1%")).Order(goqu.C("id").Asc()), len: 9, check: func(entry entry, _ int) {
235 mt.NotEqual(entry.String, "0.100000")
236 }},
237 entryTestCase{ds: ds.Where(goqu.C("string").IsNull()).Order(goqu.C("id").Asc()), len: 0, check: func(entry entry, _ int) {
238 mt.Fail("Should not have returned any records")
239 }},
240 )
241 }
242
243 func (mt *mysqlTest) TestQuery_ValueExpressions() {
244 type wrappedEntry struct {
245 entry
246 BoolValue bool `db:"bool_value"`
247 }
248 expectedDate, err := time.Parse("2006-01-02 15:04:05", "2015-02-22 19:19:55")
249 mt.NoError(err)
250 ds := mt.db.From("entry").Select(goqu.Star(), goqu.V(true).As("bool_value")).Where(goqu.Ex{"int": 1})
251 var we wrappedEntry
252 found, err := ds.ScanStruct(&we)
253 mt.NoError(err)
254 mt.True(found)
255 mt.Equal(wrappedEntry{
256 entry{2, 1, 0.100000, "0.100000", expectedDate, false, []byte("0.100000")},
257 true,
258 }, we)
259 }
260
261 func (mt *mysqlTest) TestCount() {
262 ds := mt.db.From("entry")
263 count, err := ds.Count()
264 mt.NoError(err)
265 mt.Equal(int64(10), count)
266 count, err = ds.Where(goqu.C("int").Gt(4)).Count()
267 mt.NoError(err)
268 mt.Equal(int64(5), count)
269 count, err = ds.Where(goqu.C("int").Gte(4)).Count()
270 mt.NoError(err)
271 mt.Equal(int64(6), count)
272 count, err = ds.Where(goqu.C("string").Like("0.1%")).Count()
273 mt.NoError(err)
274 mt.Equal(int64(1), count)
275 count, err = ds.Where(goqu.C("string").IsNull()).Count()
276 mt.NoError(err)
277 mt.Equal(int64(0), count)
278 }
279
280 func (mt *mysqlTest) TestInsert() {
281 ds := mt.db.From("entry")
282 now := time.Now()
283 e := entry{Int: 10, Float: 1.000000, String: "1.000000", Time: now, Bool: true, Bytes: []byte("1.000000")}
284 _, err := ds.Insert().Rows(e).Executor().Exec()
285 mt.NoError(err)
286
287 var insertedEntry entry
288 found, err := ds.Where(goqu.C("int").Eq(10)).ScanStruct(&insertedEntry)
289 mt.NoError(err)
290 mt.True(found)
291 mt.True(insertedEntry.ID > 0)
292
293 entries := []entry{
294 {Int: 11, Float: 1.100000, String: "1.100000", Time: now, Bool: false, Bytes: []byte("1.100000")},
295 {Int: 12, Float: 1.200000, String: "1.200000", Time: now, Bool: true, Bytes: []byte("1.200000")},
296 {Int: 13, Float: 1.300000, String: "1.300000", Time: now, Bool: false, Bytes: []byte("1.300000")},
297 {Int: 14, Float: 1.400000, String: "1.400000", Time: now, Bool: true, Bytes: []byte("1.400000")},
298 }
299 _, err = ds.Insert().Rows(entries).Executor().Exec()
300 mt.NoError(err)
301
302 var newEntries []entry
303 mt.NoError(ds.Where(goqu.C("int").In([]uint32{11, 12, 13, 14})).ScanStructs(&newEntries))
304 mt.Len(newEntries, 4)
305 for i, e := range newEntries {
306 mt.Equal(entries[i].Int, e.Int)
307 mt.Equal(entries[i].Float, e.Float)
308 mt.Equal(entries[i].String, e.String)
309 mt.Equal(entries[i].Time.UTC().Format(mysql.DialectOptions().TimeFormat), e.Time.Format(mysql.DialectOptions().TimeFormat))
310 mt.Equal(entries[i].Bool, e.Bool)
311 mt.Equal(entries[i].Bytes, e.Bytes)
312 }
313
314 _, err = ds.Insert().Rows(
315 entry{Int: 15, Float: 1.500000, String: "1.500000", Time: now, Bool: false, Bytes: []byte("1.500000")},
316 entry{Int: 16, Float: 1.600000, String: "1.600000", Time: now, Bool: true, Bytes: []byte("1.600000")},
317 entry{Int: 17, Float: 1.700000, String: "1.700000", Time: now, Bool: false, Bytes: []byte("1.700000")},
318 entry{Int: 18, Float: 1.800000, String: "1.800000", Time: now, Bool: true, Bytes: []byte("1.800000")},
319 ).Executor().Exec()
320 mt.NoError(err)
321
322 newEntries = newEntries[0:0]
323 mt.NoError(ds.Where(goqu.C("int").In([]uint32{15, 16, 17, 18})).ScanStructs(&newEntries))
324 mt.Len(newEntries, 4)
325 }
326
327 func (mt *mysqlTest) TestInsertReturning() {
328 ds := mt.db.From("entry")
329 now := time.Now()
330 e := entry{Int: 10, Float: 1.000000, String: "1.000000", Time: now, Bool: true, Bytes: []byte("1.000000")}
331 _, err := ds.Insert().Rows(e).Returning(goqu.Star()).Executor().ScanStruct(&e)
332 mt.Error(err)
333 }
334
335 func (mt *mysqlTest) TestUpdate() {
336 ds := mt.db.From("entry")
337 var e entry
338 found, err := ds.Where(goqu.C("int").Eq(9)).Select("id").ScanStruct(&e)
339 mt.NoError(err)
340 mt.True(found)
341 e.Int = 11
342 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Update().Set(e).Executor().Exec()
343 mt.NoError(err)
344
345 count, err := ds.Where(goqu.C("int").Eq(11)).Count()
346 mt.NoError(err)
347 mt.Equal(int64(1), count)
348 }
349
350 func (mt *mysqlTest) TestUpdateReturning() {
351 ds := mt.db.From("entry")
352 var id uint32
353 _, err := ds.Where(goqu.C("int").Eq(11)).
354 Update().
355 Set(goqu.Record{"int": 9}).
356 Returning("id").
357 Executor().ScanVal(&id)
358 mt.Error(err)
359 mt.EqualError(err, "goqu: dialect does not support RETURNING clause [dialect=mysql]")
360 }
361
362 func (mt *mysqlTest) TestDelete() {
363 ds := mt.db.From("entry")
364 var e entry
365 found, err := ds.Where(goqu.C("int").Eq(9)).Select("id").ScanStruct(&e)
366 mt.NoError(err)
367 mt.True(found)
368 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Delete().Executor().Exec()
369 mt.NoError(err)
370
371 count, err := ds.Count()
372 mt.NoError(err)
373 mt.Equal(int64(9), count)
374
375 var id uint32
376 found, err = ds.Where(goqu.C("id").Eq(e.ID)).ScanVal(&id)
377 mt.NoError(err)
378 mt.False(found)
379
380 e = entry{}
381 found, err = ds.Where(goqu.C("int").Eq(8)).Select("id").ScanStruct(&e)
382 mt.NoError(err)
383 mt.True(found)
384 mt.NotEqual(0, e.ID)
385
386 id = 0
387 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Delete().Returning("id").Executor().ScanVal(&id)
388 mt.EqualError(err, "goqu: dialect does not support RETURNING clause [dialect=mysql]")
389 }
390
391 func (mt *mysqlTest) TestInsertIgnore() {
392 ds := mt.db.From("entry")
393 now := time.Now()
394
395
396 entries := []entry{
397 {Int: 8, Float: 6.100000, String: "6.100000", Time: now, Bytes: []byte("6.100000")},
398 {Int: 9, Float: 7.200000, String: "7.200000", Time: now, Bytes: []byte("7.200000")},
399 {Int: 10, Float: 7.200000, String: "7.200000", Time: now, Bytes: []byte("7.200000")},
400 }
401 _, err := ds.Insert().Rows(entries).OnConflict(goqu.DoNothing()).Executor().Exec()
402 mt.NoError(err)
403
404 count, err := ds.Count()
405 mt.NoError(err)
406 mt.Equal(count, int64(11))
407 }
408
409 func (mt *mysqlTest) TestInsert_OnConflict() {
410 ds := mt.db.From("entry")
411 now := time.Now()
412
413
414 e := entry{Int: 10, Float: 1.100000, String: "1.100000", Time: now, Bool: false, Bytes: []byte("1.100000")}
415 _, err := ds.Insert().Rows(e).OnConflict(goqu.DoNothing()).Executor().Exec()
416 mt.NoError(err)
417
418
419 e = entry{Int: 10, Float: 2.100000, String: "2.100000", Time: now.Add(time.Hour * 100), Bool: false, Bytes: []byte("2.100000")}
420 _, err = ds.Insert().Rows(e).OnConflict(goqu.DoNothing()).Executor().Exec()
421 mt.NoError(err)
422
423
424 var entryActual entry
425 e2 := entry{Int: 10, String: "2.000000"}
426 _, err = ds.Insert().
427 Rows(e2).
428 OnConflict(goqu.DoUpdate("int", goqu.Record{"string": "upsert"})).
429 Executor().Exec()
430 mt.NoError(err)
431 _, err = ds.Where(goqu.C("int").Eq(10)).ScanStruct(&entryActual)
432 mt.NoError(err)
433 mt.Equal("upsert", entryActual.String)
434
435
436 entries := []entry{
437 {Int: 8, Float: 6.100000, String: "6.100000", Time: now, Bytes: []byte("6.100000")},
438 {Int: 9, Float: 7.200000, String: "7.200000", Time: now, Bytes: []byte("7.200000")},
439 }
440 _, err = ds.Insert().
441 Rows(entries).
442 OnConflict(goqu.DoUpdate("int", goqu.Record{"string": "upsert"}).Where(goqu.C("int").Eq(9))).
443 Executor().Exec()
444 mt.EqualError(err, "goqu: dialect does not support upsert with where clause [dialect=mysql]")
445 }
446
447 func (mt *mysqlTest) TestWindowFunction() {
448 var version string
449 ok, err := mt.db.Select(goqu.Func("version")).ScanVal(&version)
450 mt.NoError(err)
451 mt.True(ok)
452
453 fields := strings.Split(version, ".")
454 mt.True(len(fields) > 0)
455 major, err := strconv.Atoi(fields[0])
456 mt.NoError(err)
457 if major < 8 {
458
459 fmt.Printf("SKIPPING MYSQL WINDOW FUNCTION TEST BECAUSE VERSION IS < 8 [mysql_version:=%d]\n", major)
460 return
461 }
462
463 ds := mt.db.From("entry").
464 Select("int", goqu.ROW_NUMBER().OverName(goqu.I("w")).As("id")).
465 Window(goqu.W("w").OrderBy(goqu.I("int").Desc()))
466
467 var entries []entry
468 mt.NoError(ds.WithDialect("mysql8").ScanStructs(&entries))
469
470 mt.Equal([]entry{
471 {Int: 9, ID: 1},
472 {Int: 8, ID: 2},
473 {Int: 7, ID: 3},
474 {Int: 6, ID: 4},
475 {Int: 5, ID: 5},
476 {Int: 4, ID: 6},
477 {Int: 3, ID: 7},
478 {Int: 2, ID: 8},
479 {Int: 1, ID: 9},
480 {Int: 0, ID: 10},
481 }, entries)
482
483 mt.Error(ds.WithDialect("mysql").ScanStructs(&entries), "goqu: adapter does not support window function clause")
484 }
485
486 func (mt *mysqlTest) TestInsertFromSelect() {
487 ds := mt.db.From("entry")
488
489 subquery := goqu.Select(
490 goqu.V(11),
491 goqu.V(11),
492 goqu.C("float"),
493 goqu.C("string"),
494 goqu.C("time"),
495 goqu.C("bool"),
496 goqu.C("bytes"),
497 ).From(goqu.T("entry")).Where(goqu.C("int").Eq(9))
498
499 query := ds.Insert().Cols().FromQuery(subquery)
500 _, _, err := query.ToSQL()
501
502 mt.NoError(err)
503 _, err = query.Executor().Exec()
504 mt.NoError(err)
505 }
506
507 func TestMysqlSuite(t *testing.T) {
508 suite.Run(t, new(mysqlTest))
509 }
510
View as plain text