1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "strings"
9
10 "github.com/jackc/pgx/v5/internal/pgio"
11 )
12
13 type HstoreScanner interface {
14 ScanHstore(v Hstore) error
15 }
16
17 type HstoreValuer interface {
18 HstoreValue() (Hstore, error)
19 }
20
21
22
23 type Hstore map[string]*string
24
25 func (h *Hstore) ScanHstore(v Hstore) error {
26 *h = v
27 return nil
28 }
29
30 func (h Hstore) HstoreValue() (Hstore, error) {
31 return h, nil
32 }
33
34
35 func (h *Hstore) Scan(src any) error {
36 if src == nil {
37 *h = nil
38 return nil
39 }
40
41 switch src := src.(type) {
42 case string:
43 return scanPlanTextAnyToHstoreScanner{}.scanString(src, h)
44 }
45
46 return fmt.Errorf("cannot scan %T", src)
47 }
48
49
50 func (h Hstore) Value() (driver.Value, error) {
51 if h == nil {
52 return nil, nil
53 }
54
55 buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil)
56 if err != nil {
57 return nil, err
58 }
59 return string(buf), err
60 }
61
62 type HstoreCodec struct{}
63
64 func (HstoreCodec) FormatSupported(format int16) bool {
65 return format == TextFormatCode || format == BinaryFormatCode
66 }
67
68 func (HstoreCodec) PreferredFormat() int16 {
69 return BinaryFormatCode
70 }
71
72 func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
73 if _, ok := value.(HstoreValuer); !ok {
74 return nil
75 }
76
77 switch format {
78 case BinaryFormatCode:
79 return encodePlanHstoreCodecBinary{}
80 case TextFormatCode:
81 return encodePlanHstoreCodecText{}
82 }
83
84 return nil
85 }
86
87 type encodePlanHstoreCodecBinary struct{}
88
89 func (encodePlanHstoreCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
90 hstore, err := value.(HstoreValuer).HstoreValue()
91 if err != nil {
92 return nil, err
93 }
94
95 if hstore == nil {
96 return nil, nil
97 }
98
99 buf = pgio.AppendInt32(buf, int32(len(hstore)))
100
101 for k, v := range hstore {
102 buf = pgio.AppendInt32(buf, int32(len(k)))
103 buf = append(buf, k...)
104
105 if v == nil {
106 buf = pgio.AppendInt32(buf, -1)
107 } else {
108 buf = pgio.AppendInt32(buf, int32(len(*v)))
109 buf = append(buf, (*v)...)
110 }
111 }
112
113 return buf, nil
114 }
115
116 type encodePlanHstoreCodecText struct{}
117
118 func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
119 hstore, err := value.(HstoreValuer).HstoreValue()
120 if err != nil {
121 return nil, err
122 }
123
124 if len(hstore) == 0 {
125
126
127
128
129 if hstore == nil {
130 return nil, nil
131 }
132 return []byte{}, nil
133 }
134
135 firstPair := true
136
137 for k, v := range hstore {
138 if firstPair {
139 firstPair = false
140 } else {
141 buf = append(buf, ',', ' ')
142 }
143
144
145
146
147 buf = append(buf, '"')
148 buf = append(buf, quoteArrayReplacer.Replace(k)...)
149 buf = append(buf, '"')
150 buf = append(buf, "=>"...)
151
152 if v == nil {
153 buf = append(buf, "NULL"...)
154 } else {
155 buf = append(buf, '"')
156 buf = append(buf, quoteArrayReplacer.Replace(*v)...)
157 buf = append(buf, '"')
158 }
159 }
160
161 return buf, nil
162 }
163
164 func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
165
166 switch format {
167 case BinaryFormatCode:
168 switch target.(type) {
169 case HstoreScanner:
170 return scanPlanBinaryHstoreToHstoreScanner{}
171 }
172 case TextFormatCode:
173 switch target.(type) {
174 case HstoreScanner:
175 return scanPlanTextAnyToHstoreScanner{}
176 }
177 }
178
179 return nil
180 }
181
182 type scanPlanBinaryHstoreToHstoreScanner struct{}
183
184 func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
185 scanner := (dst).(HstoreScanner)
186
187 if src == nil {
188 return scanner.ScanHstore(Hstore(nil))
189 }
190
191 rp := 0
192
193 const uint32Len = 4
194 if len(src[rp:]) < uint32Len {
195 return fmt.Errorf("hstore incomplete %v", src)
196 }
197 pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
198 rp += uint32Len
199
200 hstore := make(Hstore, pairCount)
201
202 valueStrings := make([]string, pairCount)
203
204 for i := 0; i < pairCount; i++ {
205 if len(src[rp:]) < uint32Len {
206 return fmt.Errorf("hstore incomplete %v", src)
207 }
208 keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
209 rp += uint32Len
210
211 if len(src[rp:]) < keyLen {
212 return fmt.Errorf("hstore incomplete %v", src)
213 }
214 key := string(src[rp : rp+keyLen])
215 rp += keyLen
216
217 if len(src[rp:]) < uint32Len {
218 return fmt.Errorf("hstore incomplete %v", src)
219 }
220 valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
221 rp += 4
222
223 if valueLen >= 0 {
224 valueStrings[i] = string(src[rp : rp+valueLen])
225 rp += valueLen
226
227 hstore[key] = &valueStrings[i]
228 } else {
229 hstore[key] = nil
230 }
231 }
232
233 return scanner.ScanHstore(hstore)
234 }
235
236 type scanPlanTextAnyToHstoreScanner struct{}
237
238 func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error {
239 scanner := (dst).(HstoreScanner)
240
241 if src == nil {
242 return scanner.ScanHstore(Hstore(nil))
243 }
244 return s.scanString(string(src), scanner)
245 }
246
247
248 func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error {
249 hstore, err := parseHstore(src)
250 if err != nil {
251 return err
252 }
253 return scanner.ScanHstore(hstore)
254 }
255
256 func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
257 return codecDecodeToTextFormat(c, m, oid, format, src)
258 }
259
260 func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
261 if src == nil {
262 return nil, nil
263 }
264
265 var hstore Hstore
266 err := codecScan(c, m, oid, format, src, &hstore)
267 if err != nil {
268 return nil, err
269 }
270 return hstore, nil
271 }
272
273 type hstoreParser struct {
274 str string
275 pos int
276 nextBackslash int
277 }
278
279 func newHSP(in string) *hstoreParser {
280 return &hstoreParser{
281 pos: 0,
282 str: in,
283 nextBackslash: strings.IndexByte(in, '\\'),
284 }
285 }
286
287 func (p *hstoreParser) atEnd() bool {
288 return p.pos >= len(p.str)
289 }
290
291
292 func (p *hstoreParser) consume() (b byte, end bool) {
293 if p.pos >= len(p.str) {
294 return 0, true
295 }
296 b = p.str[p.pos]
297 p.pos++
298 return b, false
299 }
300
301 func unexpectedByteErr(actualB byte, expectedB byte) error {
302 return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB)
303 }
304
305
306 func (p *hstoreParser) consumeExpectedByte(expectedB byte) error {
307 nextB, end := p.consume()
308 if end {
309 return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB)
310 }
311 if nextB != expectedB {
312 return unexpectedByteErr(nextB, expectedB)
313 }
314 return nil
315 }
316
317
318
319 func (p *hstoreParser) consumeExpected2(one byte, two byte) error {
320 if p.pos+2 > len(p.str) {
321 return errors.New("unexpected end of string")
322 }
323 if p.str[p.pos] != one {
324 return unexpectedByteErr(p.str[p.pos], one)
325 }
326 if p.str[p.pos+1] != two {
327 return unexpectedByteErr(p.str[p.pos+1], two)
328 }
329 p.pos += 2
330 return nil
331 }
332
333 var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`)
334
335
336
337 func (p *hstoreParser) consumeDoubleQuoted() (string, error) {
338
339 nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"')
340 if nextDoubleQuote == -1 {
341 return "", errEOSInQuoted
342 }
343 nextDoubleQuote += p.pos
344 if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote {
345
346
347 s := strings.Clone(p.str[p.pos:nextDoubleQuote])
348 p.pos = nextDoubleQuote + 1
349 return s, nil
350 }
351
352
353 s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash)
354 p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\')
355 if p.nextBackslash != -1 {
356 p.nextBackslash += p.pos
357 }
358 return s, err
359 }
360
361
362
363
364 func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) {
365
366 var builder strings.Builder
367 builder.WriteString(p.str[p.pos:firstBackslash])
368
369
370 p.pos = firstBackslash
371
372
373 for {
374 nextB, end := p.consume()
375 if end {
376 return "", errEOSInQuoted
377 } else if nextB == '"' {
378 break
379 } else if nextB == '\\' {
380
381 nextB, end = p.consume()
382 if end {
383 return "", errEOSInQuoted
384 }
385 if !(nextB == '\\' || nextB == '"') {
386 return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB)
387 }
388 builder.WriteByte(nextB)
389 } else {
390
391 builder.WriteByte(nextB)
392 }
393 }
394 return builder.String(), nil
395 }
396
397
398 func (p *hstoreParser) consumePairSeparator() error {
399 return p.consumeExpected2(',', ' ')
400 }
401
402
403 func (p *hstoreParser) consumeKVSeparator() error {
404 return p.consumeExpected2('=', '>')
405 }
406
407
408 func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) {
409
410 if p.atEnd() {
411 return Text{}, errors.New("found end instead of value")
412 }
413 next := p.str[p.pos]
414 if next == 'N' {
415
416 err := p.consumeExpected2('N', 'U')
417 if err != nil {
418 return Text{}, err
419 }
420 err = p.consumeExpected2('L', 'L')
421 if err != nil {
422 return Text{}, err
423 }
424 return Text{String: "", Valid: false}, nil
425 } else if next != '"' {
426 return Text{}, unexpectedByteErr(next, '"')
427 }
428
429
430 p.pos += 1
431 s, err := p.consumeDoubleQuoted()
432 if err != nil {
433 return Text{}, err
434 }
435 return Text{String: s, Valid: true}, nil
436 }
437
438 func parseHstore(s string) (Hstore, error) {
439 p := newHSP(s)
440
441
442
443 numPairsEstimate := strings.Count(s, ">")
444
445 valueStrings := make([]string, 0, numPairsEstimate)
446 result := make(Hstore, numPairsEstimate)
447 first := true
448 for !p.atEnd() {
449 if !first {
450 err := p.consumePairSeparator()
451 if err != nil {
452 return nil, err
453 }
454 } else {
455 first = false
456 }
457
458 err := p.consumeExpectedByte('"')
459 if err != nil {
460 return nil, err
461 }
462
463 key, err := p.consumeDoubleQuoted()
464 if err != nil {
465 return nil, err
466 }
467
468 err = p.consumeKVSeparator()
469 if err != nil {
470 return nil, err
471 }
472
473 value, err := p.consumeDoubleQuotedOrNull()
474 if err != nil {
475 return nil, err
476 }
477 if value.Valid {
478 valueStrings = append(valueStrings, value.String)
479 result[key] = &valueStrings[len(valueStrings)-1]
480 } else {
481 result[key] = nil
482 }
483 }
484
485 return result, nil
486 }
487
View as plain text