...

Source file src/github.com/jackc/pgx/v5/pgtype/multirange.go

Documentation: github.com/jackc/pgx/v5/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql/driver"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"reflect"
     9  
    10  	"github.com/jackc/pgx/v5/internal/pgio"
    11  )
    12  
    13  // MultirangeGetter is a type that can be converted into a PostgreSQL multirange.
    14  type MultirangeGetter interface {
    15  	// IsNull returns true if the value is SQL NULL.
    16  	IsNull() bool
    17  
    18  	// Len returns the number of elements in the multirange.
    19  	Len() int
    20  
    21  	// Index returns the element at i.
    22  	Index(i int) any
    23  
    24  	// IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode.
    25  	IndexType() any
    26  }
    27  
    28  // MultirangeSetter is a type can be set from a PostgreSQL multirange.
    29  type MultirangeSetter interface {
    30  	// ScanNull sets the value to SQL NULL.
    31  	ScanNull() error
    32  
    33  	// SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing
    34  	// elements.
    35  	SetLen(n int) error
    36  
    37  	// ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex.
    38  	ScanIndex(i int) any
    39  
    40  	// ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by
    41  	// MultirangeCodec.PlanScan.
    42  	ScanIndexType() any
    43  }
    44  
    45  // MultirangeCodec is a codec for any multirange type.
    46  type MultirangeCodec struct {
    47  	ElementType *Type
    48  }
    49  
    50  func (c *MultirangeCodec) FormatSupported(format int16) bool {
    51  	return c.ElementType.Codec.FormatSupported(format)
    52  }
    53  
    54  func (c *MultirangeCodec) PreferredFormat() int16 {
    55  	return c.ElementType.Codec.PreferredFormat()
    56  }
    57  
    58  func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
    59  	multirangeValuer, ok := value.(MultirangeGetter)
    60  	if !ok {
    61  		return nil
    62  	}
    63  
    64  	elementType := multirangeValuer.IndexType()
    65  
    66  	elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
    67  	if elementEncodePlan == nil {
    68  		return nil
    69  	}
    70  
    71  	switch format {
    72  	case BinaryFormatCode:
    73  		return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid}
    74  	case TextFormatCode:
    75  		return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid}
    76  	}
    77  
    78  	return nil
    79  }
    80  
    81  type encodePlanMultirangeCodecText struct {
    82  	ac  *MultirangeCodec
    83  	m   *Map
    84  	oid uint32
    85  }
    86  
    87  func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
    88  	multirange := value.(MultirangeGetter)
    89  
    90  	if multirange.IsNull() {
    91  		return nil, nil
    92  	}
    93  
    94  	elementCount := multirange.Len()
    95  
    96  	buf = append(buf, '{')
    97  
    98  	var encodePlan EncodePlan
    99  	var lastElemType reflect.Type
   100  	inElemBuf := make([]byte, 0, 32)
   101  	for i := 0; i < elementCount; i++ {
   102  		if i > 0 {
   103  			buf = append(buf, ',')
   104  		}
   105  
   106  		elem := multirange.Index(i)
   107  		var elemBuf []byte
   108  		if elem != nil {
   109  			elemType := reflect.TypeOf(elem)
   110  			if lastElemType != elemType {
   111  				lastElemType = elemType
   112  				encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem)
   113  				if encodePlan == nil {
   114  					return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
   115  				}
   116  			}
   117  			elemBuf, err = encodePlan.Encode(elem, inElemBuf)
   118  			if err != nil {
   119  				return nil, err
   120  			}
   121  		}
   122  
   123  		if elemBuf == nil {
   124  			return nil, fmt.Errorf("multirange cannot contain NULL element")
   125  		} else {
   126  			buf = append(buf, elemBuf...)
   127  		}
   128  	}
   129  
   130  	buf = append(buf, '}')
   131  
   132  	return buf, nil
   133  }
   134  
   135  type encodePlanMultirangeCodecBinary struct {
   136  	ac  *MultirangeCodec
   137  	m   *Map
   138  	oid uint32
   139  }
   140  
   141  func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
   142  	multirange := value.(MultirangeGetter)
   143  
   144  	if multirange.IsNull() {
   145  		return nil, nil
   146  	}
   147  
   148  	elementCount := multirange.Len()
   149  
   150  	buf = pgio.AppendInt32(buf, int32(elementCount))
   151  
   152  	var encodePlan EncodePlan
   153  	var lastElemType reflect.Type
   154  	for i := 0; i < elementCount; i++ {
   155  		sp := len(buf)
   156  		buf = pgio.AppendInt32(buf, -1)
   157  
   158  		elem := multirange.Index(i)
   159  		var elemBuf []byte
   160  		if elem != nil {
   161  			elemType := reflect.TypeOf(elem)
   162  			if lastElemType != elemType {
   163  				lastElemType = elemType
   164  				encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem)
   165  				if encodePlan == nil {
   166  					return nil, fmt.Errorf("unable to encode %v", multirange.Index(i))
   167  				}
   168  			}
   169  			elemBuf, err = encodePlan.Encode(elem, buf)
   170  			if err != nil {
   171  				return nil, err
   172  			}
   173  		}
   174  
   175  		if elemBuf == nil {
   176  			return nil, fmt.Errorf("multirange cannot contain NULL element")
   177  		} else {
   178  			buf = elemBuf
   179  			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   180  		}
   181  	}
   182  
   183  	return buf, nil
   184  }
   185  
   186  func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
   187  	multirangeScanner, ok := target.(MultirangeSetter)
   188  	if !ok {
   189  		return nil
   190  	}
   191  
   192  	elementType := multirangeScanner.ScanIndexType()
   193  
   194  	elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType)
   195  	if _, ok := elementScanPlan.(*scanPlanFail); ok {
   196  		return nil
   197  	}
   198  
   199  	return &scanPlanMultirangeCodec{
   200  		multirangeCodec: c,
   201  		m:               m,
   202  		oid:             oid,
   203  		formatCode:      format,
   204  	}
   205  }
   206  
   207  func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
   208  	rp := 0
   209  
   210  	elementCount := int(binary.BigEndian.Uint32(src[rp:]))
   211  	rp += 4
   212  
   213  	err := multirange.SetLen(elementCount)
   214  	if err != nil {
   215  		return err
   216  	}
   217  
   218  	if elementCount == 0 {
   219  		return nil
   220  	}
   221  
   222  	elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
   223  	if elementScanPlan == nil {
   224  		elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0))
   225  	}
   226  
   227  	for i := 0; i < elementCount; i++ {
   228  		elem := multirange.ScanIndex(i)
   229  		elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
   230  		rp += 4
   231  		var elemSrc []byte
   232  		if elemLen >= 0 {
   233  			elemSrc = src[rp : rp+elemLen]
   234  			rp += elemLen
   235  		}
   236  		err = elementScanPlan.Scan(elemSrc, elem)
   237  		if err != nil {
   238  			return fmt.Errorf("failed to scan multirange element %d: %w", i, err)
   239  		}
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error {
   246  	elements, err := parseUntypedTextMultirange(src)
   247  	if err != nil {
   248  		return err
   249  	}
   250  
   251  	err = multirange.SetLen(len(elements))
   252  	if err != nil {
   253  		return err
   254  	}
   255  
   256  	if len(elements) == 0 {
   257  		return nil
   258  	}
   259  
   260  	elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
   261  	if elementScanPlan == nil {
   262  		elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0))
   263  	}
   264  
   265  	for i, s := range elements {
   266  		elem := multirange.ScanIndex(i)
   267  		err = elementScanPlan.Scan([]byte(s), elem)
   268  		if err != nil {
   269  			return err
   270  		}
   271  	}
   272  
   273  	return nil
   274  }
   275  
   276  type scanPlanMultirangeCodec struct {
   277  	multirangeCodec *MultirangeCodec
   278  	m               *Map
   279  	oid             uint32
   280  	formatCode      int16
   281  	elementScanPlan ScanPlan
   282  }
   283  
   284  func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error {
   285  	c := spac.multirangeCodec
   286  	m := spac.m
   287  	oid := spac.oid
   288  	formatCode := spac.formatCode
   289  
   290  	multirange := dst.(MultirangeSetter)
   291  
   292  	if src == nil {
   293  		return multirange.ScanNull()
   294  	}
   295  
   296  	switch formatCode {
   297  	case BinaryFormatCode:
   298  		return c.decodeBinary(m, oid, src, multirange)
   299  	case TextFormatCode:
   300  		return c.decodeText(m, oid, src, multirange)
   301  	default:
   302  		return fmt.Errorf("unknown format code %d", formatCode)
   303  	}
   304  }
   305  
   306  func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
   307  	if src == nil {
   308  		return nil, nil
   309  	}
   310  
   311  	switch format {
   312  	case TextFormatCode:
   313  		return string(src), nil
   314  	case BinaryFormatCode:
   315  		buf := make([]byte, len(src))
   316  		copy(buf, src)
   317  		return buf, nil
   318  	default:
   319  		return nil, fmt.Errorf("unknown format code %d", format)
   320  	}
   321  }
   322  
   323  func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
   324  	if src == nil {
   325  		return nil, nil
   326  	}
   327  
   328  	var multirange Multirange[Range[any]]
   329  	err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange)
   330  	return multirange, err
   331  }
   332  
   333  func parseUntypedTextMultirange(src []byte) ([]string, error) {
   334  	elements := make([]string, 0)
   335  
   336  	buf := bytes.NewBuffer(src)
   337  
   338  	skipWhitespace(buf)
   339  
   340  	r, _, err := buf.ReadRune()
   341  	if err != nil {
   342  		return nil, fmt.Errorf("invalid array: %w", err)
   343  	}
   344  
   345  	if r != '{' {
   346  		return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r)
   347  	}
   348  
   349  parseValueLoop:
   350  	for {
   351  		r, _, err = buf.ReadRune()
   352  		if err != nil {
   353  			return nil, fmt.Errorf("invalid multirange: %w", err)
   354  		}
   355  
   356  		switch r {
   357  		case ',': // skip range separator
   358  		case '}':
   359  			break parseValueLoop
   360  		default:
   361  			buf.UnreadRune()
   362  			value, err := parseRange(buf)
   363  			if err != nil {
   364  				return nil, fmt.Errorf("invalid multirange value: %w", err)
   365  			}
   366  			elements = append(elements, value)
   367  		}
   368  	}
   369  
   370  	skipWhitespace(buf)
   371  
   372  	if buf.Len() > 0 {
   373  		return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
   374  	}
   375  
   376  	return elements, nil
   377  
   378  }
   379  
   380  func parseRange(buf *bytes.Buffer) (string, error) {
   381  	s := &bytes.Buffer{}
   382  
   383  	boundSepRead := false
   384  	for {
   385  		r, _, err := buf.ReadRune()
   386  		if err != nil {
   387  			return "", err
   388  		}
   389  
   390  		switch r {
   391  		case ',', '}':
   392  			if r == ',' && !boundSepRead {
   393  				boundSepRead = true
   394  				break
   395  			}
   396  			buf.UnreadRune()
   397  			return s.String(), nil
   398  		}
   399  
   400  		s.WriteRune(r)
   401  	}
   402  }
   403  
   404  // Multirange is a generic multirange type.
   405  //
   406  // T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to
   407  // enforce the RangeScanner constraint.
   408  type Multirange[T RangeValuer] []T
   409  
   410  func (r Multirange[T]) IsNull() bool {
   411  	return r == nil
   412  }
   413  
   414  func (r Multirange[T]) Len() int {
   415  	return len(r)
   416  }
   417  
   418  func (r Multirange[T]) Index(i int) any {
   419  	return r[i]
   420  }
   421  
   422  func (r Multirange[T]) IndexType() any {
   423  	var zero T
   424  	return zero
   425  }
   426  
   427  func (r *Multirange[T]) ScanNull() error {
   428  	*r = nil
   429  	return nil
   430  }
   431  
   432  func (r *Multirange[T]) SetLen(n int) error {
   433  	*r = make([]T, n)
   434  	return nil
   435  }
   436  
   437  func (r Multirange[T]) ScanIndex(i int) any {
   438  	return &r[i]
   439  }
   440  
   441  func (r Multirange[T]) ScanIndexType() any {
   442  	return new(T)
   443  }
   444  

View as plain text