diff --git a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc index 8752484e..f28fbc4d 100644 --- a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc +++ b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc @@ -71,22 +71,29 @@ bool hasShardedPermutationFactors( !factorSharding.axisRefs.empty(); }); } -} // namespace +// Returns the common axes per factor if the factor sharding is compatible. +// Otherwise, returns empty AxesPerFactor. +// +// The factor sharding is compatible if it satisfies: +// 1. Factors are sharded the same way across operands and results. +// 2. Factors that need replication are unsharded. +// 3. There is no overlap between the sharding axes across different factors. +// +// Assumes factor shardings do not have overflow axes. // TODO(enver): Handle the case when some factor shardings have overflow axes. AxesPerFactor getCompatibleFactorShardings( const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule) { AxesPerFactor commonAxesPerFactor(shardingRule.getNumFactors()); BitVector seenFactors(shardingRule.getNumFactors()); + SmallVector seenAxisRefs; for (const TensorFactorShardings& tensorFactorSharding : llvm::concat( shardingProjection.getOperands(), shardingProjection.getResults())) { - // Detects conflicts within the same factor. for (const auto& [factorIndex, factorSharding] : tensorFactorSharding.factorIndexToSharding) { - // Factors that need replication should be unsharded across all operands - // and results in order for it to have a compatible sharding. + // Factors that need replication should be unsharded to be compatible. if (shardingRule.isNeedReplicationFactor(factorIndex)) { if (!factorSharding.axisRefs.empty()) { return {}; @@ -94,7 +101,11 @@ AxesPerFactor getCompatibleFactorShardings( continue; } if (!seenFactors.test(factorIndex)) { + if (overlaps(factorSharding.axisRefs, seenAxisRefs)) { + return {}; + } commonAxesPerFactor[factorIndex] = factorSharding.axisRefs; + seenAxisRefs.append(factorSharding.axisRefs); seenFactors.set(factorIndex); } else if (factorSharding.axisRefs != commonAxesPerFactor[factorIndex]) { return {}; @@ -102,25 +113,9 @@ AxesPerFactor getCompatibleFactorShardings( } } - // Detect conflict between reduction factors and output shardings. - // TODO(enver): Improve the compile-time performance. - for (const int64_t factorIndex : shardingRule.getReductionFactors()) { - ArrayRef reductionSharding = commonAxesPerFactor[factorIndex]; - for (const TensorFactorShardings& outTensorFactorSharding : - shardingProjection.getResults()) { - for (const auto& [outFactorIndex, outFactorSharding] : - outTensorFactorSharding.factorIndexToSharding) { - if (overlaps(reductionSharding, outFactorSharding.axisRefs)) { - return {}; - } - } - } - } return commonAxesPerFactor; } -namespace { - void insertExplicitReshardsOnOperand( Operation* op, const int64_t operandIndex, const ShardingProjection& shardingProjection, @@ -202,8 +197,8 @@ struct FactorAxesPair { int64_t factorIndex = kEmptyFactorIndex; AxisListRef axes; - FactorAxesPair(int64_t factorIndex, AxisListRef axes) - : factorIndex(factorIndex), axes(axes) {} + FactorAxesPair(int64_t factorIndex, ArrayRef axisRefs) + : factorIndex(factorIndex), axes(AxisListRef(axisRefs)) {} // TODO(enver): Define EmptyFactorAxesPair class with overloaded methods and // use it when the axes is empty. @@ -248,22 +243,15 @@ struct FactorAxesCandidate { FactorAxesPair factorAxes; // The total global size of the source tensors. int64_t totalGlobalSourceTensorSize = 0; - // The size of the local source tensor. In case the factor-axes pair has - // multiple source tensors, the size of the largest local one. A tensor is a - // source for a factor-axes pair if the axes is a prefix of the factor - // sharding on the tensor. - int64_t largestLocalSourceTensorSize = 0; // The size of axes to shard further. Hence, if the factor is already assigned // to axes A, and this factor-axes pair has axes B, the size of further // sharding is size(B)/size(A), and where A is a strict prefix of B. int64_t shardingSize = 0; int64_t factorTypePrecedence = 0; - FactorAxesCandidate(FactorAxesPair factorAxes, int64_t sourceTensorSize, - int64_t shardingSize, FactorType factorType) + FactorAxesCandidate(FactorAxesPair factorAxes, int64_t shardingSize, + FactorType factorType) : factorAxes(factorAxes), - totalGlobalSourceTensorSize(sourceTensorSize), - largestLocalSourceTensorSize(sourceTensorSize), shardingSize(shardingSize), factorTypePrecedence(precedence(factorType)) {} @@ -272,14 +260,12 @@ struct FactorAxesCandidate { // Multi-level comparison. // 1. totalGlobalSourceTensorSize // 2. factorTypePrecedence - // 3. largestLocalSourceTensorSize // 4. shardingSize // 5. factorAxes: If A is a strict prefix of B, then A is smaller than B. bool operator<(const FactorAxesCandidate& rhs) const { auto makeComparisonTuple = [](const FactorAxesCandidate& candidate) { return std::make_tuple(candidate.totalGlobalSourceTensorSize, candidate.factorTypePrecedence, - candidate.largestLocalSourceTensorSize, candidate.shardingSize, candidate.factorAxes); }; return makeComparisonTuple(*this) < makeComparisonTuple(rhs); @@ -304,28 +290,6 @@ struct FactorAxesCandidate { bool empty() const { return factorAxes.empty(); } }; -using FactorAxesCandidatesMap = - DenseMap; - -// Increment the count for the factor-axes pair, also modify source tensor size -// to keep the largest. -void updateFactorAxesCandidate(FactorAxesCandidatesMap& factorAxesCandidatesMap, - const FactorAxesPair& factorAxes, - int64_t sourceTensorSize, const Mesh& mesh, - const FactorType factorType) { - if (auto it = factorAxesCandidatesMap.find(factorAxes); - it != factorAxesCandidatesMap.end()) { - FactorAxesCandidate& candidate = it->second; - candidate.totalGlobalSourceTensorSize += sourceTensorSize; - candidate.largestLocalSourceTensorSize = - std::max(candidate.largestLocalSourceTensorSize, sourceTensorSize); - return; - } - factorAxesCandidatesMap.try_emplace( - factorAxes, factorAxes, sourceTensorSize, - factorAxes.axes.getShardingSize(mesh.attr()), factorType); -} - // A container for FactorAxesCandidates where the order of iteration does not // matter, and provides methods to insert and remove candidates in constant-time // while maintaining the best candidate. @@ -340,14 +304,14 @@ class FactorAxesCandidateBag { bool empty() const { return candidates.empty(); } // Inserts a new candidate to the bag. Performs in constant-time. - void insert(const FactorAxesCandidate& candidate) { - candidates.push_back(candidate); - updateBestCandidateIfValid(candidate); + void insert(const FactorAxesPair& factorAxes, + OpShardingRuleAttr shardingRule) { + candidates.emplace_back(factorAxes, factorAxes.axes.getShardingSize(mesh), + shardingRule.getFactorType(factorAxes.factorIndex)); } // Updates the sharding size of the one at index as the product of the - // sharding sizes of all individual axes excluding the `prefix`, also update - // the best. + // sharding sizes of all individual axes excluding the `prefix`. // // Assumes `prefix` is a prefix of the axes of the candidate at index. void updateShardingSizeAt(const int64_t index, @@ -355,41 +319,34 @@ class FactorAxesCandidateBag { FactorAxesCandidate& candidate = candidates[index]; candidate.shardingSize = candidate.factorAxes.axes.getExpandedShardingSize(mesh, prefix); - updateBestCandidateIfValid(candidate); - } - - // Updates the source tensor sizes of all candidates. - // TODO(enver): Optimize updating source tensor sizes. - void updateLocalSourceTensorSizes( - const ShardingProjection& shardingProjection, - ArrayRef tensorSizes, - const SmallVector& factorAxisRefs) { - // Since the (local) source tensor sizes get smaller at each iteration on - // which we extend sharding of a factor, in order to recompute largest - // source tensor sizes, we first need to reset them to zero. + } + + // TODO(enver): Optimize by grouping candidates on the same factors. + void updateTotalGlobalSourceTensorSizes( + const int64_t sourceFactorIndex, + ArrayRef sourceFactorAxisRefs, + const int64_t sourceTensorSize) { + AxisListRef sourceFactorAxes(sourceFactorAxisRefs); for (FactorAxesCandidate& candidate : candidates) { - candidate.largestLocalSourceTensorSize = 0; + FactorAxesPair& factorAxesPair = candidate.factorAxes; + if (factorAxesPair.factorIndex == sourceFactorIndex && + (sourceFactorAxes == factorAxesPair.axes || + factorAxesPair.axes.strictPrefixOf(sourceFactorAxes))) { + candidate.totalGlobalSourceTensorSize += sourceTensorSize; + } } + } - for (const auto& [tensorIndex, tensorFactorSharding] : - llvm::enumerate(llvm::concat( - shardingProjection.getOperands(), - shardingProjection.getResults()))) { - int64_t localTensorSize = tensorSizes[tensorIndex]; - for (const auto& [factorIndex, _] : - tensorFactorSharding.factorIndexToSharding) { - // TODO(enver): Consider cases tensor size may not be divisible. - localTensorSize /= factorAxisRefs[factorIndex].getShardingSize(mesh); - } - for (FactorAxesCandidate& candidate : candidates) { - if (tensorFactorSharding.factorIndexToSharding.contains( - candidate.factorAxes.factorIndex)) { - candidate.largestLocalSourceTensorSize = - std::max(candidate.largestLocalSourceTensorSize, localTensorSize); - updateBestCandidateIfValid(candidate); - } + FactorAxesCandidate getBestCandidate() { + FactorAxesCandidate bestCandidate; + for (FactorAxesCandidate& candidate : candidates) { + // The axes on replication factors are distributed to batching dimensions + // after the common axes are found for all non-replication factors. + if (isValid(candidate)) { + bestCandidate = std::max(bestCandidate, candidate); } } + return bestCandidate; } void dropFactorDependencies(const int64_t factorIndex) { @@ -398,9 +355,6 @@ class FactorAxesCandidateBag { } } - // Resets best. Performs in constant-time. - void resetBest() { bestCandidate = FactorAxesCandidate(); } - // Removes candidate at index. Performs in constant-time. After the // operation, the candidates before the index keep being before the index, and // the candidates after the index (except the removed one) keep being after @@ -414,12 +368,14 @@ class FactorAxesCandidateBag { candidates.pop_back(); } - // Returns the best. Performs in constant-time. - FactorAxesCandidate best() const { return bestCandidate; } // Returns the candidate at index. Performs in constant-time. FactorAxesCandidate& at(const int64_t index) { return candidates[index]; } // Returns the number of candidates in the bag. int64_t size() const { return candidates.size(); } + bool isValid(const FactorAxesCandidate& candidate) { + auto it = factorDependenciesMap.find(candidate.factorAxes.factorIndex); + return it == factorDependenciesMap.end() || it->second.none(); + } private: void initFactorDependencies(OpShardingRuleAttr shardingRule) { @@ -440,13 +396,6 @@ class FactorAxesCandidateBag { } } - void updateBestCandidateIfValid(const FactorAxesCandidate& candidate) { - auto it = factorDependenciesMap.find(candidate.factorAxes.factorIndex); - if (it == factorDependenciesMap.end() || it->second.none()) { - bestCandidate = std::max(bestCandidate, candidate); - } - } - // A factor is non-full if its sharding size is smaller than the size of the // factor. `factorDependenciesMap` is a map from factor indices to bitvectors, // each bitvector is associated with a factor f, and represents the set of @@ -461,7 +410,6 @@ class FactorAxesCandidateBag { // hence it may depend on multiple factors. llvm::SmallDenseMap factorDependenciesMap; SmallVector candidates; - FactorAxesCandidate bestCandidate; // Used for recalculating sharding size of a candidate. MeshAttr mesh; }; @@ -469,13 +417,13 @@ class FactorAxesCandidateBag { FactorAxesCandidateBag findFactorAxesCandidates( const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule, ArrayRef tensorSizes, - const Mesh& mesh) { + MeshAttr mesh) { // TODO(enver): For two factor-axes pairs, if both have the same factor and // the same count, and one is the prefix of the other, drop the prefix one. // Count factor-axes pairs by iterating through each sharding, and for each // sharding, update candidate for the sharding and all its prefixes. - FactorAxesCandidatesMap factorAxesCandidatesMap; + DenseSet factorAxesPairs; for (const auto& [tensorSize, tensorFactorSharding] : llvm::zip_equal(tensorSizes, llvm::concat( shardingProjection.getOperands(), @@ -487,19 +435,29 @@ FactorAxesCandidateBag findFactorAxesCandidates( } ArrayRef axisRefs = factorSharding.axisRefs; while (!axisRefs.empty()) { - updateFactorAxesCandidate( - factorAxesCandidatesMap, - FactorAxesPair(factorIndex, AxisListRef(axisRefs)), tensorSize, - mesh, shardingRule.getFactorType(factorIndex)); + factorAxesPairs.insert(FactorAxesPair(factorIndex, axisRefs)); axisRefs = axisRefs.drop_back(); } } } - FactorAxesCandidateBag factorAxesCandidates(mesh.attr(), shardingRule); - for (const auto& [_, candidate] : factorAxesCandidatesMap) { - factorAxesCandidates.insert(candidate); + FactorAxesCandidateBag factorAxesCandidates(mesh, shardingRule); + for (const FactorAxesPair& factorAxes : factorAxesPairs) { + factorAxesCandidates.insert(factorAxes, shardingRule); + } + + // Set total global source tensor sizes of candidates. + for (const auto& [tensorSize, tensorFactorSharding] : + llvm::zip_equal(tensorSizes, llvm::concat( + shardingProjection.getOperands(), + shardingProjection.getResults()))) { + for (const auto& [factorIndex, factorSharding] : + tensorFactorSharding.factorIndexToSharding) { + factorAxesCandidates.updateTotalGlobalSourceTensorSizes( + factorIndex, factorSharding.axisRefs, tensorSize); + } } + return factorAxesCandidates; } @@ -526,10 +484,10 @@ AxesPerFactor findCommonAxesHeuristic( const Mesh& mesh) { SmallVector factorAxisRefs(shardingRule.getNumFactors()); FactorAxesCandidateBag factorAxesCandidates = findFactorAxesCandidates( - shardingProjection, shardingRule, tensorSizes, mesh); - while (!factorAxesCandidates.best().empty()) { - FactorAxesPair bestFactorAxes = factorAxesCandidates.best().factorAxes; - factorAxesCandidates.resetBest(); + shardingProjection, shardingRule, tensorSizes, mesh.attr()); + FactorAxesCandidate bestCandidate = factorAxesCandidates.getBestCandidate(); + while (!bestCandidate.empty()) { + FactorAxesPair bestFactorAxes = bestCandidate.factorAxes; factorAxisRefs[bestFactorAxes.factorIndex] = bestFactorAxes.axes; if (bestFactorAxes.isFullySharded(shardingRule, mesh.attr())) { factorAxesCandidates.dropFactorDependencies(bestFactorAxes.factorIndex); @@ -592,10 +550,7 @@ AxesPerFactor findCommonAxesHeuristic( factorAxesCandidates.updateShardingSizeAt(candidateIndex++); } - // TODO(enver): Optimize updating source tensor sizes. - factorAxesCandidates.resetBest(); - factorAxesCandidates.updateLocalSourceTensorSizes( - shardingProjection, tensorSizes, factorAxisRefs); + bestCandidate = factorAxesCandidates.getBestCandidate(); } // TODO(enver): Consider to keep factorAxisRefs for longer until actual @@ -719,6 +674,12 @@ void distributeAxisRefsToBatchingFactors( AxesPerFactor findCommonAxes(const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule, ArrayRef tensorSizes, const Mesh& mesh) { + if (AxesPerFactor compatibleFactorShardings = + getCompatibleFactorShardings(shardingProjection, shardingRule); + !compatibleFactorShardings.empty()) { + return compatibleFactorShardings; + } + // Handle the special case of unary operations without factors that need // replication. Reshard only one of the tensors. if (shardingRule.getNonScalarTensorIndices().size() == 2 &&