...

Source file src/github.com/jackc/pgtype/hstore.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql/driver"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"strings"
    10  	"unicode"
    11  	"unicode/utf8"
    12  
    13  	"github.com/jackc/pgio"
    14  )
    15  
    16  // Hstore represents an hstore column that can be null or have null values
    17  // associated with its keys.
    18  type Hstore struct {
    19  	Map    map[string]Text
    20  	Status Status
    21  }
    22  
    23  func (dst *Hstore) Set(src interface{}) error {
    24  	if src == nil {
    25  		*dst = Hstore{Status: Null}
    26  		return nil
    27  	}
    28  
    29  	if value, ok := src.(interface{ Get() interface{} }); ok {
    30  		value2 := value.Get()
    31  		if value2 != value {
    32  			return dst.Set(value2)
    33  		}
    34  	}
    35  
    36  	switch value := src.(type) {
    37  	case map[string]string:
    38  		m := make(map[string]Text, len(value))
    39  		for k, v := range value {
    40  			m[k] = Text{String: v, Status: Present}
    41  		}
    42  		*dst = Hstore{Map: m, Status: Present}
    43  	case map[string]*string:
    44  		m := make(map[string]Text, len(value))
    45  		for k, v := range value {
    46  			if v == nil {
    47  				m[k] = Text{Status: Null}
    48  			} else {
    49  				m[k] = Text{String: *v, Status: Present}
    50  			}
    51  		}
    52  		*dst = Hstore{Map: m, Status: Present}
    53  	case map[string]Text:
    54  		*dst = Hstore{Map: value, Status: Present}
    55  	default:
    56  		return fmt.Errorf("cannot convert %v to Hstore", src)
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func (dst Hstore) Get() interface{} {
    63  	switch dst.Status {
    64  	case Present:
    65  		return dst.Map
    66  	case Null:
    67  		return nil
    68  	default:
    69  		return dst.Status
    70  	}
    71  }
    72  
    73  func (src *Hstore) AssignTo(dst interface{}) error {
    74  	switch src.Status {
    75  	case Present:
    76  		switch v := dst.(type) {
    77  		case *map[string]string:
    78  			*v = make(map[string]string, len(src.Map))
    79  			for k, val := range src.Map {
    80  				if val.Status != Present {
    81  					return fmt.Errorf("cannot decode %#v into %T", src, dst)
    82  				}
    83  				(*v)[k] = val.String
    84  			}
    85  			return nil
    86  		case *map[string]*string:
    87  			*v = make(map[string]*string, len(src.Map))
    88  			for k, val := range src.Map {
    89  				switch val.Status {
    90  				case Null:
    91  					(*v)[k] = nil
    92  				case Present:
    93  					str := val.String
    94  					(*v)[k] = &str
    95  				default:
    96  					return fmt.Errorf("cannot decode %#v into %T", src, dst)
    97  				}
    98  			}
    99  			return nil
   100  		default:
   101  			if nextDst, retry := GetAssignToDstType(dst); retry {
   102  				return src.AssignTo(nextDst)
   103  			}
   104  			return fmt.Errorf("unable to assign to %T", dst)
   105  		}
   106  	case Null:
   107  		return NullAssignTo(dst)
   108  	}
   109  
   110  	return fmt.Errorf("cannot decode %#v into %T", src, dst)
   111  }
   112  
   113  func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error {
   114  	if src == nil {
   115  		*dst = Hstore{Status: Null}
   116  		return nil
   117  	}
   118  
   119  	keys, values, err := parseHstore(string(src))
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	m := make(map[string]Text, len(keys))
   125  	for i := range keys {
   126  		m[keys[i]] = values[i]
   127  	}
   128  
   129  	*dst = Hstore{Map: m, Status: Present}
   130  	return nil
   131  }
   132  
   133  func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error {
   134  	if src == nil {
   135  		*dst = Hstore{Status: Null}
   136  		return nil
   137  	}
   138  
   139  	rp := 0
   140  
   141  	if len(src[rp:]) < 4 {
   142  		return fmt.Errorf("hstore incomplete %v", src)
   143  	}
   144  	pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
   145  	rp += 4
   146  
   147  	m := make(map[string]Text, pairCount)
   148  
   149  	for i := 0; i < pairCount; i++ {
   150  		if len(src[rp:]) < 4 {
   151  			return fmt.Errorf("hstore incomplete %v", src)
   152  		}
   153  		keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
   154  		rp += 4
   155  
   156  		if len(src[rp:]) < keyLen {
   157  			return fmt.Errorf("hstore incomplete %v", src)
   158  		}
   159  		key := string(src[rp : rp+keyLen])
   160  		rp += keyLen
   161  
   162  		if len(src[rp:]) < 4 {
   163  			return fmt.Errorf("hstore incomplete %v", src)
   164  		}
   165  		valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
   166  		rp += 4
   167  
   168  		var valueBuf []byte
   169  		if valueLen >= 0 {
   170  			valueBuf = src[rp : rp+valueLen]
   171  			rp += valueLen
   172  		}
   173  
   174  		var value Text
   175  		err := value.DecodeBinary(ci, valueBuf)
   176  		if err != nil {
   177  			return err
   178  		}
   179  		m[key] = value
   180  	}
   181  
   182  	*dst = Hstore{Map: m, Status: Present}
   183  
   184  	return nil
   185  }
   186  
   187  func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
   188  	switch src.Status {
   189  	case Null:
   190  		return nil, nil
   191  	case Undefined:
   192  		return nil, errUndefined
   193  	}
   194  
   195  	firstPair := true
   196  
   197  	inElemBuf := make([]byte, 0, 32)
   198  	for k, v := range src.Map {
   199  		if firstPair {
   200  			firstPair = false
   201  		} else {
   202  			buf = append(buf, ',')
   203  		}
   204  
   205  		buf = append(buf, quoteHstoreElementIfNeeded(k)...)
   206  		buf = append(buf, "=>"...)
   207  
   208  		elemBuf, err := v.EncodeText(ci, inElemBuf)
   209  		if err != nil {
   210  			return nil, err
   211  		}
   212  
   213  		if elemBuf == nil {
   214  			buf = append(buf, "NULL"...)
   215  		} else {
   216  			buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...)
   217  		}
   218  	}
   219  
   220  	return buf, nil
   221  }
   222  
   223  func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
   224  	switch src.Status {
   225  	case Null:
   226  		return nil, nil
   227  	case Undefined:
   228  		return nil, errUndefined
   229  	}
   230  
   231  	buf = pgio.AppendInt32(buf, int32(len(src.Map)))
   232  
   233  	var err error
   234  	for k, v := range src.Map {
   235  		buf = pgio.AppendInt32(buf, int32(len(k)))
   236  		buf = append(buf, k...)
   237  
   238  		sp := len(buf)
   239  		buf = pgio.AppendInt32(buf, -1)
   240  
   241  		elemBuf, err := v.EncodeText(ci, buf)
   242  		if err != nil {
   243  			return nil, err
   244  		}
   245  		if elemBuf != nil {
   246  			buf = elemBuf
   247  			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   248  		}
   249  	}
   250  
   251  	return buf, err
   252  }
   253  
   254  var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
   255  
   256  func quoteHstoreElement(src string) string {
   257  	return `"` + quoteArrayReplacer.Replace(src) + `"`
   258  }
   259  
   260  func quoteHstoreElementIfNeeded(src string) string {
   261  	if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) {
   262  		return quoteArrayElement(src)
   263  	}
   264  	return src
   265  }
   266  
   267  const (
   268  	hsPre = iota
   269  	hsKey
   270  	hsSep
   271  	hsVal
   272  	hsNul
   273  	hsNext
   274  )
   275  
   276  type hstoreParser struct {
   277  	str string
   278  	pos int
   279  }
   280  
   281  func newHSP(in string) *hstoreParser {
   282  	return &hstoreParser{
   283  		pos: 0,
   284  		str: in,
   285  	}
   286  }
   287  
   288  func (p *hstoreParser) Consume() (r rune, end bool) {
   289  	if p.pos >= len(p.str) {
   290  		end = true
   291  		return
   292  	}
   293  	r, w := utf8.DecodeRuneInString(p.str[p.pos:])
   294  	p.pos += w
   295  	return
   296  }
   297  
   298  func (p *hstoreParser) Peek() (r rune, end bool) {
   299  	if p.pos >= len(p.str) {
   300  		end = true
   301  		return
   302  	}
   303  	r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
   304  	return
   305  }
   306  
   307  // parseHstore parses the string representation of an hstore column (the same
   308  // you would get from an ordinary SELECT) into two slices of keys and values. it
   309  // is used internally in the default parsing of hstores.
   310  func parseHstore(s string) (k []string, v []Text, err error) {
   311  	if s == "" {
   312  		return
   313  	}
   314  
   315  	buf := bytes.Buffer{}
   316  	keys := []string{}
   317  	values := []Text{}
   318  	p := newHSP(s)
   319  
   320  	r, end := p.Consume()
   321  	state := hsPre
   322  
   323  	for !end {
   324  		switch state {
   325  		case hsPre:
   326  			if r == '"' {
   327  				state = hsKey
   328  			} else {
   329  				err = errors.New("String does not begin with \"")
   330  			}
   331  		case hsKey:
   332  			switch r {
   333  			case '"': //End of the key
   334  				keys = append(keys, buf.String())
   335  				buf = bytes.Buffer{}
   336  				state = hsSep
   337  			case '\\': //Potential escaped character
   338  				n, end := p.Consume()
   339  				switch {
   340  				case end:
   341  					err = errors.New("Found EOS in key, expecting character or \"")
   342  				case n == '"', n == '\\':
   343  					buf.WriteRune(n)
   344  				default:
   345  					buf.WriteRune(r)
   346  					buf.WriteRune(n)
   347  				}
   348  			default: //Any other character
   349  				buf.WriteRune(r)
   350  			}
   351  		case hsSep:
   352  			if r == '=' {
   353  				r, end = p.Consume()
   354  				switch {
   355  				case end:
   356  					err = errors.New("Found EOS after '=', expecting '>'")
   357  				case r == '>':
   358  					r, end = p.Consume()
   359  					switch {
   360  					case end:
   361  						err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
   362  					case r == '"':
   363  						state = hsVal
   364  					case r == 'N':
   365  						state = hsNul
   366  					default:
   367  						err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
   368  					}
   369  				default:
   370  					err = fmt.Errorf("Invalid character after '=', expecting '>'")
   371  				}
   372  			} else {
   373  				err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
   374  			}
   375  		case hsVal:
   376  			switch r {
   377  			case '"': //End of the value
   378  				values = append(values, Text{String: buf.String(), Status: Present})
   379  				buf = bytes.Buffer{}
   380  				state = hsNext
   381  			case '\\': //Potential escaped character
   382  				n, end := p.Consume()
   383  				switch {
   384  				case end:
   385  					err = errors.New("Found EOS in key, expecting character or \"")
   386  				case n == '"', n == '\\':
   387  					buf.WriteRune(n)
   388  				default:
   389  					buf.WriteRune(r)
   390  					buf.WriteRune(n)
   391  				}
   392  			default: //Any other character
   393  				buf.WriteRune(r)
   394  			}
   395  		case hsNul:
   396  			nulBuf := make([]rune, 3)
   397  			nulBuf[0] = r
   398  			for i := 1; i < 3; i++ {
   399  				r, end = p.Consume()
   400  				if end {
   401  					err = errors.New("Found EOS in NULL value")
   402  					return
   403  				}
   404  				nulBuf[i] = r
   405  			}
   406  			if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
   407  				values = append(values, Text{Status: Null})
   408  				state = hsNext
   409  			} else {
   410  				err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
   411  			}
   412  		case hsNext:
   413  			if r == ',' {
   414  				r, end = p.Consume()
   415  				switch {
   416  				case end:
   417  					err = errors.New("Found EOS after ',', expecting space")
   418  				case (unicode.IsSpace(r)):
   419  					r, end = p.Consume()
   420  					state = hsKey
   421  				default:
   422  					err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
   423  				}
   424  			} else {
   425  				err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
   426  			}
   427  		}
   428  
   429  		if err != nil {
   430  			return
   431  		}
   432  		r, end = p.Consume()
   433  	}
   434  	if state != hsNext {
   435  		err = errors.New("Improperly formatted hstore")
   436  		return
   437  	}
   438  	k = keys
   439  	v = values
   440  	return
   441  }
   442  
   443  // Scan implements the database/sql Scanner interface.
   444  func (dst *Hstore) Scan(src interface{}) error {
   445  	if src == nil {
   446  		*dst = Hstore{Status: Null}
   447  		return nil
   448  	}
   449  
   450  	switch src := src.(type) {
   451  	case string:
   452  		return dst.DecodeText(nil, []byte(src))
   453  	case []byte:
   454  		srcCopy := make([]byte, len(src))
   455  		copy(srcCopy, src)
   456  		return dst.DecodeText(nil, srcCopy)
   457  	}
   458  
   459  	return fmt.Errorf("cannot scan %T", src)
   460  }
   461  
   462  // Value implements the database/sql/driver Valuer interface.
   463  func (src Hstore) Value() (driver.Value, error) {
   464  	return EncodeValueText(src)
   465  }
   466  

View as plain text