...

Source file src/github.com/jackc/pgtype/convert.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"math"
     7  	"reflect"
     8  	"time"
     9  )
    10  
    11  const (
    12  	maxUint = ^uint(0)
    13  	maxInt  = int(maxUint >> 1)
    14  	minInt  = -maxInt - 1
    15  )
    16  
    17  // underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8
    18  func underlyingNumberType(val interface{}) (interface{}, bool) {
    19  	refVal := reflect.ValueOf(val)
    20  
    21  	switch refVal.Kind() {
    22  	case reflect.Ptr:
    23  		if refVal.IsNil() {
    24  			return nil, false
    25  		}
    26  		convVal := refVal.Elem().Interface()
    27  		return convVal, true
    28  	case reflect.Int:
    29  		convVal := int(refVal.Int())
    30  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    31  	case reflect.Int8:
    32  		convVal := int8(refVal.Int())
    33  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    34  	case reflect.Int16:
    35  		convVal := int16(refVal.Int())
    36  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    37  	case reflect.Int32:
    38  		convVal := int32(refVal.Int())
    39  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    40  	case reflect.Int64:
    41  		convVal := int64(refVal.Int())
    42  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    43  	case reflect.Uint:
    44  		convVal := uint(refVal.Uint())
    45  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    46  	case reflect.Uint8:
    47  		convVal := uint8(refVal.Uint())
    48  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    49  	case reflect.Uint16:
    50  		convVal := uint16(refVal.Uint())
    51  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    52  	case reflect.Uint32:
    53  		convVal := uint32(refVal.Uint())
    54  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    55  	case reflect.Uint64:
    56  		convVal := uint64(refVal.Uint())
    57  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    58  	case reflect.Float32:
    59  		convVal := float32(refVal.Float())
    60  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    61  	case reflect.Float64:
    62  		convVal := refVal.Float()
    63  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    64  	case reflect.String:
    65  		convVal := refVal.String()
    66  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    67  	}
    68  
    69  	return nil, false
    70  }
    71  
    72  // underlyingBoolType gets the underlying type that can be converted to Bool
    73  func underlyingBoolType(val interface{}) (interface{}, bool) {
    74  	refVal := reflect.ValueOf(val)
    75  
    76  	switch refVal.Kind() {
    77  	case reflect.Ptr:
    78  		if refVal.IsNil() {
    79  			return nil, false
    80  		}
    81  		convVal := refVal.Elem().Interface()
    82  		return convVal, true
    83  	case reflect.Bool:
    84  		convVal := refVal.Bool()
    85  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
    86  	}
    87  
    88  	return nil, false
    89  }
    90  
    91  // underlyingBytesType gets the underlying type that can be converted to []byte
    92  func underlyingBytesType(val interface{}) (interface{}, bool) {
    93  	refVal := reflect.ValueOf(val)
    94  
    95  	switch refVal.Kind() {
    96  	case reflect.Ptr:
    97  		if refVal.IsNil() {
    98  			return nil, false
    99  		}
   100  		convVal := refVal.Elem().Interface()
   101  		return convVal, true
   102  	case reflect.Slice:
   103  		if refVal.Type().Elem().Kind() == reflect.Uint8 {
   104  			convVal := refVal.Bytes()
   105  			return convVal, reflect.TypeOf(convVal) != refVal.Type()
   106  		}
   107  	}
   108  
   109  	return nil, false
   110  }
   111  
   112  // underlyingStringType gets the underlying type that can be converted to String
   113  func underlyingStringType(val interface{}) (interface{}, bool) {
   114  	refVal := reflect.ValueOf(val)
   115  
   116  	switch refVal.Kind() {
   117  	case reflect.Ptr:
   118  		if refVal.IsNil() {
   119  			return nil, false
   120  		}
   121  		convVal := refVal.Elem().Interface()
   122  		return convVal, true
   123  	case reflect.String:
   124  		convVal := refVal.String()
   125  		return convVal, reflect.TypeOf(convVal) != refVal.Type()
   126  	}
   127  
   128  	return nil, false
   129  }
   130  
   131  // underlyingPtrType dereferences a pointer
   132  func underlyingPtrType(val interface{}) (interface{}, bool) {
   133  	refVal := reflect.ValueOf(val)
   134  
   135  	switch refVal.Kind() {
   136  	case reflect.Ptr:
   137  		if refVal.IsNil() {
   138  			return nil, false
   139  		}
   140  		convVal := refVal.Elem().Interface()
   141  		return convVal, true
   142  	}
   143  
   144  	return nil, false
   145  }
   146  
   147  // underlyingTimeType gets the underlying type that can be converted to time.Time
   148  func underlyingTimeType(val interface{}) (interface{}, bool) {
   149  	refVal := reflect.ValueOf(val)
   150  
   151  	switch refVal.Kind() {
   152  	case reflect.Ptr:
   153  		if refVal.IsNil() {
   154  			return nil, false
   155  		}
   156  		convVal := refVal.Elem().Interface()
   157  		return convVal, true
   158  	}
   159  
   160  	timeType := reflect.TypeOf(time.Time{})
   161  	if refVal.Type().ConvertibleTo(timeType) {
   162  		return refVal.Convert(timeType).Interface(), true
   163  	}
   164  
   165  	return nil, false
   166  }
   167  
   168  // underlyingUUIDType gets the underlying type that can be converted to [16]byte
   169  func underlyingUUIDType(val interface{}) (interface{}, bool) {
   170  	refVal := reflect.ValueOf(val)
   171  
   172  	switch refVal.Kind() {
   173  	case reflect.Ptr:
   174  		if refVal.IsNil() {
   175  			return nil, false
   176  		}
   177  		convVal := refVal.Elem().Interface()
   178  		return convVal, true
   179  	}
   180  
   181  	uuidType := reflect.TypeOf([16]byte{})
   182  	if refVal.Type().ConvertibleTo(uuidType) {
   183  		return refVal.Convert(uuidType).Interface(), true
   184  	}
   185  
   186  	return nil, false
   187  }
   188  
   189  // underlyingSliceType gets the underlying slice type
   190  func underlyingSliceType(val interface{}) (interface{}, bool) {
   191  	refVal := reflect.ValueOf(val)
   192  
   193  	switch refVal.Kind() {
   194  	case reflect.Ptr:
   195  		if refVal.IsNil() {
   196  			return nil, false
   197  		}
   198  		convVal := refVal.Elem().Interface()
   199  		return convVal, true
   200  	case reflect.Slice:
   201  		baseSliceType := reflect.SliceOf(refVal.Type().Elem())
   202  		if refVal.Type().ConvertibleTo(baseSliceType) {
   203  			convVal := refVal.Convert(baseSliceType)
   204  			return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
   205  		}
   206  	}
   207  
   208  	return nil, false
   209  }
   210  
   211  func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error {
   212  	if srcStatus == Present {
   213  		switch v := dst.(type) {
   214  		case *int:
   215  			if srcVal < int64(minInt) {
   216  				return fmt.Errorf("%d is less than minimum value for int", srcVal)
   217  			} else if srcVal > int64(maxInt) {
   218  				return fmt.Errorf("%d is greater than maximum value for int", srcVal)
   219  			}
   220  			*v = int(srcVal)
   221  		case *int8:
   222  			if srcVal < math.MinInt8 {
   223  				return fmt.Errorf("%d is less than minimum value for int8", srcVal)
   224  			} else if srcVal > math.MaxInt8 {
   225  				return fmt.Errorf("%d is greater than maximum value for int8", srcVal)
   226  			}
   227  			*v = int8(srcVal)
   228  		case *int16:
   229  			if srcVal < math.MinInt16 {
   230  				return fmt.Errorf("%d is less than minimum value for int16", srcVal)
   231  			} else if srcVal > math.MaxInt16 {
   232  				return fmt.Errorf("%d is greater than maximum value for int16", srcVal)
   233  			}
   234  			*v = int16(srcVal)
   235  		case *int32:
   236  			if srcVal < math.MinInt32 {
   237  				return fmt.Errorf("%d is less than minimum value for int32", srcVal)
   238  			} else if srcVal > math.MaxInt32 {
   239  				return fmt.Errorf("%d is greater than maximum value for int32", srcVal)
   240  			}
   241  			*v = int32(srcVal)
   242  		case *int64:
   243  			if srcVal < math.MinInt64 {
   244  				return fmt.Errorf("%d is less than minimum value for int64", srcVal)
   245  			} else if srcVal > math.MaxInt64 {
   246  				return fmt.Errorf("%d is greater than maximum value for int64", srcVal)
   247  			}
   248  			*v = int64(srcVal)
   249  		case *uint:
   250  			if srcVal < 0 {
   251  				return fmt.Errorf("%d is less than zero for uint", srcVal)
   252  			} else if uint64(srcVal) > uint64(maxUint) {
   253  				return fmt.Errorf("%d is greater than maximum value for uint", srcVal)
   254  			}
   255  			*v = uint(srcVal)
   256  		case *uint8:
   257  			if srcVal < 0 {
   258  				return fmt.Errorf("%d is less than zero for uint8", srcVal)
   259  			} else if srcVal > math.MaxUint8 {
   260  				return fmt.Errorf("%d is greater than maximum value for uint8", srcVal)
   261  			}
   262  			*v = uint8(srcVal)
   263  		case *uint16:
   264  			if srcVal < 0 {
   265  				return fmt.Errorf("%d is less than zero for uint32", srcVal)
   266  			} else if srcVal > math.MaxUint16 {
   267  				return fmt.Errorf("%d is greater than maximum value for uint16", srcVal)
   268  			}
   269  			*v = uint16(srcVal)
   270  		case *uint32:
   271  			if srcVal < 0 {
   272  				return fmt.Errorf("%d is less than zero for uint32", srcVal)
   273  			} else if srcVal > math.MaxUint32 {
   274  				return fmt.Errorf("%d is greater than maximum value for uint32", srcVal)
   275  			}
   276  			*v = uint32(srcVal)
   277  		case *uint64:
   278  			if srcVal < 0 {
   279  				return fmt.Errorf("%d is less than zero for uint64", srcVal)
   280  			}
   281  			*v = uint64(srcVal)
   282  		case sql.Scanner:
   283  			return v.Scan(srcVal)
   284  		default:
   285  			if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
   286  				el := v.Elem()
   287  				switch el.Kind() {
   288  				// if dst is a pointer to pointer, strip the pointer and try again
   289  				case reflect.Ptr:
   290  					if el.IsNil() {
   291  						// allocate destination
   292  						el.Set(reflect.New(el.Type().Elem()))
   293  					}
   294  					return int64AssignTo(srcVal, srcStatus, el.Interface())
   295  				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   296  					if el.OverflowInt(int64(srcVal)) {
   297  						return fmt.Errorf("cannot put %d into %T", srcVal, dst)
   298  					}
   299  					el.SetInt(int64(srcVal))
   300  					return nil
   301  				case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   302  					if srcVal < 0 {
   303  						return fmt.Errorf("%d is less than zero for %T", srcVal, dst)
   304  					}
   305  					if el.OverflowUint(uint64(srcVal)) {
   306  						return fmt.Errorf("cannot put %d into %T", srcVal, dst)
   307  					}
   308  					el.SetUint(uint64(srcVal))
   309  					return nil
   310  				}
   311  			}
   312  			return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
   313  		}
   314  		return nil
   315  	}
   316  
   317  	// if dst is a pointer to pointer and srcStatus is not Present, nil it out
   318  	if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
   319  		el := v.Elem()
   320  		if el.Kind() == reflect.Ptr {
   321  			el.Set(reflect.Zero(el.Type()))
   322  			return nil
   323  		}
   324  	}
   325  
   326  	return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
   327  }
   328  
   329  func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error {
   330  	if srcStatus == Present {
   331  		switch v := dst.(type) {
   332  		case *float32:
   333  			*v = float32(srcVal)
   334  		case *float64:
   335  			*v = srcVal
   336  		default:
   337  			if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
   338  				el := v.Elem()
   339  				switch el.Kind() {
   340  				// if dst is a type alias of a float32 or 64, set dst val
   341  				case reflect.Float32, reflect.Float64:
   342  					el.SetFloat(srcVal)
   343  					return nil
   344  				// if dst is a pointer to pointer, strip the pointer and try again
   345  				case reflect.Ptr:
   346  					if el.IsNil() {
   347  						// allocate destination
   348  						el.Set(reflect.New(el.Type().Elem()))
   349  					}
   350  					return float64AssignTo(srcVal, srcStatus, el.Interface())
   351  				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   352  					i64 := int64(srcVal)
   353  					if float64(i64) == srcVal {
   354  						return int64AssignTo(i64, srcStatus, dst)
   355  					}
   356  				}
   357  			}
   358  			return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
   359  		}
   360  		return nil
   361  	}
   362  
   363  	// if dst is a pointer to pointer and srcStatus is not Present, nil it out
   364  	if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
   365  		el := v.Elem()
   366  		if el.Kind() == reflect.Ptr {
   367  			el.Set(reflect.Zero(el.Type()))
   368  			return nil
   369  		}
   370  	}
   371  
   372  	return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst)
   373  }
   374  
   375  func NullAssignTo(dst interface{}) error {
   376  	dstPtr := reflect.ValueOf(dst)
   377  
   378  	// AssignTo dst must always be a pointer
   379  	if dstPtr.Kind() != reflect.Ptr {
   380  		return &nullAssignmentError{dst: dst}
   381  	}
   382  
   383  	dstVal := dstPtr.Elem()
   384  
   385  	switch dstVal.Kind() {
   386  	case reflect.Ptr, reflect.Slice, reflect.Map:
   387  		dstVal.Set(reflect.Zero(dstVal.Type()))
   388  		return nil
   389  	}
   390  
   391  	return &nullAssignmentError{dst: dst}
   392  }
   393  
   394  var kindTypes map[reflect.Kind]reflect.Type
   395  
   396  func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) {
   397  	nextDst := dst.Convert(t)
   398  	return nextDst.Interface(), dst.Type() != nextDst.Type()
   399  }
   400  
   401  // GetAssignToDstType attempts to convert dst to something AssignTo can assign
   402  // to. If dst is a pointer to pointer it allocates a value and returns the
   403  // dereferences pointer. If dst is a named type such as *Foo where Foo is type
   404  // Foo int16, it converts dst to *int16.
   405  //
   406  // GetAssignToDstType returns the converted dst and a bool representing if any
   407  // change was made.
   408  func GetAssignToDstType(dst interface{}) (interface{}, bool) {
   409  	dstPtr := reflect.ValueOf(dst)
   410  
   411  	// AssignTo dst must always be a pointer
   412  	if dstPtr.Kind() != reflect.Ptr {
   413  		return nil, false
   414  	}
   415  
   416  	dstVal := dstPtr.Elem()
   417  
   418  	// if dst is a pointer to pointer, allocate space try again with the dereferenced pointer
   419  	if dstVal.Kind() == reflect.Ptr {
   420  		dstVal.Set(reflect.New(dstVal.Type().Elem()))
   421  		return dstVal.Interface(), true
   422  	}
   423  
   424  	// if dst is pointer to a base type that has been renamed
   425  	if baseValType, ok := kindTypes[dstVal.Kind()]; ok {
   426  		return toInterface(dstPtr, reflect.PtrTo(baseValType))
   427  	}
   428  
   429  	if dstVal.Kind() == reflect.Slice {
   430  		if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
   431  			return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType)))
   432  		}
   433  	}
   434  
   435  	if dstVal.Kind() == reflect.Array {
   436  		if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
   437  			return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)))
   438  		}
   439  	}
   440  
   441  	if dstVal.Kind() == reflect.Struct {
   442  		if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous {
   443  			dstPtr = dstVal.Field(0).Addr()
   444  			nested := dstVal.Type().Field(0).Type
   445  			if nested.Kind() == reflect.Array {
   446  				if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok {
   447  					return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType)))
   448  				}
   449  			}
   450  			if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() {
   451  				return dstPtr.Interface(), true
   452  			}
   453  		}
   454  	}
   455  
   456  	return nil, false
   457  }
   458  
   459  func init() {
   460  	kindTypes = map[reflect.Kind]reflect.Type{
   461  		reflect.Bool:    reflect.TypeOf(false),
   462  		reflect.Float32: reflect.TypeOf(float32(0)),
   463  		reflect.Float64: reflect.TypeOf(float64(0)),
   464  		reflect.Int:     reflect.TypeOf(int(0)),
   465  		reflect.Int8:    reflect.TypeOf(int8(0)),
   466  		reflect.Int16:   reflect.TypeOf(int16(0)),
   467  		reflect.Int32:   reflect.TypeOf(int32(0)),
   468  		reflect.Int64:   reflect.TypeOf(int64(0)),
   469  		reflect.Uint:    reflect.TypeOf(uint(0)),
   470  		reflect.Uint8:   reflect.TypeOf(uint8(0)),
   471  		reflect.Uint16:  reflect.TypeOf(uint16(0)),
   472  		reflect.Uint32:  reflect.TypeOf(uint32(0)),
   473  		reflect.Uint64:  reflect.TypeOf(uint64(0)),
   474  		reflect.String:  reflect.TypeOf(""),
   475  	}
   476  }
   477  

View as plain text