...

Source file src/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go

Documentation: github.com/jackc/pgx/v5/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math"
     7  	"math/big"
     8  	"net"
     9  	"net/netip"
    10  	"reflect"
    11  	"time"
    12  )
    13  
    14  type int8Wrapper int8
    15  
    16  func (w int8Wrapper) SkipUnderlyingTypePlan() {}
    17  
    18  func (w *int8Wrapper) ScanInt64(v Int8) error {
    19  	if !v.Valid {
    20  		return fmt.Errorf("cannot scan NULL into *int8")
    21  	}
    22  
    23  	if v.Int64 < math.MinInt8 {
    24  		return fmt.Errorf("%d is less than minimum value for int8", v.Int64)
    25  	}
    26  	if v.Int64 > math.MaxInt8 {
    27  		return fmt.Errorf("%d is greater than maximum value for int8", v.Int64)
    28  	}
    29  	*w = int8Wrapper(v.Int64)
    30  
    31  	return nil
    32  }
    33  
    34  func (w int8Wrapper) Int64Value() (Int8, error) {
    35  	return Int8{Int64: int64(w), Valid: true}, nil
    36  }
    37  
    38  type int16Wrapper int16
    39  
    40  func (w int16Wrapper) SkipUnderlyingTypePlan() {}
    41  
    42  func (w *int16Wrapper) ScanInt64(v Int8) error {
    43  	if !v.Valid {
    44  		return fmt.Errorf("cannot scan NULL into *int16")
    45  	}
    46  
    47  	if v.Int64 < math.MinInt16 {
    48  		return fmt.Errorf("%d is less than minimum value for int16", v.Int64)
    49  	}
    50  	if v.Int64 > math.MaxInt16 {
    51  		return fmt.Errorf("%d is greater than maximum value for int16", v.Int64)
    52  	}
    53  	*w = int16Wrapper(v.Int64)
    54  
    55  	return nil
    56  }
    57  
    58  func (w int16Wrapper) Int64Value() (Int8, error) {
    59  	return Int8{Int64: int64(w), Valid: true}, nil
    60  }
    61  
    62  type int32Wrapper int32
    63  
    64  func (w int32Wrapper) SkipUnderlyingTypePlan() {}
    65  
    66  func (w *int32Wrapper) ScanInt64(v Int8) error {
    67  	if !v.Valid {
    68  		return fmt.Errorf("cannot scan NULL into *int32")
    69  	}
    70  
    71  	if v.Int64 < math.MinInt32 {
    72  		return fmt.Errorf("%d is less than minimum value for int32", v.Int64)
    73  	}
    74  	if v.Int64 > math.MaxInt32 {
    75  		return fmt.Errorf("%d is greater than maximum value for int32", v.Int64)
    76  	}
    77  	*w = int32Wrapper(v.Int64)
    78  
    79  	return nil
    80  }
    81  
    82  func (w int32Wrapper) Int64Value() (Int8, error) {
    83  	return Int8{Int64: int64(w), Valid: true}, nil
    84  }
    85  
    86  type int64Wrapper int64
    87  
    88  func (w int64Wrapper) SkipUnderlyingTypePlan() {}
    89  
    90  func (w *int64Wrapper) ScanInt64(v Int8) error {
    91  	if !v.Valid {
    92  		return fmt.Errorf("cannot scan NULL into *int64")
    93  	}
    94  
    95  	*w = int64Wrapper(v.Int64)
    96  
    97  	return nil
    98  }
    99  
   100  func (w int64Wrapper) Int64Value() (Int8, error) {
   101  	return Int8{Int64: int64(w), Valid: true}, nil
   102  }
   103  
   104  type intWrapper int
   105  
   106  func (w intWrapper) SkipUnderlyingTypePlan() {}
   107  
   108  func (w *intWrapper) ScanInt64(v Int8) error {
   109  	if !v.Valid {
   110  		return fmt.Errorf("cannot scan NULL into *int")
   111  	}
   112  
   113  	if v.Int64 < math.MinInt {
   114  		return fmt.Errorf("%d is less than minimum value for int", v.Int64)
   115  	}
   116  	if v.Int64 > math.MaxInt {
   117  		return fmt.Errorf("%d is greater than maximum value for int", v.Int64)
   118  	}
   119  
   120  	*w = intWrapper(v.Int64)
   121  
   122  	return nil
   123  }
   124  
   125  func (w intWrapper) Int64Value() (Int8, error) {
   126  	return Int8{Int64: int64(w), Valid: true}, nil
   127  }
   128  
   129  type uint8Wrapper uint8
   130  
   131  func (w uint8Wrapper) SkipUnderlyingTypePlan() {}
   132  
   133  func (w *uint8Wrapper) ScanInt64(v Int8) error {
   134  	if !v.Valid {
   135  		return fmt.Errorf("cannot scan NULL into *uint8")
   136  	}
   137  
   138  	if v.Int64 < 0 {
   139  		return fmt.Errorf("%d is less than minimum value for uint8", v.Int64)
   140  	}
   141  	if v.Int64 > math.MaxUint8 {
   142  		return fmt.Errorf("%d is greater than maximum value for uint8", v.Int64)
   143  	}
   144  	*w = uint8Wrapper(v.Int64)
   145  
   146  	return nil
   147  }
   148  
   149  func (w uint8Wrapper) Int64Value() (Int8, error) {
   150  	return Int8{Int64: int64(w), Valid: true}, nil
   151  }
   152  
   153  type uint16Wrapper uint16
   154  
   155  func (w uint16Wrapper) SkipUnderlyingTypePlan() {}
   156  
   157  func (w *uint16Wrapper) ScanInt64(v Int8) error {
   158  	if !v.Valid {
   159  		return fmt.Errorf("cannot scan NULL into *uint16")
   160  	}
   161  
   162  	if v.Int64 < 0 {
   163  		return fmt.Errorf("%d is less than minimum value for uint16", v.Int64)
   164  	}
   165  	if v.Int64 > math.MaxUint16 {
   166  		return fmt.Errorf("%d is greater than maximum value for uint16", v.Int64)
   167  	}
   168  	*w = uint16Wrapper(v.Int64)
   169  
   170  	return nil
   171  }
   172  
   173  func (w uint16Wrapper) Int64Value() (Int8, error) {
   174  	return Int8{Int64: int64(w), Valid: true}, nil
   175  }
   176  
   177  type uint32Wrapper uint32
   178  
   179  func (w uint32Wrapper) SkipUnderlyingTypePlan() {}
   180  
   181  func (w *uint32Wrapper) ScanInt64(v Int8) error {
   182  	if !v.Valid {
   183  		return fmt.Errorf("cannot scan NULL into *uint32")
   184  	}
   185  
   186  	if v.Int64 < 0 {
   187  		return fmt.Errorf("%d is less than minimum value for uint32", v.Int64)
   188  	}
   189  	if v.Int64 > math.MaxUint32 {
   190  		return fmt.Errorf("%d is greater than maximum value for uint32", v.Int64)
   191  	}
   192  	*w = uint32Wrapper(v.Int64)
   193  
   194  	return nil
   195  }
   196  
   197  func (w uint32Wrapper) Int64Value() (Int8, error) {
   198  	return Int8{Int64: int64(w), Valid: true}, nil
   199  }
   200  
   201  type uint64Wrapper uint64
   202  
   203  func (w uint64Wrapper) SkipUnderlyingTypePlan() {}
   204  
   205  func (w *uint64Wrapper) ScanInt64(v Int8) error {
   206  	if !v.Valid {
   207  		return fmt.Errorf("cannot scan NULL into *uint64")
   208  	}
   209  
   210  	if v.Int64 < 0 {
   211  		return fmt.Errorf("%d is less than minimum value for uint64", v.Int64)
   212  	}
   213  
   214  	*w = uint64Wrapper(v.Int64)
   215  
   216  	return nil
   217  }
   218  
   219  func (w uint64Wrapper) Int64Value() (Int8, error) {
   220  	if uint64(w) > uint64(math.MaxInt64) {
   221  		return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w)
   222  	}
   223  
   224  	return Int8{Int64: int64(w), Valid: true}, nil
   225  }
   226  
   227  func (w *uint64Wrapper) ScanNumeric(v Numeric) error {
   228  	if !v.Valid {
   229  		return fmt.Errorf("cannot scan NULL into *uint64")
   230  	}
   231  
   232  	bi, err := v.toBigInt()
   233  	if err != nil {
   234  		return fmt.Errorf("cannot scan into *uint64: %w", err)
   235  	}
   236  
   237  	if !bi.IsUint64() {
   238  		return fmt.Errorf("cannot scan %v into *uint64", bi.String())
   239  	}
   240  
   241  	*w = uint64Wrapper(bi.Uint64())
   242  
   243  	return nil
   244  }
   245  
   246  func (w uint64Wrapper) NumericValue() (Numeric, error) {
   247  	return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil
   248  }
   249  
   250  type uintWrapper uint
   251  
   252  func (w uintWrapper) SkipUnderlyingTypePlan() {}
   253  
   254  func (w *uintWrapper) ScanInt64(v Int8) error {
   255  	if !v.Valid {
   256  		return fmt.Errorf("cannot scan NULL into *uint64")
   257  	}
   258  
   259  	if v.Int64 < 0 {
   260  		return fmt.Errorf("%d is less than minimum value for uint64", v.Int64)
   261  	}
   262  
   263  	if uint64(v.Int64) > math.MaxUint {
   264  		return fmt.Errorf("%d is greater than maximum value for uint", v.Int64)
   265  	}
   266  
   267  	*w = uintWrapper(v.Int64)
   268  
   269  	return nil
   270  }
   271  
   272  func (w uintWrapper) Int64Value() (Int8, error) {
   273  	if uint64(w) > uint64(math.MaxInt64) {
   274  		return Int8{}, fmt.Errorf("%d is greater than maximum value for int64", w)
   275  	}
   276  
   277  	return Int8{Int64: int64(w), Valid: true}, nil
   278  }
   279  
   280  func (w *uintWrapper) ScanNumeric(v Numeric) error {
   281  	if !v.Valid {
   282  		return fmt.Errorf("cannot scan NULL into *uint")
   283  	}
   284  
   285  	bi, err := v.toBigInt()
   286  	if err != nil {
   287  		return fmt.Errorf("cannot scan into *uint: %w", err)
   288  	}
   289  
   290  	if !bi.IsUint64() {
   291  		return fmt.Errorf("cannot scan %v into *uint", bi.String())
   292  	}
   293  
   294  	ui := bi.Uint64()
   295  
   296  	if math.MaxUint < ui {
   297  		return fmt.Errorf("cannot scan %v into *uint", ui)
   298  	}
   299  
   300  	*w = uintWrapper(ui)
   301  
   302  	return nil
   303  }
   304  
   305  func (w uintWrapper) NumericValue() (Numeric, error) {
   306  	return Numeric{Int: new(big.Int).SetUint64(uint64(w)), Valid: true}, nil
   307  }
   308  
   309  type float32Wrapper float32
   310  
   311  func (w float32Wrapper) SkipUnderlyingTypePlan() {}
   312  
   313  func (w *float32Wrapper) ScanInt64(v Int8) error {
   314  	if !v.Valid {
   315  		return fmt.Errorf("cannot scan NULL into *float32")
   316  	}
   317  
   318  	*w = float32Wrapper(v.Int64)
   319  
   320  	return nil
   321  }
   322  
   323  func (w float32Wrapper) Int64Value() (Int8, error) {
   324  	if w > math.MaxInt64 {
   325  		return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w)
   326  	}
   327  
   328  	return Int8{Int64: int64(w), Valid: true}, nil
   329  }
   330  
   331  func (w *float32Wrapper) ScanFloat64(v Float8) error {
   332  	if !v.Valid {
   333  		return fmt.Errorf("cannot scan NULL into *float32")
   334  	}
   335  
   336  	*w = float32Wrapper(v.Float64)
   337  
   338  	return nil
   339  }
   340  
   341  func (w float32Wrapper) Float64Value() (Float8, error) {
   342  	return Float8{Float64: float64(w), Valid: true}, nil
   343  }
   344  
   345  type float64Wrapper float64
   346  
   347  func (w float64Wrapper) SkipUnderlyingTypePlan() {}
   348  
   349  func (w *float64Wrapper) ScanInt64(v Int8) error {
   350  	if !v.Valid {
   351  		return fmt.Errorf("cannot scan NULL into *float64")
   352  	}
   353  
   354  	*w = float64Wrapper(v.Int64)
   355  
   356  	return nil
   357  }
   358  
   359  func (w float64Wrapper) Int64Value() (Int8, error) {
   360  	if w > math.MaxInt64 {
   361  		return Int8{}, fmt.Errorf("%f is greater than maximum value for int64", w)
   362  	}
   363  
   364  	return Int8{Int64: int64(w), Valid: true}, nil
   365  }
   366  
   367  func (w *float64Wrapper) ScanFloat64(v Float8) error {
   368  	if !v.Valid {
   369  		return fmt.Errorf("cannot scan NULL into *float64")
   370  	}
   371  
   372  	*w = float64Wrapper(v.Float64)
   373  
   374  	return nil
   375  }
   376  
   377  func (w float64Wrapper) Float64Value() (Float8, error) {
   378  	return Float8{Float64: float64(w), Valid: true}, nil
   379  }
   380  
   381  type stringWrapper string
   382  
   383  func (w stringWrapper) SkipUnderlyingTypePlan() {}
   384  
   385  func (w *stringWrapper) ScanText(v Text) error {
   386  	if !v.Valid {
   387  		return fmt.Errorf("cannot scan NULL into *string")
   388  	}
   389  
   390  	*w = stringWrapper(v.String)
   391  	return nil
   392  }
   393  
   394  func (w stringWrapper) TextValue() (Text, error) {
   395  	return Text{String: string(w), Valid: true}, nil
   396  }
   397  
   398  type timeWrapper time.Time
   399  
   400  func (w *timeWrapper) ScanDate(v Date) error {
   401  	if !v.Valid {
   402  		return fmt.Errorf("cannot scan NULL into *time.Time")
   403  	}
   404  
   405  	switch v.InfinityModifier {
   406  	case Finite:
   407  		*w = timeWrapper(v.Time)
   408  		return nil
   409  	case Infinity:
   410  		return fmt.Errorf("cannot scan Infinity into *time.Time")
   411  	case NegativeInfinity:
   412  		return fmt.Errorf("cannot scan -Infinity into *time.Time")
   413  	default:
   414  		return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier)
   415  	}
   416  }
   417  
   418  func (w timeWrapper) DateValue() (Date, error) {
   419  	return Date{Time: time.Time(w), Valid: true}, nil
   420  }
   421  
   422  func (w *timeWrapper) ScanTimestamp(v Timestamp) error {
   423  	if !v.Valid {
   424  		return fmt.Errorf("cannot scan NULL into *time.Time")
   425  	}
   426  
   427  	switch v.InfinityModifier {
   428  	case Finite:
   429  		*w = timeWrapper(v.Time)
   430  		return nil
   431  	case Infinity:
   432  		return fmt.Errorf("cannot scan Infinity into *time.Time")
   433  	case NegativeInfinity:
   434  		return fmt.Errorf("cannot scan -Infinity into *time.Time")
   435  	default:
   436  		return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier)
   437  	}
   438  }
   439  
   440  func (w timeWrapper) TimestampValue() (Timestamp, error) {
   441  	return Timestamp{Time: time.Time(w), Valid: true}, nil
   442  }
   443  
   444  func (w *timeWrapper) ScanTimestamptz(v Timestamptz) error {
   445  	if !v.Valid {
   446  		return fmt.Errorf("cannot scan NULL into *time.Time")
   447  	}
   448  
   449  	switch v.InfinityModifier {
   450  	case Finite:
   451  		*w = timeWrapper(v.Time)
   452  		return nil
   453  	case Infinity:
   454  		return fmt.Errorf("cannot scan Infinity into *time.Time")
   455  	case NegativeInfinity:
   456  		return fmt.Errorf("cannot scan -Infinity into *time.Time")
   457  	default:
   458  		return fmt.Errorf("invalid InfinityModifier: %v", v.InfinityModifier)
   459  	}
   460  }
   461  
   462  func (w timeWrapper) TimestamptzValue() (Timestamptz, error) {
   463  	return Timestamptz{Time: time.Time(w), Valid: true}, nil
   464  }
   465  
   466  func (w *timeWrapper) ScanTime(v Time) error {
   467  	if !v.Valid {
   468  		return fmt.Errorf("cannot scan NULL into *time.Time")
   469  	}
   470  
   471  	// 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day.
   472  	var maxRepresentableByTime int64 = 24*60*60*1000000 - 1
   473  	if v.Microseconds > maxRepresentableByTime {
   474  		return fmt.Errorf("%d microseconds cannot be represented as time.Time", v.Microseconds)
   475  	}
   476  
   477  	usec := v.Microseconds
   478  	hours := usec / microsecondsPerHour
   479  	usec -= hours * microsecondsPerHour
   480  	minutes := usec / microsecondsPerMinute
   481  	usec -= minutes * microsecondsPerMinute
   482  	seconds := usec / microsecondsPerSecond
   483  	usec -= seconds * microsecondsPerSecond
   484  	ns := usec * 1000
   485  	*w = timeWrapper(time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC))
   486  	return nil
   487  }
   488  
   489  func (w timeWrapper) TimeValue() (Time, error) {
   490  	t := time.Time(w)
   491  	usec := int64(t.Hour())*microsecondsPerHour +
   492  		int64(t.Minute())*microsecondsPerMinute +
   493  		int64(t.Second())*microsecondsPerSecond +
   494  		int64(t.Nanosecond())/1000
   495  	return Time{Microseconds: usec, Valid: true}, nil
   496  }
   497  
   498  type durationWrapper time.Duration
   499  
   500  func (w durationWrapper) SkipUnderlyingTypePlan() {}
   501  
   502  func (w *durationWrapper) ScanInterval(v Interval) error {
   503  	if !v.Valid {
   504  		return fmt.Errorf("cannot scan NULL into *time.Interval")
   505  	}
   506  
   507  	us := int64(v.Months)*microsecondsPerMonth + int64(v.Days)*microsecondsPerDay + v.Microseconds
   508  	*w = durationWrapper(time.Duration(us) * time.Microsecond)
   509  	return nil
   510  }
   511  
   512  func (w durationWrapper) IntervalValue() (Interval, error) {
   513  	return Interval{Microseconds: int64(w) / 1000, Valid: true}, nil
   514  }
   515  
   516  type netIPNetWrapper net.IPNet
   517  
   518  func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error {
   519  	if !v.IsValid() {
   520  		return fmt.Errorf("cannot scan NULL into *net.IPNet")
   521  	}
   522  
   523  	*w = netIPNetWrapper{
   524  		IP:   v.Addr().AsSlice(),
   525  		Mask: net.CIDRMask(v.Bits(), v.Addr().BitLen()),
   526  	}
   527  
   528  	return nil
   529  }
   530  func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) {
   531  	ip, ok := netip.AddrFromSlice(w.IP)
   532  	if !ok {
   533  		return netip.Prefix{}, errors.New("invalid net.IPNet")
   534  	}
   535  
   536  	ones, _ := w.Mask.Size()
   537  
   538  	return netip.PrefixFrom(ip, ones), nil
   539  }
   540  
   541  type netIPWrapper net.IP
   542  
   543  func (w netIPWrapper) SkipUnderlyingTypePlan() {}
   544  
   545  func (w *netIPWrapper) ScanNetipPrefix(v netip.Prefix) error {
   546  	if !v.IsValid() {
   547  		*w = nil
   548  		return nil
   549  	}
   550  
   551  	if v.Addr().BitLen() != v.Bits() {
   552  		return fmt.Errorf("cannot scan %v to *net.IP", v)
   553  	}
   554  
   555  	*w = netIPWrapper(v.Addr().AsSlice())
   556  	return nil
   557  }
   558  
   559  func (w netIPWrapper) NetipPrefixValue() (netip.Prefix, error) {
   560  	if w == nil {
   561  		return netip.Prefix{}, nil
   562  	}
   563  
   564  	addr, ok := netip.AddrFromSlice([]byte(w))
   565  	if !ok {
   566  		return netip.Prefix{}, errors.New("invalid net.IP")
   567  	}
   568  
   569  	return netip.PrefixFrom(addr, addr.BitLen()), nil
   570  }
   571  
   572  type netipPrefixWrapper netip.Prefix
   573  
   574  func (w *netipPrefixWrapper) ScanNetipPrefix(v netip.Prefix) error {
   575  	*w = netipPrefixWrapper(v)
   576  	return nil
   577  }
   578  
   579  func (w netipPrefixWrapper) NetipPrefixValue() (netip.Prefix, error) {
   580  	return netip.Prefix(w), nil
   581  }
   582  
   583  type netipAddrWrapper netip.Addr
   584  
   585  func (w *netipAddrWrapper) ScanNetipPrefix(v netip.Prefix) error {
   586  	if !v.IsValid() {
   587  		*w = netipAddrWrapper(netip.Addr{})
   588  		return nil
   589  	}
   590  
   591  	if v.Addr().BitLen() != v.Bits() {
   592  		return fmt.Errorf("cannot scan %v to netip.Addr", v)
   593  	}
   594  
   595  	*w = netipAddrWrapper(v.Addr())
   596  
   597  	return nil
   598  }
   599  
   600  func (w netipAddrWrapper) NetipPrefixValue() (netip.Prefix, error) {
   601  	addr := (netip.Addr)(w)
   602  	if !addr.IsValid() {
   603  		return netip.Prefix{}, nil
   604  	}
   605  
   606  	return netip.PrefixFrom(addr, addr.BitLen()), nil
   607  }
   608  
   609  type mapStringToPointerStringWrapper map[string]*string
   610  
   611  func (w *mapStringToPointerStringWrapper) ScanHstore(v Hstore) error {
   612  	*w = mapStringToPointerStringWrapper(v)
   613  	return nil
   614  }
   615  
   616  func (w mapStringToPointerStringWrapper) HstoreValue() (Hstore, error) {
   617  	return Hstore(w), nil
   618  }
   619  
   620  type mapStringToStringWrapper map[string]string
   621  
   622  func (w *mapStringToStringWrapper) ScanHstore(v Hstore) error {
   623  	*w = make(mapStringToStringWrapper, len(v))
   624  	for k, v := range v {
   625  		if v == nil {
   626  			return fmt.Errorf("cannot scan NULL to string")
   627  		}
   628  		(*w)[k] = *v
   629  	}
   630  	return nil
   631  }
   632  
   633  func (w mapStringToStringWrapper) HstoreValue() (Hstore, error) {
   634  	if w == nil {
   635  		return nil, nil
   636  	}
   637  
   638  	hstore := make(Hstore, len(w))
   639  	for k, v := range w {
   640  		s := v
   641  		hstore[k] = &s
   642  	}
   643  	return hstore, nil
   644  }
   645  
   646  type fmtStringerWrapper struct {
   647  	s fmt.Stringer
   648  }
   649  
   650  func (w fmtStringerWrapper) TextValue() (Text, error) {
   651  	return Text{String: w.s.String(), Valid: true}, nil
   652  }
   653  
   654  type byte16Wrapper [16]byte
   655  
   656  func (w *byte16Wrapper) ScanUUID(v UUID) error {
   657  	if !v.Valid {
   658  		return fmt.Errorf("cannot scan NULL into *[16]byte")
   659  	}
   660  	*w = byte16Wrapper(v.Bytes)
   661  	return nil
   662  }
   663  
   664  func (w byte16Wrapper) UUIDValue() (UUID, error) {
   665  	return UUID{Bytes: [16]byte(w), Valid: true}, nil
   666  }
   667  
   668  type byteSliceWrapper []byte
   669  
   670  func (w byteSliceWrapper) SkipUnderlyingTypePlan() {}
   671  
   672  func (w *byteSliceWrapper) ScanText(v Text) error {
   673  	if !v.Valid {
   674  		*w = nil
   675  		return nil
   676  	}
   677  
   678  	*w = byteSliceWrapper(v.String)
   679  	return nil
   680  }
   681  
   682  func (w byteSliceWrapper) TextValue() (Text, error) {
   683  	if w == nil {
   684  		return Text{}, nil
   685  	}
   686  
   687  	return Text{String: string(w), Valid: true}, nil
   688  }
   689  
   690  func (w *byteSliceWrapper) ScanUUID(v UUID) error {
   691  	if !v.Valid {
   692  		*w = nil
   693  		return nil
   694  	}
   695  	*w = make(byteSliceWrapper, 16)
   696  	copy(*w, v.Bytes[:])
   697  	return nil
   698  }
   699  
   700  func (w byteSliceWrapper) UUIDValue() (UUID, error) {
   701  	if w == nil {
   702  		return UUID{}, nil
   703  	}
   704  
   705  	uuid := UUID{Valid: true}
   706  	copy(uuid.Bytes[:], w)
   707  	return uuid, nil
   708  }
   709  
   710  // structWrapper implements CompositeIndexGetter for a struct.
   711  type structWrapper struct {
   712  	s              any
   713  	exportedFields []reflect.Value
   714  }
   715  
   716  func (w structWrapper) IsNull() bool {
   717  	return w.s == nil
   718  }
   719  
   720  func (w structWrapper) Index(i int) any {
   721  	if i >= len(w.exportedFields) {
   722  		return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
   723  	}
   724  
   725  	return w.exportedFields[i].Interface()
   726  }
   727  
   728  // ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct.
   729  type ptrStructWrapper struct {
   730  	s              any
   731  	exportedFields []reflect.Value
   732  }
   733  
   734  func (w *ptrStructWrapper) ScanNull() error {
   735  	return fmt.Errorf("cannot scan NULL into %#v", w.s)
   736  }
   737  
   738  func (w *ptrStructWrapper) ScanIndex(i int) any {
   739  	if i >= len(w.exportedFields) {
   740  		return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
   741  	}
   742  
   743  	return w.exportedFields[i].Addr().Interface()
   744  }
   745  
   746  type anySliceArrayReflect struct {
   747  	slice reflect.Value
   748  }
   749  
   750  func (a anySliceArrayReflect) Dimensions() []ArrayDimension {
   751  	if a.slice.IsNil() {
   752  		return nil
   753  	}
   754  
   755  	return []ArrayDimension{{Length: int32(a.slice.Len()), LowerBound: 1}}
   756  }
   757  
   758  func (a anySliceArrayReflect) Index(i int) any {
   759  	return a.slice.Index(i).Interface()
   760  }
   761  
   762  func (a anySliceArrayReflect) IndexType() any {
   763  	return reflect.New(a.slice.Type().Elem()).Elem().Interface()
   764  }
   765  
   766  func (a *anySliceArrayReflect) SetDimensions(dimensions []ArrayDimension) error {
   767  	sliceType := a.slice.Type()
   768  
   769  	if dimensions == nil {
   770  		a.slice.Set(reflect.Zero(sliceType))
   771  		return nil
   772  	}
   773  
   774  	elementCount := cardinality(dimensions)
   775  	slice := reflect.MakeSlice(sliceType, elementCount, elementCount)
   776  	a.slice.Set(slice)
   777  	return nil
   778  }
   779  
   780  func (a *anySliceArrayReflect) ScanIndex(i int) any {
   781  	return a.slice.Index(i).Addr().Interface()
   782  }
   783  
   784  func (a *anySliceArrayReflect) ScanIndexType() any {
   785  	return reflect.New(a.slice.Type().Elem()).Interface()
   786  }
   787  
   788  type anyMultiDimSliceArray struct {
   789  	slice reflect.Value
   790  	dims  []ArrayDimension
   791  }
   792  
   793  func (a *anyMultiDimSliceArray) Dimensions() []ArrayDimension {
   794  	if a.slice.IsNil() {
   795  		return nil
   796  	}
   797  
   798  	s := a.slice
   799  	for {
   800  		a.dims = append(a.dims, ArrayDimension{Length: int32(s.Len()), LowerBound: 1})
   801  		if s.Len() > 0 {
   802  			s = s.Index(0)
   803  		} else {
   804  			break
   805  		}
   806  		if s.Type().Kind() == reflect.Slice {
   807  		} else {
   808  			break
   809  		}
   810  	}
   811  
   812  	return a.dims
   813  }
   814  
   815  func (a *anyMultiDimSliceArray) Index(i int) any {
   816  	if len(a.dims) == 1 {
   817  		return a.slice.Index(i).Interface()
   818  	}
   819  
   820  	indexes := make([]int, len(a.dims))
   821  	for j := len(a.dims) - 1; j >= 0; j-- {
   822  		dimLen := int(a.dims[j].Length)
   823  		indexes[j] = i % dimLen
   824  		i = i / dimLen
   825  	}
   826  
   827  	v := a.slice
   828  	for _, si := range indexes {
   829  		v = v.Index(si)
   830  	}
   831  
   832  	return v.Interface()
   833  }
   834  
   835  func (a *anyMultiDimSliceArray) IndexType() any {
   836  	lowestSliceType := a.slice.Type()
   837  	for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() {
   838  	}
   839  	return reflect.New(lowestSliceType.Elem()).Elem().Interface()
   840  }
   841  
   842  func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error {
   843  	sliceType := a.slice.Type()
   844  
   845  	if dimensions == nil {
   846  		a.slice.Set(reflect.Zero(sliceType))
   847  		return nil
   848  	}
   849  
   850  	switch len(dimensions) {
   851  	case 0:
   852  		// Empty, but non-nil array
   853  		slice := reflect.MakeSlice(sliceType, 0, 0)
   854  		a.slice.Set(slice)
   855  		return nil
   856  	case 1:
   857  		elementCount := cardinality(dimensions)
   858  		slice := reflect.MakeSlice(sliceType, elementCount, elementCount)
   859  		a.slice.Set(slice)
   860  		return nil
   861  	default:
   862  		sliceDimensionCount := 1
   863  		lowestSliceType := sliceType
   864  		for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() {
   865  			sliceDimensionCount++
   866  		}
   867  
   868  		if sliceDimensionCount != len(dimensions) {
   869  			return fmt.Errorf("PostgreSQL array has %d dimensions but slice has %d dimensions", len(dimensions), sliceDimensionCount)
   870  		}
   871  
   872  		elementCount := cardinality(dimensions)
   873  		flatSlice := reflect.MakeSlice(lowestSliceType, elementCount, elementCount)
   874  
   875  		multiDimSlice := a.makeMultidimensionalSlice(sliceType, dimensions, flatSlice, 0)
   876  		a.slice.Set(multiDimSlice)
   877  
   878  		// Now that a.slice is a multi-dimensional slice with the underlying data pointed at flatSlice change a.slice to
   879  		// flatSlice so ScanIndex only has to handle simple one dimensional slices.
   880  		a.slice = flatSlice
   881  
   882  		return nil
   883  	}
   884  
   885  }
   886  
   887  func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value {
   888  	if len(dimensions) == 1 {
   889  		endIdx := flatSliceIdx + int(dimensions[0].Length)
   890  		return flatSlice.Slice3(flatSliceIdx, endIdx, endIdx)
   891  	}
   892  
   893  	sliceLen := int(dimensions[0].Length)
   894  	slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen)
   895  	for i := 0; i < sliceLen; i++ {
   896  		subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length)))
   897  		slice.Index(i).Set(subSlice)
   898  	}
   899  
   900  	return slice
   901  }
   902  
   903  func (a *anyMultiDimSliceArray) ScanIndex(i int) any {
   904  	return a.slice.Index(i).Addr().Interface()
   905  }
   906  
   907  func (a *anyMultiDimSliceArray) ScanIndexType() any {
   908  	lowestSliceType := a.slice.Type()
   909  	for ; lowestSliceType.Elem().Kind() == reflect.Slice; lowestSliceType = lowestSliceType.Elem() {
   910  	}
   911  	return reflect.New(lowestSliceType.Elem()).Interface()
   912  }
   913  
   914  type anyArrayArrayReflect struct {
   915  	array reflect.Value
   916  }
   917  
   918  func (a anyArrayArrayReflect) Dimensions() []ArrayDimension {
   919  	return []ArrayDimension{{Length: int32(a.array.Len()), LowerBound: 1}}
   920  }
   921  
   922  func (a anyArrayArrayReflect) Index(i int) any {
   923  	return a.array.Index(i).Interface()
   924  }
   925  
   926  func (a anyArrayArrayReflect) IndexType() any {
   927  	return reflect.New(a.array.Type().Elem()).Elem().Interface()
   928  }
   929  
   930  func (a *anyArrayArrayReflect) SetDimensions(dimensions []ArrayDimension) error {
   931  	if dimensions == nil {
   932  		return fmt.Errorf("anyArrayArrayReflect: cannot scan NULL into %v", a.array.Type().String())
   933  	}
   934  
   935  	if len(dimensions) != 1 {
   936  		return fmt.Errorf("anyArrayArrayReflect: cannot scan multi-dimensional array into %v", a.array.Type().String())
   937  	}
   938  
   939  	if int(dimensions[0].Length) != a.array.Len() {
   940  		return fmt.Errorf("anyArrayArrayReflect: cannot scan array with length %v into %v", dimensions[0].Length, a.array.Type().String())
   941  	}
   942  
   943  	return nil
   944  }
   945  
   946  func (a *anyArrayArrayReflect) ScanIndex(i int) any {
   947  	return a.array.Index(i).Addr().Interface()
   948  }
   949  
   950  func (a *anyArrayArrayReflect) ScanIndexType() any {
   951  	return reflect.New(a.array.Type().Elem()).Interface()
   952  }
   953  

View as plain text