...

Source file src/github.com/jackc/pgtype/composite_type.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"strings"
     9  
    10  	"github.com/jackc/pgio"
    11  )
    12  
    13  type CompositeTypeField struct {
    14  	Name string
    15  	OID  uint32
    16  }
    17  
    18  type CompositeType struct {
    19  	status Status
    20  
    21  	typeName string
    22  
    23  	fields           []CompositeTypeField
    24  	valueTranscoders []ValueTranscoder
    25  }
    26  
    27  // NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used
    28  // for fields. All field OIDs must be previously registered in ci.
    29  func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) {
    30  	valueTranscoders := make([]ValueTranscoder, len(fields))
    31  
    32  	for i := range fields {
    33  		dt, ok := ci.DataTypeForOID(fields[i].OID)
    34  		if !ok {
    35  			return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID)
    36  		}
    37  
    38  		value := NewValue(dt.Value)
    39  		valueTranscoder, ok := value.(ValueTranscoder)
    40  		if !ok {
    41  			return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID)
    42  		}
    43  
    44  		valueTranscoders[i] = valueTranscoder
    45  	}
    46  
    47  	return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil
    48  }
    49  
    50  // NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length.
    51  // Prefer NewCompositeType unless overriding the transcoding of fields is required.
    52  func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) {
    53  	if len(fields) != len(values) {
    54  		return nil, errors.New("fields and valueTranscoders must have same length")
    55  	}
    56  
    57  	return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil
    58  }
    59  
    60  func (src CompositeType) Get() interface{} {
    61  	switch src.status {
    62  	case Present:
    63  		results := make(map[string]interface{}, len(src.valueTranscoders))
    64  		for i := range src.valueTranscoders {
    65  			results[src.fields[i].Name] = src.valueTranscoders[i].Get()
    66  		}
    67  		return results
    68  	case Null:
    69  		return nil
    70  	default:
    71  		return src.status
    72  	}
    73  }
    74  
    75  func (ct *CompositeType) NewTypeValue() Value {
    76  	a := &CompositeType{
    77  		typeName:         ct.typeName,
    78  		fields:           ct.fields,
    79  		valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)),
    80  	}
    81  
    82  	for i := range ct.valueTranscoders {
    83  		a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder)
    84  	}
    85  
    86  	return a
    87  }
    88  
    89  func (ct *CompositeType) TypeName() string {
    90  	return ct.typeName
    91  }
    92  
    93  func (ct *CompositeType) Fields() []CompositeTypeField {
    94  	return ct.fields
    95  }
    96  
    97  func (dst *CompositeType) Set(src interface{}) error {
    98  	if src == nil {
    99  		dst.status = Null
   100  		return nil
   101  	}
   102  
   103  	switch value := src.(type) {
   104  	case []interface{}:
   105  		if len(value) != len(dst.valueTranscoders) {
   106  			return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders))
   107  		}
   108  		for i, v := range value {
   109  			if err := dst.valueTranscoders[i].Set(v); err != nil {
   110  				return err
   111  			}
   112  		}
   113  		dst.status = Present
   114  	case *[]interface{}:
   115  		if value == nil {
   116  			dst.status = Null
   117  			return nil
   118  		}
   119  		return dst.Set(*value)
   120  	default:
   121  		return fmt.Errorf("Can not convert %v to Composite", src)
   122  	}
   123  
   124  	return nil
   125  }
   126  
   127  // AssignTo should never be called on composite value directly
   128  func (src CompositeType) AssignTo(dst interface{}) error {
   129  	switch src.status {
   130  	case Present:
   131  		switch v := dst.(type) {
   132  		case []interface{}:
   133  			if len(v) != len(src.valueTranscoders) {
   134  				return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders))
   135  			}
   136  			for i := range src.valueTranscoders {
   137  				if v[i] == nil {
   138  					continue
   139  				}
   140  
   141  				err := assignToOrSet(src.valueTranscoders[i], v[i])
   142  				if err != nil {
   143  					return fmt.Errorf("unable to assign to dst[%d]: %v", i, err)
   144  				}
   145  			}
   146  			return nil
   147  		case *[]interface{}:
   148  			return src.AssignTo(*v)
   149  		default:
   150  			if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct {
   151  				return err
   152  			}
   153  
   154  			if nextDst, retry := GetAssignToDstType(dst); retry {
   155  				return src.AssignTo(nextDst)
   156  			}
   157  			return fmt.Errorf("unable to assign to %T", dst)
   158  		}
   159  	case Null:
   160  		return NullAssignTo(dst)
   161  	}
   162  	return fmt.Errorf("cannot decode %#v into %T", src, dst)
   163  }
   164  
   165  func assignToOrSet(src Value, dst interface{}) error {
   166  	assignToErr := src.AssignTo(dst)
   167  	if assignToErr != nil {
   168  		// Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self.
   169  		setSucceeded := false
   170  		if setter, ok := dst.(Value); ok {
   171  			err := setter.Set(src.Get())
   172  			setSucceeded = err == nil
   173  		}
   174  		if !setSucceeded {
   175  			return assignToErr
   176  		}
   177  	}
   178  
   179  	return nil
   180  }
   181  
   182  func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
   183  	dstValue := reflect.ValueOf(dst)
   184  	if dstValue.Kind() != reflect.Ptr {
   185  		return false, nil
   186  	}
   187  
   188  	if dstValue.IsNil() {
   189  		return false, nil
   190  	}
   191  
   192  	dstElemValue := dstValue.Elem()
   193  	dstElemType := dstElemValue.Type()
   194  
   195  	if dstElemType.Kind() != reflect.Struct {
   196  		return false, nil
   197  	}
   198  
   199  	exportedFields := make([]int, 0, dstElemType.NumField())
   200  	for i := 0; i < dstElemType.NumField(); i++ {
   201  		sf := dstElemType.Field(i)
   202  		if sf.PkgPath == "" {
   203  			exportedFields = append(exportedFields, i)
   204  		}
   205  	}
   206  
   207  	if len(exportedFields) != len(src.valueTranscoders) {
   208  		return false, nil
   209  	}
   210  
   211  	for i := range exportedFields {
   212  		err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
   213  		if err != nil {
   214  			return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err)
   215  		}
   216  	}
   217  
   218  	return true, nil
   219  }
   220  
   221  func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
   222  	switch src.status {
   223  	case Null:
   224  		return nil, nil
   225  	case Undefined:
   226  		return nil, errUndefined
   227  	}
   228  
   229  	b := NewCompositeBinaryBuilder(ci, buf)
   230  	for i := range src.valueTranscoders {
   231  		b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i])
   232  	}
   233  
   234  	return b.Finish()
   235  }
   236  
   237  // DecodeBinary implements BinaryDecoder interface.
   238  // Opposite to Record, fields in a composite act as a "schema"
   239  // and decoding fails if SQL value can't be assigned due to
   240  // type mismatch
   241  func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
   242  	if buf == nil {
   243  		dst.status = Null
   244  		return nil
   245  	}
   246  
   247  	scanner := NewCompositeBinaryScanner(ci, buf)
   248  
   249  	for _, f := range dst.valueTranscoders {
   250  		scanner.ScanDecoder(f)
   251  	}
   252  
   253  	if scanner.Err() != nil {
   254  		return scanner.Err()
   255  	}
   256  
   257  	dst.status = Present
   258  
   259  	return nil
   260  }
   261  
   262  func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
   263  	if buf == nil {
   264  		dst.status = Null
   265  		return nil
   266  	}
   267  
   268  	scanner := NewCompositeTextScanner(ci, buf)
   269  
   270  	for _, f := range dst.valueTranscoders {
   271  		scanner.ScanDecoder(f)
   272  	}
   273  
   274  	if scanner.Err() != nil {
   275  		return scanner.Err()
   276  	}
   277  
   278  	dst.status = Present
   279  
   280  	return nil
   281  }
   282  
   283  func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) {
   284  	switch src.status {
   285  	case Null:
   286  		return nil, nil
   287  	case Undefined:
   288  		return nil, errUndefined
   289  	}
   290  
   291  	b := NewCompositeTextBuilder(ci, buf)
   292  	for _, f := range src.valueTranscoders {
   293  		b.AppendEncoder(f)
   294  	}
   295  
   296  	return b.Finish()
   297  }
   298  
   299  type CompositeBinaryScanner struct {
   300  	ci  *ConnInfo
   301  	rp  int
   302  	src []byte
   303  
   304  	fieldCount int32
   305  	fieldBytes []byte
   306  	fieldOID   uint32
   307  	err        error
   308  }
   309  
   310  // NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
   311  func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner {
   312  	rp := 0
   313  	if len(src[rp:]) < 4 {
   314  		return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
   315  	}
   316  
   317  	fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
   318  	rp += 4
   319  
   320  	return &CompositeBinaryScanner{
   321  		ci:         ci,
   322  		rp:         rp,
   323  		src:        src,
   324  		fieldCount: fieldCount,
   325  	}
   326  }
   327  
   328  // ScanDecoder calls Next and decodes the result with d.
   329  func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) {
   330  	if cfs.err != nil {
   331  		return
   332  	}
   333  
   334  	if cfs.Next() {
   335  		cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes)
   336  	} else {
   337  		cfs.err = errors.New("read past end of composite")
   338  	}
   339  }
   340  
   341  // ScanDecoder calls Next and scans the result into d.
   342  func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) {
   343  	if cfs.err != nil {
   344  		return
   345  	}
   346  
   347  	if cfs.Next() {
   348  		cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d)
   349  	} else {
   350  		cfs.err = errors.New("read past end of composite")
   351  	}
   352  }
   353  
   354  // Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
   355  // Next returns false, the Err method can be called to check if any errors occurred.
   356  func (cfs *CompositeBinaryScanner) Next() bool {
   357  	if cfs.err != nil {
   358  		return false
   359  	}
   360  
   361  	if cfs.rp == len(cfs.src) {
   362  		return false
   363  	}
   364  
   365  	if len(cfs.src[cfs.rp:]) < 8 {
   366  		cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
   367  		return false
   368  	}
   369  	cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
   370  	cfs.rp += 4
   371  
   372  	fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
   373  	cfs.rp += 4
   374  
   375  	if fieldLen >= 0 {
   376  		if len(cfs.src[cfs.rp:]) < fieldLen {
   377  			cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
   378  			return false
   379  		}
   380  		cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
   381  		cfs.rp += fieldLen
   382  	} else {
   383  		cfs.fieldBytes = nil
   384  	}
   385  
   386  	return true
   387  }
   388  
   389  func (cfs *CompositeBinaryScanner) FieldCount() int {
   390  	return int(cfs.fieldCount)
   391  }
   392  
   393  // Bytes returns the bytes of the field most recently read by Scan().
   394  func (cfs *CompositeBinaryScanner) Bytes() []byte {
   395  	return cfs.fieldBytes
   396  }
   397  
   398  // OID returns the OID of the field most recently read by Scan().
   399  func (cfs *CompositeBinaryScanner) OID() uint32 {
   400  	return cfs.fieldOID
   401  }
   402  
   403  // Err returns any error encountered by the scanner.
   404  func (cfs *CompositeBinaryScanner) Err() error {
   405  	return cfs.err
   406  }
   407  
   408  type CompositeTextScanner struct {
   409  	ci  *ConnInfo
   410  	rp  int
   411  	src []byte
   412  
   413  	fieldBytes []byte
   414  	err        error
   415  }
   416  
   417  // NewCompositeTextScanner a scanner over a text encoded composite value.
   418  func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner {
   419  	if len(src) < 2 {
   420  		return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
   421  	}
   422  
   423  	if src[0] != '(' {
   424  		return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
   425  	}
   426  
   427  	if src[len(src)-1] != ')' {
   428  		return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
   429  	}
   430  
   431  	return &CompositeTextScanner{
   432  		ci:  ci,
   433  		rp:  1,
   434  		src: src,
   435  	}
   436  }
   437  
   438  // ScanDecoder calls Next and decodes the result with d.
   439  func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) {
   440  	if cfs.err != nil {
   441  		return
   442  	}
   443  
   444  	if cfs.Next() {
   445  		cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes)
   446  	} else {
   447  		cfs.err = errors.New("read past end of composite")
   448  	}
   449  }
   450  
   451  // ScanDecoder calls Next and scans the result into d.
   452  func (cfs *CompositeTextScanner) ScanValue(d interface{}) {
   453  	if cfs.err != nil {
   454  		return
   455  	}
   456  
   457  	if cfs.Next() {
   458  		cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d)
   459  	} else {
   460  		cfs.err = errors.New("read past end of composite")
   461  	}
   462  }
   463  
   464  // Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
   465  // Next returns false, the Err method can be called to check if any errors occurred.
   466  func (cfs *CompositeTextScanner) Next() bool {
   467  	if cfs.err != nil {
   468  		return false
   469  	}
   470  
   471  	if cfs.rp == len(cfs.src) {
   472  		return false
   473  	}
   474  
   475  	switch cfs.src[cfs.rp] {
   476  	case ',', ')': // null
   477  		cfs.rp++
   478  		cfs.fieldBytes = nil
   479  		return true
   480  	case '"': // quoted value
   481  		cfs.rp++
   482  		cfs.fieldBytes = make([]byte, 0, 16)
   483  		for {
   484  			ch := cfs.src[cfs.rp]
   485  
   486  			if ch == '"' {
   487  				cfs.rp++
   488  				if cfs.src[cfs.rp] == '"' {
   489  					cfs.fieldBytes = append(cfs.fieldBytes, '"')
   490  					cfs.rp++
   491  				} else {
   492  					break
   493  				}
   494  			} else if ch == '\\' {
   495  				cfs.rp++
   496  				cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
   497  				cfs.rp++
   498  			} else {
   499  				cfs.fieldBytes = append(cfs.fieldBytes, ch)
   500  				cfs.rp++
   501  			}
   502  		}
   503  		cfs.rp++
   504  		return true
   505  	default: // unquoted value
   506  		start := cfs.rp
   507  		for {
   508  			ch := cfs.src[cfs.rp]
   509  			if ch == ',' || ch == ')' {
   510  				break
   511  			}
   512  			cfs.rp++
   513  		}
   514  		cfs.fieldBytes = cfs.src[start:cfs.rp]
   515  		cfs.rp++
   516  		return true
   517  	}
   518  }
   519  
   520  // Bytes returns the bytes of the field most recently read by Scan().
   521  func (cfs *CompositeTextScanner) Bytes() []byte {
   522  	return cfs.fieldBytes
   523  }
   524  
   525  // Err returns any error encountered by the scanner.
   526  func (cfs *CompositeTextScanner) Err() error {
   527  	return cfs.err
   528  }
   529  
   530  type CompositeBinaryBuilder struct {
   531  	ci         *ConnInfo
   532  	buf        []byte
   533  	startIdx   int
   534  	fieldCount uint32
   535  	err        error
   536  }
   537  
   538  func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder {
   539  	startIdx := len(buf)
   540  	buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
   541  	return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx}
   542  }
   543  
   544  func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) {
   545  	if b.err != nil {
   546  		return
   547  	}
   548  
   549  	dt, ok := b.ci.DataTypeForOID(oid)
   550  	if !ok {
   551  		b.err = fmt.Errorf("unknown data type for OID: %d", oid)
   552  		return
   553  	}
   554  
   555  	err := dt.Value.Set(field)
   556  	if err != nil {
   557  		b.err = err
   558  		return
   559  	}
   560  
   561  	binaryEncoder, ok := dt.Value.(BinaryEncoder)
   562  	if !ok {
   563  		b.err = fmt.Errorf("unable to encode binary for OID: %d", oid)
   564  		return
   565  	}
   566  
   567  	b.AppendEncoder(oid, binaryEncoder)
   568  }
   569  
   570  func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) {
   571  	if b.err != nil {
   572  		return
   573  	}
   574  
   575  	b.buf = pgio.AppendUint32(b.buf, oid)
   576  	lengthPos := len(b.buf)
   577  	b.buf = pgio.AppendInt32(b.buf, -1)
   578  	fieldBuf, err := field.EncodeBinary(b.ci, b.buf)
   579  	if err != nil {
   580  		b.err = err
   581  		return
   582  	}
   583  	if fieldBuf != nil {
   584  		binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
   585  		b.buf = fieldBuf
   586  	}
   587  
   588  	b.fieldCount++
   589  }
   590  
   591  func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
   592  	if b.err != nil {
   593  		return nil, b.err
   594  	}
   595  
   596  	binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
   597  	return b.buf, nil
   598  }
   599  
   600  type CompositeTextBuilder struct {
   601  	ci         *ConnInfo
   602  	buf        []byte
   603  	startIdx   int
   604  	fieldCount uint32
   605  	err        error
   606  	fieldBuf   [32]byte
   607  }
   608  
   609  func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder {
   610  	buf = append(buf, '(') // allocate room for number of fields
   611  	return &CompositeTextBuilder{ci: ci, buf: buf}
   612  }
   613  
   614  func (b *CompositeTextBuilder) AppendValue(field interface{}) {
   615  	if b.err != nil {
   616  		return
   617  	}
   618  
   619  	if field == nil {
   620  		b.buf = append(b.buf, ',')
   621  		return
   622  	}
   623  
   624  	dt, ok := b.ci.DataTypeForValue(field)
   625  	if !ok {
   626  		b.err = fmt.Errorf("unknown data type for field: %v", field)
   627  		return
   628  	}
   629  
   630  	err := dt.Value.Set(field)
   631  	if err != nil {
   632  		b.err = err
   633  		return
   634  	}
   635  
   636  	textEncoder, ok := dt.Value.(TextEncoder)
   637  	if !ok {
   638  		b.err = fmt.Errorf("unable to encode text for value: %v", field)
   639  		return
   640  	}
   641  
   642  	b.AppendEncoder(textEncoder)
   643  }
   644  
   645  func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) {
   646  	if b.err != nil {
   647  		return
   648  	}
   649  
   650  	fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0])
   651  	if err != nil {
   652  		b.err = err
   653  		return
   654  	}
   655  	if fieldBuf != nil {
   656  		b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
   657  	}
   658  
   659  	b.buf = append(b.buf, ',')
   660  }
   661  
   662  func (b *CompositeTextBuilder) Finish() ([]byte, error) {
   663  	if b.err != nil {
   664  		return nil, b.err
   665  	}
   666  
   667  	b.buf[len(b.buf)-1] = ')'
   668  	return b.buf, nil
   669  }
   670  
   671  var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
   672  
   673  func quoteCompositeField(src string) string {
   674  	return `"` + quoteCompositeReplacer.Replace(src) + `"`
   675  }
   676  
   677  func quoteCompositeFieldIfNeeded(src string) string {
   678  	if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
   679  		return quoteCompositeField(src)
   680  	}
   681  	return src
   682  }
   683  

View as plain text