1
2
3 package sqlmock
4
5 import (
6 "context"
7 "database/sql/driver"
8 "errors"
9 "fmt"
10 "log"
11 "time"
12 )
13
14
15 type Sqlmock interface {
16
17 SqlmockCommon
18
19
20
21 NewRowsWithColumnDefinition(columns ...*Column) *Rows
22
23
24 NewColumn(name string) *Column
25 }
26
27
28
29 var ErrCancelled = errors.New("canceling query due to user request")
30
31
32 func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
33 ex, err := c.query(query, args)
34 if ex != nil {
35 select {
36 case <-time.After(ex.delay):
37 if err != nil {
38 return nil, err
39 }
40 return ex.rows, nil
41 case <-ctx.Done():
42 return nil, ErrCancelled
43 }
44 }
45
46 return nil, err
47 }
48
49
50 func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
51 ex, err := c.exec(query, args)
52 if ex != nil {
53 select {
54 case <-time.After(ex.delay):
55 if err != nil {
56 return nil, err
57 }
58 return ex.result, nil
59 case <-ctx.Done():
60 return nil, ErrCancelled
61 }
62 }
63
64 return nil, err
65 }
66
67
68 func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
69 ex, err := c.begin()
70 if ex != nil {
71 select {
72 case <-time.After(ex.delay):
73 if err != nil {
74 return nil, err
75 }
76 return c, nil
77 case <-ctx.Done():
78 return nil, ErrCancelled
79 }
80 }
81
82 return nil, err
83 }
84
85
86 func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
87 ex, err := c.prepare(query)
88 if ex != nil {
89 select {
90 case <-time.After(ex.delay):
91 if err != nil {
92 return nil, err
93 }
94 return &statement{c, ex, query}, nil
95 case <-ctx.Done():
96 return nil, ErrCancelled
97 }
98 }
99
100 return nil, err
101 }
102
103
104 func (c *sqlmock) Ping(ctx context.Context) error {
105 if !c.monitorPings {
106 return nil
107 }
108
109 ex, err := c.ping()
110 if ex != nil {
111 select {
112 case <-ctx.Done():
113 return ErrCancelled
114 case <-time.After(ex.delay):
115 }
116 }
117
118 return err
119 }
120
121 func (c *sqlmock) ping() (*ExpectedPing, error) {
122 var expected *ExpectedPing
123 var fulfilled int
124 var ok bool
125 for _, next := range c.expected {
126 next.Lock()
127 if next.fulfilled() {
128 next.Unlock()
129 fulfilled++
130 continue
131 }
132
133 if expected, ok = next.(*ExpectedPing); ok {
134 break
135 }
136
137 next.Unlock()
138 if c.ordered {
139 return nil, fmt.Errorf("call to database Ping, was not expected, next expectation is: %s", next)
140 }
141 }
142
143 if expected == nil {
144 msg := "call to database Ping was not expected"
145 if fulfilled == len(c.expected) {
146 msg = "all expectations were already fulfilled, " + msg
147 }
148 return nil, fmt.Errorf(msg)
149 }
150
151 expected.triggered = true
152 expected.Unlock()
153 return expected, expected.err
154 }
155
156
157 func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
158 return stmt.conn.ExecContext(ctx, stmt.query, args)
159 }
160
161
162 func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
163 return stmt.conn.QueryContext(ctx, stmt.query, args)
164 }
165
166 func (c *sqlmock) ExpectPing() *ExpectedPing {
167 if !c.monitorPings {
168 log.Println("ExpectPing will have no effect as monitoring pings is disabled. Use MonitorPingsOption to enable.")
169 return nil
170 }
171 e := &ExpectedPing{}
172 c.expected = append(c.expected, e)
173 return e
174 }
175
176
177
178 func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
179 namedArgs := make([]driver.NamedValue, len(args))
180 for i, v := range args {
181 namedArgs[i] = driver.NamedValue{
182 Ordinal: i + 1,
183 Value: v,
184 }
185 }
186
187 ex, err := c.query(query, namedArgs)
188 if ex != nil {
189 time.Sleep(ex.delay)
190 }
191 if err != nil {
192 return nil, err
193 }
194
195 return ex.rows, nil
196 }
197
198 func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery, error) {
199 var expected *ExpectedQuery
200 var fulfilled int
201 var ok bool
202 for _, next := range c.expected {
203 next.Lock()
204 if next.fulfilled() {
205 next.Unlock()
206 fulfilled++
207 continue
208 }
209
210 if c.ordered {
211 if expected, ok = next.(*ExpectedQuery); ok {
212 break
213 }
214 next.Unlock()
215 return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
216 }
217 if qr, ok := next.(*ExpectedQuery); ok {
218 if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil {
219 next.Unlock()
220 continue
221 }
222 if err := qr.attemptArgMatch(args); err == nil {
223 expected = qr
224 break
225 }
226 }
227 next.Unlock()
228 }
229
230 if expected == nil {
231 msg := "call to Query '%s' with args %+v was not expected"
232 if fulfilled == len(c.expected) {
233 msg = "all expectations were already fulfilled, " + msg
234 }
235 return nil, fmt.Errorf(msg, query, args)
236 }
237
238 defer expected.Unlock()
239
240 if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
241 return nil, fmt.Errorf("Query: %v", err)
242 }
243
244 if err := expected.argsMatches(args); err != nil {
245 return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err)
246 }
247
248 expected.triggered = true
249 if expected.err != nil {
250 return expected, expected.err
251 }
252
253 if expected.rows == nil {
254 return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
255 }
256 return expected, nil
257 }
258
259
260
261 func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
262 namedArgs := make([]driver.NamedValue, len(args))
263 for i, v := range args {
264 namedArgs[i] = driver.NamedValue{
265 Ordinal: i + 1,
266 Value: v,
267 }
268 }
269
270 ex, err := c.exec(query, namedArgs)
271 if ex != nil {
272 time.Sleep(ex.delay)
273 }
274 if err != nil {
275 return nil, err
276 }
277
278 return ex.result, nil
279 }
280
281 func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, error) {
282 var expected *ExpectedExec
283 var fulfilled int
284 var ok bool
285 for _, next := range c.expected {
286 next.Lock()
287 if next.fulfilled() {
288 next.Unlock()
289 fulfilled++
290 continue
291 }
292
293 if c.ordered {
294 if expected, ok = next.(*ExpectedExec); ok {
295 break
296 }
297 next.Unlock()
298 return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
299 }
300 if exec, ok := next.(*ExpectedExec); ok {
301 if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil {
302 next.Unlock()
303 continue
304 }
305
306 if err := exec.attemptArgMatch(args); err == nil {
307 expected = exec
308 break
309 }
310 }
311 next.Unlock()
312 }
313 if expected == nil {
314 msg := "call to ExecQuery '%s' with args %+v was not expected"
315 if fulfilled == len(c.expected) {
316 msg = "all expectations were already fulfilled, " + msg
317 }
318 return nil, fmt.Errorf(msg, query, args)
319 }
320 defer expected.Unlock()
321
322 if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
323 return nil, fmt.Errorf("ExecQuery: %v", err)
324 }
325
326 if err := expected.argsMatches(args); err != nil {
327 return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err)
328 }
329
330 expected.triggered = true
331 if expected.err != nil {
332 return expected, expected.err
333 }
334
335 if expected.result == nil {
336 return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
337 }
338
339 return expected, nil
340 }
341
342
343
344
345
346 func (c *sqlmock) NewRowsWithColumnDefinition(columns ...*Column) *Rows {
347 r := NewRowsWithColumnDefinition(columns...)
348 r.converter = c.converter
349 return r
350 }
351
352
353
354 func (c *sqlmock) NewColumn(name string) *Column {
355 return NewColumn(name)
356 }
357
View as plain text