1 package redis
2
3 import (
4 "errors"
5 "fmt"
6 "strings"
7 "sync"
8 "time"
9
10 "github.com/go-redis/redis/internal"
11 "github.com/go-redis/redis/internal/pool"
12 "github.com/go-redis/redis/internal/proto"
13 )
14
15 var errPingTimeout = errors.New("redis: ping timeout")
16
17
18
19
20
21
22
23 type PubSub struct {
24 opt *Options
25
26 newConn func([]string) (*pool.Conn, error)
27 closeConn func(*pool.Conn) error
28
29 mu sync.Mutex
30 cn *pool.Conn
31 channels map[string]struct{}
32 patterns map[string]struct{}
33
34 closed bool
35 exit chan struct{}
36
37 cmd *Cmd
38
39 chOnce sync.Once
40 ch chan *Message
41 ping chan struct{}
42 }
43
44 func (c *PubSub) String() string {
45 channels := mapKeys(c.channels)
46 channels = append(channels, mapKeys(c.patterns)...)
47 return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
48 }
49
50 func (c *PubSub) init() {
51 c.exit = make(chan struct{})
52 }
53
54 func (c *PubSub) conn() (*pool.Conn, error) {
55 c.mu.Lock()
56 cn, err := c._conn(nil)
57 c.mu.Unlock()
58 return cn, err
59 }
60
61 func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
62 if c.closed {
63 return nil, pool.ErrClosed
64 }
65 if c.cn != nil {
66 return c.cn, nil
67 }
68
69 channels := mapKeys(c.channels)
70 channels = append(channels, newChannels...)
71
72 cn, err := c.newConn(channels)
73 if err != nil {
74 return nil, err
75 }
76
77 if err := c.resubscribe(cn); err != nil {
78 _ = c.closeConn(cn)
79 return nil, err
80 }
81
82 c.cn = cn
83 return cn, nil
84 }
85
86 func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error {
87 return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
88 return writeCmd(wr, cmd)
89 })
90 }
91
92 func (c *PubSub) resubscribe(cn *pool.Conn) error {
93 var firstErr error
94
95 if len(c.channels) > 0 {
96 err := c._subscribe(cn, "subscribe", mapKeys(c.channels))
97 if err != nil && firstErr == nil {
98 firstErr = err
99 }
100 }
101
102 if len(c.patterns) > 0 {
103 err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns))
104 if err != nil && firstErr == nil {
105 firstErr = err
106 }
107 }
108
109 return firstErr
110 }
111
112 func mapKeys(m map[string]struct{}) []string {
113 s := make([]string, len(m))
114 i := 0
115 for k := range m {
116 s[i] = k
117 i++
118 }
119 return s
120 }
121
122 func (c *PubSub) _subscribe(
123 cn *pool.Conn, redisCmd string, channels []string,
124 ) error {
125 args := make([]interface{}, 0, 1+len(channels))
126 args = append(args, redisCmd)
127 for _, channel := range channels {
128 args = append(args, channel)
129 }
130 cmd := NewSliceCmd(args...)
131 return c.writeCmd(cn, cmd)
132 }
133
134 func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
135 c.mu.Lock()
136 c._releaseConn(cn, err, allowTimeout)
137 c.mu.Unlock()
138 }
139
140 func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
141 if c.cn != cn {
142 return
143 }
144 if internal.IsBadConn(err, allowTimeout) {
145 c._reconnect(err)
146 }
147 }
148
149 func (c *PubSub) _reconnect(reason error) {
150 _ = c._closeTheCn(reason)
151 _, _ = c._conn(nil)
152 }
153
154 func (c *PubSub) _closeTheCn(reason error) error {
155 if c.cn == nil {
156 return nil
157 }
158 if !c.closed {
159 internal.Logf("redis: discarding bad PubSub connection: %s", reason)
160 }
161 err := c.closeConn(c.cn)
162 c.cn = nil
163 return err
164 }
165
166 func (c *PubSub) Close() error {
167 c.mu.Lock()
168 defer c.mu.Unlock()
169
170 if c.closed {
171 return pool.ErrClosed
172 }
173 c.closed = true
174 close(c.exit)
175
176 err := c._closeTheCn(pool.ErrClosed)
177 return err
178 }
179
180
181
182 func (c *PubSub) Subscribe(channels ...string) error {
183 c.mu.Lock()
184 defer c.mu.Unlock()
185
186 err := c.subscribe("subscribe", channels...)
187 if c.channels == nil {
188 c.channels = make(map[string]struct{})
189 }
190 for _, s := range channels {
191 c.channels[s] = struct{}{}
192 }
193 return err
194 }
195
196
197
198 func (c *PubSub) PSubscribe(patterns ...string) error {
199 c.mu.Lock()
200 defer c.mu.Unlock()
201
202 err := c.subscribe("psubscribe", patterns...)
203 if c.patterns == nil {
204 c.patterns = make(map[string]struct{})
205 }
206 for _, s := range patterns {
207 c.patterns[s] = struct{}{}
208 }
209 return err
210 }
211
212
213
214 func (c *PubSub) Unsubscribe(channels ...string) error {
215 c.mu.Lock()
216 defer c.mu.Unlock()
217
218 for _, channel := range channels {
219 delete(c.channels, channel)
220 }
221 err := c.subscribe("unsubscribe", channels...)
222 return err
223 }
224
225
226
227 func (c *PubSub) PUnsubscribe(patterns ...string) error {
228 c.mu.Lock()
229 defer c.mu.Unlock()
230
231 for _, pattern := range patterns {
232 delete(c.patterns, pattern)
233 }
234 err := c.subscribe("punsubscribe", patterns...)
235 return err
236 }
237
238 func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
239 cn, err := c._conn(channels)
240 if err != nil {
241 return err
242 }
243
244 err = c._subscribe(cn, redisCmd, channels)
245 c._releaseConn(cn, err, false)
246 return err
247 }
248
249 func (c *PubSub) Ping(payload ...string) error {
250 args := []interface{}{"ping"}
251 if len(payload) == 1 {
252 args = append(args, payload[0])
253 }
254 cmd := NewCmd(args...)
255
256 cn, err := c.conn()
257 if err != nil {
258 return err
259 }
260
261 err = c.writeCmd(cn, cmd)
262 c.releaseConn(cn, err, false)
263 return err
264 }
265
266
267 type Subscription struct {
268
269 Kind string
270
271 Channel string
272
273 Count int
274 }
275
276 func (m *Subscription) String() string {
277 return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
278 }
279
280
281 type Message struct {
282 Channel string
283 Pattern string
284 Payload string
285 }
286
287 func (m *Message) String() string {
288 return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
289 }
290
291
292 type Pong struct {
293 Payload string
294 }
295
296 func (p *Pong) String() string {
297 if p.Payload != "" {
298 return fmt.Sprintf("Pong<%s>", p.Payload)
299 }
300 return "Pong"
301 }
302
303 func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
304 switch reply := reply.(type) {
305 case string:
306 return &Pong{
307 Payload: reply,
308 }, nil
309 case []interface{}:
310 switch kind := reply[0].(string); kind {
311 case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
312 channel, _ := reply[1].(string)
313 return &Subscription{
314 Kind: kind,
315 Channel: channel,
316 Count: int(reply[2].(int64)),
317 }, nil
318 case "message":
319 return &Message{
320 Channel: reply[1].(string),
321 Payload: reply[2].(string),
322 }, nil
323 case "pmessage":
324 return &Message{
325 Pattern: reply[1].(string),
326 Channel: reply[2].(string),
327 Payload: reply[3].(string),
328 }, nil
329 case "pong":
330 return &Pong{
331 Payload: reply[1].(string),
332 }, nil
333 default:
334 return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
335 }
336 default:
337 return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
338 }
339 }
340
341
342
343
344 func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
345 if c.cmd == nil {
346 c.cmd = NewCmd()
347 }
348
349 cn, err := c.conn()
350 if err != nil {
351 return nil, err
352 }
353
354 err = cn.WithReader(timeout, func(rd *proto.Reader) error {
355 return c.cmd.readReply(rd)
356 })
357
358 c.releaseConn(cn, err, timeout > 0)
359 if err != nil {
360 return nil, err
361 }
362
363 return c.newMessage(c.cmd.Val())
364 }
365
366
367
368
369 func (c *PubSub) Receive() (interface{}, error) {
370 return c.ReceiveTimeout(0)
371 }
372
373
374
375
376 func (c *PubSub) ReceiveMessage() (*Message, error) {
377 for {
378 msg, err := c.Receive()
379 if err != nil {
380 return nil, err
381 }
382
383 switch msg := msg.(type) {
384 case *Subscription:
385
386 case *Pong:
387
388 case *Message:
389 return msg, nil
390 default:
391 err := fmt.Errorf("redis: unknown message: %T", msg)
392 return nil, err
393 }
394 }
395 }
396
397
398
399
400
401
402
403 func (c *PubSub) Channel() <-chan *Message {
404 return c.channel(100)
405 }
406
407
408
409 func (c *PubSub) ChannelSize(size int) <-chan *Message {
410 return c.channel(size)
411 }
412
413 func (c *PubSub) channel(size int) <-chan *Message {
414 c.chOnce.Do(func() {
415 c.initChannel(size)
416 })
417 if cap(c.ch) != size {
418 err := fmt.Errorf("redis: PubSub.Channel is called with different buffer size")
419 panic(err)
420 }
421 return c.ch
422 }
423
424 func (c *PubSub) initChannel(size int) {
425 const timeout = 30 * time.Second
426
427 c.ch = make(chan *Message, size)
428 c.ping = make(chan struct{}, 1)
429
430 go func() {
431 timer := time.NewTimer(timeout)
432 timer.Stop()
433
434 var errCount int
435 for {
436 msg, err := c.Receive()
437 if err != nil {
438 if err == pool.ErrClosed {
439 close(c.ch)
440 return
441 }
442 if errCount > 0 {
443 time.Sleep(c.retryBackoff(errCount))
444 }
445 errCount++
446 continue
447 }
448
449 errCount = 0
450
451
452 select {
453 case c.ping <- struct{}{}:
454 default:
455 }
456
457 switch msg := msg.(type) {
458 case *Subscription:
459
460 case *Pong:
461
462 case *Message:
463 timer.Reset(timeout)
464 select {
465 case c.ch <- msg:
466 if !timer.Stop() {
467 <-timer.C
468 }
469 case <-timer.C:
470 internal.Logf(
471 "redis: %s channel is full for %s (message is dropped)",
472 c, timeout)
473 }
474 default:
475 internal.Logf("redis: unknown message type: %T", msg)
476 }
477 }
478 }()
479
480 go func() {
481 timer := time.NewTimer(timeout)
482 timer.Stop()
483
484 healthy := true
485 for {
486 timer.Reset(timeout)
487 select {
488 case <-c.ping:
489 healthy = true
490 if !timer.Stop() {
491 <-timer.C
492 }
493 case <-timer.C:
494 pingErr := c.Ping()
495 if healthy {
496 healthy = false
497 } else {
498 if pingErr == nil {
499 pingErr = errPingTimeout
500 }
501 c.mu.Lock()
502 c._reconnect(pingErr)
503 c.mu.Unlock()
504 }
505 case <-c.exit:
506 return
507 }
508 }
509 }()
510 }
511
512 func (c *PubSub) retryBackoff(attempt int) time.Duration {
513 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
514 }
515
View as plain text