1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "fmt"
7 "strconv"
8 "strings"
9
10 "github.com/jackc/pgx/v5/internal/pgio"
11 )
12
13 const (
14 microsecondsPerSecond = 1000000
15 microsecondsPerMinute = 60 * microsecondsPerSecond
16 microsecondsPerHour = 60 * microsecondsPerMinute
17 microsecondsPerDay = 24 * microsecondsPerHour
18 microsecondsPerMonth = 30 * microsecondsPerDay
19 )
20
21 type IntervalScanner interface {
22 ScanInterval(v Interval) error
23 }
24
25 type IntervalValuer interface {
26 IntervalValue() (Interval, error)
27 }
28
29 type Interval struct {
30 Microseconds int64
31 Days int32
32 Months int32
33 Valid bool
34 }
35
36 func (interval *Interval) ScanInterval(v Interval) error {
37 *interval = v
38 return nil
39 }
40
41 func (interval Interval) IntervalValue() (Interval, error) {
42 return interval, nil
43 }
44
45
46 func (interval *Interval) Scan(src any) error {
47 if src == nil {
48 *interval = Interval{}
49 return nil
50 }
51
52 switch src := src.(type) {
53 case string:
54 return scanPlanTextAnyToIntervalScanner{}.Scan([]byte(src), interval)
55 }
56
57 return fmt.Errorf("cannot scan %T", src)
58 }
59
60
61 func (interval Interval) Value() (driver.Value, error) {
62 if !interval.Valid {
63 return nil, nil
64 }
65
66 buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil)
67 if err != nil {
68 return nil, err
69 }
70 return string(buf), err
71 }
72
73 type IntervalCodec struct{}
74
75 func (IntervalCodec) FormatSupported(format int16) bool {
76 return format == TextFormatCode || format == BinaryFormatCode
77 }
78
79 func (IntervalCodec) PreferredFormat() int16 {
80 return BinaryFormatCode
81 }
82
83 func (IntervalCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
84 if _, ok := value.(IntervalValuer); !ok {
85 return nil
86 }
87
88 switch format {
89 case BinaryFormatCode:
90 return encodePlanIntervalCodecBinary{}
91 case TextFormatCode:
92 return encodePlanIntervalCodecText{}
93 }
94
95 return nil
96 }
97
98 type encodePlanIntervalCodecBinary struct{}
99
100 func (encodePlanIntervalCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
101 interval, err := value.(IntervalValuer).IntervalValue()
102 if err != nil {
103 return nil, err
104 }
105
106 if !interval.Valid {
107 return nil, nil
108 }
109
110 buf = pgio.AppendInt64(buf, interval.Microseconds)
111 buf = pgio.AppendInt32(buf, interval.Days)
112 buf = pgio.AppendInt32(buf, interval.Months)
113 return buf, nil
114 }
115
116 type encodePlanIntervalCodecText struct{}
117
118 func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
119 interval, err := value.(IntervalValuer).IntervalValue()
120 if err != nil {
121 return nil, err
122 }
123
124 if !interval.Valid {
125 return nil, nil
126 }
127
128 if interval.Months != 0 {
129 buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...)
130 buf = append(buf, " mon "...)
131 }
132
133 if interval.Days != 0 {
134 buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...)
135 buf = append(buf, " day "...)
136 }
137
138 absMicroseconds := interval.Microseconds
139 if absMicroseconds < 0 {
140 absMicroseconds = -absMicroseconds
141 buf = append(buf, '-')
142 }
143
144 hours := absMicroseconds / microsecondsPerHour
145 minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
146 seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond
147 microseconds := absMicroseconds % microsecondsPerSecond
148
149 timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds)
150 buf = append(buf, timeStr...)
151 return buf, nil
152 }
153
154 func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
155
156 switch format {
157 case BinaryFormatCode:
158 switch target.(type) {
159 case IntervalScanner:
160 return scanPlanBinaryIntervalToIntervalScanner{}
161 }
162 case TextFormatCode:
163 switch target.(type) {
164 case IntervalScanner:
165 return scanPlanTextAnyToIntervalScanner{}
166 }
167 }
168
169 return nil
170 }
171
172 type scanPlanBinaryIntervalToIntervalScanner struct{}
173
174 func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error {
175 scanner := (dst).(IntervalScanner)
176
177 if src == nil {
178 return scanner.ScanInterval(Interval{})
179 }
180
181 if len(src) != 16 {
182 return fmt.Errorf("Received an invalid size for an interval: %d", len(src))
183 }
184
185 microseconds := int64(binary.BigEndian.Uint64(src))
186 days := int32(binary.BigEndian.Uint32(src[8:]))
187 months := int32(binary.BigEndian.Uint32(src[12:]))
188
189 return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true})
190 }
191
192 type scanPlanTextAnyToIntervalScanner struct{}
193
194 func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error {
195 scanner := (dst).(IntervalScanner)
196
197 if src == nil {
198 return scanner.ScanInterval(Interval{})
199 }
200
201 var microseconds int64
202 var days int32
203 var months int32
204
205 parts := strings.Split(string(src), " ")
206
207 for i := 0; i < len(parts)-1; i += 2 {
208 scalar, err := strconv.ParseInt(parts[i], 10, 64)
209 if err != nil {
210 return fmt.Errorf("bad interval format")
211 }
212
213 switch parts[i+1] {
214 case "year", "years":
215 months += int32(scalar * 12)
216 case "mon", "mons":
217 months += int32(scalar)
218 case "day", "days":
219 days = int32(scalar)
220 }
221 }
222
223 if len(parts)%2 == 1 {
224 timeParts := strings.SplitN(parts[len(parts)-1], ":", 3)
225 if len(timeParts) != 3 {
226 return fmt.Errorf("bad interval format")
227 }
228
229 var negative bool
230 if timeParts[0][0] == '-' {
231 negative = true
232 timeParts[0] = timeParts[0][1:]
233 }
234
235 hours, err := strconv.ParseInt(timeParts[0], 10, 64)
236 if err != nil {
237 return fmt.Errorf("bad interval hour format: %s", timeParts[0])
238 }
239
240 minutes, err := strconv.ParseInt(timeParts[1], 10, 64)
241 if err != nil {
242 return fmt.Errorf("bad interval minute format: %s", timeParts[1])
243 }
244
245 sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".")
246
247 seconds, err := strconv.ParseInt(sec, 10, 64)
248 if err != nil {
249 return fmt.Errorf("bad interval second format: %s", sec)
250 }
251
252 var uSeconds int64
253 if secFracFound {
254 uSeconds, err = strconv.ParseInt(secFrac, 10, 64)
255 if err != nil {
256 return fmt.Errorf("bad interval decimal format: %s", secFrac)
257 }
258
259 for i := 0; i < 6-len(secFrac); i++ {
260 uSeconds *= 10
261 }
262 }
263
264 microseconds = hours * microsecondsPerHour
265 microseconds += minutes * microsecondsPerMinute
266 microseconds += seconds * microsecondsPerSecond
267 microseconds += uSeconds
268
269 if negative {
270 microseconds = -microseconds
271 }
272 }
273
274 return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true})
275 }
276
277 func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
278 return codecDecodeToTextFormat(c, m, oid, format, src)
279 }
280
281 func (c IntervalCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
282 if src == nil {
283 return nil, nil
284 }
285
286 var interval Interval
287 err := codecScan(c, m, oid, format, src, &interval)
288 if err != nil {
289 return nil, err
290 }
291 return interval, nil
292 }
293
View as plain text