1 package pgx
2
3 import (
4 "database/sql/driver"
5 "fmt"
6 "math"
7 "reflect"
8 "time"
9
10 "github.com/jackc/pgio"
11 "github.com/jackc/pgtype"
12 )
13
14
15 const (
16 TextFormatCode = 0
17 BinaryFormatCode = 1
18 )
19
20
21 type SerializationError string
22
23 func (e SerializationError) Error() string {
24 return string(e)
25 }
26
27 func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) {
28 if arg == nil {
29 return nil, nil
30 }
31
32 refVal := reflect.ValueOf(arg)
33 if refVal.Kind() == reflect.Ptr && refVal.IsNil() {
34 return nil, nil
35 }
36
37 switch arg := arg.(type) {
38
39
40
41
42
43
44
45 case *pgtype.JSON:
46 buf, err := arg.EncodeText(ci, nil)
47 if err != nil {
48 return nil, err
49 }
50 if buf == nil {
51 return nil, nil
52 }
53 return string(buf), nil
54 case *pgtype.JSONB:
55 buf, err := arg.EncodeText(ci, nil)
56 if err != nil {
57 return nil, err
58 }
59 if buf == nil {
60 return nil, nil
61 }
62 return string(buf), nil
63
64 case driver.Valuer:
65 return callValuerValue(arg)
66 case pgtype.TextEncoder:
67 buf, err := arg.EncodeText(ci, nil)
68 if err != nil {
69 return nil, err
70 }
71 if buf == nil {
72 return nil, nil
73 }
74 return string(buf), nil
75 case float32:
76 return float64(arg), nil
77 case float64:
78 return arg, nil
79 case bool:
80 return arg, nil
81 case time.Duration:
82 return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil
83 case time.Time:
84 return arg, nil
85 case string:
86 return arg, nil
87 case []byte:
88 return arg, nil
89 case int8:
90 return int64(arg), nil
91 case int16:
92 return int64(arg), nil
93 case int32:
94 return int64(arg), nil
95 case int64:
96 return arg, nil
97 case int:
98 return int64(arg), nil
99 case uint8:
100 return int64(arg), nil
101 case uint16:
102 return int64(arg), nil
103 case uint32:
104 return int64(arg), nil
105 case uint64:
106 if arg > math.MaxInt64 {
107 return nil, fmt.Errorf("arg too big for int64: %v", arg)
108 }
109 return int64(arg), nil
110 case uint:
111 if uint64(arg) > math.MaxInt64 {
112 return nil, fmt.Errorf("arg too big for int64: %v", arg)
113 }
114 return int64(arg), nil
115 }
116
117 if dt, found := ci.DataTypeForValue(arg); found {
118 v := dt.Value
119 err := v.Set(arg)
120 if err != nil {
121 return nil, err
122 }
123 buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil)
124 if err != nil {
125 return nil, err
126 }
127 if buf == nil {
128 return nil, nil
129 }
130 return string(buf), nil
131 }
132
133 if refVal.Kind() == reflect.Ptr {
134 arg = refVal.Elem().Interface()
135 return convertSimpleArgument(ci, arg)
136 }
137
138 if strippedArg, ok := stripNamedType(&refVal); ok {
139 return convertSimpleArgument(ci, strippedArg)
140 }
141 return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
142 }
143
144 func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
145 if arg == nil {
146 return pgio.AppendInt32(buf, -1), nil
147 }
148
149 switch arg := arg.(type) {
150 case pgtype.BinaryEncoder:
151 sp := len(buf)
152 buf = pgio.AppendInt32(buf, -1)
153 argBuf, err := arg.EncodeBinary(ci, buf)
154 if err != nil {
155 return nil, err
156 }
157 if argBuf != nil {
158 buf = argBuf
159 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
160 }
161 return buf, nil
162 case pgtype.TextEncoder:
163 sp := len(buf)
164 buf = pgio.AppendInt32(buf, -1)
165 argBuf, err := arg.EncodeText(ci, buf)
166 if err != nil {
167 return nil, err
168 }
169 if argBuf != nil {
170 buf = argBuf
171 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
172 }
173 return buf, nil
174 case string:
175 buf = pgio.AppendInt32(buf, int32(len(arg)))
176 buf = append(buf, arg...)
177 return buf, nil
178 }
179
180 refVal := reflect.ValueOf(arg)
181
182 if refVal.Kind() == reflect.Ptr {
183 if refVal.IsNil() {
184 return pgio.AppendInt32(buf, -1), nil
185 }
186 arg = refVal.Elem().Interface()
187 return encodePreparedStatementArgument(ci, buf, oid, arg)
188 }
189
190 if dt, ok := ci.DataTypeForOID(oid); ok {
191 value := dt.Value
192 err := value.Set(arg)
193 if err != nil {
194 {
195 if arg, ok := arg.(driver.Valuer); ok {
196 v, err := callValuerValue(arg)
197 if err != nil {
198 return nil, err
199 }
200 return encodePreparedStatementArgument(ci, buf, oid, v)
201 }
202 }
203
204 return nil, err
205 }
206
207 sp := len(buf)
208 buf = pgio.AppendInt32(buf, -1)
209 argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
210 if err != nil {
211 return nil, err
212 }
213 if argBuf != nil {
214 buf = argBuf
215 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
216 }
217 return buf, nil
218 }
219
220 if strippedArg, ok := stripNamedType(&refVal); ok {
221 return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
222 }
223 return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
224 }
225
226
227
228
229 func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 {
230 switch arg := arg.(type) {
231 case pgtype.ParamFormatPreferrer:
232 return arg.PreferredParamFormat()
233 case pgtype.BinaryEncoder:
234 return BinaryFormatCode
235 case string, *string, pgtype.TextEncoder:
236 return TextFormatCode
237 }
238
239 return ci.ParamFormatCodeForOID(oid)
240 }
241
242 func stripNamedType(val *reflect.Value) (interface{}, bool) {
243 switch val.Kind() {
244 case reflect.Int:
245 convVal := int(val.Int())
246 return convVal, reflect.TypeOf(convVal) != val.Type()
247 case reflect.Int8:
248 convVal := int8(val.Int())
249 return convVal, reflect.TypeOf(convVal) != val.Type()
250 case reflect.Int16:
251 convVal := int16(val.Int())
252 return convVal, reflect.TypeOf(convVal) != val.Type()
253 case reflect.Int32:
254 convVal := int32(val.Int())
255 return convVal, reflect.TypeOf(convVal) != val.Type()
256 case reflect.Int64:
257 convVal := int64(val.Int())
258 return convVal, reflect.TypeOf(convVal) != val.Type()
259 case reflect.Uint:
260 convVal := uint(val.Uint())
261 return convVal, reflect.TypeOf(convVal) != val.Type()
262 case reflect.Uint8:
263 convVal := uint8(val.Uint())
264 return convVal, reflect.TypeOf(convVal) != val.Type()
265 case reflect.Uint16:
266 convVal := uint16(val.Uint())
267 return convVal, reflect.TypeOf(convVal) != val.Type()
268 case reflect.Uint32:
269 convVal := uint32(val.Uint())
270 return convVal, reflect.TypeOf(convVal) != val.Type()
271 case reflect.Uint64:
272 convVal := uint64(val.Uint())
273 return convVal, reflect.TypeOf(convVal) != val.Type()
274 case reflect.String:
275 convVal := val.String()
276 return convVal, reflect.TypeOf(convVal) != val.Type()
277 }
278
279 return nil, false
280 }
281
View as plain text