barometer: update DMA's vendoring packages
[barometer.git] / src / dma / vendor / github.com / go-redis / redis / redis.go
index c0f142c..aca3064 100644 (file)
@@ -26,6 +26,7 @@ func SetLogger(logger *log.Logger) {
 type baseClient struct {
        opt      *Options
        connPool pool.Pooler
+       limiter  Limiter
 
        process           func(Cmder) error
        processPipeline   func([]Cmder) error
@@ -50,7 +51,7 @@ func (c *baseClient) newConn() (*pool.Conn, error) {
                return nil, err
        }
 
-       if !cn.Inited {
+       if cn.InitedAt.IsZero() {
                if err := c.initConn(cn); err != nil {
                        _ = c.connPool.CloseConn(cn)
                        return nil, err
@@ -61,12 +62,30 @@ func (c *baseClient) newConn() (*pool.Conn, error) {
 }
 
 func (c *baseClient) getConn() (*pool.Conn, error) {
+       if c.limiter != nil {
+               err := c.limiter.Allow()
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       cn, err := c._getConn()
+       if err != nil {
+               if c.limiter != nil {
+                       c.limiter.ReportResult(err)
+               }
+               return nil, err
+       }
+       return cn, nil
+}
+
+func (c *baseClient) _getConn() (*pool.Conn, error) {
        cn, err := c.connPool.Get()
        if err != nil {
                return nil, err
        }
 
-       if !cn.Inited {
+       if cn.InitedAt.IsZero() {
                err := c.initConn(cn)
                if err != nil {
                        c.connPool.Remove(cn)
@@ -77,18 +96,32 @@ func (c *baseClient) getConn() (*pool.Conn, error) {
        return cn, nil
 }
 
-func (c *baseClient) releaseConn(cn *pool.Conn, err error) bool {
+func (c *baseClient) releaseConn(cn *pool.Conn, err error) {
+       if c.limiter != nil {
+               c.limiter.ReportResult(err)
+       }
+
        if internal.IsBadConn(err, false) {
                c.connPool.Remove(cn)
-               return false
+       } else {
+               c.connPool.Put(cn)
+       }
+}
+
+func (c *baseClient) releaseConnStrict(cn *pool.Conn, err error) {
+       if c.limiter != nil {
+               c.limiter.ReportResult(err)
        }
 
-       c.connPool.Put(cn)
-       return true
+       if err == nil || internal.IsRedisError(err) {
+               c.connPool.Put(cn)
+       } else {
+               c.connPool.Remove(cn)
+       }
 }
 
 func (c *baseClient) initConn(cn *pool.Conn) error {
-       cn.Inited = true
+       cn.InitedAt = time.Now()
 
        if c.opt.Password == "" &&
                c.opt.DB == 0 &&
@@ -123,8 +156,17 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
        return nil
 }
 
+// Do creates a Cmd from the args and processes the cmd.
+func (c *baseClient) Do(args ...interface{}) *Cmd {
+       cmd := NewCmd(args...)
+       _ = c.Process(cmd)
+       return cmd
+}
+
 // WrapProcess wraps function that processes Redis commands.
-func (c *baseClient) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error) {
+func (c *baseClient) WrapProcess(
+       fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error,
+) {
        c.process = fn(c.process)
 }
 
@@ -147,8 +189,10 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
                        return err
                }
 
-               cn.SetWriteTimeout(c.opt.WriteTimeout)
-               if err := writeCmd(cn, cmd); err != nil {
+               err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
+                       return writeCmd(wr, cmd)
+               })
+               if err != nil {
                        c.releaseConn(cn, err)
                        cmd.setErr(err)
                        if internal.IsRetryableError(err, true) {
@@ -157,8 +201,9 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
                        return err
                }
 
-               cn.SetReadTimeout(c.cmdTimeout(cmd))
-               err = cmd.readReply(cn)
+               err = cn.WithReader(c.cmdTimeout(cmd), func(rd *proto.Reader) error {
+                       return cmd.readReply(rd)
+               })
                c.releaseConn(cn, err)
                if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) {
                        continue
@@ -176,7 +221,11 @@ func (c *baseClient) retryBackoff(attempt int) time.Duration {
 
 func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
        if timeout := cmd.readTimeout(); timeout != nil {
-               return readTimeout(*timeout)
+               t := *timeout
+               if t == 0 {
+                       return 0
+               }
+               return t + 10*time.Second
        }
        return c.opt.ReadTimeout
 }
@@ -232,35 +281,33 @@ func (c *baseClient) generalProcessPipeline(cmds []Cmder, p pipelineProcessor) e
                }
 
                canRetry, err := p(cn, cmds)
-
-               if err == nil || internal.IsRedisError(err) {
-                       c.connPool.Put(cn)
-                       break
-               }
-               c.connPool.Remove(cn)
+               c.releaseConnStrict(cn, err)
 
                if !canRetry || !internal.IsRetryableError(err, true) {
                        break
                }
        }
-       return firstCmdsErr(cmds)
+       return cmdsFirstErr(cmds)
 }
 
 func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
-       cn.SetWriteTimeout(c.opt.WriteTimeout)
-       if err := writeCmd(cn, cmds...); err != nil {
+       err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
+               return writeCmd(wr, cmds...)
+       })
+       if err != nil {
                setCmdsErr(cmds, err)
                return true, err
        }
 
-       // Set read timeout for all commands.
-       cn.SetReadTimeout(c.opt.ReadTimeout)
-       return true, pipelineReadCmds(cn, cmds)
+       err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
+               return pipelineReadCmds(rd, cmds)
+       })
+       return true, err
 }
 
-func pipelineReadCmds(cn *pool.Conn, cmds []Cmder) error {
+func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
        for _, cmd := range cmds {
-               err := cmd.readReply(cn)
+               err := cmd.readReply(rd)
                if err != nil && !internal.IsRedisError(err) {
                        return err
                }
@@ -269,47 +316,50 @@ func pipelineReadCmds(cn *pool.Conn, cmds []Cmder) error {
 }
 
 func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
-       cn.SetWriteTimeout(c.opt.WriteTimeout)
-       if err := txPipelineWriteMulti(cn, cmds); err != nil {
+       err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
+               return txPipelineWriteMulti(wr, cmds)
+       })
+       if err != nil {
                setCmdsErr(cmds, err)
                return true, err
        }
 
-       // Set read timeout for all commands.
-       cn.SetReadTimeout(c.opt.ReadTimeout)
-
-       if err := c.txPipelineReadQueued(cn, cmds); err != nil {
-               setCmdsErr(cmds, err)
-               return false, err
-       }
-
-       return false, pipelineReadCmds(cn, cmds)
+       err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
+               err := txPipelineReadQueued(rd, cmds)
+               if err != nil {
+                       setCmdsErr(cmds, err)
+                       return err
+               }
+               return pipelineReadCmds(rd, cmds)
+       })
+       return false, err
 }
 
-func txPipelineWriteMulti(cn *pool.Conn, cmds []Cmder) error {
+func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error {
        multiExec := make([]Cmder, 0, len(cmds)+2)
        multiExec = append(multiExec, NewStatusCmd("MULTI"))
        multiExec = append(multiExec, cmds...)
        multiExec = append(multiExec, NewSliceCmd("EXEC"))
-       return writeCmd(cn, multiExec...)
+       return writeCmd(wr, multiExec...)
 }
 
-func (c *baseClient) txPipelineReadQueued(cn *pool.Conn, cmds []Cmder) error {
+func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error {
        // Parse queued replies.
        var statusCmd StatusCmd
-       if err := statusCmd.readReply(cn); err != nil {
+       err := statusCmd.readReply(rd)
+       if err != nil {
                return err
        }
 
-       for _ = range cmds {
-               err := statusCmd.readReply(cn)
+       for range cmds {
+               err = statusCmd.readReply(rd)
                if err != nil && !internal.IsRedisError(err) {
                        return err
                }
        }
 
        // Parse number of replies.
-       line, err := cn.Rd.ReadLine()
+       line, err := rd.ReadLine()
        if err != nil {
                if err == Nil {
                        err = TxFailedErr
@@ -373,12 +423,12 @@ func (c *Client) WithContext(ctx context.Context) *Client {
        if ctx == nil {
                panic("nil context")
        }
-       c2 := c.copy()
+       c2 := c.clone()
        c2.ctx = ctx
        return c2
 }
 
-func (c *Client) copy() *Client {
+func (c *Client) clone() *Client {
        cp := *c
        cp.init()
        return &cp
@@ -389,6 +439,11 @@ func (c *Client) Options() *Options {
        return c.opt
 }
 
+func (c *Client) SetLimiter(l Limiter) *Client {
+       c.limiter = l
+       return c
+}
+
 type PoolStats pool.Stats
 
 // PoolStats returns connection pool stats.
@@ -437,6 +492,30 @@ func (c *Client) pubSub() *PubSub {
 
 // Subscribe subscribes the client to the specified channels.
 // Channels can be omitted to create empty subscription.
+// Note that this method does not wait on a response from Redis, so the
+// subscription may not be active immediately. To force the connection to wait,
+// you may call the Receive() method on the returned *PubSub like so:
+//
+//    sub := client.Subscribe(queryResp)
+//    iface, err := sub.Receive()
+//    if err != nil {
+//        // handle error
+//    }
+//
+//    // Should be *Subscription, but others are possible if other actions have been
+//    // taken on sub since it was created.
+//    switch iface.(type) {
+//    case *Subscription:
+//        // subscribe succeeded
+//    case *Message:
+//        // received first message
+//    case *Pong:
+//        // pong received
+//    default:
+//        // handle error
+//    }
+//
+//    ch := sub.Channel()
 func (c *Client) Subscribe(channels ...string) *PubSub {
        pubsub := c.pubSub()
        if len(channels) > 0 {