diff --git a/conn.go b/conn.go index 315148ee..c1b31859 100644 --- a/conn.go +++ b/conn.go @@ -52,7 +52,6 @@ type Conn struct { messagesInFlight int64 maxRdyCount int64 rdyCount int64 - lastRdyCount int64 lastRdyTimestamp int64 lastMsgTimestamp int64 @@ -207,13 +206,12 @@ func (c *Conn) RDY() int64 { // LastRDY returns the previously set RDY count func (c *Conn) LastRDY() int64 { - return atomic.LoadInt64(&c.lastRdyCount) + return atomic.LoadInt64(&c.rdyCount) } // SetRDY stores the specified RDY count func (c *Conn) SetRDY(rdy int64) { atomic.StoreInt64(&c.rdyCount, rdy) - atomic.StoreInt64(&c.lastRdyCount, rdy) if rdy > 0 { atomic.StoreInt64(&c.lastRdyTimestamp, time.Now().UnixNano()) } @@ -225,6 +223,8 @@ func (c *Conn) MaxRDY() int64 { return c.maxRdyCount } +// LastRdyTime returns the time of the last non-zero RDY +// update for this connection func (c *Conn) LastRdyTime() time.Time { return time.Unix(0, atomic.LoadInt64(&c.lastRdyTimestamp)) } @@ -523,7 +523,6 @@ func (c *Conn) readLoop() { msg.Delegate = delegate msg.NSQDAddress = c.String() - atomic.AddInt64(&c.rdyCount, -1) atomic.AddInt64(&c.messagesInFlight, 1) atomic.StoreInt64(&c.lastMsgTimestamp, time.Now().UnixNano()) diff --git a/consumer.go b/consumer.go index 78d64941..52b2eb1f 100644 --- a/consumer.go +++ b/consumer.go @@ -111,11 +111,11 @@ type Consumer struct { needRDYRedistributed int32 - backoffMtx sync.RWMutex + backoffMtx sync.Mutex incomingMessages chan *Message - rdyRetryMtx sync.RWMutex + rdyRetryMtx sync.Mutex rdyRetryTimers map[string]*time.Timer pendingConnections map[string]*Conn @@ -264,7 +264,7 @@ func (r *Consumer) perConnMaxInFlight() int64 { // before being able to receive more messages (ie. RDY count of 0 and not exiting) func (r *Consumer) IsStarved() bool { for _, conn := range r.conns() { - threshold := int64(float64(atomic.LoadInt64(&conn.lastRdyCount)) * 0.85) + threshold := int64(float64(conn.RDY()) * 0.85) inFlight := atomic.LoadInt64(&conn.messagesInFlight) if inFlight >= threshold && inFlight > 0 && !conn.IsClosing() { return true @@ -642,10 +642,8 @@ func (r *Consumer) DisconnectFromNSQLookupd(addr string) error { } func (r *Consumer) onConnMessage(c *Conn, msg *Message) { - atomic.AddInt64(&r.totalRdyCount, -1) atomic.AddUint64(&r.messagesReceived, 1) r.incomingMessages <- msg - r.maybeUpdateRDY(c) } func (r *Consumer) onConnMessageFinished(c *Conn, msg *Message) { @@ -771,11 +769,10 @@ func (r *Consumer) startStopContinueBackoff(conn *Conn, signal backoffSignal) { // max backoff/normal rate (by ensuring that we dont continually incr/decr // the counter during a backoff period) r.backoffMtx.Lock() + defer r.backoffMtx.Unlock() if r.inBackoffTimeout() { - r.backoffMtx.Unlock() return } - defer r.backoffMtx.Unlock() // update backoff state backoffUpdated := false @@ -879,19 +876,9 @@ func (r *Consumer) maybeUpdateRDY(conn *Conn) { return } - remain := conn.RDY() - lastRdyCount := conn.LastRDY() count := r.perConnMaxInFlight() - - // refill when at 1, or at 25%, or if connections have changed and we're imbalanced - if remain <= 1 || remain < (lastRdyCount/4) || (count > 0 && count < remain) { - r.log(LogLevelDebug, "(%s) sending RDY %d (%d remain from last RDY %d)", - conn, count, remain, lastRdyCount) - r.updateRDY(conn, count) - } else { - r.log(LogLevelDebug, "(%s) skip sending RDY %d (%d remain out of last RDY %d)", - conn, count, remain, lastRdyCount) - } + r.log(LogLevelDebug, "(%s) sending RDY %d", conn, count) + r.updateRDY(conn, count) } func (r *Consumer) rdyLoop() { @@ -961,7 +948,7 @@ func (r *Consumer) sendRDY(c *Conn, count int64) error { return nil } - atomic.AddInt64(&r.totalRdyCount, -c.RDY()+count) + atomic.AddInt64(&r.totalRdyCount, count-c.RDY()) c.SetRDY(count) err := c.WriteCommand(Ready(int(count))) if err != nil { diff --git a/consumer_test.go b/consumer_test.go index 4079dc91..5eb5a961 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -10,7 +10,6 @@ import ( "log" "net" "net/http" - "os" "strconv" "strings" "testing" @@ -116,11 +115,6 @@ func TestConsumerTLSSnappy(t *testing.T) { } func TestConsumerTLSClientCert(t *testing.T) { - envDl := os.Getenv("NSQ_DOWNLOAD") - if strings.HasPrefix(envDl, "nsq-0.2.24") || strings.HasPrefix(envDl, "nsq-0.2.27") { - t.Log("skipping due to older nsqd") - return - } cert, _ := tls.LoadX509KeyPair("./test/client.pem", "./test/client.key") consumerTest(t, func(c *Config) { c.TlsV1 = true @@ -132,11 +126,6 @@ func TestConsumerTLSClientCert(t *testing.T) { } func TestConsumerTLSClientCertViaSet(t *testing.T) { - envDl := os.Getenv("NSQ_DOWNLOAD") - if strings.HasPrefix(envDl, "nsq-0.2.24") || strings.HasPrefix(envDl, "nsq-0.2.27") { - t.Log("skipping due to older nsqd") - return - } consumerTest(t, func(c *Config) { c.Set("tls_v1", true) c.Set("tls_cert", "./test/client.pem") @@ -168,7 +157,7 @@ func consumerTest(t *testing.T, cb func(c *Config)) { } topicName = topicName + strconv.Itoa(int(time.Now().Unix())) q, _ := NewConsumer(topicName, "ch", config) - q.SetLogger(log.New(os.Stderr, "", log.Flags()), LogLevelDebug) + q.SetLogger(newTestLogger(t), LogLevelDebug) h := &MyTestHandler{ t: t, diff --git a/mock_test.go b/mock_test.go index 62b9fc00..8d9dbf91 100644 --- a/mock_test.go +++ b/mock_test.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "log" "net" "strconv" "testing" @@ -38,6 +37,7 @@ type instruction struct { } type mockNSQD struct { + t *testing.T script []instruction got [][]byte tcpAddr *net.TCPAddr @@ -45,15 +45,16 @@ type mockNSQD struct { exitChan chan int } -func newMockNSQD(script []instruction, addr string) *mockNSQD { +func newMockNSQD(t *testing.T, script []instruction, addr string) *mockNSQD { n := &mockNSQD{ + t: t, script: script, exitChan: make(chan int), } tcpListener, err := net.Listen("tcp", addr) if err != nil { - log.Fatalf("FATAL: listen (%s) failed - %s", n.tcpAddr.String(), err) + n.t.Fatalf("FATAL: listen (%s) failed - %s", n.tcpAddr, err) } n.tcpListener = tcpListener n.tcpAddr = tcpListener.Addr().(*net.TCPAddr) @@ -64,7 +65,7 @@ func newMockNSQD(script []instruction, addr string) *mockNSQD { } func (n *mockNSQD) listen() { - log.Printf("TCP: listening on %s", n.tcpListener.Addr().String()) + n.t.Logf("TCP: listening on %s", n.tcpListener.Addr()) for { conn, err := n.tcpListener.Accept() @@ -74,19 +75,19 @@ func (n *mockNSQD) listen() { go n.handle(conn) } - log.Printf("TCP: closing %s", n.tcpListener.Addr().String()) + n.t.Logf("TCP: closing %s", n.tcpListener.Addr()) close(n.exitChan) } func (n *mockNSQD) handle(conn net.Conn) { var idx int - log.Printf("TCP: new client(%s)", conn.RemoteAddr()) + n.t.Logf("TCP: new client(%s)", conn.RemoteAddr()) buf := make([]byte, 4) _, err := io.ReadFull(conn, buf) if err != nil { - log.Fatalf("ERROR: failed to read protocol version - %s", err) + n.t.Fatalf("ERROR: failed to read protocol version - %s", err) } readChan := make(chan []byte) @@ -111,7 +112,7 @@ func (n *mockNSQD) handle(conn net.Conn) { for idx < len(n.script) { select { case line := <-readChan: - log.Printf("mock: %s", line) + n.t.Logf("mock: %s", line) n.got = append(n.got, line) params := bytes.Split(line, []byte(" ")) switch { @@ -119,17 +120,17 @@ func (n *mockNSQD) handle(conn net.Conn) { l := make([]byte, 4) _, err := io.ReadFull(rdr, l) if err != nil { - log.Printf(err.Error()) + n.t.Log(err) goto exit } size := int32(binary.BigEndian.Uint32(l)) b := make([]byte, size) _, err = io.ReadFull(rdr, b) if err != nil { - log.Printf(err.Error()) + n.t.Log(err) goto exit } - log.Printf("%s", b) + n.t.Logf("%s", b) case bytes.Equal(params[0], []byte("RDY")): rdy, _ := strconv.Atoi(string(params[1])) rdyCount = rdy @@ -144,7 +145,7 @@ func (n *mockNSQD) handle(conn net.Conn) { } if inst.frameType == FrameTypeMessage { if rdyCount == 0 { - log.Printf("!!! RDY == 0") + n.t.Log("!!! RDY == 0") scriptTime = time.After(n.script[idx+1].delay) continue } @@ -152,7 +153,7 @@ func (n *mockNSQD) handle(conn net.Conn) { } _, err := conn.Write(framedResponse(inst.frameType, inst.body)) if err != nil { - log.Printf(err.Error()) + n.t.Log(err) goto exit } scriptTime = time.After(n.script[idx+1].delay) @@ -220,10 +221,10 @@ func TestConsumerBackoff(t *testing.T) { msgBad := NewMessage(msgIDBad, []byte("bad")) script := []instruction{ - // SUB - instruction{0, FrameTypeResponse, []byte("OK")}, // IDENTIFY instruction{0, FrameTypeResponse, []byte("OK")}, + // SUB + instruction{0, FrameTypeResponse, []byte("OK")}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, @@ -236,7 +237,7 @@ func TestConsumerBackoff(t *testing.T) { } addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:0") - n := newMockNSQD(script, addr.String()) + n := newMockNSQD(t, script, addr.String()) topicName := "test_consumer_commands" + strconv.Itoa(int(time.Now().Unix())) config := NewConfig() @@ -253,7 +254,7 @@ func TestConsumerBackoff(t *testing.T) { <-n.exitChan for i, r := range n.got { - log.Printf("%d: %s", i, r) + t.Logf("%d: %s", i, r) } expected := []string{ @@ -263,7 +264,6 @@ func TestConsumerBackoff(t *testing.T) { fmt.Sprintf("FIN %s", msgIDGood), fmt.Sprintf("FIN %s", msgIDGood), fmt.Sprintf("FIN %s", msgIDGood), - "RDY 5", "RDY 0", fmt.Sprintf("REQ %s 0", msgIDBad), "RDY 1", @@ -296,10 +296,10 @@ func TestConsumerRequeueNoBackoff(t *testing.T) { msgRequeueNoBackoff := NewMessage(msgIDRequeueNoBackoff, []byte("requeue_no_backoff_1")) script := []instruction{ - // SUB - instruction{0, FrameTypeResponse, []byte("OK")}, // IDENTIFY instruction{0, FrameTypeResponse, []byte("OK")}, + // SUB + instruction{0, FrameTypeResponse, []byte("OK")}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgRequeue)}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgRequeueNoBackoff)}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, @@ -308,7 +308,7 @@ func TestConsumerRequeueNoBackoff(t *testing.T) { } addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:0") - n := newMockNSQD(script, addr.String()) + n := newMockNSQD(t, script, addr.String()) topicName := "test_requeue" + strconv.Itoa(int(time.Now().Unix())) config := NewConfig() @@ -324,20 +324,19 @@ func TestConsumerRequeueNoBackoff(t *testing.T) { select { case <-n.exitChan: - log.Printf("clean exit") + t.Log("clean exit") case <-time.After(500 * time.Millisecond): - log.Printf("timeout") + t.Log("timeout") } for i, r := range n.got { - log.Printf("%d: %s", i, r) + t.Logf("%d: %s", i, r) } expected := []string{ "IDENTIFY", "SUB " + topicName + " ch", "RDY 1", - "RDY 1", "RDY 0", fmt.Sprintf("REQ %s 0", msgIDRequeue), "RDY 1", @@ -365,10 +364,10 @@ func TestConsumerBackoffDisconnect(t *testing.T) { msgRequeue := NewMessage(msgIDRequeue, []byte("requeue")) script := []instruction{ - // SUB - instruction{0, FrameTypeResponse, []byte("OK")}, // IDENTIFY instruction{0, FrameTypeResponse, []byte("OK")}, + // SUB + instruction{0, FrameTypeResponse, []byte("OK")}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgRequeue)}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgRequeue)}, @@ -378,7 +377,7 @@ func TestConsumerBackoffDisconnect(t *testing.T) { } addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:0") - n := newMockNSQD(script, addr.String()) + n := newMockNSQD(t, script, addr.String()) topicName := "test_requeue" + strconv.Itoa(int(time.Now().Unix())) config := NewConfig() @@ -396,13 +395,13 @@ func TestConsumerBackoffDisconnect(t *testing.T) { select { case <-n.exitChan: - log.Printf("clean exit") + t.Log("clean exit") case <-time.After(500 * time.Millisecond): - log.Printf("timeout") + t.Log("timeout") } for i, r := range n.got { - log.Printf("%d: %s", i, r) + t.Logf("%d: %s", i, r) } expected := []string{ @@ -430,27 +429,27 @@ func TestConsumerBackoffDisconnect(t *testing.T) { } script = []instruction{ - // SUB - instruction{0, FrameTypeResponse, []byte("OK")}, // IDENTIFY instruction{0, FrameTypeResponse, []byte("OK")}, + // SUB + instruction{0, FrameTypeResponse, []byte("OK")}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, // needed to exit test instruction{100 * time.Millisecond, -1, []byte("exit")}, } - n = newMockNSQD(script, n.tcpAddr.String()) + n = newMockNSQD(t, script, n.tcpAddr.String()) select { case <-n.exitChan: - log.Printf("clean exit") + t.Log("clean exit") case <-time.After(500 * time.Millisecond): - log.Printf("timeout") + t.Log("timeout") } for i, r := range n.got { - log.Printf("%d: %s", i, r) + t.Logf("%d: %s", i, r) } expected = []string{ @@ -470,3 +469,73 @@ func TestConsumerBackoffDisconnect(t *testing.T) { } } } + +func TestConsumerPause(t *testing.T) { + msgIDGood := MessageID{'1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 's', 'd', 'f', 'g', 'h'} + + msgGood := NewMessage(msgIDGood, []byte("good")) + + script := []instruction{ + // IDENTIFY + instruction{0, FrameTypeResponse, []byte("OK")}, + // SUB + instruction{0, FrameTypeResponse, []byte("OK")}, + instruction{20 * time.Millisecond, FrameTypeMessage, frameMessage(msgGood)}, + // needed to exit test + instruction{200 * time.Millisecond, -1, []byte("exit")}, + } + + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + n := newMockNSQD(t, script, addr.String()) + + topicName := "test_pause" + strconv.Itoa(int(time.Now().Unix())) + config := NewConfig() + config.MaxInFlight = 5 + q, _ := NewConsumer(topicName, "ch", config) + q.SetLogger(newTestLogger(t), LogLevelDebug) + q.AddHandler(&testHandler{}) + err := q.ConnectToNSQD(n.tcpAddr.String()) + if err != nil { + t.Fatalf(err.Error()) + } + + timeoutCh := time.After(500 * time.Millisecond) + pauseCh := time.After(50 * time.Millisecond) + unpauseCh := time.After(75 * time.Millisecond) + for { + select { + case <-n.exitChan: + t.Log("clean exit") + goto done + case <-timeoutCh: + t.Log("timeout") + goto done + case <-pauseCh: + q.ChangeMaxInFlight(0) + case <-unpauseCh: + q.ChangeMaxInFlight(config.MaxInFlight) + } + } +done: + + for i, r := range n.got { + t.Logf("%d: %s", i, r) + } + + expected := []string{ + "IDENTIFY", + "SUB " + topicName + " ch", + "RDY 5", + fmt.Sprintf("FIN %s", msgIDGood), + "RDY 0", + "RDY 5", + } + if len(n.got) != len(expected) { + t.Fatalf("we got %d commands != %d expected", len(n.got), len(expected)) + } + for i, r := range n.got { + if string(r) != expected[i] { + t.Fatalf("cmd %d bad %s != %s", i, r, expected[i]) + } + } +}