1
2
3 package sqlmock
4
5 import (
6 "database/sql"
7 "encoding/json"
8 "fmt"
9 "reflect"
10 "testing"
11 "time"
12 )
13
14 func TestQueryMultiRows(t *testing.T) {
15 t.Parallel()
16 db, mock, err := New()
17 if err != nil {
18 t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
19 }
20 defer db.Close()
21
22 rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
23 rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))
24
25 mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
26 WithArgs(5).
27 WillReturnRows(rs1, rs2)
28
29 rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
30 if err != nil {
31 t.Errorf("error was not expected, but got: %v", err)
32 }
33 defer rows.Close()
34
35 if !rows.Next() {
36 t.Error("expected a row to be available in first result set")
37 }
38
39 var id int
40 var name string
41
42 err = rows.Scan(&id, &name)
43 if err != nil {
44 t.Errorf("error was not expected, but got: %v", err)
45 }
46
47 if id != 5 || name != "hello world" {
48 t.Errorf("unexpected row values id: %v name: %v", id, name)
49 }
50
51 if rows.Next() {
52 t.Error("was not expecting next row in first result set")
53 }
54
55 if !rows.NextResultSet() {
56 t.Error("had to have next result set")
57 }
58
59 if !rows.Next() {
60 t.Error("expected a row to be available in second result set")
61 }
62
63 err = rows.Scan(&name)
64 if err != nil {
65 t.Errorf("error was not expected, but got: %v", err)
66 }
67
68 if name != "gopher" {
69 t.Errorf("unexpected row name: %v", name)
70 }
71
72 if !rows.Next() {
73 t.Error("expected a row to be available in second result set")
74 }
75
76 err = rows.Scan(&name)
77 if err != nil {
78 t.Errorf("error was not expected, but got: %v", err)
79 }
80
81 if name != "john" {
82 t.Errorf("unexpected row name: %v", name)
83 }
84
85 if rows.Next() {
86 t.Error("expected next row to produce error")
87 }
88
89 if rows.Err() == nil {
90 t.Error("expected an error, but there was none")
91 }
92
93 if err := mock.ExpectationsWereMet(); err != nil {
94 t.Errorf("there were unfulfilled expectations: %s", err)
95 }
96 }
97
98 func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) {
99 t.Parallel()
100 replace := []byte(invalid)
101 rows := NewRows([]string{"raw"}).
102 AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
103 AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
104 scan := func(rs *sql.Rows) ([]byte, error) {
105 var raw sql.RawBytes
106 return raw, rs.Scan(&raw)
107 }
108 want := []struct {
109 Initial []byte
110 Replaced []byte
111 }{
112 {Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]},
113 {Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]},
114 }
115 queryRowBytesInvalidatedByNext(t, rows, scan, want)
116 }
117
118 func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) {
119 t.Parallel()
120 rows := NewRows([]string{"raw"}).
121 AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
122 AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
123 scan := func(rs *sql.Rows) ([]byte, error) {
124 var b []byte
125 return b, rs.Scan(&b)
126 }
127 want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
128 queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
129 }
130
131 func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) {
132 t.Parallel()
133 rows := NewRows([]string{"raw"}).
134 AddRow([]byte(`one binary value with some text!`)).
135 AddRow([]byte(`two binary value with even more text than the first one`))
136 scan := func(rs *sql.Rows) ([]byte, error) {
137 type customBytes []byte
138 var b customBytes
139 return b, rs.Scan(&b)
140 }
141 want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
142 queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
143 }
144
145 func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) {
146 t.Parallel()
147 rows := NewRows([]string{"raw"}).
148 AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
149 AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
150 scan := func(rs *sql.Rows) ([]byte, error) {
151 type customBytes []byte
152 var b customBytes
153 return b, rs.Scan(&b)
154 }
155 want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
156 queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
157 }
158
159 func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) {
160 t.Parallel()
161 rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
162 scan := func(rs *sql.Rows) ([]byte, error) {
163 type customBytes []byte
164 var b customBytes
165 return b, rs.Scan(&b)
166 }
167 queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
168 }
169
170 func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) {
171 t.Parallel()
172 replace := []byte(invalid)
173 rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
174 scan := func(rs *sql.Rows) ([]byte, error) {
175 var raw sql.RawBytes
176 return raw, rs.Scan(&raw)
177 }
178 want := struct {
179 Initial []byte
180 Replaced []byte
181 }{
182 Initial: []byte(`{"thing": "one", "thing2": "two"}`),
183 Replaced: replace[:len(replace)-6],
184 }
185 queryRowBytesInvalidatedByClose(t, rows, scan, want)
186 }
187
188 func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) {
189 t.Parallel()
190 rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
191 scan := func(rs *sql.Rows) ([]byte, error) {
192 var b []byte
193 return b, rs.Scan(&b)
194 }
195 queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
196 }
197
198 func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) {
199 t.Parallel()
200 rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
201 scan := func(rs *sql.Rows) ([]byte, error) {
202 type customBytes []byte
203 var b customBytes
204 return b, rs.Scan(&b)
205 }
206 queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
207 }
208
209 func TestNewColumnWithDefinition(t *testing.T) {
210 now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")
211
212 t.Run("with one ResultSet", func(t *testing.T) {
213 db, mock, _ := New()
214 column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
215 column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
216 column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
217 rows := mock.NewRowsWithColumnDefinition(column1, column2, column3)
218 rows.AddRow("foo.bar", float64(10.123), now)
219
220 mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
221 isQuery := mQuery.WillReturnRows(rows)
222 isQueryClosed := mQuery.RowsWillBeClosed()
223 isDbClosed := mock.ExpectClose()
224
225 query, _ := db.Query("SELECT test, number, when from dummy")
226
227 if false == isQuery.fulfilled() {
228 t.Error("Query is not executed")
229 }
230
231 if query.Next() {
232 var test string
233 var number float64
234 var when time.Time
235
236 if queryError := query.Scan(&test, &number, &when); queryError != nil {
237 t.Error(queryError)
238 } else if test != "foo.bar" {
239 t.Error("field test is not 'foo.bar'")
240 } else if number != float64(10.123) {
241 t.Error("field number is not '10.123'")
242 } else if when != now {
243 t.Errorf("field when is not %v", now)
244 }
245
246 if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
247 t.Error(colTypErr)
248 } else if len(columnTypes) != 3 {
249 t.Error("number of columnTypes")
250 } else if name := columnTypes[0].Name(); name != "test" {
251 t.Errorf("field 'test' has a wrong name '%s'", name)
252 } else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
253 t.Errorf("field 'test' has a wrong db type '%s'", dbType)
254 } else if columnTypes[0].ScanType().Kind() != reflect.String {
255 t.Error("field 'test' has a wrong scanType")
256 } else if _, _, ok := columnTypes[0].DecimalSize(); ok {
257 t.Error("field 'test' should have not precision, scale")
258 } else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
259 t.Errorf("field 'test' has a wrong length '%d'", length)
260 } else if name := columnTypes[1].Name(); name != "number" {
261 t.Errorf("field 'number' has a wrong name '%s'", name)
262 } else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
263 t.Errorf("field 'number' has a wrong db type '%s'", dbType)
264 } else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
265 t.Error("field 'number' has a wrong scanType")
266 } else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
267 t.Error("field 'number' has a wrong precision, scale")
268 } else if _, ok := columnTypes[1].Length(); ok {
269 t.Error("field 'number' is not variable length type")
270 } else if _, ok := columnTypes[2].Nullable(); ok {
271 t.Error("field 'when' should have nullability unknown")
272 }
273 } else {
274 t.Error("no result set")
275 }
276
277 query.Close()
278 if false == isQueryClosed.fulfilled() {
279 t.Error("Query is not executed")
280 }
281
282 db.Close()
283 if false == isDbClosed.fulfilled() {
284 t.Error("Db is not closed")
285 }
286 })
287
288 t.Run("with more then one ResultSet", func(t *testing.T) {
289 db, mock, _ := New()
290 column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
291 column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
292 column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
293 rows1 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
294 rows1.AddRow("foo.bar", float64(10.123), now)
295 rows2 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
296 rows2.AddRow("bar.foo", float64(123.10), now.Add(time.Second*10))
297 rows3 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
298 rows3.AddRow("lollipop", float64(10.321), now.Add(time.Second*20))
299
300 mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
301 isQuery := mQuery.WillReturnRows(rows1, rows2, rows3)
302 isQueryClosed := mQuery.RowsWillBeClosed()
303 isDbClosed := mock.ExpectClose()
304
305 query, _ := db.Query("SELECT test, number, when from dummy")
306
307 if false == isQuery.fulfilled() {
308 t.Error("Query is not executed")
309 }
310
311 rowsSi := 0
312
313 for query.Next() {
314 var test string
315 var number float64
316 var when time.Time
317
318 if queryError := query.Scan(&test, &number, &when); queryError != nil {
319 t.Error(queryError)
320
321 } else if rowsSi == 0 && test != "foo.bar" {
322 t.Error("field test is not 'foo.bar'")
323 } else if rowsSi == 0 && number != float64(10.123) {
324 t.Error("field number is not '10.123'")
325 } else if rowsSi == 0 && when != now {
326 t.Errorf("field when is not %v", now)
327
328 } else if rowsSi == 1 && test != "bar.foo" {
329 t.Error("field test is not 'bar.bar'")
330 } else if rowsSi == 1 && number != float64(123.10) {
331 t.Error("field number is not '123.10'")
332 } else if rowsSi == 1 && when != now.Add(time.Second*10) {
333 t.Errorf("field when is not %v", now)
334
335 } else if rowsSi == 2 && test != "lollipop" {
336 t.Error("field test is not 'lollipop'")
337 } else if rowsSi == 2 && number != float64(10.321) {
338 t.Error("field number is not '10.321'")
339 } else if rowsSi == 2 && when != now.Add(time.Second*20) {
340 t.Errorf("field when is not %v", now)
341 }
342
343 rowsSi++
344
345 if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
346 t.Error(colTypErr)
347 } else if len(columnTypes) != 3 {
348 t.Error("number of columnTypes")
349 } else if name := columnTypes[0].Name(); name != "test" {
350 t.Errorf("field 'test' has a wrong name '%s'", name)
351 } else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
352 t.Errorf("field 'test' has a wrong db type '%s'", dbType)
353 } else if columnTypes[0].ScanType().Kind() != reflect.String {
354 t.Error("field 'test' has a wrong scanType")
355 } else if _, _, ok := columnTypes[0].DecimalSize(); ok {
356 t.Error("field 'test' should not have precision, scale")
357 } else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
358 t.Errorf("field 'test' has a wrong length '%d'", length)
359 } else if name := columnTypes[1].Name(); name != "number" {
360 t.Errorf("field 'number' has a wrong name '%s'", name)
361 } else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
362 t.Errorf("field 'number' has a wrong db type '%s'", dbType)
363 } else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
364 t.Error("field 'number' has a wrong scanType")
365 } else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
366 t.Error("field 'number' has a wrong precision, scale")
367 } else if _, ok := columnTypes[1].Length(); ok {
368 t.Error("field 'number' is not variable length type")
369 } else if _, ok := columnTypes[2].Nullable(); ok {
370 t.Error("field 'when' should have nullability unknown")
371 }
372 }
373 if rowsSi == 0 {
374 t.Error("no result set")
375 }
376
377 query.Close()
378 if false == isQueryClosed.fulfilled() {
379 t.Error("Query is not executed")
380 }
381
382 db.Close()
383 if false == isDbClosed.fulfilled() {
384 t.Error("Db is not closed")
385 }
386 })
387 }
388
View as plain text