Skip to content

[KafkaIO] Update tracker and watermark for non-visible progress #34202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.math.BigDecimal;
import java.math.MathContext;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -140,8 +141,8 @@
* {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically
* by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon
* as the {@link TopicPartition} is removed. For example, the removal could happen at the same time
* when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that
* case, the {@link ReadFromKafkaDoFn} will still output the fetched records.
* when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(Duration)}. In that case, the
* {@link ReadFromKafkaDoFn} will still output the fetched records.
*
* <h4>Stop Reading from Stopped {@link TopicPartition}</h4>
*
Expand Down Expand Up @@ -199,11 +200,11 @@ private ReadFromKafkaDoFn(
this.checkStopReadingFn = transform.getCheckStopReadingFn();
this.badRecordRouter = transform.getBadRecordRouter();
this.recordTag = recordTag;
if (transform.getConsumerPollingTimeout() > 0) {
this.consumerPollingTimeout = transform.getConsumerPollingTimeout();
} else {
this.consumerPollingTimeout = DEFAULT_KAFKA_POLL_TIMEOUT;
}
this.consumerPollingTimeout =
Duration.ofSeconds(
transform.getConsumerPollingTimeout() > 0
? transform.getConsumerPollingTimeout()
: DEFAULT_KAFKA_POLL_TIMEOUT);
}

