...

Source file src/github.com/jackc/pgtype/array.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  	"unicode"
    12  
    13  	"github.com/jackc/pgio"
    14  )
    15  
    16  // Information on the internals of PostgreSQL arrays can be found in
    17  // src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of
    18  // particular interest is the array_send function.
    19  
    20  type ArrayHeader struct {
    21  	ContainsNull bool
    22  	ElementOID   int32
    23  	Dimensions   []ArrayDimension
    24  }
    25  
    26  type ArrayDimension struct {
    27  	Length     int32
    28  	LowerBound int32
    29  }
    30  
    31  func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) {
    32  	if len(src) < 12 {
    33  		return 0, fmt.Errorf("array header too short: %d", len(src))
    34  	}
    35  
    36  	rp := 0
    37  
    38  	numDims := int(binary.BigEndian.Uint32(src[rp:]))
    39  	rp += 4
    40  
    41  	dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1
    42  	rp += 4
    43  
    44  	dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:]))
    45  	rp += 4
    46  
    47  	if numDims > 0 {
    48  		dst.Dimensions = make([]ArrayDimension, numDims)
    49  	}
    50  	if len(src) < 12+numDims*8 {
    51  		return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src))
    52  	}
    53  	for i := range dst.Dimensions {
    54  		dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:]))
    55  		rp += 4
    56  
    57  		dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:]))
    58  		rp += 4
    59  	}
    60  
    61  	return rp, nil
    62  }
    63  
    64  func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte {
    65  	buf = pgio.AppendInt32(buf, int32(len(src.Dimensions)))
    66  
    67  	var containsNull int32
    68  	if src.ContainsNull {
    69  		containsNull = 1
    70  	}
    71  	buf = pgio.AppendInt32(buf, containsNull)
    72  
    73  	buf = pgio.AppendInt32(buf, src.ElementOID)
    74  
    75  	for i := range src.Dimensions {
    76  		buf = pgio.AppendInt32(buf, src.Dimensions[i].Length)
    77  		buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound)
    78  	}
    79  
    80  	return buf
    81  }
    82  
    83  type UntypedTextArray struct {
    84  	Elements   []string
    85  	Quoted     []bool
    86  	Dimensions []ArrayDimension
    87  }
    88  
    89  func ParseUntypedTextArray(src string) (*UntypedTextArray, error) {
    90  	dst := &UntypedTextArray{}
    91  
    92  	buf := bytes.NewBufferString(src)
    93  
    94  	skipWhitespace(buf)
    95  
    96  	r, _, err := buf.ReadRune()
    97  	if err != nil {
    98  		return nil, fmt.Errorf("invalid array: %v", err)
    99  	}
   100  
   101  	var explicitDimensions []ArrayDimension
   102  
   103  	// Array has explicit dimensions
   104  	if r == '[' {
   105  		buf.UnreadRune()
   106  
   107  		for {
   108  			r, _, err = buf.ReadRune()
   109  			if err != nil {
   110  				return nil, fmt.Errorf("invalid array: %v", err)
   111  			}
   112  
   113  			if r == '=' {
   114  				break
   115  			} else if r != '[' {
   116  				return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r)
   117  			}
   118  
   119  			lower, err := arrayParseInteger(buf)
   120  			if err != nil {
   121  				return nil, fmt.Errorf("invalid array: %v", err)
   122  			}
   123  
   124  			r, _, err = buf.ReadRune()
   125  			if err != nil {
   126  				return nil, fmt.Errorf("invalid array: %v", err)
   127  			}
   128  
   129  			if r != ':' {
   130  				return nil, fmt.Errorf("invalid array, expected ':' got %v", r)
   131  			}
   132  
   133  			upper, err := arrayParseInteger(buf)
   134  			if err != nil {
   135  				return nil, fmt.Errorf("invalid array: %v", err)
   136  			}
   137  
   138  			r, _, err = buf.ReadRune()
   139  			if err != nil {
   140  				return nil, fmt.Errorf("invalid array: %v", err)
   141  			}
   142  
   143  			if r != ']' {
   144  				return nil, fmt.Errorf("invalid array, expected ']' got %v", r)
   145  			}
   146  
   147  			explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1})
   148  		}
   149  
   150  		r, _, err = buf.ReadRune()
   151  		if err != nil {
   152  			return nil, fmt.Errorf("invalid array: %v", err)
   153  		}
   154  	}
   155  
   156  	if r != '{' {
   157  		return nil, fmt.Errorf("invalid array, expected '{': %v", err)
   158  	}
   159  
   160  	implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}}
   161  
   162  	// Consume all initial opening brackets. This provides number of dimensions.
   163  	for {
   164  		r, _, err = buf.ReadRune()
   165  		if err != nil {
   166  			return nil, fmt.Errorf("invalid array: %v", err)
   167  		}
   168  
   169  		if r == '{' {
   170  			implicitDimensions[len(implicitDimensions)-1].Length = 1
   171  			implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1})
   172  		} else {
   173  			buf.UnreadRune()
   174  			break
   175  		}
   176  	}
   177  	currentDim := len(implicitDimensions) - 1
   178  	counterDim := currentDim
   179  
   180  	for {
   181  		r, _, err = buf.ReadRune()
   182  		if err != nil {
   183  			return nil, fmt.Errorf("invalid array: %v", err)
   184  		}
   185  
   186  		switch r {
   187  		case '{':
   188  			if currentDim == counterDim {
   189  				implicitDimensions[currentDim].Length++
   190  			}
   191  			currentDim++
   192  		case ',':
   193  		case '}':
   194  			currentDim--
   195  			if currentDim < counterDim {
   196  				counterDim = currentDim
   197  			}
   198  		default:
   199  			buf.UnreadRune()
   200  			value, quoted, err := arrayParseValue(buf)
   201  			if err != nil {
   202  				return nil, fmt.Errorf("invalid array value: %v", err)
   203  			}
   204  			if currentDim == counterDim {
   205  				implicitDimensions[currentDim].Length++
   206  			}
   207  			dst.Quoted = append(dst.Quoted, quoted)
   208  			dst.Elements = append(dst.Elements, value)
   209  		}
   210  
   211  		if currentDim < 0 {
   212  			break
   213  		}
   214  	}
   215  
   216  	skipWhitespace(buf)
   217  
   218  	if buf.Len() > 0 {
   219  		return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
   220  	}
   221  
   222  	if len(dst.Elements) == 0 {
   223  		dst.Dimensions = nil
   224  	} else if len(explicitDimensions) > 0 {
   225  		dst.Dimensions = explicitDimensions
   226  	} else {
   227  		dst.Dimensions = implicitDimensions
   228  	}
   229  
   230  	return dst, nil
   231  }
   232  
   233  func skipWhitespace(buf *bytes.Buffer) {
   234  	var r rune
   235  	var err error
   236  	for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() {
   237  	}
   238  
   239  	if err != io.EOF {
   240  		buf.UnreadRune()
   241  	}
   242  }
   243  
   244  func arrayParseValue(buf *bytes.Buffer) (string, bool, error) {
   245  	r, _, err := buf.ReadRune()
   246  	if err != nil {
   247  		return "", false, err
   248  	}
   249  	if r == '"' {
   250  		return arrayParseQuotedValue(buf)
   251  	}
   252  	buf.UnreadRune()
   253  
   254  	s := &bytes.Buffer{}
   255  
   256  	for {
   257  		r, _, err := buf.ReadRune()
   258  		if err != nil {
   259  			return "", false, err
   260  		}
   261  
   262  		switch r {
   263  		case ',', '}':
   264  			buf.UnreadRune()
   265  			return s.String(), false, nil
   266  		}
   267  
   268  		s.WriteRune(r)
   269  	}
   270  }
   271  
   272  func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) {
   273  	s := &bytes.Buffer{}
   274  
   275  	for {
   276  		r, _, err := buf.ReadRune()
   277  		if err != nil {
   278  			return "", false, err
   279  		}
   280  
   281  		switch r {
   282  		case '\\':
   283  			r, _, err = buf.ReadRune()
   284  			if err != nil {
   285  				return "", false, err
   286  			}
   287  		case '"':
   288  			r, _, err = buf.ReadRune()
   289  			if err != nil {
   290  				return "", false, err
   291  			}
   292  			buf.UnreadRune()
   293  			return s.String(), true, nil
   294  		}
   295  		s.WriteRune(r)
   296  	}
   297  }
   298  
   299  func arrayParseInteger(buf *bytes.Buffer) (int32, error) {
   300  	s := &bytes.Buffer{}
   301  
   302  	for {
   303  		r, _, err := buf.ReadRune()
   304  		if err != nil {
   305  			return 0, err
   306  		}
   307  
   308  		if ('0' <= r && r <= '9') || r == '-' {
   309  			s.WriteRune(r)
   310  		} else {
   311  			buf.UnreadRune()
   312  			n, err := strconv.ParseInt(s.String(), 10, 32)
   313  			if err != nil {
   314  				return 0, err
   315  			}
   316  			return int32(n), nil
   317  		}
   318  	}
   319  }
   320  
   321  func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte {
   322  	var customDimensions bool
   323  	for _, dim := range dimensions {
   324  		if dim.LowerBound != 1 {
   325  			customDimensions = true
   326  		}
   327  	}
   328  
   329  	if !customDimensions {
   330  		return buf
   331  	}
   332  
   333  	for _, dim := range dimensions {
   334  		buf = append(buf, '[')
   335  		buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...)
   336  		buf = append(buf, ':')
   337  		buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...)
   338  		buf = append(buf, ']')
   339  	}
   340  
   341  	return append(buf, '=')
   342  }
   343  
   344  var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
   345  
   346  func quoteArrayElement(src string) string {
   347  	return `"` + quoteArrayReplacer.Replace(src) + `"`
   348  }
   349  
   350  func isSpace(ch byte) bool {
   351  	// see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224
   352  	return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f'
   353  }
   354  
   355  func QuoteArrayElementIfNeeded(src string) string {
   356  	if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) {
   357  		return quoteArrayElement(src)
   358  	}
   359  	return src
   360  }
   361  
   362  func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) {
   363  	switch value.Kind() {
   364  	case reflect.Array:
   365  		fallthrough
   366  	case reflect.Slice:
   367  		length := value.Len()
   368  		if 0 == elementsLength {
   369  			elementsLength = length
   370  		} else {
   371  			elementsLength *= length
   372  		}
   373  		dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1})
   374  		for i := 0; i < length; i++ {
   375  			if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok {
   376  				return d, l, true
   377  			}
   378  		}
   379  	}
   380  	return dimensions, elementsLength, true
   381  }
   382  

View as plain text