Skip to content

Commit

Permalink
Ensure watermark updates when position advances
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvanrossum committed Mar 6, 2025
1 parent 275d39a commit a3cbc4f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.math.BigDecimal;
import java.math.MathContext;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -55,7 +56,6 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
Expand Down Expand Up @@ -140,8 +140,8 @@
* {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically
* by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon
* as the {@link TopicPartition} is removed. For example, the removal could happen at the same time
* when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that
* case, the {@link ReadFromKafkaDoFn} will still output the fetched records.
* when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(Duration)}. In that case, the
* {@link ReadFromKafkaDoFn} will still output the fetched records.
*
* <h4>Stop Reading from Stopped {@link TopicPartition}</h4>
*
Expand Down Expand Up @@ -199,11 +199,11 @@ private ReadFromKafkaDoFn(
this.checkStopReadingFn = transform.getCheckStopReadingFn();
this.badRecordRouter = transform.getBadRecordRouter();
this.recordTag = recordTag;
if (transform.getConsumerPollingTimeout() > 0) {
this.consumerPollingTimeout = transform.getConsumerPollingTimeout();
} else {
this.consumerPollingTimeout = DEFAULT_KAFKA_POLL_TIMEOUT;
}
this.consumerPollingTimeout =
Duration.ofSeconds(
transform.getConsumerPollingTimeout() > 0
? transform.getConsumerPollingTimeout()
: DEFAULT_KAFKA_POLL_TIMEOUT);
}

private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
Expand Down Expand Up @@ -248,7 +248,7 @@ private static final class SharedStateHolder {

private transient @Nullable LoadingCache<KafkaSourceDescriptor, MovingAvg> avgRecordSizeCache;
private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L;
@VisibleForTesting final long consumerPollingTimeout;
@VisibleForTesting final Duration consumerPollingTimeout;
@VisibleForTesting final DeserializerProvider<K> keyDeserializerProvider;
@VisibleForTesting final DeserializerProvider<V> valueDeserializerProvider;
@VisibleForTesting final Map<String, Object> consumerConfig;
Expand Down Expand Up @@ -443,15 +443,17 @@ public ProcessContinuation processElement(
consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition()));
long expectedOffset = tracker.currentRestriction().getFrom();
consumer.seek(kafkaSourceDescriptor.getTopicPartition(), expectedOffset);
ConsumerRecords<byte[], byte[]> rawRecords = ConsumerRecords.empty();

while (true) {
// Fetch the record size accumulator.
final MovingAvg avgRecordSize = avgRecordSizeCache.getUnchecked(kafkaSourceDescriptor);
rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition());
// When there are no records available for the current TopicPartition, self-checkpoint
// and move to process the next element.
if (rawRecords.isEmpty()) {
// Fetch the next records.
final ConsumerRecords<byte[], byte[]> rawRecords =
consumer.poll(this.consumerPollingTimeout);

// No progress when the polling timeout expired.
// Self-checkpoint and move to process the next element.
if (rawRecords == ConsumerRecords.<byte[], byte[]>empty()) {
if (!topicPartitionExists(
kafkaSourceDescriptor.getTopicPartition(),
consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) {
Expand All @@ -462,6 +464,9 @@ public ProcessContinuation processElement(
}
return ProcessContinuation.resume();
}

// Visible progress within the consumer polling timeout.
// Partially or fully claim and process records in this batch.
for (ConsumerRecord<byte[], byte[]> rawRecord : rawRecords) {
if (!tracker.tryClaim(rawRecord.offset())) {
return ProcessContinuation.stop();
Expand Down Expand Up @@ -512,6 +517,17 @@ public ProcessContinuation processElement(
}
}

// Non-visible progress within the consumer polling timeout.
// Claim up to the current position.
if (expectedOffset < (expectedOffset = consumer.position(topicPartition))) {
if (!tracker.tryClaim(expectedOffset - 1)) {
return ProcessContinuation.stop();
}
if (timestampPolicy != null) {
updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
}
}

backlogBytes.set(
(long)
(BigDecimal.valueOf(
Expand All @@ -531,34 +547,6 @@ private boolean topicPartitionExists(
.anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition()));
}

// see https://github.com/apache/beam/issues/25962
private ConsumerRecords<byte[], byte[]> poll(
Consumer<byte[], byte[]> consumer, TopicPartition topicPartition) {
final Stopwatch sw = Stopwatch.createStarted();
long previousPosition = -1;
java.time.Duration elapsed = java.time.Duration.ZERO;
java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout);
while (true) {
final ConsumerRecords<byte[], byte[]> rawRecords = consumer.poll(timeout.minus(elapsed));
if (!rawRecords.isEmpty()) {
// return as we have found some entries
return rawRecords;
}
if (previousPosition == (previousPosition = consumer.position(topicPartition))) {
// there was no progress on the offset/position, which indicates end of stream
return rawRecords;
}
elapsed = sw.elapsed();
if (elapsed.toMillis() >= timeout.toMillis()) {
// timeout is over
LOG.warn(
"No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.",
consumerPollingTimeout);
return rawRecords;
}
}
}

private TimestampPolicyContext updateWatermarkManually(
TimestampPolicy<K, V> timestampPolicy,
WatermarkEstimator<Instant> watermarkEstimator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,14 +715,14 @@ public void testUnbounded() {
@Test
public void testConstructorWithPollTimeout() {
ReadSourceDescriptors<String, String> descriptors = makeReadSourceDescriptor(consumer);
// default poll timeout = 1 scond
// default poll timeout = 2 seconds
ReadFromKafkaDoFn<String, String> dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(2L, dofnInstance.consumerPollingTimeout);
Assert.assertEquals(Duration.ofSeconds(2L), dofnInstance.consumerPollingTimeout);
// updated timeout = 5 seconds
descriptors = descriptors.withConsumerPollingTimeout(5L);
ReadFromKafkaDoFn<String, String> dofnInstanceNew =
ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(5L, dofnInstanceNew.consumerPollingTimeout);
Assert.assertEquals(Duration.ofSeconds(5L), dofnInstanceNew.consumerPollingTimeout);
}

private BoundednessVisitor testBoundedness(
Expand Down

0 comments on commit a3cbc4f

Please sign in to comment.