1 package postgres_test
2
3 import (
4 "database/sql"
5 "fmt"
6 "os"
7 "testing"
8 "time"
9
10 "github.com/doug-martin/goqu/v9"
11
12 "github.com/lib/pq"
13 "github.com/stretchr/testify/suite"
14 )
15
16 const schema = `
17 DROP TABLE IF EXISTS "entry";
18 CREATE TABLE "entry" (
19 "id" SERIAL PRIMARY KEY NOT NULL,
20 "int" INT NOT NULL UNIQUE,
21 "float" NUMERIC NOT NULL ,
22 "string" VARCHAR(45) NOT NULL ,
23 "time" TIMESTAMP NOT NULL ,
24 "bool" BOOL NOT NULL ,
25 "bytes" VARCHAR(45) NOT NULL);
26 INSERT INTO "entry" ("int", "float", "string", "time", "bool", "bytes") VALUES
27 (0, 0.000000, '0.000000', '2015-02-22T18:19:55.000000000-00:00', TRUE, '0.000000'),
28 (1, 0.100000, '0.100000', '2015-02-22T19:19:55.000000000-00:00', FALSE, '0.100000'),
29 (2, 0.200000, '0.200000', '2015-02-22T20:19:55.000000000-00:00', TRUE, '0.200000'),
30 (3, 0.300000, '0.300000', '2015-02-22T21:19:55.000000000-00:00', FALSE, '0.300000'),
31 (4, 0.400000, '0.400000', '2015-02-22T22:19:55.000000000-00:00', TRUE, '0.400000'),
32 (5, 0.500000, '0.500000', '2015-02-22T23:19:55.000000000-00:00', FALSE, '0.500000'),
33 (6, 0.600000, '0.600000', '2015-02-23T00:19:55.000000000-00:00', TRUE, '0.600000'),
34 (7, 0.700000, '0.700000', '2015-02-23T01:19:55.000000000-00:00', FALSE, '0.700000'),
35 (8, 0.800000, '0.800000', '2015-02-23T02:19:55.000000000-00:00', TRUE, '0.800000'),
36 (9, 0.900000, '0.900000', '2015-02-23T03:19:55.000000000-00:00', FALSE, '0.900000');
37 `
38
39 const defaultDBURI = "postgres://postgres:@localhost:5435/goqupostgres?sslmode=disable"
40
41 type (
42 postgresTest struct {
43 suite.Suite
44 db *goqu.Database
45 }
46 entry struct {
47 ID uint32 `db:"id" goqu:"skipinsert,skipupdate"`
48 Int int `db:"int"`
49 Float float64 `db:"float"`
50 String string `db:"string"`
51 Time time.Time `db:"time"`
52 Bool bool `db:"bool"`
53 Bytes []byte `db:"bytes"`
54 }
55 entryTestCase struct {
56 ds *goqu.SelectDataset
57 len int
58 check func(entry entry, index int)
59 err string
60 }
61 )
62
63 func (pt *postgresTest) assertEntries(cases ...entryTestCase) {
64 for i, c := range cases {
65 var entries []entry
66 err := c.ds.ScanStructs(&entries)
67 if c.err == "" {
68 pt.NoError(err, "test case %d failed", i)
69 } else {
70 pt.EqualError(err, c.err, "test case %d failed", i)
71 }
72 pt.Len(entries, c.len)
73 for index, entry := range entries {
74 c.check(entry, index)
75 }
76 }
77 }
78
79 func (pt *postgresTest) SetupSuite() {
80 dbURI := os.Getenv("PG_URI")
81 if dbURI == "" {
82 dbURI = defaultDBURI
83 }
84 uri, err := pq.ParseURL(dbURI)
85 if err != nil {
86 panic(err)
87 }
88 db, err := sql.Open("postgres", uri)
89 if err != nil {
90 panic(err)
91 }
92 pt.db = goqu.New("postgres", db)
93 }
94
95 func (pt *postgresTest) SetupTest() {
96 if _, err := pt.db.Exec(schema); err != nil {
97 panic(err)
98 }
99 }
100
101 func (pt *postgresTest) TestToSQL() {
102 ds := pt.db.From("entry")
103 s, _, err := ds.Select("id", "float", "string", "time", "bool").ToSQL()
104 pt.NoError(err)
105 pt.Equal(`SELECT "id", "float", "string", "time", "bool" FROM "entry"`, s)
106
107 s, _, err = ds.Where(goqu.C("int").Eq(10)).ToSQL()
108 pt.NoError(err)
109 pt.Equal(`SELECT * FROM "entry" WHERE ("int" = 10)`, s)
110
111 s, args, err := ds.Prepared(true).Where(goqu.L("? = ?", goqu.C("int"), 10)).ToSQL()
112 pt.NoError(err)
113 pt.Equal([]interface{}{int64(10)}, args)
114 pt.Equal(`SELECT * FROM "entry" WHERE "int" = $1`, s)
115 }
116
117 func (pt *postgresTest) TestQuery() {
118 ds := pt.db.From("entry")
119 floatVal := float64(0)
120 baseDate, err := time.Parse(time.RFC3339Nano, "2015-02-22T18:19:55.000000000-00:00")
121 pt.NoError(err)
122 baseDate = baseDate.UTC()
123 pt.assertEntries(
124 entryTestCase{ds: ds.Order(goqu.C("id").Asc()), len: 10, check: func(entry entry, index int) {
125 f := fmt.Sprintf("%f", floatVal)
126 pt.Equal(uint32(index+1), entry.ID)
127 pt.Equal(index, entry.Int)
128 pt.Equal(f, fmt.Sprintf("%f", entry.Float))
129 pt.Equal(f, entry.String)
130 pt.Equal([]byte(f), entry.Bytes)
131 pt.Equal(index%2 == 0, entry.Bool)
132 pt.Equal(baseDate.Add(time.Duration(index)*time.Hour).Unix(), entry.Time.Unix())
133 floatVal += float64(0.1)
134 }},
135 entryTestCase{ds: ds.Where(goqu.C("bool").IsTrue()).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
136 pt.True(entry.Bool)
137 }},
138 entryTestCase{ds: ds.Where(goqu.C("int").Gt(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
139 pt.True(entry.Int > 4)
140 }},
141 entryTestCase{ds: ds.Where(goqu.C("int").Gte(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
142 pt.True(entry.Int >= 5)
143 }},
144 entryTestCase{ds: ds.Where(goqu.C("int").Lt(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
145 pt.True(entry.Int < 5)
146 }},
147 entryTestCase{ds: ds.Where(goqu.C("int").Lte(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
148 pt.True(entry.Int <= 4)
149 }},
150 entryTestCase{ds: ds.Where(goqu.C("int").Between(goqu.Range(3, 6))).Order(goqu.C("id").Asc()), len: 4, check: func(entry entry, _ int) {
151 pt.True(entry.Int >= 3)
152 pt.True(entry.Int <= 6)
153 }},
154 entryTestCase{ds: ds.Where(goqu.C("string").Eq("0.100000")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
155 pt.Equal(entry.String, "0.100000")
156 }},
157 entryTestCase{ds: ds.Where(goqu.C("string").Like("0.1%")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
158 pt.Equal(entry.String, "0.100000")
159 }},
160 entryTestCase{ds: ds.Where(goqu.C("string").NotLike("0.1%")).Order(goqu.C("id").Asc()), len: 9, check: func(entry entry, _ int) {
161 pt.NotEqual(entry.String, "0.100000")
162 }},
163 entryTestCase{ds: ds.Where(goqu.C("string").IsNull()).Order(goqu.C("id").Asc()), len: 0, check: func(entry entry, _ int) {
164 pt.Fail("Should not have returned any records")
165 }},
166 )
167 }
168
169 func (pt *postgresTest) TestQuery_Prepared() {
170 ds := pt.db.From("entry").Prepared(true)
171 floatVal := float64(0)
172 baseDate, err := time.Parse(time.RFC3339Nano, "2015-02-22T18:19:55.000000000-00:00")
173 pt.NoError(err)
174 baseDate = baseDate.UTC()
175 pt.assertEntries(
176 entryTestCase{ds: ds.Order(goqu.C("id").Asc()), len: 10, check: func(entry entry, index int) {
177 f := fmt.Sprintf("%f", floatVal)
178 pt.Equal(uint32(index+1), entry.ID)
179 pt.Equal(index, entry.Int)
180 pt.Equal(f, fmt.Sprintf("%f", entry.Float))
181 pt.Equal(f, entry.String)
182 pt.Equal([]byte(f), entry.Bytes)
183 pt.Equal(index%2 == 0, entry.Bool)
184 pt.Equal(baseDate.Add(time.Duration(index)*time.Hour).Unix(), entry.Time.Unix())
185 floatVal += float64(0.1)
186 }},
187 entryTestCase{ds: ds.Where(goqu.C("bool").IsTrue()).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
188 pt.True(entry.Bool)
189 }},
190 entryTestCase{ds: ds.Where(goqu.C("int").Gt(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
191 pt.True(entry.Int > 4)
192 }},
193 entryTestCase{ds: ds.Where(goqu.C("int").Gte(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
194 pt.True(entry.Int >= 5)
195 }},
196 entryTestCase{ds: ds.Where(goqu.C("int").Lt(5)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
197 pt.True(entry.Int < 5)
198 }},
199 entryTestCase{ds: ds.Where(goqu.C("int").Lte(4)).Order(goqu.C("id").Asc()), len: 5, check: func(entry entry, _ int) {
200 pt.True(entry.Int <= 4)
201 }},
202 entryTestCase{ds: ds.Where(goqu.C("int").Between(goqu.Range(3, 6))).Order(goqu.C("id").Asc()), len: 4, check: func(entry entry, _ int) {
203 pt.True(entry.Int >= 3)
204 pt.True(entry.Int <= 6)
205 }},
206 entryTestCase{ds: ds.Where(goqu.C("string").Eq("0.100000")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
207 pt.Equal(entry.String, "0.100000")
208 }},
209 entryTestCase{ds: ds.Where(goqu.C("string").Like("0.1%")).Order(goqu.C("id").Asc()), len: 1, check: func(entry entry, _ int) {
210 pt.Equal(entry.String, "0.100000")
211 }},
212 entryTestCase{ds: ds.Where(goqu.C("string").NotLike("0.1%")).Order(goqu.C("id").Asc()), len: 9, check: func(entry entry, _ int) {
213 pt.NotEqual(entry.String, "0.100000")
214 }},
215 entryTestCase{ds: ds.Where(goqu.C("string").IsNull()).Order(goqu.C("id").Asc()), len: 0, check: func(entry entry, _ int) {
216 pt.Fail("Should not have returned any records")
217 }},
218 )
219 }
220
221 func (pt *postgresTest) TestQuery_ValueExpressions() {
222 type wrappedEntry struct {
223 entry
224 BoolValue bool `db:"bool_value"`
225 }
226 expectedDate, err := time.Parse(time.RFC3339Nano, "2015-02-22T19:19:55.000000000-00:00")
227 pt.NoError(err)
228 ds := pt.db.From("entry").Select(goqu.Star(), goqu.V(true).As("bool_value")).Where(goqu.Ex{"int": 1})
229 var we wrappedEntry
230 found, err := ds.ScanStruct(&we)
231 pt.NoError(err)
232 pt.True(found)
233 pt.Equal(1, we.Int)
234 pt.Equal(0.100000, we.Float)
235 pt.Equal("0.100000", we.String)
236 pt.Equal(expectedDate.Unix(), we.Time.Unix())
237 pt.Equal(false, we.Bool)
238 pt.Equal([]byte("0.100000"), we.Bytes)
239 pt.True(we.BoolValue)
240 }
241
242 func (pt *postgresTest) TestCount() {
243 ds := pt.db.From("entry")
244 count, err := ds.Count()
245 pt.NoError(err)
246 pt.Equal(int64(10), count)
247 count, err = ds.Where(goqu.C("int").Gt(4)).Count()
248 pt.NoError(err)
249 pt.Equal(int64(5), count)
250 count, err = ds.Where(goqu.C("int").Gte(4)).Count()
251 pt.NoError(err)
252 pt.Equal(int64(6), count)
253 count, err = ds.Where(goqu.C("string").Like("0.1%")).Count()
254 pt.NoError(err)
255 pt.Equal(int64(1), count)
256 count, err = ds.Where(goqu.C("string").IsNull()).Count()
257 pt.NoError(err)
258 pt.Equal(int64(0), count)
259 }
260
261 func (pt *postgresTest) TestInsert() {
262 ds := pt.db.From("entry")
263 now := time.Now()
264 e := entry{Int: 10, Float: 1.000000, String: "1.000000", Time: now, Bool: true, Bytes: []byte("1.000000")}
265 _, err := ds.Insert().Rows(e).Executor().Exec()
266 pt.NoError(err)
267
268 var insertedEntry entry
269 found, err := ds.Where(goqu.C("int").Eq(10)).ScanStruct(&insertedEntry)
270 pt.NoError(err)
271 pt.True(found)
272 pt.True(insertedEntry.ID > 0)
273
274 entries := []entry{
275 {Int: 11, Float: 1.100000, String: "1.100000", Time: now, Bool: false, Bytes: []byte("1.100000")},
276 {Int: 12, Float: 1.200000, String: "1.200000", Time: now, Bool: true, Bytes: []byte("1.200000")},
277 {Int: 13, Float: 1.300000, String: "1.300000", Time: now, Bool: false, Bytes: []byte("1.300000")},
278 {Int: 14, Float: 1.400000, String: "1.400000", Time: now, Bool: true, Bytes: []byte("1.400000")},
279 }
280 _, err = ds.Insert().Rows(entries).Executor().Exec()
281 pt.NoError(err)
282
283 var newEntries []entry
284
285 pt.NoError(ds.Where(goqu.C("int").In([]uint32{11, 12, 13, 14})).ScanStructs(&newEntries))
286 pt.Len(newEntries, 4)
287 for i, e := range newEntries {
288 pt.Equal(entries[i].Int, e.Int)
289 pt.Equal(entries[i].Float, e.Float)
290 pt.Equal(entries[i].String, e.String)
291 pt.Equal(entries[i].Time.Unix(), e.Time.Unix())
292 pt.Equal(entries[i].Bool, e.Bool)
293 pt.Equal(entries[i].Bytes, e.Bytes)
294 }
295
296 _, err = ds.Insert().Rows(
297 entry{Int: 15, Float: 1.500000, String: "1.500000", Time: now, Bool: false, Bytes: []byte("1.500000")},
298 entry{Int: 16, Float: 1.600000, String: "1.600000", Time: now, Bool: true, Bytes: []byte("1.600000")},
299 entry{Int: 17, Float: 1.700000, String: "1.700000", Time: now, Bool: false, Bytes: []byte("1.700000")},
300 entry{Int: 18, Float: 1.800000, String: "1.800000", Time: now, Bool: true, Bytes: []byte("1.800000")},
301 ).Executor().Exec()
302 pt.NoError(err)
303
304 newEntries = newEntries[0:0]
305 pt.NoError(ds.Where(goqu.C("int").In([]uint32{15, 16, 17, 18})).ScanStructs(&newEntries))
306 pt.Len(newEntries, 4)
307 }
308
309 func (pt *postgresTest) TestInsertReturning() {
310 ds := pt.db.From("entry")
311 now := time.Now()
312 e := entry{Int: 10, Float: 1.000000, String: "1.000000", Time: now, Bool: true, Bytes: []byte("1.000000")}
313 found, err := ds.Insert().Rows(e).Returning(goqu.Star()).Executor().ScanStruct(&e)
314 pt.NoError(err)
315 pt.True(found)
316 pt.True(e.ID > 0)
317
318 var ids []uint32
319 pt.NoError(ds.Insert().Rows([]entry{
320 {Int: 11, Float: 1.100000, String: "1.100000", Time: now, Bool: false, Bytes: []byte("1.100000")},
321 {Int: 12, Float: 1.200000, String: "1.200000", Time: now, Bool: true, Bytes: []byte("1.200000")},
322 {Int: 13, Float: 1.300000, String: "1.300000", Time: now, Bool: false, Bytes: []byte("1.300000")},
323 {Int: 14, Float: 1.400000, String: "1.400000", Time: now, Bool: true, Bytes: []byte("1.400000")},
324 }).Returning("id").Executor().ScanVals(&ids))
325 pt.Len(ids, 4)
326 for _, id := range ids {
327 pt.True(id > 0)
328 }
329
330 var ints []int64
331 pt.NoError(ds.Insert().Rows(
332 entry{Int: 15, Float: 1.500000, String: "1.500000", Time: now, Bool: false, Bytes: []byte("1.500000")},
333 entry{Int: 16, Float: 1.600000, String: "1.600000", Time: now, Bool: true, Bytes: []byte("1.600000")},
334 entry{Int: 17, Float: 1.700000, String: "1.700000", Time: now, Bool: false, Bytes: []byte("1.700000")},
335 entry{Int: 18, Float: 1.800000, String: "1.800000", Time: now, Bool: true, Bytes: []byte("1.800000")},
336 ).Returning("int").Executor().ScanVals(&ints))
337 pt.True(found)
338 pt.Equal(ints, []int64{15, 16, 17, 18})
339 }
340
341 func (pt *postgresTest) TestUpdate() {
342 ds := pt.db.From("entry")
343 var e entry
344 found, err := ds.Where(goqu.C("int").Eq(9)).Select("id").ScanStruct(&e)
345 pt.NoError(err)
346 pt.True(found)
347 e.Int = 11
348 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Update().Set(e).Executor().Exec()
349 pt.NoError(err)
350
351 count, err := ds.Where(goqu.C("int").Eq(11)).Count()
352 pt.NoError(err)
353 pt.Equal(int64(1), count)
354
355 var id uint32
356 found, err = ds.Where(goqu.C("int").Eq(11)).
357 Update().
358 Set(goqu.Record{"int": 9}).
359 Returning("id").Executor().ScanVal(&id)
360 pt.NoError(err)
361 pt.True(found)
362 pt.Equal(id, e.ID)
363 }
364
365 func (pt *postgresTest) TestUpdateSQL_multipleTables() {
366 ds := pt.db.Update("test")
367 updateSQL, _, err := ds.
368 Set(goqu.Record{"foo": "bar"}).
369 From("test_2").
370 Where(goqu.I("test.id").Eq(goqu.I("test_2.test_id"))).
371 ToSQL()
372 pt.NoError(err)
373 pt.Equal(`UPDATE "test" SET "foo"='bar' FROM "test_2" WHERE ("test"."id" = "test_2"."test_id")`, updateSQL)
374 }
375
376 func (pt *postgresTest) TestDelete() {
377 ds := pt.db.From("entry")
378 var e entry
379 found, err := ds.Where(goqu.C("int").Eq(9)).Select("id").ScanStruct(&e)
380 pt.NoError(err)
381 pt.True(found)
382 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Delete().Executor().Exec()
383 pt.NoError(err)
384
385 count, err := ds.Count()
386 pt.NoError(err)
387 pt.Equal(int64(9), count)
388
389 var id uint32
390 found, err = ds.Where(goqu.C("id").Eq(e.ID)).ScanVal(&id)
391 pt.NoError(err)
392 pt.False(found)
393
394 e = entry{}
395 found, err = ds.Where(goqu.C("int").Eq(8)).Select("id").ScanStruct(&e)
396 pt.NoError(err)
397 pt.True(found)
398 pt.NotEqual(e.ID, int64(0))
399
400 id = 0
401 _, err = ds.Where(goqu.C("id").Eq(e.ID)).Delete().Returning("id").Executor().ScanVal(&id)
402 pt.NoError(err)
403 pt.Equal(id, e.ID)
404 }
405
406 func (pt *postgresTest) TestInsert_OnConflict() {
407 ds := pt.db.From("entry")
408 now := time.Now()
409
410
411 e := entry{Int: 10, Float: 1.100000, String: "1.100000", Time: now, Bool: false, Bytes: []byte("1.100000")}
412 _, err := ds.Insert().Rows(e).OnConflict(goqu.DoNothing()).Executor().Exec()
413 pt.NoError(err)
414
415
416 e = entry{Int: 10, Float: 2.100000, String: "2.100000", Time: now.Add(time.Hour * 100), Bool: false, Bytes: []byte("2.100000")}
417 _, err = ds.Insert().Rows(e).OnConflict(goqu.DoNothing()).Executor().Exec()
418 pt.NoError(err)
419
420
421 var entryActual entry
422 e2 := entry{Int: 0, String: "2.000000"}
423 _, err = ds.Insert().
424 Rows(e2).
425 OnConflict(goqu.DoUpdate("int", goqu.Record{"string": "upsert"})).
426 Executor().Exec()
427 pt.NoError(err)
428 _, err = ds.Where(goqu.C("int").Eq(0)).ScanStruct(&entryActual)
429 pt.NoError(err)
430 pt.Equal("upsert", entryActual.String)
431
432
433 entries := []entry{
434 {Int: 1, Float: 6.100000, String: "6.100000", Time: now, Bytes: []byte("6.100000")},
435 {Int: 2, Float: 7.200000, String: "7.200000", Time: now, Bytes: []byte("7.200000")},
436 }
437 _, err = ds.Insert().
438 Rows(entries).
439 OnConflict(goqu.DoUpdate("int", goqu.Record{"string": "upsert"}).Where(goqu.I("excluded.int").Eq(2))).
440 Executor().
441 Exec()
442 pt.NoError(err)
443
444 var entry8, entry9 entry
445 _, err = ds.Where(goqu.Ex{"int": 1}).ScanStruct(&entry8)
446 pt.NoError(err)
447 pt.Equal("0.100000", entry8.String)
448
449 _, err = ds.Where(goqu.Ex{"int": 2}).ScanStruct(&entry9)
450 pt.NoError(err)
451 pt.Equal("upsert", entry9.String)
452 }
453
454 func (pt *postgresTest) TestWindowFunction() {
455 ds := pt.db.From("entry").
456 Select("int", goqu.ROW_NUMBER().OverName(goqu.I("w")).As("id")).
457 Window(goqu.W("w").OrderBy(goqu.I("int").Desc()))
458
459 var entries []entry
460 pt.NoError(ds.ScanStructs(&entries))
461
462 pt.Equal([]entry{
463 {Int: 9, ID: 1},
464 {Int: 8, ID: 2},
465 {Int: 7, ID: 3},
466 {Int: 6, ID: 4},
467 {Int: 5, ID: 5},
468 {Int: 4, ID: 6},
469 {Int: 3, ID: 7},
470 {Int: 2, ID: 8},
471 {Int: 1, ID: 9},
472 {Int: 0, ID: 10},
473 }, entries)
474 }
475
476 func (pt *postgresTest) TestOrderByFunction() {
477 ds := pt.db.From("entry").
478 Select(goqu.ROW_NUMBER().Over(goqu.W()).As("id")).Window().Order(goqu.ROW_NUMBER().Over(goqu.W()).Desc())
479
480 var entries []entry
481 pt.NoError(ds.ScanStructs(&entries))
482
483 pt.Equal([]entry{
484 {ID: 10},
485 {ID: 9},
486 {ID: 8},
487 {ID: 7},
488 {ID: 6},
489 {ID: 5},
490 {ID: 4},
491 {ID: 3},
492 {ID: 2},
493 {ID: 1},
494 }, entries)
495 }
496
497 func TestPostgresSuite(t *testing.T) {
498 suite.Run(t, new(postgresTest))
499 }
500
View as plain text