Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#1796] fix(spark): Implicitly unregister map output on fetch failure #1797

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
Expand All @@ -33,6 +34,9 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.deploy.SparkHadoopUtil;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
Expand Down Expand Up @@ -360,6 +364,7 @@ public static RssException reportRssFetchFailedException(
try (ShuffleManagerClient client =
ShuffleManagerClientFactory.getInstance()
.createShuffleManagerClient(ClientType.GRPC, driver, port)) {
TaskContext taskContext = TaskContext$.MODULE$.get();
// todo: Create a new rpc interface to report failures in batch.
for (int partitionId : failedPartitions) {
RssReportShuffleFetchFailureRequest req =
Expand All @@ -368,15 +373,23 @@ public static RssException reportRssFetchFailedException(
shuffleId,
stageAttemptId,
partitionId,
rssFetchFailedException.getMessage());
rssFetchFailedException.getMessage(),
new ArrayList<>(rssFetchFailedException.getFetchFailureServerIds()),
taskContext.stageId(),
taskContext.taskAttemptId(),
taskContext.attemptNumber(),
SparkEnv.get().executorId());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
LOG.error("Task:{}-{} is throwing the spark's fetchFailure exception to trigger stage retry as [{}]", taskContext.taskAttemptId(), taskContext.attemptNumber(), response.getMessage());
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1
// is provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
shuffleId, -1, partitionId, rssFetchFailedException);
return new RssException(ffe);
} else {
LOG.warn("Task:{}-{} haven't receive the shuffle manager's retry signal as [{}]", taskContext.taskAttemptId(), taskContext.attemptNumber(), response.getMessage());
}
}
} catch (IOException ioe) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,107 @@

import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import com.google.common.collect.Sets;
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.stage.RssShuffleStatus;
import org.apache.spark.shuffle.stage.RssShuffleStatusForReader;
import org.apache.spark.shuffle.stage.RssShuffleStatusForWriter;
import org.eclipse.jetty.util.ConcurrentHashSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.util.JavaUtils;

