1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "strings"
9
10 "github.com/jackc/pgx/v5/internal/pgio"
11 )
12
13
14 type CompositeIndexGetter interface {
15
16 IsNull() bool
17
18
19 Index(i int) any
20 }
21
22
23 type CompositeIndexScanner interface {
24
25 ScanNull() error
26
27
28 ScanIndex(i int) any
29 }
30
31 type CompositeCodecField struct {
32 Name string
33 Type *Type
34 }
35
36 type CompositeCodec struct {
37 Fields []CompositeCodecField
38 }
39
40 func (c *CompositeCodec) FormatSupported(format int16) bool {
41 for _, f := range c.Fields {
42 if !f.Type.Codec.FormatSupported(format) {
43 return false
44 }
45 }
46
47 return true
48 }
49
50 func (c *CompositeCodec) PreferredFormat() int16 {
51 if c.FormatSupported(BinaryFormatCode) {
52 return BinaryFormatCode
53 }
54 return TextFormatCode
55 }
56
57 func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
58 if _, ok := value.(CompositeIndexGetter); !ok {
59 return nil
60 }
61
62 switch format {
63 case BinaryFormatCode:
64 return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m}
65 case TextFormatCode:
66 return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m}
67 }
68
69 return nil
70 }
71
72 type encodePlanCompositeCodecCompositeIndexGetterToBinary struct {
73 cc *CompositeCodec
74 m *Map
75 }
76
77 func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
78 getter := value.(CompositeIndexGetter)
79
80 if getter.IsNull() {
81 return nil, nil
82 }
83
84 builder := NewCompositeBinaryBuilder(plan.m, buf)
85 for i, field := range plan.cc.Fields {
86 builder.AppendValue(field.Type.OID, getter.Index(i))
87 }
88
89 return builder.Finish()
90 }
91
92 type encodePlanCompositeCodecCompositeIndexGetterToText struct {
93 cc *CompositeCodec
94 m *Map
95 }
96
97 func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
98 getter := value.(CompositeIndexGetter)
99
100 if getter.IsNull() {
101 return nil, nil
102 }
103
104 b := NewCompositeTextBuilder(plan.m, buf)
105 for i, field := range plan.cc.Fields {
106 b.AppendValue(field.Type.OID, getter.Index(i))
107 }
108
109 return b.Finish()
110 }
111
112 func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
113 switch format {
114 case BinaryFormatCode:
115 switch target.(type) {
116 case CompositeIndexScanner:
117 return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m}
118 }
119 case TextFormatCode:
120 switch target.(type) {
121 case CompositeIndexScanner:
122 return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m}
123 }
124 }
125
126 return nil
127 }
128
129 type scanPlanBinaryCompositeToCompositeIndexScanner struct {
130 cc *CompositeCodec
131 m *Map
132 }
133
134 func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
135 targetScanner := (target).(CompositeIndexScanner)
136
137 if src == nil {
138 return targetScanner.ScanNull()
139 }
140
141 scanner := NewCompositeBinaryScanner(plan.m, src)
142 for i, field := range plan.cc.Fields {
143 if scanner.Next() {
144 fieldTarget := targetScanner.ScanIndex(i)
145 if fieldTarget != nil {
146 fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget)
147 if fieldPlan == nil {
148 return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID)
149 }
150
151 err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
152 if err != nil {
153 return err
154 }
155 }
156 } else {
157 return errors.New("read past end of composite")
158 }
159 }
160
161 if err := scanner.Err(); err != nil {
162 return err
163 }
164
165 return nil
166 }
167
168 type scanPlanTextCompositeToCompositeIndexScanner struct {
169 cc *CompositeCodec
170 m *Map
171 }
172
173 func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
174 targetScanner := (target).(CompositeIndexScanner)
175
176 if src == nil {
177 return targetScanner.ScanNull()
178 }
179
180 scanner := NewCompositeTextScanner(plan.m, src)
181 for i, field := range plan.cc.Fields {
182 if scanner.Next() {
183 fieldTarget := targetScanner.ScanIndex(i)
184 if fieldTarget != nil {
185 fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget)
186 if fieldPlan == nil {
187 return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID)
188 }
189
190 err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
191 if err != nil {
192 return err
193 }
194 }
195 } else {
196 return errors.New("read past end of composite")
197 }
198 }
199
200 if err := scanner.Err(); err != nil {
201 return err
202 }
203
204 return nil
205 }
206
207 func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
208 if src == nil {
209 return nil, nil
210 }
211
212 switch format {
213 case TextFormatCode:
214 return string(src), nil
215 case BinaryFormatCode:
216 buf := make([]byte, len(src))
217 copy(buf, src)
218 return buf, nil
219 default:
220 return nil, fmt.Errorf("unknown format code %d", format)
221 }
222 }
223
224 func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
225 if src == nil {
226 return nil, nil
227 }
228
229 switch format {
230 case TextFormatCode:
231 scanner := NewCompositeTextScanner(m, src)
232 values := make(map[string]any, len(c.Fields))
233 for i := 0; scanner.Next() && i < len(c.Fields); i++ {
234 var v any
235 fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v)
236 if fieldPlan == nil {
237 return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v)
238 }
239
240 err := fieldPlan.Scan(scanner.Bytes(), &v)
241 if err != nil {
242 return nil, err
243 }
244
245 values[c.Fields[i].Name] = v
246 }
247
248 if err := scanner.Err(); err != nil {
249 return nil, err
250 }
251
252 return values, nil
253 case BinaryFormatCode:
254 scanner := NewCompositeBinaryScanner(m, src)
255 values := make(map[string]any, len(c.Fields))
256 for i := 0; scanner.Next() && i < len(c.Fields); i++ {
257 var v any
258 fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v)
259 if fieldPlan == nil {
260 return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v)
261 }
262
263 err := fieldPlan.Scan(scanner.Bytes(), &v)
264 if err != nil {
265 return nil, err
266 }
267
268 values[c.Fields[i].Name] = v
269 }
270
271 if err := scanner.Err(); err != nil {
272 return nil, err
273 }
274
275 return values, nil
276 default:
277 return nil, fmt.Errorf("unknown format code %d", format)
278 }
279
280 }
281
282 type CompositeBinaryScanner struct {
283 m *Map
284 rp int
285 src []byte
286
287 fieldCount int32
288 fieldBytes []byte
289 fieldOID uint32
290 err error
291 }
292
293
294 func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner {
295 rp := 0
296 if len(src[rp:]) < 4 {
297 return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
298 }
299
300 fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
301 rp += 4
302
303 return &CompositeBinaryScanner{
304 m: m,
305 rp: rp,
306 src: src,
307 fieldCount: fieldCount,
308 }
309 }
310
311
312
313 func (cfs *CompositeBinaryScanner) Next() bool {
314 if cfs.err != nil {
315 return false
316 }
317
318 if cfs.rp == len(cfs.src) {
319 return false
320 }
321
322 if len(cfs.src[cfs.rp:]) < 8 {
323 cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
324 return false
325 }
326 cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
327 cfs.rp += 4
328
329 fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
330 cfs.rp += 4
331
332 if fieldLen >= 0 {
333 if len(cfs.src[cfs.rp:]) < fieldLen {
334 cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
335 return false
336 }
337 cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
338 cfs.rp += fieldLen
339 } else {
340 cfs.fieldBytes = nil
341 }
342
343 return true
344 }
345
346 func (cfs *CompositeBinaryScanner) FieldCount() int {
347 return int(cfs.fieldCount)
348 }
349
350
351 func (cfs *CompositeBinaryScanner) Bytes() []byte {
352 return cfs.fieldBytes
353 }
354
355
356 func (cfs *CompositeBinaryScanner) OID() uint32 {
357 return cfs.fieldOID
358 }
359
360
361 func (cfs *CompositeBinaryScanner) Err() error {
362 return cfs.err
363 }
364
365 type CompositeTextScanner struct {
366 m *Map
367 rp int
368 src []byte
369
370 fieldBytes []byte
371 err error
372 }
373
374
375 func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner {
376 if len(src) < 2 {
377 return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
378 }
379
380 if src[0] != '(' {
381 return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
382 }
383
384 if src[len(src)-1] != ')' {
385 return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
386 }
387
388 return &CompositeTextScanner{
389 m: m,
390 rp: 1,
391 src: src,
392 }
393 }
394
395
396
397 func (cfs *CompositeTextScanner) Next() bool {
398 if cfs.err != nil {
399 return false
400 }
401
402 if cfs.rp == len(cfs.src) {
403 return false
404 }
405
406 switch cfs.src[cfs.rp] {
407 case ',', ')':
408 cfs.rp++
409 cfs.fieldBytes = nil
410 return true
411 case '"':
412 cfs.rp++
413 cfs.fieldBytes = make([]byte, 0, 16)
414 for {
415 ch := cfs.src[cfs.rp]
416
417 if ch == '"' {
418 cfs.rp++
419 if cfs.src[cfs.rp] == '"' {
420 cfs.fieldBytes = append(cfs.fieldBytes, '"')
421 cfs.rp++
422 } else {
423 break
424 }
425 } else if ch == '\\' {
426 cfs.rp++
427 cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
428 cfs.rp++
429 } else {
430 cfs.fieldBytes = append(cfs.fieldBytes, ch)
431 cfs.rp++
432 }
433 }
434 cfs.rp++
435 return true
436 default:
437 start := cfs.rp
438 for {
439 ch := cfs.src[cfs.rp]
440 if ch == ',' || ch == ')' {
441 break
442 }
443 cfs.rp++
444 }
445 cfs.fieldBytes = cfs.src[start:cfs.rp]
446 cfs.rp++
447 return true
448 }
449 }
450
451
452 func (cfs *CompositeTextScanner) Bytes() []byte {
453 return cfs.fieldBytes
454 }
455
456
457 func (cfs *CompositeTextScanner) Err() error {
458 return cfs.err
459 }
460
461 type CompositeBinaryBuilder struct {
462 m *Map
463 buf []byte
464 startIdx int
465 fieldCount uint32
466 err error
467 }
468
469 func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder {
470 startIdx := len(buf)
471 buf = append(buf, 0, 0, 0, 0)
472 return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx}
473 }
474
475 func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) {
476 if b.err != nil {
477 return
478 }
479
480 if field == nil {
481 b.buf = pgio.AppendUint32(b.buf, oid)
482 b.buf = pgio.AppendInt32(b.buf, -1)
483 b.fieldCount++
484 return
485 }
486
487 plan := b.m.PlanEncode(oid, BinaryFormatCode, field)
488 if plan == nil {
489 b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid)
490 return
491 }
492
493 b.buf = pgio.AppendUint32(b.buf, oid)
494 lengthPos := len(b.buf)
495 b.buf = pgio.AppendInt32(b.buf, -1)
496 fieldBuf, err := plan.Encode(field, b.buf)
497 if err != nil {
498 b.err = err
499 return
500 }
501 if fieldBuf != nil {
502 binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
503 b.buf = fieldBuf
504 }
505
506 b.fieldCount++
507 }
508
509 func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
510 if b.err != nil {
511 return nil, b.err
512 }
513
514 binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
515 return b.buf, nil
516 }
517
518 type CompositeTextBuilder struct {
519 m *Map
520 buf []byte
521 startIdx int
522 fieldCount uint32
523 err error
524 fieldBuf [32]byte
525 }
526
527 func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder {
528 buf = append(buf, '(')
529 return &CompositeTextBuilder{m: m, buf: buf}
530 }
531
532 func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) {
533 if b.err != nil {
534 return
535 }
536
537 if field == nil {
538 b.buf = append(b.buf, ',')
539 return
540 }
541
542 plan := b.m.PlanEncode(oid, TextFormatCode, field)
543 if plan == nil {
544 b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid)
545 return
546 }
547
548 fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0])
549 if err != nil {
550 b.err = err
551 return
552 }
553 if fieldBuf != nil {
554 b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
555 }
556
557 b.buf = append(b.buf, ',')
558 }
559
560 func (b *CompositeTextBuilder) Finish() ([]byte, error) {
561 if b.err != nil {
562 return nil, b.err
563 }
564
565 b.buf[len(b.buf)-1] = ')'
566 return b.buf, nil
567 }
568
569 var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
570
571 func quoteCompositeField(src string) string {
572 return `"` + quoteCompositeReplacer.Replace(src) + `"`
573 }
574
575 func quoteCompositeFieldIfNeeded(src string) string {
576 if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
577 return quoteCompositeField(src)
578 }
579 return src
580 }
581
582
583
584 type CompositeFields []any
585
586 func (cf CompositeFields) SkipUnderlyingTypePlan() {}
587
588 func (cf CompositeFields) IsNull() bool {
589 return cf == nil
590 }
591
592 func (cf CompositeFields) Index(i int) any {
593 return cf[i]
594 }
595
596 func (cf CompositeFields) ScanNull() error {
597 return fmt.Errorf("cannot scan NULL into CompositeFields")
598 }
599
600 func (cf CompositeFields) ScanIndex(i int) any {
601 return cf[i]
602 }
603
View as plain text