Skip to content

Commit

Permalink
Fix wrong result of hasNext after seeking by id or time
Browse files Browse the repository at this point in the history
  • Loading branch information
shibd committed Feb 26, 2025
1 parent 4e71a47 commit c6f09e7
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 41 deletions.
1 change: 1 addition & 0 deletions pulsar/consumer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func newConsumer(client *client, options ConsumerOptions) (Consumer, error) {

if options.EnableZeroQueueConsumer {
options.ReceiverQueueSize = 0
options.StartMessageIDInclusive = true
}

if options.Interceptors == nil {
Expand Down
97 changes: 74 additions & 23 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,13 @@ type partitionConsumer struct {
backoffPolicyFunc func() backoff.Policy

dispatcherSeekingControlCh chan struct{}
isSeeking atomic.Bool
ctx context.Context
cancelFunc context.CancelFunc
// handle to the dispatcher goroutine
isSeeking atomic.Bool
// After executing seekByTime, the client is unaware of the startMessageId.
// Use this flag to compare markDeletePosition with BrokerLastMessageId when checking hasMoreMessages.
hasSoughtByTime atomic.Bool
ctx context.Context
cancelFunc context.CancelFunc
}

// pauseDispatchMessage used to discard the message in the dispatcher goroutine.
Expand Down Expand Up @@ -429,11 +433,12 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon

startingMessageID := pc.startMessageID.get()
if pc.options.startMessageIDInclusive && startingMessageID != nil && startingMessageID.equal(latestMessageID) {
msgID, err := pc.requestGetLastMessageID()
msgIDResp, err := pc.requestGetLastMessageID()
if err != nil {
pc.Close()
return nil, err
}
msgID := convertToMessageID(msgIDResp.GetLastMessageId())
if msgID.entryID != noMessageEntry {
pc.startMessageID.set(msgID)

Expand Down Expand Up @@ -616,18 +621,27 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) {
}

func (pc *partitionConsumer) getLastMessageID() (*trackingMessageID, error) {
res, err := pc.getLastMessageIDAndMarkDeletePosition()
if err != nil {
return nil, err
}
return res.msgID, err
}

func (pc *partitionConsumer) getLastMessageIDAndMarkDeletePosition() (*getLastMsgResult, error) {
if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing {
pc.log.WithField("state", state).Error("Failed to getLastMessageID for the closing or closed consumer")
return nil, errors.New("failed to getLastMessageID for the closing or closed consumer")
}
bo := pc.backoffPolicyFunc()
request := func() (*trackingMessageID, error) {
request := func() (*getLastMsgResult, error) {
req := &getLastMsgIDRequest{doneCh: make(chan struct{})}
pc.eventsCh <- req

// wait for the request to complete
<-req.doneCh
return req.msgID, req.err
res := &getLastMsgResult{req.msgID, req.markDeletePosition}
return res, req.err
}

ctx, cancel := context.WithTimeout(context.Background(), pc.client.operationTimeout)
Expand All @@ -647,10 +661,16 @@ func (pc *partitionConsumer) getLastMessageID() (*trackingMessageID, error) {

func (pc *partitionConsumer) internalGetLastMessageID(req *getLastMsgIDRequest) {
defer close(req.doneCh)
req.msgID, req.err = pc.requestGetLastMessageID()
rsp, err := pc.requestGetLastMessageID()
if err != nil {
req.err = err
return
}
req.msgID = convertToMessageID(rsp.GetLastMessageId())
req.markDeletePosition = convertToMessageID(rsp.GetConsumerMarkDeletePosition())
}

func (pc *partitionConsumer) requestGetLastMessageID() (*trackingMessageID, error) {
func (pc *partitionConsumer) requestGetLastMessageID() (*pb.CommandGetLastMessageIdResponse, error) {
if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing {
pc.log.WithField("state", state).Error("Failed to getLastMessageID closing or closed consumer")
return nil, errors.New("failed to getLastMessageID closing or closed consumer")
Expand All @@ -667,8 +687,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() (*trackingMessageID, erro
pc.log.WithError(err).Error("Failed to get last message id")
return nil, err
}
id := res.Response.GetLastMessageIdResponse.GetLastMessageId()
return convertToMessageID(id), nil
return res.Response.GetLastMessageIdResponse, nil
}

func (pc *partitionConsumer) sendIndividualAck(msgID MessageID) *ackRequest {
Expand Down Expand Up @@ -997,7 +1016,15 @@ func (pc *partitionConsumer) requestSeek(msgID *messageID) error {
if err := pc.requestSeekWithoutClear(msgID); err != nil {
return err
}
pc.clearReceiverQueue()
// When the seek operation is successful, it indicates:
// 1. The broker has reset the cursor and sent a request to close the consumer on the client side.
// Since this method is in the same goroutine as the reconnectToBroker,
// we can safely clear the messages in the queue (at this point, it won't contain messages after the seek).
// 2. The startMessageID is reset to ensure accurate judgment when calling hasNext next time.
// Since the messages in the queue are cleared here reconnection won't reset startMessageId.
pc.lastDequeuedMsg = nil
pc.startMessageID.set(toTrackingMessageID(msgID))
pc.clearQueueAndGetNextMessage()
return nil
}

Expand Down Expand Up @@ -1069,7 +1096,9 @@ func (pc *partitionConsumer) internalSeekByTime(seek *seekByTimeRequest) {
seek.err = err
return
}
pc.clearReceiverQueue()
pc.lastDequeuedMsg = nil
pc.hasSoughtByTime.Store(true)
pc.clearQueueAndGetNextMessage()
}

func (pc *partitionConsumer) internalAck(req *ackRequest) {
Expand Down Expand Up @@ -1451,10 +1480,6 @@ func (pc *partitionConsumer) messageShouldBeDiscarded(msgID *trackingMessageID)
if pc.startMessageID.get() == nil {
return false
}
// if we start at latest message, we should never discard
if pc.options.startMessageID != nil && pc.options.startMessageID.equal(latestMessageID) {
return false
}

if pc.options.startMessageIDInclusive {
return pc.startMessageID.get().greater(msgID.messageID)
Expand Down Expand Up @@ -1709,9 +1734,15 @@ type redeliveryRequest struct {
}

type getLastMsgIDRequest struct {
doneCh chan struct{}
msgID *trackingMessageID
err error
doneCh chan struct{}
msgID *trackingMessageID
markDeletePosition *trackingMessageID
err error
}

type getLastMsgResult struct {
msgID *trackingMessageID
markDeletePosition *trackingMessageID
}

type seekRequest struct {
Expand Down Expand Up @@ -2195,6 +2226,24 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData,
}

func (pc *partitionConsumer) hasNext() bool {

// If a seek by time has been performed, then the `startMessageId` becomes irrelevant.
// We need to compare `markDeletePosition` and `lastMessageId`,
// and then reset `startMessageID` to `markDeletePosition`.
if pc.hasSoughtByTime.CompareAndSwap(true, false) {
res, err := pc.getLastMessageIDAndMarkDeletePosition()
if err != nil {
pc.log.WithError(err).Error("Failed to get last message id")
return false
}
pc.lastMessageInBroker = res.msgID
pc.startMessageID.set(res.markDeletePosition)
// We only care about comparing ledger ids and entry ids as mark delete position
// doesn't have other ids such as batch index
compareResult := pc.lastMessageInBroker.messageID.compareLedgerAndEntryID(pc.startMessageID.get().messageID)
return compareResult > 0 || (pc.options.startMessageIDInclusive && compareResult == 0)
}

if pc.lastMessageInBroker != nil && pc.hasMoreMessages() {
return true
}
Expand Down Expand Up @@ -2256,12 +2305,14 @@ func convertToMessageID(id *pb.MessageIdData) *trackingMessageID {

msgID := &trackingMessageID{
messageID: &messageID{
ledgerID: int64(*id.LedgerId),
entryID: int64(*id.EntryId),
ledgerID: int64(id.GetLedgerId()),
entryID: int64(id.GetEntryId()),
batchIdx: -1,
batchSize: id.GetBatchSize(),
},
}
if id.BatchIndex != nil {
msgID.batchIdx = *id.BatchIndex
if id.GetBatchSize() > 1 {
msgID.batchIdx = id.GetBatchIndex()
}

return msgID
Expand Down
5 changes: 3 additions & 2 deletions pulsar/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1262,8 +1262,9 @@ func TestConsumerSeek(t *testing.T) {
defer producer.Close()

consumer, err := client.Subscribe(ConsumerOptions{
Topic: topicName,
SubscriptionName: "sub-1",
Topic: topicName,
SubscriptionName: "sub-1",
StartMessageIDInclusive: true,
})
assert.Nil(t, err)
defer consumer.Close()
Expand Down
14 changes: 14 additions & 0 deletions pulsar/impl_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package pulsar

import (
"cmp"
"errors"
"fmt"
"math"
Expand Down Expand Up @@ -147,6 +148,13 @@ func (id *messageID) equal(other *messageID) bool {
id.batchIdx == other.batchIdx
}

func (id *messageID) compareLedgerAndEntryID(other *messageID) int {
if result := cmp.Compare(id.ledgerID, other.ledgerID); result != 0 {
return result
}
return cmp.Compare(id.entryID, other.entryID)
}

func (id *messageID) greaterEqual(other *messageID) bool {
return id.equal(other) || id.greater(other)
}
Expand Down Expand Up @@ -204,6 +212,9 @@ func deserializeMessageID(data []byte) (MessageID, error) {
}

func newMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx int32, batchSize int32) MessageID {
if batchSize <= 1 {
batchIdx = -1
}
return &messageID{
ledgerID: ledgerID,
entryID: entryID,
Expand All @@ -225,6 +236,9 @@ func fromMessageID(msgID MessageID) *messageID {

func newTrackingMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx int32, batchSize int32,
tracker *ackTracker) *trackingMessageID {
if batchSize <= 1 {
batchIdx = -1
}
return &trackingMessageID{
messageID: &messageID{
ledgerID: ledgerID,
Expand Down
22 changes: 6 additions & 16 deletions pulsar/reader_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,19 +196,6 @@ func (r *reader) Close() {
r.metrics.ReadersClosed.Inc()
}

func (r *reader) messageID(msgID MessageID) *trackingMessageID {
mid := toTrackingMessageID(msgID)

partition := int(mid.partitionIdx)
// did we receive a valid partition index?
if partition < 0 {
r.log.Warnf("invalid partition index %d expected", partition)
return nil
}

return mid
}

func (r *reader) Seek(msgID MessageID) error {
r.Lock()
defer r.Unlock()
Expand All @@ -218,9 +205,12 @@ func (r *reader) Seek(msgID MessageID) error {
return fmt.Errorf("invalid message id type %T", msgID)
}

mid := r.messageID(msgID)
if mid == nil {
return nil
mid := toTrackingMessageID(msgID)

partition := int(mid.partitionIdx)
if partition < 0 {
r.log.Warnf("invalid partition index %d expected", partition)
return fmt.Errorf("seek msgId must include partitoinIndex")
}

return r.c.Seek(mid)
Expand Down
Loading

0 comments on commit c6f09e7

Please sign in to comment.