1 package sanitize
2
3 import (
4 "bytes"
5 "encoding/hex"
6 "fmt"
7 "strconv"
8 "strings"
9 "time"
10 "unicode/utf8"
11 )
12
13
14
15 type Part any
16
17 type Query struct {
18 Parts []Part
19 }
20
21
22
23
24
25 const replacementcharacterwidth = 3
26
27 func (q *Query) Sanitize(args ...any) (string, error) {
28 argUse := make([]bool, len(args))
29 buf := &bytes.Buffer{}
30
31 for _, part := range q.Parts {
32 var str string
33 switch part := part.(type) {
34 case string:
35 str = part
36 case int:
37 argIdx := part - 1
38
39 if argIdx < 0 {
40 return "", fmt.Errorf("first sql argument must be > 0")
41 }
42
43 if argIdx >= len(args) {
44 return "", fmt.Errorf("insufficient arguments")
45 }
46 arg := args[argIdx]
47 switch arg := arg.(type) {
48 case nil:
49 str = "null"
50 case int64:
51 str = strconv.FormatInt(arg, 10)
52 case float64:
53 str = strconv.FormatFloat(arg, 'f', -1, 64)
54 case bool:
55 str = strconv.FormatBool(arg)
56 case []byte:
57 str = QuoteBytes(arg)
58 case string:
59 str = QuoteString(arg)
60 case time.Time:
61 str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
62 default:
63 return "", fmt.Errorf("invalid arg type: %T", arg)
64 }
65 argUse[argIdx] = true
66
67
68
69 str = " " + str + " "
70 default:
71 return "", fmt.Errorf("invalid Part type: %T", part)
72 }
73 buf.WriteString(str)
74 }
75
76 for i, used := range argUse {
77 if !used {
78 return "", fmt.Errorf("unused argument: %d", i)
79 }
80 }
81 return buf.String(), nil
82 }
83
84 func NewQuery(sql string) (*Query, error) {
85 l := &sqlLexer{
86 src: sql,
87 stateFn: rawState,
88 }
89
90 for l.stateFn != nil {
91 l.stateFn = l.stateFn(l)
92 }
93
94 query := &Query{Parts: l.parts}
95
96 return query, nil
97 }
98
99 func QuoteString(str string) string {
100 return "'" + strings.ReplaceAll(str, "'", "''") + "'"
101 }
102
103 func QuoteBytes(buf []byte) string {
104 return `'\x` + hex.EncodeToString(buf) + "'"
105 }
106
107 type sqlLexer struct {
108 src string
109 start int
110 pos int
111 nested int
112 stateFn stateFn
113 parts []Part
114 }
115
116 type stateFn func(*sqlLexer) stateFn
117
118 func rawState(l *sqlLexer) stateFn {
119 for {
120 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
121 l.pos += width
122
123 switch r {
124 case 'e', 'E':
125 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
126 if nextRune == '\'' {
127 l.pos += width
128 return escapeStringState
129 }
130 case '\'':
131 return singleQuoteState
132 case '"':
133 return doubleQuoteState
134 case '$':
135 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
136 if '0' <= nextRune && nextRune <= '9' {
137 if l.pos-l.start > 0 {
138 l.parts = append(l.parts, l.src[l.start:l.pos-width])
139 }
140 l.start = l.pos
141 return placeholderState
142 }
143 case '-':
144 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
145 if nextRune == '-' {
146 l.pos += width
147 return oneLineCommentState
148 }
149 case '/':
150 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
151 if nextRune == '*' {
152 l.pos += width
153 return multilineCommentState
154 }
155 case utf8.RuneError:
156 if width != replacementcharacterwidth {
157 if l.pos-l.start > 0 {
158 l.parts = append(l.parts, l.src[l.start:l.pos])
159 l.start = l.pos
160 }
161 return nil
162 }
163 }
164 }
165 }
166
167 func singleQuoteState(l *sqlLexer) stateFn {
168 for {
169 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
170 l.pos += width
171
172 switch r {
173 case '\'':
174 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
175 if nextRune != '\'' {
176 return rawState
177 }
178 l.pos += width
179 case utf8.RuneError:
180 if width != replacementcharacterwidth {
181 if l.pos-l.start > 0 {
182 l.parts = append(l.parts, l.src[l.start:l.pos])
183 l.start = l.pos
184 }
185 return nil
186 }
187 }
188 }
189 }
190
191 func doubleQuoteState(l *sqlLexer) stateFn {
192 for {
193 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
194 l.pos += width
195
196 switch r {
197 case '"':
198 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
199 if nextRune != '"' {
200 return rawState
201 }
202 l.pos += width
203 case utf8.RuneError:
204 if width != replacementcharacterwidth {
205 if l.pos-l.start > 0 {
206 l.parts = append(l.parts, l.src[l.start:l.pos])
207 l.start = l.pos
208 }
209 return nil
210 }
211 }
212 }
213 }
214
215
216
217 func placeholderState(l *sqlLexer) stateFn {
218 num := 0
219
220 for {
221 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
222 l.pos += width
223
224 if '0' <= r && r <= '9' {
225 num *= 10
226 num += int(r - '0')
227 } else {
228 l.parts = append(l.parts, num)
229 l.pos -= width
230 l.start = l.pos
231 return rawState
232 }
233 }
234 }
235
236 func escapeStringState(l *sqlLexer) stateFn {
237 for {
238 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
239 l.pos += width
240
241 switch r {
242 case '\\':
243 _, width = utf8.DecodeRuneInString(l.src[l.pos:])
244 l.pos += width
245 case '\'':
246 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
247 if nextRune != '\'' {
248 return rawState
249 }
250 l.pos += width
251 case utf8.RuneError:
252 if width != replacementcharacterwidth {
253 if l.pos-l.start > 0 {
254 l.parts = append(l.parts, l.src[l.start:l.pos])
255 l.start = l.pos
256 }
257 return nil
258 }
259 }
260 }
261 }
262
263 func oneLineCommentState(l *sqlLexer) stateFn {
264 for {
265 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
266 l.pos += width
267
268 switch r {
269 case '\\':
270 _, width = utf8.DecodeRuneInString(l.src[l.pos:])
271 l.pos += width
272 case '\n', '\r':
273 return rawState
274 case utf8.RuneError:
275 if width != replacementcharacterwidth {
276 if l.pos-l.start > 0 {
277 l.parts = append(l.parts, l.src[l.start:l.pos])
278 l.start = l.pos
279 }
280 return nil
281 }
282 }
283 }
284 }
285
286 func multilineCommentState(l *sqlLexer) stateFn {
287 for {
288 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
289 l.pos += width
290
291 switch r {
292 case '/':
293 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
294 if nextRune == '*' {
295 l.pos += width
296 l.nested++
297 }
298 case '*':
299 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
300 if nextRune != '/' {
301 continue
302 }
303
304 l.pos += width
305 if l.nested == 0 {
306 return rawState
307 }
308 l.nested--
309
310 case utf8.RuneError:
311 if width != replacementcharacterwidth {
312 if l.pos-l.start > 0 {
313 l.parts = append(l.parts, l.src[l.start:l.pos])
314 l.start = l.pos
315 }
316 return nil
317 }
318 }
319 }
320 }
321
322
323
324
325 func SanitizeSQL(sql string, args ...any) (string, error) {
326 query, err := NewQuery(sql)
327 if err != nil {
328 return "", err
329 }
330 return query.Sanitize(args...)
331 }
332
View as plain text