private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class);
Expand Down Expand Up @@ -248,7 +249,7 @@ private static final class SharedStateHolder {

private transient @Nullable LoadingCache<KafkaSourceDescriptor, MovingAvg> avgRecordSizeCache;
private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L;
@VisibleForTesting final long consumerPollingTimeout;
@VisibleForTesting final Duration consumerPollingTimeout;
@VisibleForTesting final DeserializerProvider<K> keyDeserializerProvider;
@VisibleForTesting final DeserializerProvider<V> valueDeserializerProvider;
@VisibleForTesting final Map<String, Object> consumerConfig;
Expand Down Expand Up @@ -443,19 +444,27 @@ public ProcessContinuation processElement(
long startOffset = tracker.currentRestriction().getFrom();
long expectedOffset = startOffset;
consumer.seek(kafkaSourceDescriptor.getTopicPartition(), startOffset);
ConsumerRecords<byte[], byte[]> rawRecords = ConsumerRecords.empty();
long skippedRecords = 0L;
final Stopwatch sw = Stopwatch.createStarted();

KafkaMetrics kafkaMetrics = KafkaSinkMetrics.kafkaMetrics();
final KafkaMetrics kafkaMetrics = KafkaSinkMetrics.kafkaMetrics();
try {
while (true) {
// Fetch the record size accumulator.
final MovingAvg avgRecordSize = avgRecordSizeCache.getUnchecked(kafkaSourceDescriptor);
rawRecords = poll(consumer, kafkaSourceDescriptor.getTopicPartition(), kafkaMetrics);
// When there are no records available for the current TopicPartition, self-checkpoint
// and move to process the next element.
if (rawRecords.isEmpty()) {
// TODO: Remove this timer and use the existing fetch-latency-avg metric.
// A consumer will often have prefetches waiting to be returned immediately in which case
// this timer may contribute more latency than it measures.
// See https://shipilev.net/blog/2014/nanotrusting-nanotime/ for more information.
final Stopwatch pollTimer = Stopwatch.createStarted();
// Fetch the next records.
final ConsumerRecords<byte[], byte[]> rawRecords =
consumer.poll(this.consumerPollingTimeout);
kafkaMetrics.updateSuccessfulRpcMetrics(topicPartition.topic(), pollTimer.elapsed());

// No progress when the polling timeout expired.
// Self-checkpoint and move to process the next element.
if (rawRecords == ConsumerRecords.<byte[], byte[]>empty()) {
if (!topicPartitionExists(
kafkaSourceDescriptor.getTopicPartition(),
consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) {
Expand All @@ -466,6 +475,9 @@ public ProcessContinuation processElement(
}
return ProcessContinuation.resume();
}

// Visible progress within the consumer polling timeout.
// Partially or fully claim and process records in this batch.
for (ConsumerRecord<byte[], byte[]> rawRecord : rawRecords) {
// If the Kafka consumer returns a record with an offset that is already processed
// the record can be safely skipped. This is needed because there is a possibility
Expand Down Expand Up @@ -500,6 +512,7 @@ public ProcessContinuation processElement(
if (!tracker.tryClaim(rawRecord.offset())) {
return ProcessContinuation.stop();
}
expectedOffset = rawRecord.offset() + 1;
try {
KafkaRecord<K, V> kafkaRecord =
new KafkaRecord<>(
Expand All @@ -516,7 +529,6 @@ public ProcessContinuation processElement(
+ (rawRecord.value() == null ? 0 : rawRecord.value().length);
avgRecordSize.update(recordSize);
rawSizes.update(recordSize);
expectedOffset = rawRecord.offset() + 1;
Instant outputTimestamp;
// The outputTimestamp and watermark will be computed by timestampPolicy, where the
// WatermarkEstimator should be a manual one.
Expand Down Expand Up @@ -546,6 +558,17 @@ public ProcessContinuation processElement(
}
}

// Non-visible progress within the consumer polling timeout.
// Claim up to the current position.
if (expectedOffset < (expectedOffset = consumer.position(topicPartition))) {
if (!tracker.tryClaim(expectedOffset - 1)) {
return ProcessContinuation.stop();
}
if (timestampPolicy != null) {
updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
}
}

backlogBytes.set(
(long)
(BigDecimal.valueOf(
Expand Down Expand Up @@ -578,36 +601,6 @@ private boolean topicPartitionExists(
.anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition()));
}

// see https://github.com/apache/beam/issues/25962
private ConsumerRecords<byte[], byte[]> poll(
Consumer<byte[], byte[]> consumer, TopicPartition topicPartition, KafkaMetrics kafkaMetrics) {
final Stopwatch sw = Stopwatch.createStarted();
long previousPosition = -1;
java.time.Duration timeout = java.time.Duration.ofSeconds(this.consumerPollingTimeout);
java.time.Duration elapsed = java.time.Duration.ZERO;
while (true) {
final ConsumerRecords<byte[], byte[]> rawRecords = consumer.poll(timeout.minus(elapsed));
elapsed = sw.elapsed();
kafkaMetrics.updateSuccessfulRpcMetrics(
topicPartition.topic(), java.time.Duration.ofMillis(elapsed.toMillis()));
if (!rawRecords.isEmpty()) {
// return as we have found some entries
return rawRecords;
}
if (previousPosition == (previousPosition = consumer.position(topicPartition))) {
// there was no progress on the offset/position, which indicates end of stream
return rawRecords;
}
if (elapsed.toMillis() >= timeout.toMillis()) {
// timeout is over
LOG.warn(
"No messages retrieved with polling timeout {} seconds. Consider increasing the consumer polling timeout using withConsumerPollingTimeout method.",
consumerPollingTimeout);
return rawRecords;
}
}
}

private TimestampPolicyContext updateWatermarkManually(
TimestampPolicy<K, V> timestampPolicy,
WatermarkEstimator<Instant> watermarkEstimator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -717,14 +717,14 @@ public void testUnbounded() {
@Test
public void testConstructorWithPollTimeout() {
ReadSourceDescriptors<String, String> descriptors = makeReadSourceDescriptor(consumer);
// default poll timeout = 1 scond
// default poll timeout = 2 seconds
ReadFromKafkaDoFn<String, String> dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(2L, dofnInstance.consumerPollingTimeout);
Assert.assertEquals(Duration.ofSeconds(2L), dofnInstance.consumerPollingTimeout);
// updated timeout = 5 seconds
descriptors = descriptors.withConsumerPollingTimeout(5L);
ReadFromKafkaDoFn<String, String> dofnInstanceNew =
ReadFromKafkaDoFn.create(descriptors, RECORDS);
Assert.assertEquals(5L, dofnInstanceNew.consumerPollingTimeout);
Assert.assertEquals(Duration.ofSeconds(5L), dofnInstanceNew.consumerPollingTimeout);
}

private BoundednessVisitor testBoundedness(
Expand Down
Loading