1 package goqu
2
3 import (
4 "context"
5 "database/sql"
6 "sync"
7
8 "github.com/doug-martin/goqu/v9/exec"
9 )
10
11 type (
12 Logger interface {
13 Printf(format string, v ...interface{})
14 }
15
16
17 SQLDatabase interface {
18 Begin() (*sql.Tx, error)
19 BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
20 ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
21 PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
22 QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
23 QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
24 }
25
26
27 Database struct {
28 logger Logger
29 dialect string
30
31 Db SQLDatabase
32 qf exec.QueryFactory
33 qfOnce sync.Once
34 }
35 )
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65 func newDatabase(dialect string, db SQLDatabase) *Database {
66 return &Database{
67 logger: nil,
68 dialect: dialect,
69 Db: db,
70 qf: nil,
71 qfOnce: sync.Once{},
72 }
73 }
74
75
76 func (d *Database) Dialect() string {
77 return d.dialect
78 }
79
80
81 func (d *Database) Begin() (*TxDatabase, error) {
82 sqlTx, err := d.Db.Begin()
83 if err != nil {
84 return nil, err
85 }
86 tx := NewTx(d.dialect, sqlTx)
87 tx.Logger(d.logger)
88 return tx, nil
89 }
90
91
92 func (d *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*TxDatabase, error) {
93 sqlTx, err := d.Db.BeginTx(ctx, opts)
94 if err != nil {
95 return nil, err
96 }
97 tx := NewTx(d.dialect, sqlTx)
98 tx.Logger(d.logger)
99 return tx, nil
100 }
101
102
103 func (d *Database) WithTx(fn func(*TxDatabase) error) error {
104 tx, err := d.Begin()
105 if err != nil {
106 return err
107 }
108 return tx.Wrap(func() error { return fn(tx) })
109 }
110
111
112
113
114
115
116
117
118
119 func (d *Database) From(from ...interface{}) *SelectDataset {
120 return newDataset(d.dialect, d.queryFactory()).From(from...)
121 }
122
123 func (d *Database) Select(cols ...interface{}) *SelectDataset {
124 return newDataset(d.dialect, d.queryFactory()).Select(cols...)
125 }
126
127 func (d *Database) Update(table interface{}) *UpdateDataset {
128 return newUpdateDataset(d.dialect, d.queryFactory()).Table(table)
129 }
130
131 func (d *Database) Insert(table interface{}) *InsertDataset {
132 return newInsertDataset(d.dialect, d.queryFactory()).Into(table)
133 }
134
135 func (d *Database) Delete(table interface{}) *DeleteDataset {
136 return newDeleteDataset(d.dialect, d.queryFactory()).From(table)
137 }
138
139 func (d *Database) Truncate(table ...interface{}) *TruncateDataset {
140 return newTruncateDataset(d.dialect, d.queryFactory()).Table(table...)
141 }
142
143
144 func (d *Database) Logger(logger Logger) {
145 d.logger = logger
146 }
147
148
149 func (d *Database) Trace(op, sqlString string, args ...interface{}) {
150 if d.logger != nil {
151 if sqlString != "" {
152 if len(args) != 0 {
153 d.logger.Printf("[goqu] %s [query:=`%s` args:=%+v]", op, sqlString, args)
154 } else {
155 d.logger.Printf("[goqu] %s [query:=`%s`]", op, sqlString)
156 }
157 } else {
158 d.logger.Printf("[goqu] %s", op)
159 }
160 }
161 }
162
163
164
165
166
167
168 func (d *Database) Exec(query string, args ...interface{}) (sql.Result, error) {
169 return d.ExecContext(context.Background(), query, args...)
170 }
171
172
173
174
175
176
177 func (d *Database) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
178 d.Trace("EXEC", query, args...)
179 return d.Db.ExecContext(ctx, query, args...)
180 }
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207 func (d *Database) Prepare(query string) (*sql.Stmt, error) {
208 return d.PrepareContext(context.Background(), query)
209 }
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236 func (d *Database) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
237 d.Trace("PREPARE", query)
238 return d.Db.PrepareContext(ctx, query)
239 }
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263 func (d *Database) Query(query string, args ...interface{}) (*sql.Rows, error) {
264 return d.QueryContext(context.Background(), query, args...)
265 }
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289 func (d *Database) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
290 d.Trace("QUERY", query, args...)
291 return d.Db.QueryContext(ctx, query, args...)
292 }
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310 func (d *Database) QueryRow(query string, args ...interface{}) *sql.Row {
311 return d.QueryRowContext(context.Background(), query, args...)
312 }
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330 func (d *Database) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
331 d.Trace("QUERY ROW", query, args...)
332 return d.Db.QueryRowContext(ctx, query, args...)
333 }
334
335 func (d *Database) queryFactory() exec.QueryFactory {
336 d.qfOnce.Do(func() {
337 d.qf = exec.NewQueryFactory(d)
338 })
339 return d.qf
340 }
341
342
343
344
345
346
347
348
349
350 func (d *Database) ScanStructs(i interface{}, query string, args ...interface{}) error {
351 return d.ScanStructsContext(context.Background(), i, query, args...)
352 }
353
354
355
356
357
358
359
360
361
362 func (d *Database) ScanStructsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
363 return d.queryFactory().FromSQL(query, args...).ScanStructsContext(ctx, i)
364 }
365
366
367
368
369
370
371
372
373
374 func (d *Database) ScanStruct(i interface{}, query string, args ...interface{}) (bool, error) {
375 return d.ScanStructContext(context.Background(), i, query, args...)
376 }
377
378
379
380
381
382
383
384
385
386 func (d *Database) ScanStructContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
387 return d.queryFactory().FromSQL(query, args...).ScanStructContext(ctx, i)
388 }
389
390
391
392
393
394
395
396
397
398 func (d *Database) ScanVals(i interface{}, query string, args ...interface{}) error {
399 return d.ScanValsContext(context.Background(), i, query, args...)
400 }
401
402
403
404
405
406
407
408
409
410 func (d *Database) ScanValsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
411 return d.queryFactory().FromSQL(query, args...).ScanValsContext(ctx, i)
412 }
413
414
415
416
417
418
419
420
421
422 func (d *Database) ScanVal(i interface{}, query string, args ...interface{}) (bool, error) {
423 return d.ScanValContext(context.Background(), i, query, args...)
424 }
425
426
427
428
429
430
431
432
433
434 func (d *Database) ScanValContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
435 return d.queryFactory().FromSQL(query, args...).ScanValContext(ctx, i)
436 }
437
438
439 type (
440
441
442 SQLTx interface {
443 ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
444 PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
445 QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
446 QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
447 Commit() error
448 Rollback() error
449 }
450 TxDatabase struct {
451 logger Logger
452 dialect string
453 Tx SQLTx
454 qf exec.QueryFactory
455 qfOnce sync.Once
456 }
457 )
458
459
460 func NewTx(dialect string, tx SQLTx) *TxDatabase {
461 return &TxDatabase{dialect: dialect, Tx: tx}
462 }
463
464
465 func (td *TxDatabase) Dialect() string {
466 return td.dialect
467 }
468
469
470 func (td *TxDatabase) From(cols ...interface{}) *SelectDataset {
471 return newDataset(td.dialect, td.queryFactory()).From(cols...)
472 }
473
474 func (td *TxDatabase) Select(cols ...interface{}) *SelectDataset {
475 return newDataset(td.dialect, td.queryFactory()).Select(cols...)
476 }
477
478 func (td *TxDatabase) Update(table interface{}) *UpdateDataset {
479 return newUpdateDataset(td.dialect, td.queryFactory()).Table(table)
480 }
481
482 func (td *TxDatabase) Insert(table interface{}) *InsertDataset {
483 return newInsertDataset(td.dialect, td.queryFactory()).Into(table)
484 }
485
486 func (td *TxDatabase) Delete(table interface{}) *DeleteDataset {
487 return newDeleteDataset(td.dialect, td.queryFactory()).From(table)
488 }
489
490 func (td *TxDatabase) Truncate(table ...interface{}) *TruncateDataset {
491 return newTruncateDataset(td.dialect, td.queryFactory()).Table(table...)
492 }
493
494
495 func (td *TxDatabase) Logger(logger Logger) {
496 td.logger = logger
497 }
498
499 func (td *TxDatabase) Trace(op, sqlString string, args ...interface{}) {
500 if td.logger != nil {
501 if sqlString != "" {
502 if len(args) != 0 {
503 td.logger.Printf("[goqu - transaction] %s [query:=`%s` args:=%+v] ", op, sqlString, args)
504 } else {
505 td.logger.Printf("[goqu - transaction] %s [query:=`%s`] ", op, sqlString)
506 }
507 } else {
508 td.logger.Printf("[goqu - transaction] %s", op)
509 }
510 }
511 }
512
513
514 func (td *TxDatabase) Exec(query string, args ...interface{}) (sql.Result, error) {
515 return td.ExecContext(context.Background(), query, args...)
516 }
517
518
519 func (td *TxDatabase) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
520 td.Trace("EXEC", query, args...)
521 return td.Tx.ExecContext(ctx, query, args...)
522 }
523
524
525 func (td *TxDatabase) Prepare(query string) (*sql.Stmt, error) {
526 return td.PrepareContext(context.Background(), query)
527 }
528
529
530 func (td *TxDatabase) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
531 td.Trace("PREPARE", query)
532 return td.Tx.PrepareContext(ctx, query)
533 }
534
535
536 func (td *TxDatabase) Query(query string, args ...interface{}) (*sql.Rows, error) {
537 return td.QueryContext(context.Background(), query, args...)
538 }
539
540
541 func (td *TxDatabase) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
542 td.Trace("QUERY", query, args...)
543 return td.Tx.QueryContext(ctx, query, args...)
544 }
545
546
547 func (td *TxDatabase) QueryRow(query string, args ...interface{}) *sql.Row {
548 return td.QueryRowContext(context.Background(), query, args...)
549 }
550
551
552 func (td *TxDatabase) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
553 td.Trace("QUERY ROW", query, args...)
554 return td.Tx.QueryRowContext(ctx, query, args...)
555 }
556
557 func (td *TxDatabase) queryFactory() exec.QueryFactory {
558 td.qfOnce.Do(func() {
559 td.qf = exec.NewQueryFactory(td)
560 })
561 return td.qf
562 }
563
564
565 func (td *TxDatabase) ScanStructs(i interface{}, query string, args ...interface{}) error {
566 return td.ScanStructsContext(context.Background(), i, query, args...)
567 }
568
569
570 func (td *TxDatabase) ScanStructsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
571 return td.queryFactory().FromSQL(query, args...).ScanStructsContext(ctx, i)
572 }
573
574
575 func (td *TxDatabase) ScanStruct(i interface{}, query string, args ...interface{}) (bool, error) {
576 return td.ScanStructContext(context.Background(), i, query, args...)
577 }
578
579
580 func (td *TxDatabase) ScanStructContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
581 return td.queryFactory().FromSQL(query, args...).ScanStructContext(ctx, i)
582 }
583
584
585 func (td *TxDatabase) ScanVals(i interface{}, query string, args ...interface{}) error {
586 return td.ScanValsContext(context.Background(), i, query, args...)
587 }
588
589
590 func (td *TxDatabase) ScanValsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
591 return td.queryFactory().FromSQL(query, args...).ScanValsContext(ctx, i)
592 }
593
594
595 func (td *TxDatabase) ScanVal(i interface{}, query string, args ...interface{}) (bool, error) {
596 return td.ScanValContext(context.Background(), i, query, args...)
597 }
598
599
600 func (td *TxDatabase) ScanValContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
601 return td.queryFactory().FromSQL(query, args...).ScanValContext(ctx, i)
602 }
603
604
605 func (td *TxDatabase) Commit() error {
606 td.Trace("COMMIT", "")
607 return td.Tx.Commit()
608 }
609
610
611 func (td *TxDatabase) Rollback() error {
612 td.Trace("ROLLBACK", "")
613 return td.Tx.Rollback()
614 }
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631 func (td *TxDatabase) Wrap(fn func() error) (err error) {
632 defer func() {
633 if p := recover(); p != nil {
634 _ = td.Rollback()
635 panic(p)
636 }
637 if err != nil {
638 if rollbackErr := td.Rollback(); rollbackErr != nil {
639 err = rollbackErr
640 }
641 } else {
642 if commitErr := td.Commit(); commitErr != nil {
643 err = commitErr
644 }
645 }
646 }()
647 return fn()
648 }
649
View as plain text