...

Source file src/github.com/jackc/pgtype/pgtype.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"net"
    10  	"reflect"
    11  	"time"
    12  )
    13  
    14  // PostgreSQL oids for common types
    15  const (
    16  	BoolOID             = 16
    17  	ByteaOID            = 17
    18  	QCharOID            = 18
    19  	NameOID             = 19
    20  	Int8OID             = 20
    21  	Int2OID             = 21
    22  	Int4OID             = 23
    23  	TextOID             = 25
    24  	OIDOID              = 26
    25  	TIDOID              = 27
    26  	XIDOID              = 28
    27  	CIDOID              = 29
    28  	JSONOID             = 114
    29  	JSONArrayOID        = 199
    30  	PointOID            = 600
    31  	LsegOID             = 601
    32  	PathOID             = 602
    33  	BoxOID              = 603
    34  	PolygonOID          = 604
    35  	LineOID             = 628
    36  	CIDROID             = 650
    37  	CIDRArrayOID        = 651
    38  	Float4OID           = 700
    39  	Float8OID           = 701
    40  	CircleOID           = 718
    41  	UnknownOID          = 705
    42  	MacaddrOID          = 829
    43  	InetOID             = 869
    44  	BoolArrayOID        = 1000
    45  	Int2ArrayOID        = 1005
    46  	Int4ArrayOID        = 1007
    47  	TextArrayOID        = 1009
    48  	ByteaArrayOID       = 1001
    49  	BPCharArrayOID      = 1014
    50  	VarcharArrayOID     = 1015
    51  	Int8ArrayOID        = 1016
    52  	Float4ArrayOID      = 1021
    53  	Float8ArrayOID      = 1022
    54  	ACLItemOID          = 1033
    55  	ACLItemArrayOID     = 1034
    56  	InetArrayOID        = 1041
    57  	BPCharOID           = 1042
    58  	VarcharOID          = 1043
    59  	DateOID             = 1082
    60  	TimeOID             = 1083
    61  	TimestampOID        = 1114
    62  	TimestampArrayOID   = 1115
    63  	DateArrayOID        = 1182
    64  	TimestamptzOID      = 1184
    65  	TimestamptzArrayOID = 1185
    66  	IntervalOID         = 1186
    67  	NumericArrayOID     = 1231
    68  	BitOID              = 1560
    69  	VarbitOID           = 1562
    70  	NumericOID          = 1700
    71  	RecordOID           = 2249
    72  	UUIDOID             = 2950
    73  	UUIDArrayOID        = 2951
    74  	JSONBOID            = 3802
    75  	JSONBArrayOID       = 3807
    76  	DaterangeOID        = 3912
    77  	Int4rangeOID        = 3904
    78  	Int4multirangeOID   = 4451
    79  	NumrangeOID         = 3906
    80  	NummultirangeOID    = 4532
    81  	TsrangeOID          = 3908
    82  	TsrangeArrayOID     = 3909
    83  	TstzrangeOID        = 3910
    84  	TstzrangeArrayOID   = 3911
    85  	Int8rangeOID        = 3926
    86  	Int8multirangeOID   = 4536
    87  )
    88  
    89  type Status byte
    90  
    91  const (
    92  	Undefined Status = iota
    93  	Null
    94  	Present
    95  )
    96  
    97  type InfinityModifier int8
    98  
    99  const (
   100  	Infinity         InfinityModifier = 1
   101  	None             InfinityModifier = 0
   102  	NegativeInfinity InfinityModifier = -Infinity
   103  )
   104  
   105  func (im InfinityModifier) String() string {
   106  	switch im {
   107  	case None:
   108  		return "none"
   109  	case Infinity:
   110  		return "infinity"
   111  	case NegativeInfinity:
   112  		return "-infinity"
   113  	default:
   114  		return "invalid"
   115  	}
   116  }
   117  
   118  // PostgreSQL format codes
   119  const (
   120  	TextFormatCode   = 0
   121  	BinaryFormatCode = 1
   122  )
   123  
   124  // Value translates values to and from an internal canonical representation for the type. To actually be usable a type
   125  // that implements Value should also implement some combination of BinaryDecoder, BinaryEncoder, TextDecoder,
   126  // and TextEncoder.
   127  //
   128  // Operations that update a Value (e.g. Set, DecodeText, DecodeBinary) should entirely replace the value. e.g. Internal
   129  // slices should be replaced not resized and reused. This allows Get and AssignTo to return a slice directly rather
   130  // than incur a usually unnecessary copy.
   131  type Value interface {
   132  	// Set converts and assigns src to itself. Value takes ownership of src.
   133  	Set(src interface{}) error
   134  
   135  	// Get returns the simplest representation of Value. Get may return a pointer to an internal value but it must never
   136  	// mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte.
   137  	Get() interface{}
   138  
   139  	// AssignTo converts and assigns the Value to dst. AssignTo may a pointer to an internal value but it must never
   140  	// mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte.
   141  	AssignTo(dst interface{}) error
   142  }
   143  
   144  // TypeValue is a Value where instances can represent different PostgreSQL types. This can be useful for
   145  // representing types such as enums, composites, and arrays.
   146  //
   147  // In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an
   148  // encoder and decoder internal to ConnInfo.
   149  type TypeValue interface {
   150  	Value
   151  
   152  	// NewTypeValue creates a TypeValue including references to internal type information. e.g. the list of members
   153  	// in an EnumType.
   154  	NewTypeValue() Value
   155  
   156  	// TypeName returns the PostgreSQL name of this type.
   157  	TypeName() string
   158  }
   159  
   160  // ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces.
   161  type ValueTranscoder interface {
   162  	Value
   163  	TextEncoder
   164  	BinaryEncoder
   165  	TextDecoder
   166  	BinaryDecoder
   167  }
   168  
   169  // ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from
   170  // whether it is also a BinaryDecoder.
   171  type ResultFormatPreferrer interface {
   172  	PreferredResultFormat() int16
   173  }
   174  
   175  // ParamFormatPreferrer allows a type to specify its preferred param format instead of it being inferred from
   176  // whether it is also a BinaryEncoder.
   177  type ParamFormatPreferrer interface {
   178  	PreferredParamFormat() int16
   179  }
   180  
   181  type BinaryDecoder interface {
   182  	// DecodeBinary decodes src into BinaryDecoder. If src is nil then the
   183  	// original SQL value is NULL. BinaryDecoder takes ownership of src. The
   184  	// caller MUST not use it again.
   185  	DecodeBinary(ci *ConnInfo, src []byte) error
   186  }
   187  
   188  type TextDecoder interface {
   189  	// DecodeText decodes src into TextDecoder. If src is nil then the original
   190  	// SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not
   191  	// use it again.
   192  	DecodeText(ci *ConnInfo, src []byte) error
   193  }
   194  
   195  // BinaryEncoder is implemented by types that can encode themselves into the
   196  // PostgreSQL binary wire format.
   197  type BinaryEncoder interface {
   198  	// EncodeBinary should append the binary format of self to buf. If self is the
   199  	// SQL value NULL then append nothing and return (nil, nil). The caller of
   200  	// EncodeBinary is responsible for writing the correct NULL value or the
   201  	// length of the data written.
   202  	EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error)
   203  }
   204  
   205  // TextEncoder is implemented by types that can encode themselves into the
   206  // PostgreSQL text wire format.
   207  type TextEncoder interface {
   208  	// EncodeText should append the text format of self to buf. If self is the
   209  	// SQL value NULL then append nothing and return (nil, nil). The caller of
   210  	// EncodeText is responsible for writing the correct NULL value or the
   211  	// length of the data written.
   212  	EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error)
   213  }
   214  
   215  var errUndefined = errors.New("cannot encode status undefined")
   216  var errBadStatus = errors.New("invalid status")
   217  
   218  type nullAssignmentError struct {
   219  	dst interface{}
   220  }
   221  
   222  func (e *nullAssignmentError) Error() string {
   223  	return fmt.Sprintf("cannot assign NULL to %T", e.dst)
   224  }
   225  
   226  type DataType struct {
   227  	Value Value
   228  
   229  	textDecoder   TextDecoder
   230  	binaryDecoder BinaryDecoder
   231  
   232  	Name string
   233  	OID  uint32
   234  }
   235  
   236  type ConnInfo struct {
   237  	oidToDataType         map[uint32]*DataType
   238  	nameToDataType        map[string]*DataType
   239  	reflectTypeToName     map[reflect.Type]string
   240  	oidToParamFormatCode  map[uint32]int16
   241  	oidToResultFormatCode map[uint32]int16
   242  
   243  	reflectTypeToDataType map[reflect.Type]*DataType
   244  }
   245  
   246  func newConnInfo() *ConnInfo {
   247  	return &ConnInfo{
   248  		oidToDataType:         make(map[uint32]*DataType),
   249  		nameToDataType:        make(map[string]*DataType),
   250  		reflectTypeToName:     make(map[reflect.Type]string),
   251  		oidToParamFormatCode:  make(map[uint32]int16),
   252  		oidToResultFormatCode: make(map[uint32]int16),
   253  	}
   254  }
   255  
   256  func NewConnInfo() *ConnInfo {
   257  	ci := newConnInfo()
   258  
   259  	ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID})
   260  	ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID})
   261  	ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID})
   262  	ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID})
   263  	ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID})
   264  	ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID})
   265  	ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID})
   266  	ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID})
   267  	ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID})
   268  	ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID})
   269  	ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID})
   270  	ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID})
   271  	ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID})
   272  	ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID})
   273  	ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID})
   274  	ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID})
   275  	ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID})
   276  	ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID})
   277  	ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID})
   278  	ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID})
   279  	ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID})
   280  	ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID})
   281  	ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID})
   282  	ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID})
   283  	ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID})
   284  	ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID})
   285  	ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID})
   286  	ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID})
   287  	ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID})
   288  	ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID})
   289  	ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID})
   290  	ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID})
   291  	ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID})
   292  	ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID})
   293  	ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID})
   294  	ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
   295  	ci.RegisterDataType(DataType{Value: &Int4multirange{}, Name: "int4multirange", OID: Int4multirangeOID})
   296  	ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID})
   297  	ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID})
   298  	ci.RegisterDataType(DataType{Value: &Int8multirange{}, Name: "int8multirange", OID: Int8multirangeOID})
   299  	ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID})
   300  	ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID})
   301  	ci.RegisterDataType(DataType{Value: &JSONArray{}, Name: "_json", OID: JSONArrayOID})
   302  	ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID})
   303  	ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID})
   304  	ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID})
   305  	ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID})
   306  	ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID})
   307  	ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID})
   308  	ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID})
   309  	ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID})
   310  	ci.RegisterDataType(DataType{Value: &Nummultirange{}, Name: "nummultirange", OID: NummultirangeOID})
   311  	ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID})
   312  	ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID})
   313  	ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID})
   314  	ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID})
   315  	ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID})
   316  	ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID})
   317  	ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID})
   318  	ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID})
   319  	ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID})
   320  	ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID})
   321  	ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID})
   322  	ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID})
   323  	ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID})
   324  	ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID})
   325  	ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID})
   326  	ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID})
   327  	ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID})
   328  	ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID})
   329  	ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID})
   330  
   331  	registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) {
   332  		ci.RegisterDefaultPgType(value, name)
   333  		valueType := reflect.TypeOf(value)
   334  
   335  		ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
   336  
   337  		sliceType := reflect.SliceOf(valueType)
   338  		ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
   339  
   340  		ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
   341  	}
   342  
   343  	// Integer types that directly map to a PostgreSQL type
   344  	registerDefaultPgTypeVariants("int2", "_int2", int16(0))
   345  	registerDefaultPgTypeVariants("int4", "_int4", int32(0))
   346  	registerDefaultPgTypeVariants("int8", "_int8", int64(0))
   347  
   348  	// Integer types that do not have a direct match to a PostgreSQL type
   349  	registerDefaultPgTypeVariants("int8", "_int8", uint16(0))
   350  	registerDefaultPgTypeVariants("int8", "_int8", uint32(0))
   351  	registerDefaultPgTypeVariants("int8", "_int8", uint64(0))
   352  	registerDefaultPgTypeVariants("int8", "_int8", int(0))
   353  	registerDefaultPgTypeVariants("int8", "_int8", uint(0))
   354  
   355  	registerDefaultPgTypeVariants("float4", "_float4", float32(0))
   356  	registerDefaultPgTypeVariants("float8", "_float8", float64(0))
   357  
   358  	registerDefaultPgTypeVariants("bool", "_bool", false)
   359  	registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{})
   360  	registerDefaultPgTypeVariants("text", "_text", "")
   361  	registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil))
   362  
   363  	registerDefaultPgTypeVariants("inet", "_inet", net.IP{})
   364  	ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr")
   365  	ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr")
   366  
   367  	return ci
   368  }
   369  
   370  func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) {
   371  	for name, oid := range nameOIDs {
   372  		var value Value
   373  		if t, ok := nameValues[name]; ok {
   374  			value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value)
   375  		} else {
   376  			value = &GenericText{}
   377  		}
   378  		ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid})
   379  	}
   380  }
   381  
   382  func (ci *ConnInfo) RegisterDataType(t DataType) {
   383  	t.Value = NewValue(t.Value)
   384  
   385  	ci.oidToDataType[t.OID] = &t
   386  	ci.nameToDataType[t.Name] = &t
   387  
   388  	{
   389  		var formatCode int16
   390  		if pfp, ok := t.Value.(ParamFormatPreferrer); ok {
   391  			formatCode = pfp.PreferredParamFormat()
   392  		} else if _, ok := t.Value.(BinaryEncoder); ok {
   393  			formatCode = BinaryFormatCode
   394  		}
   395  		ci.oidToParamFormatCode[t.OID] = formatCode
   396  	}
   397  
   398  	{
   399  		var formatCode int16
   400  		if rfp, ok := t.Value.(ResultFormatPreferrer); ok {
   401  			formatCode = rfp.PreferredResultFormat()
   402  		} else if _, ok := t.Value.(BinaryDecoder); ok {
   403  			formatCode = BinaryFormatCode
   404  		}
   405  		ci.oidToResultFormatCode[t.OID] = formatCode
   406  	}
   407  
   408  	if d, ok := t.Value.(TextDecoder); ok {
   409  		t.textDecoder = d
   410  	}
   411  
   412  	if d, ok := t.Value.(BinaryDecoder); ok {
   413  		t.binaryDecoder = d
   414  	}
   415  
   416  	ci.reflectTypeToDataType = nil // Invalidated by type registration
   417  }
   418  
   419  // RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be
   420  // encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is
   421  // unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type.
   422  func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) {
   423  	ci.reflectTypeToName[reflect.TypeOf(value)] = name
   424  	ci.reflectTypeToDataType = nil // Invalidated by registering a default type
   425  }
   426  
   427  func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) {
   428  	dt, ok := ci.oidToDataType[oid]
   429  	return dt, ok
   430  }
   431  
   432  func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) {
   433  	dt, ok := ci.nameToDataType[name]
   434  	return dt, ok
   435  }
   436  
   437  func (ci *ConnInfo) buildReflectTypeToDataType() {
   438  	ci.reflectTypeToDataType = make(map[reflect.Type]*DataType)
   439  
   440  	for _, dt := range ci.oidToDataType {
   441  		if _, is := dt.Value.(TypeValue); !is {
   442  			ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt
   443  		}
   444  	}
   445  
   446  	for reflectType, name := range ci.reflectTypeToName {
   447  		if dt, ok := ci.nameToDataType[name]; ok {
   448  			ci.reflectTypeToDataType[reflectType] = dt
   449  		}
   450  	}
   451  }
   452  
   453  // DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode
   454  // themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type.
   455  func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) {
   456  	if ci.reflectTypeToDataType == nil {
   457  		ci.buildReflectTypeToDataType()
   458  	}
   459  
   460  	if tv, ok := v.(TypeValue); ok {
   461  		dt, ok := ci.nameToDataType[tv.TypeName()]
   462  		return dt, ok
   463  	}
   464  
   465  	dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)]
   466  	return dt, ok
   467  }
   468  
   469  func (ci *ConnInfo) ParamFormatCodeForOID(oid uint32) int16 {
   470  	fc, ok := ci.oidToParamFormatCode[oid]
   471  	if ok {
   472  		return fc
   473  	}
   474  	return TextFormatCode
   475  }
   476  
   477  func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 {
   478  	fc, ok := ci.oidToResultFormatCode[oid]
   479  	if ok {
   480  		return fc
   481  	}
   482  	return TextFormatCode
   483  }
   484  
   485  // DeepCopy makes a deep copy of the ConnInfo.
   486  func (ci *ConnInfo) DeepCopy() *ConnInfo {
   487  	ci2 := newConnInfo()
   488  
   489  	for _, dt := range ci.oidToDataType {
   490  		ci2.RegisterDataType(DataType{
   491  			Value: NewValue(dt.Value),
   492  			Name:  dt.Name,
   493  			OID:   dt.OID,
   494  		})
   495  	}
   496  
   497  	for t, n := range ci.reflectTypeToName {
   498  		ci2.reflectTypeToName[t] = n
   499  	}
   500  
   501  	return ci2
   502  }
   503  
   504  // ScanPlan is a precompiled plan to scan into a type of destination.
   505  type ScanPlan interface {
   506  	// Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically
   507  	// replan and scan.
   508  	Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error
   509  }
   510  
   511  type scanPlanDstBinaryDecoder struct{}
   512  
   513  func (scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   514  	if d, ok := (dst).(BinaryDecoder); ok {
   515  		return d.DecodeBinary(ci, src)
   516  	}
   517  
   518  	newPlan := ci.PlanScan(oid, formatCode, dst)
   519  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   520  }
   521  
   522  type scanPlanDstTextDecoder struct{}
   523  
   524  func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   525  	if d, ok := (dst).(TextDecoder); ok {
   526  		return d.DecodeText(ci, src)
   527  	}
   528  
   529  	newPlan := ci.PlanScan(oid, formatCode, dst)
   530  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   531  }
   532  
   533  type scanPlanDataTypeSQLScanner DataType
   534  
   535  func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   536  	scanner, ok := dst.(sql.Scanner)
   537  	if !ok {
   538  		dv := reflect.ValueOf(dst)
   539  		if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
   540  			newPlan := ci.PlanScan(oid, formatCode, dst)
   541  			return newPlan.Scan(ci, oid, formatCode, src, dst)
   542  		}
   543  		if src == nil {
   544  			// Ensure the pointer points to a zero version of the value
   545  			dv.Elem().Set(reflect.Zero(dv.Type().Elem()))
   546  			return nil
   547  		}
   548  		dv = dv.Elem()
   549  		// If the pointer is to a nil pointer then set that before scanning
   550  		if dv.Kind() == reflect.Ptr && dv.IsNil() {
   551  			dv.Set(reflect.New(dv.Type().Elem()))
   552  		}
   553  		scanner = dv.Interface().(sql.Scanner)
   554  	}
   555  
   556  	dt := (*DataType)(plan)
   557  	var err error
   558  	switch formatCode {
   559  	case BinaryFormatCode:
   560  		err = dt.binaryDecoder.DecodeBinary(ci, src)
   561  	case TextFormatCode:
   562  		err = dt.textDecoder.DecodeText(ci, src)
   563  	}
   564  	if err != nil {
   565  		return err
   566  	}
   567  
   568  	sqlSrc, err := DatabaseSQLValue(ci, dt.Value)
   569  	if err != nil {
   570  		return err
   571  	}
   572  	return scanner.Scan(sqlSrc)
   573  }
   574  
   575  type scanPlanDataTypeAssignTo DataType
   576  
   577  func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   578  	dt := (*DataType)(plan)
   579  	var err error
   580  	switch formatCode {
   581  	case BinaryFormatCode:
   582  		err = dt.binaryDecoder.DecodeBinary(ci, src)
   583  	case TextFormatCode:
   584  		err = dt.textDecoder.DecodeText(ci, src)
   585  	}
   586  	if err != nil {
   587  		return err
   588  	}
   589  
   590  	assignToErr := dt.Value.AssignTo(dst)
   591  	if assignToErr == nil {
   592  		return nil
   593  	}
   594  
   595  	if dstPtr, ok := dst.(*interface{}); ok {
   596  		*dstPtr = dt.Value.Get()
   597  		return nil
   598  	}
   599  
   600  	// assignToErr might have failed because the type of destination has changed
   601  	newPlan := ci.PlanScan(oid, formatCode, dst)
   602  	if newPlan, sameType := newPlan.(*scanPlanDataTypeAssignTo); !sameType {
   603  		return newPlan.Scan(ci, oid, formatCode, src, dst)
   604  	}
   605  
   606  	return assignToErr
   607  }
   608  
   609  type scanPlanSQLScanner struct{}
   610  
   611  func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   612  	scanner, ok := dst.(sql.Scanner)
   613  	if !ok {
   614  		dv := reflect.ValueOf(dst)
   615  		if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
   616  			newPlan := ci.PlanScan(oid, formatCode, dst)
   617  			return newPlan.Scan(ci, oid, formatCode, src, dst)
   618  		}
   619  		if src == nil {
   620  			// Ensure the pointer points to a zero version of the value
   621  			dv.Elem().Set(reflect.Zero(dv.Elem().Type()))
   622  			return nil
   623  		}
   624  		dv = dv.Elem()
   625  		// If the pointer is to a nil pointer then set that before scanning
   626  		if dv.Kind() == reflect.Ptr && dv.IsNil() {
   627  			dv.Set(reflect.New(dv.Type().Elem()))
   628  		}
   629  		scanner = dv.Interface().(sql.Scanner)
   630  	}
   631  	if src == nil {
   632  		// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
   633  		// text format path would be converted to empty string.
   634  		return scanner.Scan(nil)
   635  	} else if formatCode == BinaryFormatCode {
   636  		return scanner.Scan(src)
   637  	} else {
   638  		return scanner.Scan(string(src))
   639  	}
   640  }
   641  
   642  type scanPlanReflection struct{}
   643  
   644  func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   645  	// We might be given a pointer to something that implements the decoder interface(s),
   646  	// even though the pointer itself doesn't.
   647  	refVal := reflect.ValueOf(dst)
   648  	if refVal.Kind() == reflect.Ptr && refVal.Type().Elem().Kind() == reflect.Ptr {
   649  		// If the database returned NULL, then we set dest as nil to indicate that.
   650  		if src == nil {
   651  			nilPtr := reflect.Zero(refVal.Type().Elem())
   652  			refVal.Elem().Set(nilPtr)
   653  			return nil
   654  		}
   655  
   656  		// We need to allocate an element, and set the destination to it
   657  		// Then we can retry as that element.
   658  		elemPtr := reflect.New(refVal.Type().Elem().Elem())
   659  		refVal.Elem().Set(elemPtr)
   660  
   661  		plan := ci.PlanScan(oid, formatCode, elemPtr.Interface())
   662  		return plan.Scan(ci, oid, formatCode, src, elemPtr.Interface())
   663  	}
   664  
   665  	return scanUnknownType(oid, formatCode, src, dst)
   666  }
   667  
   668  type scanPlanBinaryInt16 struct{}
   669  
   670  func (scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   671  	if src == nil {
   672  		return fmt.Errorf("cannot scan null into %T", dst)
   673  	}
   674  
   675  	if len(src) != 2 {
   676  		return fmt.Errorf("invalid length for int2: %v", len(src))
   677  	}
   678  
   679  	if p, ok := (dst).(*int16); ok {
   680  		*p = int16(binary.BigEndian.Uint16(src))
   681  		return nil
   682  	}
   683  
   684  	newPlan := ci.PlanScan(oid, formatCode, dst)
   685  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   686  }
   687  
   688  type scanPlanBinaryInt32 struct{}
   689  
   690  func (scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   691  	if src == nil {
   692  		return fmt.Errorf("cannot scan null into %T", dst)
   693  	}
   694  
   695  	if len(src) != 4 {
   696  		return fmt.Errorf("invalid length for int4: %v", len(src))
   697  	}
   698  
   699  	if p, ok := (dst).(*int32); ok {
   700  		*p = int32(binary.BigEndian.Uint32(src))
   701  		return nil
   702  	}
   703  
   704  	newPlan := ci.PlanScan(oid, formatCode, dst)
   705  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   706  }
   707  
   708  type scanPlanBinaryInt64 struct{}
   709  
   710  func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   711  	if src == nil {
   712  		return fmt.Errorf("cannot scan null into %T", dst)
   713  	}
   714  
   715  	if len(src) != 8 {
   716  		return fmt.Errorf("invalid length for int8: %v", len(src))
   717  	}
   718  
   719  	if p, ok := (dst).(*int64); ok {
   720  		*p = int64(binary.BigEndian.Uint64(src))
   721  		return nil
   722  	}
   723  
   724  	newPlan := ci.PlanScan(oid, formatCode, dst)
   725  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   726  }
   727  
   728  type scanPlanBinaryFloat32 struct{}
   729  
   730  func (scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   731  	if src == nil {
   732  		return fmt.Errorf("cannot scan null into %T", dst)
   733  	}
   734  
   735  	if len(src) != 4 {
   736  		return fmt.Errorf("invalid length for int4: %v", len(src))
   737  	}
   738  
   739  	if p, ok := (dst).(*float32); ok {
   740  		n := int32(binary.BigEndian.Uint32(src))
   741  		*p = float32(math.Float32frombits(uint32(n)))
   742  		return nil
   743  	}
   744  
   745  	newPlan := ci.PlanScan(oid, formatCode, dst)
   746  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   747  }
   748  
   749  type scanPlanBinaryFloat64 struct{}
   750  
   751  func (scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   752  	if src == nil {
   753  		return fmt.Errorf("cannot scan null into %T", dst)
   754  	}
   755  
   756  	if len(src) != 8 {
   757  		return fmt.Errorf("invalid length for int8: %v", len(src))
   758  	}
   759  
   760  	if p, ok := (dst).(*float64); ok {
   761  		n := int64(binary.BigEndian.Uint64(src))
   762  		*p = float64(math.Float64frombits(uint64(n)))
   763  		return nil
   764  	}
   765  
   766  	newPlan := ci.PlanScan(oid, formatCode, dst)
   767  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   768  }
   769  
   770  type scanPlanBinaryBytes struct{}
   771  
   772  func (scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   773  	if p, ok := (dst).(*[]byte); ok {
   774  		*p = src
   775  		return nil
   776  	}
   777  
   778  	newPlan := ci.PlanScan(oid, formatCode, dst)
   779  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   780  }
   781  
   782  type scanPlanString struct{}
   783  
   784  func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
   785  	if src == nil {
   786  		return fmt.Errorf("cannot scan null into %T", dst)
   787  	}
   788  
   789  	if p, ok := (dst).(*string); ok {
   790  		*p = string(src)
   791  		return nil
   792  	}
   793  
   794  	newPlan := ci.PlanScan(oid, formatCode, dst)
   795  	return newPlan.Scan(ci, oid, formatCode, src, dst)
   796  }
   797  
   798  var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
   799  
   800  func isScanner(dst interface{}) bool {
   801  	if _, ok := dst.(sql.Scanner); ok {
   802  		return true
   803  	}
   804  	if t := reflect.TypeOf(dst); t != nil && t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) {
   805  		return true
   806  	}
   807  	return false
   808  }
   809  
   810  // PlanScan prepares a plan to scan a value into dst.
   811  func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan {
   812  	switch formatCode {
   813  	case BinaryFormatCode:
   814  		switch dst.(type) {
   815  		case *string:
   816  			switch oid {
   817  			case TextOID, VarcharOID:
   818  				return scanPlanString{}
   819  			}
   820  		case *int16:
   821  			if oid == Int2OID {
   822  				return scanPlanBinaryInt16{}
   823  			}
   824  		case *int32:
   825  			if oid == Int4OID {
   826  				return scanPlanBinaryInt32{}
   827  			}
   828  		case *int64:
   829  			if oid == Int8OID {
   830  				return scanPlanBinaryInt64{}
   831  			}
   832  		case *float32:
   833  			if oid == Float4OID {
   834  				return scanPlanBinaryFloat32{}
   835  			}
   836  		case *float64:
   837  			if oid == Float8OID {
   838  				return scanPlanBinaryFloat64{}
   839  			}
   840  		case *[]byte:
   841  			switch oid {
   842  			case ByteaOID, TextOID, VarcharOID, JSONOID:
   843  				return scanPlanBinaryBytes{}
   844  			}
   845  		case BinaryDecoder:
   846  			return scanPlanDstBinaryDecoder{}
   847  		}
   848  	case TextFormatCode:
   849  		switch dst.(type) {
   850  		case *string:
   851  			return scanPlanString{}
   852  		case *[]byte:
   853  			if oid != ByteaOID {
   854  				return scanPlanBinaryBytes{}
   855  			}
   856  		case TextDecoder:
   857  			return scanPlanDstTextDecoder{}
   858  		}
   859  	}
   860  
   861  	var dt *DataType
   862  
   863  	if oid == 0 {
   864  		if dataType, ok := ci.DataTypeForValue(dst); ok {
   865  			dt = dataType
   866  		}
   867  	} else {
   868  		if dataType, ok := ci.DataTypeForOID(oid); ok {
   869  			dt = dataType
   870  		}
   871  	}
   872  
   873  	if dt != nil {
   874  		if isScanner(dst) {
   875  			return (*scanPlanDataTypeSQLScanner)(dt)
   876  		}
   877  		return (*scanPlanDataTypeAssignTo)(dt)
   878  	}
   879  
   880  	if isScanner(dst) {
   881  		return scanPlanSQLScanner{}
   882  	}
   883  
   884  	return scanPlanReflection{}
   885  }
   886  
   887  func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error {
   888  	if dst == nil {
   889  		return nil
   890  	}
   891  
   892  	plan := ci.PlanScan(oid, formatCode, dst)
   893  	return plan.Scan(ci, oid, formatCode, src, dst)
   894  }
   895  
   896  func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error {
   897  	switch dest := dest.(type) {
   898  	case *string:
   899  		if formatCode == BinaryFormatCode {
   900  			return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest)
   901  		}
   902  		*dest = string(buf)
   903  		return nil
   904  	case *[]byte:
   905  		*dest = buf
   906  		return nil
   907  	default:
   908  		if nextDst, retry := GetAssignToDstType(dest); retry {
   909  			return scanUnknownType(oid, formatCode, buf, nextDst)
   910  		}
   911  		return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest)
   912  	}
   913  }
   914  
   915  // NewValue returns a new instance of the same type as v.
   916  func NewValue(v Value) Value {
   917  	if tv, ok := v.(TypeValue); ok {
   918  		return tv.NewTypeValue()
   919  	} else {
   920  		return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value)
   921  	}
   922  }
   923  
   924  var nameValues map[string]Value
   925  
   926  func init() {
   927  	nameValues = map[string]Value{
   928  		"_aclitem":       &ACLItemArray{},
   929  		"_bool":          &BoolArray{},
   930  		"_bpchar":        &BPCharArray{},
   931  		"_bytea":         &ByteaArray{},
   932  		"_cidr":          &CIDRArray{},
   933  		"_date":          &DateArray{},
   934  		"_float4":        &Float4Array{},
   935  		"_float8":        &Float8Array{},
   936  		"_inet":          &InetArray{},
   937  		"_int2":          &Int2Array{},
   938  		"_int4":          &Int4Array{},
   939  		"_int8":          &Int8Array{},
   940  		"_numeric":       &NumericArray{},
   941  		"_text":          &TextArray{},
   942  		"_timestamp":     &TimestampArray{},
   943  		"_timestamptz":   &TimestamptzArray{},
   944  		"_uuid":          &UUIDArray{},
   945  		"_varchar":       &VarcharArray{},
   946  		"_json":          &JSONArray{},
   947  		"_jsonb":         &JSONBArray{},
   948  		"aclitem":        &ACLItem{},
   949  		"bit":            &Bit{},
   950  		"bool":           &Bool{},
   951  		"box":            &Box{},
   952  		"bpchar":         &BPChar{},
   953  		"bytea":          &Bytea{},
   954  		"char":           &QChar{},
   955  		"cid":            &CID{},
   956  		"cidr":           &CIDR{},
   957  		"circle":         &Circle{},
   958  		"date":           &Date{},
   959  		"daterange":      &Daterange{},
   960  		"float4":         &Float4{},
   961  		"float8":         &Float8{},
   962  		"hstore":         &Hstore{},
   963  		"inet":           &Inet{},
   964  		"int2":           &Int2{},
   965  		"int4":           &Int4{},
   966  		"int4range":      &Int4range{},
   967  		"int4multirange": &Int4multirange{},
   968  		"int8":           &Int8{},
   969  		"int8range":      &Int8range{},
   970  		"int8multirange": &Int8multirange{},
   971  		"interval":       &Interval{},
   972  		"json":           &JSON{},
   973  		"jsonb":          &JSONB{},
   974  		"line":           &Line{},
   975  		"lseg":           &Lseg{},
   976  		"ltree":          &Ltree{},
   977  		"macaddr":        &Macaddr{},
   978  		"name":           &Name{},
   979  		"numeric":        &Numeric{},
   980  		"numrange":       &Numrange{},
   981  		"nummultirange":  &Nummultirange{},
   982  		"oid":            &OIDValue{},
   983  		"path":           &Path{},
   984  		"point":          &Point{},
   985  		"polygon":        &Polygon{},
   986  		"record":         &Record{},
   987  		"text":           &Text{},
   988  		"tid":            &TID{},
   989  		"timestamp":      &Timestamp{},
   990  		"timestamptz":    &Timestamptz{},
   991  		"tsrange":        &Tsrange{},
   992  		"_tsrange":       &TsrangeArray{},
   993  		"tstzrange":      &Tstzrange{},
   994  		"_tstzrange":     &TstzrangeArray{},
   995  		"unknown":        &Unknown{},
   996  		"uuid":           &UUID{},
   997  		"varbit":         &Varbit{},
   998  		"varchar":        &Varchar{},
   999  		"xid":            &XID{},
  1000  	}
  1001  }
  1002  

View as plain text