1 package pgtype
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/binary"
7 "fmt"
8 "reflect"
9
10 "github.com/jackc/pgx/v5/internal/pgio"
11 )
12
13
14 type MultirangeGetter interface {
15
16 IsNull() bool
17
18
19 Len() int
20
21
22 Index(i int) any
23
24
25 IndexType() any
26 }
27
28
29 type MultirangeSetter interface {
30
31 ScanNull() error
32
33
34
35 SetLen(n int) error
36
37
38 ScanIndex(i int) any
39
40
41
42 ScanIndexType() any
43 }
44
45
46 type MultirangeCodec struct {
47 ElementType *Type
48 }
49
50 func (c *MultirangeCodec) FormatSupported(format int16) bool {
51 return c.ElementType.Codec.FormatSupported(format)
52 }
53
54 func (c *MultirangeCodec) PreferredFormat() int16 {
55 return c.ElementType.Codec.PreferredFormat()
56 }
57
58 func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
59 multirangeValuer, ok := value.(MultirangeGetter)
60 if !ok {
61 return nil
62 }
63
64 elementType := multirangeValuer.IndexType()
65
66 elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
67 if elementEncodePlan == nil {
68 return nil
69 }
70
71 switch format {
72 case BinaryFormatCode:
73 return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid}
74 case TextFormatCode:
75 return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid}
76 }
77
78 return nil
79 }
80
81 type encodePlanMultirangeCodecText struct {
82 ac *MultirangeCodec
83 m *Map
84 oid uint32
85 }
86
87 func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
88 multirange := value.(MultirangeGetter)
89
90 if multirange.IsNull() {
91 return nil, nil
92 }
93
94 elementCount := multirange.Len()
95
96 buf = append(buf, '{')
97
98 var encodePlan EncodePlan
99 var lastElemType reflect.Type
100 inElemBuf := make([]byte, 0, 32)
101 for i := 0; i < elementCount; i++ {
102 if i > 0 {
103 buf = append(buf, ',')
104 }
105
106 elem := multirange.Index(i)
107 var elemBuf []byte
108 if elem != nil {
109 elemType := reflect.TypeOf(elem)
110 if lastElemType != elemType {
111 lastElemType = elemType
112 encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem)
113 if encodePlan == nil {
114 return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
115 }
116 }
117 elemBuf, err = encodePlan.Encode(elem, inElemBuf)
118 if err != nil {
119 return nil, err
120 }
121 }
122
123 if elemBuf == nil {
124 return nil, fmt.Errorf("multirange cannot contain NULL element")
125 } else {
126 buf = append(buf, elemBuf...)
127 }
128 }
129
130 buf = append(buf, '}')
131
132 return buf, nil
133 }
134
135 type encodePlanMultirangeCodecBinary struct {
136 ac *MultirangeCodec
137 m *Map
138 oid uint32
139 }
140
141 func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
142 multirange := value.(MultirangeGetter)
143
144 if multirange.IsNull() {
145 return nil, nil
146 }
147
148 elementCount := multirange.Len()
149
150 buf = pgio.AppendInt32(buf, int32(elementCount))
151
152 var encodePlan EncodePlan
153 var lastElemType reflect.Type
154 for i := 0; i < elementCount; i++ {
155 sp := len(buf)
156 buf = pgio.AppendInt32(buf, -1)
157
158 elem := multirange.Index(i)
159 var elemBuf []byte
160 if elem != nil {
161 elemType := reflect.TypeOf(elem)
162 if lastElemType != elemType {
163 lastElemType = elemType
164 encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem)
165 if encodePlan == nil {
166 return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
167 }
168 }
169 elemBuf, err = encodePlan.Encode(elem, buf)
170 if err != nil {
171 return nil, err
172 }
173 }
174
175 if elemBuf == nil {
176 return nil, fmt.Errorf("multirange cannot contain NULL element")
177 } else {
178 buf = elemBuf
179 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
180 }
181 }
182
183 return buf, nil
184 }
185
186 func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
187 multirangeScanner, ok := target.(MultirangeSetter)
188 if !ok {
189 return nil
190 }
191
192 elementType := multirangeScanner.ScanIndexType()
193
194 elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType)
195 if _, ok := elementScanPlan.(*scanPlanFail); ok {
196 return nil
197 }
198
199 return &scanPlanMultirangeCodec{
200 multirangeCodec: c,
201 m: m,
202 oid: oid,
203 formatCode: format,
204 }
205 }
206
207 func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
208 rp := 0
209
210 elementCount := int(binary.BigEndian.Uint32(src[rp:]))
211 rp += 4
212
213 err := multirange.SetLen(elementCount)
214 if err != nil {
215 return err
216 }
217
218 if elementCount == 0 {
219 return nil
220 }
221
222 elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
223 if elementScanPlan == nil {
224 elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
225 }
226
227 for i := 0; i < elementCount; i++ {
228 elem := multirange.ScanIndex(i)
229 elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
230 rp += 4
231 var elemSrc []byte
232 if elemLen >= 0 {
233 elemSrc = src[rp : rp+elemLen]
234 rp += elemLen
235 }
236 err = elementScanPlan.Scan(elemSrc, elem)
237 if err != nil {
238 return fmt.Errorf("failed to scan multirange element %d: %w", i, err)
239 }
240 }
241
242 return nil
243 }
244
245 func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
246 elements, err := parseUntypedTextMultirange(src)
247 if err != nil {
248 return err
249 }
250
251 err = multirange.SetLen(len(elements))
252 if err != nil {
253 return err
254 }
255
256 if len(elements) == 0 {
257 return nil
258 }
259
260 elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
261 if elementScanPlan == nil {
262 elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
263 }
264
265 for i, s := range elements {
266 elem := multirange.ScanIndex(i)
267 err = elementScanPlan.Scan([]byte(s), elem)
268 if err != nil {
269 return err
270 }
271 }
272
273 return nil
274 }
275
276 type scanPlanMultirangeCodec struct {
277 multirangeCodec *MultirangeCodec
278 m *Map
279 oid uint32
280 formatCode int16
281 elementScanPlan ScanPlan
282 }
283
284 func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error {
285 c := spac.multirangeCodec
286 m := spac.m
287 oid := spac.oid
288 formatCode := spac.formatCode
289
290 multirange := dst.(MultirangeSetter)
291
292 if src == nil {
293 return multirange.ScanNull()
294 }
295
296 switch formatCode {
297 case BinaryFormatCode:
298 return c.decodeBinary(m, oid, src, multirange)
299 case TextFormatCode:
300 return c.decodeText(m, oid, src, multirange)
301 default:
302 return fmt.Errorf("unknown format code %d", formatCode)
303 }
304 }
305
306 func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
307 if src == nil {
308 return nil, nil
309 }
310
311 switch format {
312 case TextFormatCode:
313 return string(src), nil
314 case BinaryFormatCode:
315 buf := make([]byte, len(src))
316 copy(buf, src)
317 return buf, nil
318 default:
319 return nil, fmt.Errorf("unknown format code %d", format)
320 }
321 }
322
323 func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
324 if src == nil {
325 return nil, nil
326 }
327
328 var multirange Multirange[Range[any]]
329 err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange)
330 return multirange, err
331 }
332
333 func parseUntypedTextMultirange(src []byte) ([]string, error) {
334 elements := make([]string, 0)
335
336 buf := bytes.NewBuffer(src)
337
338 skipWhitespace(buf)
339
340 r, _, err := buf.ReadRune()
341 if err != nil {
342 return nil, fmt.Errorf("invalid array: %w", err)
343 }
344
345 if r != '{' {
346 return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r)
347 }
348
349 parseValueLoop:
350 for {
351 r, _, err = buf.ReadRune()
352 if err != nil {
353 return nil, fmt.Errorf("invalid multirange: %w", err)
354 }
355
356 switch r {
357 case ',':
358 case '}':
359 break parseValueLoop
360 default:
361 buf.UnreadRune()
362 value, err := parseRange(buf)
363 if err != nil {
364 return nil, fmt.Errorf("invalid multirange value: %w", err)
365 }
366 elements = append(elements, value)
367 }
368 }
369
370 skipWhitespace(buf)
371
372 if buf.Len() > 0 {
373 return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
374 }
375
376 return elements, nil
377
378 }
379
380 func parseRange(buf *bytes.Buffer) (string, error) {
381 s := &bytes.Buffer{}
382
383 boundSepRead := false
384 for {
385 r, _, err := buf.ReadRune()
386 if err != nil {
387 return "", err
388 }
389
390 switch r {
391 case ',', '}':
392 if r == ',' && !boundSepRead {
393 boundSepRead = true
394 break
395 }
396 buf.UnreadRune()
397 return s.String(), nil
398 }
399
400 s.WriteRune(r)
401 }
402 }
403
404
405
406
407
408 type Multirange[T RangeValuer] []T
409
410 func (r Multirange[T]) IsNull() bool {
411 return r == nil
412 }
413
414 func (r Multirange[T]) Len() int {
415 return len(r)
416 }
417
418 func (r Multirange[T]) Index(i int) any {
419 return r[i]
420 }
421
422 func (r Multirange[T]) IndexType() any {
423 var zero T
424 return zero
425 }
426
427 func (r *Multirange[T]) ScanNull() error {
428 *r = nil
429 return nil
430 }
431
432 func (r *Multirange[T]) SetLen(n int) error {
433 *r = make([]T, n)
434 return nil
435 }
436
437 func (r Multirange[T]) ScanIndex(i int) any {
438 return &r[i]
439 }
440
441 func (r Multirange[T]) ScanIndexType() any {
442 return new(T)
443 }
444
View as plain text