1 package redis
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "math/rand"
8 "strconv"
9 "sync"
10 "sync/atomic"
11 "time"
12
13 "github.com/go-redis/redis/internal"
14 "github.com/go-redis/redis/internal/consistenthash"
15 "github.com/go-redis/redis/internal/hashtag"
16 "github.com/go-redis/redis/internal/pool"
17 )
18
19
20 type Hash consistenthash.Hash
21
22 var errRingShardsDown = errors.New("redis: all ring shards are down")
23
24
25
26 type RingOptions struct {
27
28 Addrs map[string]string
29
30
31
32 HeartbeatFrequency time.Duration
33
34
35
36 Hash Hash
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53 HashReplicas int
54
55
56
57 OnConnect func(*Conn) error
58
59 DB int
60 Password string
61
62 MaxRetries int
63 MinRetryBackoff time.Duration
64 MaxRetryBackoff time.Duration
65
66 DialTimeout time.Duration
67 ReadTimeout time.Duration
68 WriteTimeout time.Duration
69
70 PoolSize int
71 MinIdleConns int
72 MaxConnAge time.Duration
73 PoolTimeout time.Duration
74 IdleTimeout time.Duration
75 IdleCheckFrequency time.Duration
76 }
77
78 func (opt *RingOptions) init() {
79 if opt.HeartbeatFrequency == 0 {
80 opt.HeartbeatFrequency = 500 * time.Millisecond
81 }
82
83 if opt.HashReplicas == 0 {
84 opt.HashReplicas = 100
85 }
86
87 switch opt.MinRetryBackoff {
88 case -1:
89 opt.MinRetryBackoff = 0
90 case 0:
91 opt.MinRetryBackoff = 8 * time.Millisecond
92 }
93 switch opt.MaxRetryBackoff {
94 case -1:
95 opt.MaxRetryBackoff = 0
96 case 0:
97 opt.MaxRetryBackoff = 512 * time.Millisecond
98 }
99 }
100
101 func (opt *RingOptions) clientOptions() *Options {
102 return &Options{
103 OnConnect: opt.OnConnect,
104
105 DB: opt.DB,
106 Password: opt.Password,
107
108 DialTimeout: opt.DialTimeout,
109 ReadTimeout: opt.ReadTimeout,
110 WriteTimeout: opt.WriteTimeout,
111
112 PoolSize: opt.PoolSize,
113 MinIdleConns: opt.MinIdleConns,
114 MaxConnAge: opt.MaxConnAge,
115 PoolTimeout: opt.PoolTimeout,
116 IdleTimeout: opt.IdleTimeout,
117 IdleCheckFrequency: opt.IdleCheckFrequency,
118 }
119 }
120
121
122
123 type ringShard struct {
124 Client *Client
125 down int32
126 }
127
128 func (shard *ringShard) String() string {
129 var state string
130 if shard.IsUp() {
131 state = "up"
132 } else {
133 state = "down"
134 }
135 return fmt.Sprintf("%s is %s", shard.Client, state)
136 }
137
138 func (shard *ringShard) IsDown() bool {
139 const threshold = 3
140 return atomic.LoadInt32(&shard.down) >= threshold
141 }
142
143 func (shard *ringShard) IsUp() bool {
144 return !shard.IsDown()
145 }
146
147
148 func (shard *ringShard) Vote(up bool) bool {
149 if up {
150 changed := shard.IsDown()
151 atomic.StoreInt32(&shard.down, 0)
152 return changed
153 }
154
155 if shard.IsDown() {
156 return false
157 }
158
159 atomic.AddInt32(&shard.down, 1)
160 return shard.IsDown()
161 }
162
163
164
165 type ringShards struct {
166 opt *RingOptions
167
168 mu sync.RWMutex
169 hash *consistenthash.Map
170 shards map[string]*ringShard
171 list []*ringShard
172 len int
173 closed bool
174 }
175
176 func newRingShards(opt *RingOptions) *ringShards {
177 return &ringShards{
178 opt: opt,
179
180 hash: newConsistentHash(opt),
181 shards: make(map[string]*ringShard),
182 }
183 }
184
185 func (c *ringShards) Add(name string, cl *Client) {
186 shard := &ringShard{Client: cl}
187 c.hash.Add(name)
188 c.shards[name] = shard
189 c.list = append(c.list, shard)
190 }
191
192 func (c *ringShards) List() []*ringShard {
193 c.mu.RLock()
194 list := c.list
195 c.mu.RUnlock()
196 return list
197 }
198
199 func (c *ringShards) Hash(key string) string {
200 c.mu.RLock()
201 hash := c.hash.Get(key)
202 c.mu.RUnlock()
203 return hash
204 }
205
206 func (c *ringShards) GetByKey(key string) (*ringShard, error) {
207 key = hashtag.Key(key)
208
209 c.mu.RLock()
210
211 if c.closed {
212 c.mu.RUnlock()
213 return nil, pool.ErrClosed
214 }
215
216 hash := c.hash.Get(key)
217 if hash == "" {
218 c.mu.RUnlock()
219 return nil, errRingShardsDown
220 }
221
222 shard := c.shards[hash]
223 c.mu.RUnlock()
224
225 return shard, nil
226 }
227
228 func (c *ringShards) GetByHash(name string) (*ringShard, error) {
229 if name == "" {
230 return c.Random()
231 }
232
233 c.mu.RLock()
234 shard := c.shards[name]
235 c.mu.RUnlock()
236 return shard, nil
237 }
238
239 func (c *ringShards) Random() (*ringShard, error) {
240 return c.GetByKey(strconv.Itoa(rand.Int()))
241 }
242
243
244 func (c *ringShards) Heartbeat(frequency time.Duration) {
245 ticker := time.NewTicker(frequency)
246 defer ticker.Stop()
247 for range ticker.C {
248 var rebalance bool
249
250 c.mu.RLock()
251
252 if c.closed {
253 c.mu.RUnlock()
254 break
255 }
256
257 shards := c.list
258 c.mu.RUnlock()
259
260 for _, shard := range shards {
261 err := shard.Client.Ping().Err()
262 if shard.Vote(err == nil || err == pool.ErrPoolTimeout) {
263 internal.Logf("ring shard state changed: %s", shard)
264 rebalance = true
265 }
266 }
267
268 if rebalance {
269 c.rebalance()
270 }
271 }
272 }
273
274
275 func (c *ringShards) rebalance() {
276 c.mu.RLock()
277 shards := c.shards
278 c.mu.RUnlock()
279
280 hash := newConsistentHash(c.opt)
281 var shardsNum int
282 for name, shard := range shards {
283 if shard.IsUp() {
284 hash.Add(name)
285 shardsNum++
286 }
287 }
288
289 c.mu.Lock()
290 c.hash = hash
291 c.len = shardsNum
292 c.mu.Unlock()
293 }
294
295 func (c *ringShards) Len() int {
296 c.mu.RLock()
297 l := c.len
298 c.mu.RUnlock()
299 return l
300 }
301
302 func (c *ringShards) Close() error {
303 c.mu.Lock()
304 defer c.mu.Unlock()
305
306 if c.closed {
307 return nil
308 }
309 c.closed = true
310
311 var firstErr error
312 for _, shard := range c.shards {
313 if err := shard.Client.Close(); err != nil && firstErr == nil {
314 firstErr = err
315 }
316 }
317 c.hash = nil
318 c.shards = nil
319 c.list = nil
320
321 return firstErr
322 }
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340 type Ring struct {
341 cmdable
342
343 ctx context.Context
344
345 opt *RingOptions
346 shards *ringShards
347 cmdsInfoCache *cmdsInfoCache
348
349 process func(Cmder) error
350 processPipeline func([]Cmder) error
351 }
352
353 func NewRing(opt *RingOptions) *Ring {
354 opt.init()
355
356 ring := &Ring{
357 opt: opt,
358 shards: newRingShards(opt),
359 }
360 ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
361
362 ring.process = ring.defaultProcess
363 ring.processPipeline = ring.defaultProcessPipeline
364
365 ring.init()
366
367 for name, addr := range opt.Addrs {
368 clopt := opt.clientOptions()
369 clopt.Addr = addr
370 ring.shards.Add(name, NewClient(clopt))
371 }
372
373 go ring.shards.Heartbeat(opt.HeartbeatFrequency)
374
375 return ring
376 }
377
378 func (c *Ring) init() {
379 c.cmdable.setProcessor(c.Process)
380 }
381
382 func (c *Ring) Context() context.Context {
383 if c.ctx != nil {
384 return c.ctx
385 }
386 return context.Background()
387 }
388
389 func (c *Ring) WithContext(ctx context.Context) *Ring {
390 if ctx == nil {
391 panic("nil context")
392 }
393 c2 := c.clone()
394 c2.ctx = ctx
395 return c2
396 }
397
398 func (c *Ring) clone() *Ring {
399 cp := *c
400 cp.init()
401
402 return &cp
403 }
404
405
406 func (c *Ring) Options() *RingOptions {
407 return c.opt
408 }
409
410 func (c *Ring) retryBackoff(attempt int) time.Duration {
411 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
412 }
413
414
415 func (c *Ring) PoolStats() *PoolStats {
416 shards := c.shards.List()
417 var acc PoolStats
418 for _, shard := range shards {
419 s := shard.Client.connPool.Stats()
420 acc.Hits += s.Hits
421 acc.Misses += s.Misses
422 acc.Timeouts += s.Timeouts
423 acc.TotalConns += s.TotalConns
424 acc.IdleConns += s.IdleConns
425 }
426 return &acc
427 }
428
429
430 func (c *Ring) Len() int {
431 return c.shards.Len()
432 }
433
434
435 func (c *Ring) Subscribe(channels ...string) *PubSub {
436 if len(channels) == 0 {
437 panic("at least one channel is required")
438 }
439
440 shard, err := c.shards.GetByKey(channels[0])
441 if err != nil {
442
443 panic(err)
444 }
445 return shard.Client.Subscribe(channels...)
446 }
447
448
449 func (c *Ring) PSubscribe(channels ...string) *PubSub {
450 if len(channels) == 0 {
451 panic("at least one channel is required")
452 }
453
454 shard, err := c.shards.GetByKey(channels[0])
455 if err != nil {
456
457 panic(err)
458 }
459 return shard.Client.PSubscribe(channels...)
460 }
461
462
463
464 func (c *Ring) ForEachShard(fn func(client *Client) error) error {
465 shards := c.shards.List()
466 var wg sync.WaitGroup
467 errCh := make(chan error, 1)
468 for _, shard := range shards {
469 if shard.IsDown() {
470 continue
471 }
472
473 wg.Add(1)
474 go func(shard *ringShard) {
475 defer wg.Done()
476 err := fn(shard.Client)
477 if err != nil {
478 select {
479 case errCh <- err:
480 default:
481 }
482 }
483 }(shard)
484 }
485 wg.Wait()
486
487 select {
488 case err := <-errCh:
489 return err
490 default:
491 return nil
492 }
493 }
494
495 func (c *Ring) cmdsInfo() (map[string]*CommandInfo, error) {
496 shards := c.shards.List()
497 firstErr := errRingShardsDown
498 for _, shard := range shards {
499 cmdsInfo, err := shard.Client.Command().Result()
500 if err == nil {
501 return cmdsInfo, nil
502 }
503 if firstErr == nil {
504 firstErr = err
505 }
506 }
507 return nil, firstErr
508 }
509
510 func (c *Ring) cmdInfo(name string) *CommandInfo {
511 cmdsInfo, err := c.cmdsInfoCache.Get()
512 if err != nil {
513 return nil
514 }
515 info := cmdsInfo[name]
516 if info == nil {
517 internal.Logf("info for cmd=%s not found", name)
518 }
519 return info
520 }
521
522 func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
523 cmdInfo := c.cmdInfo(cmd.Name())
524 pos := cmdFirstKeyPos(cmd, cmdInfo)
525 if pos == 0 {
526 return c.shards.Random()
527 }
528 firstKey := cmd.stringArg(pos)
529 return c.shards.GetByKey(firstKey)
530 }
531
532
533 func (c *Ring) Do(args ...interface{}) *Cmd {
534 cmd := NewCmd(args...)
535 c.Process(cmd)
536 return cmd
537 }
538
539 func (c *Ring) WrapProcess(
540 fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error,
541 ) {
542 c.process = fn(c.process)
543 }
544
545 func (c *Ring) Process(cmd Cmder) error {
546 return c.process(cmd)
547 }
548
549 func (c *Ring) defaultProcess(cmd Cmder) error {
550 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
551 if attempt > 0 {
552 time.Sleep(c.retryBackoff(attempt))
553 }
554
555 shard, err := c.cmdShard(cmd)
556 if err != nil {
557 cmd.setErr(err)
558 return err
559 }
560
561 err = shard.Client.Process(cmd)
562 if err == nil {
563 return nil
564 }
565 if !internal.IsRetryableError(err, cmd.readTimeout() == nil) {
566 return err
567 }
568 }
569 return cmd.Err()
570 }
571
572 func (c *Ring) Pipeline() Pipeliner {
573 pipe := Pipeline{
574 exec: c.processPipeline,
575 }
576 pipe.cmdable.setProcessor(pipe.Process)
577 return &pipe
578 }
579
580 func (c *Ring) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
581 return c.Pipeline().Pipelined(fn)
582 }
583
584 func (c *Ring) WrapProcessPipeline(
585 fn func(oldProcess func([]Cmder) error) func([]Cmder) error,
586 ) {
587 c.processPipeline = fn(c.processPipeline)
588 }
589
590 func (c *Ring) defaultProcessPipeline(cmds []Cmder) error {
591 cmdsMap := make(map[string][]Cmder)
592 for _, cmd := range cmds {
593 cmdInfo := c.cmdInfo(cmd.Name())
594 hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo))
595 if hash != "" {
596 hash = c.shards.Hash(hashtag.Key(hash))
597 }
598 cmdsMap[hash] = append(cmdsMap[hash], cmd)
599 }
600
601 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
602 if attempt > 0 {
603 time.Sleep(c.retryBackoff(attempt))
604 }
605
606 var mu sync.Mutex
607 var failedCmdsMap map[string][]Cmder
608 var wg sync.WaitGroup
609
610 for hash, cmds := range cmdsMap {
611 wg.Add(1)
612 go func(hash string, cmds []Cmder) {
613 defer wg.Done()
614
615 shard, err := c.shards.GetByHash(hash)
616 if err != nil {
617 setCmdsErr(cmds, err)
618 return
619 }
620
621 cn, err := shard.Client.getConn()
622 if err != nil {
623 setCmdsErr(cmds, err)
624 return
625 }
626
627 canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds)
628 shard.Client.releaseConnStrict(cn, err)
629
630 if canRetry && internal.IsRetryableError(err, true) {
631 mu.Lock()
632 if failedCmdsMap == nil {
633 failedCmdsMap = make(map[string][]Cmder)
634 }
635 failedCmdsMap[hash] = cmds
636 mu.Unlock()
637 }
638 }(hash, cmds)
639 }
640
641 wg.Wait()
642 if len(failedCmdsMap) == 0 {
643 break
644 }
645 cmdsMap = failedCmdsMap
646 }
647
648 return cmdsFirstErr(cmds)
649 }
650
651 func (c *Ring) TxPipeline() Pipeliner {
652 panic("not implemented")
653 }
654
655 func (c *Ring) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
656 panic("not implemented")
657 }
658
659
660
661
662
663 func (c *Ring) Close() error {
664 return c.shards.Close()
665 }
666
667 func (c *Ring) Watch(fn func(*Tx) error, keys ...string) error {
668 if len(keys) == 0 {
669 return fmt.Errorf("redis: Watch requires at least one key")
670 }
671
672 var shards []*ringShard
673 for _, key := range keys {
674 if key != "" {
675 shard, err := c.shards.GetByKey(hashtag.Key(key))
676 if err != nil {
677 return err
678 }
679
680 shards = append(shards, shard)
681 }
682 }
683
684 if len(shards) == 0 {
685 return fmt.Errorf("redis: Watch requires at least one shard")
686 }
687
688 if len(shards) > 1 {
689 for _, shard := range shards[1:] {
690 if shard.Client != shards[0].Client {
691 err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
692 return err
693 }
694 }
695 }
696
697 return shards[0].Client.Watch(fn, keys...)
698 }
699
700 func newConsistentHash(opt *RingOptions) *consistenthash.Map {
701 return consistenthash.New(opt.HashReplicas, consistenthash.Hash(opt.Hash))
702 }
703
View as plain text