...

Source file src/github.com/jackc/pgx/v5/pgtype/hstore.go

Documentation: github.com/jackc/pgx/v5/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  
    10  	"github.com/jackc/pgx/v5/internal/pgio"
    11  )
    12  
    13  type HstoreScanner interface {
    14  	ScanHstore(v Hstore) error
    15  }
    16  
    17  type HstoreValuer interface {
    18  	HstoreValue() (Hstore, error)
    19  }
    20  
    21  // Hstore represents an hstore column that can be null or have null values
    22  // associated with its keys.
    23  type Hstore map[string]*string
    24  
    25  func (h *Hstore) ScanHstore(v Hstore) error {
    26  	*h = v
    27  	return nil
    28  }
    29  
    30  func (h Hstore) HstoreValue() (Hstore, error) {
    31  	return h, nil
    32  }
    33  
    34  // Scan implements the database/sql Scanner interface.
    35  func (h *Hstore) Scan(src any) error {
    36  	if src == nil {
    37  		*h = nil
    38  		return nil
    39  	}
    40  
    41  	switch src := src.(type) {
    42  	case string:
    43  		return scanPlanTextAnyToHstoreScanner{}.scanString(src, h)
    44  	}
    45  
    46  	return fmt.Errorf("cannot scan %T", src)
    47  }
    48  
    49  // Value implements the database/sql/driver Valuer interface.
    50  func (h Hstore) Value() (driver.Value, error) {
    51  	if h == nil {
    52  		return nil, nil
    53  	}
    54  
    55  	buf, err := HstoreCodec{}.PlanEncode(nil, 0, TextFormatCode, h).Encode(h, nil)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	return string(buf), err
    60  }
    61  
    62  type HstoreCodec struct{}
    63  
    64  func (HstoreCodec) FormatSupported(format int16) bool {
    65  	return format == TextFormatCode || format == BinaryFormatCode
    66  }
    67  
    68  func (HstoreCodec) PreferredFormat() int16 {
    69  	return BinaryFormatCode
    70  }
    71  
    72  func (HstoreCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
    73  	if _, ok := value.(HstoreValuer); !ok {
    74  		return nil
    75  	}
    76  
    77  	switch format {
    78  	case BinaryFormatCode:
    79  		return encodePlanHstoreCodecBinary{}
    80  	case TextFormatCode:
    81  		return encodePlanHstoreCodecText{}
    82  	}
    83  
    84  	return nil
    85  }
    86  
    87  type encodePlanHstoreCodecBinary struct{}
    88  
    89  func (encodePlanHstoreCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
    90  	hstore, err := value.(HstoreValuer).HstoreValue()
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	if hstore == nil {
    96  		return nil, nil
    97  	}
    98  
    99  	buf = pgio.AppendInt32(buf, int32(len(hstore)))
   100  
   101  	for k, v := range hstore {
   102  		buf = pgio.AppendInt32(buf, int32(len(k)))
   103  		buf = append(buf, k...)
   104  
   105  		if v == nil {
   106  			buf = pgio.AppendInt32(buf, -1)
   107  		} else {
   108  			buf = pgio.AppendInt32(buf, int32(len(*v)))
   109  			buf = append(buf, (*v)...)
   110  		}
   111  	}
   112  
   113  	return buf, nil
   114  }
   115  
   116  type encodePlanHstoreCodecText struct{}
   117  
   118  func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
   119  	hstore, err := value.(HstoreValuer).HstoreValue()
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	if len(hstore) == 0 {
   125  		// distinguish between empty and nil: Not strictly required by Postgres, since its protocol
   126  		// explicitly marks NULL column values separately. However, the Binary codec does this, and
   127  		// this means we can "round trip" Encode and Scan without data loss.
   128  		// nil: []byte(nil); empty: []byte{}
   129  		if hstore == nil {
   130  			return nil, nil
   131  		}
   132  		return []byte{}, nil
   133  	}
   134  
   135  	firstPair := true
   136  
   137  	for k, v := range hstore {
   138  		if firstPair {
   139  			firstPair = false
   140  		} else {
   141  			buf = append(buf, ',', ' ')
   142  		}
   143  
   144  		// unconditionally quote hstore keys/values like Postgres does
   145  		// this avoids a Mac OS X Postgres hstore parsing bug:
   146  		// https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com
   147  		buf = append(buf, '"')
   148  		buf = append(buf, quoteArrayReplacer.Replace(k)...)
   149  		buf = append(buf, '"')
   150  		buf = append(buf, "=>"...)
   151  
   152  		if v == nil {
   153  			buf = append(buf, "NULL"...)
   154  		} else {
   155  			buf = append(buf, '"')
   156  			buf = append(buf, quoteArrayReplacer.Replace(*v)...)
   157  			buf = append(buf, '"')
   158  		}
   159  	}
   160  
   161  	return buf, nil
   162  }
   163  
   164  func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
   165  
   166  	switch format {
   167  	case BinaryFormatCode:
   168  		switch target.(type) {
   169  		case HstoreScanner:
   170  			return scanPlanBinaryHstoreToHstoreScanner{}
   171  		}
   172  	case TextFormatCode:
   173  		switch target.(type) {
   174  		case HstoreScanner:
   175  			return scanPlanTextAnyToHstoreScanner{}
   176  		}
   177  	}
   178  
   179  	return nil
   180  }
   181  
   182  type scanPlanBinaryHstoreToHstoreScanner struct{}
   183  
   184  func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
   185  	scanner := (dst).(HstoreScanner)
   186  
   187  	if src == nil {
   188  		return scanner.ScanHstore(Hstore(nil))
   189  	}
   190  
   191  	rp := 0
   192  
   193  	const uint32Len = 4
   194  	if len(src[rp:]) < uint32Len {
   195  		return fmt.Errorf("hstore incomplete %v", src)
   196  	}
   197  	pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
   198  	rp += uint32Len
   199  
   200  	hstore := make(Hstore, pairCount)
   201  	// one allocation for all *string, rather than one per string, just like text parsing
   202  	valueStrings := make([]string, pairCount)
   203  
   204  	for i := 0; i < pairCount; i++ {
   205  		if len(src[rp:]) < uint32Len {
   206  			return fmt.Errorf("hstore incomplete %v", src)
   207  		}
   208  		keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
   209  		rp += uint32Len
   210  
   211  		if len(src[rp:]) < keyLen {
   212  			return fmt.Errorf("hstore incomplete %v", src)
   213  		}
   214  		key := string(src[rp : rp+keyLen])
   215  		rp += keyLen
   216  
   217  		if len(src[rp:]) < uint32Len {
   218  			return fmt.Errorf("hstore incomplete %v", src)
   219  		}
   220  		valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
   221  		rp += 4
   222  
   223  		if valueLen >= 0 {
   224  			valueStrings[i] = string(src[rp : rp+valueLen])
   225  			rp += valueLen
   226  
   227  			hstore[key] = &valueStrings[i]
   228  		} else {
   229  			hstore[key] = nil
   230  		}
   231  	}
   232  
   233  	return scanner.ScanHstore(hstore)
   234  }
   235  
   236  type scanPlanTextAnyToHstoreScanner struct{}
   237  
   238  func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error {
   239  	scanner := (dst).(HstoreScanner)
   240  
   241  	if src == nil {
   242  		return scanner.ScanHstore(Hstore(nil))
   243  	}
   244  	return s.scanString(string(src), scanner)
   245  }
   246  
   247  // scanString does not return nil hstore values because string cannot be nil.
   248  func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error {
   249  	hstore, err := parseHstore(src)
   250  	if err != nil {
   251  		return err
   252  	}
   253  	return scanner.ScanHstore(hstore)
   254  }
   255  
   256  func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
   257  	return codecDecodeToTextFormat(c, m, oid, format, src)
   258  }
   259  
   260  func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
   261  	if src == nil {
   262  		return nil, nil
   263  	}
   264  
   265  	var hstore Hstore
   266  	err := codecScan(c, m, oid, format, src, &hstore)
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  	return hstore, nil
   271  }
   272  
   273  type hstoreParser struct {
   274  	str           string
   275  	pos           int
   276  	nextBackslash int
   277  }
   278  
   279  func newHSP(in string) *hstoreParser {
   280  	return &hstoreParser{
   281  		pos:           0,
   282  		str:           in,
   283  		nextBackslash: strings.IndexByte(in, '\\'),
   284  	}
   285  }
   286  
   287  func (p *hstoreParser) atEnd() bool {
   288  	return p.pos >= len(p.str)
   289  }
   290  
   291  // consume returns the next byte of the string, or end if the string is done.
   292  func (p *hstoreParser) consume() (b byte, end bool) {
   293  	if p.pos >= len(p.str) {
   294  		return 0, true
   295  	}
   296  	b = p.str[p.pos]
   297  	p.pos++
   298  	return b, false
   299  }
   300  
   301  func unexpectedByteErr(actualB byte, expectedB byte) error {
   302  	return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB)
   303  }
   304  
   305  // consumeExpectedByte consumes expectedB from the string, or returns an error.
   306  func (p *hstoreParser) consumeExpectedByte(expectedB byte) error {
   307  	nextB, end := p.consume()
   308  	if end {
   309  		return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB)
   310  	}
   311  	if nextB != expectedB {
   312  		return unexpectedByteErr(nextB, expectedB)
   313  	}
   314  	return nil
   315  }
   316  
   317  // consumeExpected2 consumes two expected bytes or returns an error.
   318  // This was a bit faster than using a string argument (better inlining? Not sure).
   319  func (p *hstoreParser) consumeExpected2(one byte, two byte) error {
   320  	if p.pos+2 > len(p.str) {
   321  		return errors.New("unexpected end of string")
   322  	}
   323  	if p.str[p.pos] != one {
   324  		return unexpectedByteErr(p.str[p.pos], one)
   325  	}
   326  	if p.str[p.pos+1] != two {
   327  		return unexpectedByteErr(p.str[p.pos+1], two)
   328  	}
   329  	p.pos += 2
   330  	return nil
   331  }
   332  
   333  var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`)
   334  
   335  // consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been
   336  // parsed already. This copies the string from the backing string so it can be garbage collected.
   337  func (p *hstoreParser) consumeDoubleQuoted() (string, error) {
   338  	// fast path: assume most keys/values do not contain escapes
   339  	nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"')
   340  	if nextDoubleQuote == -1 {
   341  		return "", errEOSInQuoted
   342  	}
   343  	nextDoubleQuote += p.pos
   344  	if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote {
   345  		// clone the string from the source string to ensure it can be garbage collected separately
   346  		// TODO: use strings.Clone on Go 1.20; this could get optimized away
   347  		s := strings.Clone(p.str[p.pos:nextDoubleQuote])
   348  		p.pos = nextDoubleQuote + 1
   349  		return s, nil
   350  	}
   351  
   352  	// slow path: string contains escapes
   353  	s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash)
   354  	p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\')
   355  	if p.nextBackslash != -1 {
   356  		p.nextBackslash += p.pos
   357  	}
   358  	return s, err
   359  }
   360  
   361  // consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting
   362  // at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be
   363  // garbage collected separately.
   364  func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) {
   365  	// copy the prefix that does not contain backslashes
   366  	var builder strings.Builder
   367  	builder.WriteString(p.str[p.pos:firstBackslash])
   368  
   369  	// skip to the backslash
   370  	p.pos = firstBackslash
   371  
   372  	// copy bytes until the end, unescaping backslashes
   373  	for {
   374  		nextB, end := p.consume()
   375  		if end {
   376  			return "", errEOSInQuoted
   377  		} else if nextB == '"' {
   378  			break
   379  		} else if nextB == '\\' {
   380  			// escape: skip the backslash and copy the char
   381  			nextB, end = p.consume()
   382  			if end {
   383  				return "", errEOSInQuoted
   384  			}
   385  			if !(nextB == '\\' || nextB == '"') {
   386  				return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB)
   387  			}
   388  			builder.WriteByte(nextB)
   389  		} else {
   390  			// normal byte: copy it
   391  			builder.WriteByte(nextB)
   392  		}
   393  	}
   394  	return builder.String(), nil
   395  }
   396  
   397  // consumePairSeparator consumes the Hstore pair separator ", " or returns an error.
   398  func (p *hstoreParser) consumePairSeparator() error {
   399  	return p.consumeExpected2(',', ' ')
   400  }
   401  
   402  // consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error.
   403  func (p *hstoreParser) consumeKVSeparator() error {
   404  	return p.consumeExpected2('=', '>')
   405  }
   406  
   407  // consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error.
   408  func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) {
   409  	// peek at the next byte
   410  	if p.atEnd() {
   411  		return Text{}, errors.New("found end instead of value")
   412  	}
   413  	next := p.str[p.pos]
   414  	if next == 'N' {
   415  		// must be the exact string NULL: use consumeExpected2 twice
   416  		err := p.consumeExpected2('N', 'U')
   417  		if err != nil {
   418  			return Text{}, err
   419  		}
   420  		err = p.consumeExpected2('L', 'L')
   421  		if err != nil {
   422  			return Text{}, err
   423  		}
   424  		return Text{String: "", Valid: false}, nil
   425  	} else if next != '"' {
   426  		return Text{}, unexpectedByteErr(next, '"')
   427  	}
   428  
   429  	// skip the double quote
   430  	p.pos += 1
   431  	s, err := p.consumeDoubleQuoted()
   432  	if err != nil {
   433  		return Text{}, err
   434  	}
   435  	return Text{String: s, Valid: true}, nil
   436  }
   437  
   438  func parseHstore(s string) (Hstore, error) {
   439  	p := newHSP(s)
   440  
   441  	// This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it
   442  	// is less likely to occur in keys/values than '=' or ','.
   443  	numPairsEstimate := strings.Count(s, ">")
   444  	// makes one allocation of strings for the entire Hstore, rather than one allocation per value.
   445  	valueStrings := make([]string, 0, numPairsEstimate)
   446  	result := make(Hstore, numPairsEstimate)
   447  	first := true
   448  	for !p.atEnd() {
   449  		if !first {
   450  			err := p.consumePairSeparator()
   451  			if err != nil {
   452  				return nil, err
   453  			}
   454  		} else {
   455  			first = false
   456  		}
   457  
   458  		err := p.consumeExpectedByte('"')
   459  		if err != nil {
   460  			return nil, err
   461  		}
   462  
   463  		key, err := p.consumeDoubleQuoted()
   464  		if err != nil {
   465  			return nil, err
   466  		}
   467  
   468  		err = p.consumeKVSeparator()
   469  		if err != nil {
   470  			return nil, err
   471  		}
   472  
   473  		value, err := p.consumeDoubleQuotedOrNull()
   474  		if err != nil {
   475  			return nil, err
   476  		}
   477  		if value.Valid {
   478  			valueStrings = append(valueStrings, value.String)
   479  			result[key] = &valueStrings[len(valueStrings)-1]
   480  		} else {
   481  			result[key] = nil
   482  		}
   483  	}
   484  
   485  	return result, nil
   486  }
   487  

View as plain text