1 package db
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "reflect"
8 "regexp"
9 "strings"
10 )
11
12
13
14 var mariaDBUnquotedIdentifierRE = regexp.MustCompile("^[0-9a-zA-Z$_]+$")
15
16 func validMariaDBUnquotedIdentifier(s string) error {
17 if !mariaDBUnquotedIdentifierRE.MatchString(s) {
18 return fmt.Errorf("invalid MariaDB identifier %q", s)
19 }
20
21 allNumeric := true
22 startsNumeric := false
23 for i, c := range []byte(s) {
24 if c < '0' || c > '9' {
25 if startsNumeric && len(s) > i && s[i] == 'e' {
26 return fmt.Errorf("MariaDB identifier looks like floating point: %q", s)
27 }
28 allNumeric = false
29 break
30 }
31 startsNumeric = true
32 }
33 if allNumeric {
34 return fmt.Errorf("MariaDB identifier contains only numerals: %q", s)
35 }
36 return nil
37 }
38
39
40
41 func NewMappedSelector[T any](executor MappedExecutor) (MappedSelector[T], error) {
42 var throwaway T
43 t := reflect.TypeOf(throwaway)
44
45
46
47
48
49
50
51
52
53
54
55
56 columns := make([]string, 0)
57 seen := make(map[string]struct{})
58 for i := 0; i < t.NumField(); i++ {
59 field := t.Field(i)
60 if field.Anonymous {
61 return nil, fmt.Errorf("struct contains anonymous embedded struct %q", field.Name)
62 }
63 column := strings.ToLower(t.Field(i).Name)
64 err := validMariaDBUnquotedIdentifier(column)
65 if err != nil {
66 return nil, fmt.Errorf("struct field maps to unsafe db column name %q", column)
67 }
68 if _, found := seen[column]; found {
69 return nil, fmt.Errorf("struct fields map to duplicate column name %q", column)
70 }
71 seen[column] = struct{}{}
72 columns = append(columns, column)
73 }
74
75 return &mappedSelector[T]{wrapped: executor, columns: columns}, nil
76 }
77
78 type mappedSelector[T any] struct {
79 wrapped MappedExecutor
80 columns []string
81 }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97 func (ts mappedSelector[T]) QueryContext(ctx context.Context, clauses string, args ...interface{}) (Rows[T], error) {
98
99 var throwaway T
100 tableMap, err := ts.wrapped.TableFor(reflect.TypeOf(throwaway), false)
101 if err != nil {
102 return nil, fmt.Errorf("database model type not mapped to table name: %w", err)
103 }
104
105 return ts.QueryFrom(ctx, tableMap.TableName, clauses, args...)
106 }
107
108
109
110
111
112
113
114
115 func (ts mappedSelector[T]) QueryFrom(ctx context.Context, tablename string, clauses string, args ...interface{}) (Rows[T], error) {
116 err := validMariaDBUnquotedIdentifier(tablename)
117 if err != nil {
118 return nil, err
119 }
120
121
122
123 query := fmt.Sprintf(
124 "SELECT %s FROM %s %s",
125 strings.Join(ts.columns, ", "),
126 tablename,
127 clauses,
128 )
129
130 r, err := ts.wrapped.QueryContext(ctx, query, args...)
131 if err != nil {
132 return nil, fmt.Errorf("reading db: %w", err)
133 }
134
135 return &rows[T]{wrapped: r, numCols: len(ts.columns)}, nil
136 }
137
138
139
140 type rows[T any] struct {
141 wrapped *sql.Rows
142 numCols int
143 }
144
145
146
147 func (r rows[T]) Next() bool {
148 return r.wrapped.Next()
149 }
150
151
152
153
154 func (r rows[T]) Get() (*T, error) {
155 result := new(T)
156 v := reflect.ValueOf(result)
157
158
159
160
161 scanTargets := make([]interface{}, r.numCols)
162 for i := range scanTargets {
163 field := v.Elem().Field(i)
164 scanTargets[i] = field.Addr().Interface()
165 }
166
167 err := r.wrapped.Scan(scanTargets...)
168 if err != nil {
169 return nil, fmt.Errorf("reading db row: %w", err)
170 }
171
172 return result, nil
173 }
174
175
176
177 func (r rows[T]) Err() error {
178 return r.wrapped.Err()
179 }
180
181
182
183 func (r rows[T]) Close() error {
184 return r.wrapped.Close()
185 }
186
View as plain text