1 package sqlserver_test
2
3 import (
4 "database/sql"
5 "fmt"
6 "os"
7 "testing"
8 "time"
9
10 "github.com/doug-martin/goqu/v9/dialect/mysql"
11
12 _ "github.com/denisenkom/go-mssqldb"
13 "github.com/doug-martin/goqu/v9"
14 _ "github.com/doug-martin/goqu/v9/dialect/sqlserver"
15 "github.com/stretchr/testify/suite"
16 )
17
18 const (
19 dropTable = "DROP TABLE IF EXISTS \"entry\";"
20 createTable = "CREATE TABLE \"entry\" (" +
21 "\"id\" INT NOT NULL IDENTITY(1,1)," +
22 "\"int\" INT NOT NULL UNIQUE," +
23 "\"float\" FLOAT NOT NULL ," +
24 "\"string\" VARCHAR(255) NOT NULL ," +
25 "\"time\" DATETIME NOT NULL ," +
26 "\"bool\" BIT NOT NULL ," +
27 "\"bytes\" VARBINARY(100) NOT NULL ," +
28 "PRIMARY KEY (\"id\") );"
29 insertDefaultRecords = "INSERT INTO [entry] ([int], [float], [string], [time], [bool], [bytes]) VALUES" +
30 "(0, 0.000000, '0.000000', '2015-02-22 18:19:55', 1, CONVERT(BINARY(8), '0.000000'))," +
31 "(1, 0.100000, '0.100000', '2015-02-22 19:19:55', 0, CONVERT(BINARY(8), '0.100000'))," +
32 "(2, 0.200000, '0.200000', '2015-02-22 20:19:55', 1, CONVERT(BINARY(8), '0.200000'))," +
33 "(3, 0.300000, '0.300000', '2015-02-22 21:19:55', 0, CONVERT(BINARY(8), '0.300000'))," +
34 "(4, 0.400000, '0.400000', '2015-02-22 22:19:55', 1, CONVERT(BINARY(8), '0.400000'))," +
35 "(5, 0.500000, '0.500000', '2015-02-22 23:19:55', 0, CONVERT(BINARY(8), '0.500000'))," +
36 "(6, 0.600000, '0.600000', '2015-02-23 00:19:55', 1, CONVERT(BINARY(8), '0.600000'))," +
37 "(7, 0.700000, '0.700000', '2015-02-23 01:19:55', 0, CONVERT(BINARY(8), '0.700000'))," +
38 "(8, 0.800000, '0.800000', '2015-02-23 02:19:55', 1, CONVERT(BINARY(8), '0.800000'))," +
39 "(9, 0.900000, '0.900000', '2015-02-23 03:19:55', 0, CONVERT(BINARY(8), '0.900000'));"
40 )
41
42 const defaultDBURI = "sqlserver://sa:qwe123QWE@127.0.0.1:1433?database=master&connection+timeout=30"
43
44 type (
45 sqlserverTest struct {
46 suite.Suite
47 db *goqu.Database
48 }
49 entry struct {
50 ID uint32 `db:"id" goqu:"skipinsert,skipupdate"`
51 Int int `db:"int"`
52 Float float64 `db:"float"`
53 String string `db:"string"`
54 Time time.Time `db:"time"`
55 Bool bool `db:"bool"`
56 Bytes []byte `db:"bytes"`
57 }
58 entryTestCase struct {
59 ds *goqu.SelectDataset
60 len int
61 check func(entry entry, index int)
62 err string
63 }
64 )
65
66 func (sst *sqlserverTest) assertEntries(cases ...entryTestCase) {
67 for i, c := range cases {
68 var entries []entry
69 err := c.ds.ScanStructs(&entries)
70 if c.err == "" {
71 sst.NoError(err, "test case %d failed", i)
72 } else {
73 sst.EqualError(err, c.err, "test case %d failed", i)
74 }
75 sst.Len(entries, c.len)
76 for index, entry := range entries {
77 c.check(entry, index)
78 }
79 }
80 }
81
82 func (sst *sqlserverTest) SetupSuite() {
83 dbURI := os.Getenv("SQLSERVER_URI")
84 if dbURI == "" {
85 dbURI = defaultDBURI
86 }
87 db, err := sql.Open("sqlserver", dbURI)
88 if err != nil {
89 panic(err.Error())
90 }
91 sst.db = goqu.New("sqlserver", db)
92 }
93
94 func (sst *sqlserverTest) SetupTest() {
95 if _, err := sst.db.Exec(dropTable); err != nil {
96 panic(err)
97 }
98 if _, err := sst.db.Exec(createTable); err != nil {
99 panic(err)
100 }
101 if _, err := sst.db.Exec(insertDefaultRecords); err != nil {
102 panic(err)
103 }
104 }
105
106 func (sst *sqlserverTest) TestToSQL() {
107 ds := sst.db.From("entry")
108 s, _, err := ds.Select("id", "float", "string", "time", "bool").ToSQL()
109 sst.NoError(err)
110 sst.Equal("SELECT \"id\", \"float\", \"string\", \"time\", \"bool\" FROM \"entry\"", s)
111
112 s, _, err = ds.Where(goqu.C("int").Eq(10)).ToSQL()
113 sst.NoError(err)
114 sst.Equal("SELECT * FROM \"entry\" WHERE (\"int\" = 10)", s)
115
116 s, args, err := ds.Prepared(true).Where(goqu.L("? = ?", goqu.C("int"), 10)).ToSQL()
117 sst.NoError(err)
118 sst.Equal([]interface{}{int64(10)}, args)
119 sst.Equal("SELECT * FROM \"entry\" WHERE \"int\" = @p1", s)
120 }
121
122 func (sst *sqlserverTest) TestQuery() {
123 ds := sst.db.From("entry")
124 floatVal := float64(0)
125 baseDate, err := time.Parse(
126 "2006-01-02 15:04:05",
127 "2015-02-22 18:19:55",
128 )
129 sst.NoError(err)
130 sst.assertEntries(
131 entryTestCase{ds: ds.Order(goqu.C("id").Asc()), len: 10, check: func(entry entry, index int) {
132 f := fmt.Sprintf("%f", floatVal)
133 sst.Equal(uint32(index+1), entry.ID)
134 sst.Equal(index, entry.Int)
135 sst.Equal(f, fmt.Sprintf("%f", entry.Float))
136 sst.Equal(f, entry.String)
137 sst.Equal([]byte(f), entry.Bytes)
138 sst.Equal(index%2 == 0, entry.Bool)
139 sst.Equal(baseDate.Add(time.Duration(index)*time.Hour).Unix(), entry.Time.Unix())
140 floatVal += float64(0.1)
141 }},
142 entryTestCase{
143 ds: ds.Where(goqu.C("bool").IsTrue()).Order(goqu.C("id").Asc()),
144 err: "goqu: boolean data type is not supported by dialect \"sqlserver\"",
145 },
146 entryTestCase{ds: ds.Where(goqu.C("int").Gt(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
147 sst.True(entry.Int > 4)
148 }},
149 entryTestCase{ds: ds.Where(goqu.C("int").Gte(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
150 sst.True(entry.Int >= 5)
151 }},
152 entryTestCase{ds: ds.Where(goqu.C("int").Lt(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
153 sst.True(entry.Int < 5)
154 }},
155 entryTestCase{ds: ds.Where(goqu.C("int").Lte(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
156 sst.True(entry.Int <= 4)
157 }},
158 entryTestCase{ds: ds.Where(goqu.C("int").Between(goqu.Range(3, 6))).Order(goqu.C("id").Asc()), len: 4, check: func(entry entry, _ int) {
159 sst.True(entry.Int >= 3)
160 sst.True(entry.Int <= 6)
161 }},
162 entryTestCase{ds: ds.Where(goqu.C("string").Eq("0.100000")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
163 sst.Equal(entry.String, "0.100000")
164 }},
165 entryTestCase{ds: ds.Where(goqu.C("string").Like("0.1%")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
166 sst.Equal(entry.String, "0.100000")
167 }},
168 entryTestCase{ds: ds.Where(goqu.C("string").NotLike("0.1%")).Order(goqu.C("id").Asc()), len: 9, check: func(entry entry, _ int) {
169 sst.NotEqual(entry.String, "0.100000")
170 }},
171 entryTestCase{ds: ds.Where(goqu.C("string").IsNull()).Order(goqu.C("id").Asc()), len: 0, check: func(entry entry, _ int) {
172 sst.Fail("Should not have returned any records")
173 }},
174 )
175 }
176
177 func (sst *sqlserverTest) TestQuery_Prepared() {
178 ds := sst.db.From("entry").Prepared(true)
179 floatVal := float64(0)
180 baseDate, err := time.Parse(
181 "2006-01-02 15:04:05",
182 "2015-02-22 18:19:55",
183 )
184 sst.NoError(err)
185 sst.assertEntries(
186 entryTestCase{ds: ds.Order(goqu.C("id").Asc()), len: 10, check: func(entry entry, index int) {
187 f := fmt.Sprintf("%f", floatVal)
188 sst.Equal(uint32(index+1), entry.ID)
189 sst.Equal(index, entry.Int)
190 sst.Equal(f, fmt.Sprintf("%f", entry.Float))
191 sst.Equal(f, entry.String)
192 sst.Equal([]byte(f), entry.Bytes)
193 sst.Equal(index%2 == 0, entry.Bool)
194 sst.Equal(baseDate.Add(time.Duration(index)*time.Hour).Unix(), entry.Time.Unix())
195 floatVal += float64(0.1)
196 }},
197 entryTestCase{
198 ds: ds.Where(goqu.C("bool").IsTrue()).Order(goqu.C("id").Asc()),
199 err: "goqu: boolean data type is not supported by dialect \"sqlserver\"",
200 },
201 entryTestCase{ds: ds.Where(goqu.C("int").Gt(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
202 sst.True(entry.Int > 4)
203 }},
204 entryTestCase{ds: ds.Where(goqu.C("int").Gte(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
205 sst.True(entry.Int >= 5)
206 }},
207 entryTestCase{ds: ds.Where(goqu.C("int").Lt(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
208 sst.True(entry.Int < 5)
209 }},
210 entryTestCase{ds: ds.Where(goqu.C("int").Lte(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
211 sst.True(entry.Int <= 4)
212 }},
213 entryTestCase{ds: ds.Where(goqu.C("int").Between(goqu.Range(3, 6))).Order(goqu.C("id").Asc()), len: 4, check: func(entry entry, _ int) {
214 sst.True(entry.Int >= 3)
215 sst.True(entry.Int <= 6)
216 }},
217 entryTestCase{ds: ds.Where(goqu.C("string").Eq("0.100000")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
218 sst.Equal(entry.String, "0.100000")
219 }},
220 entryTestCase{ds: ds.Where(goqu.C("string").Like("0.1%")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
221 sst.Equal(entry.String, "0.100000")
222 }},
223 entryTestCase{ds: ds.Where(goqu.C("string").NotLike("0.1%")).Order(goqu.C("id").Asc()), len: 9, check: func(entry entry, _ int) {
224 sst.NotEqual(entry.String, "0.100000")
225 }},
226 entryTestCase{ds: ds.Where(goqu.C("string").IsNull()).Order(goqu.C("id").Asc()), len: 0, check: func(entry entry, _ int) {
227 sst.Fail("Should not have returned any records")
228 }},
229 )
230 }
231
232 func (sst *sqlserverTest) TestQuery_ValueExpressions() {
233 type wrappedEntry struct {
234 entry
235 BoolValue bool `db:"bool_value"`
236 }
237 expectedDate, err := time.Parse("2006-01-02 15:04:05", "2015-02-22 19:19:55")
238 sst.NoError(err)
239 ds := sst.db.From("entry").Select(goqu.Star(), goqu.V(true).As("bool_value")).Where(goqu.Ex{"int": 1})
240 var we wrappedEntry
241 found, err := ds.ScanStruct(&we)
242 sst.NoError(err)
243 sst.True(found)
244 sst.Equal(wrappedEntry{
245 entry{2, 1, 0.100000, "0.100000", expectedDate, false, []byte("0.100000")},
246 true,
247 }, we)
248 }
249
250 func (sst *sqlserverTest) TestCount() {
251 ds := sst.db.From("entry")
252 count, err := ds.Count()
253 sst.NoError(err)
254 sst.Equal(int64(10), count)
255 count, err = ds.Where(goqu.C("int").Gt(4)).Count()
256 sst.NoError(err)
257 sst.Equal(int64(5), count)
258 count, err = ds.Where(goqu.C("int").Gte(4)).Count()
259 sst.NoError(err)
260 sst.Equal(int64(6), count)
261 count, err = ds.Where(goqu.C("string").Like("0.1%")).Count()
262 sst.NoError(err)
263 sst.Equal(int64(1), count)
264 count, err = ds.Where(goqu.C("string").IsNull()).Count()
265 sst.NoError(err)
266 sst.Equal(int64(0), count)
267 }
268
269 func (sst *sqlserverTest) TestLimitOffset() {
270 ds := sst.db.From("entry").Where(goqu.C("id").Gte(1)).Limit(1)
271 var e entry
272 found, err := ds.ScanStruct(&e)
273 sst.NoError(err)
274 sst.True(found)
275 sst.Equal(uint32(1), e.ID)
276
277 ds = sst.db.From("entry").Where(goqu.C("id").Gte(1)).Order(goqu.C("id").Desc()).Limit(1)
278 found, err = ds.ScanStruct(&e)
279 sst.NoError(err)
280 sst.True(found)
281 sst.Equal(uint32(10), e.ID)
282
283 ds = sst.db.From("entry").Where(goqu.C("id").Gte(1)).Order(goqu.C("id").Asc()).Offset(1).Limit(1)
284 found, err = ds.ScanStruct(&e)
285 sst.NoError(err)
286 sst.True(found)
287 sst.Equal(uint32(2), e.ID)
288 }
289
290 func (sst *sqlserverTest) TestLimitOffsetParameterized() {
291 ds := sst.db.From("entry").Prepared(true).Where(goqu.C("id").Gte(1)).Limit(1)
292 var e entry
293 found, err := ds.ScanStruct(&e)
294 sst.NoError(err)
295 sst.True(found)
296 sst.Equal(uint32(1), e.ID)
297
298 ds = sst.db.From("entry").Prepared(true).Where(goqu.C("id").Gte(1)).Order(goqu.C("id").Desc()).Limit(1)
299 found, err = ds.ScanStruct(&e)
300 sst.NoError(err)
301 sst.True(found)
302 sst.Equal(uint32(10), e.ID)
303
304 ds = sst.db.From("entry").Prepared(true).Where(goqu.C("id").Gte(1)).Order(goqu.C("id").Asc()).Offset(1).Limit(1)
305 found, err = ds.ScanStruct(&e)
306 sst.NoError(err)
307 sst.True(found)
308 sst.Equal(uint32(2), e.ID)
309 }
310
311 func (sst *sqlserverTest) TestInsert() {
312 ds := sst.db.From("entry")
313 now := time.Now()
314 _, err := ds.Insert().Rows(goqu.Record{
315 "Int": 10,
316 "Float": 1.00000,
317 "String": "1.000000",
318 "Time": now,
319 "Bool": true,
320 "Bytes": goqu.Cast(goqu.V([]byte("1.000000")), "BINARY(8)"),
321 }).Executor().Exec()
322 sst.NoError(err)
323
324 var insertedEntry entry
325 found, err := ds.Where(goqu.C("int").Eq(10)).ScanStruct(&insertedEntry)
326 sst.NoError(err)
327 sst.True(found)
328 sst.True(insertedEntry.ID > 0)
329
330 entries := []goqu.Record{
331 {
332 "Int": 11, "Float": 1.100000, "String": "1.100000", "Time": now,
333 "Bool": false, "Bytes": goqu.Cast(goqu.V([]byte("1.100000")), "BINARY(8)"),
334 },
335 {
336 "Int": 12, "Float": 1.200000, "String": "1.200000", "Time": now,
337 "Bool": true, "Bytes": goqu.Cast(goqu.V([]byte("1.200000")), "BINARY(8)"),
338 },
339 {
340 "Int": 13, "Float": 1.300000, "String": "1.300000", "Time": now,
341 "Bool": false, "Bytes": goqu.Cast(goqu.V([]byte("1.300000")), "BINARY(8)"),
342 },
343 {
344 "Int": 14, "Float": 1.400000, "String": "1.400000", "Time": now,
345 "Bool": true, "Bytes": goqu.Cast(goqu.V([]byte("1.400000")), "BINARY(8)"),
346 },
347 }
348 _, err = ds.Insert().Rows(entries).Executor().Exec()
349 sst.NoError(err)
350
351 var newEntries []entry
352 sst.NoError(ds.Where(goqu.C("int").In([]uint32{11, 12, 13, 14})).ScanStructs(&newEntries))
353 sst.Len(newEntries, 4)
354 for i, e := range newEntries {
355 sst.Equal(entries[i]["Int"], e.Int)
356 sst.Equal(entries[i]["Float"], e.Float)
357 sst.Equal(entries[i]["String"], e.String)
358 sst.Equal(
359 entries[i]["Time"].(time.Time).UTC().Format(mysql.DialectOptions().TimeFormat),
360 e.Time.Format(mysql.DialectOptions().TimeFormat),
361 )
362 sst.Equal(entries[i]["Bool"], e.Bool)
363 sst.Equal([]byte(entries[i]["String"].(string)), e.Bytes)
364 }
365
366 _, err = ds.Insert().Rows(
367 goqu.Record{
368 "Int": 15, "Float": 1.500000, "String": "1.500000", "Time": now,
369 "Bool": false, "Bytes": goqu.Cast(goqu.V([]byte("1.500000")), "BINARY(8)"),
370 },
371 goqu.Record{
372 "Int": 16, "Float": 1.600000, "String": "1.600000", "Time": now,
373 "Bool": true, "Bytes": goqu.Cast(goqu.V([]byte("1.600000")), "BINARY(8)"),
374 },
375 goqu.Record{
376 "Int": 17, "Float": 1.700000, "String": "1.700000", "Time": now,
377 "Bool": false, "Bytes": goqu.Cast(goqu.V([]byte("1.700000")), "BINARY(8)"),
378 },
379 goqu.Record{
380 "Int": 18, "Float": 1.800000, "String": "1.800000", "Time": now,
381 "Bool": true, "Bytes": goqu.Cast(goqu.V([]byte("1.800000")), "BINARY(8)"),
382 },
383 ).Executor().Exec()
384 sst.NoError(err)
385
386 newEntries = newEntries[0:0]
387 sst.NoError(ds.Where(goqu.C("int").In([]uint32{15, 16, 17, 18})).ScanStructs(&newEntries))
388 sst.Len(newEntries, 4)
389 }
390
391 func (sst *sqlserverTest) TestInsertReturningProducesError() {
392 ds := sst.db.From("entry")
393 now := time.Now()
394 e := entry{Int: 10, Float: 1.000000, String: "1.000000", Time: now, Bool: true, Bytes: []byte("1.000000")}
395 _, err := ds.Insert().Rows(e).Returning(goqu.Star()).Executor().ScanStruct(&e)
396 sst.Error(err)
397 }
398
399 func (sst *sqlserverTest) TestUpdate() {
400 ds := sst.db.From("entry")
401 var e entry
402 found, err := ds.Where(goqu.C("int").Eq(9)).Select("id").ScanStruct(&e)
403 sst.NoError(err)
404 sst.True(found)
405 e.Int = 11
406 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Update().Set(goqu.Record{"Int": e.Int}).Executor().Exec()
407 sst.NoError(err)
408
409 count, err := ds.Where(goqu.C("int").Eq(11)).Count()
410 sst.NoError(err)
411 sst.Equal(int64(1), count)
412 }
413
414 func (sst *sqlserverTest) TestUpdateReturning() {
415 ds := sst.db.From("entry")
416 var id uint32
417 _, err := ds.Where(goqu.C("int").Eq(11)).
418 Update().
419 Set(goqu.Record{"int": 9}).
420 Returning("id").
421 Executor().ScanVal(&id)
422 sst.Error(err)
423 sst.EqualError(err, "goqu: dialect does not support RETURNING clause [dialect=sqlserver]")
424 }
425
426 func (sst *sqlserverTest) TestDelete() {
427 ds := sst.db.From("entry")
428 var e entry
429 found, err := ds.Where(goqu.C("int").Eq(9)).Select("id").ScanStruct(&e)
430 sst.NoError(err)
431 sst.True(found)
432 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Delete().Executor().Exec()
433 sst.NoError(err)
434
435 count, err := ds.Count()
436 sst.NoError(err)
437 sst.Equal(int64(9), count)
438
439 var id uint32
440 found, err = ds.Where(goqu.C("id").Eq(e.ID)).ScanVal(&id)
441 sst.NoError(err)
442 sst.False(found)
443
444 e = entry{}
445 found, err = ds.Where(goqu.C("int").Eq(8)).Select("id").ScanStruct(&e)
446 sst.NoError(err)
447 sst.True(found)
448 sst.NotEqual(0, e.ID)
449
450 id = 0
451 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Delete().Returning("id").Executor().ScanVal(&id)
452 sst.EqualError(err, "goqu: dialect does not support RETURNING clause [dialect=sqlserver]")
453 }
454
455 func (sst *sqlserverTest) TestInsertIgnoreNotSupported() {
456 ds := sst.db.From("entry")
457 now := time.Now()
458
459
460 entries := []goqu.Record{
461 {
462 "Int": 8, "Float": 6.100000, "String": "6.100000", "Bool": false, "Time": now,
463 "Bytes": goqu.Cast(goqu.V([]byte("6.100000")), "BINARY(8)"),
464 },
465 {
466 "Int": 9, "Float": 7.200000, "String": "7.200000", "Bool": false, "Time": now,
467 "Bytes": goqu.Cast(goqu.V([]byte("7.200000")), "BINARY(8)"),
468 },
469 {
470 "Int": 10, "Float": 7.200000, "String": "7.200000", "Bool": false, "Time": now,
471 "Bytes": goqu.Cast(goqu.V([]byte("7.200000")), "BINARY(8)"),
472 },
473 }
474 _, err := ds.Insert().Rows(entries).OnConflict(goqu.DoNothing()).Executor().Exec()
475 sst.Error(err)
476 sst.Contains(err.Error(), "Cannot insert duplicate key in object 'dbo.entry'. The duplicate key value is (8)")
477
478 count, err := ds.Count()
479 sst.NoError(err)
480 sst.Equal(count, int64(10))
481 }
482
483 func TestSqlServerSuite(t *testing.T) {
484 suite.Run(t, new(sqlserverTest))
485 }
486
View as plain text