...

Source file src/github.com/jackc/pgtype/numeric.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql/driver"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"math"
     9  	"math/big"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/jackc/pgio"
    14  )
    15  
    16  // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
    17  const nbase = 10000
    18  
    19  const (
    20  	pgNumericNaN     = 0x00000000c0000000
    21  	pgNumericNaNSign = 0xc000
    22  
    23  	pgNumericPosInf     = 0x00000000d0000000
    24  	pgNumericPosInfSign = 0xd000
    25  
    26  	pgNumericNegInf     = 0x00000000f0000000
    27  	pgNumericNegInfSign = 0xf000
    28  )
    29  
    30  var big0 *big.Int = big.NewInt(0)
    31  var big1 *big.Int = big.NewInt(1)
    32  var big10 *big.Int = big.NewInt(10)
    33  var big100 *big.Int = big.NewInt(100)
    34  var big1000 *big.Int = big.NewInt(1000)
    35  
    36  var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
    37  var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
    38  var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
    39  var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
    40  var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
    41  var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
    42  var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
    43  var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
    44  var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
    45  var bigMinInt *big.Int = big.NewInt(int64(minInt))
    46  
    47  var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
    48  var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
    49  var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
    50  var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
    51  var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))
    52  
    53  var bigNBase *big.Int = big.NewInt(nbase)
    54  var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
    55  var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
    56  var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
    57  
    58  type Numeric struct {
    59  	Int              *big.Int
    60  	Exp              int32
    61  	Status           Status
    62  	NaN              bool
    63  	InfinityModifier InfinityModifier
    64  }
    65  
    66  func (dst *Numeric) Set(src interface{}) error {
    67  	if src == nil {
    68  		*dst = Numeric{Status: Null}
    69  		return nil
    70  	}
    71  
    72  	if value, ok := src.(interface{ Get() interface{} }); ok {
    73  		value2 := value.Get()
    74  		if value2 != value {
    75  			return dst.Set(value2)
    76  		}
    77  	}
    78  
    79  	switch value := src.(type) {
    80  	case float32:
    81  		if math.IsNaN(float64(value)) {
    82  			*dst = Numeric{Status: Present, NaN: true}
    83  			return nil
    84  		} else if math.IsInf(float64(value), 1) {
    85  			*dst = Numeric{Status: Present, InfinityModifier: Infinity}
    86  			return nil
    87  		} else if math.IsInf(float64(value), -1) {
    88  			*dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity}
    89  			return nil
    90  		}
    91  		num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
    92  		if err != nil {
    93  			return err
    94  		}
    95  		*dst = Numeric{Int: num, Exp: exp, Status: Present}
    96  	case float64:
    97  		if math.IsNaN(value) {
    98  			*dst = Numeric{Status: Present, NaN: true}
    99  			return nil
   100  		} else if math.IsInf(value, 1) {
   101  			*dst = Numeric{Status: Present, InfinityModifier: Infinity}
   102  			return nil
   103  		} else if math.IsInf(value, -1) {
   104  			*dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity}
   105  			return nil
   106  		}
   107  		num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
   108  		if err != nil {
   109  			return err
   110  		}
   111  		*dst = Numeric{Int: num, Exp: exp, Status: Present}
   112  	case int8:
   113  		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
   114  	case uint8:
   115  		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
   116  	case int16:
   117  		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
   118  	case uint16:
   119  		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
   120  	case int32:
   121  		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
   122  	case uint32:
   123  		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
   124  	case int64:
   125  		*dst = Numeric{Int: big.NewInt(value), Status: Present}
   126  	case uint64:
   127  		*dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present}
   128  	case int:
   129  		*dst = Numeric{Int: big.NewInt(int64(value)), Status: Present}
   130  	case uint:
   131  		*dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present}
   132  	case string:
   133  		num, exp, err := parseNumericString(value)
   134  		if err != nil {
   135  			return err
   136  		}
   137  		*dst = Numeric{Int: num, Exp: exp, Status: Present}
   138  	case *float64:
   139  		if value == nil {
   140  			*dst = Numeric{Status: Null}
   141  		} else {
   142  			return dst.Set(*value)
   143  		}
   144  	case *float32:
   145  		if value == nil {
   146  			*dst = Numeric{Status: Null}
   147  		} else {
   148  			return dst.Set(*value)
   149  		}
   150  	case *int8:
   151  		if value == nil {
   152  			*dst = Numeric{Status: Null}
   153  		} else {
   154  			return dst.Set(*value)
   155  		}
   156  	case *uint8:
   157  		if value == nil {
   158  			*dst = Numeric{Status: Null}
   159  		} else {
   160  			return dst.Set(*value)
   161  		}
   162  	case *int16:
   163  		if value == nil {
   164  			*dst = Numeric{Status: Null}
   165  		} else {
   166  			return dst.Set(*value)
   167  		}
   168  	case *uint16:
   169  		if value == nil {
   170  			*dst = Numeric{Status: Null}
   171  		} else {
   172  			return dst.Set(*value)
   173  		}
   174  	case *int32:
   175  		if value == nil {
   176  			*dst = Numeric{Status: Null}
   177  		} else {
   178  			return dst.Set(*value)
   179  		}
   180  	case *uint32:
   181  		if value == nil {
   182  			*dst = Numeric{Status: Null}
   183  		} else {
   184  			return dst.Set(*value)
   185  		}
   186  	case *int64:
   187  		if value == nil {
   188  			*dst = Numeric{Status: Null}
   189  		} else {
   190  			return dst.Set(*value)
   191  		}
   192  	case *uint64:
   193  		if value == nil {
   194  			*dst = Numeric{Status: Null}
   195  		} else {
   196  			return dst.Set(*value)
   197  		}
   198  	case *int:
   199  		if value == nil {
   200  			*dst = Numeric{Status: Null}
   201  		} else {
   202  			return dst.Set(*value)
   203  		}
   204  	case *uint:
   205  		if value == nil {
   206  			*dst = Numeric{Status: Null}
   207  		} else {
   208  			return dst.Set(*value)
   209  		}
   210  	case *string:
   211  		if value == nil {
   212  			*dst = Numeric{Status: Null}
   213  		} else {
   214  			return dst.Set(*value)
   215  		}
   216  	case InfinityModifier:
   217  		*dst = Numeric{InfinityModifier: value, Status: Present}
   218  	default:
   219  		if originalSrc, ok := underlyingNumberType(src); ok {
   220  			return dst.Set(originalSrc)
   221  		}
   222  		return fmt.Errorf("cannot convert %v to Numeric", value)
   223  	}
   224  
   225  	return nil
   226  }
   227  
   228  func (dst Numeric) Get() interface{} {
   229  	switch dst.Status {
   230  	case Present:
   231  		if dst.InfinityModifier != None {
   232  			return dst.InfinityModifier
   233  		}
   234  		return dst
   235  	case Null:
   236  		return nil
   237  	default:
   238  		return dst.Status
   239  	}
   240  }
   241  
   242  func (src *Numeric) AssignTo(dst interface{}) error {
   243  	switch src.Status {
   244  	case Present:
   245  		switch v := dst.(type) {
   246  		case *float32:
   247  			f, err := src.toFloat64()
   248  			if err != nil {
   249  				return err
   250  			}
   251  			return float64AssignTo(f, src.Status, dst)
   252  		case *float64:
   253  			f, err := src.toFloat64()
   254  			if err != nil {
   255  				return err
   256  			}
   257  			return float64AssignTo(f, src.Status, dst)
   258  		case *int:
   259  			normalizedInt, err := src.toBigInt()
   260  			if err != nil {
   261  				return err
   262  			}
   263  			if normalizedInt.Cmp(bigMaxInt) > 0 {
   264  				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
   265  			}
   266  			if normalizedInt.Cmp(bigMinInt) < 0 {
   267  				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
   268  			}
   269  			*v = int(normalizedInt.Int64())
   270  		case *int8:
   271  			normalizedInt, err := src.toBigInt()
   272  			if err != nil {
   273  				return err
   274  			}
   275  			if normalizedInt.Cmp(bigMaxInt8) > 0 {
   276  				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
   277  			}
   278  			if normalizedInt.Cmp(bigMinInt8) < 0 {
   279  				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
   280  			}
   281  			*v = int8(normalizedInt.Int64())
   282  		case *int16:
   283  			normalizedInt, err := src.toBigInt()
   284  			if err != nil {
   285  				return err
   286  			}
   287  			if normalizedInt.Cmp(bigMaxInt16) > 0 {
   288  				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
   289  			}
   290  			if normalizedInt.Cmp(bigMinInt16) < 0 {
   291  				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
   292  			}
   293  			*v = int16(normalizedInt.Int64())
   294  		case *int32:
   295  			normalizedInt, err := src.toBigInt()
   296  			if err != nil {
   297  				return err
   298  			}
   299  			if normalizedInt.Cmp(bigMaxInt32) > 0 {
   300  				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
   301  			}
   302  			if normalizedInt.Cmp(bigMinInt32) < 0 {
   303  				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
   304  			}
   305  			*v = int32(normalizedInt.Int64())
   306  		case *int64:
   307  			normalizedInt, err := src.toBigInt()
   308  			if err != nil {
   309  				return err
   310  			}
   311  			if normalizedInt.Cmp(bigMaxInt64) > 0 {
   312  				return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v)
   313  			}
   314  			if normalizedInt.Cmp(bigMinInt64) < 0 {
   315  				return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v)
   316  			}
   317  			*v = normalizedInt.Int64()
   318  		case *uint:
   319  			normalizedInt, err := src.toBigInt()
   320  			if err != nil {
   321  				return err
   322  			}
   323  			if normalizedInt.Cmp(big0) < 0 {
   324  				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
   325  			} else if normalizedInt.Cmp(bigMaxUint) > 0 {
   326  				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
   327  			}
   328  			*v = uint(normalizedInt.Uint64())
   329  		case *uint8:
   330  			normalizedInt, err := src.toBigInt()
   331  			if err != nil {
   332  				return err
   333  			}
   334  			if normalizedInt.Cmp(big0) < 0 {
   335  				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
   336  			} else if normalizedInt.Cmp(bigMaxUint8) > 0 {
   337  				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
   338  			}
   339  			*v = uint8(normalizedInt.Uint64())
   340  		case *uint16:
   341  			normalizedInt, err := src.toBigInt()
   342  			if err != nil {
   343  				return err
   344  			}
   345  			if normalizedInt.Cmp(big0) < 0 {
   346  				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
   347  			} else if normalizedInt.Cmp(bigMaxUint16) > 0 {
   348  				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
   349  			}
   350  			*v = uint16(normalizedInt.Uint64())
   351  		case *uint32:
   352  			normalizedInt, err := src.toBigInt()
   353  			if err != nil {
   354  				return err
   355  			}
   356  			if normalizedInt.Cmp(big0) < 0 {
   357  				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
   358  			} else if normalizedInt.Cmp(bigMaxUint32) > 0 {
   359  				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
   360  			}
   361  			*v = uint32(normalizedInt.Uint64())
   362  		case *uint64:
   363  			normalizedInt, err := src.toBigInt()
   364  			if err != nil {
   365  				return err
   366  			}
   367  			if normalizedInt.Cmp(big0) < 0 {
   368  				return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v)
   369  			} else if normalizedInt.Cmp(bigMaxUint64) > 0 {
   370  				return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v)
   371  			}
   372  			*v = normalizedInt.Uint64()
   373  		case *big.Rat:
   374  			rat, err := src.toBigRat()
   375  			if err != nil {
   376  				return err
   377  			}
   378  			v.Set(rat)
   379  		case *string:
   380  			buf, err := encodeNumericText(*src, nil)
   381  			if err != nil {
   382  				return err
   383  			}
   384  			*v = string(buf)
   385  		default:
   386  			if nextDst, retry := GetAssignToDstType(dst); retry {
   387  				return src.AssignTo(nextDst)
   388  			}
   389  			return fmt.Errorf("unable to assign to %T", dst)
   390  		}
   391  	case Null:
   392  		return NullAssignTo(dst)
   393  	}
   394  
   395  	return nil
   396  }
   397  
   398  func (dst *Numeric) toBigInt() (*big.Int, error) {
   399  	if dst.Exp == 0 {
   400  		return dst.Int, nil
   401  	}
   402  
   403  	num := &big.Int{}
   404  	num.Set(dst.Int)
   405  	if dst.Exp > 0 {
   406  		mul := &big.Int{}
   407  		mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil)
   408  		num.Mul(num, mul)
   409  		return num, nil
   410  	}
   411  
   412  	div := &big.Int{}
   413  	div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil)
   414  	remainder := &big.Int{}
   415  	num.DivMod(num, div, remainder)
   416  	if remainder.Cmp(big0) != 0 {
   417  		return nil, fmt.Errorf("cannot convert %v to integer", dst)
   418  	}
   419  	return num, nil
   420  }
   421  
   422  func (dst *Numeric) toBigRat() (*big.Rat, error) {
   423  	if dst.NaN {
   424  		return nil, fmt.Errorf("%v is not a number", dst)
   425  	} else if dst.InfinityModifier == Infinity {
   426  		return nil, fmt.Errorf("%v is infinity", dst)
   427  	} else if dst.InfinityModifier == NegativeInfinity {
   428  		return nil, fmt.Errorf("%v is -infinity", dst)
   429  	}
   430  
   431  	num := new(big.Rat).SetInt(dst.Int)
   432  	if dst.Exp > 0 {
   433  		mul := new(big.Int).Exp(big10, big.NewInt(int64(dst.Exp)), nil)
   434  		num.Mul(num, new(big.Rat).SetInt(mul))
   435  	} else if dst.Exp < 0 {
   436  		mul := new(big.Int).Exp(big10, big.NewInt(int64(-dst.Exp)), nil)
   437  		num.Quo(num, new(big.Rat).SetInt(mul))
   438  	}
   439  	return num, nil
   440  }
   441  
   442  func (src *Numeric) toFloat64() (float64, error) {
   443  	if src.NaN {
   444  		return math.NaN(), nil
   445  	} else if src.InfinityModifier == Infinity {
   446  		return math.Inf(1), nil
   447  	} else if src.InfinityModifier == NegativeInfinity {
   448  		return math.Inf(-1), nil
   449  	}
   450  
   451  	buf := make([]byte, 0, 32)
   452  
   453  	buf = append(buf, src.Int.String()...)
   454  	buf = append(buf, 'e')
   455  	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
   456  
   457  	f, err := strconv.ParseFloat(string(buf), 64)
   458  	if err != nil {
   459  		return 0, err
   460  	}
   461  	return f, nil
   462  }
   463  
   464  func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
   465  	if src == nil {
   466  		*dst = Numeric{Status: Null}
   467  		return nil
   468  	}
   469  
   470  	if string(src) == "NaN" {
   471  		*dst = Numeric{Status: Present, NaN: true}
   472  		return nil
   473  	} else if string(src) == "Infinity" {
   474  		*dst = Numeric{Status: Present, InfinityModifier: Infinity}
   475  		return nil
   476  	} else if string(src) == "-Infinity" {
   477  		*dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity}
   478  		return nil
   479  	}
   480  
   481  	num, exp, err := parseNumericString(string(src))
   482  	if err != nil {
   483  		return err
   484  	}
   485  
   486  	*dst = Numeric{Int: num, Exp: exp, Status: Present}
   487  	return nil
   488  }
   489  
   490  func parseNumericString(str string) (n *big.Int, exp int32, err error) {
   491  	parts := strings.SplitN(str, ".", 2)
   492  	digits := strings.Join(parts, "")
   493  
   494  	if len(parts) > 1 {
   495  		exp = int32(-len(parts[1]))
   496  	} else {
   497  		for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' {
   498  			digits = digits[:len(digits)-1]
   499  			exp++
   500  		}
   501  	}
   502  
   503  	accum := &big.Int{}
   504  	if _, ok := accum.SetString(digits, 10); !ok {
   505  		return nil, 0, fmt.Errorf("%s is not a number", str)
   506  	}
   507  
   508  	return accum, exp, nil
   509  }
   510  
   511  func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
   512  	if src == nil {
   513  		*dst = Numeric{Status: Null}
   514  		return nil
   515  	}
   516  
   517  	if len(src) < 8 {
   518  		return fmt.Errorf("numeric incomplete %v", src)
   519  	}
   520  
   521  	rp := 0
   522  	ndigits := binary.BigEndian.Uint16(src[rp:])
   523  	rp += 2
   524  	weight := int16(binary.BigEndian.Uint16(src[rp:]))
   525  	rp += 2
   526  	sign := binary.BigEndian.Uint16(src[rp:])
   527  	rp += 2
   528  	dscale := int16(binary.BigEndian.Uint16(src[rp:]))
   529  	rp += 2
   530  
   531  	if sign == pgNumericNaNSign {
   532  		*dst = Numeric{Status: Present, NaN: true}
   533  		return nil
   534  	} else if sign == pgNumericPosInfSign {
   535  		*dst = Numeric{Status: Present, InfinityModifier: Infinity}
   536  		return nil
   537  	} else if sign == pgNumericNegInfSign {
   538  		*dst = Numeric{Status: Present, InfinityModifier: NegativeInfinity}
   539  		return nil
   540  	}
   541  
   542  	if ndigits == 0 {
   543  		*dst = Numeric{Int: big.NewInt(0), Status: Present}
   544  		return nil
   545  	}
   546  
   547  	if len(src[rp:]) < int(ndigits)*2 {
   548  		return fmt.Errorf("numeric incomplete %v", src)
   549  	}
   550  
   551  	accum := &big.Int{}
   552  
   553  	for i := 0; i < int(ndigits+3)/4; i++ {
   554  		int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:])
   555  		rp += bytesRead
   556  
   557  		if i > 0 {
   558  			var mul *big.Int
   559  			switch digitsRead {
   560  			case 1:
   561  				mul = bigNBase
   562  			case 2:
   563  				mul = bigNBaseX2
   564  			case 3:
   565  				mul = bigNBaseX3
   566  			case 4:
   567  				mul = bigNBaseX4
   568  			default:
   569  				return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead)
   570  			}
   571  			accum.Mul(accum, mul)
   572  		}
   573  
   574  		accum.Add(accum, big.NewInt(int64accum))
   575  	}
   576  
   577  	exp := (int32(weight) - int32(ndigits) + 1) * 4
   578  
   579  	if dscale > 0 {
   580  		fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1)
   581  		fracDecimalDigits := fracNBaseDigits * 4
   582  
   583  		if dscale > fracDecimalDigits {
   584  			multCount := int(dscale - fracDecimalDigits)
   585  			for i := 0; i < multCount; i++ {
   586  				accum.Mul(accum, big10)
   587  				exp--
   588  			}
   589  		} else if dscale < fracDecimalDigits {
   590  			divCount := int(fracDecimalDigits - dscale)
   591  			for i := 0; i < divCount; i++ {
   592  				accum.Div(accum, big10)
   593  				exp++
   594  			}
   595  		}
   596  	}
   597  
   598  	reduced := &big.Int{}
   599  	remainder := &big.Int{}
   600  	if exp >= 0 {
   601  		for {
   602  			reduced.DivMod(accum, big10, remainder)
   603  			if remainder.Cmp(big0) != 0 {
   604  				break
   605  			}
   606  			accum.Set(reduced)
   607  			exp++
   608  		}
   609  	}
   610  
   611  	if sign != 0 {
   612  		accum.Neg(accum)
   613  	}
   614  
   615  	*dst = Numeric{Int: accum, Exp: exp, Status: Present}
   616  
   617  	return nil
   618  
   619  }
   620  
   621  func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
   622  	digits := len(src) / 2
   623  	if digits > 4 {
   624  		digits = 4
   625  	}
   626  
   627  	rp := 0
   628  
   629  	for i := 0; i < digits; i++ {
   630  		if i > 0 {
   631  			accum *= nbase
   632  		}
   633  		accum += int64(binary.BigEndian.Uint16(src[rp:]))
   634  		rp += 2
   635  	}
   636  
   637  	return accum, rp, digits
   638  }
   639  
   640  func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
   641  	switch src.Status {
   642  	case Null:
   643  		return nil, nil
   644  	case Undefined:
   645  		return nil, errUndefined
   646  	}
   647  
   648  	if src.NaN {
   649  		buf = append(buf, "NaN"...)
   650  		return buf, nil
   651  	} else if src.InfinityModifier == Infinity {
   652  		buf = append(buf, "Infinity"...)
   653  		return buf, nil
   654  	} else if src.InfinityModifier == NegativeInfinity {
   655  		buf = append(buf, "-Infinity"...)
   656  		return buf, nil
   657  	}
   658  
   659  	buf = append(buf, src.Int.String()...)
   660  	buf = append(buf, 'e')
   661  	buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
   662  	return buf, nil
   663  }
   664  
   665  func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
   666  	switch src.Status {
   667  	case Null:
   668  		return nil, nil
   669  	case Undefined:
   670  		return nil, errUndefined
   671  	}
   672  
   673  	if src.NaN {
   674  		buf = pgio.AppendUint64(buf, pgNumericNaN)
   675  		return buf, nil
   676  	} else if src.InfinityModifier == Infinity {
   677  		buf = pgio.AppendUint64(buf, pgNumericPosInf)
   678  		return buf, nil
   679  	} else if src.InfinityModifier == NegativeInfinity {
   680  		buf = pgio.AppendUint64(buf, pgNumericNegInf)
   681  		return buf, nil
   682  	}
   683  
   684  	var sign int16
   685  	if src.Int.Cmp(big0) < 0 {
   686  		sign = 16384
   687  	}
   688  
   689  	absInt := &big.Int{}
   690  	wholePart := &big.Int{}
   691  	fracPart := &big.Int{}
   692  	remainder := &big.Int{}
   693  	absInt.Abs(src.Int)
   694  
   695  	// Normalize absInt and exp to where exp is always a multiple of 4. This makes
   696  	// converting to 16-bit base 10,000 digits easier.
   697  	var exp int32
   698  	switch src.Exp % 4 {
   699  	case 1, -3:
   700  		exp = src.Exp - 1
   701  		absInt.Mul(absInt, big10)
   702  	case 2, -2:
   703  		exp = src.Exp - 2
   704  		absInt.Mul(absInt, big100)
   705  	case 3, -1:
   706  		exp = src.Exp - 3
   707  		absInt.Mul(absInt, big1000)
   708  	default:
   709  		exp = src.Exp
   710  	}
   711  
   712  	if exp < 0 {
   713  		divisor := &big.Int{}
   714  		divisor.Exp(big10, big.NewInt(int64(-exp)), nil)
   715  		wholePart.DivMod(absInt, divisor, fracPart)
   716  		fracPart.Add(fracPart, divisor)
   717  	} else {
   718  		wholePart = absInt
   719  	}
   720  
   721  	var wholeDigits, fracDigits []int16
   722  
   723  	for wholePart.Cmp(big0) != 0 {
   724  		wholePart.DivMod(wholePart, bigNBase, remainder)
   725  		wholeDigits = append(wholeDigits, int16(remainder.Int64()))
   726  	}
   727  
   728  	if fracPart.Cmp(big0) != 0 {
   729  		for fracPart.Cmp(big1) != 0 {
   730  			fracPart.DivMod(fracPart, bigNBase, remainder)
   731  			fracDigits = append(fracDigits, int16(remainder.Int64()))
   732  		}
   733  	}
   734  
   735  	buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits)))
   736  
   737  	var weight int16
   738  	if len(wholeDigits) > 0 {
   739  		weight = int16(len(wholeDigits) - 1)
   740  		if exp > 0 {
   741  			weight += int16(exp / 4)
   742  		}
   743  	} else {
   744  		weight = int16(exp/4) - 1 + int16(len(fracDigits))
   745  	}
   746  	buf = pgio.AppendInt16(buf, weight)
   747  
   748  	buf = pgio.AppendInt16(buf, sign)
   749  
   750  	var dscale int16
   751  	if src.Exp < 0 {
   752  		dscale = int16(-src.Exp)
   753  	}
   754  	buf = pgio.AppendInt16(buf, dscale)
   755  
   756  	for i := len(wholeDigits) - 1; i >= 0; i-- {
   757  		buf = pgio.AppendInt16(buf, wholeDigits[i])
   758  	}
   759  
   760  	for i := len(fracDigits) - 1; i >= 0; i-- {
   761  		buf = pgio.AppendInt16(buf, fracDigits[i])
   762  	}
   763  
   764  	return buf, nil
   765  }
   766  
   767  // Scan implements the database/sql Scanner interface.
   768  func (dst *Numeric) Scan(src interface{}) error {
   769  	if src == nil {
   770  		*dst = Numeric{Status: Null}
   771  		return nil
   772  	}
   773  
   774  	switch src := src.(type) {
   775  	case string:
   776  		return dst.DecodeText(nil, []byte(src))
   777  	case []byte:
   778  		srcCopy := make([]byte, len(src))
   779  		copy(srcCopy, src)
   780  		return dst.DecodeText(nil, srcCopy)
   781  	}
   782  
   783  	return fmt.Errorf("cannot scan %T", src)
   784  }
   785  
   786  // Value implements the database/sql/driver Valuer interface.
   787  func (src Numeric) Value() (driver.Value, error) {
   788  	switch src.Status {
   789  	case Present:
   790  		buf, err := src.EncodeText(nil, nil)
   791  		if err != nil {
   792  			return nil, err
   793  		}
   794  
   795  		return string(buf), nil
   796  	case Null:
   797  		return nil, nil
   798  	default:
   799  		return nil, errUndefined
   800  	}
   801  }
   802  
   803  func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) {
   804  	// if !n.Valid {
   805  	// 	return nil, nil
   806  	// }
   807  
   808  	if n.NaN {
   809  		buf = append(buf, "NaN"...)
   810  		return buf, nil
   811  	} else if n.InfinityModifier == Infinity {
   812  		buf = append(buf, "Infinity"...)
   813  		return buf, nil
   814  	} else if n.InfinityModifier == NegativeInfinity {
   815  		buf = append(buf, "-Infinity"...)
   816  		return buf, nil
   817  	}
   818  
   819  	buf = append(buf, n.numberTextBytes()...)
   820  
   821  	return buf, nil
   822  }
   823  
   824  // numberString returns a string of the number. undefined if NaN, infinite, or NULL
   825  func (n Numeric) numberTextBytes() []byte {
   826  	intStr := n.Int.String()
   827  	buf := &bytes.Buffer{}
   828  	exp := int(n.Exp)
   829  	if exp > 0 {
   830  		buf.WriteString(intStr)
   831  		for i := 0; i < exp; i++ {
   832  			buf.WriteByte('0')
   833  		}
   834  	} else if exp < 0 {
   835  		if len(intStr) <= -exp {
   836  			buf.WriteString("0.")
   837  			leadingZeros := -exp - len(intStr)
   838  			for i := 0; i < leadingZeros; i++ {
   839  				buf.WriteByte('0')
   840  			}
   841  			buf.WriteString(intStr)
   842  		} else if len(intStr) > -exp {
   843  			dpPos := len(intStr) + exp
   844  			buf.WriteString(intStr[:dpPos])
   845  			buf.WriteByte('.')
   846  			buf.WriteString(intStr[dpPos:])
   847  		}
   848  	} else {
   849  		buf.WriteString(intStr)
   850  	}
   851  
   852  	return buf.Bytes()
   853  }
   854  

View as plain text