Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 81 additions & 120 deletions shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,56 +71,51 @@ bool hasShardedPermutationFactors(
!factorSharding.axisRefs.empty();
});
}
} // namespace

// Returns the common axes per factor if the factor sharding is compatible.
// Otherwise, returns empty AxesPerFactor.
//
// The factor sharding is compatible if it satisfies:
// 1. Factors are sharded the same way across operands and results.
// 2. Factors that need replication are unsharded.
// 3. There is no overlap between the sharding axes across different factors.
//
// Assumes factor shardings do not have overflow axes.
// TODO(enver): Handle the case when some factor shardings have overflow axes.
AxesPerFactor getCompatibleFactorShardings(
const ShardingProjection& shardingProjection,
OpShardingRuleAttr shardingRule) {
AxesPerFactor commonAxesPerFactor(shardingRule.getNumFactors());
BitVector seenFactors(shardingRule.getNumFactors());
SmallVector<AxisRefAttr> seenAxisRefs;
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(
shardingProjection.getOperands(), shardingProjection.getResults())) {
// Detects conflicts within the same factor.
for (const auto& [factorIndex, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
// Factors that need replication should be unsharded across all operands
// and results in order for it to have a compatible sharding.
// Factors that need replication should be unsharded to be compatible.
if (shardingRule.isNeedReplicationFactor(factorIndex)) {
if (!factorSharding.axisRefs.empty()) {
return {};
}
continue;
}
if (!seenFactors.test(factorIndex)) {
if (overlaps(factorSharding.axisRefs, seenAxisRefs)) {
return {};
}
commonAxesPerFactor[factorIndex] = factorSharding.axisRefs;
seenAxisRefs.append(factorSharding.axisRefs);
seenFactors.set(factorIndex);
} else if (factorSharding.axisRefs != commonAxesPerFactor[factorIndex]) {
return {};
}
}
}

// Detect conflict between reduction factors and output shardings.
// TODO(enver): Improve the compile-time performance.
for (const int64_t factorIndex : shardingRule.getReductionFactors()) {
ArrayRef<AxisRefAttr> reductionSharding = commonAxesPerFactor[factorIndex];
for (const TensorFactorShardings& outTensorFactorSharding :
shardingProjection.getResults()) {
for (const auto& [outFactorIndex, outFactorSharding] :
outTensorFactorSharding.factorIndexToSharding) {
if (overlaps(reductionSharding, outFactorSharding.axisRefs)) {
return {};
}
}
}
}
return commonAxesPerFactor;
}

