barometer: update DMA's vendoring packages
[barometer.git] / src / dma / vendor / github.com / go-redis / redis / pubsub.go
index 2cfcd15..0afb47c 100644 (file)
@@ -1,15 +1,19 @@
 package redis
 
 import (
+       "errors"
        "fmt"
        "sync"
        "time"
 
        "github.com/go-redis/redis/internal"
        "github.com/go-redis/redis/internal/pool"
+       "github.com/go-redis/redis/internal/proto"
 )
 
-// PubSub implements Pub/Sub commands as described in
+var errPingTimeout = errors.New("redis: ping timeout")
+
+// PubSub implements Pub/Sub commands bas described in
 // http://redis.io/topics/pubsub. Message receiving is NOT safe
 // for concurrent use by multiple goroutines.
 //
@@ -46,15 +50,17 @@ func (c *PubSub) conn() (*pool.Conn, error) {
        return cn, err
 }
 
-func (c *PubSub) _conn(channels []string) (*pool.Conn, error) {
+func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
        if c.closed {
                return nil, pool.ErrClosed
        }
-
        if c.cn != nil {
                return c.cn, nil
        }
 
+       channels := mapKeys(c.channels)
+       channels = append(channels, newChannels...)
+
        cn, err := c.newConn(channels)
        if err != nil {
                return nil, err
@@ -69,20 +75,24 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) {
        return cn, nil
 }
 
+func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error {
+       return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
+               return writeCmd(wr, cmd)
+       })
+}
+
 func (c *PubSub) resubscribe(cn *pool.Conn) error {
        var firstErr error
 
        if len(c.channels) > 0 {
-               channels := mapKeys(c.channels)
-               err := c._subscribe(cn, "subscribe", channels...)
+               err := c._subscribe(cn, "subscribe", mapKeys(c.channels))
                if err != nil && firstErr == nil {
                        firstErr = err
                }
        }
 
        if len(c.patterns) > 0 {
-               patterns := mapKeys(c.patterns)
-               err := c._subscribe(cn, "psubscribe", patterns...)
+               err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns))
                if err != nil && firstErr == nil {
                        firstErr = err
                }
@@ -101,51 +111,48 @@ func mapKeys(m map[string]struct{}) []string {
        return s
 }
 
