...

Source file src/github.com/jackc/pgx/v5/pgtype/composite.go

Documentation: github.com/jackc/pgx/v5/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  
    10  	"github.com/jackc/pgx/v5/internal/pgio"
    11  )
    12  
    13  // CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite.
    14  type CompositeIndexGetter interface {
    15  	// IsNull returns true if the value is SQL NULL.
    16  	IsNull() bool
    17  
    18  	// Index returns the element at i.
    19  	Index(i int) any
    20  }
    21  
    22  // CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite.
    23  type CompositeIndexScanner interface {
    24  	// ScanNull sets the value to SQL NULL.
    25  	ScanNull() error
    26  
    27  	// ScanIndex returns a value usable as a scan target for i.
    28  	ScanIndex(i int) any
    29  }
    30  
    31  type CompositeCodecField struct {
    32  	Name string
    33  	Type *Type
    34  }
    35  
    36  type CompositeCodec struct {
    37  	Fields []CompositeCodecField
    38  }
    39  
    40  func (c *CompositeCodec) FormatSupported(format int16) bool {
    41  	for _, f := range c.Fields {
    42  		if !f.Type.Codec.FormatSupported(format) {
    43  			return false
    44  		}
    45  	}
    46  
    47  	return true
    48  }
    49  
    50  func (c *CompositeCodec) PreferredFormat() int16 {
    51  	if c.FormatSupported(BinaryFormatCode) {
    52  		return BinaryFormatCode
    53  	}
    54  	return TextFormatCode
    55  }
    56  
    57  func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
    58  	if _, ok := value.(CompositeIndexGetter); !ok {
    59  		return nil
    60  	}
    61  
    62  	switch format {
    63  	case BinaryFormatCode:
    64  		return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m}
    65  	case TextFormatCode:
    66  		return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m}
    67  	}
    68  
    69  	return nil
    70  }
    71  
    72  type encodePlanCompositeCodecCompositeIndexGetterToBinary struct {
    73  	cc *CompositeCodec
    74  	m  *Map
    75  }
    76  
    77  func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
    78  	getter := value.(CompositeIndexGetter)
    79  
    80  	if getter.IsNull() {
    81  		return nil, nil
    82  	}
    83  
    84  	builder := NewCompositeBinaryBuilder(plan.m, buf)
    85  	for i, field := range plan.cc.Fields {
    86  		builder.AppendValue(field.Type.OID, getter.Index(i))
    87  	}
    88  
    89  	return builder.Finish()
    90  }
    91  
    92  type encodePlanCompositeCodecCompositeIndexGetterToText struct {
    93  	cc *CompositeCodec
    94  	m  *Map
    95  }
    96  
    97  func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
    98  	getter := value.(CompositeIndexGetter)
    99  
   100  	if getter.IsNull() {
   101  		return nil, nil
   102  	}
   103  
   104  	b := NewCompositeTextBuilder(plan.m, buf)
   105  	for i, field := range plan.cc.Fields {
   106  		b.AppendValue(field.Type.OID, getter.Index(i))
   107  	}
   108  
   109  	return b.Finish()
   110  }
   111  
   112  func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
   113  	switch format {
   114  	case BinaryFormatCode:
   115  		switch target.(type) {
   116  		case CompositeIndexScanner:
   117  			return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m}
   118  		}
   119  	case TextFormatCode:
   120  		switch target.(type) {
   121  		case CompositeIndexScanner:
   122  			return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m}
   123  		}
   124  	}
   125  
   126  	return nil
   127  }
   128  
   129  type scanPlanBinaryCompositeToCompositeIndexScanner struct {
   130  	cc *CompositeCodec
   131  	m  *Map
   132  }
   133  
   134  func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
   135  	targetScanner := (target).(CompositeIndexScanner)
   136  
   137  	if src == nil {
   138  		return targetScanner.ScanNull()
   139  	}
   140  
   141  	scanner := NewCompositeBinaryScanner(plan.m, src)
   142  	for i, field := range plan.cc.Fields {
   143  		if scanner.Next() {
   144  			fieldTarget := targetScanner.ScanIndex(i)
   145  			if fieldTarget != nil {
   146  				fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget)
   147  				if fieldPlan == nil {
   148  					return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID)
   149  				}
   150  
   151  				err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
   152  				if err != nil {
   153  					return err
   154  				}
   155  			}
   156  		} else {
   157  			return errors.New("read past end of composite")
   158  		}
   159  	}
   160  
   161  	if err := scanner.Err(); err != nil {
   162  		return err
   163  	}
   164  
   165  	return nil
   166  }
   167  
   168  type scanPlanTextCompositeToCompositeIndexScanner struct {
   169  	cc *CompositeCodec
   170  	m  *Map
   171  }
   172  
   173  func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error {
   174  	targetScanner := (target).(CompositeIndexScanner)
   175  
   176  	if src == nil {
   177  		return targetScanner.ScanNull()
   178  	}
   179  
   180  	scanner := NewCompositeTextScanner(plan.m, src)
   181  	for i, field := range plan.cc.Fields {
   182  		if scanner.Next() {
   183  			fieldTarget := targetScanner.ScanIndex(i)
   184  			if fieldTarget != nil {
   185  				fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget)
   186  				if fieldPlan == nil {
   187  					return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID)
   188  				}
   189  
   190  				err := fieldPlan.Scan(scanner.Bytes(), fieldTarget)
   191  				if err != nil {
   192  					return err
   193  				}
   194  			}
   195  		} else {
   196  			return errors.New("read past end of composite")
   197  		}
   198  	}
   199  
   200  	if err := scanner.Err(); err != nil {
   201  		return err
   202  	}
   203  
   204  	return nil
   205  }
   206  
   207  func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
   208  	if src == nil {
   209  		return nil, nil
   210  	}
   211  
   212  	switch format {
   213  	case TextFormatCode:
   214  		return string(src), nil
   215  	case BinaryFormatCode:
   216  		buf := make([]byte, len(src))
   217  		copy(buf, src)
   218  		return buf, nil
   219  	default:
   220  		return nil, fmt.Errorf("unknown format code %d", format)
   221  	}
   222  }
   223  
   224  func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
   225  	if src == nil {
   226  		return nil, nil
   227  	}
   228  
   229  	switch format {
   230  	case TextFormatCode:
   231  		scanner := NewCompositeTextScanner(m, src)
   232  		values := make(map[string]any, len(c.Fields))
   233  		for i := 0; scanner.Next() && i < len(c.Fields); i++ {
   234  			var v any
   235  			fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v)
   236  			if fieldPlan == nil {
   237  				return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v)
   238  			}
   239  
   240  			err := fieldPlan.Scan(scanner.Bytes(), &v)
   241  			if err != nil {
   242  				return nil, err
   243  			}
   244  
   245  			values[c.Fields[i].Name] = v
   246  		}
   247  
   248  		if err := scanner.Err(); err != nil {
   249  			return nil, err
   250  		}
   251  
   252  		return values, nil
   253  	case BinaryFormatCode:
   254  		scanner := NewCompositeBinaryScanner(m, src)
   255  		values := make(map[string]any, len(c.Fields))
   256  		for i := 0; scanner.Next() && i < len(c.Fields); i++ {
   257  			var v any
   258  			fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v)
   259  			if fieldPlan == nil {
   260  				return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v)
   261  			}
   262  
   263  			err := fieldPlan.Scan(scanner.Bytes(), &v)
   264  			if err != nil {
   265  				return nil, err
   266  			}
   267  
   268  			values[c.Fields[i].Name] = v
   269  		}
   270  
   271  		if err := scanner.Err(); err != nil {
   272  			return nil, err
   273  		}
   274  
   275  		return values, nil
   276  	default:
   277  		return nil, fmt.Errorf("unknown format code %d", format)
   278  	}
   279  
   280  }
   281  
   282  type CompositeBinaryScanner struct {
   283  	m   *Map
   284  	rp  int
   285  	src []byte
   286  
   287  	fieldCount int32
   288  	fieldBytes []byte
   289  	fieldOID   uint32
   290  	err        error
   291  }
   292  
   293  // NewCompositeBinaryScanner a scanner over a binary encoded composite balue.
   294  func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner {
   295  	rp := 0
   296  	if len(src[rp:]) < 4 {
   297  		return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)}
   298  	}
   299  
   300  	fieldCount := int32(binary.BigEndian.Uint32(src[rp:]))
   301  	rp += 4
   302  
   303  	return &CompositeBinaryScanner{
   304  		m:          m,
   305  		rp:         rp,
   306  		src:        src,
   307  		fieldCount: fieldCount,
   308  	}
   309  }
   310  
   311  // Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
   312  // Next returns false, the Err method can be called to check if any errors occurred.
   313  func (cfs *CompositeBinaryScanner) Next() bool {
   314  	if cfs.err != nil {
   315  		return false
   316  	}
   317  
   318  	if cfs.rp == len(cfs.src) {
   319  		return false
   320  	}
   321  
   322  	if len(cfs.src[cfs.rp:]) < 8 {
   323  		cfs.err = fmt.Errorf("Record incomplete %v", cfs.src)
   324  		return false
   325  	}
   326  	cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:])
   327  	cfs.rp += 4
   328  
   329  	fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:])))
   330  	cfs.rp += 4
   331  
   332  	if fieldLen >= 0 {
   333  		if len(cfs.src[cfs.rp:]) < fieldLen {
   334  			cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src)
   335  			return false
   336  		}
   337  		cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen]
   338  		cfs.rp += fieldLen
   339  	} else {
   340  		cfs.fieldBytes = nil
   341  	}
   342  
   343  	return true
   344  }
   345  
   346  func (cfs *CompositeBinaryScanner) FieldCount() int {
   347  	return int(cfs.fieldCount)
   348  }
   349  
   350  // Bytes returns the bytes of the field most recently read by Scan().
   351  func (cfs *CompositeBinaryScanner) Bytes() []byte {
   352  	return cfs.fieldBytes
   353  }
   354  
   355  // OID returns the OID of the field most recently read by Scan().
   356  func (cfs *CompositeBinaryScanner) OID() uint32 {
   357  	return cfs.fieldOID
   358  }
   359  
   360  // Err returns any error encountered by the scanner.
   361  func (cfs *CompositeBinaryScanner) Err() error {
   362  	return cfs.err
   363  }
   364  
   365  type CompositeTextScanner struct {
   366  	m   *Map
   367  	rp  int
   368  	src []byte
   369  
   370  	fieldBytes []byte
   371  	err        error
   372  }
   373  
   374  // NewCompositeTextScanner a scanner over a text encoded composite value.
   375  func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner {
   376  	if len(src) < 2 {
   377  		return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)}
   378  	}
   379  
   380  	if src[0] != '(' {
   381  		return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")}
   382  	}
   383  
   384  	if src[len(src)-1] != ')' {
   385  		return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")}
   386  	}
   387  
   388  	return &CompositeTextScanner{
   389  		m:   m,
   390  		rp:  1,
   391  		src: src,
   392  	}
   393  }
   394  
   395  // Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After
   396  // Next returns false, the Err method can be called to check if any errors occurred.
   397  func (cfs *CompositeTextScanner) Next() bool {
   398  	if cfs.err != nil {
   399  		return false
   400  	}
   401  
   402  	if cfs.rp == len(cfs.src) {
   403  		return false
   404  	}
   405  
   406  	switch cfs.src[cfs.rp] {
   407  	case ',', ')': // null
   408  		cfs.rp++
   409  		cfs.fieldBytes = nil
   410  		return true
   411  	case '"': // quoted value
   412  		cfs.rp++
   413  		cfs.fieldBytes = make([]byte, 0, 16)
   414  		for {
   415  			ch := cfs.src[cfs.rp]
   416  
   417  			if ch == '"' {
   418  				cfs.rp++
   419  				if cfs.src[cfs.rp] == '"' {
   420  					cfs.fieldBytes = append(cfs.fieldBytes, '"')
   421  					cfs.rp++
   422  				} else {
   423  					break
   424  				}
   425  			} else if ch == '\\' {
   426  				cfs.rp++
   427  				cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp])
   428  				cfs.rp++
   429  			} else {
   430  				cfs.fieldBytes = append(cfs.fieldBytes, ch)
   431  				cfs.rp++
   432  			}
   433  		}
   434  		cfs.rp++
   435  		return true
   436  	default: // unquoted value
   437  		start := cfs.rp
   438  		for {
   439  			ch := cfs.src[cfs.rp]
   440  			if ch == ',' || ch == ')' {
   441  				break
   442  			}
   443  			cfs.rp++
   444  		}
   445  		cfs.fieldBytes = cfs.src[start:cfs.rp]
   446  		cfs.rp++
   447  		return true
   448  	}
   449  }
   450  
   451  // Bytes returns the bytes of the field most recently read by Scan().
   452  func (cfs *CompositeTextScanner) Bytes() []byte {
   453  	return cfs.fieldBytes
   454  }
   455  
   456  // Err returns any error encountered by the scanner.
   457  func (cfs *CompositeTextScanner) Err() error {
   458  	return cfs.err
   459  }
   460  
   461  type CompositeBinaryBuilder struct {
   462  	m          *Map
   463  	buf        []byte
   464  	startIdx   int
   465  	fieldCount uint32
   466  	err        error
   467  }
   468  
   469  func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder {
   470  	startIdx := len(buf)
   471  	buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields
   472  	return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx}
   473  }
   474  
   475  func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) {
   476  	if b.err != nil {
   477  		return
   478  	}
   479  
   480  	if field == nil {
   481  		b.buf = pgio.AppendUint32(b.buf, oid)
   482  		b.buf = pgio.AppendInt32(b.buf, -1)
   483  		b.fieldCount++
   484  		return
   485  	}
   486  
   487  	plan := b.m.PlanEncode(oid, BinaryFormatCode, field)
   488  	if plan == nil {
   489  		b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid)
   490  		return
   491  	}
   492  
   493  	b.buf = pgio.AppendUint32(b.buf, oid)
   494  	lengthPos := len(b.buf)
   495  	b.buf = pgio.AppendInt32(b.buf, -1)
   496  	fieldBuf, err := plan.Encode(field, b.buf)
   497  	if err != nil {
   498  		b.err = err
   499  		return
   500  	}
   501  	if fieldBuf != nil {
   502  		binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf)))
   503  		b.buf = fieldBuf
   504  	}
   505  
   506  	b.fieldCount++
   507  }
   508  
   509  func (b *CompositeBinaryBuilder) Finish() ([]byte, error) {
   510  	if b.err != nil {
   511  		return nil, b.err
   512  	}
   513  
   514  	binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount)
   515  	return b.buf, nil
   516  }
   517  
   518  type CompositeTextBuilder struct {
   519  	m          *Map
   520  	buf        []byte
   521  	startIdx   int
   522  	fieldCount uint32
   523  	err        error
   524  	fieldBuf   [32]byte
   525  }
   526  
   527  func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder {
   528  	buf = append(buf, '(') // allocate room for number of fields
   529  	return &CompositeTextBuilder{m: m, buf: buf}
   530  }
   531  
   532  func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) {
   533  	if b.err != nil {
   534  		return
   535  	}
   536  
   537  	if field == nil {
   538  		b.buf = append(b.buf, ',')
   539  		return
   540  	}
   541  
   542  	plan := b.m.PlanEncode(oid, TextFormatCode, field)
   543  	if plan == nil {
   544  		b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid)
   545  		return
   546  	}
   547  
   548  	fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0])
   549  	if err != nil {
   550  		b.err = err
   551  		return
   552  	}
   553  	if fieldBuf != nil {
   554  		b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...)
   555  	}
   556  
   557  	b.buf = append(b.buf, ',')
   558  }
   559  
   560  func (b *CompositeTextBuilder) Finish() ([]byte, error) {
   561  	if b.err != nil {
   562  		return nil, b.err
   563  	}
   564  
   565  	b.buf[len(b.buf)-1] = ')'
   566  	return b.buf, nil
   567  }
   568  
   569  var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
   570  
   571  func quoteCompositeField(src string) string {
   572  	return `"` + quoteCompositeReplacer.Replace(src) + `"`
   573  }
   574  
   575  func quoteCompositeFieldIfNeeded(src string) string {
   576  	if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) {
   577  		return quoteCompositeField(src)
   578  	}
   579  	return src
   580  }
   581  
   582  // CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target.
   583  // It cannot scan a NULL, but the composite fields can be NULL.
   584  type CompositeFields []any
   585  
   586  func (cf CompositeFields) SkipUnderlyingTypePlan() {}
   587  
   588  func (cf CompositeFields) IsNull() bool {
   589  	return cf == nil
   590  }
   591  
   592  func (cf CompositeFields) Index(i int) any {
   593  	return cf[i]
   594  }
   595  
   596  func (cf CompositeFields) ScanNull() error {
   597  	return fmt.Errorf("cannot scan NULL into CompositeFields")
   598  }
   599  
   600  func (cf CompositeFields) ScanIndex(i int) any {
   601  	return cf[i]
   602  }
   603  

View as plain text