1 package pgtype
2
3 import (
4 "database/sql"
5 "fmt"
6 "math"
7 "reflect"
8 "time"
9 )
10
11 const (
12 maxUint = ^uint(0)
13 maxInt = int(maxUint >> 1)
14 minInt = -maxInt - 1
15 )
16
17
18 func underlyingNumberType(val interface{}) (interface{}, bool) {
19 refVal := reflect.ValueOf(val)
20
21 switch refVal.Kind() {
22 case reflect.Ptr:
23 if refVal.IsNil() {
24 return nil, false
25 }
26 convVal := refVal.Elem().Interface()
27 return convVal, true
28 case reflect.Int:
29 convVal := int(refVal.Int())
30 return convVal, reflect.TypeOf(convVal) != refVal.Type()
31 case reflect.Int8:
32 convVal := int8(refVal.Int())
33 return convVal, reflect.TypeOf(convVal) != refVal.Type()
34 case reflect.Int16:
35 convVal := int16(refVal.Int())
36 return convVal, reflect.TypeOf(convVal) != refVal.Type()
37 case reflect.Int32:
38 convVal := int32(refVal.Int())
39 return convVal, reflect.TypeOf(convVal) != refVal.Type()
40 case reflect.Int64:
41 convVal := int64(refVal.Int())
42 return convVal, reflect.TypeOf(convVal) != refVal.Type()
43 case reflect.Uint:
44 convVal := uint(refVal.Uint())
45 return convVal, reflect.TypeOf(convVal) != refVal.Type()
46 case reflect.Uint8:
47 convVal := uint8(refVal.Uint())
48 return convVal, reflect.TypeOf(convVal) != refVal.Type()
49 case reflect.Uint16:
50 convVal := uint16(refVal.Uint())
51 return convVal, reflect.TypeOf(convVal) != refVal.Type()
52 case reflect.Uint32:
53 convVal := uint32(refVal.Uint())
54 return convVal, reflect.TypeOf(convVal) != refVal.Type()
55 case reflect.Uint64:
56 convVal := uint64(refVal.Uint())
57 return convVal, reflect.TypeOf(convVal) != refVal.Type()
58 case reflect.Float32:
59 convVal := float32(refVal.Float())
60 return convVal, reflect.TypeOf(convVal) != refVal.Type()
61 case reflect.Float64:
62 convVal := refVal.Float()
63 return convVal, reflect.TypeOf(convVal) != refVal.Type()
64 case reflect.String:
65 convVal := refVal.String()
66 return convVal, reflect.TypeOf(convVal) != refVal.Type()
67 }
68
69 return nil, false
70 }
71
72
73 func underlyingBoolType(val interface{}) (interface{}, bool) {
74 refVal := reflect.ValueOf(val)
75
76 switch refVal.Kind() {
77 case reflect.Ptr:
78 if refVal.IsNil() {
79 return nil, false
80 }
81 convVal := refVal.Elem().Interface()
82 return convVal, true
83 case reflect.Bool:
84 convVal := refVal.Bool()
85 return convVal, reflect.TypeOf(convVal) != refVal.Type()
86 }
87
88 return nil, false
89 }
90
91
92 func underlyingBytesType(val interface{}) (interface{}, bool) {
93 refVal := reflect.ValueOf(val)
94
95 switch refVal.Kind() {
96 case reflect.Ptr:
97 if refVal.IsNil() {
98 return nil, false
99 }
100 convVal := refVal.Elem().Interface()
101 return convVal, true
102 case reflect.Slice:
103 if refVal.Type().Elem().Kind() == reflect.Uint8 {
104 convVal := refVal.Bytes()
105 return convVal, reflect.TypeOf(convVal) != refVal.Type()
106 }
107 }
108
109 return nil, false
110 }
111
112
113 func underlyingStringType(val interface{}) (interface{}, bool) {
114 refVal := reflect.ValueOf(val)
115
116 switch refVal.Kind() {
117 case reflect.Ptr:
118 if refVal.IsNil() {
119 return nil, false
120 }
121 convVal := refVal.Elem().Interface()
122 return convVal, true
123 case reflect.String:
124 convVal := refVal.String()
125 return convVal, reflect.TypeOf(convVal) != refVal.Type()
126 }
127
128 return nil, false
129 }
130
131
132 func underlyingPtrType(val interface{}) (interface{}, bool) {
133 refVal := reflect.ValueOf(val)
134
135 switch refVal.Kind() {
136 case reflect.Ptr:
137 if refVal.IsNil() {
138 return nil, false
139 }
140 convVal := refVal.Elem().Interface()
141 return convVal, true
142 }
143
144 return nil, false
145 }
146
147
148 func underlyingTimeType(val interface{}) (interface{}, bool) {
149 refVal := reflect.ValueOf(val)
150
151 switch refVal.Kind() {
152 case reflect.Ptr:
153 if refVal.IsNil() {
154 return nil, false
155 }
156 convVal := refVal.Elem().Interface()
157 return convVal, true
158 }
159
160 timeType := reflect.TypeOf(time.Time{})
161 if refVal.Type().ConvertibleTo(timeType) {
162 return refVal.Convert(timeType).Interface(), true
163 }
164
165 return nil, false
166 }
167
168
169 func underlyingUUIDType(val interface{}) (interface{}, bool) {
170 refVal := reflect.ValueOf(val)
171
172 switch refVal.Kind() {
173 case reflect.Ptr:
174 if refVal.IsNil() {
175 return nil, false
176 }
177 convVal := refVal.Elem().Interface()
178 return convVal, true
179 }
180
181 uuidType := reflect.TypeOf([16]byte{})
182 if refVal.Type().ConvertibleTo(uuidType) {
183 return refVal.Convert(uuidType).Interface(), true
184 }
185
186 return nil, false
187 }
188
189
190 func underlyingSliceType(val interface{}) (interface{}, bool) {
191 refVal := reflect.ValueOf(val)
192
193 switch refVal.Kind() {
194 case reflect.Ptr:
195 if refVal.IsNil() {
196 return nil, false
197 }
198 convVal := refVal.Elem().Interface()
199 return convVal, true
200 case reflect.Slice:
201 baseSliceType := reflect.SliceOf(refVal.Type().Elem())
202 if refVal.Type().ConvertibleTo(baseSliceType) {
203 convVal := refVal.Convert(baseSliceType)
204 return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
205 }
206 }
207
208 return nil, false
209 }
210
211 func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error {
212 if srcStatus == Present {
213 switch v := dst.(type) {
214 case *int:
215 if srcVal < int64(minInt) {
216 return fmt.Errorf("%d is less than minimum value for int", srcVal)
217 } else if srcVal > int64(maxInt) {
218 return fmt.Errorf("%d is greater than maximum value for int", srcVal)
219 }
220 *v = int(srcVal)
221 case *int8:
222 if srcVal < math.MinInt8 {
223 return fmt.Errorf("%d is less than minimum value for int8", srcVal)
224 } else if srcVal > math.MaxInt8 {
225 return fmt.Errorf("%d is greater than maximum value for int8", srcVal)
226 }
227 *v = int8(srcVal)
228 case *int16:
229 if srcVal < math.MinInt16 {
230 return fmt.Errorf("%d is less than minimum value for int16", srcVal)
231 } else if srcVal > math.MaxInt16 {
232 return fmt.Errorf("%d is greater than maximum value for int16", srcVal)
233 }
234 *v = int16(srcVal)
235 case *int32:
236 if srcVal < math.MinInt32 {
237 return fmt.Errorf("%d is less than minimum value for int32", srcVal)
238 } else if srcVal > math.MaxInt32 {
239 return fmt.Errorf("%d is greater than maximum value for int32", srcVal)
240 }
241 *v = int32(srcVal)
242 case *int64:
243 if srcVal < math.MinInt64 {
244 return fmt.Errorf("%d is less than minimum value for int64", srcVal)
245 } else if srcVal > math.MaxInt64 {
246 return fmt.Errorf("%d is greater than maximum value for int64", srcVal)
247 }
248 *v = int64(srcVal)
249 case *uint:
250 if srcVal < 0 {
251 return fmt.Errorf("%d is less than zero for uint", srcVal)
252 } else if uint64(srcVal) > uint64(maxUint) {
253 return fmt.Errorf("%d is greater than maximum value for uint", srcVal)
254 }
255 *v = uint(srcVal)
256 case *uint8:
257 if srcVal < 0 {
258 return fmt.Errorf("%d is less than zero for uint8", srcVal)
259 } else if srcVal > math.MaxUint8 {
260 return fmt.Errorf("%d is greater than maximum value for uint8", srcVal)
261 }
262 *v = uint8(srcVal)
263 case *uint16:
264 if srcVal < 0 {
265 return fmt.Errorf("%d is less than zero for uint32", srcVal)
266 } else if srcVal > math.MaxUint16 {
267 return fmt.Errorf("%d is greater than maximum value for uint16", srcVal)
268 }
269 *v = uint16(srcVal)
270 case *uint32:
271 if srcVal < 0 {
272 return fmt.Errorf("%d is less than zero for uint32", srcVal)
273 } else if srcVal > math.MaxUint32 {
274 return fmt.Errorf("%d is greater than maximum value for uint32", srcVal)
275 }
276 *v = uint32(srcVal)
277 case *uint64:
278 if srcVal < 0 {
279 return fmt.Errorf("%d is less than zero for uint64", srcVal)
280 }
281 *v = uint64(srcVal)
282 case sql.Scanner:
283 return v.Scan(srcVal)
284 default:
285 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
286 el := v.Elem()
287 switch el.Kind() {
288
289 case reflect.Ptr:
290 if el.IsNil() {
291
292 el.Set(reflect.New(el.Type().Elem()))
293 }
294 return int64AssignTo(srcVal, srcStatus, el.Interface())
295 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
296 if el.OverflowInt(int64(srcVal)) {
297 return fmt.Errorf("cannot put %d into %T", srcVal, dst)
298 }
299 el.SetInt(int64(srcVal))
300 return nil
301 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
302 if srcVal < 0 {
303 return fmt.Errorf("%d is less than zero for %T", srcVal, dst)
304 }
305 if el.OverflowUint(uint64(srcVal)) {
306 return fmt.Errorf("cannot put %d into %T", srcVal, dst)
307 }
308 el.SetUint(uint64(srcVal))
309 return nil
310 }
311 }
312 return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
313 }
314 return nil
315 }
316
317
318 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
319 el := v.Elem()
320 if el.Kind() == reflect.Ptr {
321 el.Set(reflect.Zero(el.Type()))
322 return nil
323 }
324 }
325
326 return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
327 }
328
329 func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
330 if srcStatus == Present {
331 switch v := dst.(type) {
332 case *float32:
333 *v = float32(srcVal)
334 case *float64:
335 *v = srcVal
336 default:
337 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
338 el := v.Elem()
339 switch el.Kind() {
340
341 case reflect.Float32, reflect.Float64:
342 el.SetFloat(srcVal)
343 return nil
344
345 case reflect.Ptr:
346 if el.IsNil() {
347
348 el.Set(reflect.New(el.Type().Elem()))
349 }
350 return float64AssignTo(srcVal, srcStatus, el.Interface())
351 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
352 i64 := int64(srcVal)
353 if float64(i64) == srcVal {
354 return int64AssignTo(i64, srcStatus, dst)
355 }
356 }
357 }
358 return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
359 }
360 return nil
361 }
362
363
364 if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
365 el := v.Elem()
366 if el.Kind() == reflect.Ptr {
367 el.Set(reflect.Zero(el.Type()))
368 return nil
369 }
370 }
371
372 return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
373 }
374
375 func NullAssignTo(dst interface{}) error {
376 dstPtr := reflect.ValueOf(dst)
377
378
379 if dstPtr.Kind() != reflect.Ptr {
380 return &nullAssignmentError{dst: dst}
381 }
382
383 dstVal := dstPtr.Elem()
384
385 switch dstVal.Kind() {
386 case reflect.Ptr, reflect.Slice, reflect.Map:
387 dstVal.Set(reflect.Zero(dstVal.Type()))
388 return nil
389 }
390
391 return &nullAssignmentError{dst: dst}
392 }
393
394 var kindTypes map[reflect.Kind]reflect.Type
395
396 func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) {
397 nextDst := dst.Convert(t)
398 return nextDst.Interface(), dst.Type() != nextDst.Type()
399 }
400
401
402
403
404
405
406
407
408 func GetAssignToDstType(dst interface{}) (interface{}, bool) {
409 dstPtr := reflect.ValueOf(dst)
410
411
412 if dstPtr.Kind() != reflect.Ptr {
413 return nil, false
414 }
415
416 dstVal := dstPtr.Elem()
417
418
419 if dstVal.Kind() == reflect.Ptr {
420 dstVal.Set(reflect.New(dstVal.Type().Elem()))
421 return dstVal.Interface(), true
422 }
423
424
425 if baseValType, ok := kindTypes[dstVal.Kind()]; ok {
426 return toInterface(dstPtr, reflect.PtrTo(baseValType))
427 }
428
429 if dstVal.Kind() == reflect.Slice {
430 if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
431 return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType)))
432 }
433 }
434
435 if dstVal.Kind() == reflect.Array {
436 if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
437 return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)))
438 }
439 }
440
441 if dstVal.Kind() == reflect.Struct {
442 if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous {
443 dstPtr = dstVal.Field(0).Addr()
444 nested := dstVal.Type().Field(0).Type
445 if nested.Kind() == reflect.Array {
446 if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok {
447 return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType)))
448 }
449 }
450 if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() {
451 return dstPtr.Interface(), true
452 }
453 }
454 }
455
456 return nil, false
457 }
458
459 func init() {
460 kindTypes = map[reflect.Kind]reflect.Type{
461 reflect.Bool: reflect.TypeOf(false),
462 reflect.Float32: reflect.TypeOf(float32(0)),
463 reflect.Float64: reflect.TypeOf(float64(0)),
464 reflect.Int: reflect.TypeOf(int(0)),
465 reflect.Int8: reflect.TypeOf(int8(0)),
466 reflect.Int16: reflect.TypeOf(int16(0)),
467 reflect.Int32: reflect.TypeOf(int32(0)),
468 reflect.Int64: reflect.TypeOf(int64(0)),
469 reflect.Uint: reflect.TypeOf(uint(0)),
470 reflect.Uint8: reflect.TypeOf(uint8(0)),
471 reflect.Uint16: reflect.TypeOf(uint16(0)),
472 reflect.Uint32: reflect.TypeOf(uint32(0)),
473 reflect.Uint64: reflect.TypeOf(uint64(0)),
474 reflect.String: reflect.TypeOf(""),
475 }
476 }
477
View as plain text