1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "fmt"
6
7 "github.com/jackc/pgx/v5/internal/pgio"
8 )
9
10
11 type RangeValuer interface {
12
13 IsNull() bool
14
15
16 BoundTypes() (lower, upper BoundType)
17
18
19 Bounds() (lower, upper any)
20 }
21
22
23 type RangeScanner interface {
24
25 ScanNull() error
26
27
28
29 ScanBounds() (lowerTarget, upperTarget any)
30
31
32
33
34 SetBoundTypes(lower, upper BoundType) error
35 }
36
37
38 type RangeCodec struct {
39 ElementType *Type
40 }
41
42 func (c *RangeCodec) FormatSupported(format int16) bool {
43 return c.ElementType.Codec.FormatSupported(format)
44 }
45
46 func (c *RangeCodec) PreferredFormat() int16 {
47 if c.FormatSupported(BinaryFormatCode) {
48 return BinaryFormatCode
49 }
50 return TextFormatCode
51 }
52
53 func (c *RangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
54 if _, ok := value.(RangeValuer); !ok {
55 return nil
56 }
57
58 switch format {
59 case BinaryFormatCode:
60 return &encodePlanRangeCodecRangeValuerToBinary{rc: c, m: m}
61 case TextFormatCode:
62 return &encodePlanRangeCodecRangeValuerToText{rc: c, m: m}
63 }
64
65 return nil
66 }
67
68 type encodePlanRangeCodecRangeValuerToBinary struct {
69 rc *RangeCodec
70 m *Map
71 }
72
73 func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
74 getter := value.(RangeValuer)
75
76 if getter.IsNull() {
77 return nil, nil
78 }
79
80 lowerType, upperType := getter.BoundTypes()
81 lower, upper := getter.Bounds()
82
83 var rangeType byte
84 switch lowerType {
85 case Inclusive:
86 rangeType |= lowerInclusiveMask
87 case Unbounded:
88 rangeType |= lowerUnboundedMask
89 case Exclusive:
90 case Empty:
91 return append(buf, emptyMask), nil
92 default:
93 return nil, fmt.Errorf("unknown LowerType: %v", lowerType)
94 }
95
96 switch upperType {
97 case Inclusive:
98 rangeType |= upperInclusiveMask
99 case Unbounded:
100 rangeType |= upperUnboundedMask
101 case Exclusive:
102 default:
103 return nil, fmt.Errorf("unknown UpperType: %v", upperType)
104 }
105
106 buf = append(buf, rangeType)
107
108 if lowerType != Unbounded {
109 if lower == nil {
110 return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
111 }
112
113 sp := len(buf)
114 buf = pgio.AppendInt32(buf, -1)
115
116 lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, lower)
117 if lowerPlan == nil {
118 return nil, fmt.Errorf("cannot encode %v as element of range", lower)
119 }
120
121 buf, err = lowerPlan.Encode(lower, buf)
122 if err != nil {
123 return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
124 }
125 if buf == nil {
126 return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
127 }
128
129 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
130 }
131
132 if upperType != Unbounded {
133 if upper == nil {
134 return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
135 }
136
137 sp := len(buf)
138 buf = pgio.AppendInt32(buf, -1)
139
140 upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, BinaryFormatCode, upper)
141 if upperPlan == nil {
142 return nil, fmt.Errorf("cannot encode %v as element of range", upper)
143 }
144
145 buf, err = upperPlan.Encode(upper, buf)
146 if err != nil {
147 return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
148 }
149 if buf == nil {
150 return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
151 }
152
153 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
154 }
155
156 return buf, nil
157 }
158
159 type encodePlanRangeCodecRangeValuerToText struct {
160 rc *RangeCodec
161 m *Map
162 }
163
164 func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
165 getter := value.(RangeValuer)
166
167 if getter.IsNull() {
168 return nil, nil
169 }
170
171 lowerType, upperType := getter.BoundTypes()
172 lower, upper := getter.Bounds()
173
174 switch lowerType {
175 case Exclusive, Unbounded:
176 buf = append(buf, '(')
177 case Inclusive:
178 buf = append(buf, '[')
179 case Empty:
180 return append(buf, "empty"...), nil
181 default:
182 return nil, fmt.Errorf("unknown lower bound type %v", lowerType)
183 }
184
185 if lowerType != Unbounded {
186 if lower == nil {
187 return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
188 }
189
190 lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower)
191 if lowerPlan == nil {
192 return nil, fmt.Errorf("cannot encode %v as element of range", lower)
193 }
194
195 buf, err = lowerPlan.Encode(lower, buf)
196 if err != nil {
197 return nil, fmt.Errorf("failed to encode %v as element of range: %w", lower, err)
198 }
199 if buf == nil {
200 return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
201 }
202 }
203
204 buf = append(buf, ',')
205
206 if upperType != Unbounded {
207 if upper == nil {
208 return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
209 }
210
211 upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper)
212 if upperPlan == nil {
213 return nil, fmt.Errorf("cannot encode %v as element of range", upper)
214 }
215
216 buf, err = upperPlan.Encode(upper, buf)
217 if err != nil {
218 return nil, fmt.Errorf("failed to encode %v as element of range: %w", upper, err)
219 }
220 if buf == nil {
221 return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
222 }
223 }
224
225 switch upperType {
226 case Exclusive, Unbounded:
227 buf = append(buf, ')')
228 case Inclusive:
229 buf = append(buf, ']')
230 default:
231 return nil, fmt.Errorf("unknown upper bound type %v", upperType)
232 }
233
234 return buf, nil
235 }
236
237 func (c *RangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
238 switch format {
239 case BinaryFormatCode:
240 switch target.(type) {
241 case RangeScanner:
242 return &scanPlanBinaryRangeToRangeScanner{rc: c, m: m}
243 }
244 case TextFormatCode:
245 switch target.(type) {
246 case RangeScanner:
247 return &scanPlanTextRangeToRangeScanner{rc: c, m: m}
248 }
249 }
250
251 return nil
252 }
253
254 type scanPlanBinaryRangeToRangeScanner struct {
255 rc *RangeCodec
256 m *Map
257 }
258
259 func (plan *scanPlanBinaryRangeToRangeScanner) Scan(src []byte, target any) error {
260 rangeScanner := (target).(RangeScanner)
261
262 if src == nil {
263 return rangeScanner.ScanNull()
264 }
265
266 ubr, err := parseUntypedBinaryRange(src)
267 if err != nil {
268 return err
269 }
270
271 if ubr.LowerType == Empty {
272 return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
273 }
274
275 lowerTarget, upperTarget := rangeScanner.ScanBounds()
276
277 if ubr.LowerType == Inclusive || ubr.LowerType == Exclusive {
278 lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, lowerTarget)
279 if lowerPlan == nil {
280 return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
281 }
282
283 err = lowerPlan.Scan(ubr.Lower, lowerTarget)
284 if err != nil {
285 return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
286 }
287 }
288
289 if ubr.UpperType == Inclusive || ubr.UpperType == Exclusive {
290 upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, BinaryFormatCode, upperTarget)
291 if upperPlan == nil {
292 return fmt.Errorf("cannot scan into %v from range element", upperTarget)
293 }
294
295 err = upperPlan.Scan(ubr.Upper, upperTarget)
296 if err != nil {
297 return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
298 }
299 }
300
301 return rangeScanner.SetBoundTypes(ubr.LowerType, ubr.UpperType)
302 }
303
304 type scanPlanTextRangeToRangeScanner struct {
305 rc *RangeCodec
306 m *Map
307 }
308
309 func (plan *scanPlanTextRangeToRangeScanner) Scan(src []byte, target any) error {
310 rangeScanner := (target).(RangeScanner)
311
312 if src == nil {
313 return rangeScanner.ScanNull()
314 }
315
316 utr, err := parseUntypedTextRange(string(src))
317 if err != nil {
318 return err
319 }
320
321 if utr.LowerType == Empty {
322 return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
323 }
324
325 lowerTarget, upperTarget := rangeScanner.ScanBounds()
326
327 if utr.LowerType == Inclusive || utr.LowerType == Exclusive {
328 lowerPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, lowerTarget)
329 if lowerPlan == nil {
330 return fmt.Errorf("cannot scan into %v from range element", lowerTarget)
331 }
332
333 err = lowerPlan.Scan([]byte(utr.Lower), lowerTarget)
334 if err != nil {
335 return fmt.Errorf("cannot scan into %v from range element: %w", lowerTarget, err)
336 }
337 }
338
339 if utr.UpperType == Inclusive || utr.UpperType == Exclusive {
340 upperPlan := plan.m.PlanScan(plan.rc.ElementType.OID, TextFormatCode, upperTarget)
341 if upperPlan == nil {
342 return fmt.Errorf("cannot scan into %v from range element", upperTarget)
343 }
344
345 err = upperPlan.Scan([]byte(utr.Upper), upperTarget)
346 if err != nil {
347 return fmt.Errorf("cannot scan into %v from range element: %w", upperTarget, err)
348 }
349 }
350
351 return rangeScanner.SetBoundTypes(utr.LowerType, utr.UpperType)
352 }
353
354 func (c *RangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
355 if src == nil {
356 return nil, nil
357 }
358
359 switch format {
360 case TextFormatCode:
361 return string(src), nil
362 case BinaryFormatCode:
363 buf := make([]byte, len(src))
364 copy(buf, src)
365 return buf, nil
366 default:
367 return nil, fmt.Errorf("unknown format code %d", format)
368 }
369 }
370
371 func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
372 if src == nil {
373 return nil, nil
374 }
375
376 var r Range[any]
377 err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
378 return r, err
379 }
380
View as plain text