1 package pgx
2
3 import (
4 "database/sql/driver"
5 "fmt"
6 "reflect"
7
8 "github.com/jackc/pgtype"
9 )
10
11 type extendedQueryBuilder struct {
12 paramValues [][]byte
13 paramValueBytes []byte
14 paramFormats []int16
15 resultFormats []int16
16 }
17
18 func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error {
19 f := chooseParameterFormatCode(ci, oid, arg)
20 eqb.paramFormats = append(eqb.paramFormats, f)
21
22 v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg)
23 if err != nil {
24 return err
25 }
26 eqb.paramValues = append(eqb.paramValues, v)
27
28 return nil
29 }
30
31 func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) {
32 eqb.resultFormats = append(eqb.resultFormats, f)
33 }
34
35
36 func (eqb *extendedQueryBuilder) Reset() {
37 eqb.paramValues = eqb.paramValues[0:0]
38 eqb.paramValueBytes = eqb.paramValueBytes[0:0]
39 eqb.paramFormats = eqb.paramFormats[0:0]
40 eqb.resultFormats = eqb.resultFormats[0:0]
41
42 if cap(eqb.paramValues) > 64 {
43 eqb.paramValues = make([][]byte, 0, 64)
44 }
45
46 if cap(eqb.paramValueBytes) > 256 {
47 eqb.paramValueBytes = make([]byte, 0, 256)
48 }
49
50 if cap(eqb.paramFormats) > 64 {
51 eqb.paramFormats = make([]int16, 0, 64)
52 }
53 if cap(eqb.resultFormats) > 64 {
54 eqb.resultFormats = make([]int16, 0, 64)
55 }
56 }
57
58 func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
59 if arg == nil {
60 return nil, nil
61 }
62
63 refVal := reflect.ValueOf(arg)
64 argIsPtr := refVal.Kind() == reflect.Ptr
65
66 if argIsPtr && refVal.IsNil() {
67 return nil, nil
68 }
69
70 if eqb.paramValueBytes == nil {
71 eqb.paramValueBytes = make([]byte, 0, 128)
72 }
73
74 var err error
75 var buf []byte
76 pos := len(eqb.paramValueBytes)
77
78 if arg, ok := arg.(string); ok {
79 return []byte(arg), nil
80 }
81
82 if formatCode == TextFormatCode {
83 if arg, ok := arg.(pgtype.TextEncoder); ok {
84 buf, err = arg.EncodeText(ci, eqb.paramValueBytes)
85 if err != nil {
86 return nil, err
87 }
88 if buf == nil {
89 return nil, nil
90 }
91 eqb.paramValueBytes = buf
92 return eqb.paramValueBytes[pos:], nil
93 }
94 } else if formatCode == BinaryFormatCode {
95 if arg, ok := arg.(pgtype.BinaryEncoder); ok {
96 buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes)
97 if err != nil {
98 return nil, err
99 }
100 if buf == nil {
101 return nil, nil
102 }
103 eqb.paramValueBytes = buf
104 return eqb.paramValueBytes[pos:], nil
105 }
106 }
107
108 if argIsPtr {
109
110
111 arg = refVal.Elem().Interface()
112 return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg)
113 }
114
115 if dt, ok := ci.DataTypeForOID(oid); ok {
116 value := dt.Value
117 err := value.Set(arg)
118 if err != nil {
119 {
120 if arg, ok := arg.(driver.Valuer); ok {
121 v, err := callValuerValue(arg)
122 if err != nil {
123 return nil, err
124 }
125 return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
126 }
127 }
128
129 return nil, err
130 }
131
132 return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
133 }
134
135
136
137 if dt, ok := ci.DataTypeForValue(arg); ok {
138 value := dt.Value
139 if textEncoder, ok := value.(pgtype.TextEncoder); ok {
140 err := value.Set(arg)
141 if err != nil {
142 return nil, err
143 }
144
145 buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes)
146 if err != nil {
147 return nil, err
148 }
149 if buf == nil {
150 return nil, nil
151 }
152 eqb.paramValueBytes = buf
153 return eqb.paramValueBytes[pos:], nil
154 }
155 }
156
157 if strippedArg, ok := stripNamedType(&refVal); ok {
158 return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg)
159 }
160 return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
161 }
162
View as plain text