1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "fmt"
7 "strconv"
8 "strings"
9 "time"
10
11 "github.com/jackc/pgio"
12 )
13
14 const (
15 microsecondsPerSecond = 1000000
16 microsecondsPerMinute = 60 * microsecondsPerSecond
17 microsecondsPerHour = 60 * microsecondsPerMinute
18 microsecondsPerDay = 24 * microsecondsPerHour
19 microsecondsPerMonth = 30 * microsecondsPerDay
20 )
21
22 type Interval struct {
23 Microseconds int64
24 Days int32
25 Months int32
26 Status Status
27 }
28
29 func (dst *Interval) Set(src interface{}) error {
30 if src == nil {
31 *dst = Interval{Status: Null}
32 return nil
33 }
34
35 if value, ok := src.(interface{ Get() interface{} }); ok {
36 value2 := value.Get()
37 if value2 != value {
38 return dst.Set(value2)
39 }
40 }
41
42 switch value := src.(type) {
43 case time.Duration:
44 *dst = Interval{Microseconds: int64(value) / 1000, Status: Present}
45 default:
46 if originalSrc, ok := underlyingPtrType(src); ok {
47 return dst.Set(originalSrc)
48 }
49 return fmt.Errorf("cannot convert %v to Interval", value)
50 }
51
52 return nil
53 }
54
55 func (dst Interval) Get() interface{} {
56 switch dst.Status {
57 case Present:
58 return dst
59 case Null:
60 return nil
61 default:
62 return dst.Status
63 }
64 }
65
66 func (src *Interval) AssignTo(dst interface{}) error {
67 switch src.Status {
68 case Present:
69 switch v := dst.(type) {
70 case *time.Duration:
71 us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds
72 *v = time.Duration(us) * time.Microsecond
73 return nil
74 default:
75 if nextDst, retry := GetAssignToDstType(dst); retry {
76 return src.AssignTo(nextDst)
77 }
78 return fmt.Errorf("unable to assign to %T", dst)
79 }
80 case Null:
81 return NullAssignTo(dst)
82 }
83
84 return fmt.Errorf("cannot decode %#v into %T", src, dst)
85 }
86
87 func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error {
88 if src == nil {
89 *dst = Interval{Status: Null}
90 return nil
91 }
92
93 var microseconds int64
94 var days int32
95 var months int32
96
97 parts := strings.Split(string(src), " ")
98
99 for i := 0; i < len(parts)-1; i += 2 {
100 scalar, err := strconv.ParseInt(parts[i], 10, 64)
101 if err != nil {
102 return fmt.Errorf("bad interval format")
103 }
104
105 switch parts[i+1] {
106 case "year", "years":
107 months += int32(scalar * 12)
108 case "mon", "mons":
109 months += int32(scalar)
110 case "day", "days":
111 days = int32(scalar)
112 }
113 }
114
115 if len(parts)%2 == 1 {
116 timeParts := strings.SplitN(parts[len(parts)-1], ":", 3)
117 if len(timeParts) != 3 {
118 return fmt.Errorf("bad interval format")
119 }
120
121 var negative bool
122 if timeParts[0][0] == '-' {
123 negative = true
124 timeParts[0] = timeParts[0][1:]
125 }
126
127 hours, err := strconv.ParseInt(timeParts[0], 10, 64)
128 if err != nil {
129 return fmt.Errorf("bad interval hour format: %s", timeParts[0])
130 }
131
132 minutes, err := strconv.ParseInt(timeParts[1], 10, 64)
133 if err != nil {
134 return fmt.Errorf("bad interval minute format: %s", timeParts[1])
135 }
136
137 secondParts := strings.SplitN(timeParts[2], ".", 2)
138
139 seconds, err := strconv.ParseInt(secondParts[0], 10, 64)
140 if err != nil {
141 return fmt.Errorf("bad interval second format: %s", secondParts[0])
142 }
143
144 var uSeconds int64
145 if len(secondParts) == 2 {
146 uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64)
147 if err != nil {
148 return fmt.Errorf("bad interval decimal format: %s", secondParts[1])
149 }
150
151 for i := 0; i < 6-len(secondParts[1]); i++ {
152 uSeconds *= 10
153 }
154 }
155
156 microseconds = hours * microsecondsPerHour
157 microseconds += minutes * microsecondsPerMinute
158 microseconds += seconds * microsecondsPerSecond
159 microseconds += uSeconds
160
161 if negative {
162 microseconds = -microseconds
163 }
164 }
165
166 *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Status: Present}
167 return nil
168 }
169
170 func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error {
171 if src == nil {
172 *dst = Interval{Status: Null}
173 return nil
174 }
175
176 if len(src) != 16 {
177 return fmt.Errorf("Received an invalid size for an interval: %d", len(src))
178 }
179
180 microseconds := int64(binary.BigEndian.Uint64(src))
181 days := int32(binary.BigEndian.Uint32(src[8:]))
182 months := int32(binary.BigEndian.Uint32(src[12:]))
183
184 *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Status: Present}
185 return nil
186 }
187
188 func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
189 switch src.Status {
190 case Null:
191 return nil, nil
192 case Undefined:
193 return nil, errUndefined
194 }
195
196 if src.Months != 0 {
197 buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...)
198 buf = append(buf, " mon "...)
199 }
200
201 if src.Days != 0 {
202 buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...)
203 buf = append(buf, " day "...)
204 }
205
206 absMicroseconds := src.Microseconds
207 if absMicroseconds < 0 {
208 absMicroseconds = -absMicroseconds
209 buf = append(buf, '-')
210 }
211
212 hours := absMicroseconds / microsecondsPerHour
213 minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
214 seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond
215 microseconds := absMicroseconds % microsecondsPerSecond
216
217 timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds)
218 return append(buf, timeStr...), nil
219 }
220
221
222 func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
223 switch src.Status {
224 case Null:
225 return nil, nil
226 case Undefined:
227 return nil, errUndefined
228 }
229
230 buf = pgio.AppendInt64(buf, src.Microseconds)
231 buf = pgio.AppendInt32(buf, src.Days)
232 return pgio.AppendInt32(buf, src.Months), nil
233 }
234
235
236 func (dst *Interval) Scan(src interface{}) error {
237 if src == nil {
238 *dst = Interval{Status: Null}
239 return nil
240 }
241
242 switch src := src.(type) {
243 case string:
244 return dst.DecodeText(nil, []byte(src))
245 case []byte:
246 srcCopy := make([]byte, len(src))
247 copy(srcCopy, src)
248 return dst.DecodeText(nil, srcCopy)
249 }
250
251 return fmt.Errorf("cannot scan %T", src)
252 }
253
254
255 func (src Interval) Value() (driver.Value, error) {
256 return EncodeValueText(src)
257 }
258
View as plain text