1 package db
2
3 import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "reflect"
9 "regexp"
10
11 "github.com/go-sql-driver/mysql"
12 "github.com/letsencrypt/borp"
13 )
14
15
16
17
18 type ErrDatabaseOp struct {
19 Op string
20 Table string
21 Err error
22 }
23
24
25
26 func (e ErrDatabaseOp) Error() string {
27
28 if e.Table != "" {
29 return fmt.Sprintf(
30 "failed to %s %s: %s",
31 e.Op,
32 e.Table,
33 e.Err)
34 }
35 return fmt.Sprintf(
36 "failed to %s: %s",
37 e.Op,
38 e.Err)
39 }
40
41
42 func (e ErrDatabaseOp) Unwrap() error {
43 return e.Err
44 }
45
46
47
48
49 func IsNoRows(err error) bool {
50 return errors.Is(err, sql.ErrNoRows)
51 }
52
53
54
55
56 func IsDuplicate(err error) bool {
57 var dbErr *mysql.MySQLError
58 return errors.As(err, &dbErr) && dbErr.Number == 1062
59 }
60
61
62
63 type WrappedMap struct {
64 dbMap *borp.DbMap
65 }
66
67 func NewWrappedMap(dbMap *borp.DbMap) *WrappedMap {
68 return &WrappedMap{dbMap: dbMap}
69 }
70
71 func (m *WrappedMap) TableFor(t reflect.Type, checkPK bool) (*borp.TableMap, error) {
72 return m.dbMap.TableFor(t, checkPK)
73 }
74
75 func (m *WrappedMap) Get(ctx context.Context, holder interface{}, keys ...interface{}) (interface{}, error) {
76 return WrappedExecutor{sqlExecutor: m.dbMap}.Get(ctx, holder, keys...)
77 }
78
79 func (m *WrappedMap) Insert(ctx context.Context, list ...interface{}) error {
80 return WrappedExecutor{sqlExecutor: m.dbMap}.Insert(ctx, list...)
81 }
82
83 func (m *WrappedMap) Update(ctx context.Context, list ...interface{}) (int64, error) {
84 return WrappedExecutor{sqlExecutor: m.dbMap}.Update(ctx, list...)
85 }
86
87 func (m *WrappedMap) Delete(ctx context.Context, list ...interface{}) (int64, error) {
88 return WrappedExecutor{sqlExecutor: m.dbMap}.Delete(ctx, list...)
89 }
90
91 func (m *WrappedMap) Select(ctx context.Context, holder interface{}, query string, args ...interface{}) ([]interface{}, error) {
92 return WrappedExecutor{sqlExecutor: m.dbMap}.Select(ctx, holder, query, args...)
93 }
94
95 func (m *WrappedMap) SelectOne(ctx context.Context, holder interface{}, query string, args ...interface{}) error {
96 return WrappedExecutor{sqlExecutor: m.dbMap}.SelectOne(ctx, holder, query, args...)
97 }
98
99 func (m *WrappedMap) SelectNullInt(ctx context.Context, query string, args ...interface{}) (sql.NullInt64, error) {
100 return WrappedExecutor{sqlExecutor: m.dbMap}.SelectNullInt(ctx, query, args...)
101 }
102
103 func (m *WrappedMap) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
104 return WrappedExecutor{sqlExecutor: m.dbMap}.QueryContext(ctx, query, args...)
105 }
106
107 func (m *WrappedMap) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
108 return WrappedExecutor{sqlExecutor: m.dbMap}.QueryRowContext(ctx, query, args...)
109 }
110
111 func (m *WrappedMap) SelectStr(ctx context.Context, query string, args ...interface{}) (string, error) {
112 return WrappedExecutor{sqlExecutor: m.dbMap}.SelectStr(ctx, query, args...)
113 }
114
115 func (m *WrappedMap) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
116 return WrappedExecutor{sqlExecutor: m.dbMap}.ExecContext(ctx, query, args...)
117 }
118
119 func (m *WrappedMap) BeginTx(ctx context.Context) (Transaction, error) {
120 tx, err := m.dbMap.BeginTx(ctx)
121 if err != nil {
122 return tx, ErrDatabaseOp{
123 Op: "begin transaction",
124 Err: err,
125 }
126 }
127 return WrappedTransaction{
128 transaction: tx,
129 }, err
130 }
131
132
133
134
135 type WrappedTransaction struct {
136 transaction *borp.Transaction
137 }
138
139 func (tx WrappedTransaction) Commit() error {
140 return tx.transaction.Commit()
141 }
142
143 func (tx WrappedTransaction) Rollback() error {
144 return tx.transaction.Rollback()
145 }
146
147 func (tx WrappedTransaction) Get(ctx context.Context, holder interface{}, keys ...interface{}) (interface{}, error) {
148 return (WrappedExecutor{sqlExecutor: tx.transaction}).Get(ctx, holder, keys...)
149 }
150
151 func (tx WrappedTransaction) Insert(ctx context.Context, list ...interface{}) error {
152 return (WrappedExecutor{sqlExecutor: tx.transaction}).Insert(ctx, list...)
153 }
154
155 func (tx WrappedTransaction) Update(ctx context.Context, list ...interface{}) (int64, error) {
156 return (WrappedExecutor{sqlExecutor: tx.transaction}).Update(ctx, list...)
157 }
158
159 func (tx WrappedTransaction) Delete(ctx context.Context, list ...interface{}) (int64, error) {
160 return (WrappedExecutor{sqlExecutor: tx.transaction}).Delete(ctx, list...)
161 }
162
163 func (tx WrappedTransaction) Select(ctx context.Context, holder interface{}, query string, args ...interface{}) ([]interface{}, error) {
164 return (WrappedExecutor{sqlExecutor: tx.transaction}).Select(ctx, holder, query, args...)
165 }
166
167 func (tx WrappedTransaction) SelectOne(ctx context.Context, holder interface{}, query string, args ...interface{}) error {
168 return (WrappedExecutor{sqlExecutor: tx.transaction}).SelectOne(ctx, holder, query, args...)
169 }
170
171 func (tx WrappedTransaction) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
172 return (WrappedExecutor{sqlExecutor: tx.transaction}).QueryContext(ctx, query, args...)
173 }
174
175 func (tx WrappedTransaction) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
176 return (WrappedExecutor{sqlExecutor: tx.transaction}).ExecContext(ctx, query, args...)
177 }
178
179
180
181
182 type WrappedExecutor struct {
183 sqlExecutor borp.SqlExecutor
184 }
185
186 func errForOp(operation string, err error, list []interface{}) ErrDatabaseOp {
187 table := "unknown"
188 if len(list) > 0 {
189 table = fmt.Sprintf("%T", list[0])
190 }
191 return ErrDatabaseOp{
192 Op: operation,
193 Table: table,
194 Err: err,
195 }
196 }
197
198 func errForQuery(query, operation string, err error, list []interface{}) ErrDatabaseOp {
199
200 table := tableFromQuery(query)
201 if table == "" && len(list) > 0 {
202
203
204
205 table = fmt.Sprintf("%T (unknown table)", list[0])
206 } else if table == "" {
207
208
209 table = "unknown table"
210 }
211
212 return ErrDatabaseOp{
213 Op: operation,
214 Table: table,
215 Err: err,
216 }
217 }
218
219 func (we WrappedExecutor) Get(ctx context.Context, holder interface{}, keys ...interface{}) (interface{}, error) {
220 res, err := we.sqlExecutor.Get(ctx, holder, keys...)
221 if err != nil {
222 return res, errForOp("get", err, []interface{}{holder})
223 }
224 return res, err
225 }
226
227 func (we WrappedExecutor) Insert(ctx context.Context, list ...interface{}) error {
228 err := we.sqlExecutor.Insert(ctx, list...)
229 if err != nil {
230 return errForOp("insert", err, list)
231 }
232 return nil
233 }
234
235 func (we WrappedExecutor) Update(ctx context.Context, list ...interface{}) (int64, error) {
236 updatedRows, err := we.sqlExecutor.Update(ctx, list...)
237 if err != nil {
238 return updatedRows, errForOp("update", err, list)
239 }
240 return updatedRows, err
241 }
242
243 func (we WrappedExecutor) Delete(ctx context.Context, list ...interface{}) (int64, error) {
244 deletedRows, err := we.sqlExecutor.Delete(ctx, list...)
245 if err != nil {
246 return deletedRows, errForOp("delete", err, list)
247 }
248 return deletedRows, err
249 }
250
251 func (we WrappedExecutor) Select(ctx context.Context, holder interface{}, query string, args ...interface{}) ([]interface{}, error) {
252 result, err := we.sqlExecutor.Select(ctx, holder, query, args...)
253 if err != nil {
254 return result, errForQuery(query, "select", err, []interface{}{holder})
255 }
256 return result, err
257 }
258
259 func (we WrappedExecutor) SelectOne(ctx context.Context, holder interface{}, query string, args ...interface{}) error {
260 err := we.sqlExecutor.SelectOne(ctx, holder, query, args...)
261 if err != nil {
262 return errForQuery(query, "select one", err, []interface{}{holder})
263 }
264 return nil
265 }
266
267 func (we WrappedExecutor) SelectNullInt(ctx context.Context, query string, args ...interface{}) (sql.NullInt64, error) {
268 rows, err := we.sqlExecutor.SelectNullInt(ctx, query, args...)
269 if err != nil {
270 return sql.NullInt64{}, errForQuery(query, "select", err, nil)
271 }
272 return rows, nil
273 }
274
275 func (we WrappedExecutor) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
276
277
278 return we.sqlExecutor.QueryRowContext(ctx, query, args...)
279 }
280
281 func (we WrappedExecutor) SelectStr(ctx context.Context, query string, args ...interface{}) (string, error) {
282 str, err := we.sqlExecutor.SelectStr(ctx, query, args...)
283 if err != nil {
284 return "", errForQuery(query, "select", err, nil)
285 }
286 return str, nil
287 }
288
289 func (we WrappedExecutor) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
290 rows, err := we.sqlExecutor.QueryContext(ctx, query, args...)
291 if err != nil {
292 return nil, errForQuery(query, "select", err, nil)
293 }
294 return rows, nil
295 }
296
297 var (
298
299 selectTableRegexp = regexp.MustCompile(`(?i)^\s*select\s+[a-z\d:\.\(\), \_\*` + "`" + `]+\s+from\s+([a-z\d\_,` + "`" + `]+)`)
300
301 insertTableRegexp = regexp.MustCompile(`(?i)^\s*insert\s+into\s+([a-z\d \_,` + "`" + `]+)\s+(?:set|\()`)
302
303 updateTableRegexp = regexp.MustCompile(`(?i)^\s*update\s+([a-z\d \_,` + "`" + `]+)\s+set`)
304
305 deleteTableRegexp = regexp.MustCompile(`(?i)^\s*delete\s+from\s+([a-z\d \_,` + "`" + `]+)\s+where`)
306
307
308
309
310
311
312
313 tableRegexps = []*regexp.Regexp{
314 selectTableRegexp,
315 insertTableRegexp,
316 updateTableRegexp,
317 deleteTableRegexp,
318 }
319 )
320
321
322
323
324 func tableFromQuery(query string) string {
325 for _, r := range tableRegexps {
326 if matches := r.FindStringSubmatch(query); len(matches) >= 2 {
327 return matches[1]
328 }
329 }
330 return ""
331 }
332
333 func (we WrappedExecutor) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
334 res, err := we.sqlExecutor.ExecContext(ctx, query, args...)
335 if err != nil {
336 return res, errForQuery(query, "exec", err, args)
337 }
338 return res, nil
339 }
340
View as plain text