1
11 package sqlmock
12
13 import (
14 "database/sql"
15 "database/sql/driver"
16 "fmt"
17 "time"
18 )
19
20
21
22
23 type SqlmockCommon interface {
24
25
26
27 ExpectClose() *ExpectedClose
28
29
30
31 ExpectationsWereMet() error
32
33
34
35
36
37 ExpectPrepare(expectedSQL string) *ExpectedPrepare
38
39
40
41 ExpectQuery(expectedSQL string) *ExpectedQuery
42
43
44
45 ExpectExec(expectedSQL string) *ExpectedExec
46
47
48
49 ExpectBegin() *ExpectedBegin
50
51
52
53 ExpectCommit() *ExpectedCommit
54
55
56
57 ExpectRollback() *ExpectedRollback
58
59
60
61
62
63
64
65
66
67
68 ExpectPing() *ExpectedPing
69
70
71
72
73
74
75
76
77
78
79
80
81 MatchExpectationsInOrder(bool)
82
83
84
85
86 NewRows(columns []string) *Rows
87 }
88
89 type sqlmock struct {
90 ordered bool
91 dsn string
92 opened int
93 drv *mockDriver
94 converter driver.ValueConverter
95 queryMatcher QueryMatcher
96 monitorPings bool
97
98 expected []expectation
99 }
100
101 func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
102 db, err := sql.Open("sqlmock", c.dsn)
103 if err != nil {
104 return db, c, err
105 }
106 for _, option := range options {
107 err := option(c)
108 if err != nil {
109 return db, c, err
110 }
111 }
112 if c.converter == nil {
113 c.converter = driver.DefaultParameterConverter
114 }
115 if c.queryMatcher == nil {
116 c.queryMatcher = QueryMatcherRegexp
117 }
118
119 if c.monitorPings {
120
121
122
123
124 c.monitorPings = false
125 defer func() { c.monitorPings = true }()
126 }
127 return db, c, db.Ping()
128 }
129
130 func (c *sqlmock) ExpectClose() *ExpectedClose {
131 e := &ExpectedClose{}
132 c.expected = append(c.expected, e)
133 return e
134 }
135
136 func (c *sqlmock) MatchExpectationsInOrder(b bool) {
137 c.ordered = b
138 }
139
140
141
142
143
144 func (c *sqlmock) Close() error {
145 c.drv.Lock()
146 defer c.drv.Unlock()
147
148 c.opened--
149 if c.opened == 0 {
150 delete(c.drv.conns, c.dsn)
151 }
152
153 var expected *ExpectedClose
154 var fulfilled int
155 var ok bool
156 for _, next := range c.expected {
157 next.Lock()
158 if next.fulfilled() {
159 next.Unlock()
160 fulfilled++
161 continue
162 }
163
164 if expected, ok = next.(*ExpectedClose); ok {
165 break
166 }
167
168 next.Unlock()
169 if c.ordered {
170 return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next)
171 }
172 }
173
174 if expected == nil {
175 msg := "call to database Close was not expected"
176 if fulfilled == len(c.expected) {
177 msg = "all expectations were already fulfilled, " + msg
178 }
179 return fmt.Errorf(msg)
180 }
181
182 expected.triggered = true
183 expected.Unlock()
184 return expected.err
185 }
186
187 func (c *sqlmock) ExpectationsWereMet() error {
188 for _, e := range c.expected {
189 e.Lock()
190 fulfilled := e.fulfilled()
191 e.Unlock()
192
193 if !fulfilled {
194 return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
195 }
196
197
198 if prep, ok := e.(*ExpectedPrepare); ok {
199 if prep.mustBeClosed && !prep.wasClosed {
200 return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep)
201 }
202 }
203
204
205 if query, ok := e.(*ExpectedQuery); ok {
206 if query.rowsMustBeClosed && !query.rowsWereClosed {
207 return fmt.Errorf("expected query rows to be closed, but it was not: %s", query)
208 }
209 }
210 }
211 return nil
212 }
213
214
215 func (c *sqlmock) Begin() (driver.Tx, error) {
216 ex, err := c.begin()
217 if ex != nil {
218 time.Sleep(ex.delay)
219 }
220 if err != nil {
221 return nil, err
222 }
223
224 return c, nil
225 }
226
227 func (c *sqlmock) begin() (*ExpectedBegin, error) {
228 var expected *ExpectedBegin
229 var ok bool
230 var fulfilled int
231 for _, next := range c.expected {
232 next.Lock()
233 if next.fulfilled() {
234 next.Unlock()
235 fulfilled++
236 continue
237 }
238
239 if expected, ok = next.(*ExpectedBegin); ok {
240 break
241 }
242
243 next.Unlock()
244 if c.ordered {
245 return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next)
246 }
247 }
248 if expected == nil {
249 msg := "call to database transaction Begin was not expected"
250 if fulfilled == len(c.expected) {
251 msg = "all expectations were already fulfilled, " + msg
252 }
253 return nil, fmt.Errorf(msg)
254 }
255
256 expected.triggered = true
257 expected.Unlock()
258
259 return expected, expected.err
260 }
261
262 func (c *sqlmock) ExpectBegin() *ExpectedBegin {
263 e := &ExpectedBegin{}
264 c.expected = append(c.expected, e)
265 return e
266 }
267
268 func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec {
269 e := &ExpectedExec{}
270 e.expectSQL = expectedSQL
271 e.converter = c.converter
272 c.expected = append(c.expected, e)
273 return e
274 }
275
276
277 func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
278 ex, err := c.prepare(query)
279 if ex != nil {
280 time.Sleep(ex.delay)
281 }
282 if err != nil {
283 return nil, err
284 }
285
286 return &statement{c, ex, query}, nil
287 }
288
289 func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
290 var expected *ExpectedPrepare
291 var fulfilled int
292 var ok bool
293
294 for _, next := range c.expected {
295 next.Lock()
296 if next.fulfilled() {
297 next.Unlock()
298 fulfilled++
299 continue
300 }
301
302 if c.ordered {
303 if expected, ok = next.(*ExpectedPrepare); ok {
304 break
305 }
306
307 next.Unlock()
308 return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next)
309 }
310
311 if pr, ok := next.(*ExpectedPrepare); ok {
312 if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
313 expected = pr
314 break
315 }
316 }
317 next.Unlock()
318 }
319
320 if expected == nil {
321 msg := "call to Prepare '%s' query was not expected"
322 if fulfilled == len(c.expected) {
323 msg = "all expectations were already fulfilled, " + msg
324 }
325 return nil, fmt.Errorf(msg, query)
326 }
327 defer expected.Unlock()
328 if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
329 return nil, fmt.Errorf("Prepare: %v", err)
330 }
331
332 expected.triggered = true
333 return expected, expected.err
334 }
335
336 func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
337 e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
338 c.expected = append(c.expected, e)
339 return e
340 }
341
342 func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
343 e := &ExpectedQuery{}
344 e.expectSQL = expectedSQL
345 e.converter = c.converter
346 c.expected = append(c.expected, e)
347 return e
348 }
349
350 func (c *sqlmock) ExpectCommit() *ExpectedCommit {
351 e := &ExpectedCommit{}
352 c.expected = append(c.expected, e)
353 return e
354 }
355
356 func (c *sqlmock) ExpectRollback() *ExpectedRollback {
357 e := &ExpectedRollback{}
358 c.expected = append(c.expected, e)
359 return e
360 }
361
362
363 func (c *sqlmock) Commit() error {
364 var expected *ExpectedCommit
365 var fulfilled int
366 var ok bool
367 for _, next := range c.expected {
368 next.Lock()
369 if next.fulfilled() {
370 next.Unlock()
371 fulfilled++
372 continue
373 }
374
375 if expected, ok = next.(*ExpectedCommit); ok {
376 break
377 }
378
379 next.Unlock()
380 if c.ordered {
381 return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next)
382 }
383 }
384 if expected == nil {
385 msg := "call to Commit transaction was not expected"
386 if fulfilled == len(c.expected) {
387 msg = "all expectations were already fulfilled, " + msg
388 }
389 return fmt.Errorf(msg)
390 }
391
392 expected.triggered = true
393 expected.Unlock()
394 return expected.err
395 }
396
397
398 func (c *sqlmock) Rollback() error {
399 var expected *ExpectedRollback
400 var fulfilled int
401 var ok bool
402 for _, next := range c.expected {
403 next.Lock()
404 if next.fulfilled() {
405 next.Unlock()
406 fulfilled++
407 continue
408 }
409
410 if expected, ok = next.(*ExpectedRollback); ok {
411 break
412 }
413
414 next.Unlock()
415 if c.ordered {
416 return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next)
417 }
418 }
419 if expected == nil {
420 msg := "call to Rollback transaction was not expected"
421 if fulfilled == len(c.expected) {
422 msg = "all expectations were already fulfilled, " + msg
423 }
424 return fmt.Errorf(msg)
425 }
426
427 expected.triggered = true
428 expected.Unlock()
429 return expected.err
430 }
431
432
433
434
435 func (c *sqlmock) NewRows(columns []string) *Rows {
436 r := NewRows(columns)
437 r.converter = c.converter
438 return r
439 }
440
View as plain text