...

Source file src/github.com/jackc/pgx/v5/pgtype/array.go

Documentation: github.com/jackc/pgx/v5/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"strconv"
     9  	"strings"
    10  	"unicode"
    11  
    12  	"github.com/jackc/pgx/v5/internal/pgio"
    13  )
    14  
    15  // Information on the internals of PostgreSQL arrays can be found in
    16  // src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of
    17  // particular interest is the array_send function.
    18  
    19  type arrayHeader struct {
    20  	ContainsNull bool
    21  	ElementOID   uint32
    22  	Dimensions   []ArrayDimension
    23  }
    24  
    25  type ArrayDimension struct {
    26  	Length     int32
    27  	LowerBound int32
    28  }
    29  
    30  // cardinality returns the number of elements in an array of dimensions size.
    31  func cardinality(dimensions []ArrayDimension) int {
    32  	if len(dimensions) == 0 {
    33  		return 0
    34  	}
    35  
    36  	elementCount := int(dimensions[0].Length)
    37  	for _, d := range dimensions[1:] {
    38  		elementCount *= int(d.Length)
    39  	}
    40  
    41  	return elementCount
    42  }
    43  
    44  func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) {
    45  	if len(src) < 12 {
    46  		return 0, fmt.Errorf("array header too short: %d", len(src))
    47  	}
    48  
    49  	rp := 0
    50  
    51  	numDims := int(binary.BigEndian.Uint32(src[rp:]))
    52  	rp += 4
    53  
    54  	dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1
    55  	rp += 4
    56  
    57  	dst.ElementOID = binary.BigEndian.Uint32(src[rp:])
    58  	rp += 4
    59  
    60  	dst.Dimensions = make([]ArrayDimension, numDims)
    61  	if len(src) < 12+numDims*8 {
    62  		return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src))
    63  	}
    64  	for i := range dst.Dimensions {
    65  		dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:]))
    66  		rp += 4
    67  
    68  		dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:]))
    69  		rp += 4
    70  	}
    71  
    72  	return rp, nil
    73  }
    74  
    75  func (src arrayHeader) EncodeBinary(buf []byte) []byte {
    76  	buf = pgio.AppendInt32(buf, int32(len(src.Dimensions)))
    77  
    78  	var containsNull int32
    79  	if src.ContainsNull {
    80  		containsNull = 1
    81  	}
    82  	buf = pgio.AppendInt32(buf, containsNull)
    83  
    84  	buf = pgio.AppendUint32(buf, src.ElementOID)
    85  
    86  	for i := range src.Dimensions {
    87  		buf = pgio.AppendInt32(buf, src.Dimensions[i].Length)
    88  		buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound)
    89  	}
    90  
    91  	return buf
    92  }
    93  
    94  type untypedTextArray struct {
    95  	Elements   []string
    96  	Quoted     []bool
    97  	Dimensions []ArrayDimension
    98  }
    99  
   100  func parseUntypedTextArray(src string) (*untypedTextArray, error) {
   101  	dst := &untypedTextArray{
   102  		Elements:   []string{},
   103  		Quoted:     []bool{},
   104  		Dimensions: []ArrayDimension{},
   105  	}
   106  
   107  	buf := bytes.NewBufferString(src)
   108  
   109  	skipWhitespace(buf)
   110  
   111  	r, _, err := buf.ReadRune()
   112  	if err != nil {
   113  		return nil, fmt.Errorf("invalid array: %w", err)
   114  	}
   115  
   116  	var explicitDimensions []ArrayDimension
   117  
   118  	// Array has explicit dimensions
   119  	if r == '[' {
   120  		buf.UnreadRune()
   121  
   122  		for {
   123  			r, _, err = buf.ReadRune()
   124  			if err != nil {
   125  				return nil, fmt.Errorf("invalid array: %w", err)
   126  			}
   127  
   128  			if r == '=' {
   129  				break
   130  			} else if r != '[' {
   131  				return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r)
   132  			}
   133  
   134  			lower, err := arrayParseInteger(buf)
   135  			if err != nil {
   136  				return nil, fmt.Errorf("invalid array: %w", err)
   137  			}
   138  
   139  			r, _, err = buf.ReadRune()
   140  			if err != nil {
   141  				return nil, fmt.Errorf("invalid array: %w", err)
   142  			}
   143  
   144  			if r != ':' {
   145  				return nil, fmt.Errorf("invalid array, expected ':' got %v", r)
   146  			}
   147  
   148  			upper, err := arrayParseInteger(buf)
   149  			if err != nil {
   150  				return nil, fmt.Errorf("invalid array: %w", err)
   151  			}
   152  
   153  			r, _, err = buf.ReadRune()
   154  			if err != nil {
   155  				return nil, fmt.Errorf("invalid array: %w", err)
   156  			}
   157  
   158  			if r != ']' {
   159  				return nil, fmt.Errorf("invalid array, expected ']' got %v", r)
   160  			}
   161  
   162  			explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1})
   163  		}
   164  
   165  		r, _, err = buf.ReadRune()
   166  		if err != nil {
   167  			return nil, fmt.Errorf("invalid array: %w", err)
   168  		}
   169  	}
   170  
   171  	if r != '{' {
   172  		return nil, fmt.Errorf("invalid array, expected '{' got %v", r)
   173  	}
   174  
   175  	implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}}
   176  
   177  	// Consume all initial opening brackets. This provides number of dimensions.
   178  	for {
   179  		r, _, err = buf.ReadRune()
   180  		if err != nil {
   181  			return nil, fmt.Errorf("invalid array: %w", err)
   182  		}
   183  
   184  		if r == '{' {
   185  			implicitDimensions[len(implicitDimensions)-1].Length = 1
   186  			implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1})
   187  		} else {
   188  			buf.UnreadRune()
   189  			break
   190  		}
   191  	}
   192  	currentDim := len(implicitDimensions) - 1
   193  	counterDim := currentDim
   194  
   195  	for {
   196  		r, _, err = buf.ReadRune()
   197  		if err != nil {
   198  			return nil, fmt.Errorf("invalid array: %w", err)
   199  		}
   200  
   201  		switch r {
   202  		case '{':
   203  			if currentDim == counterDim {
   204  				implicitDimensions[currentDim].Length++
   205  			}
   206  			currentDim++
   207  		case ',':
   208  		case '}':
   209  			currentDim--
   210  			if currentDim < counterDim {
   211  				counterDim = currentDim
   212  			}
   213  		default:
   214  			buf.UnreadRune()
   215  			value, quoted, err := arrayParseValue(buf)
   216  			if err != nil {
   217  				return nil, fmt.Errorf("invalid array value: %w", err)
   218  			}
   219  			if currentDim == counterDim {
   220  				implicitDimensions[currentDim].Length++
   221  			}
   222  			dst.Quoted = append(dst.Quoted, quoted)
   223  			dst.Elements = append(dst.Elements, value)
   224  		}
   225  
   226  		if currentDim < 0 {
   227  			break
   228  		}
   229  	}
   230  
   231  	skipWhitespace(buf)
   232  
   233  	if buf.Len() > 0 {
   234  		return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
   235  	}
   236  
   237  	if len(dst.Elements) == 0 {
   238  	} else if len(explicitDimensions) > 0 {
   239  		dst.Dimensions = explicitDimensions
   240  	} else {
   241  		dst.Dimensions = implicitDimensions
   242  	}
   243  
   244  	return dst, nil
   245  }
   246  
   247  func skipWhitespace(buf *bytes.Buffer) {
   248  	var r rune
   249  	var err error
   250  	for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() {
   251  	}
   252  
   253  	if err != io.EOF {
   254  		buf.UnreadRune()
   255  	}
   256  }
   257  
   258  func arrayParseValue(buf *bytes.Buffer) (string, bool, error) {
   259  	r, _, err := buf.ReadRune()
   260  	if err != nil {
   261  		return "", false, err
   262  	}
   263  	if r == '"' {
   264  		return arrayParseQuotedValue(buf)
   265  	}
   266  	buf.UnreadRune()
   267  
   268  	s := &bytes.Buffer{}
   269  
   270  	for {
   271  		r, _, err := buf.ReadRune()
   272  		if err != nil {
   273  			return "", false, err
   274  		}
   275  
   276  		switch r {
   277  		case ',', '}':
   278  			buf.UnreadRune()
   279  			return s.String(), false, nil
   280  		}
   281  
   282  		s.WriteRune(r)
   283  	}
   284  }
   285  
   286  func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) {
   287  	s := &bytes.Buffer{}
   288  
   289  	for {
   290  		r, _, err := buf.ReadRune()
   291  		if err != nil {
   292  			return "", false, err
   293  		}
   294  
   295  		switch r {
   296  		case '\\':
   297  			r, _, err = buf.ReadRune()
   298  			if err != nil {
   299  				return "", false, err
   300  			}
   301  		case '"':
   302  			r, _, err = buf.ReadRune()
   303  			if err != nil {
   304  				return "", false, err
   305  			}
   306  			buf.UnreadRune()
   307  			return s.String(), true, nil
   308  		}
   309  		s.WriteRune(r)
   310  	}
   311  }
   312  
   313  func arrayParseInteger(buf *bytes.Buffer) (int32, error) {
   314  	s := &bytes.Buffer{}
   315  
   316  	for {
   317  		r, _, err := buf.ReadRune()
   318  		if err != nil {
   319  			return 0, err
   320  		}
   321  
   322  		if ('0' <= r && r <= '9') || r == '-' {
   323  			s.WriteRune(r)
   324  		} else {
   325  			buf.UnreadRune()
   326  			n, err := strconv.ParseInt(s.String(), 10, 32)
   327  			if err != nil {
   328  				return 0, err
   329  			}
   330  			return int32(n), nil
   331  		}
   332  	}
   333  }
   334  
   335  func encodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte {
   336  	var customDimensions bool
   337  	for _, dim := range dimensions {
   338  		if dim.LowerBound != 1 {
   339  			customDimensions = true
   340  		}
   341  	}
   342  
   343  	if !customDimensions {
   344  		return buf
   345  	}
   346  
   347  	for _, dim := range dimensions {
   348  		buf = append(buf, '[')
   349  		buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...)
   350  		buf = append(buf, ':')
   351  		buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...)
   352  		buf = append(buf, ']')
   353  	}
   354  
   355  	return append(buf, '=')
   356  }
   357  
   358  var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
   359  
   360  func quoteArrayElement(src string) string {
   361  	return `"` + quoteArrayReplacer.Replace(src) + `"`
   362  }
   363  
   364  func isSpace(ch byte) bool {
   365  	// see array_isspace:
   366  	// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c
   367  	return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f'
   368  }
   369  
   370  func quoteArrayElementIfNeeded(src string) string {
   371  	if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) {
   372  		return quoteArrayElement(src)
   373  	}
   374  	return src
   375  }
   376  
   377  // Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves
   378  // PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed.
   379  type Array[T any] struct {
   380  	Elements []T
   381  	Dims     []ArrayDimension
   382  	Valid    bool
   383  }
   384  
   385  func (a Array[T]) Dimensions() []ArrayDimension {
   386  	return a.Dims
   387  }
   388  
   389  func (a Array[T]) Index(i int) any {
   390  	return a.Elements[i]
   391  }
   392  
   393  func (a Array[T]) IndexType() any {
   394  	var el T
   395  	return el
   396  }
   397  
   398  func (a *Array[T]) SetDimensions(dimensions []ArrayDimension) error {
   399  	if dimensions == nil {
   400  		*a = Array[T]{}
   401  		return nil
   402  	}
   403  
   404  	elementCount := cardinality(dimensions)
   405  	*a = Array[T]{
   406  		Elements: make([]T, elementCount),
   407  		Dims:     dimensions,
   408  		Valid:    true,
   409  	}
   410  
   411  	return nil
   412  }
   413  
   414  func (a Array[T]) ScanIndex(i int) any {
   415  	return &a.Elements[i]
   416  }
   417  
   418  func (a Array[T]) ScanIndexType() any {
   419  	return new(T)
   420  }
   421  
   422  // FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions
   423  // and custom lower bounds. Use Array to preserve these.
   424  type FlatArray[T any] []T
   425  
   426  func (a FlatArray[T]) Dimensions() []ArrayDimension {
   427  	if a == nil {
   428  		return nil
   429  	}
   430  
   431  	return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}}
   432  }
   433  
   434  func (a FlatArray[T]) Index(i int) any {
   435  	return a[i]
   436  }
   437  
   438  func (a FlatArray[T]) IndexType() any {
   439  	var el T
   440  	return el
   441  }
   442  
   443  func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error {
   444  	if dimensions == nil {
   445  		*a = nil
   446  		return nil
   447  	}
   448  
   449  	elementCount := cardinality(dimensions)
   450  	*a = make(FlatArray[T], elementCount)
   451  	return nil
   452  }
   453  
   454  func (a FlatArray[T]) ScanIndex(i int) any {
   455  	return &a[i]
   456  }
   457  
   458  func (a FlatArray[T]) ScanIndexType() any {
   459  	return new(T)
   460  }
   461  

View as plain text