Skip to content

Commit

Permalink
Address comments (4)
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Feb 18, 2025
1 parent 160f9b7 commit ac77783
Showing 1 changed file with 25 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,20 @@ static bool checkContractionOpEquivalence(Operation *aOp, Operation *bOp) {
}

// Check that the n-dimensions are the same
FailureOr<linalg::ContractionDimensions> aContactionDims =
FailureOr<linalg::ContractionDimensions> aContractionDims =
linalg::inferContractionDims(aLinalgOp);
FailureOr<linalg::ContractionDimensions> bContactionDims =
linalg::inferContractionDims(bLinalgOp);
if (failed(aContactionDims) || failed(bContactionDims)) {
if (failed(aContractionDims) || failed(bContactionDims)) {
return false;
}
if (aContactionDims.value() != bContactionDims.value()) {
if (aContractionDims.value() != bContactionDims.value()) {
return false;
}

SmallVector<int64_t, 4> aStaticDims = aLinalgOp.getStaticLoopRanges();
SmallVector<int64_t, 4> bStaticDims = bLinalgOp.getStaticLoopRanges();
for (auto nDim : aContactionDims->n) {
for (auto nDim : aContractionDims->n) {
if (aStaticDims[nDim] != bStaticDims[nDim] ||
ShapedType::isDynamic(aStaticDims[nDim])) {
return false;
Expand Down Expand Up @@ -133,19 +133,6 @@ static bool checkContractionOpEquivalence(Operation *aOp, Operation *bOp) {
return true;
}

/// Check that an operation is a `contraction`
static bool isEquivalentContractionOp(
linalg::LinalgOp linalgOp,
std::optional<linalg::LinalgOp> seedContractionOp = std::nullopt) {
if (!linalg::isaContractionOpInterface(linalgOp)) {
return false;
}
if (seedContractionOp) {
return checkContractionOpEquivalence(seedContractionOp.value(), linalgOp);
}
return true;
}

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the `slice` of the operation does not contain any op
/// from the group.
Expand Down Expand Up @@ -179,7 +166,7 @@ static bool isHorizontalToGroup(Operation *op,
/// %4 = linalg.matmul ins(%arg0, concat(%arg1, %arg2, %arg3))
/// ```
///
/// Note: The actual operation generated does not concat the RHS
/// Note: The actual operation generated does not concat the RHS.
static std::optional<SmallVector<Operation *>> getHorizontalFusionGroupMembers(
linalg::LinalgOp seedOp,
const llvm::SmallDenseSet<Operation *> &groupedOperations,
Expand All @@ -197,7 +184,8 @@ static std::optional<SmallVector<Operation *>> getHorizontalFusionGroupMembers(
}

// Constraints of the operation itself.
if (!isEquivalentContractionOp(linalgOp, seedOp)) {
if (!linalg::isaContractionOpInterface(linalgOp) ||
!checkContractionOpEquivalence(linalgOp, seedOp)) {
return false;
}
if (groupedOperations.contains(linalgOp)) {
Expand Down Expand Up @@ -287,10 +275,10 @@ permuteIndexingMapsToMatchSeedLhs(RewriterBase &rewriter,
getResultDimsRange(seedLhsIndexingMap.getResults());
auto lhsResultDimsRange = getResultDimsRange(lhsIndexingMap.getResults());

// Start with a identity permutations. For now try to only swap dimensions
// Start with an identity permutations. For now try to only swap dimensions
// which is not a general solution.
SmallVector<unsigned> interchangeVector =
llvm::to_vector(llvm::seq<unsigned>(0, lhsIndexingMap.getNumDims()));
SmallVector<int64_t> interchangeVector =
llvm::to_vector(llvm::seq<int64_t>(0, lhsIndexingMap.getNumDims()));
for (auto [seedDimPos, lhsDimPos] :
llvm::zip_equal(seedLhsResultDimsRange, lhsResultDimsRange)) {
if (seedDimPos == lhsDimPos) {
Expand All @@ -299,8 +287,7 @@ permuteIndexingMapsToMatchSeedLhs(RewriterBase &rewriter,
// If the current positions are what we started with, swap the positions.
if (interchangeVector[lhsDimPos] == lhsDimPos &&
interchangeVector[seedDimPos] == seedDimPos) {
interchangeVector[lhsDimPos] = seedDimPos;
interchangeVector[seedDimPos] = lhsDimPos;
std::swap(interchangeVector[lhsDimPos], interchangeVector[seedDimPos]);
continue;
}
// If this was a changed dimension, check that it is consistent.
Expand All @@ -313,6 +300,7 @@ permuteIndexingMapsToMatchSeedLhs(RewriterBase &rewriter,
// Check that the iterator types remain the same
SmallVector<utils::IteratorType> permutedIteratorTypes =
llvm::to_vector(iteratorTypes);
applyPermutationToVector(permutedIteratorTypes, interchangeVector);
if (permutedIteratorTypes != iteratorTypes) {
return failure();
}
Expand All @@ -331,7 +319,9 @@ permuteIndexingMapsToMatchSeedLhs(RewriterBase &rewriter,
/// results, corresponding to the results of the fused operations. It is assumed
/// that the LHS of the contraction operations fused horizontally is the same
/// and have the same indexing map for all the operations. The RHS/outputs of
/// the operations can be different.
/// the operations can be different, but share the same iteration space.
/// Returns the generated fused op, or `std::nullopt` when the fused op
/// could not be generated.
static std::optional<linalg::GenericOp>
fuseContractionsHorizontally(RewriterBase &rewriter, Location loc,
MutableArrayRef<Operation *> linalgOps) {
Expand All @@ -347,7 +337,7 @@ fuseContractionsHorizontally(RewriterBase &rewriter, Location loc,

auto seedOp = cast<linalg::LinalgOp>(linalgOps.front());
SmallVector<utils::IteratorType> fusedIteratorTypes =
cast<linalg::LinalgOp>(seedOp).getIteratorTypesArray();
seedOp.getIteratorTypesArray();

OpOperand *seedOpLhs = seedOp.getDpsInputOperand(0);
AffineMap seedOpLhsIndexingMap = seedOp.getMatchingIndexingMap(seedOpLhs);
Expand All @@ -371,7 +361,7 @@ fuseContractionsHorizontally(RewriterBase &rewriter, Location loc,
continue;
}

// Append the RHS operands;
// Append the RHS operands.
SmallVector<OpOperand *> ins = linalgOp.getDpsInputOperands();
llvm::append_range(
fusedIns,
Expand Down Expand Up @@ -406,8 +396,8 @@ fuseContractionsHorizontally(RewriterBase &rewriter, Location loc,
fusedIteratorTypes, [](OpBuilder &, Location, ValueRange) {});

Block *fusedBody = fusedOp.getBlock();
auto rhsIndex = 0;
auto outsIndex = fusedOp.getNumDpsInputs();
int64_t rhsIndex = 0;
int64_t outsIndex = fusedOp.getNumDpsInputs();
SmallVector<Value> yieldVals;
for (auto op : linalgOps) {
if (droppedOps.contains(op)) {
Expand Down Expand Up @@ -440,9 +430,9 @@ fuseContractionsHorizontally(RewriterBase &rewriter, Location loc,
rewriter.setInsertionPointToEnd(fusedBody);
rewriter.create<linalg::YieldOp>(loc, yieldVals);

auto resultsIndex = 0;
unsigned resultsIndex = 0;
for (auto linalgOp : linalgOps) {
auto numResults = linalgOp->getNumResults();
unsigned numResults = linalgOp->getNumResults();
rewriter.replaceOp(linalgOp,
fusedOp->getResults().slice(resultsIndex, numResults));
resultsIndex += numResults;
Expand All @@ -469,7 +459,9 @@ static void fuseGroup(RewriterBase &rewriter,
return;
}

fuseContractionsHorizontally(rewriter, loc, fusionGroup);
std::optional<linalg::GenericOp> fusedOp =
fuseContractionsHorizontally(rewriter, loc, fusionGroup);
(void)fusedOp;
}

void FuseHorizontalContractionsPass::runOnOperation() {
Expand All @@ -480,7 +472,7 @@ void FuseHorizontalContractionsPass::runOnOperation() {
llvm::SmallDenseSet<Operation *> groupedOperations;

getOperation()->walk([&](linalg::LinalgOp linalgOp) {
if (!isEquivalentContractionOp(linalgOp)) {
if (!linalg::isaContractionOpInterface(linalgOp)) {
return;
}
// Avoid already grouped operations;
Expand Down

0 comments on commit ac77783

Please sign in to comment.