1 package pgtype
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/binary"
7 "errors"
8 "fmt"
9 "strings"
10 "unicode"
11 "unicode/utf8"
12
13 "github.com/jackc/pgio"
14 )
15
16
17
18 type Hstore struct {
19 Map map[string]Text
20 Status Status
21 }
22
23 func (dst *Hstore) Set(src interface{}) error {
24 if src == nil {
25 *dst = Hstore{Status: Null}
26 return nil
27 }
28
29 if value, ok := src.(interface{ Get() interface{} }); ok {
30 value2 := value.Get()
31 if value2 != value {
32 return dst.Set(value2)
33 }
34 }
35
36 switch value := src.(type) {
37 case map[string]string:
38 m := make(map[string]Text, len(value))
39 for k, v := range value {
40 m[k] = Text{String: v, Status: Present}
41 }
42 *dst = Hstore{Map: m, Status: Present}
43 case map[string]*string:
44 m := make(map[string]Text, len(value))
45 for k, v := range value {
46 if v == nil {
47 m[k] = Text{Status: Null}
48 } else {
49 m[k] = Text{String: *v, Status: Present}
50 }
51 }
52 *dst = Hstore{Map: m, Status: Present}
53 case map[string]Text:
54 *dst = Hstore{Map: value, Status: Present}
55 default:
56 return fmt.Errorf("cannot convert %v to Hstore", src)
57 }
58
59 return nil
60 }
61
62 func (dst Hstore) Get() interface{} {
63 switch dst.Status {
64 case Present:
65 return dst.Map
66 case Null:
67 return nil
68 default:
69 return dst.Status
70 }
71 }
72
73 func (src *Hstore) AssignTo(dst interface{}) error {
74 switch src.Status {
75 case Present:
76 switch v := dst.(type) {
77 case *map[string]string:
78 *v = make(map[string]string, len(src.Map))
79 for k, val := range src.Map {
80 if val.Status != Present {
81 return fmt.Errorf("cannot decode %#v into %T", src, dst)
82 }
83 (*v)[k] = val.String
84 }
85 return nil
86 case *map[string]*string:
87 *v = make(map[string]*string, len(src.Map))
88 for k, val := range src.Map {
89 switch val.Status {
90 case Null:
91 (*v)[k] = nil
92 case Present:
93 str := val.String
94 (*v)[k] = &str
95 default:
96 return fmt.Errorf("cannot decode %#v into %T", src, dst)
97 }
98 }
99 return nil
100 default:
101 if nextDst, retry := GetAssignToDstType(dst); retry {
102 return src.AssignTo(nextDst)
103 }
104 return fmt.Errorf("unable to assign to %T", dst)
105 }
106 case Null:
107 return NullAssignTo(dst)
108 }
109
110 return fmt.Errorf("cannot decode %#v into %T", src, dst)
111 }
112
113 func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error {
114 if src == nil {
115 *dst = Hstore{Status: Null}
116 return nil
117 }
118
119 keys, values, err := parseHstore(string(src))
120 if err != nil {
121 return err
122 }
123
124 m := make(map[string]Text, len(keys))
125 for i := range keys {
126 m[keys[i]] = values[i]
127 }
128
129 *dst = Hstore{Map: m, Status: Present}
130 return nil
131 }
132
133 func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error {
134 if src == nil {
135 *dst = Hstore{Status: Null}
136 return nil
137 }
138
139 rp := 0
140
141 if len(src[rp:]) < 4 {
142 return fmt.Errorf("hstore incomplete %v", src)
143 }
144 pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
145 rp += 4
146
147 m := make(map[string]Text, pairCount)
148
149 for i := 0; i < pairCount; i++ {
150 if len(src[rp:]) < 4 {
151 return fmt.Errorf("hstore incomplete %v", src)
152 }
153 keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
154 rp += 4
155
156 if len(src[rp:]) < keyLen {
157 return fmt.Errorf("hstore incomplete %v", src)
158 }
159 key := string(src[rp : rp+keyLen])
160 rp += keyLen
161
162 if len(src[rp:]) < 4 {
163 return fmt.Errorf("hstore incomplete %v", src)
164 }
165 valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
166 rp += 4
167
168 var valueBuf []byte
169 if valueLen >= 0 {
170 valueBuf = src[rp : rp+valueLen]
171 rp += valueLen
172 }
173
174 var value Text
175 err := value.DecodeBinary(ci, valueBuf)
176 if err != nil {
177 return err
178 }
179 m[key] = value
180 }
181
182 *dst = Hstore{Map: m, Status: Present}
183
184 return nil
185 }
186
187 func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
188 switch src.Status {
189 case Null:
190 return nil, nil
191 case Undefined:
192 return nil, errUndefined
193 }
194
195 firstPair := true
196
197 inElemBuf := make([]byte, 0, 32)
198 for k, v := range src.Map {
199 if firstPair {
200 firstPair = false
201 } else {
202 buf = append(buf, ',')
203 }
204
205 buf = append(buf, quoteHstoreElementIfNeeded(k)...)
206 buf = append(buf, "=>"...)
207
208 elemBuf, err := v.EncodeText(ci, inElemBuf)
209 if err != nil {
210 return nil, err
211 }
212
213 if elemBuf == nil {
214 buf = append(buf, "NULL"...)
215 } else {
216 buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...)
217 }
218 }
219
220 return buf, nil
221 }
222
223 func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
224 switch src.Status {
225 case Null:
226 return nil, nil
227 case Undefined:
228 return nil, errUndefined
229 }
230
231 buf = pgio.AppendInt32(buf, int32(len(src.Map)))
232
233 var err error
234 for k, v := range src.Map {
235 buf = pgio.AppendInt32(buf, int32(len(k)))
236 buf = append(buf, k...)
237
238 sp := len(buf)
239 buf = pgio.AppendInt32(buf, -1)
240
241 elemBuf, err := v.EncodeText(ci, buf)
242 if err != nil {
243 return nil, err
244 }
245 if elemBuf != nil {
246 buf = elemBuf
247 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
248 }
249 }
250
251 return buf, err
252 }
253
254 var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
255
256 func quoteHstoreElement(src string) string {
257 return `"` + quoteArrayReplacer.Replace(src) + `"`
258 }
259
260 func quoteHstoreElementIfNeeded(src string) string {
261 if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) {
262 return quoteArrayElement(src)
263 }
264 return src
265 }
266
267 const (
268 hsPre = iota
269 hsKey
270 hsSep
271 hsVal
272 hsNul
273 hsNext
274 )
275
276 type hstoreParser struct {
277 str string
278 pos int
279 }
280
281 func newHSP(in string) *hstoreParser {
282 return &hstoreParser{
283 pos: 0,
284 str: in,
285 }
286 }
287
288 func (p *hstoreParser) Consume() (r rune, end bool) {
289 if p.pos >= len(p.str) {
290 end = true
291 return
292 }
293 r, w := utf8.DecodeRuneInString(p.str[p.pos:])
294 p.pos += w
295 return
296 }
297
298 func (p *hstoreParser) Peek() (r rune, end bool) {
299 if p.pos >= len(p.str) {
300 end = true
301 return
302 }
303 r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
304 return
305 }
306
307
308
309
310 func parseHstore(s string) (k []string, v []Text, err error) {
311 if s == "" {
312 return
313 }
314
315 buf := bytes.Buffer{}
316 keys := []string{}
317 values := []Text{}
318 p := newHSP(s)
319
320 r, end := p.Consume()
321 state := hsPre
322
323 for !end {
324 switch state {
325 case hsPre:
326 if r == '"' {
327 state = hsKey
328 } else {
329 err = errors.New("String does not begin with \"")
330 }
331 case hsKey:
332 switch r {
333 case '"':
334 keys = append(keys, buf.String())
335 buf = bytes.Buffer{}
336 state = hsSep
337 case '\\':
338 n, end := p.Consume()
339 switch {
340 case end:
341 err = errors.New("Found EOS in key, expecting character or \"")
342 case n == '"', n == '\\':
343 buf.WriteRune(n)
344 default:
345 buf.WriteRune(r)
346 buf.WriteRune(n)
347 }
348 default:
349 buf.WriteRune(r)
350 }
351 case hsSep:
352 if r == '=' {
353 r, end = p.Consume()
354 switch {
355 case end:
356 err = errors.New("Found EOS after '=', expecting '>'")
357 case r == '>':
358 r, end = p.Consume()
359 switch {
360 case end:
361 err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
362 case r == '"':
363 state = hsVal
364 case r == 'N':
365 state = hsNul
366 default:
367 err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
368 }
369 default:
370 err = fmt.Errorf("Invalid character after '=', expecting '>'")
371 }
372 } else {
373 err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
374 }
375 case hsVal:
376 switch r {
377 case '"':
378 values = append(values, Text{String: buf.String(), Status: Present})
379 buf = bytes.Buffer{}
380 state = hsNext
381 case '\\':
382 n, end := p.Consume()
383 switch {
384 case end:
385 err = errors.New("Found EOS in key, expecting character or \"")
386 case n == '"', n == '\\':
387 buf.WriteRune(n)
388 default:
389 buf.WriteRune(r)
390 buf.WriteRune(n)
391 }
392 default:
393 buf.WriteRune(r)
394 }
395 case hsNul:
396 nulBuf := make([]rune, 3)
397 nulBuf[0] = r
398 for i := 1; i < 3; i++ {
399 r, end = p.Consume()
400 if end {
401 err = errors.New("Found EOS in NULL value")
402 return
403 }
404 nulBuf[i] = r
405 }
406 if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
407 values = append(values, Text{Status: Null})
408 state = hsNext
409 } else {
410 err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
411 }
412 case hsNext:
413 if r == ',' {
414 r, end = p.Consume()
415 switch {
416 case end:
417 err = errors.New("Found EOS after ',', expecting space")
418 case (unicode.IsSpace(r)):
419 r, end = p.Consume()
420 state = hsKey
421 default:
422 err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
423 }
424 } else {
425 err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
426 }
427 }
428
429 if err != nil {
430 return
431 }
432 r, end = p.Consume()
433 }
434 if state != hsNext {
435 err = errors.New("Improperly formatted hstore")
436 return
437 }
438 k = keys
439 v = values
440 return
441 }
442
443
444 func (dst *Hstore) Scan(src interface{}) error {
445 if src == nil {
446 *dst = Hstore{Status: Null}
447 return nil
448 }
449
450 switch src := src.(type) {
451 case string:
452 return dst.DecodeText(nil, []byte(src))
453 case []byte:
454 srcCopy := make([]byte, len(src))
455 copy(srcCopy, src)
456 return dst.DecodeText(nil, srcCopy)
457 }
458
459 return fmt.Errorf("cannot scan %T", src)
460 }
461
462
463 func (src Hstore) Value() (driver.Value, error) {
464 return EncodeValueText(src)
465 }
466
View as plain text