1
2
3 package pgtype
4
5 import (
6 "database/sql/driver"
7 "encoding/binary"
8 "fmt"
9 "net"
10 "reflect"
11
12 "github.com/jackc/pgio"
13 )
14
15 type CIDRArray struct {
16 Elements []CIDR
17 Dimensions []ArrayDimension
18 Status Status
19 }
20
21 func (dst *CIDRArray) Set(src interface{}) error {
22
23 if src == nil {
24 *dst = CIDRArray{Status: Null}
25 return nil
26 }
27
28 if value, ok := src.(interface{ Get() interface{} }); ok {
29 value2 := value.Get()
30 if value2 != value {
31 return dst.Set(value2)
32 }
33 }
34
35
36 switch value := src.(type) {
37
38 case []*net.IPNet:
39 if value == nil {
40 *dst = CIDRArray{Status: Null}
41 } else if len(value) == 0 {
42 *dst = CIDRArray{Status: Present}
43 } else {
44 elements := make([]CIDR, len(value))
45 for i := range value {
46 if err := elements[i].Set(value[i]); err != nil {
47 return err
48 }
49 }
50 *dst = CIDRArray{
51 Elements: elements,
52 Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
53 Status: Present,
54 }
55 }
56
57 case []net.IP:
58 if value == nil {
59 *dst = CIDRArray{Status: Null}
60 } else if len(value) == 0 {
61 *dst = CIDRArray{Status: Present}
62 } else {
63 elements := make([]CIDR, len(value))
64 for i := range value {
65 if err := elements[i].Set(value[i]); err != nil {
66 return err
67 }
68 }
69 *dst = CIDRArray{
70 Elements: elements,
71 Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
72 Status: Present,
73 }
74 }
75
76 case []*net.IP:
77 if value == nil {
78 *dst = CIDRArray{Status: Null}
79 } else if len(value) == 0 {
80 *dst = CIDRArray{Status: Present}
81 } else {
82 elements := make([]CIDR, len(value))
83 for i := range value {
84 if err := elements[i].Set(value[i]); err != nil {
85 return err
86 }
87 }
88 *dst = CIDRArray{
89 Elements: elements,
90 Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
91 Status: Present,
92 }
93 }
94
95 case []CIDR:
96 if value == nil {
97 *dst = CIDRArray{Status: Null}
98 } else if len(value) == 0 {
99 *dst = CIDRArray{Status: Present}
100 } else {
101 *dst = CIDRArray{
102 Elements: value,
103 Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
104 Status: Present,
105 }
106 }
107 default:
108
109
110
111 reflectedValue := reflect.ValueOf(src)
112 if !reflectedValue.IsValid() || reflectedValue.IsZero() {
113 *dst = CIDRArray{Status: Null}
114 return nil
115 }
116
117 dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
118 if !ok {
119 return fmt.Errorf("cannot find dimensions of %v for CIDRArray", src)
120 }
121 if elementsLength == 0 {
122 *dst = CIDRArray{Status: Present}
123 return nil
124 }
125 if len(dimensions) == 0 {
126 if originalSrc, ok := underlyingSliceType(src); ok {
127 return dst.Set(originalSrc)
128 }
129 return fmt.Errorf("cannot convert %v to CIDRArray", src)
130 }
131
132 *dst = CIDRArray{
133 Elements: make([]CIDR, elementsLength),
134 Dimensions: dimensions,
135 Status: Present,
136 }
137 elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
138 if err != nil {
139
140 if len(dst.Dimensions) > 1 {
141 dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
142 elementsLength = 0
143 for _, dim := range dst.Dimensions {
144 if elementsLength == 0 {
145 elementsLength = int(dim.Length)
146 } else {
147 elementsLength *= int(dim.Length)
148 }
149 }
150 dst.Elements = make([]CIDR, elementsLength)
151 elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
152 if err != nil {
153 return err
154 }
155 } else {
156 return err
157 }
158 }
159 if elementCount != len(dst.Elements) {
160 return fmt.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
161 }
162 }
163
164 return nil
165 }
166
167 func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) {
168 switch value.Kind() {
169 case reflect.Array:
170 fallthrough
171 case reflect.Slice:
172 if len(dst.Dimensions) == dimension {
173 break
174 }
175
176 valueLen := value.Len()
177 if int32(valueLen) != dst.Dimensions[dimension].Length {
178 return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
179 }
180 for i := 0; i < valueLen; i++ {
181 var err error
182 index, err = dst.setRecursive(value.Index(i), index, dimension+1)
183 if err != nil {
184 return 0, err
185 }
186 }
187
188 return index, nil
189 }
190 if !value.CanInterface() {
191 return 0, fmt.Errorf("cannot convert all values to CIDRArray")
192 }
193 if err := dst.Elements[index].Set(value.Interface()); err != nil {
194 return 0, fmt.Errorf("%v in CIDRArray", err)
195 }
196 index++
197
198 return index, nil
199 }
200
201 func (dst CIDRArray) Get() interface{} {
202 switch dst.Status {
203 case Present:
204 return dst
205 case Null:
206 return nil
207 default:
208 return dst.Status
209 }
210 }
211
212 func (src *CIDRArray) AssignTo(dst interface{}) error {
213 switch src.Status {
214 case Present:
215 if len(src.Dimensions) <= 1 {
216
217 switch v := dst.(type) {
218
219 case *[]*net.IPNet:
220 *v = make([]*net.IPNet, len(src.Elements))
221 for i := range src.Elements {
222 if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
223 return err
224 }
225 }
226 return nil
227
228 case *[]net.IP:
229 *v = make([]net.IP, len(src.Elements))
230 for i := range src.Elements {
231 if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
232 return err
233 }
234 }
235 return nil
236
237 case *[]*net.IP:
238 *v = make([]*net.IP, len(src.Elements))
239 for i := range src.Elements {
240 if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
241 return err
242 }
243 }
244 return nil
245
246 }
247 }
248
249
250 if nextDst, retry := GetAssignToDstType(dst); retry {
251 return src.AssignTo(nextDst)
252 }
253
254
255
256
257 value := reflect.ValueOf(dst)
258 if value.Kind() == reflect.Ptr {
259 value = value.Elem()
260 }
261
262 switch value.Kind() {
263 case reflect.Array, reflect.Slice:
264 default:
265 return fmt.Errorf("cannot assign %T to %T", src, dst)
266 }
267
268 if len(src.Elements) == 0 {
269 if value.Kind() == reflect.Slice {
270 value.Set(reflect.MakeSlice(value.Type(), 0, 0))
271 return nil
272 }
273 }
274
275 elementCount, err := src.assignToRecursive(value, 0, 0)
276 if err != nil {
277 return err
278 }
279 if elementCount != len(src.Elements) {
280 return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
281 }
282
283 return nil
284 case Null:
285 return NullAssignTo(dst)
286 }
287
288 return fmt.Errorf("cannot decode %#v into %T", src, dst)
289 }
290
291 func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
292 switch kind := value.Kind(); kind {
293 case reflect.Array:
294 fallthrough
295 case reflect.Slice:
296 if len(src.Dimensions) == dimension {
297 break
298 }
299
300 length := int(src.Dimensions[dimension].Length)
301 if reflect.Array == kind {
302 typ := value.Type()
303 if typ.Len() != length {
304 return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
305 }
306 value.Set(reflect.New(typ).Elem())
307 } else {
308 value.Set(reflect.MakeSlice(value.Type(), length, length))
309 }
310
311 var err error
312 for i := 0; i < length; i++ {
313 index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
314 if err != nil {
315 return 0, err
316 }
317 }
318
319 return index, nil
320 }
321 if len(src.Dimensions) != dimension {
322 return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
323 }
324 if !value.CanAddr() {
325 return 0, fmt.Errorf("cannot assign all values from CIDRArray")
326 }
327 addr := value.Addr()
328 if !addr.CanInterface() {
329 return 0, fmt.Errorf("cannot assign all values from CIDRArray")
330 }
331 if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
332 return 0, err
333 }
334 index++
335 return index, nil
336 }
337
338 func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error {
339 if src == nil {
340 *dst = CIDRArray{Status: Null}
341 return nil
342 }
343
344 uta, err := ParseUntypedTextArray(string(src))
345 if err != nil {
346 return err
347 }
348
349 var elements []CIDR
350
351 if len(uta.Elements) > 0 {
352 elements = make([]CIDR, len(uta.Elements))
353
354 for i, s := range uta.Elements {
355 var elem CIDR
356 var elemSrc []byte
357 if s != "NULL" || uta.Quoted[i] {
358 elemSrc = []byte(s)
359 }
360 err = elem.DecodeText(ci, elemSrc)
361 if err != nil {
362 return err
363 }
364
365 elements[i] = elem
366 }
367 }
368
369 *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
370
371 return nil
372 }
373
374 func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error {
375 if src == nil {
376 *dst = CIDRArray{Status: Null}
377 return nil
378 }
379
380 var arrayHeader ArrayHeader
381 rp, err := arrayHeader.DecodeBinary(ci, src)
382 if err != nil {
383 return err
384 }
385
386 if len(arrayHeader.Dimensions) == 0 {
387 *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present}
388 return nil
389 }
390
391 elementCount := arrayHeader.Dimensions[0].Length
392 for _, d := range arrayHeader.Dimensions[1:] {
393 elementCount *= d.Length
394 }
395
396 elements := make([]CIDR, elementCount)
397
398 for i := range elements {
399 elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
400 rp += 4
401 var elemSrc []byte
402 if elemLen >= 0 {
403 elemSrc = src[rp : rp+elemLen]
404 rp += elemLen
405 }
406 err = elements[i].DecodeBinary(ci, elemSrc)
407 if err != nil {
408 return err
409 }
410 }
411
412 *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
413 return nil
414 }
415
416 func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
417 switch src.Status {
418 case Null:
419 return nil, nil
420 case Undefined:
421 return nil, errUndefined
422 }
423
424 if len(src.Dimensions) == 0 {
425 return append(buf, '{', '}'), nil
426 }
427
428 buf = EncodeTextArrayDimensions(buf, src.Dimensions)
429
430
431
432
433
434
435 dimElemCounts := make([]int, len(src.Dimensions))
436 dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
437 for i := len(src.Dimensions) - 2; i > -1; i-- {
438 dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
439 }
440
441 inElemBuf := make([]byte, 0, 32)
442 for i, elem := range src.Elements {
443 if i > 0 {
444 buf = append(buf, ',')
445 }
446
447 for _, dec := range dimElemCounts {
448 if i%dec == 0 {
449 buf = append(buf, '{')
450 }
451 }
452
453 elemBuf, err := elem.EncodeText(ci, inElemBuf)
454 if err != nil {
455 return nil, err
456 }
457 if elemBuf == nil {
458 buf = append(buf, `NULL`...)
459 } else {
460 buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
461 }
462
463 for _, dec := range dimElemCounts {
464 if (i+1)%dec == 0 {
465 buf = append(buf, '}')
466 }
467 }
468 }
469
470 return buf, nil
471 }
472
473 func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
474 switch src.Status {
475 case Null:
476 return nil, nil
477 case Undefined:
478 return nil, errUndefined
479 }
480
481 arrayHeader := ArrayHeader{
482 Dimensions: src.Dimensions,
483 }
484
485 if dt, ok := ci.DataTypeForName("cidr"); ok {
486 arrayHeader.ElementOID = int32(dt.OID)
487 } else {
488 return nil, fmt.Errorf("unable to find oid for type name %v", "cidr")
489 }
490
491 for i := range src.Elements {
492 if src.Elements[i].Status == Null {
493 arrayHeader.ContainsNull = true
494 break
495 }
496 }
497
498 buf = arrayHeader.EncodeBinary(ci, buf)
499
500 for i := range src.Elements {
501 sp := len(buf)
502 buf = pgio.AppendInt32(buf, -1)
503
504 elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
505 if err != nil {
506 return nil, err
507 }
508 if elemBuf != nil {
509 buf = elemBuf
510 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
511 }
512 }
513
514 return buf, nil
515 }
516
517
518 func (dst *CIDRArray) Scan(src interface{}) error {
519 if src == nil {
520 return dst.DecodeText(nil, nil)
521 }
522
523 switch src := src.(type) {
524 case string:
525 return dst.DecodeText(nil, []byte(src))
526 case []byte:
527 srcCopy := make([]byte, len(src))
528 copy(srcCopy, src)
529 return dst.DecodeText(nil, srcCopy)
530 }
531
532 return fmt.Errorf("cannot scan %T", src)
533 }
534
535
536 func (src CIDRArray) Value() (driver.Value, error) {
537 buf, err := src.EncodeText(nil, nil)
538 if err != nil {
539 return nil, err
540 }
541 if buf == nil {
542 return nil, nil
543 }
544
545 return string(buf), nil
546 }
547
View as plain text