diff --git a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc index 234ea66f..77264ce3 100644 --- a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc +++ b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc @@ -71,22 +71,29 @@ bool hasShardedPermutationFactors( !factorSharding.axisRefs.empty(); }); } -} // namespace +// Returns the common axes per factor if the factor sharding is compatible. +// Otherwise, returns empty AxesPerFactor. +// +// The factor sharding is compatible if it satisfies: +// 1. Factors are sharded the same way across operands and results. +// 2. Factors that need replication are unsharded. +// 3. There is no overlap between the sharding axes across different factors. +// +// Assumes factor shardings do not have overflow axes. // TODO(enver): Handle the case when some factor shardings have overflow axes. AxesPerFactor getCompatibleFactorShardings( const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule) { AxesPerFactor commonAxesPerFactor(shardingRule.getNumFactors()); BitVector seenFactors(shardingRule.getNumFactors()); + SmallVector seenAxisRefs; for (const TensorFactorShardings& tensorFactorSharding : llvm::concat( shardingProjection.getOperands(), shardingProjection.getResults())) { - // Detects conflicts within the same factor. for (const auto& [factorIndex, factorSharding] : tensorFactorSharding.factorIndexToSharding) { - // Factors that need replication should be unsharded across all operands - // and results in order for it to have a compatible sharding. + // Factors that need replication should be unsharded to be compatible. if (shardingRule.isNeedReplicationFactor(factorIndex)) { if (!factorSharding.axisRefs.empty()) { return {}; @@ -94,7 +101,11 @@ AxesPerFactor getCompatibleFactorShardings( continue; } if (!seenFactors.test(factorIndex)) { + if (overlaps(factorSharding.axisRefs, seenAxisRefs)) { + return {}; + } commonAxesPerFactor[factorIndex] = factorSharding.axisRefs; + seenAxisRefs.append(factorSharding.axisRefs); seenFactors.set(factorIndex); } else if (factorSharding.axisRefs != commonAxesPerFactor[factorIndex]) { return {}; @@ -102,25 +113,9 @@ AxesPerFactor getCompatibleFactorShardings( } } - // Detect conflict between reduction factors and output shardings. - // TODO(enver): Improve the compile-time performance. - for (const int64_t factorIndex : shardingRule.getReductionFactors()) { - ArrayRef reductionSharding = commonAxesPerFactor[factorIndex]; - for (const TensorFactorShardings& outTensorFactorSharding : - shardingProjection.getResults()) { - for (const auto& [outFactorIndex, outFactorSharding] : - outTensorFactorSharding.factorIndexToSharding) { - if (overlaps(reductionSharding, outFactorSharding.axisRefs)) { - return {}; - } - } - } - } return commonAxesPerFactor; } -namespace { - void insertExplicitReshardsOnOperand( Operation* op, const int64_t operandIndex, const ShardingProjection& shardingProjection, @@ -202,8 +197,8 @@ struct FactorAxesPair { int64_t factorIndex = kEmptyFactorIndex; AxisListRef axes; - FactorAxesPair(int64_t factorIndex, AxisListRef axes) - : factorIndex(factorIndex), axes(axes) {} + FactorAxesPair(int64_t factorIndex, ArrayRef axisRefs) + : factorIndex(factorIndex), axes(AxisListRef(axisRefs)) {} // TODO(enver): Define EmptyFactorAxesPair class with overloaded methods and // use it when the axes is empty. @@ -223,7 +218,7 @@ struct FactorAxesPair { bool empty() const { return factorIndex == kEmptyFactorIndex; } - bool isFullySharded(OpShardingRuleAttr shardingRule, MeshAttr mesh) { + bool isFullySharded(OpShardingRuleAttr shardingRule, MeshAttr mesh) const { return axes.getShardingSize(mesh) == shardingRule.getFactorSizes()[factorIndex]; } @@ -246,50 +241,50 @@ struct FactorAxesPairInfo : public llvm::DenseMapInfo { struct FactorAxesCandidate { FactorAxesPair factorAxes; - // The total size of the source tensors. - int64_t totalSourceTensorSize = 0; - // The size of the source tensor. In case the factor-axes pair has multiple - // source tensors, the size of the largest one. A tensor is a source for a - // factor-axes pair if the axes is a prefix of the factor sharding on the - // tensor. - int64_t largestSourceTensorSize = 0; + // The total global size of the source tensors. + int64_t totalGlobalSourceTensorSize = 0; + // The size of the local source tensor. In case the factor-axes pair has + // multiple source tensors, the size of the largest local one. A tensor is a + // source for a factor-axes pair if the axes is a prefix of the factor + // sharding on the tensor. + int64_t largestLocalSourceTensorSize = 0; // The size of axes to shard further. Hence, if the factor is already assigned // to axes A, and this factor-axes pair has axes B, the size of further // sharding is size(B)/size(A), and where A is a strict prefix of B. int64_t shardingSize = 0; int64_t factorTypePrecedence = 0; + int64_t communicationCost = INT64_MAX; - FactorAxesCandidate(FactorAxesPair factorAxes, int64_t sourceTensorSize, - int64_t shardingSize, FactorType factorType) + FactorAxesCandidate(FactorAxesPair factorAxes, int64_t shardingSize, + FactorType factorType) : factorAxes(factorAxes), - totalSourceTensorSize(sourceTensorSize), - largestSourceTensorSize(sourceTensorSize), shardingSize(shardingSize), factorTypePrecedence(precedence(factorType)) {} FactorAxesCandidate() = default; // Multi-level comparison. - // 1. totalSourceTensorSize + // 0. communicationCost + // 1. totalGlobalSourceTensorSize // 2. factorTypePrecedence - // 3. largestSourceTensorSize + // 3. largestLocalSourceTensorSize // 4. shardingSize // 5. factorAxes: If A is a strict prefix of B, then A is smaller than B. bool operator<(const FactorAxesCandidate& rhs) const { auto makeComparisonTuple = [](const FactorAxesCandidate& candidate) { - return std::forward_as_tuple( - candidate.totalSourceTensorSize, candidate.factorTypePrecedence, - candidate.largestSourceTensorSize, candidate.shardingSize, - candidate.factorAxes); + return std::make_tuple(-candidate.communicationCost, + candidate.totalGlobalSourceTensorSize, + candidate.factorTypePrecedence, + candidate.largestLocalSourceTensorSize, + candidate.shardingSize, candidate.factorAxes); }; return makeComparisonTuple(*this) < makeComparisonTuple(rhs); } - // A candidate with a higher precedence will be preferable (given their source + // A candidate with a higher precedence will be preferred (given their source // tensor sizes are the same) to a candidate with a lower precedence when - // finding the best candidate to extend the factor sharding assignment during - // the majority vote heuristic. - int64_t precedence(FactorType factorType) { + // finding the best candidate to extend the factor sharding assignment. + int64_t precedence(FactorType factorType) const { switch (factorType) { case FactorType::kPassThrough: return 3; @@ -305,26 +300,157 @@ struct FactorAxesCandidate { bool empty() const { return factorAxes.empty(); } }; -using FactorAxesCandidatesMap = - DenseMap; - -// Increment the count for the factor-axes pair, also modify source tensor size -// to keep the largest. -void updateFactorAxesCandidate(FactorAxesCandidatesMap& factorAxesCandidatesMap, - const FactorAxesPair& factorAxes, - int64_t sourceTensorSize, const Mesh& mesh, - const FactorType factorType) { - if (auto it = factorAxesCandidatesMap.find(factorAxes); - it != factorAxesCandidatesMap.end()) { - FactorAxesCandidate& candidate = it->second; - candidate.totalSourceTensorSize += sourceTensorSize; - candidate.largestSourceTensorSize = - std::max(candidate.largestSourceTensorSize, sourceTensorSize); - return; +int64_t getShardingSize(ArrayRef axisRefs, MeshAttr mesh) { + int64_t shardingSize = 1; + for (AxisRefAttr axisRef : axisRefs) { + shardingSize *= axisRef.getSize(mesh); } - factorAxesCandidatesMap.try_emplace( - factorAxes, factorAxes, sourceTensorSize, - factorAxes.axes.getShardingSize(mesh.attr()), factorType); + return shardingSize; +} + +std::pair, SmallVector> +getShardingAxesInOtherAndThisFactor( + const TensorFactorShardings& tensorFactorSharding, + const int64_t factorIndex) { + SmallVector axesInOtherFactor; + SmallVector axesInThisFactor; + for (const auto& [i, factorSharding] : + tensorFactorSharding.factorIndexToSharding) { + if (i == factorIndex) { + axesInThisFactor = factorSharding.axisRefs; + } else { + axesInOtherFactor.append(factorSharding.axisRefs.begin(), + factorSharding.axisRefs.end()); + } + } + return {axesInOtherFactor, axesInThisFactor}; +} + +int64_t getCommunicationCost(const ShardingProjection& shardingProjection, + OpShardingRuleAttr shardingRule, + ArrayRef tensorSizes, + ArrayRef localTensorSizes, MeshAttr mesh, + const FactorAxesPair& factorAxesPair, + const int64_t expandedShardingSize) { + // The relative cost of collective operations. + constexpr int64_t allToAllCost = 1; + constexpr int64_t collectivePermuteCost = 2; + constexpr int64_t allGatherCost = 4; + constexpr int64_t reduceScatterCost = 4; + constexpr int64_t allReduceCost = 8; + + int64_t communicationCost = 0; + + // For each tensor (operand or result), we use the following notations. + // + // `factorAxesPair` is the candidate factor-axes pair. + // * X = factorAxesPair.axes. + // * A = sharding axes in other factors in the original sharding. + // * B = sharding axes in this factor in the original sharding. + // * AX = the intersection (overlap) of A and X. + // * B-X = the difference of B and X. + + SmallVector axesX = factorAxesPair.axes.toVector(); + int64_t axesXSize = factorAxesPair.axes.getShardingSize(mesh); + + // For each operand, estimate the cost of reshard from original sharding to + // the candidate sharding axes. + // + // If the operand does not contain this factor, we need an all-gather on AX. + // + // If the operand contains this factor, we need + // 1. all-to-all to move AX from other factors to this factor. + // 2. collective-permute to handle B-X. + // 3. all-gather to shrink the sharding size if needed. + for (const auto& [tensorSize, tensorFactorSharding] : llvm::zip_equal( + tensorSizes.drop_back(shardingProjection.getNumResults()), + shardingProjection.getOperands())) { + bool operandContainsFactor = + tensorFactorSharding.factorIndexToSharding.contains( + factorAxesPair.factorIndex); + int64_t shardedTensorSize = + tensorSize / tensorFactorSharding.getShardingSize(mesh); + auto [axesA, axesB] = getShardingAxesInOtherAndThisFactor( + tensorFactorSharding, factorAxesPair.factorIndex); + + SmallVector diffXA = getAxisSetDiff(axesX, axesA, mesh); + int64_t diffXASize = getShardingSize(diffXA, mesh); + + if (axesXSize > diffXASize) { + // all-to-all on AX. + communicationCost += + (operandContainsFactor ? allToAllCost : allGatherCost) * + shardedTensorSize; + } + + if (operandContainsFactor) { + if (!getAxisSetDiff(axesB, axesX, mesh).empty()) { + communicationCost += collectivePermuteCost * shardedTensorSize; + } + if (getShardingSize(axesB, mesh) > diffXASize) { + // The operand is over-sharded than the candidate. We need all-gather to + // shrink the sharding size. + communicationCost += allGatherCost * shardedTensorSize; + } + } + } + + // For each result, estimate the cost of reshard from the candidate sharding + // axes to original sharding. + // + // We use the same notations as above. + // + // If the candidate factor is a reduction factor, we need all-reduce or + // reduce-scatter on the result. + // + // If the result does not contain this factor, there is no additional cost. + // + // If the result contains this factor, we need + // 1. all-to-all to move AX from this factor to other factors. + // 2. all-gather to shrink the sharding size after the all-to-all above. + for (const auto& [tensorSize, localTensorSize, tensorFactorSharding] : + llvm::zip_equal( + tensorSizes.drop_front(shardingProjection.getNumOperands()), + localTensorSizes.drop_front(shardingProjection.getNumOperands()), + shardingProjection.getResults())) { + // A candidate factor axes (factorAxesPair) is guaranteed to be an expansion + // of its existing sharding and `localTensorSize has already taken into its + // existing sharding. In order to avoid double counting, it needs to shard + // further on the expanded sharding size only. + auto [axesA, axesB] = getShardingAxesInOtherAndThisFactor( + tensorFactorSharding, factorAxesPair.factorIndex); + + SmallVector diffXA = getAxisSetDiff(axesX, axesA, mesh); + int64_t diffXASize = getShardingSize(diffXA, mesh); + + if (shardingRule.isReductionFactor(factorAxesPair.factorIndex)) { + communicationCost += + (diffXASize > 1 ? allReduceCost : reduceScatterCost) * + (localTensorSize / expandedShardingSize); + } + + int64_t shardedTensorSize = + tensorSize / tensorFactorSharding.getShardingSize(mesh); + if (!tensorFactorSharding.factorIndexToSharding.contains( + factorAxesPair.factorIndex)) { + continue; + } + if (axesXSize > diffXASize) { + // all-to-all on AX. + communicationCost += allToAllCost * shardedTensorSize; + } + + if (!getAxisSetDiff(axesB, axesX, mesh).empty()) { + communicationCost += collectivePermuteCost * shardedTensorSize; + } + if (getShardingSize(axesB, mesh) < diffXASize) { + // The result is less-sharded than the candidate. We need all-gather to + // shrink the sharding size. + communicationCost += allGatherCost * shardedTensorSize; + } + } + + return communicationCost; } // A container for FactorAxesCandidates where the order of iteration does not @@ -341,14 +467,14 @@ class FactorAxesCandidateBag { bool empty() const { return candidates.empty(); } // Inserts a new candidate to the bag. Performs in constant-time. - void insert(const FactorAxesCandidate& candidate) { - candidates.push_back(candidate); - updateBestCandidateIfValid(candidate); + void insert(const FactorAxesPair& factorAxes, + OpShardingRuleAttr shardingRule) { + candidates.emplace_back(factorAxes, factorAxes.axes.getShardingSize(mesh), + shardingRule.getFactorType(factorAxes.factorIndex)); } // Updates the sharding size of the one at index as the product of the - // sharding sizes of all individual axes excluding the `prefix`, also update - // the best. + // sharding sizes of all individual axes excluding the `prefix`. // // Assumes `prefix` is a prefix of the axes of the candidate at index. void updateShardingSizeAt(const int64_t index, @@ -356,23 +482,44 @@ class FactorAxesCandidateBag { FactorAxesCandidate& candidate = candidates[index]; candidate.shardingSize = candidate.factorAxes.axes.getExpandedShardingSize(mesh, prefix); - updateBestCandidateIfValid(candidate); } - // Updates the source tensor sizes of all candidates. - // TODO(enver): Optimize updating source tensor sizes. - void updateSourceTensorSizes(const ShardingProjection& shardingProjection, - ArrayRef tensorSizes, - const SmallVector& factorAxisRefs) { + // TODO(enver): Optimize by grouping candidates on the same factors. + void updateTotalGlobalSourceTensorSizes( + const int64_t sourceFactorIndex, + ArrayRef sourceFactorAxisRefs, + const int64_t sourceTensorSize) { + AxisListRef sourceFactorAxes(sourceFactorAxisRefs); + for (FactorAxesCandidate& candidate : candidates) { + FactorAxesPair& factorAxesPair = candidate.factorAxes; + if (factorAxesPair.factorIndex == sourceFactorIndex && + (sourceFactorAxes == factorAxesPair.axes || + factorAxesPair.axes.strictPrefixOf(sourceFactorAxes))) { + candidate.totalGlobalSourceTensorSize += sourceTensorSize; + } + } + } + + // Updates the local largest source tensor sizes and communication costs of + // all candidates and returns the new best. + // TODO(enver): Optimize updating communication costs. + FactorAxesCandidate updateCommunicationCostsAndGetBest( + const ShardingProjection& shardingProjection, + ArrayRef tensorSizes, + const SmallVector& factorAxisRefs, + OpShardingRuleAttr shardingRule) { // Since the (local) source tensor sizes get smaller at each iteration on // which we extend sharding of a factor, in order to recompute largest // source tensor sizes, we first need to reset them to zero. - resetLargestSourceTensorSizes(); + for (FactorAxesCandidate& candidate : candidates) { + candidate.largestLocalSourceTensorSize = 0; + } + SmallVector localTensorSizes = llvm::to_vector(tensorSizes); for (const auto& [tensorIndex, tensorFactorSharding] : llvm::enumerate(llvm::concat( shardingProjection.getOperands(), shardingProjection.getResults()))) { - int64_t localTensorSize = tensorSizes[tensorIndex]; + int64_t& localTensorSize = localTensorSizes[tensorIndex]; for (const auto& [factorIndex, _] : tensorFactorSharding.factorIndexToSharding) { // TODO(enver): Consider cases tensor size may not be divisible. @@ -381,12 +528,29 @@ class FactorAxesCandidateBag { for (FactorAxesCandidate& candidate : candidates) { if (tensorFactorSharding.factorIndexToSharding.contains( candidate.factorAxes.factorIndex)) { - candidate.largestSourceTensorSize = - std::max(candidate.largestSourceTensorSize, localTensorSize); - updateBestCandidateIfValid(candidate); + candidate.largestLocalSourceTensorSize = + std::max(candidate.largestLocalSourceTensorSize, localTensorSize); } } } + + FactorAxesCandidate bestCandidate; + for (FactorAxesCandidate& candidate : candidates) { + // NOTE: The axes on replication factors are distributed to batching + // dimensions after the common axes are found for all non-replication + // factors. The communication cost calculation does not take this into + // account yet and hence is not ready for cases that sharding rule has + // replication factors. + if (shardingRule.getNeedReplicationFactors().empty()) { + candidate.communicationCost = getCommunicationCost( + shardingProjection, shardingRule, tensorSizes, localTensorSizes, + mesh, candidate.factorAxes, candidate.shardingSize); + } + if (isValid(candidate)) { + bestCandidate = std::max(bestCandidate, candidate); + } + } + return bestCandidate; } void dropFactorDependencies(const int64_t factorIndex) { @@ -395,9 +559,6 @@ class FactorAxesCandidateBag { } } - // Resets best. Performs in constant-time. - void resetBest() { bestCandidate = FactorAxesCandidate(); } - // Removes candidate at index. Performs in constant-time. After the // operation, the candidates before the index keep being before the index, and // the candidates after the index (except the removed one) keep being after @@ -411,20 +572,16 @@ class FactorAxesCandidateBag { candidates.pop_back(); } - // Returns the best. Performs in constant-time. - FactorAxesCandidate best() const { return bestCandidate; } // Returns the candidate at index. Performs in constant-time. FactorAxesCandidate& at(const int64_t index) { return candidates[index]; } // Returns the number of candidates in the bag. int64_t size() const { return candidates.size(); } - - private: - void resetLargestSourceTensorSizes() { - for (FactorAxesCandidate& candidate : candidates) { - candidate.largestSourceTensorSize = 0; - } + bool isValid(const FactorAxesCandidate& candidate) { + auto it = factorDependenciesMap.find(candidate.factorAxes.factorIndex); + return it == factorDependenciesMap.end() || it->second.none(); } + private: void initFactorDependencies(OpShardingRuleAttr shardingRule) { for (const TensorMappingAttr& tensorMapping : llvm::concat( @@ -443,17 +600,6 @@ class FactorAxesCandidateBag { } } - void updateBestCandidateIfValid(const FactorAxesCandidate& candidate) { - if (isValid(candidate)) { - bestCandidate = std::max(bestCandidate, candidate); - } - } - - bool isValid(const FactorAxesCandidate& candidate) { - auto it = factorDependenciesMap.find(candidate.factorAxes.factorIndex); - return it == factorDependenciesMap.end() || it->second.none(); - } - // A factor is non-full if its sharding size is smaller than the size of the // factor. `factorDependenciesMap` is a map from factor indices to bitvectors, // each bitvector is associated with a factor f, and represents the set of @@ -468,7 +614,6 @@ class FactorAxesCandidateBag { // hence it may depend on multiple factors. llvm::SmallDenseMap factorDependenciesMap; SmallVector candidates; - FactorAxesCandidate bestCandidate; // Used for recalculating sharding size of a candidate. MeshAttr mesh; }; @@ -476,13 +621,13 @@ class FactorAxesCandidateBag { FactorAxesCandidateBag findFactorAxesCandidates( const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule, ArrayRef tensorSizes, - const Mesh& mesh) { + MeshAttr mesh) { // TODO(enver): For two factor-axes pairs, if both have the same factor and // the same count, and one is the prefix of the other, drop the prefix one. // Count factor-axes pairs by iterating through each sharding, and for each // sharding, update candidate for the sharding and all its prefixes. - FactorAxesCandidatesMap factorAxesCandidatesMap; + DenseSet factorAxesPairs; for (const auto& [tensorSize, tensorFactorSharding] : llvm::zip_equal(tensorSizes, llvm::concat( shardingProjection.getOperands(), @@ -494,19 +639,29 @@ FactorAxesCandidateBag findFactorAxesCandidates( } ArrayRef axisRefs = factorSharding.axisRefs; while (!axisRefs.empty()) { - updateFactorAxesCandidate( - factorAxesCandidatesMap, - FactorAxesPair(factorIndex, AxisListRef(axisRefs)), tensorSize, - mesh, shardingRule.getFactorType(factorIndex)); + factorAxesPairs.insert(FactorAxesPair(factorIndex, axisRefs)); axisRefs = axisRefs.drop_back(); } } } - FactorAxesCandidateBag factorAxesCandidates(mesh.attr(), shardingRule); - for (const auto& [_, candidate] : factorAxesCandidatesMap) { - factorAxesCandidates.insert(candidate); + FactorAxesCandidateBag factorAxesCandidates(mesh, shardingRule); + for (const FactorAxesPair& factorAxes : factorAxesPairs) { + factorAxesCandidates.insert(factorAxes, shardingRule); + } + + // Set total global source tensor sizes of candidates. + for (const auto& [tensorSize, tensorFactorSharding] : + llvm::zip_equal(tensorSizes, llvm::concat( + shardingProjection.getOperands(), + shardingProjection.getResults()))) { + for (const auto& [factorIndex, factorSharding] : + tensorFactorSharding.factorIndexToSharding) { + factorAxesCandidates.updateTotalGlobalSourceTensorSizes( + factorIndex, factorSharding.axisRefs, tensorSize); + } } + return factorAxesCandidates; } @@ -527,18 +682,18 @@ AxesPerFactor toAxesPerFactor(const SmallVector& factorAxisRefs) { // until the list is empty. // // Guarantees to return a non-empty AxesPerFactor. -AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic( +AxesPerFactor findCommonAxesHeuristic( const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule, ArrayRef tensorSizes, const Mesh& mesh) { SmallVector factorAxisRefs(shardingRule.getNumFactors()); FactorAxesCandidateBag factorAxesCandidates = findFactorAxesCandidates( - shardingProjection, shardingRule, tensorSizes, mesh); - // TODO(enver): Assign an axis to a factor immediately if the count is more - // than floor(n/2) where n is the number of tensors. - while (!factorAxesCandidates.best().empty()) { - FactorAxesPair bestFactorAxes = factorAxesCandidates.best().factorAxes; - factorAxesCandidates.resetBest(); + shardingProjection, shardingRule, tensorSizes, mesh.attr()); + FactorAxesCandidate bestCandidate = + factorAxesCandidates.updateCommunicationCostsAndGetBest( + shardingProjection, tensorSizes, factorAxisRefs, shardingRule); + while (!bestCandidate.empty()) { + FactorAxesPair bestFactorAxes = bestCandidate.factorAxes; factorAxisRefs[bestFactorAxes.factorIndex] = bestFactorAxes.axes; if (bestFactorAxes.isFullySharded(shardingRule, mesh.attr())) { factorAxesCandidates.dropFactorDependencies(bestFactorAxes.factorIndex); @@ -601,10 +756,8 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic( factorAxesCandidates.updateShardingSizeAt(candidateIndex++); } - // TODO(enver): Optimize updating source tensor sizes. - factorAxesCandidates.resetBest(); - factorAxesCandidates.updateSourceTensorSizes(shardingProjection, - tensorSizes, factorAxisRefs); + bestCandidate = factorAxesCandidates.updateCommunicationCostsAndGetBest( + shardingProjection, tensorSizes, factorAxisRefs, shardingRule); } // TODO(enver): Consider to keep factorAxisRefs for longer until actual @@ -728,6 +881,12 @@ void distributeAxisRefsToBatchingFactors( AxesPerFactor findCommonAxes(const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule, ArrayRef tensorSizes, const Mesh& mesh) { + if (AxesPerFactor compatibleFactorShardings = + getCompatibleFactorShardings(shardingProjection, shardingRule); + !compatibleFactorShardings.empty()) { + return compatibleFactorShardings; + } + // Handle the special case of unary operations without factors that need // replication. Reshard only one of the tensors. if (shardingRule.getNonScalarTensorIndices().size() == 2 && @@ -737,7 +896,7 @@ AxesPerFactor findCommonAxes(const ShardingProjection& shardingProjection, tensorSizes, mesh); } - AxesPerFactor factorCommonAxes = findCommonAxesUsingMajorityVoteHeuristic( + AxesPerFactor factorCommonAxes = findCommonAxesHeuristic( shardingProjection, shardingRule, tensorSizes, mesh); // Distribute the greatest common prefix of shardings of factors that need diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/concatenate.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/concatenate.mlir index f1123a1c..d814b16f 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/concatenate.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/concatenate.mlir @@ -64,10 +64,11 @@ func.func @concatenate_operands_are_from_slices_of_the_same_tensor(%arg0: tensor func.func @concatenate_operands_are_results_of_slices_different_shardings_on_permutation_dim_with_equal_counts(%arg0: tensor<4x40x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}, {}]>}, %arg1: tensor<4x60x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) -> (tensor<4x80x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {}]>}) { %0 = stablehlo.slice %arg0 [0:4, 0:32, 0:256] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}, {}]>]>} : (tensor<4x40x256xf32>) -> tensor<4x32x256xf32> %1 = stablehlo.slice %arg1 [0:4, 0:48, 0:256] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {}]>]>} : (tensor<4x60x256xf32>) -> tensor<4x48x256xf32> - // CHECK: %[[RESHARD1:.*]] = sdy.reshard %0 <@mesh, [{}, {"x"}, {}]> : tensor<4x32x256xf32> - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %1 <@mesh, [{}, {"x"}, {}]> : tensor<4x48x256xf32> - // CHECK-NEXT: %[[CONCATENATE:.*]] = stablehlo.concatenate %[[RESHARD1]], %[[RESHARD2]], dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> - // CHECK-NEXT: return %[[CONCATENATE]] : tensor<4x80x256xf32> + // CHECK: %[[RESHARD0:.*]] = sdy.reshard %0 <@mesh, [{"x"}, {"y"}, {}]> : tensor<4x32x256xf32> + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %1 <@mesh, [{"x"}, {"y"}, {}]> : tensor<4x48x256xf32> + // CHECK-NEXT: %[[CONCATENATE:.*]] = stablehlo.concatenate %[[RESHARD0]], %[[RESHARD1]], dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[CONCATENATE]] <@mesh, [{}, {"x"}, {}]> : tensor<4x80x256xf32> + // CHECK-NEXT: return %[[RESHARD2]] : tensor<4x80x256xf32> %2 = stablehlo.concatenate %0, %1, dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [ {}, {"x"}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> return %2 : tensor<4x80x256xf32> @@ -77,10 +78,11 @@ func.func @concatenate_operands_are_results_of_slices_different_shardings_on_per func.func @concatenate_operands_are_results_of_slices_different_shardings_on_permutation_dim_with_equal_counts_but_conflicting_on_batching_dim(%arg0: tensor<4x40x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(2)2}, {}, {}]>}, %arg1: tensor<4x60x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {}]>}) -> (tensor<4x80x256xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {}]>}) { %0 = stablehlo.slice %arg0 [0:4, 0:32, 0:256] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x":(2)2}, {}, {}]>]>} : (tensor<4x40x256xf32>) -> tensor<4x32x256xf32> %1 = stablehlo.slice %arg1 [0:4, 0:48, 0:256] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {}]>]>} : (tensor<4x60x256xf32>) -> tensor<4x48x256xf32> - // CHECK: %[[RESHARD1:.*]] = sdy.reshard %0 <@mesh, [{}, {"x"}, {}]> : tensor<4x32x256xf32> - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %1 <@mesh, [{}, {"x"}, {}]> : tensor<4x48x256xf32> - // CHECK-NEXT: %[[CONCATENATE:.*]] = stablehlo.concatenate %[[RESHARD1]], %[[RESHARD2]], dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> - // CHECK-NEXT: return %[[CONCATENATE]] : tensor<4x80x256xf32> + // CHECK: %[[RESHARD0:.*]] = sdy.reshard %0 <@mesh, [{"x":(2)2}, {"y"}, {}]> : tensor<4x32x256xf32> + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %1 <@mesh, [{"x":(2)2}, {"y"}, {}]> : tensor<4x48x256xf32> + // CHECK-NEXT: %[[CONCATENATE:.*]] = stablehlo.concatenate %[[RESHARD0]], %[[RESHARD1]], dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x":(2)2}, {"y"}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[CONCATENATE]] <@mesh, [{}, {"x"}, {}]> : tensor<4x80x256xf32> + // CHECK-NEXT: return %[[RESHARD2]] : tensor<4x80x256xf32> %2 = stablehlo.concatenate %0, %1, dim = 1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}, {}]>]>} : (tensor<4x32x256xf32>, tensor<4x48x256xf32>) -> tensor<4x80x256xf32> return %2 : tensor<4x80x256xf32> } diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dot_dot_general.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dot_dot_general.mlir index 9105844d..65d5a113 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dot_dot_general.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dot_dot_general.mlir @@ -251,11 +251,10 @@ func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_n // CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded_smaller_local_contracting_dim func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded_smaller_local_contracting_dim(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<16x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) { - // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x16xf32> - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<16x16xf32> - // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x16xf32>, tensor<16x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {"x"}]> : tensor<8x16xf32> - // CHECK-NEXT: return %[[ALL_REDUCE]] : tensor<8x16xf32> + // CHECK-NEXT: %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x16xf32>, tensor<16x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %1 = sdy.all_reduce {"y"} %0 out_sharding=<@mesh, [{"x"}, {}]> : tensor<8x16xf32> + // CHECK-NEXT: %2 = sdy.reshard %1 <@mesh, [{}, {"x"}]> : tensor<8x16xf32> + // CHECK-NEXT: return %2 : tensor<8x16xf32> %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x16xf32>, tensor<16x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> } @@ -527,10 +526,11 @@ func.func @dot_general_one_suffix_has_larger_count_on_another_factor(%arg0: tens // CHECK-LABEL: func @dot_general_batching_dimension_shardings_have_common_prefix func.func @dot_general_batching_dimension_shardings_have_common_prefix(%arg0: tensor<64x8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"y", "x":(1)2, "t":(1)2}, {"t":(2)2}, {}]>}, %arg1: tensor<64x32x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{"y", "x":(1)2, "t":(2)2}, {}, {"t":(1)2}]>}) ->(tensor<64x8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xyzt, [{}, {"t":(2)2}, {"t":(1)2}]>}) { - // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh_xyzt, [{"y", "x":(1)2, "t":(2)2}, {}, {}]> : tensor<64x8x32xf32> - // CHECK-NEXT: %[[DOTGENERAL:.*]] = stablehlo.dot_general %[[RESHARD1]], %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyzt, [{"y", "x":(1)2, "t":(2)2}, {}, {"t":(1)2}]>]>} : (tensor<64x8x32xf32>, tensor<64x32x16xf32>) -> tensor<64x8x16xf32> - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[DOTGENERAL]] <@mesh_xyzt, [{}, {"t":(2)2}, {"t":(1)2}]> : tensor<64x8x16xf32> - // CHECK-NEXT: return %[[RESHARD2]] : tensor<64x8x16xf32> + // CHECK-NEXT: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh_xyzt, [{"y", "x":(1)2}, {"t":(2)2}, {}]> + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh_xyzt, [{"y", "x":(1)2}, {}, {"t":(1)2}]> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot_general %[[RESHARD0]], %[[RESHARD1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyzt, [{"y", "x":(1)2}, {"t":(2)2}, {"t":(1)2}]>]>} + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[DOT]] <@mesh_xyzt, [{}, {"t":(2)2}, {"t":(1)2}]> + // CHECK-NEXT: return %[[RESHARD2]] %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyzt, [{}, {"t":(2)2}, {"t":(1)2}]>]>} : (tensor<64x8x32xf32>, tensor<64x32x16xf32>) -> tensor<64x8x16xf32> return %0 : tensor<64x8x16xf32> } @@ -586,15 +586,71 @@ func.func @dot_only_contracting_dims_sharded_and_has_same_shardings( return %0 : tensor<8x16xf32> } +// The following 4 test targets are analyzed quantitatively in b/448376870#comment6. +// In short, keep the largest factor sharded. + +// CHECK-LABEL: func @dot_ij_jk_ik_i_is_largest +func.func @dot_ij_jk_ik_i_is_largest( + %arg0: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) + -> (tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK-NEXT: %[[RESHARD_LHS:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> : tensor<16x8xf32> + // CHECK-NEXT: %[[RESHARD_RHS:.*]] = sdy.reshard %arg1 <@mesh, [{}, {}]> : tensor<8x8xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD_LHS]], %[[RESHARD_RHS]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<16x8xf32>, tensor<8x8xf32>) -> tensor<16x8xf32> + // CHECK-NEXT: return %[[DOT]] : tensor<16x8xf32> + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<16x8xf32>, tensor<8x8xf32>) -> tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: func @dot_ij_jk_ik_j_is_largest +func.func @dot_ij_jk_ik_j_is_largest( + %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}, + %arg1: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) + -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK-NEXT: %[[RESHARD_RHS:.*]] = sdy.reshard %arg1 <@mesh, [{"x"}, {}]> : tensor<16x8xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %[[RESHARD_RHS]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"x"} %[[DOT]] out_sharding=<@mesh, [{}, {}]> : tensor<8x8xf32> + // CHECK-NEXT: %[[RESHARD_OUT:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh, [{"x"}, {}]> : tensor<8x8xf32> + // CHECK-NEXT: return %[[RESHARD_OUT]] : tensor<8x8xf32> + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @dot_ij_jk_ik_k_is_largest +func.func @dot_ij_jk_ik_k_is_largest( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}, + %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) + -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK-NEXT: %[[RESHARD_LHS:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> : tensor<8x8xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD_LHS]], %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[RESHARD_OUT:.*]] = sdy.reshard %[[DOT]] <@mesh, [{"x"}, {}]> : tensor<8x16xf32> + // CHECK-NEXT: return %[[RESHARD_OUT]] : tensor<8x16xf32> + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: func @dot_ij_jk_ik_same_ijk +func.func @dot_ij_jk_ik_same_ijk( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) + -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK-NEXT: %[[RESHARD_LHS:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> : tensor<8x8xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD_LHS]], %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %[[RESHARD_OUT:.*]] = sdy.reshard %[[DOT]] <@mesh, [{"x"}, {}]> : tensor<8x8xf32> + // CHECK-NEXT: return %[[RESHARD_OUT]] : tensor<8x8xf32> + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} + // CHECK-LABEL: func @dot_on_square_matrices_lhs_2nd_dim_rhs_2nd_dim_sharded_the_same_way func.func @dot_on_square_matrices_lhs_2nd_dim_rhs_2nd_dim_sharded_the_same_way( %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) -> tensor<8x8xf32> { - // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> - // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}]>]>} - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[DOT]] <@mesh, [{}, {}]> - // CHECK-NEXT: return %[[RESHARD2]] + // CHECK-NEXT: %[[RESHARD0:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {}]> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %[[RESHARD0]] + // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {}]> + // CHECK-NEXT: return %[[ALL_REDUCE]] %0 = stablehlo.dot %arg0, %arg1 : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -604,10 +660,10 @@ func.func @dot_on_rectangular_inputs_square_output_small_contracting_dim_lhs_2nd %arg0: tensor<8x2xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<2x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) -> tensor<8x8xf32> { - // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> - // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}]>]>} - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[DOT]] <@mesh, [{}, {}]> - // CHECK-NEXT: return %[[RESHARD2]] + // CHECK-NEXT: %[[RESHARD0:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {}]> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %0 + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {}]> + // CHECK-NEXT: return %[[RESHARD1]] %0 = stablehlo.dot %arg0, %arg1 : (tensor<8x2xf32>, tensor<2x8xf32>) -> tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -617,10 +673,10 @@ func.func @dot_on_rectangular_inputs_square_output_large_contracting_dim_lhs_2nd %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<16x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) -> tensor<8x8xf32> { - // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {}]> - // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}]>]>} - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[DOT]] <@mesh, [{}, {}]> - // CHECK-NEXT: return %[[RESHARD2]] + // CHECK-NEXT: %0 = sdy.reshard %arg1 <@mesh, [{"y"}, {}]> : tensor<16x8xf32> + // CHECK-NEXT: %1 = stablehlo.dot %arg0, %0 : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %2 = sdy.all_reduce {"y"} %1 out_sharding=<@mesh, [{}, {}]> : tensor<8x8xf32> + // CHECK-NEXT: return %2 : tensor<8x8xf32> %0 = stablehlo.dot %arg0, %arg1 : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -652,6 +708,46 @@ func.func @dot_on_rectangular_inputs_square_output_large_contracting_dim_lhs_2nd return %0 : tensor<8x8xf32> } +// CHECK-LABEL: func @dot_result_is_smaller_than_rhs_due_to_other_axes +func.func @dot_result_is_smaller_than_rhs_due_to_other_axes( + %arg0: tensor<8x32x64xf32> {sdy.sharding = #sdy.sharding<@mesh_xyz, [{"x"}, {"z"}, {"y"}]>}, + %arg1: tensor<64x256xf32> {sdy.sharding = #sdy.sharding<@mesh_xyz, [{"y"}, {}]>}) + -> (tensor<8x32x256xf32> {sdy.sharding = #sdy.sharding<@mesh_xyz, [{"x", "y"}, {"z"}, {}]>}) { + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"x"}, {"z"}, {}]>]>} + // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh_xyz, [{"x"}, {"z"}, {}]> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh_xyz, [{"x", "y"}, {"z"}, {}]> + // CHECK-NEXT: return %[[RESHARD]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [2] x [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"x", "y"}, {"z"}, {}]>]>} : (tensor<8x32x64xf32>, tensor<64x256xf32>) -> tensor<8x32x256xf32> + return %0 : tensor<8x32x256xf32> +} + +// CHECK-LABEL: func @dot_all_factors_have_the_same_sharding_one_non_contracting_dim_is_largest_the_other_smallest_result_tensor_is_sharded_on_larger_factor +func.func @dot_all_factors_have_the_same_sharding_one_non_contracting_dim_is_largest_the_other_smallest_result_tensor_is_sharded_on_larger_factor( + %arg0: tensor<128x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}, + %arg1: tensor<8x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) + -> (tensor<128x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { + // CHECK-NEXT: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh, [{}, {}]> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD0]], %[[RESHARD1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} + // CHECK-NEXT: return %[[DOT]] + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<128x8xf32>, tensor<8x4xf32>) -> tensor<128x4xf32> + return %0 : tensor<128x4xf32> +} + +// CHECK-LABEL: func @dot_all_factors_have_the_same_sharding_one_non_contracting_dim_is_largest_the_other_smallest_result_tensor_is_sharded_on_smaller_factor +func.func @dot_all_factors_have_the_same_sharding_one_non_contracting_dim_is_largest_the_other_smallest_result_tensor_is_sharded_on_smaller_factor( + %arg0: tensor<128x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, + %arg1: tensor<8x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) + -> (tensor<128x4xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) { + // CHECK-NEXT: %[[RESHARD0:.*]] = sdy.reshard %arg1 <@mesh, [{}, {}]> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %[[RESHARD0]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %[[DOT]] <@mesh, [{}, {"x"}]> + // CHECK-NEXT: return %[[RESHARD1]] + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<128x8xf32>, tensor<8x4xf32>) -> tensor<128x4xf32> + return %0 : tensor<128x4xf32> +} + // This one is derived from b/456082569#comment8. // CHECK-LABEL: func @dot_general_sharded_contracting_dim_with_axes_redistribution func.func @dot_general_sharded_contracting_dim_with_axes_redistribution( diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/reshape.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/reshape.mlir index fd420418..ff52cace 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/reshape.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/reshape.mlir @@ -22,9 +22,9 @@ func.func @reshape_simple_merge_sharding_is_from_x_to_x_and_x_fits_exactly_to_fi // CHECK-LABEL: func.func @reshape_simple_merge_sharding_is_from_x_to_y_and_x_fits_exactly_to_first_dim func.func @reshape_simple_merge_sharding_is_from_x_to_y_and_x_fits_exactly_to_first_dim(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { - // CHECK-NEXT: %0 = stablehlo.reshape %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> - // CHECK-NEXT: %1 = sdy.reshard %0 <@mesh, [{"y"}]> : tensor<32xf32> - // CHECK-NEXT: return %1 : tensor<32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}, {}]> : tensor<4x8xf32> + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> + // CHECK-NEXT: return %[[RESHAPE]] : tensor<32xf32> %0 = stablehlo.reshape %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> return %0 : tensor<32xf32> } @@ -38,9 +38,9 @@ func.func @reshape_simple_merge_sharding_is_from_y_to_y_and_y_underfits_to_first // CHECK-LABEL: func.func @reshape_simple_merge_sharding_is_from_y_to_x_and_y_underfits_to_first_dim func.func @reshape_simple_merge_sharding_is_from_y_to_x_and_y_underfits_to_first_dim(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) { - // CHECK-NEXT: %0 = sdy.reshard %arg0 <@mesh, [{"x"}, {}]> : tensor<4x8xf32> - // CHECK-NEXT: %1 = stablehlo.reshape %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> - // CHECK-NEXT: return %1 : tensor<32xf32> + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[RESHAPE]] <@mesh, [{"x"}]> : tensor<32xf32> + // CHECK-NEXT: return %[[RESHARD]] : tensor<32xf32> %0 = stablehlo.reshape %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> return %0 : tensor<32xf32> } @@ -55,9 +55,9 @@ func.func @reshape_simple_merge_sharding_is_from_xy_to_xy_and_x_fits_exactly_to_ // CHECK-LABEL: func.func @reshape_simple_merge_sharding_is_from_xy_to_yx_and_x_fits_exactly_to_first_dim // NOTE: It reshards this way because the dependencies are dropped as factors are fully-sharded. func.func @reshape_simple_merge_sharding_is_from_xy_to_yx_and_x_fits_exactly_to_first_dim(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y", "x"}]>}) { - // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y", "x":(1)2}, {"x":(2)2}]> : tensor<4x8xf32> - // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[RESHARD]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y", "x"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> - // CHECK-NEXT: return %[[RESHAPE]] : tensor<32xf32> + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", "y"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[RESHAPE]] <@mesh, [{"y", "x"}]> : tensor<32xf32> + // CHECK-NEXT: return %[[RESHARD]] : tensor<32xf32> %0 = stablehlo.reshape %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y", "x"}]>]>} : (tensor<4x8xf32>) -> tensor<32xf32> return %0 : tensor<32xf32> }