...
1 package sqlmock
2
3 import (
4 "database/sql/driver"
5 "fmt"
6 "strings"
7 "sync"
8 "time"
9 )
10
11
12 type expectation interface {
13 fulfilled() bool
14 Lock()
15 Unlock()
16 String() string
17 }
18
19
20
21 type commonExpectation struct {
22 sync.Mutex
23 triggered bool
24 err error
25 }
26
27 func (e *commonExpectation) fulfilled() bool {
28 return e.triggered
29 }
30
31
32
33 type ExpectedClose struct {
34 commonExpectation
35 }
36
37
38 func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose {
39 e.err = err
40 return e
41 }
42
43
44 func (e *ExpectedClose) String() string {
45 msg := "ExpectedClose => expecting database Close"
46 if e.err != nil {
47 msg += fmt.Sprintf(", which should return error: %s", e.err)
48 }
49 return msg
50 }
51
52
53
54 type ExpectedBegin struct {
55 commonExpectation
56 delay time.Duration
57 }
58
59
60 func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
61 e.err = err
62 return e
63 }
64
65
66 func (e *ExpectedBegin) String() string {
67 msg := "ExpectedBegin => expecting database transaction Begin"
68 if e.err != nil {
69 msg += fmt.Sprintf(", which should return error: %s", e.err)
70 }
71 return msg
72 }
73
74
75
76 func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
77 e.delay = duration
78 return e
79 }
80
81
82
83 type ExpectedCommit struct {
84 commonExpectation
85 }
86
87
88 func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit {
89 e.err = err
90 return e
91 }
92
93
94 func (e *ExpectedCommit) String() string {
95 msg := "ExpectedCommit => expecting transaction Commit"
96 if e.err != nil {
97 msg += fmt.Sprintf(", which should return error: %s", e.err)
98 }
99 return msg
100 }
101
102
103
104 type ExpectedRollback struct {
105 commonExpectation
106 }
107
108
109 func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback {
110 e.err = err
111 return e
112 }
113
114
115 func (e *ExpectedRollback) String() string {
116 msg := "ExpectedRollback => expecting transaction Rollback"
117 if e.err != nil {
118 msg += fmt.Sprintf(", which should return error: %s", e.err)
119 }
120 return msg
121 }
122
123
124
125
126 type ExpectedQuery struct {
127 queryBasedExpectation
128 rows driver.Rows
129 delay time.Duration
130 rowsMustBeClosed bool
131 rowsWereClosed bool
132 }
133
134
135
136
137
138 func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
139 if e.noArgs {
140 panic("WithArgs() and WithoutArgs() must not be used together")
141 }
142 e.args = args
143 return e
144 }
145
146
147
148
149
150 func (e *ExpectedQuery) WithoutArgs() *ExpectedQuery {
151 if len(e.args) > 0 {
152 panic("WithoutArgs() and WithArgs() must not be used together")
153 }
154 e.noArgs = true
155 return e
156 }
157
158
159 func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery {
160 e.rowsMustBeClosed = true
161 return e
162 }
163
164
165 func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
166 e.err = err
167 return e
168 }
169
170
171
172 func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
173 e.delay = duration
174 return e
175 }
176
177
178 func (e *ExpectedQuery) String() string {
179 msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:"
180 msg += "\n - matches sql: '" + e.expectSQL + "'"
181
182 if len(e.args) == 0 {
183 msg += "\n - is without arguments"
184 } else {
185 msg += "\n - is with arguments:\n"
186 for i, arg := range e.args {
187 msg += fmt.Sprintf(" %d - %+v\n", i, arg)
188 }
189 msg = strings.TrimSpace(msg)
190 }
191
192 if e.rows != nil {
193 msg += fmt.Sprintf("\n - %s", e.rows)
194 }
195
196 if e.err != nil {
197 msg += fmt.Sprintf("\n - should return error: %s", e.err)
198 }
199
200 return msg
201 }
202
203
204
205 type ExpectedExec struct {
206 queryBasedExpectation
207 result driver.Result
208 delay time.Duration
209 }
210
211
212
213
214
215 func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec {
216 if len(e.args) > 0 {
217 panic("WithArgs() and WithoutArgs() must not be used together")
218 }
219 e.args = args
220 return e
221 }
222
223
224
225
226
227 func (e *ExpectedExec) WithoutArgs() *ExpectedExec {
228 if len(e.args) > 0 {
229 panic("WithoutArgs() and WithArgs() must not be used together")
230 }
231 e.noArgs = true
232 return e
233 }
234
235
236 func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
237 e.err = err
238 return e
239 }
240
241
242
243 func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec {
244 e.delay = duration
245 return e
246 }
247
248
249 func (e *ExpectedExec) String() string {
250 msg := "ExpectedExec => expecting Exec or ExecContext which:"
251 msg += "\n - matches sql: '" + e.expectSQL + "'"
252
253 if len(e.args) == 0 {
254 msg += "\n - is without arguments"
255 } else {
256 msg += "\n - is with arguments:\n"
257 var margs []string
258 for i, arg := range e.args {
259 margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg))
260 }
261 msg += strings.Join(margs, "\n")
262 }
263
264 if e.result != nil {
265 if res, ok := e.result.(*result); ok {
266 msg += "\n - should return Result having:"
267 msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID)
268 msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected)
269 if res.err != nil {
270 msg += fmt.Sprintf("\n Error: %s", res.err)
271 }
272 }
273 }
274
275 if e.err != nil {
276 msg += fmt.Sprintf("\n - should return error: %s", e.err)
277 }
278
279 return msg
280 }
281
282
283
284
285
286 func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
287 e.result = result
288 return e
289 }
290
291
292
293 type ExpectedPrepare struct {
294 commonExpectation
295 mock *sqlmock
296 expectSQL string
297 statement driver.Stmt
298 closeErr error
299 mustBeClosed bool
300 wasClosed bool
301 delay time.Duration
302 }
303
304
305 func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
306 e.err = err
307 return e
308 }
309
310
311 func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
312 e.closeErr = err
313 return e
314 }
315
316
317
318 func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare {
319 e.delay = duration
320 return e
321 }
322
323
324
325 func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
326 e.mustBeClosed = true
327 return e
328 }
329
330
331
332 func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
333 eq := &ExpectedQuery{}
334 eq.expectSQL = e.expectSQL
335 eq.converter = e.mock.converter
336 e.mock.expected = append(e.mock.expected, eq)
337 return eq
338 }
339
340
341
342 func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
343 eq := &ExpectedExec{}
344 eq.expectSQL = e.expectSQL
345 eq.converter = e.mock.converter
346 e.mock.expected = append(e.mock.expected, eq)
347 return eq
348 }
349
350
351 func (e *ExpectedPrepare) String() string {
352 msg := "ExpectedPrepare => expecting Prepare statement which:"
353 msg += "\n - matches sql: '" + e.expectSQL + "'"
354
355 if e.err != nil {
356 msg += fmt.Sprintf("\n - should return error: %s", e.err)
357 }
358
359 if e.closeErr != nil {
360 msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr)
361 }
362
363 return msg
364 }
365
366
367
368 type queryBasedExpectation struct {
369 commonExpectation
370 expectSQL string
371 converter driver.ValueConverter
372 args []driver.Value
373 noArgs bool
374 }
375
376
377
378 type ExpectedPing struct {
379 commonExpectation
380 delay time.Duration
381 }
382
383
384
385 func (e *ExpectedPing) WillDelayFor(duration time.Duration) *ExpectedPing {
386 e.delay = duration
387 return e
388 }
389
390
391 func (e *ExpectedPing) WillReturnError(err error) *ExpectedPing {
392 e.err = err
393 return e
394 }
395
396
397 func (e *ExpectedPing) String() string {
398 msg := "ExpectedPing => expecting database Ping"
399 if e.err != nil {
400 msg += fmt.Sprintf(", which should return error: %s", e.err)
401 }
402 return msg
403 }
404
View as plain text