-func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error {
-       args := make([]interface{}, 1+len(channels))
-       args[0] = redisCmd
-       for i, channel := range channels {
-               args[1+i] = channel
+func (c *PubSub) _subscribe(
+       cn *pool.Conn, redisCmd string, channels []string,
+) error {
+       args := make([]interface{}, 0, 1+len(channels))
+       args = append(args, redisCmd)
+       for _, channel := range channels {
+               args = append(args, channel)
        }
        cmd := NewSliceCmd(args...)
-
-       cn.SetWriteTimeout(c.opt.WriteTimeout)
-       return writeCmd(cn, cmd)
+       return c.writeCmd(cn, cmd)
 }
 
-func (c *PubSub) releaseConn(cn *pool.Conn, err error) {
+func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
        c.mu.Lock()
-       c._releaseConn(cn, err)
+       c._releaseConn(cn, err, allowTimeout)
        c.mu.Unlock()
 }
 
-func (c *PubSub) _releaseConn(cn *pool.Conn, err error) {
+func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
        if c.cn != cn {
                return
        }
-       if internal.IsBadConn(err, true) {
-               c._reconnect()
+       if internal.IsBadConn(err, allowTimeout) {
+               c._reconnect(err)
        }
 }
 
-func (c *PubSub) _closeTheCn() error {
-       var err error
-       if c.cn != nil {
-               err = c.closeConn(c.cn)
-               c.cn = nil
-       }
-       return err
-}
-
-func (c *PubSub) reconnect() {
-       c.mu.Lock()
-       c._reconnect()
-       c.mu.Unlock()
+func (c *PubSub) _reconnect(reason error) {
+       _ = c._closeTheCn(reason)
+       _, _ = c._conn(nil)
 }
 
-func (c *PubSub) _reconnect() {
-       _ = c._closeTheCn()
-       _, _ = c._conn(nil)
+func (c *PubSub) _closeTheCn(reason error) error {
+       if c.cn == nil {
+               return nil
+       }
+       if !c.closed {
+               internal.Logf("redis: discarding bad PubSub connection: %s", reason)
+       }
+       err := c.closeConn(c.cn)
+       c.cn = nil
+       return err
 }
 
 func (c *PubSub) Close() error {
@@ -158,7 +165,7 @@ func (c *PubSub) Close() error {
        c.closed = true
        close(c.exit)
 
-       err := c._closeTheCn()
+       err := c._closeTheCn(pool.ErrClosed)
        return err
 }
 
@@ -172,8 +179,8 @@ func (c *PubSub) Subscribe(channels ...string) error {
        if c.channels == nil {
                c.channels = make(map[string]struct{})
        }
-       for _, channel := range channels {
-               c.channels[channel] = struct{}{}
+       for _, s := range channels {
+               c.channels[s] = struct{}{}
        }
        return err
 }
@@ -188,8 +195,8 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
        if c.patterns == nil {
                c.patterns = make(map[string]struct{})
        }
-       for _, pattern := range patterns {
-               c.patterns[pattern] = struct{}{}
+       for _, s := range patterns {
+               c.patterns[s] = struct{}{}
        }
        return err
 }
@@ -200,10 +207,10 @@ func (c *PubSub) Unsubscribe(channels ...string) error {
        c.mu.Lock()
        defer c.mu.Unlock()
 
-       err := c.subscribe("unsubscribe", channels...)
        for _, channel := range channels {
                delete(c.channels, channel)
        }
+       err := c.subscribe("unsubscribe", channels...)
        return err
 }
 
@@ -213,10 +220,10 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
        c.mu.Lock()
        defer c.mu.Unlock()
 
-       err := c.subscribe("punsubscribe", patterns...)
        for _, pattern := range patterns {
                delete(c.patterns, pattern)
        }
+       err := c.subscribe("punsubscribe", patterns...)
        return err
 }
 
@@ -226,8 +233,8 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
                return err
        }
 
-       err = c._subscribe(cn, redisCmd, channels...)
-       c._releaseConn(cn, err)
+       err = c._subscribe(cn, redisCmd, channels)
+       c._releaseConn(cn, err, false)
        return err
 }
 
@@ -243,9 +250,8 @@ func (c *PubSub) Ping(payload ...string) error {
                return err
        }
 
-       cn.SetWriteTimeout(c.opt.WriteTimeout)
-       err = writeCmd(cn, cmd)
-       c.releaseConn(cn, err)
+       err = c.writeCmd(cn, cmd)
+       c.releaseConn(cn, err, false)
        return err
 }
 
@@ -336,9 +342,11 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
                return nil, err
        }
 
-       cn.SetReadTimeout(timeout)
-       err = c.cmd.readReply(cn)
-       c.releaseConn(cn, err)
+       err = cn.WithReader(timeout, func(rd *proto.Reader) error {
+               return c.cmd.readReply(rd)
+       })
+
+       c.releaseConn(cn, err, timeout > 0)
        if err != nil {
                return nil, err
        }
@@ -432,21 +440,26 @@ func (c *PubSub) initChannel() {
                timer := time.NewTimer(timeout)
                timer.Stop()
 
-               var hasPing bool
+               healthy := true
                for {
                        timer.Reset(timeout)
                        select {
                        case <-c.ping:
-                               hasPing = true
+                               healthy = true
                                if !timer.Stop() {
                                        <-timer.C
                                }
                        case <-timer.C:
-                               if hasPing {
-                                       hasPing = false
-                                       _ = c.Ping()
+                               pingErr := c.Ping()
+                               if healthy {
+                                       healthy = false
                                } else {
-                                       c.reconnect()
+                                       if pingErr == nil {
+                                               pingErr = errPingTimeout
+                                       }
+                                       c.mu.Lock()
+                                       c._reconnect(pingErr)
+                                       c.mu.Unlock()
                                }
                        case <-c.exit:
                                return