From 432304f5b5eb52ddd640decf28e9dcc42bafbe55 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Mon, 17 Jun 2024 14:21:17 +0800 Subject: [PATCH 01/23] [#1796] fix(spark): Implicitly unregister map output on fetch failure --- .../spark/shuffle/RssSparkShuffleUtils.java | 16 ++++++++++++++++ .../shuffle/reader/RssFetchFailedIterator.java | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index b3763df32a..f2704b3d0f 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -33,6 +33,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.deploy.SparkHadoopUtil; import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; @@ -44,7 +45,9 @@ import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; +import org.apache.uniffle.client.request.RssReassignServersRequest; import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest; +import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ClientType; @@ -371,6 +374,19 @@ public static RssException reportRssFetchFailedException( rssFetchFailedException.getMessage()); RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); if (response.getReSubmitWholeStage()) { + TaskContext taskContext = TaskContext.get(); + RssReassignServersRequest rssReassignServersRequest = + new RssReassignServersRequest( + taskContext.stageId(), + taskContext.stageAttemptNumber(), + shuffleId, + taskContext.numPartitions()); + RssReassignServersResponse reassignServersResponse = + client.reassignShuffleServers(rssReassignServersRequest); + LOG.info( + "Reassign servers for stage retry due to the fetch failure, result: {}", + reassignServersResponse.isNeedReassign()); + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 // is provided. FetchFailedException ffe = diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java index c394f510bb..d3dab84456 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java @@ -24,6 +24,7 @@ import scala.collection.AbstractIterator; import scala.collection.Iterator; +import org.apache.spark.TaskContext; import org.apache.spark.shuffle.FetchFailedException; import org.apache.spark.shuffle.RssSparkShuffleUtils; import org.slf4j.Logger; @@ -31,7 +32,9 @@ import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; +import org.apache.uniffle.client.request.RssReassignServersRequest; import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest; +import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.exception.RssException; @@ -120,6 +123,19 @@ private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) { e.getMessage()); RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); if (response.getReSubmitWholeStage()) { + TaskContext taskContext = TaskContext.get(); + RssReassignServersRequest rssReassignServersRequest = + new RssReassignServersRequest( + taskContext.stageId(), + taskContext.stageAttemptNumber(), + builder.shuffleId, + taskContext.numPartitions()); + RssReassignServersResponse reassignServersResponse = + client.reassignShuffleServers(rssReassignServersRequest); + LOG.info( + "Reassign servers for stage retry due to the fetch failure, result: {}", + reassignServersResponse.isNeedReassign()); + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is // provided. FetchFailedException ffe = From 78448d78371ad9c89ba806ee276f60c1363b4826 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 18 Jun 2024 17:43:43 +0800 Subject: [PATCH 02/23] refactor --- .../spark/shuffle/RssSparkShuffleUtils.java | 26 ++++------ .../reader/RssFetchFailedIterator.java | 26 ++++------ .../exception/RssFetchFailedException.java | 13 +++++ .../impl/grpc/ShuffleServerGrpcClient.java | 17 ++++--- .../grpc/ShuffleServerGrpcNettyClient.java | 11 ++-- .../RssReportShuffleFetchFailureRequest.java | 51 ++++++++++++++++++- proto/src/main/proto/Rss.proto | 7 ++- 7 files changed, 106 insertions(+), 45 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index f2704b3d0f..ca3bd0c93c 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -21,6 +21,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -33,7 +34,9 @@ import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; +import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.deploy.SparkHadoopUtil; import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; @@ -45,9 +48,7 @@ import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; -import org.apache.uniffle.client.request.RssReassignServersRequest; import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest; -import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ClientType; @@ -363,6 +364,7 @@ public static RssException reportRssFetchFailedException( try (ShuffleManagerClient client = ShuffleManagerClientFactory.getInstance() .createShuffleManagerClient(ClientType.GRPC, driver, port)) { + TaskContext taskContext = TaskContext$.MODULE$.get(); // todo: Create a new rpc interface to report failures in batch. for (int partitionId : failedPartitions) { RssReportShuffleFetchFailureRequest req = @@ -371,22 +373,14 @@ public static RssException reportRssFetchFailedException( shuffleId, stageAttemptId, partitionId, - rssFetchFailedException.getMessage()); + rssFetchFailedException.getMessage(), + Collections.emptyList(), + taskContext.stageId(), + taskContext.taskAttemptId(), + taskContext.attemptNumber(), + SparkEnv.get().executorId()); RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); if (response.getReSubmitWholeStage()) { - TaskContext taskContext = TaskContext.get(); - RssReassignServersRequest rssReassignServersRequest = - new RssReassignServersRequest( - taskContext.stageId(), - taskContext.stageAttemptNumber(), - shuffleId, - taskContext.numPartitions()); - RssReassignServersResponse reassignServersResponse = - client.reassignShuffleServers(rssReassignServersRequest); - LOG.info( - "Reassign servers for stage retry due to the fetch failure, result: {}", - reassignServersResponse.isNeedReassign()); - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 // is provided. FetchFailedException ffe = diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java index d3dab84456..7d9e596a20 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java @@ -18,13 +18,16 @@ package org.apache.spark.shuffle.reader; import java.io.IOException; +import java.util.Collections; import java.util.Objects; import scala.Product2; import scala.collection.AbstractIterator; import scala.collection.Iterator; +import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.shuffle.FetchFailedException; import org.apache.spark.shuffle.RssSparkShuffleUtils; import org.slf4j.Logger; @@ -32,9 +35,7 @@ import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; -import org.apache.uniffle.client.request.RssReassignServersRequest; import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest; -import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.exception.RssException; @@ -114,28 +115,21 @@ private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) { int port = builder.reportServerPort; // todo: reuse this manager client if this is a bottleneck. try (ShuffleManagerClient client = createShuffleManagerClient(driver, port)) { + TaskContext taskContext = TaskContext$.MODULE$.get(); RssReportShuffleFetchFailureRequest req = new RssReportShuffleFetchFailureRequest( builder.appId, builder.shuffleId, builder.stageAttemptId, builder.partitionId, - e.getMessage()); + e.getMessage(), + Collections.emptyList(), + taskContext.stageId(), + taskContext.taskAttemptId(), + taskContext.attemptNumber(), + SparkEnv.get().executorId()); RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); if (response.getReSubmitWholeStage()) { - TaskContext taskContext = TaskContext.get(); - RssReassignServersRequest rssReassignServersRequest = - new RssReassignServersRequest( - taskContext.stageId(), - taskContext.stageAttemptNumber(), - builder.shuffleId, - taskContext.numPartitions()); - RssReassignServersResponse reassignServersResponse = - client.reassignShuffleServers(rssReassignServersRequest); - LOG.info( - "Reassign servers for stage retry due to the fetch failure, result: {}", - reassignServersResponse.isNeedReassign()); - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is // provided. FetchFailedException ffe = diff --git a/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java b/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java index 08c2b8101c..c61e6776dd 100644 --- a/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java +++ b/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java @@ -17,8 +17,17 @@ package org.apache.uniffle.common.exception; +import org.apache.uniffle.common.ShuffleServerInfo; + /** Dedicated exception for rss client's shuffle failed related exception. */ public class RssFetchFailedException extends RssException { + private ShuffleServerInfo fetchFailureServerId; + + public RssFetchFailedException(String message, ShuffleServerInfo fetchFailureServerId) { + super(message); + this.fetchFailureServerId = fetchFailureServerId; + } + public RssFetchFailedException(String message) { super(message); } @@ -26,4 +35,8 @@ public RssFetchFailedException(String message) { public RssFetchFailedException(String message, Throwable e) { super(message, e); } + + public ShuffleServerInfo getFetchFailureServerId() { + return fetchFailureServerId; + } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 14dbf2f60f..9df4070654 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -66,6 +66,7 @@ import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; +import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.NotRetryException; @@ -827,12 +828,16 @@ public RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest r + ", errorMsg:" + rpcResponse.getRetMsg(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } return response; } + private ShuffleServerInfo getShuffleServerInfo() { + return new ShuffleServerInfo(host, port); + } + @Override public RssGetShuffleResultResponse getShuffleResultForMultiPart( RssGetShuffleResultForMultiPartRequest request) { @@ -876,7 +881,7 @@ public RssGetShuffleResultResponse getShuffleResultForMultiPart( + ", errorMsg:" + rpcResponse.getRetMsg(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } return response; @@ -939,7 +944,7 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request + ", errorMsg:" + rpcResponse.getRetMsg(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } return response; } @@ -1002,7 +1007,7 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ + ", errorMsg:" + rpcResponse.getRetMsg(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } return response; } @@ -1078,7 +1083,7 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( + ", errorMsg:" + rpcResponse.getRetMsg(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } return response; } @@ -1101,7 +1106,7 @@ protected void waitOrThrow( request.getRetryMax(), System.currentTimeMillis() - start); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } try { long backoffTime = diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index a05d94b51d..bde877b0e4 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -37,6 +37,7 @@ import org.apache.uniffle.client.response.RssGetShuffleIndexResponse; import org.apache.uniffle.client.response.RssSendShuffleDataResponse; import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.NotRetryException; @@ -304,10 +305,14 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( + ", errorMsg:" + getMemoryShuffleDataResponse.getRetMessage(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } } + private ShuffleServerInfo getShuffleServerInfo() { + return new ShuffleServerInfo(host, port, nettyPort); + } + @Override public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest request) { TransportClient transportClient = getTransportClient(); @@ -364,7 +369,7 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ + ", errorMsg:" + getLocalShuffleIndexResponse.getRetMessage(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } } @@ -425,7 +430,7 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request + ", errorMsg:" + getLocalShuffleDataResponse.getRetMessage(); LOG.error(msg); - throw new RssFetchFailedException(msg); + throw new RssFetchFailedException(msg, getShuffleServerInfo()); } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleFetchFailureRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleFetchFailureRequest.java index d9cea576a3..c3a93ae603 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleFetchFailureRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleFetchFailureRequest.java @@ -17,6 +17,12 @@ package org.apache.uniffle.client.request; +import java.util.Collections; +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureRequest; public class RssReportShuffleFetchFailureRequest { @@ -25,14 +31,50 @@ public class RssReportShuffleFetchFailureRequest { private int stageAttemptId; private int partitionId; private String exception; + private List fetchFailureServerInfos; + private int stageId; + private long taskAttemptId; + private int taskAttemptNumber; + private String executorId; public RssReportShuffleFetchFailureRequest( - String appId, int shuffleId, int stageAttemptId, int partitionId, String exception) { + String appId, + int shuffleId, + int stageAttemptId, + int partitionId, + String exception, + List fetchFailureServerInfos, + int stageId, + long taskAttemptId, + int taskAttemptNumber, + String executorId) { this.appId = appId; this.shuffleId = shuffleId; this.stageAttemptId = stageAttemptId; this.partitionId = partitionId; this.exception = exception; + this.fetchFailureServerInfos = fetchFailureServerInfos; + this.stageId = stageId; + this.taskAttemptId = taskAttemptId; + this.taskAttemptNumber = taskAttemptNumber; + this.executorId = executorId; + } + + // Only for tests + @VisibleForTesting + public RssReportShuffleFetchFailureRequest( + String appId, int shuffleId, int stageAttemptId, int partitionId, String exception) { + this( + appId, + shuffleId, + stageAttemptId, + partitionId, + exception, + Collections.emptyList(), + 0, + 0L, + 0, + "executor1"); } public ReportShuffleFetchFailureRequest toProto() { @@ -42,7 +84,12 @@ public ReportShuffleFetchFailureRequest toProto() { .setAppId(appId) .setShuffleId(shuffleId) .setStageAttemptId(stageAttemptId) - .setPartitionId(partitionId); + .setPartitionId(partitionId) + .setStageId(stageId) + .setTaskAttemptId(taskAttemptId) + .setTaskAttemptNumber(taskAttemptNumber) + .setExecutorId(executorId) + .addAllFetchFailureServerId(ShuffleServerInfo.toProto(fetchFailureServerInfos)); if (exception != null) { builder.setException(exception); } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 97928bf201..8f14c32305 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -557,8 +557,11 @@ message ReportShuffleFetchFailureRequest { int32 stageAttemptId = 3; int32 partitionId = 4; string exception = 5; - // todo: report ShuffleServerId if needed - // ShuffleServerId serverId = 6; + repeated ShuffleServerId fetchFailureServerId = 6; + int32 stageId = 7; + int64 taskAttemptId = 8; + int32 taskAttemptNumber = 9; + string executorId = 10; } message ReportShuffleFetchFailureResponse { From f1e69fe8662dcaaa5c328053a6843ba5de69cd81 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 15:08:53 +0800 Subject: [PATCH 03/23] refactor the shuffle status --- .../shuffle/RssStageResubmitManager.java | 54 ++++++- .../spark/shuffle/stage/RssShuffleStatus.java | 109 +++++++++++++++ .../stage/RssShuffleStatusForReader.java | 24 ++++ .../stage/RssShuffleStatusForWriter.java | 24 ++++ .../manager/RssShuffleManagerBase.java | 4 + .../manager/RssShuffleManagerInterface.java | 3 + .../manager/ShuffleManagerGrpcService.java | 132 ++++-------------- .../shuffle/stage/RssShuffleStatusTest.java | 53 +++++++ .../manager/DummyRssShuffleManager.java | 6 + proto/src/main/proto/Rss.proto | 1 + 10 files changed, 306 insertions(+), 104 deletions(-) create mode 100644 client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java create mode 100644 client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForReader.java create mode 100644 client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForWriter.java create mode 100644 client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java index 028622f922..e8555a4af8 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -19,16 +19,68 @@ import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import com.google.common.collect.Sets; +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.stage.RssShuffleStatus; +import org.apache.spark.shuffle.stage.RssShuffleStatusForReader; +import org.apache.spark.shuffle.stage.RssShuffleStatusForWriter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.uniffle.common.util.JavaUtils; public class RssStageResubmitManager { - private static final Logger LOG = LoggerFactory.getLogger(RssStageResubmitManager.class); + + private final SparkConf sparkConf = new SparkConf(); + private final Map shuffleStatusForReader = + new ConcurrentHashMap<>(); + private final Map shuffleStatusForWriter = + new ConcurrentHashMap<>(); + + public void clear(int shuffleId) { + shuffleStatusForReader.remove(shuffleId); + shuffleStatusForWriter.remove(shuffleId); + } + + public RssShuffleStatus getShuffleStatusForReader(int shuffleId, int stageId, int stageAttempt) { + RssShuffleStatus shuffleStatus = + shuffleStatusForReader.computeIfAbsent( + shuffleId, x -> new RssShuffleStatusForReader(stageId, shuffleId)); + if (shuffleStatus.updateStageAttemptIfNecessary(stageAttempt)) { + return shuffleStatus; + } + return null; + } + + public RssShuffleStatus getShuffleStatusForWriter(int shuffleId, int stageId, int stageAttempt) { + RssShuffleStatus shuffleStatus = + shuffleStatusForWriter.computeIfAbsent( + shuffleId, x -> new RssShuffleStatusForWriter(stageId, shuffleId)); + if (shuffleStatus.updateStageAttemptIfNecessary(stageAttempt)) { + return shuffleStatus; + } + return null; + } + + public boolean triggerStageRetry(RssShuffleStatus shuffleStatus) { + final String TASK_MAX_FAILURE = "spark.task.maxFailures"; + int sparkTaskMaxFailures = sparkConf.getInt(TASK_MAX_FAILURE, 4); + if (shuffleStatus instanceof RssShuffleStatusForReader) { + if (shuffleStatus.getStageRetriedNumber() > 1) { + LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it."); + return false; + } + if (shuffleStatus.getTaskFailureAttemptCount() >= sparkTaskMaxFailures) { + shuffleStatus.markStageAttemptRetried(); + return true; + } + } + return false; + } + /** Blacklist of the Shuffle Server when the write fails. */ private Set serverIdBlackList; /** diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java new file mode 100644 index 0000000000..099166120d --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.stage; + +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Supplier; + +/** + * This class is to track the stage attempt status to check whether to trigger the stage retry of + * Spark. + */ +public class RssShuffleStatus { + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock.ReadLock readLock = lock.readLock(); + private final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock(); + private final int stageId; + private final int shuffleId; + // the retried stage attempt records + private final Set stageAttemptRetriedRecords; + + private int stageAttemptNumber; + // the failed task attempt numbers. Attention: these are not task attempt ids! + private Set taskAttemptFailureRecords; + + public RssShuffleStatus(int stageId, int shuffleId) { + this.shuffleId = shuffleId; + this.stageId = stageId; + this.stageAttemptRetriedRecords = new HashSet<>(); + this.taskAttemptFailureRecords = new HashSet<>(); + } + + private T withReadLock(Supplier fn) { + readLock.lock(); + try { + return fn.get(); + } finally { + readLock.unlock(); + } + } + + private T withWriteLock(Supplier fn) { + writeLock.lock(); + try { + return fn.get(); + } finally { + writeLock.unlock(); + } + } + + public int getStageRetriedNumber() { + return withReadLock(() -> this.stageAttemptRetriedRecords.size()); + } + + public void markStageAttemptRetried() { + withWriteLock( + () -> { + this.stageAttemptRetriedRecords.add(stageAttemptNumber); + return null; + }); + } + + public int getStageAttempt() { + return withReadLock(() -> this.stageAttemptNumber); + } + + public boolean updateStageAttemptIfNecessary(int stageAttempt) { + return withWriteLock( + () -> { + if (this.stageAttemptNumber < stageAttempt) { + // a new stage attempt is issued. + this.stageAttemptNumber = stageAttempt; + this.taskAttemptFailureRecords = new HashSet<>(); + return true; + } else if (this.stageAttemptNumber > stageAttempt) { + return false; + } + return true; + }); + } + + public void incTaskFailure(int taskAttemptNumber) { + withWriteLock( + () -> { + taskAttemptFailureRecords.add(taskAttemptNumber); + return null; + }); + } + + public int getTaskFailureAttemptCount() { + return withReadLock(() -> taskAttemptFailureRecords.size()); + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForReader.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForReader.java new file mode 100644 index 0000000000..81f4d77c54 --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForReader.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.stage; + +public class RssShuffleStatusForReader extends RssShuffleStatus { + public RssShuffleStatusForReader(int stageId, int shuffleId) { + super(stageId, shuffleId); + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForWriter.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForWriter.java new file mode 100644 index 0000000000..d00487664b --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatusForWriter.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.stage; + +public class RssShuffleStatusForWriter extends RssShuffleStatus { + public RssShuffleStatusForWriter(int stageId, int shuffleId) { + super(stageId, shuffleId); + } +} diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 209ede25c7..fbc03c078d 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -1058,4 +1058,8 @@ public boolean isRssStageRetryForWriteFailureEnabled() { public boolean isRssStageRetryForFetchFailureEnabled() { return rssStageRetryForFetchFailureEnabled; } + + public RssStageResubmitManager getStageResubmitManager() { + return rssStageResubmitManager; + } } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java index 77379efb5f..d927e1b887 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java @@ -21,6 +21,7 @@ import java.util.Map; import org.apache.spark.SparkException; +import org.apache.spark.shuffle.RssStageResubmitManager; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.ShuffleHandleInfo; @@ -87,4 +88,6 @@ MutableShuffleHandleInfo reassignOnBlockSendFailure( int stageAttemptNumber, int shuffleId, Map> partitionToFailureServers); + + RssStageResubmitManager getStageResubmitManager(); } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index b9828408c0..1c03434785 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -18,7 +18,6 @@ package org.apache.uniffle.shuffle.manager; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -29,8 +28,10 @@ import com.google.protobuf.UnsafeByteOperations; import io.grpc.stub.StreamObserver; +import org.apache.spark.shuffle.RssStageResubmitManager; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; +import org.apache.spark.shuffle.stage.RssShuffleStatus; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,7 +46,6 @@ public class ShuffleManagerGrpcService extends ShuffleManagerImplBase { private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class); - private final Map shuffleStatus = JavaUtils.newConcurrentMap(); // The shuffleId mapping records the number of ShuffleServer write failures private final Map shuffleWrtieStatus = JavaUtils.newConcurrentMap(); @@ -133,6 +133,13 @@ public void reportShuffleFetchFailure( String appId = request.getAppId(); int stageAttempt = request.getStageAttemptId(); int partitionId = request.getPartitionId(); + int shuffleId = request.getShuffleId(); + int stageId = request.getStageId(); + long taskAttemptId = request.getTaskAttemptId(); + int taskAttemptNumber = request.getTaskAttemptNumber(); + String executorId = request.getExecutorId(); + + RssStageResubmitManager stageResubmitManager = shuffleManager.getStageResubmitManager(); RssProtos.StatusCode code; boolean reSubmitWholeStage; String msg; @@ -145,35 +152,36 @@ public void reportShuffleFetchFailure( code = RssProtos.StatusCode.INVALID_REQUEST; reSubmitWholeStage = false; } else { - RssShuffleStatus status = - shuffleStatus.computeIfAbsent( - request.getShuffleId(), - key -> { - int partitionNum = shuffleManager.getPartitionNum(key); - return new RssShuffleStatus(partitionNum, stageAttempt); - }); - int c = status.resetStageAttemptIfNecessary(stageAttempt); - if (c < 0) { + RssShuffleStatus rssShuffleStatus = + stageResubmitManager.getShuffleStatusForReader(shuffleId, stageId, stageAttempt); + if (rssShuffleStatus == null) { msg = String.format( - "got an old stage(%d vs %d) shuffle fetch failure report, which should be impossible.", - status.getStageAttempt(), stageAttempt); + "got an old stage(%d:%d) shuffle fetch failure report from executor(%s), task(%d:%d) which should be impossible.", + stageId, stageAttempt, executorId, taskAttemptId, taskAttemptNumber); LOG.warn(msg); code = RssProtos.StatusCode.INVALID_REQUEST; reSubmitWholeStage = false; - } else { // update the stage partition fetch failure count + } else { code = RssProtos.StatusCode.SUCCESS; - status.incPartitionFetchFailure(stageAttempt, partitionId); - int fetchFailureNum = status.getPartitionFetchFailureNum(stageAttempt, partitionId); - if (fetchFailureNum >= shuffleManager.getMaxFetchFailures()) { + rssShuffleStatus.incTaskFailure(taskAttemptNumber); + if (stageResubmitManager.triggerStageRetry(rssShuffleStatus)) { reSubmitWholeStage = true; msg = String.format( - "report shuffle fetch failure as maximum number(%d) of shuffle fetch is occurred", - shuffleManager.getMaxFetchFailures()); + "Activate stage retry on stage(%d:%d), taskFailuresCount:(%d)", + stageId, stageAttempt, rssShuffleStatus.getTaskFailureAttemptCount()); + if (shuffleManager.reassignOnStageResubmit(stageId, stageAttempt, shuffleId, -1)) { + LOG.info( + "{} from executorId({}), task({}:{})", + msg, + executorId, + taskAttemptId, + taskAttemptNumber); + } } else { reSubmitWholeStage = false; - msg = "don't report shuffle fetch failure"; + msg = "Accepted task fetch failure report"; } } } @@ -307,7 +315,7 @@ public void reassignOnBlockSendFailure( * @param shuffleId the shuffle id to unregister. */ public void unregisterShuffle(int shuffleId) { - shuffleStatus.remove(shuffleId); + shuffleManager.getStageResubmitManager().clear(shuffleId); } private static class ShuffleServerFailureRecord { @@ -394,88 +402,6 @@ public boolean incPartitionWriteFailure( } } - private static class RssShuffleStatus { - private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); - private final ReentrantReadWriteLock.ReadLock readLock = lock.readLock(); - private final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock(); - private final int[] partitions; - private int stageAttempt; - - private RssShuffleStatus(int partitionNum, int stageAttempt) { - this.stageAttempt = stageAttempt; - this.partitions = new int[partitionNum]; - } - - private T withReadLock(Supplier fn) { - readLock.lock(); - try { - return fn.get(); - } finally { - readLock.unlock(); - } - } - - private T withWriteLock(Supplier fn) { - writeLock.lock(); - try { - return fn.get(); - } finally { - writeLock.unlock(); - } - } - - // todo: maybe it's more performant to just use synchronized method here. - public int getStageAttempt() { - return withReadLock(() -> this.stageAttempt); - } - - /** - * Check whether the input stage attempt is a new stage or not. If a new stage attempt is - * requested, reset partitions. - * - * @param stageAttempt the incoming stage attempt number - * @return 0 if stageAttempt == this.stageAttempt 1 if stageAttempt > this.stageAttempt -1 if - * stateAttempt < this.stageAttempt which means nothing happens - */ - public int resetStageAttemptIfNecessary(int stageAttempt) { - return withWriteLock( - () -> { - if (this.stageAttempt < stageAttempt) { - // a new stage attempt is issued. the partitions array should be clear and reset. - Arrays.fill(this.partitions, 0); - this.stageAttempt = stageAttempt; - return 1; - } else if (this.stageAttempt > stageAttempt) { - return -1; - } - return 0; - }); - } - - public void incPartitionFetchFailure(int stageAttempt, int partition) { - withWriteLock( - () -> { - if (this.stageAttempt != stageAttempt) { - // do nothing here - } else { - this.partitions[partition] = this.partitions[partition] + 1; - } - return null; - }); - } - - public int getPartitionFetchFailureNum(int stageAttempt, int partition) { - return withReadLock( - () -> { - if (this.stageAttempt != stageAttempt) { - return 0; - } else { - return this.partitions[partition]; - } - }); - } - } - @Override public void getShuffleResult( RssProtos.GetShuffleResultRequest request, diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java new file mode 100644 index 0000000000..2a5d28865d --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.stage; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class RssShuffleStatusTest { + + @Test + public void test() { + // case1 + RssShuffleStatus shuffleStatus = new RssShuffleStatus(10, 1); + shuffleStatus.updateStageAttemptIfNecessary(0); + shuffleStatus.incTaskFailure(0); + shuffleStatus.incTaskFailure(1); + assertEquals(0, shuffleStatus.getStageAttempt()); + assertEquals(2, shuffleStatus.getTaskFailureAttemptCount()); + + // case2 + shuffleStatus.updateStageAttemptIfNecessary(1); + assertEquals(1, shuffleStatus.getStageAttempt()); + assertEquals(0, shuffleStatus.getTaskFailureAttemptCount()); + shuffleStatus.incTaskFailure(1); + shuffleStatus.incTaskFailure(3); + shuffleStatus.incTaskFailure(2); + assertEquals(3, shuffleStatus.getTaskFailureAttemptCount()); + + // case3 + shuffleStatus.markStageAttemptRetried(); + assertEquals(1, shuffleStatus.getStageRetriedNumber()); + + // case4: illegal stage attempt + assertFalse(shuffleStatus.updateStageAttemptIfNecessary(0)); + } +} diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java index 317d0cd9ea..7891dc3597 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Set; +import org.apache.spark.shuffle.RssStageResubmitManager; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.ShuffleHandleInfo; @@ -83,4 +84,9 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( Map> partitionToFailureServers) { return null; } + + @Override + public RssStageResubmitManager getStageResubmitManager() { + return null; + } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 8f14c32305..1d49bc3d65 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -557,6 +557,7 @@ message ReportShuffleFetchFailureRequest { int32 stageAttemptId = 3; int32 partitionId = 4; string exception = 5; + // todo: support adding the fetchFailureServerIds into the blacklist to avoid reassign the same servers repeated ShuffleServerId fetchFailureServerId = 6; int32 stageId = 7; int64 taskAttemptId = 8; From a67aa8f38dce6bf7dd8fe7739c281baa96e857b8 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 15:15:11 +0800 Subject: [PATCH 04/23] log enhance --- .../shuffle/manager/ShuffleManagerGrpcService.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index 1c03434785..4c8a0cfc7b 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -171,13 +171,18 @@ public void reportShuffleFetchFailure( String.format( "Activate stage retry on stage(%d:%d), taskFailuresCount:(%d)", stageId, stageAttempt, rssShuffleStatus.getTaskFailureAttemptCount()); - if (shuffleManager.reassignOnStageResubmit(stageId, stageAttempt, shuffleId, -1)) { + int partitionNum = shuffleManager.getPartitionNum(shuffleId); + if (shuffleManager.reassignOnStageResubmit( + stageId, stageAttempt, shuffleId, partitionNum)) { LOG.info( - "{} from executorId({}), task({}:{})", + "{} from executorId({}), task({}:{}) on stageId({}:{}), shuffleId({})", msg, executorId, taskAttemptId, - taskAttemptNumber); + taskAttemptNumber, + stageId, + stageAttempt, + shuffleId); } } else { reSubmitWholeStage = false; From 6bf6c96f8cf8d75cdec3d471d6ab679b5c8e30f9 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 16:09:40 +0800 Subject: [PATCH 05/23] fix --- .../apache/uniffle/shuffle/manager/DummyRssShuffleManager.java | 3 ++- proto/src/main/proto/Rss.proto | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java index 7891dc3597..e3fbf0a9f1 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java @@ -31,6 +31,7 @@ public class DummyRssShuffleManager implements RssShuffleManagerInterface { public Set unregisteredShuffleIds = new LinkedHashSet<>(); + private RssStageResubmitManager stageResubmitManager = new RssStageResubmitManager(); @Override public String getAppId() { @@ -87,6 +88,6 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure( @Override public RssStageResubmitManager getStageResubmitManager() { - return null; + return stageResubmitManager; } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 1d49bc3d65..8f14c32305 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -557,7 +557,6 @@ message ReportShuffleFetchFailureRequest { int32 stageAttemptId = 3; int32 partitionId = 4; string exception = 5; - // todo: support adding the fetchFailureServerIds into the blacklist to avoid reassign the same servers repeated ShuffleServerId fetchFailureServerId = 6; int32 stageId = 7; int64 taskAttemptId = 8; From 0664f5f76d50627aabdb3bb10b47a0d10c068d7b Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 16:41:01 +0800 Subject: [PATCH 06/23] fix --- .../shuffle/RssStageResubmitManager.java | 79 +++++++++---------- .../spark/shuffle/stage/RssShuffleStatus.java | 6 +- .../manager/RssShuffleManagerBase.java | 21 ++--- .../manager/RssShuffleManagerInterface.java | 2 +- .../manager/ShuffleManagerGrpcService.java | 5 +- .../shuffle/stage/RssShuffleStatusTest.java | 2 +- .../manager/DummyRssShuffleManager.java | 6 +- .../ShuffleManagerGrpcServiceTest.java | 5 ++ .../spark/shuffle/RssShuffleManager.java | 6 +- .../spark/shuffle/RssShuffleManager.java | 6 +- 10 files changed, 71 insertions(+), 67 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java index e8555a4af8..4e3f0a46f9 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -20,29 +20,58 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; -import com.google.common.collect.Sets; import org.apache.spark.SparkConf; import org.apache.spark.shuffle.stage.RssShuffleStatus; import org.apache.spark.shuffle.stage.RssShuffleStatusForReader; import org.apache.spark.shuffle.stage.RssShuffleStatusForWriter; +import org.eclipse.jetty.util.ConcurrentHashSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.uniffle.common.util.JavaUtils; - +/** This class is to manage shuffle status for stage retry */ public class RssStageResubmitManager { private static final Logger LOG = LoggerFactory.getLogger(RssStageResubmitManager.class); - private final SparkConf sparkConf = new SparkConf(); + private final SparkConf sparkConf; private final Map shuffleStatusForReader = new ConcurrentHashMap<>(); private final Map shuffleStatusForWriter = new ConcurrentHashMap<>(); + private final Map shuffleLock = new ConcurrentHashMap<>(); + + private final Set blackListedServerIds = new ConcurrentHashSet<>(); + + public RssStageResubmitManager(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + public Object getOrCreateShuffleLock(int shuffleId) { + return shuffleLock.computeIfAbsent(shuffleId, x -> new Object()); + } public void clear(int shuffleId) { shuffleStatusForReader.remove(shuffleId); shuffleStatusForWriter.remove(shuffleId); + shuffleLock.remove(shuffleId); + } + + public boolean isStageAttemptRetried(int shuffleId, int stageId, int stageAttemptNumber) { + RssShuffleStatus readerShuffleStatus = shuffleStatusForReader.get(shuffleId); + RssShuffleStatus writerShuffleStatus = shuffleStatusForWriter.get(shuffleId); + if (readerShuffleStatus == null && writerShuffleStatus == null) { + return false; + } + if (readerShuffleStatus != null + && readerShuffleStatus.isStageAttemptRetried(stageAttemptNumber)) { + return true; + } + if (writerShuffleStatus != null + && writerShuffleStatus.isStageAttemptRetried(stageAttemptNumber)) { + return true; + } + return false; } public RssShuffleStatus getShuffleStatusForReader(int shuffleId, int stageId, int stageAttempt) { @@ -65,11 +94,11 @@ public RssShuffleStatus getShuffleStatusForWriter(int shuffleId, int stageId, in return null; } - public boolean triggerStageRetry(RssShuffleStatus shuffleStatus) { + public boolean activateStageRetry(RssShuffleStatus shuffleStatus) { final String TASK_MAX_FAILURE = "spark.task.maxFailures"; int sparkTaskMaxFailures = sparkConf.getInt(TASK_MAX_FAILURE, 4); if (shuffleStatus instanceof RssShuffleStatusForReader) { - if (shuffleStatus.getStageRetriedNumber() > 1) { + if (shuffleStatus.getStageRetriedCount() > 1) { LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it."); return false; } @@ -81,41 +110,11 @@ public boolean triggerStageRetry(RssShuffleStatus shuffleStatus) { return false; } - /** Blacklist of the Shuffle Server when the write fails. */ - private Set serverIdBlackList; - /** - * Prevent multiple tasks from reporting FetchFailed, resulting in multiple ShuffleServer - * assignments, stageID, Attemptnumber Whether to reassign the combination flag; - */ - private Map serverAssignedInfos; - - public RssStageResubmitManager() { - this.serverIdBlackList = Sets.newConcurrentHashSet(); - this.serverAssignedInfos = JavaUtils.newConcurrentMap(); - } - - public Set getServerIdBlackList() { - return serverIdBlackList; - } - - public void resetServerIdBlackList(Set failuresShuffleServerIds) { - this.serverIdBlackList = failuresShuffleServerIds; - } - - public void recordFailuresShuffleServer(String shuffleServerId) { - serverIdBlackList.add(shuffleServerId); - } - - public RssStageInfo recordAndGetServerAssignedInfo(int shuffleId, String stageIdAndAttempt) { - - return serverAssignedInfos.computeIfAbsent( - shuffleId, id -> new RssStageInfo(stageIdAndAttempt, false)); + public Set getBlackListedServerIds() { + return blackListedServerIds.stream().collect(Collectors.toSet()); } - public void recordAndGetServerAssignedInfo( - int shuffleId, String stageIdAndAttempt, boolean isRetried) { - serverAssignedInfos - .computeIfAbsent(shuffleId, id -> new RssStageInfo(stageIdAndAttempt, false)) - .setReassigned(isRetried); + public void addBlackListedServer(String shuffleServerId) { + blackListedServerIds.add(shuffleServerId); } } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java index 099166120d..43cbb85d68 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java @@ -64,7 +64,11 @@ private T withWriteLock(Supplier fn) { } } - public int getStageRetriedNumber() { + public boolean isStageAttemptRetried(int stageAttempt) { + return withReadLock(() -> stageAttemptRetriedRecords.contains(stageAttempt)); + } + + public int getStageRetriedCount() { return withReadLock(() -> this.stageAttemptRetriedRecords.size()); } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index fbc03c078d..e0259c8225 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -46,7 +46,6 @@ import org.apache.spark.SparkException; import org.apache.spark.shuffle.RssSparkConfig; import org.apache.spark.shuffle.RssSparkShuffleUtils; -import org.apache.spark.shuffle.RssStageInfo; import org.apache.spark.shuffle.RssStageResubmitManager; import org.apache.spark.shuffle.ShuffleHandleInfoManager; import org.apache.spark.shuffle.ShuffleManager; @@ -648,8 +647,8 @@ public int getMaxFetchFailures() { * @param shuffleServerId */ @Override - public void addFailuresShuffleServerInfos(String shuffleServerId) { - rssStageResubmitManager.recordFailuresShuffleServer(shuffleServerId); + public void addFaultShuffleServer(String shuffleServerId) { + rssStageResubmitManager.addBlackListedServer(shuffleServerId); } /** @@ -661,12 +660,9 @@ public void addFailuresShuffleServerInfos(String shuffleServerId) { @Override public boolean reassignOnStageResubmit( int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) { - String stageIdAndAttempt = stageId + "_" + stageAttemptNumber; - RssStageInfo rssStageInfo = - rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, stageIdAndAttempt); - synchronized (rssStageInfo) { - Boolean needReassign = rssStageInfo.isReassigned(); - if (!needReassign) { + Object shuffleLock = rssStageResubmitManager.getOrCreateShuffleLock(shuffleId); + synchronized (shuffleLock) { + if (!rssStageResubmitManager.isStageAttemptRetried(shuffleId, stageId, stageAttemptNumber)) { int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); @@ -682,7 +678,7 @@ public boolean reassignOnStageResubmit( 1, requiredShuffleServerNumber, estimateTaskConcurrency, - rssStageResubmitManager.getServerIdBlackList(), + rssStageResubmitManager.getBlackListedServerIds(), stageId, stageAttemptNumber, false); @@ -701,7 +697,6 @@ public boolean reassignOnStageResubmit( StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = (StageAttemptShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId); stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo); - rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, stageIdAndAttempt, true); LOG.info( "The stage retry has been triggered successfully for the stageId: {}, attemptNumber: {}", stageId, @@ -880,7 +875,7 @@ private Map> requestShuffleAssignment( assignmentTags.add(clientType); long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); - faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList()); + faultyServerIds.addAll(rssStageResubmitManager.getBlackListedServerIds()); try { return RetryUtils.retry( () -> { @@ -928,7 +923,7 @@ protected Map> requestShuffleAssignment( long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); - faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList()); + faultyServerIds.addAll(rssStageResubmitManager.getBlackListedServerIds()); try { return RetryUtils.retry( () -> { diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java index d927e1b887..61b14a020f 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java @@ -79,7 +79,7 @@ public interface RssShuffleManagerInterface { * * @param shuffleServerId */ - void addFailuresShuffleServerInfos(String shuffleServerId); + void addFaultShuffleServer(String shuffleServerId); boolean reassignOnStageResubmit(int stageId, int stageAttemptNumber, int shuffleId, int numMaps); diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index 4c8a0cfc7b..1caca58094 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -165,7 +165,7 @@ public void reportShuffleFetchFailure( } else { code = RssProtos.StatusCode.SUCCESS; rssShuffleStatus.incTaskFailure(taskAttemptNumber); - if (stageResubmitManager.triggerStageRetry(rssShuffleStatus)) { + if (stageResubmitManager.activateStageRetry(rssShuffleStatus)) { reSubmitWholeStage = true; msg = String.format( @@ -397,8 +397,7 @@ public boolean incPartitionWriteFailure( Map.Entry shuffleServerInfoIntegerEntry = list.get(0); if (shuffleServerInfoIntegerEntry.getValue().get() > shuffleManager.getMaxFetchFailures()) { - shuffleManager.addFailuresShuffleServerInfos( - shuffleServerInfoIntegerEntry.getKey()); + shuffleManager.addFaultShuffleServer(shuffleServerInfoIntegerEntry.getKey()); return true; } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java index 2a5d28865d..3f3ae35d81 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/stage/RssShuffleStatusTest.java @@ -45,7 +45,7 @@ public void test() { // case3 shuffleStatus.markStageAttemptRetried(); - assertEquals(1, shuffleStatus.getStageRetriedNumber()); + assertEquals(1, shuffleStatus.getStageRetriedCount()); // case4: illegal stage attempt assertFalse(shuffleStatus.updateStageAttemptIfNecessary(0)); diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java index e3fbf0a9f1..025f8af75a 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Set; +import org.apache.spark.SparkConf; import org.apache.spark.shuffle.RssStageResubmitManager; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.ShuffleHandleInfo; @@ -31,7 +32,8 @@ public class DummyRssShuffleManager implements RssShuffleManagerInterface { public Set unregisteredShuffleIds = new LinkedHashSet<>(); - private RssStageResubmitManager stageResubmitManager = new RssStageResubmitManager(); + private RssStageResubmitManager stageResubmitManager = + new RssStageResubmitManager(new SparkConf()); @Override public String getAppId() { @@ -69,7 +71,7 @@ public int getMaxFetchFailures() { } @Override - public void addFailuresShuffleServerInfos(String shuffleServerId) {} + public void addFaultShuffleServer(String shuffleServerId) {} @Override public boolean reassignOnStageResubmit( diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java index ac3fbda7e3..c9007c3ada 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java @@ -18,6 +18,8 @@ package org.apache.uniffle.shuffle.manager; import io.grpc.stub.StreamObserver; +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.RssStageResubmitManager; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -41,6 +43,8 @@ public class ShuffleManagerGrpcServiceTest { private static final int shuffleId = 0; private static final int numMaps = 100; private static final int numReduces = 10; + private static final RssStageResubmitManager stageResubmitManager = + new RssStageResubmitManager(new SparkConf()); private static class MockedStreamObserver implements StreamObserver { T value; @@ -70,6 +74,7 @@ public static void setup() { Mockito.when(mockShuffleManager.getNumMaps(shuffleId)).thenReturn(numMaps); Mockito.when(mockShuffleManager.getPartitionNum(shuffleId)).thenReturn(numReduces); Mockito.when(mockShuffleManager.getMaxFetchFailures()).thenReturn(maxFetchFailures); + Mockito.when(mockShuffleManager.getStageResubmitManager()).thenReturn(stageResubmitManager); } @Test diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index de5d4da635..2572fc21ae 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -254,7 +254,7 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { keepAliveTime); } this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); - this.rssStageResubmitManager = new RssStageResubmitManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(sparkConf); } // This method is called in Spark driver side, @@ -346,7 +346,7 @@ public ShuffleHandle registerShuffle( 1, requiredShuffleServerNumber, estimateTaskConcurrency, - rssStageResubmitManager.getServerIdBlackList(), + rssStageResubmitManager.getBlackListedServerIds(), 0); startHeartbeat(); @@ -742,7 +742,7 @@ public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) { private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) { Set faultyServerIds = Sets.newHashSet(faultyShuffleServerId); - faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList()); + faultyServerIds.addAll(rssStageResubmitManager.getBlackListedServerIds()); Map> partitionToServers = requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, 0); if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) { diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 1d50507904..5d170b4471 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -284,7 +284,7 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { this.partitionReassignMaxServerNum = rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); - this.rssStageResubmitManager = new RssStageResubmitManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(sparkConf); } public CompletableFuture sendData(AddBlockEvent event) { @@ -378,7 +378,7 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s this.partitionReassignMaxServerNum = rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); - this.rssStageResubmitManager = new RssStageResubmitManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(sparkConf); } // This method is called in Spark driver side, @@ -464,7 +464,7 @@ public ShuffleHandle registerShuffle( 1, requiredShuffleServerNumber, estimateTaskConcurrency, - rssStageResubmitManager.getServerIdBlackList(), + rssStageResubmitManager.getBlackListedServerIds(), 0); startHeartbeat(); shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions()); From 46bc3ae863b4c06334030b6470a5d682a5056a7e Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 16:56:06 +0800 Subject: [PATCH 07/23] fix incorrect logic --- .../shuffle/RssStageResubmitManager.java | 1 - .../manager/ShuffleManagerGrpcService.java | 25 +++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java index 4e3f0a46f9..1d410ced56 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -103,7 +103,6 @@ public boolean activateStageRetry(RssShuffleStatus shuffleStatus) { return false; } if (shuffleStatus.getTaskFailureAttemptCount() >= sparkTaskMaxFailures) { - shuffleStatus.markStageAttemptRetried(); return true; } } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index 1caca58094..e6589c4ebe 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -172,17 +172,20 @@ public void reportShuffleFetchFailure( "Activate stage retry on stage(%d:%d), taskFailuresCount:(%d)", stageId, stageAttempt, rssShuffleStatus.getTaskFailureAttemptCount()); int partitionNum = shuffleManager.getPartitionNum(shuffleId); - if (shuffleManager.reassignOnStageResubmit( - stageId, stageAttempt, shuffleId, partitionNum)) { - LOG.info( - "{} from executorId({}), task({}:{}) on stageId({}:{}), shuffleId({})", - msg, - executorId, - taskAttemptId, - taskAttemptNumber, - stageId, - stageAttempt, - shuffleId); + synchronized (rssShuffleStatus) { + if (shuffleManager.reassignOnStageResubmit( + stageId, stageAttempt, shuffleId, partitionNum)) { + LOG.info( + "{} from executorId({}), task({}:{}) on stageId({}:{}), shuffleId({})", + msg, + executorId, + taskAttemptId, + taskAttemptNumber, + stageId, + stageAttempt, + shuffleId); + } + rssShuffleStatus.markStageAttemptRetried(); } } else { reSubmitWholeStage = false; From 9e0c0534f81447c50e4fe9f9d773aac7b1a9ab26 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 17:22:50 +0800 Subject: [PATCH 08/23] refactor write side reassign --- .../manager/ShuffleManagerGrpcService.java | 168 +++++------------- .../shuffle/writer/RssShuffleWriter.java | 23 +-- .../shuffle/writer/RssShuffleWriter.java | 21 +-- .../RssReportShuffleWriteFailureRequest.java | 20 ++- proto/src/main/proto/Rss.proto | 4 + 5 files changed, 84 insertions(+), 152 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index e6589c4ebe..db700d9897 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -17,13 +17,8 @@ package org.apache.uniffle.shuffle.manager; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.function.Supplier; import java.util.stream.Collectors; import com.google.protobuf.UnsafeByteOperations; @@ -38,7 +33,6 @@ import org.apache.uniffle.common.ReceivingFailureServer; import org.apache.uniffle.common.ShuffleServerInfo; -import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.proto.ShuffleManagerGrpc.ShuffleManagerImplBase; @@ -46,9 +40,6 @@ public class ShuffleManagerGrpcService extends ShuffleManagerImplBase { private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class); - // The shuffleId mapping records the number of ShuffleServer write failures - private final Map shuffleWrtieStatus = - JavaUtils.newConcurrentMap(); private final RssShuffleManagerInterface shuffleManager; public ShuffleManagerGrpcService(RssShuffleManagerInterface shuffleManager) { @@ -63,6 +54,11 @@ public void reportShuffleWriteFailure( int shuffleId = request.getShuffleId(); int stageAttemptNumber = request.getStageAttemptNumber(); List shuffleServerIdsList = request.getShuffleServerIdsList(); + int stageId = request.getStageId(); + String executorId = request.getExecutorId(); + long taskAttemptId = request.getTaskAttemptId(); + int taskAttemptNumber = request.getTaskAttemptNumber(); + RssProtos.StatusCode code; boolean reSubmitWholeStage; String msg; @@ -75,43 +71,55 @@ public void reportShuffleWriteFailure( code = RssProtos.StatusCode.INVALID_REQUEST; reSubmitWholeStage = false; } else { - Map shuffleServerInfoIntegerMap = JavaUtils.newConcurrentMap(); - List shuffleServerInfos = - ShuffleServerInfo.fromProto(shuffleServerIdsList); - shuffleServerInfos.forEach( - shuffleServerInfo -> { - shuffleServerInfoIntegerMap.put(shuffleServerInfo.getId(), new AtomicInteger(0)); - }); - ShuffleServerFailureRecord shuffleServerFailureRecord = - shuffleWrtieStatus.computeIfAbsent( - shuffleId, - key -> - new ShuffleServerFailureRecord(shuffleServerInfoIntegerMap, stageAttemptNumber)); - boolean resetflag = - shuffleServerFailureRecord.resetStageAttemptIfNecessary(stageAttemptNumber); - if (resetflag) { + RssStageResubmitManager stageResubmitManager = shuffleManager.getStageResubmitManager(); + RssShuffleStatus shuffleStatus = + stageResubmitManager.getShuffleStatusForWriter(shuffleId, stageId, stageAttemptNumber); + if (shuffleStatus == null) { msg = String.format( - "got an old stage(%d vs %d) shuffle write failure report, which should be impossible.", - shuffleServerFailureRecord.getStageAttempt(), stageAttemptNumber); + "got an old stage(%d:%d) shuffle(%d) write failure report from executor(%s), task(%d:%d) which should be impossible.", + stageId, + stageAttemptNumber, + shuffleId, + executorId, + taskAttemptId, + taskAttemptNumber); LOG.warn(msg); code = RssProtos.StatusCode.INVALID_REQUEST; reSubmitWholeStage = false; } else { code = RssProtos.StatusCode.SUCCESS; - // update the stage shuffleServer write failed count - boolean fetchFailureflag = - shuffleServerFailureRecord.incPartitionWriteFailure( - stageAttemptNumber, shuffleServerInfos, shuffleManager); - if (fetchFailureflag) { + shuffleStatus.incTaskFailure(taskAttemptNumber); + if (shuffleServerIdsList != null) { + List serverInfos = ShuffleServerInfo.fromProto(shuffleServerIdsList); + serverInfos.stream().forEach(x -> stageResubmitManager.addBlackListedServer(x.getId())); + } + if (stageResubmitManager.activateStageRetry(shuffleStatus)) { reSubmitWholeStage = true; msg = String.format( - "report shuffle write failure as maximum number(%d) of shuffle write is occurred", - shuffleManager.getMaxFetchFailures()); + "Activate stage retry for writer on stage(%d:%d), taskFailuresCount:(%d)", + stageId, stageAttemptNumber, shuffleStatus.getTaskFailureAttemptCount()); + int partitionNum = shuffleManager.getPartitionNum(shuffleId); + Object shuffleLock = stageResubmitManager.getOrCreateShuffleLock(shuffleId); + synchronized (shuffleLock) { + if (shuffleManager.reassignOnStageResubmit( + stageId, stageAttemptNumber, shuffleId, partitionNum)) { + LOG.info( + "{} from executorId({}), task({}:{}) on stageId({}:{}), shuffleId({})", + msg, + executorId, + taskAttemptId, + taskAttemptNumber, + stageId, + stageAttemptNumber, + shuffleId); + } + shuffleStatus.markStageAttemptRetried(); + } } else { reSubmitWholeStage = false; - msg = "don't report shuffle write failure"; + msg = "Accepted task write failure report"; } } } @@ -157,8 +165,8 @@ public void reportShuffleFetchFailure( if (rssShuffleStatus == null) { msg = String.format( - "got an old stage(%d:%d) shuffle fetch failure report from executor(%s), task(%d:%d) which should be impossible.", - stageId, stageAttempt, executorId, taskAttemptId, taskAttemptNumber); + "got an old stage(%d:%d) shuffle(%d) fetch failure report from executor(%s), task(%d:%d) which should be impossible.", + stageId, stageAttempt, shuffleId, executorId, taskAttemptId, taskAttemptNumber); LOG.warn(msg); code = RssProtos.StatusCode.INVALID_REQUEST; reSubmitWholeStage = false; @@ -169,10 +177,11 @@ public void reportShuffleFetchFailure( reSubmitWholeStage = true; msg = String.format( - "Activate stage retry on stage(%d:%d), taskFailuresCount:(%d)", + "Activate stage retry for reader on stage(%d:%d), taskFailuresCount:(%d)", stageId, stageAttempt, rssShuffleStatus.getTaskFailureAttemptCount()); int partitionNum = shuffleManager.getPartitionNum(shuffleId); - synchronized (rssShuffleStatus) { + Object shuffleLock = stageResubmitManager.getOrCreateShuffleLock(shuffleId); + synchronized (shuffleLock) { if (shuffleManager.reassignOnStageResubmit( stageId, stageAttempt, shuffleId, partitionNum)) { LOG.info( @@ -326,89 +335,6 @@ public void unregisterShuffle(int shuffleId) { shuffleManager.getStageResubmitManager().clear(shuffleId); } - private static class ShuffleServerFailureRecord { - private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); - private final ReentrantReadWriteLock.ReadLock readLock = lock.readLock(); - private final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock(); - private final Map shuffleServerFailureRecordCount; - private int stageAttemptNumber; - - private ShuffleServerFailureRecord( - Map shuffleServerFailureRecordCount, int stageAttemptNumber) { - this.shuffleServerFailureRecordCount = shuffleServerFailureRecordCount; - this.stageAttemptNumber = stageAttemptNumber; - } - - private T withReadLock(Supplier fn) { - readLock.lock(); - try { - return fn.get(); - } finally { - readLock.unlock(); - } - } - - private T withWriteLock(Supplier fn) { - writeLock.lock(); - try { - return fn.get(); - } finally { - writeLock.unlock(); - } - } - - public int getStageAttempt() { - return withReadLock(() -> this.stageAttemptNumber); - } - - public boolean resetStageAttemptIfNecessary(int stageAttemptNumber) { - return withWriteLock( - () -> { - if (this.stageAttemptNumber < stageAttemptNumber) { - // a new stage attempt is issued. Record the shuffleServer status of the Map should be - // clear and reset. - shuffleServerFailureRecordCount.clear(); - this.stageAttemptNumber = stageAttemptNumber; - return false; - } else if (this.stageAttemptNumber > stageAttemptNumber) { - return true; - } - return false; - }); - } - - public boolean incPartitionWriteFailure( - int stageAttemptNumber, - List shuffleServerInfos, - RssShuffleManagerInterface shuffleManager) { - return withWriteLock( - () -> { - if (this.stageAttemptNumber != stageAttemptNumber) { - // do nothing here - return false; - } - shuffleServerInfos.forEach( - shuffleServerInfo -> { - shuffleServerFailureRecordCount - .computeIfAbsent(shuffleServerInfo.getId(), k -> new AtomicInteger()) - .incrementAndGet(); - }); - List> list = - new ArrayList(shuffleServerFailureRecordCount.entrySet()); - if (!list.isEmpty()) { - Collections.sort(list, (o1, o2) -> (o1.getValue().get() - o2.getValue().get())); - Map.Entry shuffleServerInfoIntegerEntry = list.get(0); - if (shuffleServerInfoIntegerEntry.getValue().get() - > shuffleManager.getMaxFetchFailures()) { - shuffleManager.addFaultShuffleServer(shuffleServerInfoIntegerEntry.getKey()); - return true; - } - } - return false; - }); - } - } - @Override public void getShuffleResult( RssProtos.GetShuffleResultRequest request, diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 65b66df3df..41866d5b42 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -46,7 +46,9 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -66,9 +68,7 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.impl.FailedBlockSendTracker; -import org.apache.uniffle.client.request.RssReassignServersRequest; import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest; -import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.ShuffleBlockInfo; @@ -542,13 +542,18 @@ private void throwFetchFailedIfNecessary(Exception e) { shuffleManager.getBlockIdsFailedSendTracker(taskId); List shuffleServerInfos = Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers()); + TaskContext taskContext = TaskContext$.MODULE$.get(); RssReportShuffleWriteFailureRequest req = new RssReportShuffleWriteFailureRequest( appId, shuffleId, taskContext.stageAttemptNumber(), shuffleServerInfos, - e.getMessage()); + e.getMessage(), + taskContext.stageId(), + taskContext.taskAttemptId(), + taskContext.attemptNumber(), + SparkEnv.get().executorId()); RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); String driver = rssConf.getString("driver.host", ""); int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); @@ -556,18 +561,6 @@ private void throwFetchFailedIfNecessary(Exception e) { RssReportShuffleWriteFailureResponse response = shuffleManagerClient.reportShuffleWriteFailure(req); if (response.getReSubmitWholeStage()) { - // The shuffle server is reassigned. - RssReassignServersRequest rssReassignServersRequest = - new RssReassignServersRequest( - taskContext.stageId(), - taskContext.stageAttemptNumber(), - shuffleId, - partitioner.numPartitions()); - RssReassignServersResponse rssReassignServersResponse = - shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest); - LOG.info( - "Whether the reassignment is successful: {}", - rssReassignServersResponse.isNeedReassign()); // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is // provided. FetchFailedException ffe = diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 3dfc2fd620..4b9f62e910 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -54,6 +54,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.FetchFailedException; @@ -74,10 +75,8 @@ import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.client.impl.TrackingBlockStatus; import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest; -import org.apache.uniffle.client.request.RssReassignServersRequest; import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest; import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse; -import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.ReceivingFailureServer; @@ -828,13 +827,18 @@ private void throwFetchFailedIfNecessary(Exception e) { shuffleManager.getBlockIdsFailedSendTracker(taskId); List shuffleServerInfos = Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers()); + TaskContext taskContext = TaskContext$.MODULE$.get(); RssReportShuffleWriteFailureRequest req = new RssReportShuffleWriteFailureRequest( appId, shuffleId, taskContext.stageAttemptNumber(), shuffleServerInfos, - e.getMessage()); + e.getMessage(), + taskContext.stageId(), + taskContext.taskAttemptId(), + taskContext.attemptNumber(), + SparkEnv.get().executorId()); RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); String driver = rssConf.getString("driver.host", ""); int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); @@ -842,17 +846,6 @@ private void throwFetchFailedIfNecessary(Exception e) { RssReportShuffleWriteFailureResponse response = shuffleManagerClient.reportShuffleWriteFailure(req); if (response.getReSubmitWholeStage()) { - RssReassignServersRequest rssReassignServersRequest = - new RssReassignServersRequest( - taskContext.stageId(), - taskContext.stageAttemptNumber(), - shuffleId, - partitioner.numPartitions()); - RssReassignServersResponse rssReassignServersResponse = - shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest); - LOG.info( - "Whether the reassignment is successful: {}", - rssReassignServersResponse.isNeedReassign()); // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is // provided. FetchFailedException ffe = diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteFailureRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteFailureRequest.java index c05176769d..ad7de4699d 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteFailureRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteFailureRequest.java @@ -31,18 +31,30 @@ public class RssReportShuffleWriteFailureRequest { private int stageAttemptNumber; private List shuffleServerInfos; private String exception; + private int stageId; + private long taskAttemptId; + private int taskAttemptNumber; + private String executorId; public RssReportShuffleWriteFailureRequest( String appId, int shuffleId, int stageAttemptNumber, List shuffleServerInfos, - String exception) { + String exception, + int stageId, + long taskAttemptId, + int taskAttemptNumber, + String executorId) { this.appId = appId; this.shuffleId = shuffleId; this.stageAttemptNumber = stageAttemptNumber; this.shuffleServerInfos = shuffleServerInfos; this.exception = exception; + this.stageId = stageId; + this.taskAttemptId = taskAttemptId; + this.taskAttemptNumber = taskAttemptNumber; + this.executorId = executorId; } public ReportShuffleWriteFailureRequest toProto() { @@ -61,7 +73,11 @@ public ReportShuffleWriteFailureRequest toProto() { .setAppId(appId) .setShuffleId(shuffleId) .setStageAttemptNumber(stageAttemptNumber) - .addAllShuffleServerIds(shuffleServerIds); + .addAllShuffleServerIds(shuffleServerIds) + .setStageId(stageId) + .setTaskAttemptId(taskAttemptId) + .setTaskAttemptNumber(taskAttemptNumber) + .setExecutorId(executorId); if (exception != null) { builder.setException(exception); } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 8f14c32305..ad11ab9442 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -610,6 +610,10 @@ message ReportShuffleWriteFailureRequest { int32 stageAttemptNumber = 3; repeated ShuffleServerId shuffleServerIds= 5; string exception = 6; + int32 stageId = 7; + int64 taskAttemptId = 8; + int32 taskAttemptNumber = 9; + string executorId = 10; } message ReportShuffleWriteFailureResponse { From 9996dbc129316f31d96e9ea602dfb2fab8b41bdc Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 17:24:36 +0800 Subject: [PATCH 09/23] remove dead code --- .../uniffle/shuffle/manager/ShuffleManagerGrpcService.java | 1 - 1 file changed, 1 deletion(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index db700d9897..a721159b78 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -140,7 +140,6 @@ public void reportShuffleFetchFailure( StreamObserver responseObserver) { String appId = request.getAppId(); int stageAttempt = request.getStageAttemptId(); - int partitionId = request.getPartitionId(); int shuffleId = request.getShuffleId(); int stageId = request.getStageId(); long taskAttemptId = request.getTaskAttemptId(); From 3d0450ff9f1479e38d0d16839b90985d55093528 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 19 Jun 2024 17:51:53 +0800 Subject: [PATCH 10/23] add unit tests --- .../shuffle/RssStageResubmitManager.java | 14 ++-- .../shuffle/RssStageResubmitManagerTest.java | 67 +++++++++++++++++++ 2 files changed, 73 insertions(+), 8 deletions(-) create mode 100644 client-spark/common/src/test/java/org/apache/spark/shuffle/RssStageResubmitManagerTest.java diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java index 1d410ced56..635e634732 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -97,14 +97,12 @@ public RssShuffleStatus getShuffleStatusForWriter(int shuffleId, int stageId, in public boolean activateStageRetry(RssShuffleStatus shuffleStatus) { final String TASK_MAX_FAILURE = "spark.task.maxFailures"; int sparkTaskMaxFailures = sparkConf.getInt(TASK_MAX_FAILURE, 4); - if (shuffleStatus instanceof RssShuffleStatusForReader) { - if (shuffleStatus.getStageRetriedCount() > 1) { - LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it."); - return false; - } - if (shuffleStatus.getTaskFailureAttemptCount() >= sparkTaskMaxFailures) { - return true; - } + if (shuffleStatus.getStageRetriedCount() > 1) { + LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it."); + return false; + } + if (shuffleStatus.getTaskFailureAttemptCount() >= sparkTaskMaxFailures) { + return true; } return false; } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/RssStageResubmitManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssStageResubmitManagerTest.java new file mode 100644 index 0000000000..deef55df64 --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/RssStageResubmitManagerTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.stage.RssShuffleStatus; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RssStageResubmitManagerTest { + + @Test + public void testMultipleStatusFromWriterAndReader() { + int shuffleId = 1; + int stageId = 10; + RssStageResubmitManager manager = new RssStageResubmitManager(new SparkConf()); + RssShuffleStatus readerShuffleStatus = manager.getShuffleStatusForReader(shuffleId, stageId, 0); + + // case1 + readerShuffleStatus.incTaskFailure(0); + readerShuffleStatus.incTaskFailure(1); + readerShuffleStatus.incTaskFailure(2); + assertFalse(manager.activateStageRetry(readerShuffleStatus)); + + readerShuffleStatus.incTaskFailure(3); + assertTrue(manager.activateStageRetry(readerShuffleStatus)); + + readerShuffleStatus.markStageAttemptRetried(); + assertTrue(manager.isStageAttemptRetried(shuffleId, stageId, 0)); + assertFalse(manager.isStageAttemptRetried(shuffleId, stageId, 1)); + + readerShuffleStatus = manager.getShuffleStatusForReader(shuffleId, stageId, 1); + + // case2 + RssShuffleStatus writerShuffleStatus = manager.getShuffleStatusForWriter(shuffleId, stageId, 1); + writerShuffleStatus.incTaskFailure(0); + writerShuffleStatus.incTaskFailure(1); + readerShuffleStatus.incTaskFailure(0); + readerShuffleStatus.incTaskFailure(1); + assertFalse(manager.activateStageRetry(readerShuffleStatus)); + assertFalse(manager.activateStageRetry(writerShuffleStatus)); + + writerShuffleStatus.incTaskFailure(2); + writerShuffleStatus.incTaskFailure(3); + if (manager.activateStageRetry(writerShuffleStatus)) { + writerShuffleStatus.markStageAttemptRetried(); + } + assertTrue(manager.isStageAttemptRetried(shuffleId, stageId, 1)); + } +} From c8384cdfad736c69f7875a1f1e4bca1c7eb4e741 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Thu, 20 Jun 2024 10:25:19 +0800 Subject: [PATCH 11/23] make the reader fetch failure serverIds into blacklist --- .../spark/shuffle/RssSparkShuffleUtils.java | 4 ++-- .../reader/RssFetchFailedIterator.java | 4 ++-- .../manager/ShuffleManagerGrpcService.java | 7 +++++++ .../exception/RssFetchFailedException.java | 20 ++++++++++++++----- .../coordinator/CoordinatorGrpcService.java | 2 +- .../impl/ComposedClientReadHandler.java | 8 +++++++- 6 files changed, 34 insertions(+), 11 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index ca3bd0c93c..9e8cfa7011 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -20,8 +20,8 @@ import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -374,7 +374,7 @@ public static RssException reportRssFetchFailedException( stageAttemptId, partitionId, rssFetchFailedException.getMessage(), - Collections.emptyList(), + new ArrayList<>(rssFetchFailedException.getFetchFailureServerIds()), taskContext.stageId(), taskContext.taskAttemptId(), taskContext.attemptNumber(), diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java index 7d9e596a20..13372d08b0 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.reader; import java.io.IOException; -import java.util.Collections; +import java.util.ArrayList; import java.util.Objects; import scala.Product2; @@ -123,7 +123,7 @@ private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) { builder.stageAttemptId, builder.partitionId, e.getMessage(), - Collections.emptyList(), + new ArrayList<>(e.getFetchFailureServerIds()), taskContext.stageId(), taskContext.taskAttemptId(), taskContext.attemptNumber(), diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index a721159b78..733d9fcd5c 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -23,6 +23,7 @@ import com.google.protobuf.UnsafeByteOperations; import io.grpc.stub.StreamObserver; +import org.apache.commons.collections.CollectionUtils; import org.apache.spark.shuffle.RssStageResubmitManager; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; @@ -145,6 +146,7 @@ public void reportShuffleFetchFailure( long taskAttemptId = request.getTaskAttemptId(); int taskAttemptNumber = request.getTaskAttemptNumber(); String executorId = request.getExecutorId(); + List serverIds = request.getFetchFailureServerIdList(); RssStageResubmitManager stageResubmitManager = shuffleManager.getStageResubmitManager(); RssProtos.StatusCode code; @@ -172,6 +174,11 @@ public void reportShuffleFetchFailure( } else { code = RssProtos.StatusCode.SUCCESS; rssShuffleStatus.incTaskFailure(taskAttemptNumber); + if (CollectionUtils.isNotEmpty(serverIds)) { + ShuffleServerInfo.fromProto(serverIds).stream() + .map(x -> x.getId()) + .forEach(x -> stageResubmitManager.addBlackListedServer(x)); + } if (stageResubmitManager.activateStageRetry(rssShuffleStatus)) { reSubmitWholeStage = true; msg = diff --git a/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java b/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java index c61e6776dd..8da18a6752 100644 --- a/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java +++ b/common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java @@ -17,15 +17,19 @@ package org.apache.uniffle.common.exception; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + import org.apache.uniffle.common.ShuffleServerInfo; /** Dedicated exception for rss client's shuffle failed related exception. */ public class RssFetchFailedException extends RssException { - private ShuffleServerInfo fetchFailureServerId; + private Set fetchFailureServerIds = new HashSet<>(); - public RssFetchFailedException(String message, ShuffleServerInfo fetchFailureServerId) { + public RssFetchFailedException(String message, ShuffleServerInfo... fetchFailureServerIds) { super(message); - this.fetchFailureServerId = fetchFailureServerId; + Arrays.stream(fetchFailureServerIds).forEach(x -> this.fetchFailureServerIds.add(x)); } public RssFetchFailedException(String message) { @@ -36,7 +40,13 @@ public RssFetchFailedException(String message, Throwable e) { super(message, e); } - public ShuffleServerInfo getFetchFailureServerId() { - return fetchFailureServerId; + public RssFetchFailedException( + String message, Throwable e, ShuffleServerInfo... fetchFailureServerIds) { + super(message, e); + Arrays.stream(fetchFailureServerIds).forEach(x -> this.fetchFailureServerIds.add(x)); + } + + public Set getFetchFailureServerIds() { + return fetchFailureServerIds; } } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java index ddca8a1e33..962d2bd0e7 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java @@ -127,7 +127,7 @@ public void getShuffleAssignments( replica, requiredTags, requiredShuffleServerNumber, - faultyServerIds.size(), + faultyServerIds, request.getStageId(), request.getStageAttemptNumber(), request.getReassign()); diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java index e619bbd1dd..affbd477e6 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java @@ -26,6 +26,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import org.apache.commons.collections.CollectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -118,7 +119,12 @@ public ShuffleDataResult readShuffleData() { + currentTier.name() + "handler, error: " + e.getMessage(); - throw new RssFetchFailedException(message, cause); + if (CollectionUtils.isEmpty(e.getFetchFailureServerIds())) { + throw new RssFetchFailedException(message, cause); + } else { + throw new RssFetchFailedException( + message, cause, e.getFetchFailureServerIds().toArray(new ShuffleServerInfo[0])); + } } catch (Exception e) { throw new RssFetchFailedException( "Failed to read shuffle data from " + currentTier.name() + " handler", e); From 8523d6dbebb83a75fcdf6b3438f7461374d662d0 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Mon, 24 Jun 2024 16:42:46 +0800 Subject: [PATCH 12/23] group the same stageId reader to calculate the task failure count --- .../spark/shuffle/RssStageResubmitManager.java | 16 ++++++++++++++++ .../spark/shuffle/stage/RssShuffleStatus.java | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java index 635e634732..1ae259f672 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -97,6 +97,7 @@ public RssShuffleStatus getShuffleStatusForWriter(int shuffleId, int stageId, in public boolean activateStageRetry(RssShuffleStatus shuffleStatus) { final String TASK_MAX_FAILURE = "spark.task.maxFailures"; int sparkTaskMaxFailures = sparkConf.getInt(TASK_MAX_FAILURE, 4); + // todo: use the extra config to control max stage retried count if (shuffleStatus.getStageRetriedCount() > 1) { LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it."); return false; @@ -104,6 +105,21 @@ public boolean activateStageRetry(RssShuffleStatus shuffleStatus) { if (shuffleStatus.getTaskFailureAttemptCount() >= sparkTaskMaxFailures) { return true; } + // for the sort merge join, the same stageId could trigger stage retry + if (shuffleStatus instanceof RssShuffleStatusForReader) { + int stageId = shuffleStatus.getStageId(); + long taskFailureCnt = + shuffleStatusForReader.values().stream() + .filter(x -> x.getStageId() == stageId) + .map(x -> x.getTaskFailureAttemptCount()) + .count(); + if (taskFailureCnt >= sparkTaskMaxFailures) { + LOG.info( + "Multiple same stageIds reader shuffle status's task failure count is greater than the threshold: {}", + sparkTaskMaxFailures); + return true; + } + } return false; } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java index 43cbb85d68..ea3e80fa3a 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java @@ -110,4 +110,12 @@ public void incTaskFailure(int taskAttemptNumber) { public int getTaskFailureAttemptCount() { return withReadLock(() -> taskAttemptFailureRecords.size()); } + + public int getStageId() { + return stageId; + } + + public int getShuffleId() { + return shuffleId; + } } From 2d4e9a0d7a8e29f427239d7207ce3fbb6bca75ae Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 25 Jun 2024 13:57:22 +0800 Subject: [PATCH 13/23] add support of catch the failure result fetch servers --- .../uniffle/shuffle/manager/RssShuffleManagerBase.java | 6 ++++-- .../uniffle/client/impl/ShuffleWriteClientImpl.java | 10 ++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index e0259c8225..7aaefa98cc 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -663,6 +663,7 @@ public boolean reassignOnStageResubmit( Object shuffleLock = rssStageResubmitManager.getOrCreateShuffleLock(shuffleId); synchronized (shuffleLock) { if (!rssStageResubmitManager.isStageAttemptRetried(shuffleId, stageId, stageAttemptNumber)) { + long start = System.currentTimeMillis(); int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); @@ -698,9 +699,10 @@ public boolean reassignOnStageResubmit( (StageAttemptShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId); stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo); LOG.info( - "The stage retry has been triggered successfully for the stageId: {}, attemptNumber: {}", + "The stage retry has been triggered successfully for the stageId: {}, attemptNumber: {}. It costs {}(ms)", stageId, - stageAttemptNumber); + stageAttemptNumber, + System.currentTimeMillis() - start); return true; } else { LOG.info( diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index b16b88168d..ca40121892 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -797,6 +797,7 @@ public Roaring64NavigableMap getShuffleResult( boolean isSuccessful = false; Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); int successCnt = 0; + Set failureServers = new HashSet<>(); for (ShuffleServerInfo ssi : shuffleServerInfoSet) { try { RssGetShuffleResultResponse response = @@ -812,6 +813,7 @@ public Roaring64NavigableMap getShuffleResult( } } } catch (Exception e) { + failureServers.add(ssi); LOG.warn( "Get shuffle result is failed from " + ssi @@ -824,7 +826,8 @@ public Roaring64NavigableMap getShuffleResult( } if (!isSuccessful) { throw new RssFetchFailedException( - "Get shuffle result is failed for appId[" + appId + "], shuffleId[" + shuffleId + "]"); + "Get shuffle result is failed for appId[" + appId + "], shuffleId[" + shuffleId + "]", + failureServers.toArray(new ShuffleServerInfo[0])); } return blockIdBitmap; } @@ -839,6 +842,7 @@ public Roaring64NavigableMap getShuffleResultForMultiPart( PartitionDataReplicaRequirementTracking replicaRequirementTracking) { Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); Set allRequestedPartitionIds = new HashSet<>(); + Set failureServers = new HashSet<>(); for (Map.Entry> entry : serverToPartitions.entrySet()) { ShuffleServerInfo shuffleServerInfo = entry.getKey(); Set requestPartitions = Sets.newHashSet(); @@ -864,6 +868,7 @@ public Roaring64NavigableMap getShuffleResultForMultiPart( } } } catch (Exception e) { + failureServers.add(shuffleServerInfo); failedPartitions.addAll(requestPartitions); LOG.warn( "Get shuffle result is failed from " @@ -883,7 +888,8 @@ public Roaring64NavigableMap getShuffleResultForMultiPart( if (!isSuccessful) { LOG.error("Failed to meet replica requirement: {}", replicaRequirementTracking); throw new RssFetchFailedException( - "Get shuffle result is failed for appId[" + appId + "], shuffleId[" + shuffleId + "]"); + "Get shuffle result is failed for appId[" + appId + "], shuffleId[" + shuffleId + "]", + failureServers.toArray(new ShuffleServerInfo[0])); } return blockIdBitmap; } From bce959807861167a219cdc0ecb1f597744f79c33 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 25 Jun 2024 15:04:00 +0800 Subject: [PATCH 14/23] remove unnecessary lock --- .../manager/RssShuffleManagerBase.java | 98 +++++++++---------- 1 file changed, 48 insertions(+), 50 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 7aaefa98cc..d34da2a243 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -660,57 +660,55 @@ public void addFaultShuffleServer(String shuffleServerId) { @Override public boolean reassignOnStageResubmit( int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) { - Object shuffleLock = rssStageResubmitManager.getOrCreateShuffleLock(shuffleId); - synchronized (shuffleLock) { - if (!rssStageResubmitManager.isStageAttemptRetried(shuffleId, stageId, stageAttemptNumber)) { - long start = System.currentTimeMillis(); - int requiredShuffleServerNumber = - RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); - int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); - - /** - * this will clear up the previous stage attempt all data when registering the same - * shuffleId at the second time - */ - Map> partitionToServers = - requestShuffleAssignment( - shuffleId, - numPartitions, - 1, - requiredShuffleServerNumber, - estimateTaskConcurrency, - rssStageResubmitManager.getBlackListedServerIds(), - stageId, - stageAttemptNumber, - false); - /** - * we need to clear the metadata of the completed task, otherwise some of the stage's data - * will be lost - */ - try { - unregisterAllMapOutput(shuffleId); - } catch (SparkException e) { - LOG.error("Clear MapoutTracker Meta failed!"); - throw new RssException("Clear MapoutTracker Meta failed!", e); - } - MutableShuffleHandleInfo shuffleHandleInfo = - new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo()); - StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = - (StageAttemptShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId); - stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo); - LOG.info( - "The stage retry has been triggered successfully for the stageId: {}, attemptNumber: {}. It costs {}(ms)", - stageId, - stageAttemptNumber, - System.currentTimeMillis() - start); - return true; - } else { - LOG.info( - "Do nothing that the stage: {} has been reassigned for attempt{}", - stageId, - stageAttemptNumber); - return false; + if (!rssStageResubmitManager.isStageAttemptRetried(shuffleId, stageId, stageAttemptNumber)) { + LOG.info("Doing reassign on stage retry for stage:{}, stageAttempt: {}", stageId, stageAttemptNumber); + long start = System.currentTimeMillis(); + int requiredShuffleServerNumber = + RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); + int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); + + /** + * this will clear up the previous stage attempt all data when registering the same + * shuffleId at the second time + */ + Map> partitionToServers = + requestShuffleAssignment( + shuffleId, + numPartitions, + 1, + requiredShuffleServerNumber, + estimateTaskConcurrency, + rssStageResubmitManager.getBlackListedServerIds(), + stageId, + stageAttemptNumber, + false); + /** + * we need to clear the metadata of the completed task, otherwise some of the stage's data + * will be lost + */ + try { + unregisterAllMapOutput(shuffleId); + } catch (SparkException e) { + LOG.error("Clear MapoutTracker Meta failed!"); + throw new RssException("Clear MapoutTracker Meta failed!", e); } + MutableShuffleHandleInfo shuffleHandleInfo = + new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo()); + StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = + (StageAttemptShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId); + stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo); + LOG.info( + "The stage retry has been triggered successfully for the stageId: {}, attemptNumber: {}. It costs {}(ms)", + stageId, + stageAttemptNumber, + System.currentTimeMillis() - start); + return true; + } else { + LOG.info( + "Do nothing that the stage: {} has been reassigned for attempt{}", + stageId, + stageAttemptNumber); + return false; } } From ab2ab2b92a68ff6ea428c73ec25a88eebfc04385 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 25 Jun 2024 15:25:30 +0800 Subject: [PATCH 15/23] register on reassign to use the next attempt number --- .../apache/uniffle/shuffle/manager/RssShuffleManagerBase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index d34da2a243..b5c492a34a 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -680,7 +680,7 @@ public boolean reassignOnStageResubmit( estimateTaskConcurrency, rssStageResubmitManager.getBlackListedServerIds(), stageId, - stageAttemptNumber, + stageAttemptNumber + 1, false); /** * we need to clear the metadata of the completed task, otherwise some of the stage's data From 07b8790a98eb7b7d540b53881f7d99812c89d3c8 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 25 Jun 2024 18:06:12 +0800 Subject: [PATCH 16/23] send blocks with stage attempt number --- .../apache/spark/shuffle/writer/WriteBufferManager.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 95add50481..bdd852adfa 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -34,6 +34,7 @@ import com.clearspring.analytics.util.Lists; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; +import org.apache.spark.TaskContext$; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryMode; @@ -100,6 +101,7 @@ public class WriteBufferManager extends MemoryConsumer { private BlockIdLayout blockIdLayout; private double bufferSpillRatio; private Function> partitionAssignmentRetrieveFunc; + private final int stageAttemptNumber; public WriteBufferManager( int shuffleId, @@ -167,6 +169,7 @@ public WriteBufferManager( this.bufferSpillRatio = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_RATIO); this.blockIdLayout = BlockIdLayout.from(rssConf); this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc; + this.stageAttemptNumber = TaskContext$.MODULE$.get().stageAttemptNumber(); } public WriteBufferManager( @@ -486,7 +489,7 @@ public List buildBlockEvents(List shuffleBlockI + totalSize + " bytes"); } - events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); + events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent)); shuffleBlockInfosPerEvent = Lists.newArrayList(); totalSize = 0; } @@ -501,7 +504,7 @@ public List buildBlockEvents(List shuffleBlockI + " bytes"); } // Use final temporary variables for closures - events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); + events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent)); } return events; } From cbd71ff052b5c006f30c92407db1087b7784bc4f Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 26 Jun 2024 10:01:02 +0800 Subject: [PATCH 17/23] fix some bugs --- .../spark/shuffle/RssSparkShuffleUtils.java | 3 ++ .../shuffle/RssStageResubmitManager.java | 24 +++++--------- .../spark/shuffle/stage/RssShuffleStatus.java | 9 ++++++ .../spark/shuffle/writer/AddBlockEvent.java | 3 ++ .../manager/RssShuffleManagerBase.java | 6 +++- .../manager/ShuffleManagerGrpcService.java | 31 +++++++++++-------- client-spark/spark3-shaded/pom.xml | 5 +++ .../client/impl/ShuffleWriteClientImpl.java | 2 +- .../uniffle/server/ShuffleTaskInfo.java | 4 +++ .../uniffle/server/ShuffleTaskManager.java | 3 +- 10 files changed, 58 insertions(+), 32 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index 9e8cfa7011..1dfbc8a6da 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -381,12 +381,15 @@ public static RssException reportRssFetchFailedException( SparkEnv.get().executorId()); RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); if (response.getReSubmitWholeStage()) { + LOG.error("Task:{}-{} is throwing the spark's fetchFailure exception to trigger stage retry as [{}]", taskContext.taskAttemptId(), taskContext.attemptNumber(), response.getMessage()); // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 // is provided. FetchFailedException ffe = RssSparkShuffleUtils.createFetchFailedException( shuffleId, -1, partitionId, rssFetchFailedException); return new RssException(ffe); + } else { + LOG.warn("Task:{}-{} haven't receive the shuffle manager's retry signal as [{}]", taskContext.taskAttemptId(), taskContext.attemptNumber(), response.getMessage()); } } } catch (IOException ioe) { diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java index 1ae259f672..8b353bff29 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -102,24 +102,16 @@ public boolean activateStageRetry(RssShuffleStatus shuffleStatus) { LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it."); return false; } - if (shuffleStatus.getTaskFailureAttemptCount() >= sparkTaskMaxFailures) { + int maxTaskFailureAttempt = shuffleStatus.getMaxFailureAttemptNumber(); + if (maxTaskFailureAttempt >= sparkTaskMaxFailures - 1) { + LOG.warn("Task failure attempt:{} is the final task attempt: {}", maxTaskFailureAttempt, sparkTaskMaxFailures - 1); return true; } - // for the sort merge join, the same stageId could trigger stage retry - if (shuffleStatus instanceof RssShuffleStatusForReader) { - int stageId = shuffleStatus.getStageId(); - long taskFailureCnt = - shuffleStatusForReader.values().stream() - .filter(x -> x.getStageId() == stageId) - .map(x -> x.getTaskFailureAttemptCount()) - .count(); - if (taskFailureCnt >= sparkTaskMaxFailures) { - LOG.info( - "Multiple same stageIds reader shuffle status's task failure count is greater than the threshold: {}", - sparkTaskMaxFailures); - return true; - } - } +// int taskFailureAttemptCnt = shuffleStatus.getTaskFailureAttemptCount(); +// if (taskFailureAttemptCnt >= sparkTaskMaxFailures) { +// LOG.warn("Task failure attempt count:{} reaches the spark's max task failure threshold: {}", taskFailureAttemptCnt, sparkTaskMaxFailures); +// return true; +// } return false; } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java index ea3e80fa3a..43a4837fc6 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/stage/RssShuffleStatus.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.stage; +import java.util.Comparator; import java.util.HashSet; import java.util.Set; import java.util.concurrent.locks.ReentrantReadWriteLock; @@ -111,6 +112,14 @@ public int getTaskFailureAttemptCount() { return withReadLock(() -> taskAttemptFailureRecords.size()); } + public int getMaxFailureAttemptNumber() { + return withReadLock(() -> taskAttemptFailureRecords.stream().max(Comparator.comparing(Integer::intValue)).orElse(0)); + } + + public Set getTaskAttemptFailureRecords() { + return withReadLock(() -> new HashSet<>(taskAttemptFailureRecords)); + } + public int getStageId() { return stageId; } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java index f989fdb0b1..cf0635b318 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.writer; +import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; import java.util.List; @@ -29,6 +30,8 @@ public class AddBlockEvent { private List shuffleDataInfoList; private List processedCallbackChain; + // only for tests. + @VisibleForTesting public AddBlockEvent(String taskId, List shuffleDataInfoList) { this(taskId, 0, shuffleDataInfoList); } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index b5c492a34a..aeada8f31b 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -53,6 +53,7 @@ import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.ShuffleHandleInfo; import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; +import org.apache.uniffle.client.impl.ShuffleWriteClientImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -652,7 +653,8 @@ public void addFaultShuffleServer(String shuffleServerId) { } /** - * Reassign the ShuffleServer list for ShuffleId + * Reassign the ShuffleServer list for ShuffleId. + * This is not thread safe which should be ensured by the invoking side. * * @param shuffleId * @param numPartitions @@ -667,6 +669,8 @@ public boolean reassignOnStageResubmit( RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); + // avoid heartbeat + ((ShuffleWriteClientImpl) shuffleWriteClient).removeShuffleServer(appId, shuffleId); /** * this will clear up the previous stage attempt all data when registering the same * shuffleId at the second time diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index 733d9fcd5c..a9a790df6b 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -148,6 +148,7 @@ public void reportShuffleFetchFailure( String executorId = request.getExecutorId(); List serverIds = request.getFetchFailureServerIdList(); + LOG.info("Accepted reportShuffleFetchFailure. stageId: {}, stageAttemptNumber: {}, taskId: {}, taskAttemptNumber: {}", stageId, stageAttempt, taskAttemptId, taskAttemptNumber); RssStageResubmitManager stageResubmitManager = shuffleManager.getStageResubmitManager(); RssProtos.StatusCode code; boolean reSubmitWholeStage; @@ -179,19 +180,20 @@ public void reportShuffleFetchFailure( .map(x -> x.getId()) .forEach(x -> stageResubmitManager.addBlackListedServer(x)); } - if (stageResubmitManager.activateStageRetry(rssShuffleStatus)) { - reSubmitWholeStage = true; - msg = - String.format( - "Activate stage retry for reader on stage(%d:%d), taskFailuresCount:(%d)", - stageId, stageAttempt, rssShuffleStatus.getTaskFailureAttemptCount()); - int partitionNum = shuffleManager.getPartitionNum(shuffleId); - Object shuffleLock = stageResubmitManager.getOrCreateShuffleLock(shuffleId); - synchronized (shuffleLock) { + Object shuffleLock = stageResubmitManager.getOrCreateShuffleLock(shuffleId); + synchronized (shuffleLock) { + if (stageResubmitManager.activateStageRetry(rssShuffleStatus)) { + reSubmitWholeStage = true; + msg = + String.format( + "Make stage retry for reader on stage(%d:%d), taskFailuresCount:(%d)", + stageId, stageAttempt, rssShuffleStatus.getTaskFailureAttemptCount()); + LOG.info(msg); + int partitionNum = shuffleManager.getPartitionNum(shuffleId); if (shuffleManager.reassignOnStageResubmit( stageId, stageAttempt, shuffleId, partitionNum)) { LOG.info( - "{} from executorId({}), task({}:{}) on stageId({}:{}), shuffleId({})", + "Finished reassign on stage retry from executorId({}), task({}-{}) on stageId({}-{}), shuffleId({})", msg, executorId, taskAttemptId, @@ -201,10 +203,13 @@ public void reportShuffleFetchFailure( shuffleId); } rssShuffleStatus.markStageAttemptRetried(); + } else { + reSubmitWholeStage = false; + msg = "Current task failure attempt records: " + + rssShuffleStatus.getTaskAttemptFailureRecords() + + ". And attempts count haven't reached the spark max task failure threshold"; + LOG.info(msg); } - } else { - reSubmitWholeStage = false; - msg = "Accepted task fetch failure report"; } } } diff --git a/client-spark/spark3-shaded/pom.xml b/client-spark/spark3-shaded/pom.xml index 7c4580cb5e..f771238c68 100644 --- a/client-spark/spark3-shaded/pom.xml +++ b/client-spark/spark3-shaded/pom.xml @@ -69,6 +69,7 @@ org.roaringbitmap:RoaringBitmap org.roaringbitmap:shims org.apache.commons:commons-collections4 + org.eclipse.jetty:jetty-util ${project.artifactId}-${project.version} @@ -124,6 +125,10 @@ org.roaringbitmap ${rss.shade.packageName}.org.roaringbitmap + + org.eclipse.jetty + ${rss.shade.packageName}.org.eclipse.jetty + diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index ca40121892..28ad208c56 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -1158,7 +1158,7 @@ void addShuffleServer(String appId, int shuffleId, ShuffleServerInfo serverInfo) } @VisibleForTesting - void removeShuffleServer(String appId, int shuffleId) { + public void removeShuffleServer(String appId, int shuffleId) { Map> appServerMap = shuffleServerInfoMap.get(appId); if (appServerMap != null) { appServerMap.remove(shuffleId); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java index e7848963fe..40a7009eed 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java @@ -229,6 +229,10 @@ public long getBlockNumber(int shuffleId, int partitionId) { return counter.get(); } + public void clearBlockNumber(int shuffleId) { + partitionBlockCounters.remove(shuffleId); + } + public Integer getLatestStageAttemptNumber(int shuffleId) { return latestStageAttemptNumbers.computeIfAbsent(shuffleId, key -> 0); } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index 8fe597d03a..730a5f66e4 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -755,6 +755,7 @@ public void removeResourcesByShuffleIds(String appId, List shuffleIds) taskInfo.getCachedBlockIds().remove(shuffleId); taskInfo.getCommitCounts().remove(shuffleId); taskInfo.getCommitLocks().remove(shuffleId); + taskInfo.clearBlockNumber(shuffleId); } } Optional.ofNullable(partitionsToBlockIds.get(appId)) @@ -794,7 +795,7 @@ public void removeResources(String appId, boolean checkAppExpired) { Lock lock = getAppWriteLock(appId); lock.lock(); try { - LOG.info("Start remove resource for appId[" + appId + "]"); + LOG.info("Start remove resource for appId[" + appId + "]" ); if (checkAppExpired && !isAppExpired(appId)) { LOG.info( "It seems that this appId[{}] has registered a new shuffle, just ignore this AppPurgeEvent event.", From 4bb973b900b4d3e6e44c8baf9c27e5c8341bc7fe Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 26 Jun 2024 10:12:49 +0800 Subject: [PATCH 18/23] remove outdate handlers --- .../apache/uniffle/common/util/RssUtils.java | 4 ++++ .../server/storage/HadoopStorageManager.java | 5 +++++ .../server/storage/LocalStorageManager.java | 3 +++ .../storage/common/AbstractStorage.java | 20 +++++++++++++++++++ .../uniffle/storage/common/Storage.java | 3 +++ 5 files changed, 35 insertions(+) diff --git a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java index 121fe10de4..b3cf57cd3b 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java +++ b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java @@ -269,6 +269,10 @@ public static Roaring64NavigableMap cloneBitMap(Roaring64NavigableMap bitmap) { return clone; } + public static String generateShuffleKeyWithSplitKey(String appId, int shuffleId) { + return String.join(Constants.KEY_SPLIT_CHAR, appId, String.valueOf(shuffleId), ""); + } + public static String generateShuffleKey(String appId, int shuffleId) { return String.join(Constants.KEY_SPLIT_CHAR, appId, String.valueOf(shuffleId)); } diff --git a/server/src/main/java/org/apache/uniffle/server/storage/HadoopStorageManager.java b/server/src/main/java/org/apache/uniffle/server/storage/HadoopStorageManager.java index 7ab9ef6396..c40730138c 100644 --- a/server/src/main/java/org/apache/uniffle/server/storage/HadoopStorageManager.java +++ b/server/src/main/java/org/apache/uniffle/server/storage/HadoopStorageManager.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -28,6 +29,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.uniffle.server.event.ShufflePurgeEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -102,6 +104,9 @@ public void removeResources(PurgeEvent event) { appIdToStorages.remove(appId); purgeForExpired = ((AppPurgeEvent) event).isAppExpired(); } + if (event instanceof ShufflePurgeEvent) { + storage.removeHandlers(appId, new HashSet<>(event.getShuffleIds())); + } ShuffleDeleteHandler deleteHandler = ShuffleHandlerFactory.getInstance() .createShuffleDeleteHandler( diff --git a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java index 7242012315..5272ebcabf 100644 --- a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java +++ b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java @@ -272,6 +272,9 @@ public void removeResources(PurgeEvent event) { if (event instanceof AppPurgeEvent) { storage.removeHandlers(appId); } + if (event instanceof ShufflePurgeEvent) { + storage.removeHandlers(appId, new HashSet<>(shuffleSet)); + } for (Integer shuffleId : shuffleSet) { storage.removeResources(RssUtils.generateShuffleKey(appId, shuffleId)); } diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java index f8bf8c9b4b..4b609c5ebe 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java +++ b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; +import java.util.Set; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.storage.handler.api.ServerReadHandler; @@ -84,6 +85,25 @@ public void removeHandlers(String appId) { requests.remove(appId); } + @Override + public void removeHandlers(String appId, Set shuffleIds) { + for (int shuffleId : shuffleIds) { + String shuffleKeyPrefix = RssUtils.generateShuffleKeyWithSplitKey(appId, shuffleId); + Map writeHandlers = writerHandlers.get(appId); + if (writeHandlers != null) { + writeHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); + } + Map readHandlers = readerHandlers.get(appId); + if (readHandlers != null) { + readHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); + } + Map requests = this.requests.get(appId); + if (requests != null) { + requests.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); + } + } + } + @VisibleForTesting public int getHandlerSize() { return writerHandlers.size(); diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/Storage.java b/storage/src/main/java/org/apache/uniffle/storage/common/Storage.java index 43168c324f..62f82298d3 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/common/Storage.java +++ b/storage/src/main/java/org/apache/uniffle/storage/common/Storage.java @@ -19,6 +19,7 @@ import java.io.IOException; +import java.util.Set; import org.apache.uniffle.storage.handler.api.ServerReadHandler; import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler; import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest; @@ -39,6 +40,8 @@ ShuffleWriteHandler getOrCreateWriteHandler(CreateShuffleWriteHandlerRequest req void removeHandlers(String appId); + void removeHandlers(String appId, Set shuffleIds); + void createMetadataIfNotExist(String shuffleKey); String getStoragePath(); From 55aebb75505245247e28df5edfbd42c6acece798 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 26 Jun 2024 11:06:42 +0800 Subject: [PATCH 19/23] get shuffle result with stage attempt number --- .../spark/shuffle/RssShuffleManager.java | 3 ++- .../client/api/ShuffleWriteClient.java | 21 +++++++++++++++- .../client/impl/ShuffleWriteClientImpl.java | 5 ++-- ...ssGetShuffleResultForMultiPartRequest.java | 15 ++++++++++- proto/src/main/proto/Rss.proto | 1 + .../server/ShuffleServerGrpcService.java | 25 +++++++++++++++++++ 6 files changed, 65 insertions(+), 5 deletions(-) diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 5d170b4471..924434abeb 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -1053,7 +1053,8 @@ private Roaring64NavigableMap getShuffleResultForMultiPart( appId, shuffleId, failedPartitions, - replicaRequirementTracking); + replicaRequirementTracking, + stageAttemptId); } catch (RssFetchFailedException e) { throw RssSparkShuffleUtils.reportRssFetchFailedException( e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions); diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index db0c914848..5e4f09f3eb 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -162,13 +162,32 @@ Roaring64NavigableMap getShuffleResult( int shuffleId, int partitionId); + default Roaring64NavigableMap getShuffleResultForMultiPart( + String clientType, + Map> serverToPartitions, + String appId, + int shuffleId, + Set failedPartitions, + PartitionDataReplicaRequirementTracking replicaRequirementTracking) { + return getShuffleResultForMultiPart( + clientType, + serverToPartitions, + appId, + shuffleId, + failedPartitions, + replicaRequirementTracking, + 0 + ); + } + Roaring64NavigableMap getShuffleResultForMultiPart( String clientType, Map> serverToPartitions, String appId, int shuffleId, Set failedPartitions, - PartitionDataReplicaRequirementTracking replicaRequirementTracking); + PartitionDataReplicaRequirementTracking replicaRequirementTracking, + int stageAttemptNumber); void close(); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index 28ad208c56..00f41af7ae 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -839,7 +839,8 @@ public Roaring64NavigableMap getShuffleResultForMultiPart( String appId, int shuffleId, Set failedPartitions, - PartitionDataReplicaRequirementTracking replicaRequirementTracking) { + PartitionDataReplicaRequirementTracking replicaRequirementTracking, + int stageAttemptNumber) { Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); Set allRequestedPartitionIds = new HashSet<>(); Set failureServers = new HashSet<>(); @@ -854,7 +855,7 @@ public Roaring64NavigableMap getShuffleResultForMultiPart( allRequestedPartitionIds.addAll(requestPartitions); RssGetShuffleResultForMultiPartRequest request = new RssGetShuffleResultForMultiPartRequest( - appId, shuffleId, requestPartitions, blockIdLayout); + appId, shuffleId, requestPartitions, blockIdLayout, stageAttemptNumber); try { RssGetShuffleResultResponse response = getShuffleServerClient(shuffleServerInfo).getShuffleResultForMultiPart(request); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java index 23c0a6a765..8da08801d1 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java @@ -28,12 +28,20 @@ public class RssGetShuffleResultForMultiPartRequest { private Set partitions; private BlockIdLayout blockIdLayout; + private int stageAttemptNumber; + public RssGetShuffleResultForMultiPartRequest( - String appId, int shuffleId, Set partitions, BlockIdLayout blockIdLayout) { + String appId, int shuffleId, Set partitions, BlockIdLayout blockIdLayout, int stageAttemptNumbers) { this.appId = appId; this.shuffleId = shuffleId; this.partitions = partitions; this.blockIdLayout = blockIdLayout; + this.stageAttemptNumber = stageAttemptNumbers; + } + + public RssGetShuffleResultForMultiPartRequest( + String appId, int shuffleId, Set partitions, BlockIdLayout blockIdLayout) { + this(appId, shuffleId, partitions, blockIdLayout, 0); } public String getAppId() { @@ -52,6 +60,10 @@ public BlockIdLayout getBlockIdLayout() { return blockIdLayout; } + public int getStageAttemptNumber() { + return stageAttemptNumber; + } + public RssProtos.GetShuffleResultForMultiPartRequest toProto() { RssGetShuffleResultForMultiPartRequest request = this; RssProtos.GetShuffleResultForMultiPartRequest rpcRequest = @@ -65,6 +77,7 @@ public RssProtos.GetShuffleResultForMultiPartRequest toProto() { .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits) .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits) .build()) + .setStageAttemptNumber(request.getStageAttemptNumber()) .build(); return rpcRequest; } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index ad11ab9442..d5ed558269 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -163,6 +163,7 @@ message GetShuffleResultForMultiPartRequest { int32 shuffleId = 2; repeated int32 partitions = 3; BlockIdLayout blockIdLayout = 4; + int32 stageAttemptNumber = 5; } message GetShuffleResultForMultiPartResponse { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index b6e37029f1..c6650f6031 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -656,6 +656,31 @@ public void getShuffleResultForMultiPart( StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetShuffleResultForMultiPartResponse reply; + + try { + ShuffleTaskInfo taskInfo = shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId); + if (taskInfo != null) { + synchronized (taskInfo) { + int latestAttemptNumber = taskInfo.getLatestStageAttemptNumber(shuffleId); + if (request.getStageAttemptNumber() != latestAttemptNumber) { + LOG.error("Abort this request with the old stageAttemptNumber:{}. latest: {}", request.getStageAttemptNumber(), latestAttemptNumber); + reply = + GetShuffleResultForMultiPartResponse.newBuilder() + .setStatus(StatusCode.INTERNAL_ERROR.toProto()) + .setRetMsg("Stage retry. Abort this request.") + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } + } + } else { + LOG.warn("TaskInfo is null. This should not happen"); + } + } catch (Exception e) { + LOG.info("Errors on getting shuffle result with multi-parts.", e); + } + byte[] serializedBlockIds = null; String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "], partitions" + partitionsList; From fee41dd43fcb5336cb0210971ba5c39ff3c9c55d Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 26 Jun 2024 13:38:33 +0800 Subject: [PATCH 20/23] record the cache clear time --- .../org/apache/uniffle/storage/common/AbstractStorage.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java index 4b609c5ebe..692baade24 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java +++ b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java @@ -29,8 +29,11 @@ import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest; import org.apache.uniffle.storage.request.CreateShuffleWriteHandlerRequest; import org.apache.uniffle.storage.util.ShuffleStorageUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public abstract class AbstractStorage implements Storage { + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStorage.class); private Map> writerHandlers = JavaUtils.newConcurrentMap(); @@ -87,6 +90,7 @@ public void removeHandlers(String appId) { @Override public void removeHandlers(String appId, Set shuffleIds) { + long start = System.currentTimeMillis(); for (int shuffleId : shuffleIds) { String shuffleKeyPrefix = RssUtils.generateShuffleKeyWithSplitKey(appId, shuffleId); Map writeHandlers = writerHandlers.get(appId); @@ -102,6 +106,7 @@ public void removeHandlers(String appId, Set shuffleIds) { requests.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); } } + LOGGER.info("Removed the handlers for appId:{}, shuffleId:{} costs {} ms", appId, shuffleIds, System.currentTimeMillis() - start); } @VisibleForTesting From 5c4c9e9f018db67635682f839947b96558230b75 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 26 Jun 2024 14:52:32 +0800 Subject: [PATCH 21/23] avoid adding the stage retry exception server into blacklist --- .../shuffle/RssStageResubmitManager.java | 2 +- .../client/impl/ShuffleWriteClientImpl.java | 5 ++- .../exception/StageRetryAbortException.java | 8 +++++ .../uniffle/test/RSSStageResubmitTest.java | 16 ++++++++-- .../test/SparkIntegrationTestBase.java | 5 +++ .../impl/grpc/ShuffleServerGrpcClient.java | 4 +++ .../server/ShuffleServerGrpcService.java | 5 +-- .../uniffle/server/ShuffleTaskManager.java | 14 +++++++-- .../server/event/ShufflePurgeEvent.java | 16 ++++++++++ .../server/storage/LocalStorageManager.java | 31 ++++++++++++++----- 10 files changed, 88 insertions(+), 18 deletions(-) create mode 100644 common/src/main/java/org/apache/uniffle/common/exception/StageRetryAbortException.java diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java index 8b353bff29..37ac69849a 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -99,7 +99,7 @@ public boolean activateStageRetry(RssShuffleStatus shuffleStatus) { int sparkTaskMaxFailures = sparkConf.getInt(TASK_MAX_FAILURE, 4); // todo: use the extra config to control max stage retried count if (shuffleStatus.getStageRetriedCount() > 1) { - LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it."); + LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it.", shuffleStatus.getShuffleId(), shuffleStatus.getStageId()); return false; } int maxTaskFailureAttempt = shuffleStatus.getMaxFailureAttemptNumber(); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index 00f41af7ae..f7ed9e6018 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -41,6 +41,7 @@ import io.grpc.StatusRuntimeException; import org.apache.commons.collections4.CollectionUtils; import org.apache.hadoop.security.UserGroupInformation; +import org.apache.uniffle.common.exception.StageRetryAbortException; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -869,7 +870,9 @@ public Roaring64NavigableMap getShuffleResultForMultiPart( } } } catch (Exception e) { - failureServers.add(shuffleServerInfo); + if (!(e instanceof StageRetryAbortException)) { + failureServers.add(shuffleServerInfo); + } failedPartitions.addAll(requestPartitions); LOG.warn( "Get shuffle result is failed from " diff --git a/common/src/main/java/org/apache/uniffle/common/exception/StageRetryAbortException.java b/common/src/main/java/org/apache/uniffle/common/exception/StageRetryAbortException.java new file mode 100644 index 0000000000..ef83568a9e --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/exception/StageRetryAbortException.java @@ -0,0 +1,8 @@ +package org.apache.uniffle.common.exception; + +public class StageRetryAbortException extends RuntimeException { + + public StageRetryAbortException(String message) { + super(message); + } +} diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java index 774ba572bb..47d1352b32 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java @@ -36,6 +36,8 @@ import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER; + public class RSSStageResubmitTest extends SparkTaskFailureIntegrationTestBase { @BeforeAll @@ -51,6 +53,8 @@ public static void setupServers() throws Exception { createCoordinatorServer(coordinatorConf); ShuffleServerConf grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC); createMockedShuffleServer(grpcShuffleServerConf); + grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC); + createMockedShuffleServer(grpcShuffleServerConf); enableFirstReadRequest(2 * maxTaskFailures); ShuffleServerConf nettyShuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY); createMockedShuffleServer(nettyShuffleServerConf); @@ -58,9 +62,7 @@ public static void setupServers() throws Exception { } private static void enableFirstReadRequest(int failCount) { - for (ShuffleServer server : grpcShuffleServers) { - ((MockedGrpcServer) server.getServer()).getService().enableFirstNReadRequestToFail(failCount); - } + ((MockedGrpcServer) grpcShuffleServers.get(0).getServer()).getService().enableFirstNReadRequestToFail(Integer.MAX_VALUE); } @Override @@ -79,8 +81,16 @@ public void updateSparkConfCustomer(SparkConf sparkConf) { super.updateSparkConfCustomer(sparkConf); sparkConf.set( RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_RESUBMIT_STAGE, "true"); + sparkConf.set( + RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER, "2"); } +// protected void injectBeforeStop(SparkConf sparkConf) throws InterruptedException { +// if (sparkConf.getBoolean("spark.rss.resubmit.stage", false)) { +// Thread.sleep(1000000); +// } +// } + @Test public void testRSSStageResubmit() throws Exception { run(); diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java index 3a04680dac..e9c714e506 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java @@ -100,10 +100,15 @@ protected Map runSparkApp(SparkConf sparkConf, String testFileName) throws Excep } spark = SparkSession.builder().config(sparkConf).getOrCreate(); Map result = runTest(spark, testFileName); + injectBeforeStop(sparkConf); spark.stop(); return result; } + protected void injectBeforeStop(SparkConf conf) throws InterruptedException { + + } + protected SparkConf createSparkConf() { return new SparkConf().setAppName(this.getClass().getSimpleName()).setMaster("local[4]"); } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 9df4070654..c1f74572ef 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -31,6 +31,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.UnsafeByteOperations; import io.netty.buffer.Unpooled; +import org.apache.uniffle.common.exception.StageRetryAbortException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -852,6 +853,7 @@ public RssGetShuffleResultResponse getShuffleResultForMultiPart( .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits) .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits) .build()) + .setStageAttemptNumber(request.getStageAttemptNumber()) .build(); GetShuffleResultForMultiPartResponse rpcResponse = getBlockingStub().getShuffleResultForMultiPart(rpcRequest); @@ -868,6 +870,8 @@ public RssGetShuffleResultResponse getShuffleResultForMultiPart( throw new RssException(e); } break; + case STAGE_RETRY_IGNORE: + throw new StageRetryAbortException(rpcResponse.getRetMsg()); default: String msg = "Can't get shuffle result from " diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index c6650f6031..bec94dc42d 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -167,9 +167,10 @@ public void registerShuffle( int attemptNumber = taskInfo.getLatestStageAttemptNumber(shuffleId); if (stageAttemptNumber > attemptNumber) { taskInfo.refreshLatestStageAttemptNumber(shuffleId, stageAttemptNumber); + LOG.info("Refreshed the stage attempt number from {} to {}", attemptNumber, stageAttemptNumber); try { long start = System.currentTimeMillis(); - shuffleServer.getShuffleTaskManager().removeShuffleDataSync(appId, shuffleId); + shuffleServer.getShuffleTaskManager().removeShuffleForStageRetry(appId, shuffleId, attemptNumber); LOG.info( "Deleted the previous stage attempt data due to stage recomputing for app: {}, " + "shuffleId: {}. It costs {} ms", @@ -666,7 +667,7 @@ public void getShuffleResultForMultiPart( LOG.error("Abort this request with the old stageAttemptNumber:{}. latest: {}", request.getStageAttemptNumber(), latestAttemptNumber); reply = GetShuffleResultForMultiPartResponse.newBuilder() - .setStatus(StatusCode.INTERNAL_ERROR.toProto()) + .setStatus(StatusCode.STAGE_RETRY_IGNORE.toProto()) .setRetMsg("Stage retry. Abort this request.") .build(); responseObserver.onNext(reply); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index 730a5f66e4..47deafcafb 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -733,13 +733,17 @@ private boolean isAppExpired(String appId) { > appExpiredWithoutHB; } + public void removeResourcesByShuffleIds(String appId, List shuffleIds) { + removeResourcesByShuffleIds(appId, shuffleIds, false, 0); + } + /** * Clear up the partial resources of shuffleIds of App. * * @param appId * @param shuffleIds */ - public void removeResourcesByShuffleIds(String appId, List shuffleIds) { + public void removeResourcesByShuffleIds(String appId, List shuffleIds, boolean isLazyDeletion, int stageAttemptNumber) { Lock writeLock = getAppWriteLock(appId); writeLock.lock(); try { @@ -768,7 +772,7 @@ public void removeResourcesByShuffleIds(String appId, List shuffleIds) shuffleBufferManager.removeBufferByShuffleId(appId, shuffleIds); shuffleFlushManager.removeResourcesOfShuffleId(appId, shuffleIds); storageManager.removeResources( - new ShufflePurgeEvent(appId, getUserByAppId(appId), shuffleIds)); + new ShufflePurgeEvent(appId, getUserByAppId(appId), shuffleIds, isLazyDeletion, stageAttemptNumber)); LOG.info( "Finish remove resource for appId[{}], shuffleIds[{}], cost[{}]", appId, @@ -795,7 +799,7 @@ public void removeResources(String appId, boolean checkAppExpired) { Lock lock = getAppWriteLock(appId); lock.lock(); try { - LOG.info("Start remove resource for appId[" + appId + "]" ); + LOG.info("Start remove resource for appId[" + appId + "]"); if (checkAppExpired && !isAppExpired(appId)) { LOG.info( "It seems that this appId[{}] has registered a new shuffle, just ignore this AppPurgeEvent event.", @@ -915,6 +919,10 @@ public void removeShuffleDataSync(String appId, int shuffleId) { removeResourcesByShuffleIds(appId, Arrays.asList(shuffleId)); } + public void removeShuffleForStageRetry(String appId, int shuffleId, int stageAttemptNumber) { + removeResourcesByShuffleIds(appId, Arrays.asList(shuffleId), true, stageAttemptNumber); + } + public ShuffleDataDistributionType getDataDistributionType(String appId) { return shuffleTaskInfos.get(appId).getDataDistType(); } diff --git a/server/src/main/java/org/apache/uniffle/server/event/ShufflePurgeEvent.java b/server/src/main/java/org/apache/uniffle/server/event/ShufflePurgeEvent.java index cbc39aab84..79715644cc 100644 --- a/server/src/main/java/org/apache/uniffle/server/event/ShufflePurgeEvent.java +++ b/server/src/main/java/org/apache/uniffle/server/event/ShufflePurgeEvent.java @@ -20,8 +20,24 @@ import java.util.List; public class ShufflePurgeEvent extends PurgeEvent { + private boolean lazyDeletion = false; + private int stageAttemptNumber = 0; public ShufflePurgeEvent(String appId, String user, List shuffleIds) { super(appId, user, shuffleIds); } + + public ShufflePurgeEvent(String appId, String user, List shuffleIds, boolean isLazyDelete, int stageAttemptNumber) { + super(appId, user, shuffleIds); + this.lazyDeletion = isLazyDelete; + this.stageAttemptNumber = stageAttemptNumber; + } + + public boolean isLazyDeletion() { + return lazyDeletion; + } + + public int getStageAttemptNumber() { + return stageAttemptNumber; + } } diff --git a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java index 5272ebcabf..1d8915a002 100644 --- a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java +++ b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java @@ -17,6 +17,7 @@ package org.apache.uniffle.server.storage; +import java.io.File; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -279,14 +280,8 @@ public void removeResources(PurgeEvent event) { storage.removeResources(RssUtils.generateShuffleKey(appId, shuffleId)); } } - // delete shuffle data for application - ShuffleDeleteHandler deleteHandler = - ShuffleHandlerFactory.getInstance() - .createShuffleDeleteHandler( - new CreateShuffleDeleteHandlerRequest( - StorageType.LOCALFILE.name(), new Configuration())); - List deletePaths = + List dataPaths = storageBasePaths.stream() .flatMap( path -> { @@ -305,7 +300,27 @@ public void removeResources(PurgeEvent event) { }) .collect(Collectors.toList()); - deleteHandler.delete(deletePaths.toArray(new String[deletePaths.size()]), appId, user); + if (event instanceof ShufflePurgeEvent) { + if (((ShufflePurgeEvent) event).isLazyDeletion()) { + for (String path : dataPaths) { + File file = new File(path); + if (file.exists()) { + // todo: use the RenameHandler to do this. + file.renameTo(new File(file.getAbsolutePath() + "-" + ((ShufflePurgeEvent) event).getStageAttemptNumber())); + } + } + return; + } + } + + // delete shuffle data for application + ShuffleDeleteHandler deleteHandler = + ShuffleHandlerFactory.getInstance() + .createShuffleDeleteHandler( + new CreateShuffleDeleteHandlerRequest( + StorageType.LOCALFILE.name(), new Configuration())); + + deleteHandler.delete(dataPaths.toArray(new String[dataPaths.size()]), appId, user); } private void cleanupStorageSelectionCache(PurgeEvent event) { From 5bf6010a1875ee59d80ba3ae67ecd18ac787b6a6 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 26 Jun 2024 17:39:24 +0800 Subject: [PATCH 22/23] avoid removing cost too much times --- .../storage/common/AbstractStorage.java | 30 +++++++++---------- .../handler/impl/LocalFileWriteHandler.java | 1 + 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java index 692baade24..97006a75b1 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java +++ b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java @@ -91,21 +91,21 @@ public void removeHandlers(String appId) { @Override public void removeHandlers(String appId, Set shuffleIds) { long start = System.currentTimeMillis(); - for (int shuffleId : shuffleIds) { - String shuffleKeyPrefix = RssUtils.generateShuffleKeyWithSplitKey(appId, shuffleId); - Map writeHandlers = writerHandlers.get(appId); - if (writeHandlers != null) { - writeHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); - } - Map readHandlers = readerHandlers.get(appId); - if (readHandlers != null) { - readHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); - } - Map requests = this.requests.get(appId); - if (requests != null) { - requests.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); - } - } +// for (int shuffleId : shuffleIds) { +// String shuffleKeyPrefix = RssUtils.generateShuffleKeyWithSplitKey(appId, shuffleId); +// Map writeHandlers = writerHandlers.get(appId); +// if (writeHandlers != null) { +// writeHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); +// } +// Map readHandlers = readerHandlers.get(appId); +// if (readHandlers != null) { +// readHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> readHandlers.remove(x)); +// } +// Map requests = this.requests.get(appId); +// if (requests != null) { +// requests.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> requests.remove(x)); +// } +// } LOGGER.info("Removed the handlers for appId:{}, shuffleId:{} costs {} ms", appId, shuffleIds, System.currentTimeMillis() - start); } diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java index 4b06e5aa91..6eb13640cf 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java @@ -88,6 +88,7 @@ public synchronized void write(List shuffleBlocks) thro File baseFolder = new File(basePath); if (!baseFolder.exists()) { LOG.warn("{} don't exist, the app or shuffle may be deleted", baseFolder.getAbsolutePath()); + createBasePath(); return; } From 73e502023d3d878e3b1a7a5bdc7025828b83b1aa Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Fri, 28 Jun 2024 15:24:44 +0800 Subject: [PATCH 23/23] draft all --- .../spark/shuffle/RssShuffleManager.java | 15 ++++++++++++- .../client/impl/ShuffleReadClientImpl.java | 3 +++ .../server/ShuffleServerGrpcService.java | 22 +++++++++++++++++++ .../storage/common/AbstractStorage.java | 20 +++++------------ .../impl/ComposedClientReadHandler.java | 16 ++++++++++++++ .../impl/DataSkippableReadHandler.java | 7 ++++++ .../handler/impl/MemoryClientReadHandler.java | 7 ++++++ 7 files changed, 74 insertions(+), 16 deletions(-) diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 924434abeb..bd7b318d52 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -525,11 +525,24 @@ public ShuffleWriter getWriter( shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage()); } String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); +// return new RssShuffleWriter<>( +// rssHandle.getAppId(), +// shuffleId, +// taskId, +// getTaskAttemptIdForBlockId(context.partitionId(), context.attemptNumber()), +// writeMetrics, +// this, +// sparkConf, +// shuffleWriteClient, +// rssHandle, +// this::markFailedTask, +// context, +// shuffleHandleInfo); return new RssShuffleWriter<>( rssHandle.getAppId(), shuffleId, taskId, - getTaskAttemptIdForBlockId(context.partitionId(), context.attemptNumber()), + context.taskAttemptId(), writeMetrics, this, sparkConf, diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java index e1aa0f9582..1fc5eb5229 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java @@ -27,6 +27,7 @@ import com.google.common.collect.Queues; import com.google.common.collect.Sets; import org.apache.hadoop.conf.Configuration; +import org.apache.uniffle.storage.handler.impl.ComposedClientReadHandler; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -323,6 +324,8 @@ private int read() { @Override public void checkProcessedBlockIds() { + long blockCounter = ((ComposedClientReadHandler)clientReadHandler).getBlockCounter(); + LOG.info("Fetched block counter: {}", blockCounter); RssUtils.checkProcessedBlockIds(blockIdBitmap, processedBlockIds); } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index bec94dc42d..8e2719fe5c 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import com.google.common.collect.Lists; @@ -33,6 +34,7 @@ import io.grpc.stub.StreamObserver; import io.netty.buffer.ByteBuf; import org.apache.commons.lang3.StringUtils; +import org.apache.uniffle.common.util.JavaUtils; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -96,6 +98,9 @@ public class ShuffleServerGrpcService extends ShuffleServerImplBase { private static final Logger LOG = LoggerFactory.getLogger(ShuffleServerGrpcService.class); private final ShuffleServer shuffleServer; + // shuffleId -> partitionId -> blocks + private Map> blockIdCounter = JavaUtils.newConcurrentMap(); + public ShuffleServerGrpcService(ShuffleServer shuffleServer) { this.shuffleServer = shuffleServer; } @@ -658,6 +663,17 @@ public void getShuffleResultForMultiPart( String msg = "OK"; GetShuffleResultForMultiPartResponse reply; + try { + if (request.getStageAttemptNumber() == 1) { + for (int pid : partitionsList) { + long blockCnt = blockIdCounter.get(shuffleId).get(pid).get(); + LOG.info("ShuffleId:{}. partitionId:{}. blockCount: {}", shuffleId, pid, blockCnt); + } + } + } catch (Exception e) { + LOG.error("Errors on getting shuffle result. ", e); + } + try { ShuffleTaskInfo taskInfo = shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId); if (taskInfo != null) { @@ -1041,8 +1057,14 @@ public void getMemoryShuffleData( private List toPartitionedData(SendShuffleDataRequest req) { List ret = Lists.newArrayList(); + int shuffleId = req.getShuffleId(); + Map partitionBlockIds = + blockIdCounter.computeIfAbsent(shuffleId, x -> JavaUtils.newConcurrentMap()); for (ShuffleData data : req.getShuffleDataList()) { + if (req.getStageAttemptNumber() == 1) { + partitionBlockIds.computeIfAbsent(data.getPartitionId(), x -> new AtomicInteger()).incrementAndGet(); + } ret.add( new ShufflePartitionedData( data.getPartitionId(), toPartitionedBlock(data.getBlockList()))); diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java index 97006a75b1..69a9dcda8e 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java +++ b/storage/src/main/java/org/apache/uniffle/storage/common/AbstractStorage.java @@ -55,6 +55,7 @@ public ShuffleWriteHandler getOrCreateWriteHandler(CreateShuffleWriteHandlerRequ Map requestMap = requests.get(request.getAppId()); requestMap.putIfAbsent(partitionKey, request); return map.get(partitionKey); +// return newWriteHandler(request); } @Override @@ -68,6 +69,7 @@ public ServerReadHandler getOrCreateReadHandler(CreateShuffleReadHandlerRequest RssUtils.generatePartitionKey(request.getAppId(), request.getShuffleId(), range[0]); map.computeIfAbsent(partitionKey, key -> newReadHandler(request)); return map.get(partitionKey); +// return newReadHandler(request); } protected abstract ServerReadHandler newReadHandler(CreateShuffleReadHandlerRequest request); @@ -91,21 +93,9 @@ public void removeHandlers(String appId) { @Override public void removeHandlers(String appId, Set shuffleIds) { long start = System.currentTimeMillis(); -// for (int shuffleId : shuffleIds) { -// String shuffleKeyPrefix = RssUtils.generateShuffleKeyWithSplitKey(appId, shuffleId); -// Map writeHandlers = writerHandlers.get(appId); -// if (writeHandlers != null) { -// writeHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> writeHandlers.remove(x)); -// } -// Map readHandlers = readerHandlers.get(appId); -// if (readHandlers != null) { -// readHandlers.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> readHandlers.remove(x)); -// } -// Map requests = this.requests.get(appId); -// if (requests != null) { -// requests.keySet().stream().filter(x -> x.startsWith(shuffleKeyPrefix)).forEach(x -> requests.remove(x)); -// } -// } + LOGGER.info("Removing handlers...."); +// writerHandlers.clear(); +// readerHandlers.clear(); LOGGER.info("Removed the handlers for appId:{}, shuffleId:{} costs {} ms", appId, shuffleIds, System.currentTimeMillis() - start); } diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java index affbd477e6..679bcabb18 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java @@ -17,6 +17,7 @@ package org.apache.uniffle.storage.handler.impl; +import java.util.Collections; import java.util.EnumMap; import java.util.List; import java.util.Map; @@ -216,4 +217,19 @@ private String getMetricsInfo( sb.append(" ]"); return sb.toString(); } + + public long getBlockCounter() { + long counter = 0; + for (ClientReadHandler readHandler : handlerMap.values()) { + if (readHandler instanceof MemoryClientReadHandler) { + counter += ((MemoryClientReadHandler) readHandler).getBlockCounter(); + LOG.info("Mem: {}", ((MemoryClientReadHandler) readHandler).getBlockCounter()); + } + if (readHandler instanceof DataSkippableReadHandler) { + counter += ((DataSkippableReadHandler) readHandler).getBlockCounter(); + LOG.info("Disk: {}", ((DataSkippableReadHandler) readHandler).getBlockCounter()); + } + } + return counter; + } } diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java index 220e02997a..9eb7164da8 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java @@ -42,6 +42,8 @@ public abstract class DataSkippableReadHandler extends AbstractClientReadHandler protected ShuffleDataDistributionType distributionType; protected Roaring64NavigableMap expectTaskIds; + protected long blockCounter; + public DataSkippableReadHandler( String appId, int shuffleId, @@ -78,6 +80,7 @@ public ShuffleDataResult readShuffleData() { SegmentSplitterFactory.getInstance() .get(distributionType, expectTaskIds, readBufferSize) .split(shuffleIndexResult); + shuffleDataSegments.forEach(x -> blockCounter += x.getBufferSegments().size()); } finally { shuffleIndexResult.release(); } @@ -105,4 +108,8 @@ public ShuffleDataResult readShuffleData() { } return result; } + + public long getBlockCounter() { + return blockCounter; + } } diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java index f1fbe2361c..722fc33880 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java @@ -41,6 +41,8 @@ public class MemoryClientReadHandler extends AbstractClientReadHandler { private int retryMax; private long retryIntervalMax; + private long blockCounter; + public MemoryClientReadHandler( String appId, int shuffleId, @@ -104,6 +106,11 @@ public ShuffleDataResult readShuffleData() { lastBlockId = bufferSegments.get(bufferSegments.size() - 1).getBlockId(); } + blockCounter += result.getBufferSegments().size(); return result; } + + public long getBlockCounter() { + return blockCounter; + } }