namespace {

void insertExplicitReshardsOnOperand(
Operation* op, const int64_t operandIndex,
const ShardingProjection& shardingProjection,
Expand Down Expand Up @@ -202,8 +197,8 @@ struct FactorAxesPair {
int64_t factorIndex = kEmptyFactorIndex;
AxisListRef axes;

FactorAxesPair(int64_t factorIndex, AxisListRef axes)
: factorIndex(factorIndex), axes(axes) {}
FactorAxesPair(int64_t factorIndex, ArrayRef<AxisRefAttr> axisRefs)
: factorIndex(factorIndex), axes(AxisListRef(axisRefs)) {}

// TODO(enver): Define EmptyFactorAxesPair class with overloaded methods and
// use it when the axes is empty.
Expand Down Expand Up @@ -248,22 +243,15 @@ struct FactorAxesCandidate {
FactorAxesPair factorAxes;
// The total global size of the source tensors.
int64_t totalGlobalSourceTensorSize = 0;
// The size of the local source tensor. In case the factor-axes pair has
// multiple source tensors, the size of the largest local one. A tensor is a
// source for a factor-axes pair if the axes is a prefix of the factor
// sharding on the tensor.
int64_t largestLocalSourceTensorSize = 0;
// The size of axes to shard further. Hence, if the factor is already assigned
// to axes A, and this factor-axes pair has axes B, the size of further
// sharding is size(B)/size(A), and where A is a strict prefix of B.
int64_t shardingSize = 0;
int64_t factorTypePrecedence = 0;

FactorAxesCandidate(FactorAxesPair factorAxes, int64_t sourceTensorSize,
int64_t shardingSize, FactorType factorType)
FactorAxesCandidate(FactorAxesPair factorAxes, int64_t shardingSize,
FactorType factorType)
: factorAxes(factorAxes),
totalGlobalSourceTensorSize(sourceTensorSize),
largestLocalSourceTensorSize(sourceTensorSize),
shardingSize(shardingSize),
factorTypePrecedence(precedence(factorType)) {}

Expand All @@ -272,14 +260,12 @@ struct FactorAxesCandidate {
// Multi-level comparison.
// 1. totalGlobalSourceTensorSize
// 2. factorTypePrecedence
// 3. largestLocalSourceTensorSize
// 4. shardingSize
// 5. factorAxes: If A is a strict prefix of B, then A is smaller than B.
bool operator<(const FactorAxesCandidate& rhs) const {
auto makeComparisonTuple = [](const FactorAxesCandidate& candidate) {
return std::make_tuple(candidate.totalGlobalSourceTensorSize,
candidate.factorTypePrecedence,
candidate.largestLocalSourceTensorSize,
candidate.shardingSize, candidate.factorAxes);
};
return makeComparisonTuple(*this) < makeComparisonTuple(rhs);
Expand All @@ -304,28 +290,6 @@ struct FactorAxesCandidate {
bool empty() const { return factorAxes.empty(); }
};

using FactorAxesCandidatesMap =
DenseMap<FactorAxesPair, FactorAxesCandidate, FactorAxesPairInfo>;

// Increment the count for the factor-axes pair, also modify source tensor size
// to keep the largest.
void updateFactorAxesCandidate(FactorAxesCandidatesMap& factorAxesCandidatesMap,
const FactorAxesPair& factorAxes,
int64_t sourceTensorSize, const Mesh& mesh,
const FactorType factorType) {
if (auto it = factorAxesCandidatesMap.find(factorAxes);
it != factorAxesCandidatesMap.end()) {
FactorAxesCandidate& candidate = it->second;
candidate.totalGlobalSourceTensorSize += sourceTensorSize;
candidate.largestLocalSourceTensorSize =
std::max(candidate.largestLocalSourceTensorSize, sourceTensorSize);
return;
}
factorAxesCandidatesMap.try_emplace(
factorAxes, factorAxes, sourceTensorSize,
factorAxes.axes.getShardingSize(mesh.attr()), factorType);
}

// A container for FactorAxesCandidates where the order of iteration does not
// matter, and provides methods to insert and remove candidates in constant-time
// while maintaining the best candidate.
Expand All @@ -340,56 +304,49 @@ class FactorAxesCandidateBag {
bool empty() const { return candidates.empty(); }

// Inserts a new candidate to the bag. Performs in constant-time.
void insert(const FactorAxesCandidate& candidate) {
candidates.push_back(candidate);
updateBestCandidateIfValid(candidate);
void insert(const FactorAxesPair& factorAxes,
OpShardingRuleAttr shardingRule) {
candidates.emplace_back(factorAxes, factorAxes.axes.getShardingSize(mesh),
shardingRule.getFactorType(factorAxes.factorIndex));
}

// Updates the sharding size of the one at index as the product of the
// sharding sizes of all individual axes excluding the `prefix`, also update
// the best.
// sharding sizes of all individual axes excluding the `prefix`.
//
// Assumes `prefix` is a prefix of the axes of the candidate at index.
void updateShardingSizeAt(const int64_t index,
const AxisListRef& prefix = AxisListRef()) {
FactorAxesCandidate& candidate = candidates[index];
candidate.shardingSize =
candidate.factorAxes.axes.getExpandedShardingSize(mesh, prefix);
updateBestCandidateIfValid(candidate);
}

// Updates the source tensor sizes of all candidates.
// TODO(enver): Optimize updating source tensor sizes.
void updateLocalSourceTensorSizes(
const ShardingProjection& shardingProjection,
ArrayRef<int64_t> tensorSizes,
const SmallVector<AxisListRef>& factorAxisRefs) {
// Since the (local) source tensor sizes get smaller at each iteration on
// which we extend sharding of a factor, in order to recompute largest
// source tensor sizes, we first need to reset them to zero.
}

// TODO(enver): Optimize by grouping candidates on the same factors.
void updateTotalGlobalSourceTensorSizes(
const int64_t sourceFactorIndex,
ArrayRef<AxisRefAttr> sourceFactorAxisRefs,
const int64_t sourceTensorSize) {
AxisListRef sourceFactorAxes(sourceFactorAxisRefs);
for (FactorAxesCandidate& candidate : candidates) {
candidate.largestLocalSourceTensorSize = 0;
FactorAxesPair& factorAxesPair = candidate.factorAxes;
if (factorAxesPair.factorIndex == sourceFactorIndex &&
(sourceFactorAxes == factorAxesPair.axes ||
factorAxesPair.axes.strictPrefixOf(sourceFactorAxes))) {
candidate.totalGlobalSourceTensorSize += sourceTensorSize;
}
}
}

for (const auto& [tensorIndex, tensorFactorSharding] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
shardingProjection.getOperands(),
shardingProjection.getResults()))) {
int64_t localTensorSize = tensorSizes[tensorIndex];
for (const auto& [factorIndex, _] :
tensorFactorSharding.factorIndexToSharding) {
// TODO(enver): Consider cases tensor size may not be divisible.
localTensorSize /= factorAxisRefs[factorIndex].getShardingSize(mesh);
}
for (FactorAxesCandidate& candidate : candidates) {
if (tensorFactorSharding.factorIndexToSharding.contains(
candidate.factorAxes.factorIndex)) {
candidate.largestLocalSourceTensorSize =
std::max(candidate.largestLocalSourceTensorSize, localTensorSize);
updateBestCandidateIfValid(candidate);
}
FactorAxesCandidate getBestCandidate() {
FactorAxesCandidate bestCandidate;
for (FactorAxesCandidate& candidate : candidates) {
// The axes on replication factors are distributed to batching dimensions
// after the common axes are found for all non-replication factors.
if (isValid(candidate)) {
bestCandidate = std::max(bestCandidate, candidate);
}
}
return bestCandidate;
}

void dropFactorDependencies(const int64_t factorIndex) {
Expand All @@ -398,9 +355,6 @@ class FactorAxesCandidateBag {
}
}

// Resets best. Performs in constant-time.
void resetBest() { bestCandidate = FactorAxesCandidate(); }

// Removes candidate at index. Performs in constant-time. After the
// operation, the candidates before the index keep being before the index, and
// the candidates after the index (except the removed one) keep being after
Expand All @@ -414,12 +368,14 @@ class FactorAxesCandidateBag {
candidates.pop_back();
}

// Returns the best. Performs in constant-time.
FactorAxesCandidate best() const { return bestCandidate; }
// Returns the candidate at index. Performs in constant-time.
FactorAxesCandidate& at(const int64_t index) { return candidates[index]; }
// Returns the number of candidates in the bag.
int64_t size() const { return candidates.size(); }
bool isValid(const FactorAxesCandidate& candidate) {
auto it = factorDependenciesMap.find(candidate.factorAxes.factorIndex);
return it == factorDependenciesMap.end() || it->second.none();
}

private:
void initFactorDependencies(OpShardingRuleAttr shardingRule) {
Expand All @@ -440,13 +396,6 @@ class FactorAxesCandidateBag {
}
}

void updateBestCandidateIfValid(const FactorAxesCandidate& candidate) {
auto it = factorDependenciesMap.find(candidate.factorAxes.factorIndex);
if (it == factorDependenciesMap.end() || it->second.none()) {
bestCandidate = std::max(bestCandidate, candidate);
}
}

// A factor is non-full if its sharding size is smaller than the size of the
// factor. `factorDependenciesMap` is a map from factor indices to bitvectors,
// each bitvector is associated with a factor f, and represents the set of
Expand All @@ -461,21 +410,20 @@ class FactorAxesCandidateBag {
// hence it may depend on multiple factors.
llvm::SmallDenseMap<int64_t, BitVector> factorDependenciesMap;
SmallVector<FactorAxesCandidate> candidates;
FactorAxesCandidate bestCandidate;
// Used for recalculating sharding size of a candidate.
MeshAttr mesh;
};

FactorAxesCandidateBag findFactorAxesCandidates(
const ShardingProjection& shardingProjection,
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
const Mesh& mesh) {
MeshAttr mesh) {
// TODO(enver): For two factor-axes pairs, if both have the same factor and
// the same count, and one is the prefix of the other, drop the prefix one.

// Count factor-axes pairs by iterating through each sharding, and for each
// sharding, update candidate for the sharding and all its prefixes.
FactorAxesCandidatesMap factorAxesCandidatesMap;
DenseSet<FactorAxesPair, FactorAxesPairInfo> factorAxesPairs;
for (const auto& [tensorSize, tensorFactorSharding] :
llvm::zip_equal(tensorSizes, llvm::concat<const TensorFactorShardings>(
shardingProjection.getOperands(),
Expand All @@ -487,19 +435,29 @@ FactorAxesCandidateBag findFactorAxesCandidates(
}
ArrayRef<AxisRefAttr> axisRefs = factorSharding.axisRefs;
while (!axisRefs.empty()) {
updateFactorAxesCandidate(
factorAxesCandidatesMap,
FactorAxesPair(factorIndex, AxisListRef(axisRefs)), tensorSize,
mesh, shardingRule.getFactorType(factorIndex));
factorAxesPairs.insert(FactorAxesPair(factorIndex, axisRefs));
axisRefs = axisRefs.drop_back();
}
}
}

FactorAxesCandidateBag factorAxesCandidates(mesh.attr(), shardingRule);
for (const auto& [_, candidate] : factorAxesCandidatesMap) {
factorAxesCandidates.insert(candidate);
FactorAxesCandidateBag factorAxesCandidates(mesh, shardingRule);
for (const FactorAxesPair& factorAxes : factorAxesPairs) {
factorAxesCandidates.insert(factorAxes, shardingRule);
}

// Set total global source tensor sizes of candidates.
for (const auto& [tensorSize, tensorFactorSharding] :
llvm::zip_equal(tensorSizes, llvm::concat<const TensorFactorShardings>(
shardingProjection.getOperands(),
shardingProjection.getResults()))) {
for (const auto& [factorIndex, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
factorAxesCandidates.updateTotalGlobalSourceTensorSizes(
factorIndex, factorSharding.axisRefs, tensorSize);
}
}

return factorAxesCandidates;
}

Expand All @@ -526,10 +484,10 @@ AxesPerFactor findCommonAxesHeuristic(
const Mesh& mesh) {
SmallVector<AxisListRef> factorAxisRefs(shardingRule.getNumFactors());
FactorAxesCandidateBag factorAxesCandidates = findFactorAxesCandidates(
shardingProjection, shardingRule, tensorSizes, mesh);
while (!factorAxesCandidates.best().empty()) {
FactorAxesPair bestFactorAxes = factorAxesCandidates.best().factorAxes;
factorAxesCandidates.resetBest();
shardingProjection, shardingRule, tensorSizes, mesh.attr());
FactorAxesCandidate bestCandidate = factorAxesCandidates.getBestCandidate();
while (!bestCandidate.empty()) {
FactorAxesPair bestFactorAxes = bestCandidate.factorAxes;
factorAxisRefs[bestFactorAxes.factorIndex] = bestFactorAxes.axes;
if (bestFactorAxes.isFullySharded(shardingRule, mesh.attr())) {
factorAxesCandidates.dropFactorDependencies(bestFactorAxes.factorIndex);
Expand Down Expand Up @@ -592,10 +550,7 @@ AxesPerFactor findCommonAxesHeuristic(
factorAxesCandidates.updateShardingSizeAt(candidateIndex++);
}

// TODO(enver): Optimize updating source tensor sizes.
factorAxesCandidates.resetBest();
factorAxesCandidates.updateLocalSourceTensorSizes(
shardingProjection, tensorSizes, factorAxisRefs);
bestCandidate = factorAxesCandidates.getBestCandidate();
}

// TODO(enver): Consider to keep factorAxisRefs for longer until actual
Expand Down Expand Up @@ -719,6 +674,12 @@ void distributeAxisRefsToBatchingFactors(
AxesPerFactor findCommonAxes(const ShardingProjection& shardingProjection,
OpShardingRuleAttr shardingRule,
ArrayRef<int64_t> tensorSizes, const Mesh& mesh) {
if (AxesPerFactor compatibleFactorShardings =
getCompatibleFactorShardings(shardingProjection, shardingRule);
!compatibleFactorShardings.empty()) {
return compatibleFactorShardings;
}

// Handle the special case of unary operations without factors that need
// replication. Reshard only one of the tensors.
if (shardingRule.getNonScalarTensorIndices().size() == 2 &&
Expand Down
Loading