1 package pq
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/binary"
7 "encoding/hex"
8 "errors"
9 "fmt"
10 "math"
11 "regexp"
12 "strconv"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/lib/pq/oid"
18 )
19
20 var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`)
21
22 func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
23 switch v := x.(type) {
24 case []byte:
25 return v
26 default:
27 return encode(parameterStatus, x, oid.T_unknown)
28 }
29 }
30
31 func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
32 switch v := x.(type) {
33 case int64:
34 return strconv.AppendInt(nil, v, 10)
35 case float64:
36 return strconv.AppendFloat(nil, v, 'f', -1, 64)
37 case []byte:
38 if pgtypOid == oid.T_bytea {
39 return encodeBytea(parameterStatus.serverVersion, v)
40 }
41
42 return v
43 case string:
44 if pgtypOid == oid.T_bytea {
45 return encodeBytea(parameterStatus.serverVersion, []byte(v))
46 }
47
48 return []byte(v)
49 case bool:
50 return strconv.AppendBool(nil, v)
51 case time.Time:
52 return formatTs(v)
53
54 default:
55 errorf("encode: unknown type for %T", v)
56 }
57
58 panic("not reached")
59 }
60
61 func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
62 switch f {
63 case formatBinary:
64 return binaryDecode(parameterStatus, s, typ)
65 case formatText:
66 return textDecode(parameterStatus, s, typ)
67 default:
68 panic("not reached")
69 }
70 }
71
72 func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
73 switch typ {
74 case oid.T_bytea:
75 return s
76 case oid.T_int8:
77 return int64(binary.BigEndian.Uint64(s))
78 case oid.T_int4:
79 return int64(int32(binary.BigEndian.Uint32(s)))
80 case oid.T_int2:
81 return int64(int16(binary.BigEndian.Uint16(s)))
82 case oid.T_uuid:
83 b, err := decodeUUIDBinary(s)
84 if err != nil {
85 panic(err)
86 }
87 return b
88
89 default:
90 errorf("don't know how to decode binary parameter of type %d", uint32(typ))
91 }
92
93 panic("not reached")
94 }
95
96 func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
97 switch typ {
98 case oid.T_char, oid.T_varchar, oid.T_text:
99 return string(s)
100 case oid.T_bytea:
101 b, err := parseBytea(s)
102 if err != nil {
103 errorf("%s", err)
104 }
105 return b
106 case oid.T_timestamptz:
107 return parseTs(parameterStatus.currentLocation, string(s))
108 case oid.T_timestamp, oid.T_date:
109 return parseTs(nil, string(s))
110 case oid.T_time:
111 return mustParse("15:04:05", typ, s)
112 case oid.T_timetz:
113 return mustParse("15:04:05-07", typ, s)
114 case oid.T_bool:
115 return s[0] == 't'
116 case oid.T_int8, oid.T_int4, oid.T_int2:
117 i, err := strconv.ParseInt(string(s), 10, 64)
118 if err != nil {
119 errorf("%s", err)
120 }
121 return i
122 case oid.T_float4, oid.T_float8:
123
124
125
126 f, err := strconv.ParseFloat(string(s), 64)
127 if err != nil {
128 errorf("%s", err)
129 }
130 return f
131 }
132
133 return s
134 }
135
136
137
138 func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
139 switch v := x.(type) {
140 case int64:
141 return strconv.AppendInt(buf, v, 10)
142 case float64:
143 return strconv.AppendFloat(buf, v, 'f', -1, 64)
144 case []byte:
145 encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
146 return appendEscapedText(buf, string(encodedBytea))
147 case string:
148 return appendEscapedText(buf, v)
149 case bool:
150 return strconv.AppendBool(buf, v)
151 case time.Time:
152 return append(buf, formatTs(v)...)
153 case nil:
154 return append(buf, "\\N"...)
155 default:
156 errorf("encode: unknown type for %T", v)
157 }
158
159 panic("not reached")
160 }
161
162 func appendEscapedText(buf []byte, text string) []byte {
163 escapeNeeded := false
164 startPos := 0
165 var c byte
166
167
168 for i := 0; i < len(text); i++ {
169 c = text[i]
170 if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
171 escapeNeeded = true
172 startPos = i
173 break
174 }
175 }
176 if !escapeNeeded {
177 return append(buf, text...)
178 }
179
180
181 result := append(buf, text[:startPos]...)
182 for i := startPos; i < len(text); i++ {
183 c = text[i]
184 switch c {
185 case '\\':
186 result = append(result, '\\', '\\')
187 case '\n':
188 result = append(result, '\\', 'n')
189 case '\r':
190 result = append(result, '\\', 'r')
191 case '\t':
192 result = append(result, '\\', 't')
193 default:
194 result = append(result, c)
195 }
196 }
197 return result
198 }
199
200 func mustParse(f string, typ oid.Oid, s []byte) time.Time {
201 str := string(s)
202
203
204 if typ == oid.T_timestamptz || typ == oid.T_timetz {
205 for i := 3; i <= 6; i += 3 {
206 if str[len(str)-i] == ':' {
207 f += ":00"
208 continue
209 }
210 break
211 }
212 }
213
214
215
216
217
218
219 var is2400Time bool
220 switch typ {
221 case oid.T_timetz, oid.T_time:
222 if matches := time2400Regex.FindStringSubmatch(str); matches != nil {
223
224 str = "00:00:00" + str[len(matches[1]):]
225 is2400Time = true
226 }
227 }
228 t, err := time.Parse(f, str)
229 if err != nil {
230 errorf("decode: %s", err)
231 }
232 if is2400Time {
233 t = t.Add(24 * time.Hour)
234 }
235 return t
236 }
237
238 var errInvalidTimestamp = errors.New("invalid timestamp")
239
240 type timestampParser struct {
241 err error
242 }
243
244 func (p *timestampParser) expect(str string, char byte, pos int) {
245 if p.err != nil {
246 return
247 }
248 if pos+1 > len(str) {
249 p.err = errInvalidTimestamp
250 return
251 }
252 if c := str[pos]; c != char && p.err == nil {
253 p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
254 }
255 }
256
257 func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
258 if p.err != nil {
259 return 0
260 }
261 if begin < 0 || end < 0 || begin > end || end > len(str) {
262 p.err = errInvalidTimestamp
263 return 0
264 }
265 result, err := strconv.Atoi(str[begin:end])
266 if err != nil {
267 if p.err == nil {
268 p.err = fmt.Errorf("expected number; got '%v'", str)
269 }
270 return 0
271 }
272 return result
273 }
274
275
276 type locationCache struct {
277 cache map[int]*time.Location
278 lock sync.Mutex
279 }
280
281
282
283
284
285 var globalLocationCache = newLocationCache()
286
287 func newLocationCache() *locationCache {
288 return &locationCache{cache: make(map[int]*time.Location)}
289 }
290
291
292
293 func (c *locationCache) getLocation(offset int) *time.Location {
294 c.lock.Lock()
295 defer c.lock.Unlock()
296
297 location, ok := c.cache[offset]
298 if !ok {
299 location = time.FixedZone("", offset)
300 c.cache[offset] = location
301 }
302
303 return location
304 }
305
306 var infinityTsEnabled = false
307 var infinityTsNegative time.Time
308 var infinityTsPositive time.Time
309
310 const (
311 infinityTsEnabledAlready = "pq: infinity timestamp enabled already"
312 infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
313 )
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335 func EnableInfinityTs(negative time.Time, positive time.Time) {
336 if infinityTsEnabled {
337 panic(infinityTsEnabledAlready)
338 }
339 if !negative.Before(positive) {
340 panic(infinityTsNegativeMustBeSmaller)
341 }
342 infinityTsEnabled = true
343 infinityTsNegative = negative
344 infinityTsPositive = positive
345 }
346
347
350 func disableInfinityTs() {
351 infinityTsEnabled = false
352 }
353
354
355
356
357
358 func parseTs(currentLocation *time.Location, str string) interface{} {
359 switch str {
360 case "-infinity":
361 if infinityTsEnabled {
362 return infinityTsNegative
363 }
364 return []byte(str)
365 case "infinity":
366 if infinityTsEnabled {
367 return infinityTsPositive
368 }
369 return []byte(str)
370 }
371 t, err := ParseTimestamp(currentLocation, str)
372 if err != nil {
373 panic(err)
374 }
375 return t
376 }
377
378
379
380
381
382 func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
383 p := timestampParser{}
384
385 monSep := strings.IndexRune(str, '-')
386
387
388 year := p.mustAtoi(str, 0, monSep)
389 daySep := monSep + 3
390 month := p.mustAtoi(str, monSep+1, daySep)
391 p.expect(str, '-', daySep)
392 timeSep := daySep + 3
393 day := p.mustAtoi(str, daySep+1, timeSep)
394
395 minLen := monSep + len("01-01") + 1
396
397 isBC := strings.HasSuffix(str, " BC")
398 if isBC {
399 minLen += 3
400 }
401
402 var hour, minute, second int
403 if len(str) > minLen {
404 p.expect(str, ' ', timeSep)
405 minSep := timeSep + 3
406 p.expect(str, ':', minSep)
407 hour = p.mustAtoi(str, timeSep+1, minSep)
408 secSep := minSep + 3
409 p.expect(str, ':', secSep)
410 minute = p.mustAtoi(str, minSep+1, secSep)
411 secEnd := secSep + 3
412 second = p.mustAtoi(str, secSep+1, secEnd)
413 }
414 remainderIdx := monSep + len("01-01 00:00:00") + 1
415
416
417
418
419
420 nanoSec := 0
421 tzOff := 0
422
423 if remainderIdx < len(str) && str[remainderIdx] == '.' {
424 fracStart := remainderIdx + 1
425 fracOff := strings.IndexAny(str[fracStart:], "-+Z ")
426 if fracOff < 0 {
427 fracOff = len(str) - fracStart
428 }
429 fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
430 nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
431
432 remainderIdx += fracOff + 1
433 }
434 if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
435
436 var tzSign int
437 switch c := str[tzStart]; c {
438 case '-':
439 tzSign = -1
440 case '+':
441 tzSign = +1
442 default:
443 return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
444 }
445 tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
446 remainderIdx += 3
447 var tzMin, tzSec int
448 if remainderIdx < len(str) && str[remainderIdx] == ':' {
449 tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
450 remainderIdx += 3
451 }
452 if remainderIdx < len(str) && str[remainderIdx] == ':' {
453 tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
454 remainderIdx += 3
455 }
456 tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
457 } else if tzStart < len(str) && str[tzStart] == 'Z' {
458
459 remainderIdx += 1
460 }
461
462 var isoYear int
463
464 if isBC {
465 isoYear = 1 - year
466 remainderIdx += 3
467 } else {
468 isoYear = year
469 }
470 if remainderIdx < len(str) {
471 return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
472 }
473 t := time.Date(isoYear, time.Month(month), day,
474 hour, minute, second, nanoSec,
475 globalLocationCache.getLocation(tzOff))
476
477 if currentLocation != nil {
478
479
480
481 lt := t.In(currentLocation)
482 _, newOff := lt.Zone()
483 if newOff == tzOff {
484 t = lt
485 }
486 }
487
488 return t, p.err
489 }
490
491
492 func formatTs(t time.Time) []byte {
493 if infinityTsEnabled {
494
495 if !t.After(infinityTsNegative) {
496 return []byte("-infinity")
497 }
498
499 if !t.Before(infinityTsPositive) {
500 return []byte("infinity")
501 }
502 }
503 return FormatTimestamp(t)
504 }
505
506
507 func FormatTimestamp(t time.Time) []byte {
508
509
510
511 bc := false
512 if t.Year() <= 0 {
513
514 t = t.AddDate((-t.Year())*2+1, 0, 0)
515 bc = true
516 }
517 b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
518
519 _, offset := t.Zone()
520 offset %= 60
521 if offset != 0 {
522
523 if offset < 0 {
524 offset = -offset
525 }
526
527 b = append(b, ':')
528 if offset < 10 {
529 b = append(b, '0')
530 }
531 b = strconv.AppendInt(b, int64(offset), 10)
532 }
533
534 if bc {
535 b = append(b, " BC"...)
536 }
537 return b
538 }
539
540
541
542 func parseBytea(s []byte) (result []byte, err error) {
543 if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
544
545 s = s[2:]
546 result = make([]byte, hex.DecodedLen(len(s)))
547 _, err := hex.Decode(result, s)
548 if err != nil {
549 return nil, err
550 }
551 } else {
552
553 for len(s) > 0 {
554 if s[0] == '\\' {
555
556 if len(s) >= 2 && s[1] == '\\' {
557 result = append(result, '\\')
558 s = s[2:]
559 continue
560 }
561
562
563 if len(s) < 4 {
564 return nil, fmt.Errorf("invalid bytea sequence %v", s)
565 }
566 r, err := strconv.ParseUint(string(s[1:4]), 8, 8)
567 if err != nil {
568 return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
569 }
570 result = append(result, byte(r))
571 s = s[4:]
572 } else {
573
574
575 i := bytes.IndexByte(s, '\\')
576 if i == -1 {
577 result = append(result, s...)
578 break
579 }
580 result = append(result, s[:i]...)
581 s = s[i:]
582 }
583 }
584 }
585
586 return result, nil
587 }
588
589 func encodeBytea(serverVersion int, v []byte) (result []byte) {
590 if serverVersion >= 90000 {
591
592 result = make([]byte, 2+hex.EncodedLen(len(v)))
593 result[0] = '\\'
594 result[1] = 'x'
595 hex.Encode(result[2:], v)
596 } else {
597
598 for _, b := range v {
599 if b == '\\' {
600 result = append(result, '\\', '\\')
601 } else if b < 0x20 || b > 0x7e {
602 result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
603 } else {
604 result = append(result, b)
605 }
606 }
607 }
608
609 return result
610 }
611
612
613
614
615 type NullTime struct {
616 Time time.Time
617 Valid bool
618 }
619
620
621 func (nt *NullTime) Scan(value interface{}) error {
622 nt.Time, nt.Valid = value.(time.Time)
623 return nil
624 }
625
626
627 func (nt NullTime) Value() (driver.Value, error) {
628 if !nt.Valid {
629 return nil, nil
630 }
631 return nt.Time, nil
632 }
633
View as plain text