1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package redis
16
17 import (
18 "bufio"
19 "bytes"
20 "crypto/tls"
21 "errors"
22 "fmt"
23 "io"
24 "net"
25 "net/url"
26 "regexp"
27 "strconv"
28 "sync"
29 "time"
30 )
31
32 var (
33 _ ConnWithTimeout = (*conn)(nil)
34 )
35
36
37 type conn struct {
38
39 mu sync.Mutex
40 pending int
41 err error
42 conn net.Conn
43
44
45 readTimeout time.Duration
46 br *bufio.Reader
47
48
49 writeTimeout time.Duration
50 bw *bufio.Writer
51
52
53
54 lenScratch [32]byte
55
56
57 numScratch [40]byte
58 }
59
60
61
62
63
64 func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
65 return Dial(network, address,
66 DialConnectTimeout(connectTimeout),
67 DialReadTimeout(readTimeout),
68 DialWriteTimeout(writeTimeout))
69 }
70
71
72 type DialOption struct {
73 f func(*dialOptions)
74 }
75
76 type dialOptions struct {
77 readTimeout time.Duration
78 writeTimeout time.Duration
79 dialer *net.Dialer
80 dial func(network, addr string) (net.Conn, error)
81 db int
82 password string
83 useTLS bool
84 skipVerify bool
85 tlsConfig *tls.Config
86 }
87
88
89 func DialReadTimeout(d time.Duration) DialOption {
90 return DialOption{func(do *dialOptions) {
91 do.readTimeout = d
92 }}
93 }
94
95
96 func DialWriteTimeout(d time.Duration) DialOption {
97 return DialOption{func(do *dialOptions) {
98 do.writeTimeout = d
99 }}
100 }
101
102
103
104 func DialConnectTimeout(d time.Duration) DialOption {
105 return DialOption{func(do *dialOptions) {
106 do.dialer.Timeout = d
107 }}
108 }
109
110
111
112
113
114 func DialKeepAlive(d time.Duration) DialOption {
115 return DialOption{func(do *dialOptions) {
116 do.dialer.KeepAlive = d
117 }}
118 }
119
120
121
122
123 func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
124 return DialOption{func(do *dialOptions) {
125 do.dial = dial
126 }}
127 }
128
129
130 func DialDatabase(db int) DialOption {
131 return DialOption{func(do *dialOptions) {
132 do.db = db
133 }}
134 }
135
136
137
138 func DialPassword(password string) DialOption {
139 return DialOption{func(do *dialOptions) {
140 do.password = password
141 }}
142 }
143
144
145
146 func DialTLSConfig(c *tls.Config) DialOption {
147 return DialOption{func(do *dialOptions) {
148 do.tlsConfig = c
149 }}
150 }
151
152
153
154 func DialTLSSkipVerify(skip bool) DialOption {
155 return DialOption{func(do *dialOptions) {
156 do.skipVerify = skip
157 }}
158 }
159
160
161
162 func DialUseTLS(useTLS bool) DialOption {
163 return DialOption{func(do *dialOptions) {
164 do.useTLS = useTLS
165 }}
166 }
167
168
169
170 func Dial(network, address string, options ...DialOption) (Conn, error) {
171 do := dialOptions{
172 dialer: &net.Dialer{
173 KeepAlive: time.Minute * 5,
174 },
175 }
176 for _, option := range options {
177 option.f(&do)
178 }
179 if do.dial == nil {
180 do.dial = do.dialer.Dial
181 }
182
183 netConn, err := do.dial(network, address)
184 if err != nil {
185 return nil, err
186 }
187
188 if do.useTLS {
189 var tlsConfig *tls.Config
190 if do.tlsConfig == nil {
191 tlsConfig = &tls.Config{InsecureSkipVerify: do.skipVerify}
192 } else {
193 tlsConfig = cloneTLSConfig(do.tlsConfig)
194 }
195 if tlsConfig.ServerName == "" {
196 host, _, err := net.SplitHostPort(address)
197 if err != nil {
198 netConn.Close()
199 return nil, err
200 }
201 tlsConfig.ServerName = host
202 }
203
204 tlsConn := tls.Client(netConn, tlsConfig)
205 if err := tlsConn.Handshake(); err != nil {
206 netConn.Close()
207 return nil, err
208 }
209 netConn = tlsConn
210 }
211
212 c := &conn{
213 conn: netConn,
214 bw: bufio.NewWriter(netConn),
215 br: bufio.NewReader(netConn),
216 readTimeout: do.readTimeout,
217 writeTimeout: do.writeTimeout,
218 }
219
220 if do.password != "" {
221 if _, err := c.Do("AUTH", do.password); err != nil {
222 netConn.Close()
223 return nil, err
224 }
225 }
226
227 if do.db != 0 {
228 if _, err := c.Do("SELECT", do.db); err != nil {
229 netConn.Close()
230 return nil, err
231 }
232 }
233
234 return c, nil
235 }
236
237 var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
238
239
240
241
242 func DialURL(rawurl string, options ...DialOption) (Conn, error) {
243 u, err := url.Parse(rawurl)
244 if err != nil {
245 return nil, err
246 }
247
248 if u.Scheme != "redis" && u.Scheme != "rediss" {
249 return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
250 }
251
252
253
254 host, port, err := net.SplitHostPort(u.Host)
255 if err != nil {
256
257 host = u.Host
258 port = "6379"
259 }
260 if host == "" {
261 host = "localhost"
262 }
263 address := net.JoinHostPort(host, port)
264
265 if u.User != nil {
266 password, isSet := u.User.Password()
267 if isSet {
268 options = append(options, DialPassword(password))
269 }
270 }
271
272 match := pathDBRegexp.FindStringSubmatch(u.Path)
273 if len(match) == 2 {
274 db := 0
275 if len(match[1]) > 0 {
276 db, err = strconv.Atoi(match[1])
277 if err != nil {
278 return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
279 }
280 }
281 if db != 0 {
282 options = append(options, DialDatabase(db))
283 }
284 } else if u.Path != "" {
285 return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
286 }
287
288 options = append(options, DialUseTLS(u.Scheme == "rediss"))
289
290 return Dial("tcp", address, options...)
291 }
292
293
294 func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
295 return &conn{
296 conn: netConn,
297 bw: bufio.NewWriter(netConn),
298 br: bufio.NewReader(netConn),
299 readTimeout: readTimeout,
300 writeTimeout: writeTimeout,
301 }
302 }
303
304 func (c *conn) Close() error {
305 c.mu.Lock()
306 err := c.err
307 if c.err == nil {
308 c.err = errors.New("redigo: closed")
309 err = c.conn.Close()
310 }
311 c.mu.Unlock()
312 return err
313 }
314
315 func (c *conn) fatal(err error) error {
316 c.mu.Lock()
317 if c.err == nil {
318 c.err = err
319
320
321 c.conn.Close()
322 }
323 c.mu.Unlock()
324 return err
325 }
326
327 func (c *conn) Err() error {
328 c.mu.Lock()
329 err := c.err
330 c.mu.Unlock()
331 return err
332 }
333
334 func (c *conn) writeLen(prefix byte, n int) error {
335 c.lenScratch[len(c.lenScratch)-1] = '\n'
336 c.lenScratch[len(c.lenScratch)-2] = '\r'
337 i := len(c.lenScratch) - 3
338 for {
339 c.lenScratch[i] = byte('0' + n%10)
340 i -= 1
341 n = n / 10
342 if n == 0 {
343 break
344 }
345 }
346 c.lenScratch[i] = prefix
347 _, err := c.bw.Write(c.lenScratch[i:])
348 return err
349 }
350
351 func (c *conn) writeString(s string) error {
352 c.writeLen('$', len(s))
353 c.bw.WriteString(s)
354 _, err := c.bw.WriteString("\r\n")
355 return err
356 }
357
358 func (c *conn) writeBytes(p []byte) error {
359 c.writeLen('$', len(p))
360 c.bw.Write(p)
361 _, err := c.bw.WriteString("\r\n")
362 return err
363 }
364
365 func (c *conn) writeInt64(n int64) error {
366 return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
367 }
368
369 func (c *conn) writeFloat64(n float64) error {
370 return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
371 }
372
373 func (c *conn) writeCommand(cmd string, args []interface{}) error {
374 c.writeLen('*', 1+len(args))
375 if err := c.writeString(cmd); err != nil {
376 return err
377 }
378 for _, arg := range args {
379 if err := c.writeArg(arg, true); err != nil {
380 return err
381 }
382 }
383 return nil
384 }
385
386 func (c *conn) writeArg(arg interface{}, argumentTypeOK bool) (err error) {
387 switch arg := arg.(type) {
388 case string:
389 return c.writeString(arg)
390 case []byte:
391 return c.writeBytes(arg)
392 case int:
393 return c.writeInt64(int64(arg))
394 case int64:
395 return c.writeInt64(arg)
396 case float64:
397 return c.writeFloat64(arg)
398 case bool:
399 if arg {
400 return c.writeString("1")
401 } else {
402 return c.writeString("0")
403 }
404 case nil:
405 return c.writeString("")
406 case Argument:
407 if argumentTypeOK {
408 return c.writeArg(arg.RedisArg(), false)
409 }
410
411 var buf bytes.Buffer
412 fmt.Fprint(&buf, arg)
413 return c.writeBytes(buf.Bytes())
414 default:
415
416
417
418 var buf bytes.Buffer
419 fmt.Fprint(&buf, arg)
420 return c.writeBytes(buf.Bytes())
421 }
422 }
423
424 type protocolError string
425
426 func (pe protocolError) Error() string {
427 return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
428 }
429
430 func (c *conn) readLine() ([]byte, error) {
431 p, err := c.br.ReadSlice('\n')
432 if err == bufio.ErrBufferFull {
433 return nil, protocolError("long response line")
434 }
435 if err != nil {
436 return nil, err
437 }
438 i := len(p) - 2
439 if i < 0 || p[i] != '\r' {
440 return nil, protocolError("bad response line terminator")
441 }
442 return p[:i], nil
443 }
444
445
446 func parseLen(p []byte) (int, error) {
447 if len(p) == 0 {
448 return -1, protocolError("malformed length")
449 }
450
451 if p[0] == '-' && len(p) == 2 && p[1] == '1' {
452
453 return -1, nil
454 }
455
456 var n int
457 for _, b := range p {
458 n *= 10
459 if b < '0' || b > '9' {
460 return -1, protocolError("illegal bytes in length")
461 }
462 n += int(b - '0')
463 }
464
465 return n, nil
466 }
467
468
469 func parseInt(p []byte) (interface{}, error) {
470 if len(p) == 0 {
471 return 0, protocolError("malformed integer")
472 }
473
474 var negate bool
475 if p[0] == '-' {
476 negate = true
477 p = p[1:]
478 if len(p) == 0 {
479 return 0, protocolError("malformed integer")
480 }
481 }
482
483 var n int64
484 for _, b := range p {
485 n *= 10
486 if b < '0' || b > '9' {
487 return 0, protocolError("illegal bytes in length")
488 }
489 n += int64(b - '0')
490 }
491
492 if negate {
493 n = -n
494 }
495 return n, nil
496 }
497
498 var (
499 okReply interface{} = "OK"
500 pongReply interface{} = "PONG"
501 )
502
503 func (c *conn) readReply() (interface{}, error) {
504 line, err := c.readLine()
505 if err != nil {
506 return nil, err
507 }
508 if len(line) == 0 {
509 return nil, protocolError("short response line")
510 }
511 switch line[0] {
512 case '+':
513 switch {
514 case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
515
516 return okReply, nil
517 case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
518
519 return pongReply, nil
520 default:
521 return string(line[1:]), nil
522 }
523 case '-':
524 return Error(string(line[1:])), nil
525 case ':':
526 return parseInt(line[1:])
527 case '$':
528 n, err := parseLen(line[1:])
529 if n < 0 || err != nil {
530 return nil, err
531 }
532 p := make([]byte, n)
533 _, err = io.ReadFull(c.br, p)
534 if err != nil {
535 return nil, err
536 }
537 if line, err := c.readLine(); err != nil {
538 return nil, err
539 } else if len(line) != 0 {
540 return nil, protocolError("bad bulk string format")
541 }
542 return p, nil
543 case '*':
544 n, err := parseLen(line[1:])
545 if n < 0 || err != nil {
546 return nil, err
547 }
548 r := make([]interface{}, n)
549 for i := range r {
550 r[i], err = c.readReply()
551 if err != nil {
552 return nil, err
553 }
554 }
555 return r, nil
556 }
557 return nil, protocolError("unexpected response line")
558 }
559
560 func (c *conn) Send(cmd string, args ...interface{}) error {
561 c.mu.Lock()
562 c.pending += 1
563 c.mu.Unlock()
564 if c.writeTimeout != 0 {
565 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
566 }
567 if err := c.writeCommand(cmd, args); err != nil {
568 return c.fatal(err)
569 }
570 return nil
571 }
572
573 func (c *conn) Flush() error {
574 if c.writeTimeout != 0 {
575 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
576 }
577 if err := c.bw.Flush(); err != nil {
578 return c.fatal(err)
579 }
580 return nil
581 }
582
583 func (c *conn) Receive() (interface{}, error) {
584 return c.ReceiveWithTimeout(c.readTimeout)
585 }
586
587 func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
588 var deadline time.Time
589 if timeout != 0 {
590 deadline = time.Now().Add(timeout)
591 }
592 c.conn.SetReadDeadline(deadline)
593
594 if reply, err = c.readReply(); err != nil {
595 return nil, c.fatal(err)
596 }
597
598
599
600
601
602
603
604 c.mu.Lock()
605 if c.pending > 0 {
606 c.pending -= 1
607 }
608 c.mu.Unlock()
609 if err, ok := reply.(Error); ok {
610 return nil, err
611 }
612 return
613 }
614
615 func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
616 return c.DoWithTimeout(c.readTimeout, cmd, args...)
617 }
618
619 func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
620 c.mu.Lock()
621 pending := c.pending
622 c.pending = 0
623 c.mu.Unlock()
624
625 if cmd == "" && pending == 0 {
626 return nil, nil
627 }
628
629 if c.writeTimeout != 0 {
630 c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
631 }
632
633 if cmd != "" {
634 if err := c.writeCommand(cmd, args); err != nil {
635 return nil, c.fatal(err)
636 }
637 }
638
639 if err := c.bw.Flush(); err != nil {
640 return nil, c.fatal(err)
641 }
642
643 var deadline time.Time
644 if readTimeout != 0 {
645 deadline = time.Now().Add(readTimeout)
646 }
647 c.conn.SetReadDeadline(deadline)
648
649 if cmd == "" {
650 reply := make([]interface{}, pending)
651 for i := range reply {
652 r, e := c.readReply()
653 if e != nil {
654 return nil, c.fatal(e)
655 }
656 reply[i] = r
657 }
658 return reply, nil
659 }
660
661 var err error
662 var reply interface{}
663 for i := 0; i <= pending; i++ {
664 var e error
665 if reply, e = c.readReply(); e != nil {
666 return nil, c.fatal(e)
667 }
668 if e, ok := reply.(Error); ok && err == nil {
669 err = e
670 }
671 }
672 return reply, err
673 }
674
View as plain text