/** This class is to manage shuffle status for stage retry */
public class RssStageResubmitManager {

private static final Logger LOG = LoggerFactory.getLogger(RssStageResubmitManager.class);
/** Blacklist of the Shuffle Server when the write fails. */
private Set<String> serverIdBlackList;
/**
* Prevent multiple tasks from reporting FetchFailed, resulting in multiple ShuffleServer
* assignments, stageID, Attemptnumber Whether to reassign the combination flag;
*/
private Map<Integer, RssStageInfo> serverAssignedInfos;

public RssStageResubmitManager() {
this.serverIdBlackList = Sets.newConcurrentHashSet();
this.serverAssignedInfos = JavaUtils.newConcurrentMap();
private final SparkConf sparkConf;
private final Map<Integer, RssShuffleStatusForReader> shuffleStatusForReader =
new ConcurrentHashMap<>();
private final Map<Integer, RssShuffleStatusForWriter> shuffleStatusForWriter =
new ConcurrentHashMap<>();
private final Map<Integer, Object> shuffleLock = new ConcurrentHashMap<>();

private final Set<String> blackListedServerIds = new ConcurrentHashSet<>();

public RssStageResubmitManager(SparkConf sparkConf) {
this.sparkConf = sparkConf;
}

public Set<String> getServerIdBlackList() {
return serverIdBlackList;
public Object getOrCreateShuffleLock(int shuffleId) {
return shuffleLock.computeIfAbsent(shuffleId, x -> new Object());
}

public void resetServerIdBlackList(Set<String> failuresShuffleServerIds) {
this.serverIdBlackList = failuresShuffleServerIds;
public void clear(int shuffleId) {
shuffleStatusForReader.remove(shuffleId);
shuffleStatusForWriter.remove(shuffleId);
shuffleLock.remove(shuffleId);
}

public void recordFailuresShuffleServer(String shuffleServerId) {
serverIdBlackList.add(shuffleServerId);
public boolean isStageAttemptRetried(int shuffleId, int stageId, int stageAttemptNumber) {
RssShuffleStatus readerShuffleStatus = shuffleStatusForReader.get(shuffleId);
RssShuffleStatus writerShuffleStatus = shuffleStatusForWriter.get(shuffleId);
if (readerShuffleStatus == null && writerShuffleStatus == null) {
return false;
}
if (readerShuffleStatus != null
&& readerShuffleStatus.isStageAttemptRetried(stageAttemptNumber)) {
return true;
}
if (writerShuffleStatus != null
&& writerShuffleStatus.isStageAttemptRetried(stageAttemptNumber)) {
return true;
}
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one case, the Reader triggers retry, and the retry is recorded. After the Writer fails to write data for several times, the retry is triggered. However, this method returns that the retry has been performed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The same stageIdAttemptNumber retry will ocurr one time, is this incorrect? @yl09099


public RssStageInfo recordAndGetServerAssignedInfo(int shuffleId, String stageIdAndAttempt) {
public RssShuffleStatus getShuffleStatusForReader(int shuffleId, int stageId, int stageAttempt) {
RssShuffleStatus shuffleStatus =
shuffleStatusForReader.computeIfAbsent(
shuffleId, x -> new RssShuffleStatusForReader(stageId, shuffleId));
if (shuffleStatus.updateStageAttemptIfNecessary(stageAttempt)) {
return shuffleStatus;
}
return null;
}

public RssShuffleStatus getShuffleStatusForWriter(int shuffleId, int stageId, int stageAttempt) {
RssShuffleStatus shuffleStatus =
shuffleStatusForWriter.computeIfAbsent(
shuffleId, x -> new RssShuffleStatusForWriter(stageId, shuffleId));
if (shuffleStatus.updateStageAttemptIfNecessary(stageAttempt)) {
return shuffleStatus;
}
return null;
}

public boolean activateStageRetry(RssShuffleStatus shuffleStatus) {
final String TASK_MAX_FAILURE = "spark.task.maxFailures";
int sparkTaskMaxFailures = sparkConf.getInt(TASK_MAX_FAILURE, 4);
// todo: use the extra config to control max stage retried count
if (shuffleStatus.getStageRetriedCount() > 1) {
LOG.warn("The shuffleId:{}, stageId:{} has been retried. Ignore it.", shuffleStatus.getShuffleId(), shuffleStatus.getStageId());
return false;
}
int maxTaskFailureAttempt = shuffleStatus.getMaxFailureAttemptNumber();
if (maxTaskFailureAttempt >= sparkTaskMaxFailures - 1) {
LOG.warn("Task failure attempt:{} is the final task attempt: {}", maxTaskFailureAttempt, sparkTaskMaxFailures - 1);
return true;
}
// int taskFailureAttemptCnt = shuffleStatus.getTaskFailureAttemptCount();
// if (taskFailureAttemptCnt >= sparkTaskMaxFailures) {
// LOG.warn("Task failure attempt count:{} reaches the spark's max task failure threshold: {}", taskFailureAttemptCnt, sparkTaskMaxFailures);
// return true;
// }
return false;
}

return serverAssignedInfos.computeIfAbsent(
shuffleId, id -> new RssStageInfo(stageIdAndAttempt, false));
public Set<String> getBlackListedServerIds() {
return blackListedServerIds.stream().collect(Collectors.toSet());
}

public void recordAndGetServerAssignedInfo(
int shuffleId, String stageIdAndAttempt, boolean isRetried) {
serverAssignedInfos
.computeIfAbsent(shuffleId, id -> new RssStageInfo(stageIdAndAttempt, false))
.setReassigned(isRetried);
public void addBlackListedServer(String shuffleServerId) {
blackListedServerIds.add(shuffleServerId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
package org.apache.spark.shuffle.reader;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Objects;

import scala.Product2;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;

import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.shuffle.FetchFailedException;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -111,13 +115,19 @@ private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) {
int port = builder.reportServerPort;
// todo: reuse this manager client if this is a bottleneck.
try (ShuffleManagerClient client = createShuffleManagerClient(driver, port)) {
TaskContext taskContext = TaskContext$.MODULE$.get();
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
builder.appId,
builder.shuffleId,
builder.stageAttemptId,
builder.partitionId,
e.getMessage());
e.getMessage(),
new ArrayList<>(e.getFetchFailureServerIds()),
taskContext.stageId(),
taskContext.taskAttemptId(),
taskContext.attemptNumber(),
SparkEnv.get().executorId());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.stage;

import java.util.Comparator;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;

/**
* This class is to track the stage attempt status to check whether to trigger the stage retry of
* Spark.
*/
public class RssShuffleStatus {
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private final ReentrantReadWriteLock.ReadLock readLock = lock.readLock();
private final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock();
private final int stageId;
private final int shuffleId;
// the retried stage attempt records
private final Set<Integer> stageAttemptRetriedRecords;

private int stageAttemptNumber;
// the failed task attempt numbers. Attention: these are not task attempt ids!
private Set<Integer> taskAttemptFailureRecords;

public RssShuffleStatus(int stageId, int shuffleId) {
this.shuffleId = shuffleId;
this.stageId = stageId;
this.stageAttemptRetriedRecords = new HashSet<>();
this.taskAttemptFailureRecords = new HashSet<>();
}

private <T> T withReadLock(Supplier<T> fn) {
readLock.lock();
try {
return fn.get();
} finally {
readLock.unlock();
}
}

private <T> T withWriteLock(Supplier<T> fn) {
writeLock.lock();
try {
return fn.get();
} finally {
writeLock.unlock();
}
}

public boolean isStageAttemptRetried(int stageAttempt) {
return withReadLock(() -> stageAttemptRetriedRecords.contains(stageAttempt));
}

public int getStageRetriedCount() {
return withReadLock(() -> this.stageAttemptRetriedRecords.size());
}

public void markStageAttemptRetried() {
withWriteLock(
() -> {
this.stageAttemptRetriedRecords.add(stageAttemptNumber);
return null;
});
}

public int getStageAttempt() {
return withReadLock(() -> this.stageAttemptNumber);
}

public boolean updateStageAttemptIfNecessary(int stageAttempt) {
return withWriteLock(
() -> {
if (this.stageAttemptNumber < stageAttempt) {
// a new stage attempt is issued.
this.stageAttemptNumber = stageAttempt;
this.taskAttemptFailureRecords = new HashSet<>();
return true;
} else if (this.stageAttemptNumber > stageAttempt) {
return false;
}
return true;
});
}

public void incTaskFailure(int taskAttemptNumber) {
withWriteLock(
() -> {
taskAttemptFailureRecords.add(taskAttemptNumber);
return null;
});
}

public int getTaskFailureAttemptCount() {
return withReadLock(() -> taskAttemptFailureRecords.size());
}

public int getMaxFailureAttemptNumber() {
return withReadLock(() -> taskAttemptFailureRecords.stream().max(Comparator.comparing(Integer::intValue)).orElse(0));
}

public Set<Integer> getTaskAttemptFailureRecords() {
return withReadLock(() -> new HashSet<>(taskAttemptFailureRecords));
}

public int getStageId() {
return stageId;
}

public int getShuffleId() {
return shuffleId;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.stage;

public class RssShuffleStatusForReader extends RssShuffleStatus {
public RssShuffleStatusForReader(int stageId, int shuffleId) {
super(stageId, shuffleId);
}
}
Loading
Loading