...

Source file src/github.com/lib/pq/encode.go

Documentation: github.com/lib/pq

     1  package pq
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql/driver"
     6  	"encoding/binary"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  	"math"
    11  	"regexp"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/lib/pq/oid"
    18  )
    19  
    20  var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`)
    21  
    22  func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
    23  	switch v := x.(type) {
    24  	case []byte:
    25  		return v
    26  	default:
    27  		return encode(parameterStatus, x, oid.T_unknown)
    28  	}
    29  }
    30  
    31  func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
    32  	switch v := x.(type) {
    33  	case int64:
    34  		return strconv.AppendInt(nil, v, 10)
    35  	case float64:
    36  		return strconv.AppendFloat(nil, v, 'f', -1, 64)
    37  	case []byte:
    38  		if pgtypOid == oid.T_bytea {
    39  			return encodeBytea(parameterStatus.serverVersion, v)
    40  		}
    41  
    42  		return v
    43  	case string:
    44  		if pgtypOid == oid.T_bytea {
    45  			return encodeBytea(parameterStatus.serverVersion, []byte(v))
    46  		}
    47  
    48  		return []byte(v)
    49  	case bool:
    50  		return strconv.AppendBool(nil, v)
    51  	case time.Time:
    52  		return formatTs(v)
    53  
    54  	default:
    55  		errorf("encode: unknown type for %T", v)
    56  	}
    57  
    58  	panic("not reached")
    59  }
    60  
    61  func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
    62  	switch f {
    63  	case formatBinary:
    64  		return binaryDecode(parameterStatus, s, typ)
    65  	case formatText:
    66  		return textDecode(parameterStatus, s, typ)
    67  	default:
    68  		panic("not reached")
    69  	}
    70  }
    71  
    72  func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
    73  	switch typ {
    74  	case oid.T_bytea:
    75  		return s
    76  	case oid.T_int8:
    77  		return int64(binary.BigEndian.Uint64(s))
    78  	case oid.T_int4:
    79  		return int64(int32(binary.BigEndian.Uint32(s)))
    80  	case oid.T_int2:
    81  		return int64(int16(binary.BigEndian.Uint16(s)))
    82  	case oid.T_uuid:
    83  		b, err := decodeUUIDBinary(s)
    84  		if err != nil {
    85  			panic(err)
    86  		}
    87  		return b
    88  
    89  	default:
    90  		errorf("don't know how to decode binary parameter of type %d", uint32(typ))
    91  	}
    92  
    93  	panic("not reached")
    94  }
    95  
    96  func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
    97  	switch typ {
    98  	case oid.T_char, oid.T_varchar, oid.T_text:
    99  		return string(s)
   100  	case oid.T_bytea:
   101  		b, err := parseBytea(s)
   102  		if err != nil {
   103  			errorf("%s", err)
   104  		}
   105  		return b
   106  	case oid.T_timestamptz:
   107  		return parseTs(parameterStatus.currentLocation, string(s))
   108  	case oid.T_timestamp, oid.T_date:
   109  		return parseTs(nil, string(s))
   110  	case oid.T_time:
   111  		return mustParse("15:04:05", typ, s)
   112  	case oid.T_timetz:
   113  		return mustParse("15:04:05-07", typ, s)
   114  	case oid.T_bool:
   115  		return s[0] == 't'
   116  	case oid.T_int8, oid.T_int4, oid.T_int2:
   117  		i, err := strconv.ParseInt(string(s), 10, 64)
   118  		if err != nil {
   119  			errorf("%s", err)
   120  		}
   121  		return i
   122  	case oid.T_float4, oid.T_float8:
   123  		// We always use 64 bit parsing, regardless of whether the input text is for
   124  		// a float4 or float8, because clients expect float64s for all float datatypes
   125  		// and returning a 32-bit parsed float64 produces lossy results.
   126  		f, err := strconv.ParseFloat(string(s), 64)
   127  		if err != nil {
   128  			errorf("%s", err)
   129  		}
   130  		return f
   131  	}
   132  
   133  	return s
   134  }
   135  
   136  // appendEncodedText encodes item in text format as required by COPY
   137  // and appends to buf
   138  func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
   139  	switch v := x.(type) {
   140  	case int64:
   141  		return strconv.AppendInt(buf, v, 10)
   142  	case float64:
   143  		return strconv.AppendFloat(buf, v, 'f', -1, 64)
   144  	case []byte:
   145  		encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
   146  		return appendEscapedText(buf, string(encodedBytea))
   147  	case string:
   148  		return appendEscapedText(buf, v)
   149  	case bool:
   150  		return strconv.AppendBool(buf, v)
   151  	case time.Time:
   152  		return append(buf, formatTs(v)...)
   153  	case nil:
   154  		return append(buf, "\\N"...)
   155  	default:
   156  		errorf("encode: unknown type for %T", v)
   157  	}
   158  
   159  	panic("not reached")
   160  }
   161  
   162  func appendEscapedText(buf []byte, text string) []byte {
   163  	escapeNeeded := false
   164  	startPos := 0
   165  	var c byte
   166  
   167  	// check if we need to escape
   168  	for i := 0; i < len(text); i++ {
   169  		c = text[i]
   170  		if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
   171  			escapeNeeded = true
   172  			startPos = i
   173  			break
   174  		}
   175  	}
   176  	if !escapeNeeded {
   177  		return append(buf, text...)
   178  	}
   179  
   180  	// copy till first char to escape, iterate the rest
   181  	result := append(buf, text[:startPos]...)
   182  	for i := startPos; i < len(text); i++ {
   183  		c = text[i]
   184  		switch c {
   185  		case '\\':
   186  			result = append(result, '\\', '\\')
   187  		case '\n':
   188  			result = append(result, '\\', 'n')
   189  		case '\r':
   190  			result = append(result, '\\', 'r')
   191  		case '\t':
   192  			result = append(result, '\\', 't')
   193  		default:
   194  			result = append(result, c)
   195  		}
   196  	}
   197  	return result
   198  }
   199  
   200  func mustParse(f string, typ oid.Oid, s []byte) time.Time {
   201  	str := string(s)
   202  
   203  	// Check for a minute and second offset in the timezone.
   204  	if typ == oid.T_timestamptz || typ == oid.T_timetz {
   205  		for i := 3; i <= 6; i += 3 {
   206  			if str[len(str)-i] == ':' {
   207  				f += ":00"
   208  				continue
   209  			}
   210  			break
   211  		}
   212  	}
   213  
   214  	// Special case for 24:00 time.
   215  	// Unfortunately, golang does not parse 24:00 as a proper time.
   216  	// In this case, we want to try "round to the next day", to differentiate.
   217  	// As such, we find if the 24:00 time matches at the beginning; if so,
   218  	// we default it back to 00:00 but add a day later.
   219  	var is2400Time bool
   220  	switch typ {
   221  	case oid.T_timetz, oid.T_time:
   222  		if matches := time2400Regex.FindStringSubmatch(str); matches != nil {
   223  			// Concatenate timezone information at the back.
   224  			str = "00:00:00" + str[len(matches[1]):]
   225  			is2400Time = true
   226  		}
   227  	}
   228  	t, err := time.Parse(f, str)
   229  	if err != nil {
   230  		errorf("decode: %s", err)
   231  	}
   232  	if is2400Time {
   233  		t = t.Add(24 * time.Hour)
   234  	}
   235  	return t
   236  }
   237  
   238  var errInvalidTimestamp = errors.New("invalid timestamp")
   239  
   240  type timestampParser struct {
   241  	err error
   242  }
   243  
   244  func (p *timestampParser) expect(str string, char byte, pos int) {
   245  	if p.err != nil {
   246  		return
   247  	}
   248  	if pos+1 > len(str) {
   249  		p.err = errInvalidTimestamp
   250  		return
   251  	}
   252  	if c := str[pos]; c != char && p.err == nil {
   253  		p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
   254  	}
   255  }
   256  
   257  func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
   258  	if p.err != nil {
   259  		return 0
   260  	}
   261  	if begin < 0 || end < 0 || begin > end || end > len(str) {
   262  		p.err = errInvalidTimestamp
   263  		return 0
   264  	}
   265  	result, err := strconv.Atoi(str[begin:end])
   266  	if err != nil {
   267  		if p.err == nil {
   268  			p.err = fmt.Errorf("expected number; got '%v'", str)
   269  		}
   270  		return 0
   271  	}
   272  	return result
   273  }
   274  
   275  // The location cache caches the time zones typically used by the client.
   276  type locationCache struct {
   277  	cache map[int]*time.Location
   278  	lock  sync.Mutex
   279  }
   280  
   281  // All connections share the same list of timezones. Benchmarking shows that
   282  // about 5% speed could be gained by putting the cache in the connection and
   283  // losing the mutex, at the cost of a small amount of memory and a somewhat
   284  // significant increase in code complexity.
   285  var globalLocationCache = newLocationCache()
   286  
   287  func newLocationCache() *locationCache {
   288  	return &locationCache{cache: make(map[int]*time.Location)}
   289  }
   290  
   291  // Returns the cached timezone for the specified offset, creating and caching
   292  // it if necessary.
   293  func (c *locationCache) getLocation(offset int) *time.Location {
   294  	c.lock.Lock()
   295  	defer c.lock.Unlock()
   296  
   297  	location, ok := c.cache[offset]
   298  	if !ok {
   299  		location = time.FixedZone("", offset)
   300  		c.cache[offset] = location
   301  	}
   302  
   303  	return location
   304  }
   305  
   306  var infinityTsEnabled = false
   307  var infinityTsNegative time.Time
   308  var infinityTsPositive time.Time
   309  
   310  const (
   311  	infinityTsEnabledAlready        = "pq: infinity timestamp enabled already"
   312  	infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
   313  )
   314  
   315  // EnableInfinityTs controls the handling of Postgres' "-infinity" and
   316  // "infinity" "timestamp"s.
   317  //
   318  // If EnableInfinityTs is not called, "-infinity" and "infinity" will return
   319  // []byte("-infinity") and []byte("infinity") respectively, and potentially
   320  // cause error "sql: Scan error on column index 0: unsupported driver -> Scan
   321  // pair: []uint8 -> *time.Time", when scanning into a time.Time value.
   322  //
   323  // Once EnableInfinityTs has been called, all connections created using this
   324  // driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
   325  // "timestamp with time zone" and "date" types to the predefined minimum and
   326  // maximum times, respectively.  When encoding time.Time values, any time which
   327  // equals or precedes the predefined minimum time will be encoded to
   328  // "-infinity".  Any values at or past the maximum time will similarly be
   329  // encoded to "infinity".
   330  //
   331  // If EnableInfinityTs is called with negative >= positive, it will panic.
   332  // Calling EnableInfinityTs after a connection has been established results in
   333  // undefined behavior.  If EnableInfinityTs is called more than once, it will
   334  // panic.
   335  func EnableInfinityTs(negative time.Time, positive time.Time) {
   336  	if infinityTsEnabled {
   337  		panic(infinityTsEnabledAlready)
   338  	}
   339  	if !negative.Before(positive) {
   340  		panic(infinityTsNegativeMustBeSmaller)
   341  	}
   342  	infinityTsEnabled = true
   343  	infinityTsNegative = negative
   344  	infinityTsPositive = positive
   345  }
   346  
   347  /*
   348   * Testing might want to toggle infinityTsEnabled
   349   */
   350  func disableInfinityTs() {
   351  	infinityTsEnabled = false
   352  }
   353  
   354  // This is a time function specific to the Postgres default DateStyle
   355  // setting ("ISO, MDY"), the only one we currently support. This
   356  // accounts for the discrepancies between the parsing available with
   357  // time.Parse and the Postgres date formatting quirks.
   358  func parseTs(currentLocation *time.Location, str string) interface{} {
   359  	switch str {
   360  	case "-infinity":
   361  		if infinityTsEnabled {
   362  			return infinityTsNegative
   363  		}
   364  		return []byte(str)
   365  	case "infinity":
   366  		if infinityTsEnabled {
   367  			return infinityTsPositive
   368  		}
   369  		return []byte(str)
   370  	}
   371  	t, err := ParseTimestamp(currentLocation, str)
   372  	if err != nil {
   373  		panic(err)
   374  	}
   375  	return t
   376  }
   377  
   378  // ParseTimestamp parses Postgres' text format. It returns a time.Time in
   379  // currentLocation iff that time's offset agrees with the offset sent from the
   380  // Postgres server. Otherwise, ParseTimestamp returns a time.Time with the
   381  // fixed offset offset provided by the Postgres server.
   382  func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
   383  	p := timestampParser{}
   384  
   385  	monSep := strings.IndexRune(str, '-')
   386  	// this is Gregorian year, not ISO Year
   387  	// In Gregorian system, the year 1 BC is followed by AD 1
   388  	year := p.mustAtoi(str, 0, monSep)
   389  	daySep := monSep + 3
   390  	month := p.mustAtoi(str, monSep+1, daySep)
   391  	p.expect(str, '-', daySep)
   392  	timeSep := daySep + 3
   393  	day := p.mustAtoi(str, daySep+1, timeSep)
   394  
   395  	minLen := monSep + len("01-01") + 1
   396  
   397  	isBC := strings.HasSuffix(str, " BC")
   398  	if isBC {
   399  		minLen += 3
   400  	}
   401  
   402  	var hour, minute, second int
   403  	if len(str) > minLen {
   404  		p.expect(str, ' ', timeSep)
   405  		minSep := timeSep + 3
   406  		p.expect(str, ':', minSep)
   407  		hour = p.mustAtoi(str, timeSep+1, minSep)
   408  		secSep := minSep + 3
   409  		p.expect(str, ':', secSep)
   410  		minute = p.mustAtoi(str, minSep+1, secSep)
   411  		secEnd := secSep + 3
   412  		second = p.mustAtoi(str, secSep+1, secEnd)
   413  	}
   414  	remainderIdx := monSep + len("01-01 00:00:00") + 1
   415  	// Three optional (but ordered) sections follow: the
   416  	// fractional seconds, the time zone offset, and the BC
   417  	// designation. We set them up here and adjust the other
   418  	// offsets if the preceding sections exist.
   419  
   420  	nanoSec := 0
   421  	tzOff := 0
   422  
   423  	if remainderIdx < len(str) && str[remainderIdx] == '.' {
   424  		fracStart := remainderIdx + 1
   425  		fracOff := strings.IndexAny(str[fracStart:], "-+Z ")
   426  		if fracOff < 0 {
   427  			fracOff = len(str) - fracStart
   428  		}
   429  		fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
   430  		nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
   431  
   432  		remainderIdx += fracOff + 1
   433  	}
   434  	if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
   435  		// time zone separator is always '-' or '+' or 'Z' (UTC is +00)
   436  		var tzSign int
   437  		switch c := str[tzStart]; c {
   438  		case '-':
   439  			tzSign = -1
   440  		case '+':
   441  			tzSign = +1
   442  		default:
   443  			return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
   444  		}
   445  		tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
   446  		remainderIdx += 3
   447  		var tzMin, tzSec int
   448  		if remainderIdx < len(str) && str[remainderIdx] == ':' {
   449  			tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
   450  			remainderIdx += 3
   451  		}
   452  		if remainderIdx < len(str) && str[remainderIdx] == ':' {
   453  			tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
   454  			remainderIdx += 3
   455  		}
   456  		tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
   457  	} else if tzStart < len(str) && str[tzStart] == 'Z' {
   458  		// time zone Z separator indicates UTC is +00
   459  		remainderIdx += 1
   460  	}
   461  
   462  	var isoYear int
   463  
   464  	if isBC {
   465  		isoYear = 1 - year
   466  		remainderIdx += 3
   467  	} else {
   468  		isoYear = year
   469  	}
   470  	if remainderIdx < len(str) {
   471  		return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
   472  	}
   473  	t := time.Date(isoYear, time.Month(month), day,
   474  		hour, minute, second, nanoSec,
   475  		globalLocationCache.getLocation(tzOff))
   476  
   477  	if currentLocation != nil {
   478  		// Set the location of the returned Time based on the session's
   479  		// TimeZone value, but only if the local time zone database agrees with
   480  		// the remote database on the offset.
   481  		lt := t.In(currentLocation)
   482  		_, newOff := lt.Zone()
   483  		if newOff == tzOff {
   484  			t = lt
   485  		}
   486  	}
   487  
   488  	return t, p.err
   489  }
   490  
   491  // formatTs formats t into a format postgres understands.
   492  func formatTs(t time.Time) []byte {
   493  	if infinityTsEnabled {
   494  		// t <= -infinity : ! (t > -infinity)
   495  		if !t.After(infinityTsNegative) {
   496  			return []byte("-infinity")
   497  		}
   498  		// t >= infinity : ! (!t < infinity)
   499  		if !t.Before(infinityTsPositive) {
   500  			return []byte("infinity")
   501  		}
   502  	}
   503  	return FormatTimestamp(t)
   504  }
   505  
   506  // FormatTimestamp formats t into Postgres' text format for timestamps.
   507  func FormatTimestamp(t time.Time) []byte {
   508  	// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
   509  	// minus sign preferred by Go.
   510  	// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
   511  	bc := false
   512  	if t.Year() <= 0 {
   513  		// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
   514  		t = t.AddDate((-t.Year())*2+1, 0, 0)
   515  		bc = true
   516  	}
   517  	b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
   518  
   519  	_, offset := t.Zone()
   520  	offset %= 60
   521  	if offset != 0 {
   522  		// RFC3339Nano already printed the minus sign
   523  		if offset < 0 {
   524  			offset = -offset
   525  		}
   526  
   527  		b = append(b, ':')
   528  		if offset < 10 {
   529  			b = append(b, '0')
   530  		}
   531  		b = strconv.AppendInt(b, int64(offset), 10)
   532  	}
   533  
   534  	if bc {
   535  		b = append(b, " BC"...)
   536  	}
   537  	return b
   538  }
   539  
   540  // Parse a bytea value received from the server.  Both "hex" and the legacy
   541  // "escape" format are supported.
   542  func parseBytea(s []byte) (result []byte, err error) {
   543  	if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
   544  		// bytea_output = hex
   545  		s = s[2:] // trim off leading "\\x"
   546  		result = make([]byte, hex.DecodedLen(len(s)))
   547  		_, err := hex.Decode(result, s)
   548  		if err != nil {
   549  			return nil, err
   550  		}
   551  	} else {
   552  		// bytea_output = escape
   553  		for len(s) > 0 {
   554  			if s[0] == '\\' {
   555  				// escaped '\\'
   556  				if len(s) >= 2 && s[1] == '\\' {
   557  					result = append(result, '\\')
   558  					s = s[2:]
   559  					continue
   560  				}
   561  
   562  				// '\\' followed by an octal number
   563  				if len(s) < 4 {
   564  					return nil, fmt.Errorf("invalid bytea sequence %v", s)
   565  				}
   566  				r, err := strconv.ParseUint(string(s[1:4]), 8, 8)
   567  				if err != nil {
   568  					return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
   569  				}
   570  				result = append(result, byte(r))
   571  				s = s[4:]
   572  			} else {
   573  				// We hit an unescaped, raw byte.  Try to read in as many as
   574  				// possible in one go.
   575  				i := bytes.IndexByte(s, '\\')
   576  				if i == -1 {
   577  					result = append(result, s...)
   578  					break
   579  				}
   580  				result = append(result, s[:i]...)
   581  				s = s[i:]
   582  			}
   583  		}
   584  	}
   585  
   586  	return result, nil
   587  }
   588  
   589  func encodeBytea(serverVersion int, v []byte) (result []byte) {
   590  	if serverVersion >= 90000 {
   591  		// Use the hex format if we know that the server supports it
   592  		result = make([]byte, 2+hex.EncodedLen(len(v)))
   593  		result[0] = '\\'
   594  		result[1] = 'x'
   595  		hex.Encode(result[2:], v)
   596  	} else {
   597  		// .. or resort to "escape"
   598  		for _, b := range v {
   599  			if b == '\\' {
   600  				result = append(result, '\\', '\\')
   601  			} else if b < 0x20 || b > 0x7e {
   602  				result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
   603  			} else {
   604  				result = append(result, b)
   605  			}
   606  		}
   607  	}
   608  
   609  	return result
   610  }
   611  
   612  // NullTime represents a time.Time that may be null. NullTime implements the
   613  // sql.Scanner interface so it can be used as a scan destination, similar to
   614  // sql.NullString.
   615  type NullTime struct {
   616  	Time  time.Time
   617  	Valid bool // Valid is true if Time is not NULL
   618  }
   619  
   620  // Scan implements the Scanner interface.
   621  func (nt *NullTime) Scan(value interface{}) error {
   622  	nt.Time, nt.Valid = value.(time.Time)
   623  	return nil
   624  }
   625  
   626  // Value implements the driver Valuer interface.
   627  func (nt NullTime) Value() (driver.Value, error) {
   628  	if !nt.Valid {
   629  		return nil, nil
   630  	}
   631  	return nt.Time, nil
   632  }
   633  

View as plain text