1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "fmt"
7 "reflect"
8
9 "github.com/jackc/pgx/v5/internal/anynil"
10 "github.com/jackc/pgx/v5/internal/pgio"
11 )
12
13
14 type ArrayGetter interface {
15
16 Dimensions() []ArrayDimension
17
18
19 Index(i int) any
20
21
22 IndexType() any
23 }
24
25
26 type ArraySetter interface {
27
28
29
30 SetDimensions(dimensions []ArrayDimension) error
31
32
33 ScanIndex(i int) any
34
35
36
37 ScanIndexType() any
38 }
39
40
41 type ArrayCodec struct {
42 ElementType *Type
43 }
44
45 func (c *ArrayCodec) FormatSupported(format int16) bool {
46 return c.ElementType.Codec.FormatSupported(format)
47 }
48
49 func (c *ArrayCodec) PreferredFormat() int16 {
50
51
52
53
54
55
56 if c.ElementType.Codec.FormatSupported(BinaryFormatCode) {
57 return BinaryFormatCode
58 }
59 return TextFormatCode
60 }
61
62 func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
63 arrayValuer, ok := value.(ArrayGetter)
64 if !ok {
65 return nil
66 }
67
68 elementType := arrayValuer.IndexType()
69
70 elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
71 if elementEncodePlan == nil {
72 if reflect.TypeOf(elementType) != nil {
73 return nil
74 }
75 }
76
77 switch format {
78 case BinaryFormatCode:
79 return &encodePlanArrayCodecBinary{ac: c, m: m, oid: oid}
80 case TextFormatCode:
81 return &encodePlanArrayCodecText{ac: c, m: m, oid: oid}
82 }
83
84 return nil
85 }
86
87 type encodePlanArrayCodecText struct {
88 ac *ArrayCodec
89 m *Map
90 oid uint32
91 }
92
93 func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
94 array := value.(ArrayGetter)
95
96 dimensions := array.Dimensions()
97 if dimensions == nil {
98 return nil, nil
99 }
100
101 elementCount := cardinality(dimensions)
102 if elementCount == 0 {
103 return append(buf, '{', '}'), nil
104 }
105
106 buf = encodeTextArrayDimensions(buf, dimensions)
107
108
109
110
111
112
113 dimElemCounts := make([]int, len(dimensions))
114 dimElemCounts[len(dimensions)-1] = int(dimensions[len(dimensions)-1].Length)
115 for i := len(dimensions) - 2; i > -1; i-- {
116 dimElemCounts[i] = int(dimensions[i].Length) * dimElemCounts[i+1]
117 }
118
119 var encodePlan EncodePlan
120 var lastElemType reflect.Type
121 inElemBuf := make([]byte, 0, 32)
122 for i := 0; i < elementCount; i++ {
123 if i > 0 {
124 buf = append(buf, ',')
125 }
126
127 for _, dec := range dimElemCounts {
128 if i%dec == 0 {
129 buf = append(buf, '{')
130 }
131 }
132
133 elem := array.Index(i)
134 var elemBuf []byte
135 if elem != nil {
136 elemType := reflect.TypeOf(elem)
137 if lastElemType != elemType {
138 lastElemType = elemType
139 encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem)
140 if encodePlan == nil {
141 return nil, fmt.Errorf("unable to encode %v", array.Index(i))
142 }
143 }
144 elemBuf, err = encodePlan.Encode(elem, inElemBuf)
145 if err != nil {
146 return nil, err
147 }
148 }
149
150 if elemBuf == nil {
151 buf = append(buf, `NULL`...)
152 } else {
153 buf = append(buf, quoteArrayElementIfNeeded(string(elemBuf))...)
154 }
155
156 for _, dec := range dimElemCounts {
157 if (i+1)%dec == 0 {
158 buf = append(buf, '}')
159 }
160 }
161 }
162
163 return buf, nil
164 }
165
166 type encodePlanArrayCodecBinary struct {
167 ac *ArrayCodec
168 m *Map
169 oid uint32
170 }
171
172 func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
173 array := value.(ArrayGetter)
174
175 dimensions := array.Dimensions()
176 if dimensions == nil {
177 return nil, nil
178 }
179
180 arrayHeader := arrayHeader{
181 Dimensions: dimensions,
182 ElementOID: p.ac.ElementType.OID,
183 }
184
185 containsNullIndex := len(buf) + 4
186
187 buf = arrayHeader.EncodeBinary(buf)
188
189 elementCount := cardinality(dimensions)
190
191 var encodePlan EncodePlan
192 var lastElemType reflect.Type
193 for i := 0; i < elementCount; i++ {
194 sp := len(buf)
195 buf = pgio.AppendInt32(buf, -1)
196
197 elem := array.Index(i)
198 var elemBuf []byte
199 if elem != nil {
200 elemType := reflect.TypeOf(elem)
201 if lastElemType != elemType {
202 lastElemType = elemType
203 encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem)
204 if encodePlan == nil {
205 return nil, fmt.Errorf("unable to encode %v", array.Index(i))
206 }
207 }
208 elemBuf, err = encodePlan.Encode(elem, buf)
209 if err != nil {
210 return nil, err
211 }
212 }
213
214 if elemBuf == nil {
215 pgio.SetInt32(buf[containsNullIndex:], 1)
216 } else {
217 buf = elemBuf
218 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
219 }
220 }
221
222 return buf, nil
223 }
224
225 func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
226 arrayScanner, ok := target.(ArraySetter)
227 if !ok {
228 return nil
229 }
230
231
232
233 if anynil.Is(target) {
234 arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter)
235 }
236
237 elementType := arrayScanner.ScanIndexType()
238
239 elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType)
240 if _, ok := elementScanPlan.(*scanPlanFail); ok {
241 return nil
242 }
243
244 return &scanPlanArrayCodec{
245 arrayCodec: c,
246 m: m,
247 oid: oid,
248 formatCode: format,
249 }
250 }
251
252 func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array ArraySetter) error {
253 var arrayHeader arrayHeader
254 rp, err := arrayHeader.DecodeBinary(m, src)
255 if err != nil {
256 return err
257 }
258
259 err = array.SetDimensions(arrayHeader.Dimensions)
260 if err != nil {
261 return err
262 }
263
264 elementCount := cardinality(arrayHeader.Dimensions)
265 if elementCount == 0 {
266 return nil
267 }
268
269 elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0))
270 if elementScanPlan == nil {
271 elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0))
272 }
273
274 for i := 0; i < elementCount; i++ {
275 elem := array.ScanIndex(i)
276 elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
277 rp += 4
278 var elemSrc []byte
279 if elemLen >= 0 {
280 elemSrc = src[rp : rp+elemLen]
281 rp += elemLen
282 }
283 err = elementScanPlan.Scan(elemSrc, elem)
284 if err != nil {
285 return fmt.Errorf("failed to scan array element %d: %w", i, err)
286 }
287 }
288
289 return nil
290 }
291
292 func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array ArraySetter) error {
293 uta, err := parseUntypedTextArray(string(src))
294 if err != nil {
295 return err
296 }
297
298 err = array.SetDimensions(uta.Dimensions)
299 if err != nil {
300 return err
301 }
302
303 if len(uta.Elements) == 0 {
304 return nil
305 }
306
307 elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, array.ScanIndex(0))
308 if elementScanPlan == nil {
309 elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, array.ScanIndex(0))
310 }
311
312 for i, s := range uta.Elements {
313 elem := array.ScanIndex(i)
314 var elemSrc []byte
315 if s != "NULL" || uta.Quoted[i] {
316 elemSrc = []byte(s)
317 }
318
319 err = elementScanPlan.Scan(elemSrc, elem)
320 if err != nil {
321 return err
322 }
323 }
324
325 return nil
326 }
327
328 type scanPlanArrayCodec struct {
329 arrayCodec *ArrayCodec
330 m *Map
331 oid uint32
332 formatCode int16
333 elementScanPlan ScanPlan
334 }
335
336 func (spac *scanPlanArrayCodec) Scan(src []byte, dst any) error {
337 c := spac.arrayCodec
338 m := spac.m
339 oid := spac.oid
340 formatCode := spac.formatCode
341
342 array := dst.(ArraySetter)
343
344 if src == nil {
345 return array.SetDimensions(nil)
346 }
347
348 switch formatCode {
349 case BinaryFormatCode:
350 return c.decodeBinary(m, oid, src, array)
351 case TextFormatCode:
352 return c.decodeText(m, oid, src, array)
353 default:
354 return fmt.Errorf("unknown format code %d", formatCode)
355 }
356 }
357
358 func (c *ArrayCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
359 if src == nil {
360 return nil, nil
361 }
362
363 switch format {
364 case TextFormatCode:
365 return string(src), nil
366 case BinaryFormatCode:
367 buf := make([]byte, len(src))
368 copy(buf, src)
369 return buf, nil
370 default:
371 return nil, fmt.Errorf("unknown format code %d", format)
372 }
373 }
374
375 func (c *ArrayCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
376 if src == nil {
377 return nil, nil
378 }
379
380 var slice []any
381 err := m.PlanScan(oid, format, &slice).Scan(src, &slice)
382 return slice, err
383 }
384
385 func isRagged(slice reflect.Value) bool {
386 if slice.Type().Elem().Kind() != reflect.Slice {
387 return false
388 }
389
390 sliceLen := slice.Len()
391 innerLen := 0
392 for i := 0; i < sliceLen; i++ {
393 if i == 0 {
394 innerLen = slice.Index(i).Len()
395 } else {
396 if slice.Index(i).Len() != innerLen {
397 return true
398 }
399 }
400 if isRagged(slice.Index(i)) {
401 return true
402 }
403 }
404
405 return false
406 }
407
View as plain text