1 package pgx
2
3 import (
4 "context"
5 "strconv"
6 "strings"
7 "unicode/utf8"
8 )
9
10
11
12
13
14
15
16
17
18
19
20 type NamedArgs map[string]any
21
22
23 func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
24 l := &sqlLexer{
25 src: sql,
26 stateFn: rawState,
27 nameToOrdinal: make(map[namedArg]int, len(na)),
28 }
29
30 for l.stateFn != nil {
31 l.stateFn = l.stateFn(l)
32 }
33
34 sb := strings.Builder{}
35 for _, p := range l.parts {
36 switch p := p.(type) {
37 case string:
38 sb.WriteString(p)
39 case namedArg:
40 sb.WriteRune('$')
41 sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
42 }
43 }
44
45 newArgs = make([]any, len(l.nameToOrdinal))
46 for name, ordinal := range l.nameToOrdinal {
47 newArgs[ordinal-1] = na[string(name)]
48 }
49
50 return sb.String(), newArgs, nil
51 }
52
53 type namedArg string
54
55 type sqlLexer struct {
56 src string
57 start int
58 pos int
59 nested int
60 stateFn stateFn
61 parts []any
62
63 nameToOrdinal map[namedArg]int
64 }
65
66 type stateFn func(*sqlLexer) stateFn
67
68 func rawState(l *sqlLexer) stateFn {
69 for {
70 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
71 l.pos += width
72
73 switch r {
74 case 'e', 'E':
75 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
76 if nextRune == '\'' {
77 l.pos += width
78 return escapeStringState
79 }
80 case '\'':
81 return singleQuoteState
82 case '"':
83 return doubleQuoteState
84 case '@':
85 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
86 if isLetter(nextRune) || nextRune == '_' {
87 if l.pos-l.start > 0 {
88 l.parts = append(l.parts, l.src[l.start:l.pos-width])
89 }
90 l.start = l.pos
91 return namedArgState
92 }
93 case '-':
94 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
95 if nextRune == '-' {
96 l.pos += width
97 return oneLineCommentState
98 }
99 case '/':
100 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
101 if nextRune == '*' {
102 l.pos += width
103 return multilineCommentState
104 }
105 case utf8.RuneError:
106 if l.pos-l.start > 0 {
107 l.parts = append(l.parts, l.src[l.start:l.pos])
108 l.start = l.pos
109 }
110 return nil
111 }
112 }
113 }
114
115 func isLetter(r rune) bool {
116 return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
117 }
118
119 func namedArgState(l *sqlLexer) stateFn {
120 for {
121 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
122 l.pos += width
123
124 if r == utf8.RuneError {
125 if l.pos-l.start > 0 {
126 na := namedArg(l.src[l.start:l.pos])
127 if _, found := l.nameToOrdinal[na]; !found {
128 l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
129 }
130 l.parts = append(l.parts, na)
131 l.start = l.pos
132 }
133 return nil
134 } else if !(isLetter(r) || (r >= '0' && r <= '9') || r == '_') {
135 l.pos -= width
136 na := namedArg(l.src[l.start:l.pos])
137 if _, found := l.nameToOrdinal[na]; !found {
138 l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
139 }
140 l.parts = append(l.parts, namedArg(na))
141 l.start = l.pos
142 return rawState
143 }
144 }
145 }
146
147 func singleQuoteState(l *sqlLexer) stateFn {
148 for {
149 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
150 l.pos += width
151
152 switch r {
153 case '\'':
154 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
155 if nextRune != '\'' {
156 return rawState
157 }
158 l.pos += width
159 case utf8.RuneError:
160 if l.pos-l.start > 0 {
161 l.parts = append(l.parts, l.src[l.start:l.pos])
162 l.start = l.pos
163 }
164 return nil
165 }
166 }
167 }
168
169 func doubleQuoteState(l *sqlLexer) stateFn {
170 for {
171 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
172 l.pos += width
173
174 switch r {
175 case '"':
176 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
177 if nextRune != '"' {
178 return rawState
179 }
180 l.pos += width
181 case utf8.RuneError:
182 if l.pos-l.start > 0 {
183 l.parts = append(l.parts, l.src[l.start:l.pos])
184 l.start = l.pos
185 }
186 return nil
187 }
188 }
189 }
190
191 func escapeStringState(l *sqlLexer) stateFn {
192 for {
193 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
194 l.pos += width
195
196 switch r {
197 case '\\':
198 _, width = utf8.DecodeRuneInString(l.src[l.pos:])
199 l.pos += width
200 case '\'':
201 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
202 if nextRune != '\'' {
203 return rawState
204 }
205 l.pos += width
206 case utf8.RuneError:
207 if l.pos-l.start > 0 {
208 l.parts = append(l.parts, l.src[l.start:l.pos])
209 l.start = l.pos
210 }
211 return nil
212 }
213 }
214 }
215
216 func oneLineCommentState(l *sqlLexer) stateFn {
217 for {
218 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
219 l.pos += width
220
221 switch r {
222 case '\\':
223 _, width = utf8.DecodeRuneInString(l.src[l.pos:])
224 l.pos += width
225 case '\n', '\r':
226 return rawState
227 case utf8.RuneError:
228 if l.pos-l.start > 0 {
229 l.parts = append(l.parts, l.src[l.start:l.pos])
230 l.start = l.pos
231 }
232 return nil
233 }
234 }
235 }
236
237 func multilineCommentState(l *sqlLexer) stateFn {
238 for {
239 r, width := utf8.DecodeRuneInString(l.src[l.pos:])
240 l.pos += width
241
242 switch r {
243 case '/':
244 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
245 if nextRune == '*' {
246 l.pos += width
247 l.nested++
248 }
249 case '*':
250 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
251 if nextRune != '/' {
252 continue
253 }
254
255 l.pos += width
256 if l.nested == 0 {
257 return rawState
258 }
259 l.nested--
260
261 case utf8.RuneError:
262 if l.pos-l.start > 0 {
263 l.parts = append(l.parts, l.src[l.start:l.pos])
264 l.start = l.pos
265 }
266 return nil
267 }
268 }
269 }
270
View as plain text