1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "fmt"
7 "reflect"
8
9 "github.com/jackc/pgio"
10 )
11
12
13
14
15 type ArrayType struct {
16 elements []ValueTranscoder
17 dimensions []ArrayDimension
18
19 typeName string
20 newElement func() ValueTranscoder
21
22 elementOID uint32
23 status Status
24 }
25
26 func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType {
27 return &ArrayType{typeName: typeName, elementOID: elementOID, newElement: newElement}
28 }
29
30 func (at *ArrayType) NewTypeValue() Value {
31 return &ArrayType{
32 elements: at.elements,
33 dimensions: at.dimensions,
34 status: at.status,
35
36 typeName: at.typeName,
37 elementOID: at.elementOID,
38 newElement: at.newElement,
39 }
40 }
41
42 func (at *ArrayType) TypeName() string {
43 return at.typeName
44 }
45
46 func (dst *ArrayType) setNil() {
47 dst.elements = nil
48 dst.dimensions = nil
49 dst.status = Null
50 }
51
52 func (dst *ArrayType) Set(src interface{}) error {
53
54 if src == nil {
55 dst.setNil()
56 return nil
57 }
58
59 sliceVal := reflect.ValueOf(src)
60 if sliceVal.Kind() != reflect.Slice {
61 return fmt.Errorf("cannot set non-slice")
62 }
63
64 if sliceVal.IsNil() {
65 dst.setNil()
66 return nil
67 }
68
69 dst.elements = make([]ValueTranscoder, sliceVal.Len())
70 for i := range dst.elements {
71 v := dst.newElement()
72 err := v.Set(sliceVal.Index(i).Interface())
73 if err != nil {
74 return err
75 }
76
77 dst.elements[i] = v
78 }
79 dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}}
80 dst.status = Present
81
82 return nil
83 }
84
85 func (dst ArrayType) Get() interface{} {
86 switch dst.status {
87 case Present:
88 elementValues := make([]interface{}, len(dst.elements))
89 for i := range dst.elements {
90 elementValues[i] = dst.elements[i].Get()
91 }
92 return elementValues
93 case Null:
94 return nil
95 default:
96 return dst.status
97 }
98 }
99
100 func (src *ArrayType) AssignTo(dst interface{}) error {
101 ptrSlice := reflect.ValueOf(dst)
102 if ptrSlice.Kind() != reflect.Ptr {
103 return fmt.Errorf("cannot assign to non-pointer")
104 }
105
106 sliceVal := ptrSlice.Elem()
107 sliceType := sliceVal.Type()
108
109 if sliceType.Kind() != reflect.Slice {
110 return fmt.Errorf("cannot assign to pointer to non-slice")
111 }
112
113 switch src.status {
114 case Present:
115 slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements))
116 elemType := sliceType.Elem()
117
118 for i := range src.elements {
119 ptrElem := reflect.New(elemType)
120 err := src.elements[i].AssignTo(ptrElem.Interface())
121 if err != nil {
122 return err
123 }
124
125 slice.Index(i).Set(ptrElem.Elem())
126 }
127
128 sliceVal.Set(slice)
129 return nil
130 case Null:
131 sliceVal.Set(reflect.Zero(sliceType))
132 return nil
133 }
134
135 return fmt.Errorf("cannot decode %#v into %T", src, dst)
136 }
137
138 func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
139 if src == nil {
140 dst.setNil()
141 return nil
142 }
143
144 uta, err := ParseUntypedTextArray(string(src))
145 if err != nil {
146 return err
147 }
148
149 var elements []ValueTranscoder
150
151 if len(uta.Elements) > 0 {
152 elements = make([]ValueTranscoder, len(uta.Elements))
153
154 for i, s := range uta.Elements {
155 elem := dst.newElement()
156 var elemSrc []byte
157 if s != "NULL" {
158 elemSrc = []byte(s)
159 }
160 err = elem.DecodeText(ci, elemSrc)
161 if err != nil {
162 return err
163 }
164
165 elements[i] = elem
166 }
167 }
168
169 dst.elements = elements
170 dst.dimensions = uta.Dimensions
171 dst.status = Present
172
173 return nil
174 }
175
176 func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error {
177 if src == nil {
178 dst.setNil()
179 return nil
180 }
181
182 var arrayHeader ArrayHeader
183 rp, err := arrayHeader.DecodeBinary(ci, src)
184 if err != nil {
185 return err
186 }
187
188 var elements []ValueTranscoder
189
190 if len(arrayHeader.Dimensions) == 0 {
191 dst.elements = elements
192 dst.dimensions = arrayHeader.Dimensions
193 dst.status = Present
194 return nil
195 }
196
197 elementCount := arrayHeader.Dimensions[0].Length
198 for _, d := range arrayHeader.Dimensions[1:] {
199 elementCount *= d.Length
200 }
201
202 elements = make([]ValueTranscoder, elementCount)
203
204 for i := range elements {
205 elem := dst.newElement()
206 elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
207 rp += 4
208 var elemSrc []byte
209 if elemLen >= 0 {
210 elemSrc = src[rp : rp+elemLen]
211 rp += elemLen
212 }
213 err = elem.DecodeBinary(ci, elemSrc)
214 if err != nil {
215 return err
216 }
217
218 elements[i] = elem
219 }
220
221 dst.elements = elements
222 dst.dimensions = arrayHeader.Dimensions
223 dst.status = Present
224
225 return nil
226 }
227
228 func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
229 switch src.status {
230 case Null:
231 return nil, nil
232 case Undefined:
233 return nil, errUndefined
234 }
235
236 if len(src.dimensions) == 0 {
237 return append(buf, '{', '}'), nil
238 }
239
240 buf = EncodeTextArrayDimensions(buf, src.dimensions)
241
242
243
244
245
246
247 dimElemCounts := make([]int, len(src.dimensions))
248 dimElemCounts[len(src.dimensions)-1] = int(src.dimensions[len(src.dimensions)-1].Length)
249 for i := len(src.dimensions) - 2; i > -1; i-- {
250 dimElemCounts[i] = int(src.dimensions[i].Length) * dimElemCounts[i+1]
251 }
252
253 inElemBuf := make([]byte, 0, 32)
254 for i, elem := range src.elements {
255 if i > 0 {
256 buf = append(buf, ',')
257 }
258
259 for _, dec := range dimElemCounts {
260 if i%dec == 0 {
261 buf = append(buf, '{')
262 }
263 }
264
265 elemBuf, err := elem.EncodeText(ci, inElemBuf)
266 if err != nil {
267 return nil, err
268 }
269 if elemBuf == nil {
270 buf = append(buf, `NULL`...)
271 } else {
272 buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
273 }
274
275 for _, dec := range dimElemCounts {
276 if (i+1)%dec == 0 {
277 buf = append(buf, '}')
278 }
279 }
280 }
281
282 return buf, nil
283 }
284
285 func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
286 switch src.status {
287 case Null:
288 return nil, nil
289 case Undefined:
290 return nil, errUndefined
291 }
292
293 arrayHeader := ArrayHeader{
294 Dimensions: src.dimensions,
295 ElementOID: int32(src.elementOID),
296 }
297
298 for i := range src.elements {
299 if src.elements[i].Get() == nil {
300 arrayHeader.ContainsNull = true
301 break
302 }
303 }
304
305 buf = arrayHeader.EncodeBinary(ci, buf)
306
307 for i := range src.elements {
308 sp := len(buf)
309 buf = pgio.AppendInt32(buf, -1)
310
311 elemBuf, err := src.elements[i].EncodeBinary(ci, buf)
312 if err != nil {
313 return nil, err
314 }
315 if elemBuf != nil {
316 buf = elemBuf
317 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
318 }
319 }
320
321 return buf, nil
322 }
323
324
325 func (dst *ArrayType) Scan(src interface{}) error {
326 if src == nil {
327 return dst.DecodeText(nil, nil)
328 }
329
330 switch src := src.(type) {
331 case string:
332 return dst.DecodeText(nil, []byte(src))
333 case []byte:
334 srcCopy := make([]byte, len(src))
335 copy(srcCopy, src)
336 return dst.DecodeText(nil, srcCopy)
337 }
338
339 return fmt.Errorf("cannot scan %T", src)
340 }
341
342
343 func (src ArrayType) Value() (driver.Value, error) {
344 buf, err := src.EncodeText(nil, nil)
345 if err != nil {
346 return nil, err
347 }
348 if buf == nil {
349 return nil, nil
350 }
351
352 return string(buf), nil
353 }
354
View as plain text