...

Source file src/github.com/jackc/pgtype/array_type.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"reflect"
     8  
     9  	"github.com/jackc/pgio"
    10  )
    11  
    12  // ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties
    13  // when registered as a data type in a ConnType. It should not be used directly as a Value. ArrayType is a convenience
    14  // type for types that do not have a concrete array type.
    15  type ArrayType struct {
    16  	elements   []ValueTranscoder
    17  	dimensions []ArrayDimension
    18  
    19  	typeName   string
    20  	newElement func() ValueTranscoder
    21  
    22  	elementOID uint32
    23  	status     Status
    24  }
    25  
    26  func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType {
    27  	return &ArrayType{typeName: typeName, elementOID: elementOID, newElement: newElement}
    28  }
    29  
    30  func (at *ArrayType) NewTypeValue() Value {
    31  	return &ArrayType{
    32  		elements:   at.elements,
    33  		dimensions: at.dimensions,
    34  		status:     at.status,
    35  
    36  		typeName:   at.typeName,
    37  		elementOID: at.elementOID,
    38  		newElement: at.newElement,
    39  	}
    40  }
    41  
    42  func (at *ArrayType) TypeName() string {
    43  	return at.typeName
    44  }
    45  
    46  func (dst *ArrayType) setNil() {
    47  	dst.elements = nil
    48  	dst.dimensions = nil
    49  	dst.status = Null
    50  }
    51  
    52  func (dst *ArrayType) Set(src interface{}) error {
    53  	// untyped nil and typed nil interfaces are different
    54  	if src == nil {
    55  		dst.setNil()
    56  		return nil
    57  	}
    58  
    59  	sliceVal := reflect.ValueOf(src)
    60  	if sliceVal.Kind() != reflect.Slice {
    61  		return fmt.Errorf("cannot set non-slice")
    62  	}
    63  
    64  	if sliceVal.IsNil() {
    65  		dst.setNil()
    66  		return nil
    67  	}
    68  
    69  	dst.elements = make([]ValueTranscoder, sliceVal.Len())
    70  	for i := range dst.elements {
    71  		v := dst.newElement()
    72  		err := v.Set(sliceVal.Index(i).Interface())
    73  		if err != nil {
    74  			return err
    75  		}
    76  
    77  		dst.elements[i] = v
    78  	}
    79  	dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}}
    80  	dst.status = Present
    81  
    82  	return nil
    83  }
    84  
    85  func (dst ArrayType) Get() interface{} {
    86  	switch dst.status {
    87  	case Present:
    88  		elementValues := make([]interface{}, len(dst.elements))
    89  		for i := range dst.elements {
    90  			elementValues[i] = dst.elements[i].Get()
    91  		}
    92  		return elementValues
    93  	case Null:
    94  		return nil
    95  	default:
    96  		return dst.status
    97  	}
    98  }
    99  
   100  func (src *ArrayType) AssignTo(dst interface{}) error {
   101  	ptrSlice := reflect.ValueOf(dst)
   102  	if ptrSlice.Kind() != reflect.Ptr {
   103  		return fmt.Errorf("cannot assign to non-pointer")
   104  	}
   105  
   106  	sliceVal := ptrSlice.Elem()
   107  	sliceType := sliceVal.Type()
   108  
   109  	if sliceType.Kind() != reflect.Slice {
   110  		return fmt.Errorf("cannot assign to pointer to non-slice")
   111  	}
   112  
   113  	switch src.status {
   114  	case Present:
   115  		slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements))
   116  		elemType := sliceType.Elem()
   117  
   118  		for i := range src.elements {
   119  			ptrElem := reflect.New(elemType)
   120  			err := src.elements[i].AssignTo(ptrElem.Interface())
   121  			if err != nil {
   122  				return err
   123  			}
   124  
   125  			slice.Index(i).Set(ptrElem.Elem())
   126  		}
   127  
   128  		sliceVal.Set(slice)
   129  		return nil
   130  	case Null:
   131  		sliceVal.Set(reflect.Zero(sliceType))
   132  		return nil
   133  	}
   134  
   135  	return fmt.Errorf("cannot decode %#v into %T", src, dst)
   136  }
   137  
   138  func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
   139  	if src == nil {
   140  		dst.setNil()
   141  		return nil
   142  	}
   143  
   144  	uta, err := ParseUntypedTextArray(string(src))
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	var elements []ValueTranscoder
   150  
   151  	if len(uta.Elements) > 0 {
   152  		elements = make([]ValueTranscoder, len(uta.Elements))
   153  
   154  		for i, s := range uta.Elements {
   155  			elem := dst.newElement()
   156  			var elemSrc []byte
   157  			if s != "NULL" {
   158  				elemSrc = []byte(s)
   159  			}
   160  			err = elem.DecodeText(ci, elemSrc)
   161  			if err != nil {
   162  				return err
   163  			}
   164  
   165  			elements[i] = elem
   166  		}
   167  	}
   168  
   169  	dst.elements = elements
   170  	dst.dimensions = uta.Dimensions
   171  	dst.status = Present
   172  
   173  	return nil
   174  }
   175  
   176  func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error {
   177  	if src == nil {
   178  		dst.setNil()
   179  		return nil
   180  	}
   181  
   182  	var arrayHeader ArrayHeader
   183  	rp, err := arrayHeader.DecodeBinary(ci, src)
   184  	if err != nil {
   185  		return err
   186  	}
   187  
   188  	var elements []ValueTranscoder
   189  
   190  	if len(arrayHeader.Dimensions) == 0 {
   191  		dst.elements = elements
   192  		dst.dimensions = arrayHeader.Dimensions
   193  		dst.status = Present
   194  		return nil
   195  	}
   196  
   197  	elementCount := arrayHeader.Dimensions[0].Length
   198  	for _, d := range arrayHeader.Dimensions[1:] {
   199  		elementCount *= d.Length
   200  	}
   201  
   202  	elements = make([]ValueTranscoder, elementCount)
   203  
   204  	for i := range elements {
   205  		elem := dst.newElement()
   206  		elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
   207  		rp += 4
   208  		var elemSrc []byte
   209  		if elemLen >= 0 {
   210  			elemSrc = src[rp : rp+elemLen]
   211  			rp += elemLen
   212  		}
   213  		err = elem.DecodeBinary(ci, elemSrc)
   214  		if err != nil {
   215  			return err
   216  		}
   217  
   218  		elements[i] = elem
   219  	}
   220  
   221  	dst.elements = elements
   222  	dst.dimensions = arrayHeader.Dimensions
   223  	dst.status = Present
   224  
   225  	return nil
   226  }
   227  
   228  func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
   229  	switch src.status {
   230  	case Null:
   231  		return nil, nil
   232  	case Undefined:
   233  		return nil, errUndefined
   234  	}
   235  
   236  	if len(src.dimensions) == 0 {
   237  		return append(buf, '{', '}'), nil
   238  	}
   239  
   240  	buf = EncodeTextArrayDimensions(buf, src.dimensions)
   241  
   242  	// dimElemCounts is the multiples of elements that each array lies on. For
   243  	// example, a single dimension array of length 4 would have a dimElemCounts of
   244  	// [4]. A multi-dimensional array of lengths [3,5,2] would have a
   245  	// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
   246  	// or '}'.
   247  	dimElemCounts := make([]int, len(src.dimensions))
   248  	dimElemCounts[len(src.dimensions)-1] = int(src.dimensions[len(src.dimensions)-1].Length)
   249  	for i := len(src.dimensions) - 2; i > -1; i-- {
   250  		dimElemCounts[i] = int(src.dimensions[i].Length) * dimElemCounts[i+1]
   251  	}
   252  
   253  	inElemBuf := make([]byte, 0, 32)
   254  	for i, elem := range src.elements {
   255  		if i > 0 {
   256  			buf = append(buf, ',')
   257  		}
   258  
   259  		for _, dec := range dimElemCounts {
   260  			if i%dec == 0 {
   261  				buf = append(buf, '{')
   262  			}
   263  		}
   264  
   265  		elemBuf, err := elem.EncodeText(ci, inElemBuf)
   266  		if err != nil {
   267  			return nil, err
   268  		}
   269  		if elemBuf == nil {
   270  			buf = append(buf, `NULL`...)
   271  		} else {
   272  			buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
   273  		}
   274  
   275  		for _, dec := range dimElemCounts {
   276  			if (i+1)%dec == 0 {
   277  				buf = append(buf, '}')
   278  			}
   279  		}
   280  	}
   281  
   282  	return buf, nil
   283  }
   284  
   285  func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
   286  	switch src.status {
   287  	case Null:
   288  		return nil, nil
   289  	case Undefined:
   290  		return nil, errUndefined
   291  	}
   292  
   293  	arrayHeader := ArrayHeader{
   294  		Dimensions: src.dimensions,
   295  		ElementOID: int32(src.elementOID),
   296  	}
   297  
   298  	for i := range src.elements {
   299  		if src.elements[i].Get() == nil {
   300  			arrayHeader.ContainsNull = true
   301  			break
   302  		}
   303  	}
   304  
   305  	buf = arrayHeader.EncodeBinary(ci, buf)
   306  
   307  	for i := range src.elements {
   308  		sp := len(buf)
   309  		buf = pgio.AppendInt32(buf, -1)
   310  
   311  		elemBuf, err := src.elements[i].EncodeBinary(ci, buf)
   312  		if err != nil {
   313  			return nil, err
   314  		}
   315  		if elemBuf != nil {
   316  			buf = elemBuf
   317  			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
   318  		}
   319  	}
   320  
   321  	return buf, nil
   322  }
   323  
   324  // Scan implements the database/sql Scanner interface.
   325  func (dst *ArrayType) Scan(src interface{}) error {
   326  	if src == nil {
   327  		return dst.DecodeText(nil, nil)
   328  	}
   329  
   330  	switch src := src.(type) {
   331  	case string:
   332  		return dst.DecodeText(nil, []byte(src))
   333  	case []byte:
   334  		srcCopy := make([]byte, len(src))
   335  		copy(srcCopy, src)
   336  		return dst.DecodeText(nil, srcCopy)
   337  	}
   338  
   339  	return fmt.Errorf("cannot scan %T", src)
   340  }
   341  
   342  // Value implements the database/sql/driver Valuer interface.
   343  func (src ArrayType) Value() (driver.Value, error) {
   344  	buf, err := src.EncodeText(nil, nil)
   345  	if err != nil {
   346  		return nil, err
   347  	}
   348  	if buf == nil {
   349  		return nil, nil
   350  	}
   351  
   352  	return string(buf), nil
   353  }
   354  

View as plain text