...

Source file src/github.com/gomodule/redigo/redis/conn.go

Documentation: github.com/gomodule/redigo/redis

     1  // Copyright 2012 Gary Burd
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"): you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    11  // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    12  // License for the specific language governing permissions and limitations
    13  // under the License.
    14  
    15  package redis
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"crypto/tls"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"net/url"
    26  	"regexp"
    27  	"strconv"
    28  	"sync"
    29  	"time"
    30  )
    31  
    32  var (
    33  	_ ConnWithTimeout = (*conn)(nil)
    34  )
    35  
    36  // conn is the low-level implementation of Conn
    37  type conn struct {
    38  	// Shared
    39  	mu      sync.Mutex
    40  	pending int
    41  	err     error
    42  	conn    net.Conn
    43  
    44  	// Read
    45  	readTimeout time.Duration
    46  	br          *bufio.Reader
    47  
    48  	// Write
    49  	writeTimeout time.Duration
    50  	bw           *bufio.Writer
    51  
    52  	// Scratch space for formatting argument length.
    53  	// '*' or '$', length, "\r\n"
    54  	lenScratch [32]byte
    55  
    56  	// Scratch space for formatting integers and floats.
    57  	numScratch [40]byte
    58  }
    59  
    60  // DialTimeout acts like Dial but takes timeouts for establishing the
    61  // connection to the server, writing a command and reading a reply.
    62  //
    63  // Deprecated: Use Dial with options instead.
    64  func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
    65  	return Dial(network, address,
    66  		DialConnectTimeout(connectTimeout),
    67  		DialReadTimeout(readTimeout),
    68  		DialWriteTimeout(writeTimeout))
    69  }
    70  
    71  // DialOption specifies an option for dialing a Redis server.
    72  type DialOption struct {
    73  	f func(*dialOptions)
    74  }
    75  
    76  type dialOptions struct {
    77  	readTimeout  time.Duration
    78  	writeTimeout time.Duration
    79  	dialer       *net.Dialer
    80  	dial         func(network, addr string) (net.Conn, error)
    81  	db           int
    82  	password     string
    83  	useTLS       bool
    84  	skipVerify   bool
    85  	tlsConfig    *tls.Config
    86  }
    87  
    88  // DialReadTimeout specifies the timeout for reading a single command reply.
    89  func DialReadTimeout(d time.Duration) DialOption {
    90  	return DialOption{func(do *dialOptions) {
    91  		do.readTimeout = d
    92  	}}
    93  }
    94  
    95  // DialWriteTimeout specifies the timeout for writing a single command.
    96  func DialWriteTimeout(d time.Duration) DialOption {
    97  	return DialOption{func(do *dialOptions) {
    98  		do.writeTimeout = d
    99  	}}
   100  }
   101  
   102  // DialConnectTimeout specifies the timeout for connecting to the Redis server when
   103  // no DialNetDial option is specified.
   104  func DialConnectTimeout(d time.Duration) DialOption {
   105  	return DialOption{func(do *dialOptions) {
   106  		do.dialer.Timeout = d
   107  	}}
   108  }
   109  
   110  // DialKeepAlive specifies the keep-alive period for TCP connections to the Redis server
   111  // when no DialNetDial option is specified.
   112  // If zero, keep-alives are not enabled. If no DialKeepAlive option is specified then
   113  // the default of 5 minutes is used to ensure that half-closed TCP sessions are detected.
   114  func DialKeepAlive(d time.Duration) DialOption {
   115  	return DialOption{func(do *dialOptions) {
   116  		do.dialer.KeepAlive = d
   117  	}}
   118  }
   119  
   120  // DialNetDial specifies a custom dial function for creating TCP
   121  // connections, otherwise a net.Dialer customized via the other options is used.
   122  // DialNetDial overrides DialConnectTimeout and DialKeepAlive.
   123  func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
   124  	return DialOption{func(do *dialOptions) {
   125  		do.dial = dial
   126  	}}
   127  }
   128  
   129  // DialDatabase specifies the database to select when dialing a connection.
   130  func DialDatabase(db int) DialOption {
   131  	return DialOption{func(do *dialOptions) {
   132  		do.db = db
   133  	}}
   134  }
   135  
   136  // DialPassword specifies the password to use when connecting to
   137  // the Redis server.
   138  func DialPassword(password string) DialOption {
   139  	return DialOption{func(do *dialOptions) {
   140  		do.password = password
   141  	}}
   142  }
   143  
   144  // DialTLSConfig specifies the config to use when a TLS connection is dialed.
   145  // Has no effect when not dialing a TLS connection.
   146  func DialTLSConfig(c *tls.Config) DialOption {
   147  	return DialOption{func(do *dialOptions) {
   148  		do.tlsConfig = c
   149  	}}
   150  }
   151  
   152  // DialTLSSkipVerify disables server name verification when connecting over
   153  // TLS. Has no effect when not dialing a TLS connection.
   154  func DialTLSSkipVerify(skip bool) DialOption {
   155  	return DialOption{func(do *dialOptions) {
   156  		do.skipVerify = skip
   157  	}}
   158  }
   159  
   160  // DialUseTLS specifies whether TLS should be used when connecting to the
   161  // server. This option is ignore by DialURL.
   162  func DialUseTLS(useTLS bool) DialOption {
   163  	return DialOption{func(do *dialOptions) {
   164  		do.useTLS = useTLS
   165  	}}
   166  }
   167  
   168  // Dial connects to the Redis server at the given network and
   169  // address using the specified options.
   170  func Dial(network, address string, options ...DialOption) (Conn, error) {
   171  	do := dialOptions{
   172  		dialer: &net.Dialer{
   173  			KeepAlive: time.Minute * 5,
   174  		},
   175  	}
   176  	for _, option := range options {
   177  		option.f(&do)
   178  	}
   179  	if do.dial == nil {
   180  		do.dial = do.dialer.Dial
   181  	}
   182  
   183  	netConn, err := do.dial(network, address)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	if do.useTLS {
   189  		var tlsConfig *tls.Config
   190  		if do.tlsConfig == nil {
   191  			tlsConfig = &tls.Config{InsecureSkipVerify: do.skipVerify}
   192  		} else {
   193  			tlsConfig = cloneTLSConfig(do.tlsConfig)
   194  		}
   195  		if tlsConfig.ServerName == "" {
   196  			host, _, err := net.SplitHostPort(address)
   197  			if err != nil {
   198  				netConn.Close()
   199  				return nil, err
   200  			}
   201  			tlsConfig.ServerName = host
   202  		}
   203  
   204  		tlsConn := tls.Client(netConn, tlsConfig)
   205  		if err := tlsConn.Handshake(); err != nil {
   206  			netConn.Close()
   207  			return nil, err
   208  		}
   209  		netConn = tlsConn
   210  	}
   211  
   212  	c := &conn{
   213  		conn:         netConn,
   214  		bw:           bufio.NewWriter(netConn),
   215  		br:           bufio.NewReader(netConn),
   216  		readTimeout:  do.readTimeout,
   217  		writeTimeout: do.writeTimeout,
   218  	}
   219  
   220  	if do.password != "" {
   221  		if _, err := c.Do("AUTH", do.password); err != nil {
   222  			netConn.Close()
   223  			return nil, err
   224  		}
   225  	}
   226  
   227  	if do.db != 0 {
   228  		if _, err := c.Do("SELECT", do.db); err != nil {
   229  			netConn.Close()
   230  			return nil, err
   231  		}
   232  	}
   233  
   234  	return c, nil
   235  }
   236  
   237  var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
   238  
   239  // DialURL connects to a Redis server at the given URL using the Redis
   240  // URI scheme. URLs should follow the draft IANA specification for the
   241  // scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
   242  func DialURL(rawurl string, options ...DialOption) (Conn, error) {
   243  	u, err := url.Parse(rawurl)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	if u.Scheme != "redis" && u.Scheme != "rediss" {
   249  		return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
   250  	}
   251  
   252  	// As per the IANA draft spec, the host defaults to localhost and
   253  	// the port defaults to 6379.
   254  	host, port, err := net.SplitHostPort(u.Host)
   255  	if err != nil {
   256  		// assume port is missing
   257  		host = u.Host
   258  		port = "6379"
   259  	}
   260  	if host == "" {
   261  		host = "localhost"
   262  	}
   263  	address := net.JoinHostPort(host, port)
   264  
   265  	if u.User != nil {
   266  		password, isSet := u.User.Password()
   267  		if isSet {
   268  			options = append(options, DialPassword(password))
   269  		}
   270  	}
   271  
   272  	match := pathDBRegexp.FindStringSubmatch(u.Path)
   273  	if len(match) == 2 {
   274  		db := 0
   275  		if len(match[1]) > 0 {
   276  			db, err = strconv.Atoi(match[1])
   277  			if err != nil {
   278  				return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
   279  			}
   280  		}
   281  		if db != 0 {
   282  			options = append(options, DialDatabase(db))
   283  		}
   284  	} else if u.Path != "" {
   285  		return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
   286  	}
   287  
   288  	options = append(options, DialUseTLS(u.Scheme == "rediss"))
   289  
   290  	return Dial("tcp", address, options...)
   291  }
   292  
   293  // NewConn returns a new Redigo connection for the given net connection.
   294  func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
   295  	return &conn{
   296  		conn:         netConn,
   297  		bw:           bufio.NewWriter(netConn),
   298  		br:           bufio.NewReader(netConn),
   299  		readTimeout:  readTimeout,
   300  		writeTimeout: writeTimeout,
   301  	}
   302  }
   303  
   304  func (c *conn) Close() error {
   305  	c.mu.Lock()
   306  	err := c.err
   307  	if c.err == nil {
   308  		c.err = errors.New("redigo: closed")
   309  		err = c.conn.Close()
   310  	}
   311  	c.mu.Unlock()
   312  	return err
   313  }
   314  
   315  func (c *conn) fatal(err error) error {
   316  	c.mu.Lock()
   317  	if c.err == nil {
   318  		c.err = err
   319  		// Close connection to force errors on subsequent calls and to unblock
   320  		// other reader or writer.
   321  		c.conn.Close()
   322  	}
   323  	c.mu.Unlock()
   324  	return err
   325  }
   326  
   327  func (c *conn) Err() error {
   328  	c.mu.Lock()
   329  	err := c.err
   330  	c.mu.Unlock()
   331  	return err
   332  }
   333  
   334  func (c *conn) writeLen(prefix byte, n int) error {
   335  	c.lenScratch[len(c.lenScratch)-1] = '\n'
   336  	c.lenScratch[len(c.lenScratch)-2] = '\r'
   337  	i := len(c.lenScratch) - 3
   338  	for {
   339  		c.lenScratch[i] = byte('0' + n%10)
   340  		i -= 1
   341  		n = n / 10
   342  		if n == 0 {
   343  			break
   344  		}
   345  	}
   346  	c.lenScratch[i] = prefix
   347  	_, err := c.bw.Write(c.lenScratch[i:])
   348  	return err
   349  }
   350  
   351  func (c *conn) writeString(s string) error {
   352  	c.writeLen('$', len(s))
   353  	c.bw.WriteString(s)
   354  	_, err := c.bw.WriteString("\r\n")
   355  	return err
   356  }
   357  
   358  func (c *conn) writeBytes(p []byte) error {
   359  	c.writeLen('$', len(p))
   360  	c.bw.Write(p)
   361  	_, err := c.bw.WriteString("\r\n")
   362  	return err
   363  }
   364  
   365  func (c *conn) writeInt64(n int64) error {
   366  	return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
   367  }
   368  
   369  func (c *conn) writeFloat64(n float64) error {
   370  	return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
   371  }
   372  
   373  func (c *conn) writeCommand(cmd string, args []interface{}) error {
   374  	c.writeLen('*', 1+len(args))
   375  	if err := c.writeString(cmd); err != nil {
   376  		return err
   377  	}
   378  	for _, arg := range args {
   379  		if err := c.writeArg(arg, true); err != nil {
   380  			return err
   381  		}
   382  	}
   383  	return nil
   384  }
   385  
   386  func (c *conn) writeArg(arg interface{}, argumentTypeOK bool) (err error) {
   387  	switch arg := arg.(type) {
   388  	case string:
   389  		return c.writeString(arg)
   390  	case []byte:
   391  		return c.writeBytes(arg)
   392  	case int:
   393  		return c.writeInt64(int64(arg))
   394  	case int64:
   395  		return c.writeInt64(arg)
   396  	case float64:
   397  		return c.writeFloat64(arg)
   398  	case bool:
   399  		if arg {
   400  			return c.writeString("1")
   401  		} else {
   402  			return c.writeString("0")
   403  		}
   404  	case nil:
   405  		return c.writeString("")
   406  	case Argument:
   407  		if argumentTypeOK {
   408  			return c.writeArg(arg.RedisArg(), false)
   409  		}
   410  		// See comment in default clause below.
   411  		var buf bytes.Buffer
   412  		fmt.Fprint(&buf, arg)
   413  		return c.writeBytes(buf.Bytes())
   414  	default:
   415  		// This default clause is intended to handle builtin numeric types.
   416  		// The function should return an error for other types, but this is not
   417  		// done for compatibility with previous versions of the package.
   418  		var buf bytes.Buffer
   419  		fmt.Fprint(&buf, arg)
   420  		return c.writeBytes(buf.Bytes())
   421  	}
   422  }
   423  
   424  type protocolError string
   425  
   426  func (pe protocolError) Error() string {
   427  	return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
   428  }
   429  
   430  func (c *conn) readLine() ([]byte, error) {
   431  	p, err := c.br.ReadSlice('\n')
   432  	if err == bufio.ErrBufferFull {
   433  		return nil, protocolError("long response line")
   434  	}
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  	i := len(p) - 2
   439  	if i < 0 || p[i] != '\r' {
   440  		return nil, protocolError("bad response line terminator")
   441  	}
   442  	return p[:i], nil
   443  }
   444  
   445  // parseLen parses bulk string and array lengths.
   446  func parseLen(p []byte) (int, error) {
   447  	if len(p) == 0 {
   448  		return -1, protocolError("malformed length")
   449  	}
   450  
   451  	if p[0] == '-' && len(p) == 2 && p[1] == '1' {
   452  		// handle $-1 and $-1 null replies.
   453  		return -1, nil
   454  	}
   455  
   456  	var n int
   457  	for _, b := range p {
   458  		n *= 10
   459  		if b < '0' || b > '9' {
   460  			return -1, protocolError("illegal bytes in length")
   461  		}
   462  		n += int(b - '0')
   463  	}
   464  
   465  	return n, nil
   466  }
   467  
   468  // parseInt parses an integer reply.
   469  func parseInt(p []byte) (interface{}, error) {
   470  	if len(p) == 0 {
   471  		return 0, protocolError("malformed integer")
   472  	}
   473  
   474  	var negate bool
   475  	if p[0] == '-' {
   476  		negate = true
   477  		p = p[1:]
   478  		if len(p) == 0 {
   479  			return 0, protocolError("malformed integer")
   480  		}
   481  	}
   482  
   483  	var n int64
   484  	for _, b := range p {
   485  		n *= 10
   486  		if b < '0' || b > '9' {
   487  			return 0, protocolError("illegal bytes in length")
   488  		}
   489  		n += int64(b - '0')
   490  	}
   491  
   492  	if negate {
   493  		n = -n
   494  	}
   495  	return n, nil
   496  }
   497  
   498  var (
   499  	okReply   interface{} = "OK"
   500  	pongReply interface{} = "PONG"
   501  )
   502  
   503  func (c *conn) readReply() (interface{}, error) {
   504  	line, err := c.readLine()
   505  	if err != nil {
   506  		return nil, err
   507  	}
   508  	if len(line) == 0 {
   509  		return nil, protocolError("short response line")
   510  	}
   511  	switch line[0] {
   512  	case '+':
   513  		switch {
   514  		case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
   515  			// Avoid allocation for frequent "+OK" response.
   516  			return okReply, nil
   517  		case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
   518  			// Avoid allocation in PING command benchmarks :)
   519  			return pongReply, nil
   520  		default:
   521  			return string(line[1:]), nil
   522  		}
   523  	case '-':
   524  		return Error(string(line[1:])), nil
   525  	case ':':
   526  		return parseInt(line[1:])
   527  	case '$':
   528  		n, err := parseLen(line[1:])
   529  		if n < 0 || err != nil {
   530  			return nil, err
   531  		}
   532  		p := make([]byte, n)
   533  		_, err = io.ReadFull(c.br, p)
   534  		if err != nil {
   535  			return nil, err
   536  		}
   537  		if line, err := c.readLine(); err != nil {
   538  			return nil, err
   539  		} else if len(line) != 0 {
   540  			return nil, protocolError("bad bulk string format")
   541  		}
   542  		return p, nil
   543  	case '*':
   544  		n, err := parseLen(line[1:])
   545  		if n < 0 || err != nil {
   546  			return nil, err
   547  		}
   548  		r := make([]interface{}, n)
   549  		for i := range r {
   550  			r[i], err = c.readReply()
   551  			if err != nil {
   552  				return nil, err
   553  			}
   554  		}
   555  		return r, nil
   556  	}
   557  	return nil, protocolError("unexpected response line")
   558  }
   559  
   560  func (c *conn) Send(cmd string, args ...interface{}) error {
   561  	c.mu.Lock()
   562  	c.pending += 1
   563  	c.mu.Unlock()
   564  	if c.writeTimeout != 0 {
   565  		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
   566  	}
   567  	if err := c.writeCommand(cmd, args); err != nil {
   568  		return c.fatal(err)
   569  	}
   570  	return nil
   571  }
   572  
   573  func (c *conn) Flush() error {
   574  	if c.writeTimeout != 0 {
   575  		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
   576  	}
   577  	if err := c.bw.Flush(); err != nil {
   578  		return c.fatal(err)
   579  	}
   580  	return nil
   581  }
   582  
   583  func (c *conn) Receive() (interface{}, error) {
   584  	return c.ReceiveWithTimeout(c.readTimeout)
   585  }
   586  
   587  func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
   588  	var deadline time.Time
   589  	if timeout != 0 {
   590  		deadline = time.Now().Add(timeout)
   591  	}
   592  	c.conn.SetReadDeadline(deadline)
   593  
   594  	if reply, err = c.readReply(); err != nil {
   595  		return nil, c.fatal(err)
   596  	}
   597  	// When using pub/sub, the number of receives can be greater than the
   598  	// number of sends. To enable normal use of the connection after
   599  	// unsubscribing from all channels, we do not decrement pending to a
   600  	// negative value.
   601  	//
   602  	// The pending field is decremented after the reply is read to handle the
   603  	// case where Receive is called before Send.
   604  	c.mu.Lock()
   605  	if c.pending > 0 {
   606  		c.pending -= 1
   607  	}
   608  	c.mu.Unlock()
   609  	if err, ok := reply.(Error); ok {
   610  		return nil, err
   611  	}
   612  	return
   613  }
   614  
   615  func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
   616  	return c.DoWithTimeout(c.readTimeout, cmd, args...)
   617  }
   618  
   619  func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
   620  	c.mu.Lock()
   621  	pending := c.pending
   622  	c.pending = 0
   623  	c.mu.Unlock()
   624  
   625  	if cmd == "" && pending == 0 {
   626  		return nil, nil
   627  	}
   628  
   629  	if c.writeTimeout != 0 {
   630  		c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
   631  	}
   632  
   633  	if cmd != "" {
   634  		if err := c.writeCommand(cmd, args); err != nil {
   635  			return nil, c.fatal(err)
   636  		}
   637  	}
   638  
   639  	if err := c.bw.Flush(); err != nil {
   640  		return nil, c.fatal(err)
   641  	}
   642  
   643  	var deadline time.Time
   644  	if readTimeout != 0 {
   645  		deadline = time.Now().Add(readTimeout)
   646  	}
   647  	c.conn.SetReadDeadline(deadline)
   648  
   649  	if cmd == "" {
   650  		reply := make([]interface{}, pending)
   651  		for i := range reply {
   652  			r, e := c.readReply()
   653  			if e != nil {
   654  				return nil, c.fatal(e)
   655  			}
   656  			reply[i] = r
   657  		}
   658  		return reply, nil
   659  	}
   660  
   661  	var err error
   662  	var reply interface{}
   663  	for i := 0; i <= pending; i++ {
   664  		var e error
   665  		if reply, e = c.readReply(); e != nil {
   666  			return nil, c.fatal(e)
   667  		}
   668  		if e, ok := reply.(Error); ok && err == nil {
   669  			err = e
   670  		}
   671  	}
   672  	return reply, err
   673  }
   674  

View as plain text