1 package pgtype
2
3 import (
4 "encoding/binary"
5 "errors"
6 "fmt"
7 "reflect"
8 "strings"
9
10 "github.com/jackc/pgio"
11 )
12
13 type CompositeTypeField struct {
14 Name string
15 OID uint32
16 }
17
18 type CompositeType struct {
19 status Status
20
21 typeName string
22
23 fields []CompositeTypeField
24 valueTranscoders []ValueTranscoder
25 }
26
27
28
29 func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) {
30 valueTranscoders := make([]ValueTranscoder, len(fields))
31
32 for i := range fields {
33 dt, ok := ci.DataTypeForOID(fields[i].OID)
34 if !ok {
35 return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID)
36 }
37
38 value := NewValue(dt.Value)
39 valueTranscoder, ok := value.(ValueTranscoder)
40 if !ok {
41 return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID)
42 }
43
44 valueTranscoders[i] = valueTranscoder
45 }
46
47 return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil
48 }
49
50
51
52 func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) {
53 if len(fields) != len(values) {
54 return nil, errors.New("fields and valueTranscoders must have same length")
55 }
56
57 return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil
58 }
59
60 func (src CompositeType) Get() interface{} {
61 switch src.status {
62 case Present:
63 results := make(map[string]interface{}, len(src.valueTranscoders))
64 for i := range src.valueTranscoders {
65 results[src.fields[i].Name] = src.valueTranscoders[i].Get()
66 }
67 return results
68 case Null:
69 return nil
70 default:
71 return src.status
72 }
73 }
74
75 func (ct *CompositeType) NewTypeValue() Value {
76 a := &CompositeType{
77 typeName: ct.typeName,
78 fields: ct.fields,
79 valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)),
80 }
81
82 for i := range ct.valueTranscoders {
83 a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder)
84 }
85
86 return a
87 }
88
89 func (ct *CompositeType) TypeName() string {
90 return ct.typeName
91 }
92
93 func (ct *CompositeType) Fields() []CompositeTypeField {
94 return ct.fields
95 }
96
97 func (dst *CompositeType) Set(src interface{}) error {
98 if src == nil {
99 dst.status = Null
100 return nil
101 }
102
103 switch value := src.(type) {
104 case []interface{}:
105 if len(value) != len(dst.valueTranscoders) {
106 return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders))
107 }
108 for i, v := range value {
109 if err := dst.valueTranscoders[i].Set(v); err != nil {
110 return err
111 }
112 }
113 dst.status = Present
114 case *[]interface{}:
115 if value == nil {
116 dst.status = Null
117 return nil
118 }
119 return dst.Set(*value)
120 default:
121 return fmt.Errorf("Can not convert %v to Composite", src)
122 }
123
124 return nil
125 }
126
127
128 func (src CompositeType) AssignTo(dst interface{}) error {
129 switch src.status {
130 case Present:
131 switch v := dst.(type) {
132 case []interface{}:
133 if len(v) != len(src.valueTranscoders) {
134 return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders))
135 }
136 for i := range src.valueTranscoders {
137 if v[i] == nil {
138 continue
139 }
140
141 err := assignToOrSet(src.valueTranscoders[i], v[i])
142 if err != nil {
143 return fmt.Errorf("unable to assign to dst[%d]: %v", i, err)
144 }
145 }
146 return nil
147 case *[]interface{}:
148 return src.AssignTo(*v)
149 default:
150 if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct {
151 return err
152 }
153
154 if nextDst, retry := GetAssignToDstType(dst); retry {
155 return src.AssignTo(nextDst)
156 }
157 return fmt.Errorf("unable to assign to %T", dst)
158 }
159 case Null:
160 return NullAssignTo(dst)
161 }
162 return fmt.Errorf("cannot decode %#v into %T", src, dst)
163 }
164
165 func assignToOrSet(src Value, dst interface{}) error {
166 assignToErr := src.AssignTo(dst)
167 if assignToErr != nil {
168
169 setSucceeded := false
170 if setter, ok := dst.(Value); ok {
171 err := setter.Set(src.Get())
172 setSucceeded = err == nil
173 }
174 if !setSucceeded {
175 return assignToErr
176 }
177 }
178
179 return nil
180 }
181
182 func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
183 dstValue := reflect.ValueOf(dst)
184 if dstValue.Kind() != reflect.Ptr {
185 return false, nil
186 }
187
188 if dstValue.IsNil() {
189 return false, nil
190 }
191
192 dstElemValue := dstValue.Elem()
193 dstElemType := dstElemValue.Type()
194
195 if dstElemType.Kind() != reflect.Struct {
196 return false, nil
197 }
198
199 exportedFields := make([]int, 0, dstElemType.NumField())
200 for i := 0; i < dstElemType.NumField(); i++ {
201 sf := dstElemType.Field(i)
202 if sf.PkgPath == "" {
203 exportedFields = append(exportedFields, i)
204 }
205 }
206
207 if len(exportedFields) != len(src.valueTranscoders) {
208 return false, nil
209 }
210
211 for i := range exportedFields {
212 err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
213 if err != nil {
214 return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err)
215 }
216 }
217
218 return true, nil
219 }
220
221 func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
222 switch src.status {
223 case Null:
224 return nil, nil
225 case Undefined:
226 return nil, errUndefined
227 }
228
229 b := NewCompositeBinaryBuilder(ci, buf)
230 for i := range src.valueTranscoders {
231 b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i])
232 }
233
234 return b.Finish()
235 }
236
237
238
239
240
241 func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
242 if buf == nil {
243 dst.status = Null
244 return nil
245 }
246
247 scanner := NewCompositeBinaryScanner(ci, buf)
248
249 for _, f := range dst.valueTranscoders {
250 scanner.ScanDecoder(f)
251 }
252
253 if scanner.Err() != nil {
254 return scanner.Err()
255 }
256
257 dst.status = Present
258
259 return nil
260 }
261
262 func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
263 if buf == nil {
264 dst.status = Null
265 return nil
266 }
267
268 scanner := NewCompositeTextScanner(ci, buf)
269
270 for _, f := range dst.valueTranscoders {
271 scanner.ScanDecoder(f)
272 }
273
274 if scanner.Err() != nil {
275 return scanner.Err()
276 }
277
278 dst.status = Present
279
280 return nil
281 }
282
283 func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
284 switch src.status {
285 case Null:
286 return nil, nil
287 case Undefined:
288 return nil, errUndefined
289 }
290
291 b := NewCompositeTextBuilder(ci, buf)
292 for _, f := range src.valueTranscoders {
293 b.AppendEncoder(f)
294 }
295
296 return b.Finish()
297 }
298
299 type CompositeBinaryScanner struct {
300 ci *ConnInfo
301 rp int
302 src []byte
303
304 fieldCount int32
305 fieldBytes []byte
306 fieldOID uint32
307 err error
308 }
309
310
311 func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
312 rp := 0
313 if len(src[rp:]) < 4 {
314 return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
315 }
316
317 fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
318 rp += 4
319
320 return &CompositeBinaryScanner{
321 ci: ci,
322 rp: rp,
323 src: src,
324 fieldCount: fieldCount,
325 }
326 }
327
328
329 func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) {
330 if cfs.err != nil {
331 return
332 }
333
334 if cfs.Next() {
335 cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes)
336 } else {
337 cfs.err = errors.New("read past end of composite")
338 }
339 }
340
341
342 func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) {
343 if cfs.err != nil {
344 return
345 }
346
347 if cfs.Next() {
348 cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d)
349 } else {
350 cfs.err = errors.New("read past end of composite")
351 }
352 }
353
354
355
356 func (cfs *CompositeBinaryScanner) Next() bool {
357 if cfs.err != nil {
358 return false
359 }
360
361 if cfs.rp == len(cfs.src) {
362 return false
363 }
364
365 if len(cfs.src[cfs.rp:]) < 8 {
366 cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
367 return false
368 }
369 cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
370 cfs.rp += 4
371
372 fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
373 cfs.rp += 4
374
375 if fieldLen >= 0 {
376 if len(cfs.src[cfs.rp:]) < fieldLen {
377 cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
378 return false
379 }
380 cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
381 cfs.rp += fieldLen
382 } else {
383 cfs.fieldBytes = nil
384 }
385
386 return true
387 }
388
389 func (cfs *CompositeBinaryScanner) FieldCount() int {
390 return int(cfs.fieldCount)
391 }
392
393
394 func (cfs *CompositeBinaryScanner) Bytes() []byte {
395 return cfs.fieldBytes
396 }
397
398
399 func (cfs *CompositeBinaryScanner) OID() uint32 {
400 return cfs.fieldOID
401 }
402
403
404 func (cfs *CompositeBinaryScanner) Err() error {
405 return cfs.err
406 }
407
408 type CompositeTextScanner struct {
409 ci *ConnInfo
410 rp int
411 src []byte
412
413 fieldBytes []byte
414 err error
415 }
416
417
418 func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
419 if len(src) < 2 {
420 return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
421 }
422
423 if src[0] != '(' {
424 return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
425 }
426
427 if src[len(src)-1] != ')' {
428 return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
429 }
430
431 return &CompositeTextScanner{
432 ci: ci,
433 rp: 1,
434 src: src,
435 }
436 }
437
438
439 func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) {
440 if cfs.err != nil {
441 return
442 }
443
444 if cfs.Next() {
445 cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes)
446 } else {
447 cfs.err = errors.New("read past end of composite")
448 }
449 }
450
451
452 func (cfs *CompositeTextScanner) ScanValue(d interface{}) {
453 if cfs.err != nil {
454 return
455 }
456
457 if cfs.Next() {
458 cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d)
459 } else {
460 cfs.err = errors.New("read past end of composite")
461 }
462 }
463
464
465
466 func (cfs *CompositeTextScanner) Next() bool {
467 if cfs.err != nil {
468 return false
469 }
470
471 if cfs.rp == len(cfs.src) {
472 return false
473 }
474
475 switch cfs.src[cfs.rp] {
476 case ',', ')':
477 cfs.rp++
478 cfs.fieldBytes = nil
479 return true
480 case '"':
481 cfs.rp++
482 cfs.fieldBytes = make([]byte, 0, 16)
483 for {
484 ch := cfs.src[cfs.rp]
485
486 if ch == '"' {
487 cfs.rp++
488 if cfs.src[cfs.rp] == '"' {
489 cfs.fieldBytes = append(cfs.fieldBytes, '"')
490 cfs.rp++
491 } else {
492 break
493 }
494 } else if ch == '\\' {
495 cfs.rp++
496 cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
497 cfs.rp++
498 } else {
499 cfs.fieldBytes = append(cfs.fieldBytes, ch)
500 cfs.rp++
501 }
502 }
503 cfs.rp++
504 return true
505 default:
506 start := cfs.rp
507 for {
508 ch := cfs.src[cfs.rp]
509 if ch == ',' || ch == ')' {
510 break
511 }
512 cfs.rp++
513 }
514 cfs.fieldBytes = cfs.src[start:cfs.rp]
515 cfs.rp++
516 return true
517 }
518 }
519
520
521 func (cfs *CompositeTextScanner) Bytes() []byte {
522 return cfs.fieldBytes
523 }
524
525
526 func (cfs *CompositeTextScanner) Err() error {
527 return cfs.err
528 }
529
530 type CompositeBinaryBuilder struct {
531 ci *ConnInfo
532 buf []byte
533 startIdx int
534 fieldCount uint32
535 err error
536 }
537
538 func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder {
539 startIdx := len(buf)
540 buf = append(buf, 0, 0, 0, 0)
541 return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx}
542 }
543
544 func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
545 if b.err != nil {
546 return
547 }
548
549 dt, ok := b.ci.DataTypeForOID(oid)
550 if !ok {
551 b.err = fmt.Errorf("unknown data type for OID: %d", oid)
552 return
553 }
554
555 err := dt.Value.Set(field)
556 if err != nil {
557 b.err = err
558 return
559 }
560
561 binaryEncoder, ok := dt.Value.(BinaryEncoder)
562 if !ok {
563 b.err = fmt.Errorf("unable to encode binary for OID: %d", oid)
564 return
565 }
566
567 b.AppendEncoder(oid, binaryEncoder)
568 }
569
570 func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) {
571 if b.err != nil {
572 return
573 }
574
575 b.buf = pgio.AppendUint32(b.buf, oid)
576 lengthPos := len(b.buf)
577 b.buf = pgio.AppendInt32(b.buf, -1)
578 fieldBuf, err := field.EncodeBinary(b.ci, b.buf)
579 if err != nil {
580 b.err = err
581 return
582 }
583 if fieldBuf != nil {
584 binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
585 b.buf = fieldBuf
586 }
587
588 b.fieldCount++
589 }
590
591 func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
592 if b.err != nil {
593 return nil, b.err
594 }
595
596 binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
597 return b.buf, nil
598 }
599
600 type CompositeTextBuilder struct {
601 ci *ConnInfo
602 buf []byte
603 startIdx int
604 fieldCount uint32
605 err error
606 fieldBuf [32]byte
607 }
608
609 func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder {
610 buf = append(buf, '(')
611 return &CompositeTextBuilder{ci: ci, buf: buf}
612 }
613
614 func (b *CompositeTextBuilder) AppendValue(field interface{}) {
615 if b.err != nil {
616 return
617 }
618
619 if field == nil {
620 b.buf = append(b.buf, ',')
621 return
622 }
623
624 dt, ok := b.ci.DataTypeForValue(field)
625 if !ok {
626 b.err = fmt.Errorf("unknown data type for field: %v", field)
627 return
628 }
629
630 err := dt.Value.Set(field)
631 if err != nil {
632 b.err = err
633 return
634 }
635
636 textEncoder, ok := dt.Value.(TextEncoder)
637 if !ok {
638 b.err = fmt.Errorf("unable to encode text for value: %v", field)
639 return
640 }
641
642 b.AppendEncoder(textEncoder)
643 }
644
645 func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) {
646 if b.err != nil {
647 return
648 }
649
650 fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0])
651 if err != nil {
652 b.err = err
653 return
654 }
655 if fieldBuf != nil {
656 b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
657 }
658
659 b.buf = append(b.buf, ',')
660 }
661
662 func (b *CompositeTextBuilder) Finish() ([]byte, error) {
663 if b.err != nil {
664 return nil, b.err
665 }
666
667 b.buf[len(b.buf)-1] = ')'
668 return b.buf, nil
669 }
670
671 var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
672
673 func quoteCompositeField(src string) string {
674 return `"` + quoteCompositeReplacer.Replace(src) + `"`
675 }
676
677 func quoteCompositeFieldIfNeeded(src string) string {
678 if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
679 return quoteCompositeField(src)
680 }
681 return src
682 }
683
View as plain text