...

Source file src/github.com/jackc/pgx/v5/pgtype/range_codec.go

Documentation: github.com/jackc/pgx/v5/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"fmt"
     6  
     7  	"github.com/jackc/pgx/v5/internal/pgio"
     8  )
     9  
    10  // RangeValuer is a type that can be converted into a PostgreSQL range.
    11  type RangeValuer interface {
    12  	// IsNull returns true if the value is SQL NULL.
    13  	IsNull() bool
    14  
    15  	// BoundTypes returns the lower and upper bound types.
    16  	BoundTypes() (lower, upper BoundType)
    17  
    18  	// Bounds returns the lower and upper range values.
    19  	Bounds() (lower, upper any)
    20  }
    21  
    22  // RangeScanner is a type can be scanned from a PostgreSQL range.
    23  type RangeScanner interface {
    24  	// ScanNull sets the value to SQL NULL.
    25  	ScanNull() error
    26  
    27  	// ScanBounds returns values usable as a scan target. The returned values may not be scanned if the range is empty or
    28  	// the bound type is unbounded.
    29  	ScanBounds() (lowerTarget, upperTarget any)
    30  
    31  	// SetBoundTypes sets the lower and upper bound types. ScanBounds will be called and the returned values scanned
    32  	// (if appropriate) before SetBoundTypes is called. If the bound types are unbounded or empty this method must
    33  	// also set the bound values.
    34  	SetBoundTypes(lower, upper BoundType) error
    35  }
    36  
    37  // RangeCodec is a codec for any range type.
    38  type RangeCodec struct {
    39  	ElementType *Type
    40  }
    41  
    42  func (c *RangeCodec) FormatSupported(format int16) bool {
    43  	return c.ElementType.Codec.FormatSupported(format)
    44  }
    45  
    46  func (c *RangeCodec) PreferredFormat() int16 {
    47  	if c.FormatSupported(BinaryFormatCode) {
    48  		return BinaryFormatCode
    49  	}
    50  	return TextFormatCode
    51  }
    52  
    53  func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
    54  	if _, ok := value.(RangeValuer); !ok {
    55  		return nil
    56  	}
    57  
    58  	switch format {
    59  	case BinaryFormatCode:
    60  		return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m}
    61  	case TextFormatCode:
    62  		return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m}
    63  	}
    64  
    65  	return nil
    66  }
    67  
    68  type encodePlanRangeCodecRangeValuerToBinary struct {
    69  	rc *RangeCodec
    70  	m  *Map
    71  }
    72  
    73  func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
    74  	getter := value.(RangeValuer)
    75  
    76  	if getter.IsNull() {
    77  		return nil, nil
    78  	}
    79  
    80  	lowerType, upperType := getter.BoundTypes()
    81  	lower, upper := getter.Bounds()
    82  
    83  	var rangeType byte
    84  	switch lowerType {
    85  	case Inclusive:
    86  		rangeType |= lowerInclusiveMask
    87  	case Unbounded:
    88  		rangeType |= lowerUnboundedMask
    89  	case Exclusive:
    90  	case Empty:
    91  		return append(buf, emptyMask), nil
    92  	default:
    93  		return nil, fmt.Errorf("unknown LowerType: %v", lowerType)
    94  	}
    95  
    96  	switch upperType {
    97  	case Inclusive:
    98  		rangeType |= upperInclusiveMask
    99  	case Unbounded:
   100  		rangeType |= upperUnboundedMask
   101  	case Exclusive:
   102  	default:
   103  		return nil, fmt.Errorf("unknown UpperType: %v", upperType)
   104  	}
   105  
   106  	buf = append(buf, rangeType)
   107  
   108  	if lowerType != Unbounded {
   109  		if lower == nil {
   110  			return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
   111  		}
   112  
   113  		sp := len(buf)
   114  		buf = pgio.AppendInt32(buf, -1)
   115  
   116  		lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower)
   117  		if lowerPlan == nil {
   118  			return nil, fmt.Errorf("cannot encode %v as element of range", lower)
   119  		}
   120  
   121  		buf, err = lowerPlan.Encode(lower, buf)
   122  		if err != nil {
   123  			return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
   124  		}
   125  		if buf == nil {
   126  			return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
   127  		}
   128  
   129  		pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   130  	}
   131  
   132  	if upperType != Unbounded {
   133  		if upper == nil {
   134  			return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
   135  		}
   136  
   137  		sp := len(buf)
   138  		buf = pgio.AppendInt32(buf, -1)
   139  
   140  		upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper)
   141  		if upperPlan == nil {
   142  			return nil, fmt.Errorf("cannot encode %v as element of range", upper)
   143  		}
   144  
   145  		buf, err = upperPlan.Encode(upper, buf)
   146  		if err != nil {
   147  			return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
   148  		}
   149  		if buf == nil {
   150  			return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
   151  		}
   152  
   153  		pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   154  	}
   155  
   156  	return buf, nil
   157  }
   158  
   159  type encodePlanRangeCodecRangeValuerToText struct {
   160  	rc *RangeCodec
   161  	m  *Map
   162  }
   163  
   164  func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
   165  	getter := value.(RangeValuer)
   166  
   167  	if getter.IsNull() {
   168  		return nil, nil
   169  	}
   170  
   171  	lowerType, upperType := getter.BoundTypes()
   172  	lower, upper := getter.Bounds()
   173  
   174  	switch lowerType {
   175  	case Exclusive, Unbounded:
   176  		buf = append(buf, '(')
   177  	case Inclusive:
   178  		buf = append(buf, '[')
   179  	case Empty:
   180  		return append(buf, "empty"...), nil
   181  	default:
   182  		return nil, fmt.Errorf("unknown lower bound type %v", lowerType)
   183  	}
   184  
   185  	if lowerType != Unbounded {
   186  		if lower == nil {
   187  			return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
   188  		}
   189  
   190  		lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower)
   191  		if lowerPlan == nil {
   192  			return nil, fmt.Errorf("cannot encode %v as element of range", lower)
   193  		}
   194  
   195  		buf, err = lowerPlan.Encode(lower, buf)
   196  		if err != nil {
   197  			return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
   198  		}
   199  		if buf == nil {
   200  			return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
   201  		}
   202  	}
   203  
   204  	buf = append(buf, ',')
   205  
   206  	if upperType != Unbounded {
   207  		if upper == nil {
   208  			return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
   209  		}
   210  
   211  		upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper)
   212  		if upperPlan == nil {
   213  			return nil, fmt.Errorf("cannot encode %v as element of range", upper)
   214  		}
   215  
   216  		buf, err = upperPlan.Encode(upper, buf)
   217  		if err != nil {
   218  			return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
   219  		}
   220  		if buf == nil {
   221  			return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
   222  		}
   223  	}
   224  
   225  	switch upperType {
   226  	case Exclusive, Unbounded:
   227  		buf = append(buf, ')')
   228  	case Inclusive:
   229  		buf = append(buf, ']')
   230  	default:
   231  		return nil, fmt.Errorf("unknown upper bound type %v", upperType)
   232  	}
   233  
   234  	return buf, nil
   235  }
   236  
   237  func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
   238  	switch format {
   239  	case BinaryFormatCode:
   240  		switch target.(type) {
   241  		case RangeScanner:
   242  			return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m}
   243  		}
   244  	case TextFormatCode:
   245  		switch target.(type) {
   246  		case RangeScanner:
   247  			return &scanPlanTextRangeToRangeScanner{rc: c, m: m}
   248  		}
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  type scanPlanBinaryRangeToRangeScanner struct {
   255  	rc *RangeCodec
   256  	m  *Map
   257  }
   258  
   259  func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error {
   260  	rangeScanner := (target).(RangeScanner)
   261  
   262  	if src == nil {
   263  		return rangeScanner.ScanNull()
   264  	}
   265  
   266  	ubr, err := parseUntypedBinaryRange(src)
   267  	if err != nil {
   268  		return err
   269  	}
   270  
   271  	if ubr.LowerType == Empty {
   272  		return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
   273  	}
   274  
   275  	lowerTarget, upperTarget := rangeScanner.ScanBounds()
   276  
   277  	if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive {
   278  		lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget)
   279  		if lowerPlan == nil {
   280  			return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
   281  		}
   282  
   283  		err = lowerPlan.Scan(ubr.Lower, lowerTarget)
   284  		if err != nil {
   285  			return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
   286  		}
   287  	}
   288  
   289  	if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive {
   290  		upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget)
   291  		if upperPlan == nil {
   292  			return fmt.Errorf("cannot scan into %v from range element", upperTarget)
   293  		}
   294  
   295  		err = upperPlan.Scan(ubr.Upper, upperTarget)
   296  		if err != nil {
   297  			return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
   298  		}
   299  	}
   300  
   301  	return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
   302  }
   303  
   304  type scanPlanTextRangeToRangeScanner struct {
   305  	rc *RangeCodec
   306  	m  *Map
   307  }
   308  
   309  func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error {
   310  	rangeScanner := (target).(RangeScanner)
   311  
   312  	if src == nil {
   313  		return rangeScanner.ScanNull()
   314  	}
   315  
   316  	utr, err := parseUntypedTextRange(string(src))
   317  	if err != nil {
   318  		return err
   319  	}
   320  
   321  	if utr.LowerType == Empty {
   322  		return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
   323  	}
   324  
   325  	lowerTarget, upperTarget := rangeScanner.ScanBounds()
   326  
   327  	if utr.LowerType == Inclusive || utr.LowerType == Exclusive {
   328  		lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget)
   329  		if lowerPlan == nil {
   330  			return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
   331  		}
   332  
   333  		err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget)
   334  		if err != nil {
   335  			return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
   336  		}
   337  	}
   338  
   339  	if utr.UpperType == Inclusive || utr.UpperType == Exclusive {
   340  		upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget)
   341  		if upperPlan == nil {
   342  			return fmt.Errorf("cannot scan into %v from range element", upperTarget)
   343  		}
   344  
   345  		err = upperPlan.Scan([]byte(utr.Upper), upperTarget)
   346  		if err != nil {
   347  			return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
   348  		}
   349  	}
   350  
   351  	return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
   352  }
   353  
   354  func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
   355  	if src == nil {
   356  		return nil, nil
   357  	}
   358  
   359  	switch format {
   360  	case TextFormatCode:
   361  		return string(src), nil
   362  	case BinaryFormatCode:
   363  		buf := make([]byte, len(src))
   364  		copy(buf, src)
   365  		return buf, nil
   366  	default:
   367  		return nil, fmt.Errorf("unknown format code %d", format)
   368  	}
   369  }
   370  
   371  func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
   372  	if src == nil {
   373  		return nil, nil
   374  	}
   375  
   376  	var r Range[any]
   377  	err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
   378  	return r, err
   379  }
   380  

View as plain text