1 package pgx
2
3 import (
4 "database/sql/driver"
5 "fmt"
6
7 "github.com/jackc/pgx/v5/internal/anynil"
8 "github.com/jackc/pgx/v5/pgconn"
9 "github.com/jackc/pgx/v5/pgtype"
10 )
11
12
13
14 type ExtendedQueryBuilder struct {
15 ParamValues [][]byte
16 paramValueBytes []byte
17 ParamFormats []int16
18 ResultFormats []int16
19 }
20
21
22
23 func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
24 eqb.reset()
25
26 anynil.NormalizeSlice(args)
27
28 if sd == nil {
29 return eqb.appendParamsForQueryExecModeExec(m, args)
30 }
31
32 if len(sd.ParamOIDs) != len(args) {
33 return fmt.Errorf("mismatched param and argument count")
34 }
35
36 for i := range args {
37 err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
38 if err != nil {
39 err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
40 return err
41 }
42 }
43
44 for i := range sd.Fields {
45 eqb.appendResultFormat(m.FormatCodeForOID(sd.Fields[i].DataTypeOID))
46 }
47
48 return nil
49 }
50
51
52
53 func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
54 if format == -1 {
55 preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
56 preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
57 if preferredErr == nil {
58 return nil
59 }
60
61 var otherFormat int16
62 if preferredFormat == TextFormatCode {
63 otherFormat = BinaryFormatCode
64 } else {
65 otherFormat = TextFormatCode
66 }
67
68 otherErr := eqb.appendParam(m, oid, otherFormat, arg)
69 if otherErr == nil {
70 return nil
71 }
72
73 return preferredErr
74 }
75
76 v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
77 if err != nil {
78 return err
79 }
80
81 eqb.ParamFormats = append(eqb.ParamFormats, format)
82 eqb.ParamValues = append(eqb.ParamValues, v)
83
84 return nil
85 }
86
87
88 func (eqb *ExtendedQueryBuilder) appendResultFormat(format int16) {
89 eqb.ResultFormats = append(eqb.ResultFormats, format)
90 }
91
92
93 func (eqb *ExtendedQueryBuilder) reset() {
94 eqb.ParamValues = eqb.ParamValues[0:0]
95 eqb.paramValueBytes = eqb.paramValueBytes[0:0]
96 eqb.ParamFormats = eqb.ParamFormats[0:0]
97 eqb.ResultFormats = eqb.ResultFormats[0:0]
98
99 if cap(eqb.ParamValues) > 64 {
100 eqb.ParamValues = make([][]byte, 0, 64)
101 }
102
103 if cap(eqb.paramValueBytes) > 256 {
104 eqb.paramValueBytes = make([]byte, 0, 256)
105 }
106
107 if cap(eqb.ParamFormats) > 64 {
108 eqb.ParamFormats = make([]int16, 0, 64)
109 }
110 if cap(eqb.ResultFormats) > 64 {
111 eqb.ResultFormats = make([]int16, 0, 64)
112 }
113 }
114
115 func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
116 if anynil.Is(arg) {
117 return nil, nil
118 }
119
120 if eqb.paramValueBytes == nil {
121 eqb.paramValueBytes = make([]byte, 0, 128)
122 }
123
124 pos := len(eqb.paramValueBytes)
125
126 buf, err := m.Encode(oid, formatCode, arg, eqb.paramValueBytes)
127 if err != nil {
128 return nil, err
129 }
130 if buf == nil {
131 return nil, nil
132 }
133 eqb.paramValueBytes = buf
134 return eqb.paramValueBytes[pos:], nil
135 }
136
137
138
139
140 func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid uint32, arg any) int16 {
141 switch arg.(type) {
142 case string, *string:
143 return TextFormatCode
144 }
145
146 return m.FormatCodeForOID(oid)
147 }
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162 func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
163 for _, arg := range args {
164 if arg == nil {
165 err := eqb.appendParam(m, 0, TextFormatCode, arg)
166 if err != nil {
167 return err
168 }
169 } else {
170 dt, ok := m.TypeForValue(arg)
171 if !ok {
172 var tv pgtype.TextValuer
173 if tv, ok = arg.(pgtype.TextValuer); ok {
174 t, err := tv.TextValue()
175 if err != nil {
176 return err
177 }
178
179 dt, ok = m.TypeForOID(pgtype.TextOID)
180 if ok {
181 arg = t
182 }
183 }
184 }
185 if !ok {
186 var dv driver.Valuer
187 if dv, ok = arg.(driver.Valuer); ok {
188 v, err := dv.Value()
189 if err != nil {
190 return err
191 }
192 dt, ok = m.TypeForValue(v)
193 if ok {
194 arg = v
195 }
196 }
197 }
198 if !ok {
199 var str fmt.Stringer
200 if str, ok = arg.(fmt.Stringer); ok {
201 dt, ok = m.TypeForOID(pgtype.TextOID)
202 if ok {
203 arg = str.String()
204 }
205 }
206 }
207 if !ok {
208 return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
209 }
210 err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
211 if err != nil {
212 return err
213 }
214 }
215 }
216
217 return nil
218 }
219
View as plain text