Skip to content

Commit ae417e6

Browse files
authored
fix(spark): potential hang with skipped segments on overlapping decompression (#2745)
### What changes were proposed in this pull request? fix the PR #2735 ### Why are the changes needed? fix potential hang ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests
1 parent b80940d commit ae417e6

4 files changed

Lines changed: 74 additions & 3 deletions

File tree

client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,28 @@ public void add(int batchIndex, ShuffleDataResult shuffleDataResult) {
154154
}
155155

156156
public DecompressedShuffleBlock get(int batchIndex, int segmentIndex) {
157+
// guardedly safe to remove the previous batches if exist since the upstream will fetch the
158+
// segments in order
159+
for (int i = 0; i < batchIndex; i++) {
160+
ConcurrentHashMap<Integer, DecompressedShuffleBlock> prevBlocks = tasks.remove(i);
161+
if (prevBlocks != null) {
162+
segmentPermits.ifPresent(x -> x.release(prevBlocks.values().size()));
163+
}
164+
}
165+
157166
ConcurrentHashMap<Integer, DecompressedShuffleBlock> blocks = tasks.get(batchIndex);
158167
if (blocks == null) {
159168
return null;
160169
}
170+
171+
// guardedly safe to remove the previous segments if exist since the upstream will fetch the
172+
// segments in order
173+
for (int i = 0; i < segmentIndex; i++) {
174+
if (blocks.remove(i) != null) {
175+
segmentPermits.ifPresent(x -> x.release());
176+
}
177+
}
178+
161179
DecompressedShuffleBlock block = blocks.remove(segmentIndex);
162180
// simplify the memory statistic logic here, just decrease the memory used when the block is
163181
// fetched, this is effective due to the upstream will use single-thread to get and release the

client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,16 @@ public ShuffleBlock readShuffleBlockData() {
314314
// mark block as processed
315315
processedBlockIds.add(bs.getBlockId());
316316
pendingBlockIds.removeLong(bs.getBlockId());
317-
// update the segment index to skip the unnecessary block in overlapping decompression mode
318-
segmentIndex += 1;
317+
318+
// update the segment index to skip the unnecessary block in overlapping decompression mode.
319+
// In overlapping decompression mode, decompression tasks for the whole batch have already
320+
// been submitted. If we skip a segment without removing the corresponding handler, the
321+
// backpressure permits may never be released, which can block subsequent decompression.
322+
if (decompressionWorker != null) {
323+
decompressionWorker.get(batchIndex - 1, segmentIndex++);
324+
} else {
325+
segmentIndex += 1;
326+
}
319327
}
320328

321329
if (bs != null) {

client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import static org.apache.uniffle.common.config.RssClientConf.COMPRESSION_TYPE;
3636
import static org.junit.jupiter.api.Assertions.assertEquals;
3737
import static org.junit.jupiter.api.Assertions.assertNull;
38+
import static org.junit.jupiter.api.Assertions.assertTrue;
3839

3940
public class DecompressionWorkerTest {
4041

@@ -66,7 +67,9 @@ public void testBackpressure() throws Exception {
6667
}
6768
Thread.sleep(10);
6869
worker.get(0, maxSegments).getByteBuffer();
69-
assertEquals(1024 * maxSegments, worker.getPeekMemoryUsed());
70+
// Peak memory is a runtime metric and may include one additional segment due to thread timing.
71+
assertTrue(worker.getPeekMemoryUsed() <= 1024L * (maxSegments + 1));
72+
assertTrue(worker.getPeekMemoryUsed() >= 1024L * maxSegments);
7073
assertEquals(maxSegments, worker.getAvailablePermits());
7174
}
7275

client/src/test/java/org/apache/uniffle/client/impl/ShuffleReadClientImplTest.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
5353
import org.apache.uniffle.storage.util.StorageType;
5454

55+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
5556
import static org.junit.jupiter.api.Assertions.assertEquals;
5657
import static org.junit.jupiter.api.Assertions.assertNotEquals;
5758
import static org.junit.jupiter.api.Assertions.assertNull;
@@ -769,6 +770,47 @@ public void readTest16(Supplier<ShuffleClientFactory.ReadClientBuilder> builderS
769770
readClient.close();
770771
}
771772

773+
@ParameterizedTest
774+
@MethodSource("clientBuilderProvider")
775+
public void readTestSkipBlocksWithBackpressureDoesNotHang(
776+
Supplier<ShuffleClientFactory.ReadClientBuilder> builderSupplier) throws Exception {
777+
// This test is meaningful only when overlapping decompression is enabled.
778+
// For non-overlapping mode, it should still pass and act as a regression guard.
779+
String basePath = uniq(HDFS_URI + "clientReadTestSkipBlocksWithBackpressureDoesNotHang");
780+
HadoopShuffleWriteHandler writeHandler =
781+
new HadoopShuffleWriteHandler("appId", 0, 1, 1, basePath, ssi1.getId(), conf);
782+
783+
Map<Long, byte[]> expectedData = Maps.newHashMap();
784+
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
785+
786+
// Write skipped blocks first to increase the chance of exhausting permits if permits are not
787+
// released when skipping.
788+
writeTestData(writeHandler, 20, 30, 1, 2, Maps.newHashMap(), blockIdBitmap);
789+
writeTestData(writeHandler, 5, 30, 1, 0, expectedData, blockIdBitmap);
790+
791+
RssConf rssConf = new RssConf();
792+
// Provide required base configs to avoid reader treating this as "prod mode" with empty values.
793+
rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
794+
rssConf.setInteger(RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD, 1);
795+
rssConf.setInteger(RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS, 1);
796+
797+
// Expect only taskAttemptId=0 blocks; taskAttemptId=2 blocks will be skipped by the reader.
798+
Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
799+
ShuffleReadClientImpl readClient =
800+
builderSupplier
801+
.get()
802+
.partitionId(1)
803+
.basePath(basePath)
804+
.blockIdBitmap(blockIdBitmap)
805+
.taskIdBitmap(taskIdBitmap)
806+
.rssConf(rssConf)
807+
.build();
808+
809+
assertDoesNotThrow(() -> TestUtils.validateResult(readClient, expectedData));
810+
readClient.checkProcessedBlockIds();
811+
readClient.close();
812+
}
813+
772814
private void writeTestData(
773815
HadoopShuffleWriteHandler writeHandler,
774816
int num,

0 commit comments

Comments
 (0)