...

Source file src/github.com/go-redis/redis/pubsub.go

Documentation: github.com/go-redis/redis

     1  package redis
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/go-redis/redis/internal"
    11  	"github.com/go-redis/redis/internal/pool"
    12  	"github.com/go-redis/redis/internal/proto"
    13  )
    14  
    15  var errPingTimeout = errors.New("redis: ping timeout")
    16  
    17  // PubSub implements Pub/Sub commands as described in
    18  // http://redis.io/topics/pubsub. Message receiving is NOT safe
    19  // for concurrent use by multiple goroutines.
    20  //
    21  // PubSub automatically reconnects to Redis Server and resubscribes
    22  // to the channels in case of network errors.
    23  type PubSub struct {
    24  	opt *Options
    25  
    26  	newConn   func([]string) (*pool.Conn, error)
    27  	closeConn func(*pool.Conn) error
    28  
    29  	mu       sync.Mutex
    30  	cn       *pool.Conn
    31  	channels map[string]struct{}
    32  	patterns map[string]struct{}
    33  
    34  	closed bool
    35  	exit   chan struct{}
    36  
    37  	cmd *Cmd
    38  
    39  	chOnce sync.Once
    40  	ch     chan *Message
    41  	ping   chan struct{}
    42  }
    43  
    44  func (c *PubSub) String() string {
    45  	channels := mapKeys(c.channels)
    46  	channels = append(channels, mapKeys(c.patterns)...)
    47  	return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
    48  }
    49  
    50  func (c *PubSub) init() {
    51  	c.exit = make(chan struct{})
    52  }
    53  
    54  func (c *PubSub) conn() (*pool.Conn, error) {
    55  	c.mu.Lock()
    56  	cn, err := c._conn(nil)
    57  	c.mu.Unlock()
    58  	return cn, err
    59  }
    60  
    61  func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
    62  	if c.closed {
    63  		return nil, pool.ErrClosed
    64  	}
    65  	if c.cn != nil {
    66  		return c.cn, nil
    67  	}
    68  
    69  	channels := mapKeys(c.channels)
    70  	channels = append(channels, newChannels...)
    71  
    72  	cn, err := c.newConn(channels)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	if err := c.resubscribe(cn); err != nil {
    78  		_ = c.closeConn(cn)
    79  		return nil, err
    80  	}
    81  
    82  	c.cn = cn
    83  	return cn, nil
    84  }
    85  
    86  func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error {
    87  	return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
    88  		return writeCmd(wr, cmd)
    89  	})
    90  }
    91  
    92  func (c *PubSub) resubscribe(cn *pool.Conn) error {
    93  	var firstErr error
    94  
    95  	if len(c.channels) > 0 {
    96  		err := c._subscribe(cn, "subscribe", mapKeys(c.channels))
    97  		if err != nil && firstErr == nil {
    98  			firstErr = err
    99  		}
   100  	}
   101  
   102  	if len(c.patterns) > 0 {
   103  		err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns))
   104  		if err != nil && firstErr == nil {
   105  			firstErr = err
   106  		}
   107  	}
   108  
   109  	return firstErr
   110  }
   111  
   112  func mapKeys(m map[string]struct{}) []string {
   113  	s := make([]string, len(m))
   114  	i := 0
   115  	for k := range m {
   116  		s[i] = k
   117  		i++
   118  	}
   119  	return s
   120  }
   121  
   122  func (c *PubSub) _subscribe(
   123  	cn *pool.Conn, redisCmd string, channels []string,
   124  ) error {
   125  	args := make([]interface{}, 0, 1+len(channels))
   126  	args = append(args, redisCmd)
   127  	for _, channel := range channels {
   128  		args = append(args, channel)
   129  	}
   130  	cmd := NewSliceCmd(args...)
   131  	return c.writeCmd(cn, cmd)
   132  }
   133  
   134  func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
   135  	c.mu.Lock()
   136  	c._releaseConn(cn, err, allowTimeout)
   137  	c.mu.Unlock()
   138  }
   139  
   140  func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
   141  	if c.cn != cn {
   142  		return
   143  	}
   144  	if internal.IsBadConn(err, allowTimeout) {
   145  		c._reconnect(err)
   146  	}
   147  }
   148  
   149  func (c *PubSub) _reconnect(reason error) {
   150  	_ = c._closeTheCn(reason)
   151  	_, _ = c._conn(nil)
   152  }
   153  
   154  func (c *PubSub) _closeTheCn(reason error) error {
   155  	if c.cn == nil {
   156  		return nil
   157  	}
   158  	if !c.closed {
   159  		internal.Logf("redis: discarding bad PubSub connection: %s", reason)
   160  	}
   161  	err := c.closeConn(c.cn)
   162  	c.cn = nil
   163  	return err
   164  }
   165  
   166  func (c *PubSub) Close() error {
   167  	c.mu.Lock()
   168  	defer c.mu.Unlock()
   169  
   170  	if c.closed {
   171  		return pool.ErrClosed
   172  	}
   173  	c.closed = true
   174  	close(c.exit)
   175  
   176  	err := c._closeTheCn(pool.ErrClosed)
   177  	return err
   178  }
   179  
   180  // Subscribe the client to the specified channels. It returns
   181  // empty subscription if there are no channels.
   182  func (c *PubSub) Subscribe(channels ...string) error {
   183  	c.mu.Lock()
   184  	defer c.mu.Unlock()
   185  
   186  	err := c.subscribe("subscribe", channels...)
   187  	if c.channels == nil {
   188  		c.channels = make(map[string]struct{})
   189  	}
   190  	for _, s := range channels {
   191  		c.channels[s] = struct{}{}
   192  	}
   193  	return err
   194  }
   195  
   196  // PSubscribe the client to the given patterns. It returns
   197  // empty subscription if there are no patterns.
   198  func (c *PubSub) PSubscribe(patterns ...string) error {
   199  	c.mu.Lock()
   200  	defer c.mu.Unlock()
   201  
   202  	err := c.subscribe("psubscribe", patterns...)
   203  	if c.patterns == nil {
   204  		c.patterns = make(map[string]struct{})
   205  	}
   206  	for _, s := range patterns {
   207  		c.patterns[s] = struct{}{}
   208  	}
   209  	return err
   210  }
   211  
   212  // Unsubscribe the client from the given channels, or from all of
   213  // them if none is given.
   214  func (c *PubSub) Unsubscribe(channels ...string) error {
   215  	c.mu.Lock()
   216  	defer c.mu.Unlock()
   217  
   218  	for _, channel := range channels {
   219  		delete(c.channels, channel)
   220  	}
   221  	err := c.subscribe("unsubscribe", channels...)
   222  	return err
   223  }
   224  
   225  // PUnsubscribe the client from the given patterns, or from all of
   226  // them if none is given.
   227  func (c *PubSub) PUnsubscribe(patterns ...string) error {
   228  	c.mu.Lock()
   229  	defer c.mu.Unlock()
   230  
   231  	for _, pattern := range patterns {
   232  		delete(c.patterns, pattern)
   233  	}
   234  	err := c.subscribe("punsubscribe", patterns...)
   235  	return err
   236  }
   237  
   238  func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
   239  	cn, err := c._conn(channels)
   240  	if err != nil {
   241  		return err
   242  	}
   243  
   244  	err = c._subscribe(cn, redisCmd, channels)
   245  	c._releaseConn(cn, err, false)
   246  	return err
   247  }
   248  
   249  func (c *PubSub) Ping(payload ...string) error {
   250  	args := []interface{}{"ping"}
   251  	if len(payload) == 1 {
   252  		args = append(args, payload[0])
   253  	}
   254  	cmd := NewCmd(args...)
   255  
   256  	cn, err := c.conn()
   257  	if err != nil {
   258  		return err
   259  	}
   260  
   261  	err = c.writeCmd(cn, cmd)
   262  	c.releaseConn(cn, err, false)
   263  	return err
   264  }
   265  
   266  // Subscription received after a successful subscription to channel.
   267  type Subscription struct {
   268  	// Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
   269  	Kind string
   270  	// Channel name we have subscribed to.
   271  	Channel string
   272  	// Number of channels we are currently subscribed to.
   273  	Count int
   274  }
   275  
   276  func (m *Subscription) String() string {
   277  	return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
   278  }
   279  
   280  // Message received as result of a PUBLISH command issued by another client.
   281  type Message struct {
   282  	Channel string
   283  	Pattern string
   284  	Payload string
   285  }
   286  
   287  func (m *Message) String() string {
   288  	return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
   289  }
   290  
   291  // Pong received as result of a PING command issued by another client.
   292  type Pong struct {
   293  	Payload string
   294  }
   295  
   296  func (p *Pong) String() string {
   297  	if p.Payload != "" {
   298  		return fmt.Sprintf("Pong<%s>", p.Payload)
   299  	}
   300  	return "Pong"
   301  }
   302  
   303  func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
   304  	switch reply := reply.(type) {
   305  	case string:
   306  		return &Pong{
   307  			Payload: reply,
   308  		}, nil
   309  	case []interface{}:
   310  		switch kind := reply[0].(string); kind {
   311  		case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
   312  			channel, _ := reply[1].(string)
   313  			return &Subscription{
   314  				Kind:    kind,
   315  				Channel: channel,
   316  				Count:   int(reply[2].(int64)),
   317  			}, nil
   318  		case "message":
   319  			return &Message{
   320  				Channel: reply[1].(string),
   321  				Payload: reply[2].(string),
   322  			}, nil
   323  		case "pmessage":
   324  			return &Message{
   325  				Pattern: reply[1].(string),
   326  				Channel: reply[2].(string),
   327  				Payload: reply[3].(string),
   328  			}, nil
   329  		case "pong":
   330  			return &Pong{
   331  				Payload: reply[1].(string),
   332  			}, nil
   333  		default:
   334  			return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
   335  		}
   336  	default:
   337  		return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
   338  	}
   339  }
   340  
   341  // ReceiveTimeout acts like Receive but returns an error if message
   342  // is not received in time. This is low-level API and in most cases
   343  // Channel should be used instead.
   344  func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
   345  	if c.cmd == nil {
   346  		c.cmd = NewCmd()
   347  	}
   348  
   349  	cn, err := c.conn()
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  
   354  	err = cn.WithReader(timeout, func(rd *proto.Reader) error {
   355  		return c.cmd.readReply(rd)
   356  	})
   357  
   358  	c.releaseConn(cn, err, timeout > 0)
   359  	if err != nil {
   360  		return nil, err
   361  	}
   362  
   363  	return c.newMessage(c.cmd.Val())
   364  }
   365  
   366  // Receive returns a message as a Subscription, Message, Pong or error.
   367  // See PubSub example for details. This is low-level API and in most cases
   368  // Channel should be used instead.
   369  func (c *PubSub) Receive() (interface{}, error) {
   370  	return c.ReceiveTimeout(0)
   371  }
   372  
   373  // ReceiveMessage returns a Message or error ignoring Subscription and Pong
   374  // messages. This is low-level API and in most cases Channel should be used
   375  // instead.
   376  func (c *PubSub) ReceiveMessage() (*Message, error) {
   377  	for {
   378  		msg, err := c.Receive()
   379  		if err != nil {
   380  			return nil, err
   381  		}
   382  
   383  		switch msg := msg.(type) {
   384  		case *Subscription:
   385  			// Ignore.
   386  		case *Pong:
   387  			// Ignore.
   388  		case *Message:
   389  			return msg, nil
   390  		default:
   391  			err := fmt.Errorf("redis: unknown message: %T", msg)
   392  			return nil, err
   393  		}
   394  	}
   395  }
   396  
   397  // Channel returns a Go channel for concurrently receiving messages.
   398  // It periodically sends Ping messages to test connection health.
   399  // The channel is closed with PubSub. Receive* APIs can not be used
   400  // after channel is created.
   401  //
   402  // If the Go channel is full for 30 seconds the message is dropped.
   403  func (c *PubSub) Channel() <-chan *Message {
   404  	return c.channel(100)
   405  }
   406  
   407  // ChannelSize is like Channel, but creates a Go channel
   408  // with specified buffer size.
   409  func (c *PubSub) ChannelSize(size int) <-chan *Message {
   410  	return c.channel(size)
   411  }
   412  
   413  func (c *PubSub) channel(size int) <-chan *Message {
   414  	c.chOnce.Do(func() {
   415  		c.initChannel(size)
   416  	})
   417  	if cap(c.ch) != size {
   418  		err := fmt.Errorf("redis: PubSub.Channel is called with different buffer size")
   419  		panic(err)
   420  	}
   421  	return c.ch
   422  }
   423  
   424  func (c *PubSub) initChannel(size int) {
   425  	const timeout = 30 * time.Second
   426  
   427  	c.ch = make(chan *Message, size)
   428  	c.ping = make(chan struct{}, 1)
   429  
   430  	go func() {
   431  		timer := time.NewTimer(timeout)
   432  		timer.Stop()
   433  
   434  		var errCount int
   435  		for {
   436  			msg, err := c.Receive()
   437  			if err != nil {
   438  				if err == pool.ErrClosed {
   439  					close(c.ch)
   440  					return
   441  				}
   442  				if errCount > 0 {
   443  					time.Sleep(c.retryBackoff(errCount))
   444  				}
   445  				errCount++
   446  				continue
   447  			}
   448  
   449  			errCount = 0
   450  
   451  			// Any message is as good as a ping.
   452  			select {
   453  			case c.ping <- struct{}{}:
   454  			default:
   455  			}
   456  
   457  			switch msg := msg.(type) {
   458  			case *Subscription:
   459  				// Ignore.
   460  			case *Pong:
   461  				// Ignore.
   462  			case *Message:
   463  				timer.Reset(timeout)
   464  				select {
   465  				case c.ch <- msg:
   466  					if !timer.Stop() {
   467  						<-timer.C
   468  					}
   469  				case <-timer.C:
   470  					internal.Logf(
   471  						"redis: %s channel is full for %s (message is dropped)",
   472  						c, timeout)
   473  				}
   474  			default:
   475  				internal.Logf("redis: unknown message type: %T", msg)
   476  			}
   477  		}
   478  	}()
   479  
   480  	go func() {
   481  		timer := time.NewTimer(timeout)
   482  		timer.Stop()
   483  
   484  		healthy := true
   485  		for {
   486  			timer.Reset(timeout)
   487  			select {
   488  			case <-c.ping:
   489  				healthy = true
   490  				if !timer.Stop() {
   491  					<-timer.C
   492  				}
   493  			case <-timer.C:
   494  				pingErr := c.Ping()
   495  				if healthy {
   496  					healthy = false
   497  				} else {
   498  					if pingErr == nil {
   499  						pingErr = errPingTimeout
   500  					}
   501  					c.mu.Lock()
   502  					c._reconnect(pingErr)
   503  					c.mu.Unlock()
   504  				}
   505  			case <-c.exit:
   506  				return
   507  			}
   508  		}
   509  	}()
   510  }
   511  
   512  func (c *PubSub) retryBackoff(attempt int) time.Duration {
   513  	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
   514  }
   515  

View as plain text