1
2
3 package pgtype
4
5 import (
6 "database/sql/driver"
7 "encoding/binary"
8 "fmt"
9 "reflect"
10
11 "github.com/jackc/pgio"
12 )
13
14 type ByteaArray struct {
15 Elements []Bytea
16 Dimensions []ArrayDimension
17 Status Status
18 }
19
20 func (dst *ByteaArray) Set(src interface{}) error {
21
22 if src == nil {
23 *dst = ByteaArray{Status: Null}
24 return nil
25 }
26
27 if value, ok := src.(interface{ Get() interface{} }); ok {
28 value2 := value.Get()
29 if value2 != value {
30 return dst.Set(value2)
31 }
32 }
33
34
35 switch value := src.(type) {
36
37 case [][]byte:
38 if value == nil {
39 *dst = ByteaArray{Status: Null}
40 } else if len(value) == 0 {
41 *dst = ByteaArray{Status: Present}
42 } else {
43 elements := make([]Bytea, len(value))
44 for i := range value {
45 if err := elements[i].Set(value[i]); err != nil {
46 return err
47 }
48 }
49 *dst = ByteaArray{
50 Elements: elements,
51 Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
52 Status: Present,
53 }
54 }
55
56 case []Bytea:
57 if value == nil {
58 *dst = ByteaArray{Status: Null}
59 } else if len(value) == 0 {
60 *dst = ByteaArray{Status: Present}
61 } else {
62 *dst = ByteaArray{
63 Elements: value,
64 Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
65 Status: Present,
66 }
67 }
68 default:
69
70
71
72 reflectedValue := reflect.ValueOf(src)
73 if !reflectedValue.IsValid() || reflectedValue.IsZero() {
74 *dst = ByteaArray{Status: Null}
75 return nil
76 }
77
78 dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
79 if !ok {
80 return fmt.Errorf("cannot find dimensions of %v for ByteaArray", src)
81 }
82 if elementsLength == 0 {
83 *dst = ByteaArray{Status: Present}
84 return nil
85 }
86 if len(dimensions) == 0 {
87 if originalSrc, ok := underlyingSliceType(src); ok {
88 return dst.Set(originalSrc)
89 }
90 return fmt.Errorf("cannot convert %v to ByteaArray", src)
91 }
92
93 *dst = ByteaArray{
94 Elements: make([]Bytea, elementsLength),
95 Dimensions: dimensions,
96 Status: Present,
97 }
98 elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
99 if err != nil {
100
101 if len(dst.Dimensions) > 1 {
102 dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
103 elementsLength = 0
104 for _, dim := range dst.Dimensions {
105 if elementsLength == 0 {
106 elementsLength = int(dim.Length)
107 } else {
108 elementsLength *= int(dim.Length)
109 }
110 }
111 dst.Elements = make([]Bytea, elementsLength)
112 elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
113 if err != nil {
114 return err
115 }
116 } else {
117 return err
118 }
119 }
120 if elementCount != len(dst.Elements) {
121 return fmt.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
122 }
123 }
124
125 return nil
126 }
127
128 func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) (int, error) {
129 switch value.Kind() {
130 case reflect.Array:
131 fallthrough
132 case reflect.Slice:
133 if len(dst.Dimensions) == dimension {
134 break
135 }
136
137 valueLen := value.Len()
138 if int32(valueLen) != dst.Dimensions[dimension].Length {
139 return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
140 }
141 for i := 0; i < valueLen; i++ {
142 var err error
143 index, err = dst.setRecursive(value.Index(i), index, dimension+1)
144 if err != nil {
145 return 0, err
146 }
147 }
148
149 return index, nil
150 }
151 if !value.CanInterface() {
152 return 0, fmt.Errorf("cannot convert all values to ByteaArray")
153 }
154 if err := dst.Elements[index].Set(value.Interface()); err != nil {
155 return 0, fmt.Errorf("%v in ByteaArray", err)
156 }
157 index++
158
159 return index, nil
160 }
161
162 func (dst ByteaArray) Get() interface{} {
163 switch dst.Status {
164 case Present:
165 return dst
166 case Null:
167 return nil
168 default:
169 return dst.Status
170 }
171 }
172
173 func (src *ByteaArray) AssignTo(dst interface{}) error {
174 switch src.Status {
175 case Present:
176 if len(src.Dimensions) <= 1 {
177
178 switch v := dst.(type) {
179
180 case *[][]byte:
181 *v = make([][]byte, len(src.Elements))
182 for i := range src.Elements {
183 if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
184 return err
185 }
186 }
187 return nil
188
189 }
190 }
191
192
193 if nextDst, retry := GetAssignToDstType(dst); retry {
194 return src.AssignTo(nextDst)
195 }
196
197
198
199
200 value := reflect.ValueOf(dst)
201 if value.Kind() == reflect.Ptr {
202 value = value.Elem()
203 }
204
205 switch value.Kind() {
206 case reflect.Array, reflect.Slice:
207 default:
208 return fmt.Errorf("cannot assign %T to %T", src, dst)
209 }
210
211 if len(src.Elements) == 0 {
212 if value.Kind() == reflect.Slice {
213 value.Set(reflect.MakeSlice(value.Type(), 0, 0))
214 return nil
215 }
216 }
217
218 elementCount, err := src.assignToRecursive(value, 0, 0)
219 if err != nil {
220 return err
221 }
222 if elementCount != len(src.Elements) {
223 return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
224 }
225
226 return nil
227 case Null:
228 return NullAssignTo(dst)
229 }
230
231 return fmt.Errorf("cannot decode %#v into %T", src, dst)
232 }
233
234 func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
235 switch kind := value.Kind(); kind {
236 case reflect.Array:
237 fallthrough
238 case reflect.Slice:
239 if len(src.Dimensions) == dimension {
240 break
241 }
242
243 length := int(src.Dimensions[dimension].Length)
244 if reflect.Array == kind {
245 typ := value.Type()
246 if typ.Len() != length {
247 return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
248 }
249 value.Set(reflect.New(typ).Elem())
250 } else {
251 value.Set(reflect.MakeSlice(value.Type(), length, length))
252 }
253
254 var err error
255 for i := 0; i < length; i++ {
256 index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
257 if err != nil {
258 return 0, err
259 }
260 }
261
262 return index, nil
263 }
264 if len(src.Dimensions) != dimension {
265 return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
266 }
267 if !value.CanAddr() {
268 return 0, fmt.Errorf("cannot assign all values from ByteaArray")
269 }
270 addr := value.Addr()
271 if !addr.CanInterface() {
272 return 0, fmt.Errorf("cannot assign all values from ByteaArray")
273 }
274 if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
275 return 0, err
276 }
277 index++
278 return index, nil
279 }
280
281 func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error {
282 if src == nil {
283 *dst = ByteaArray{Status: Null}
284 return nil
285 }
286
287 uta, err := ParseUntypedTextArray(string(src))
288 if err != nil {
289 return err
290 }
291
292 var elements []Bytea
293
294 if len(uta.Elements) > 0 {
295 elements = make([]Bytea, len(uta.Elements))
296
297 for i, s := range uta.Elements {
298 var elem Bytea
299 var elemSrc []byte
300 if s != "NULL" || uta.Quoted[i] {
301 elemSrc = []byte(s)
302 }
303 err = elem.DecodeText(ci, elemSrc)
304 if err != nil {
305 return err
306 }
307
308 elements[i] = elem
309 }
310 }
311
312 *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
313
314 return nil
315 }
316
317 func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error {
318 if src == nil {
319 *dst = ByteaArray{Status: Null}
320 return nil
321 }
322
323 var arrayHeader ArrayHeader
324 rp, err := arrayHeader.DecodeBinary(ci, src)
325 if err != nil {
326 return err
327 }
328
329 if len(arrayHeader.Dimensions) == 0 {
330 *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present}
331 return nil
332 }
333
334 elementCount := arrayHeader.Dimensions[0].Length
335 for _, d := range arrayHeader.Dimensions[1:] {
336 elementCount *= d.Length
337 }
338
339 elements := make([]Bytea, elementCount)
340
341 for i := range elements {
342 elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
343 rp += 4
344 var elemSrc []byte
345 if elemLen >= 0 {
346 elemSrc = src[rp : rp+elemLen]
347 rp += elemLen
348 }
349 err = elements[i].DecodeBinary(ci, elemSrc)
350 if err != nil {
351 return err
352 }
353 }
354
355 *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
356 return nil
357 }
358
359 func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
360 switch src.Status {
361 case Null:
362 return nil, nil
363 case Undefined:
364 return nil, errUndefined
365 }
366
367 if len(src.Dimensions) == 0 {
368 return append(buf, '{', '}'), nil
369 }
370
371 buf = EncodeTextArrayDimensions(buf, src.Dimensions)
372
373
374
375
376
377
378 dimElemCounts := make([]int, len(src.Dimensions))
379 dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
380 for i := len(src.Dimensions) - 2; i > -1; i-- {
381 dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
382 }
383
384 inElemBuf := make([]byte, 0, 32)
385 for i, elem := range src.Elements {
386 if i > 0 {
387 buf = append(buf, ',')
388 }
389
390 for _, dec := range dimElemCounts {
391 if i%dec == 0 {
392 buf = append(buf, '{')
393 }
394 }
395
396 elemBuf, err := elem.EncodeText(ci, inElemBuf)
397 if err != nil {
398 return nil, err
399 }
400 if elemBuf == nil {
401 buf = append(buf, `NULL`...)
402 } else {
403 buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
404 }
405
406 for _, dec := range dimElemCounts {
407 if (i+1)%dec == 0 {
408 buf = append(buf, '}')
409 }
410 }
411 }
412
413 return buf, nil
414 }
415
416 func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
417 switch src.Status {
418 case Null:
419 return nil, nil
420 case Undefined:
421 return nil, errUndefined
422 }
423
424 arrayHeader := ArrayHeader{
425 Dimensions: src.Dimensions,
426 }
427
428 if dt, ok := ci.DataTypeForName("bytea"); ok {
429 arrayHeader.ElementOID = int32(dt.OID)
430 } else {
431 return nil, fmt.Errorf("unable to find oid for type name %v", "bytea")
432 }
433
434 for i := range src.Elements {
435 if src.Elements[i].Status == Null {
436 arrayHeader.ContainsNull = true
437 break
438 }
439 }
440
441 buf = arrayHeader.EncodeBinary(ci, buf)
442
443 for i := range src.Elements {
444 sp := len(buf)
445 buf = pgio.AppendInt32(buf, -1)
446
447 elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
448 if err != nil {
449 return nil, err
450 }
451 if elemBuf != nil {
452 buf = elemBuf
453 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
454 }
455 }
456
457 return buf, nil
458 }
459
460
461 func (dst *ByteaArray) Scan(src interface{}) error {
462 if src == nil {
463 return dst.DecodeText(nil, nil)
464 }
465
466 switch src := src.(type) {
467 case string:
468 return dst.DecodeText(nil, []byte(src))
469 case []byte:
470 srcCopy := make([]byte, len(src))
471 copy(srcCopy, src)
472 return dst.DecodeText(nil, srcCopy)
473 }
474
475 return fmt.Errorf("cannot scan %T", src)
476 }
477
478
479 func (src ByteaArray) Value() (driver.Value, error) {
480 buf, err := src.EncodeText(nil, nil)
481 if err != nil {
482 return nil, err
483 }
484 if buf == nil {
485 return nil, nil
486 }
487
488 return string(buf), nil
489 }
490
View as plain text