Skip to content

Commit

Permalink
Address comments (3)
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Feb 15, 2025
1 parent dcfa537 commit 968c63a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct FuseHorizontalContractionsPass final
} // namespace

/// Helper method to check operations equivalence
static bool checkOperationEquivalence(Operation *aOp, Operation *bOp) {
static bool checkContractionOpEquivalence(Operation *aOp, Operation *bOp) {
auto aLinalgOp = dyn_cast<linalg::LinalgOp>(aOp);
auto bLinalgOp = dyn_cast<linalg::LinalgOp>(bOp);

Expand Down Expand Up @@ -112,7 +112,7 @@ static bool checkOperationEquivalence(Operation *aOp, Operation *bOp) {
}

// Check that the output rank and element type are the same. We dont check the
// type cause we allow RHS to be transposes.
// type cause we allow output to be transposes.
if (!checkSameRankAndElementType(aLinalgOp.getDpsInitOperand(0)->get(),
bLinalgOp.getDpsInitOperand(0)->get())) {
return false;
Expand All @@ -133,15 +133,15 @@ static bool checkOperationEquivalence(Operation *aOp, Operation *bOp) {
return true;
}

/// Check that an operation is a `empty -> fill -> contraction`
/// Check that an operation is a `contraction`
static bool isEquivalentContractionOp(
linalg::LinalgOp linalgOp,
std::optional<linalg::LinalgOp> seedContractionOp = std::nullopt) {
if (!linalg::isaContractionOpInterface(linalgOp)) {
return false;
}
if (seedContractionOp) {
return checkOperationEquivalence(seedContractionOp.value(), linalgOp);
return checkContractionOpEquivalence(seedContractionOp.value(), linalgOp);
}
return true;
}
Expand Down Expand Up @@ -345,9 +345,9 @@ fuseContractionsHorizontally(RewriterBase &rewriter, Location loc,
SmallVector<AffineMap> fusedInsIndexingMaps;
SmallVector<AffineMap> fusedOutsIndexingMaps;

linalg::LinalgOp seedOp = cast<linalg::LinalgOp>(linalgOps.front());
auto seedOp = cast<linalg::LinalgOp>(linalgOps.front());
SmallVector<utils::IteratorType> fusedIteratorTypes =
cast<linalg::LinalgOp>(linalgOps.front()).getIteratorTypesArray();
cast<linalg::LinalgOp>(seedOp).getIteratorTypesArray();

OpOperand *seedOpLhs = seedOp.getDpsInputOperand(0);
AffineMap seedOpLhsIndexingMap = seedOp.getMatchingIndexingMap(seedOpLhs);
Expand Down Expand Up @@ -459,7 +459,7 @@ static void fuseGroup(RewriterBase &rewriter,
})) {
return;
}
linalg::LinalgOp baseContractOp = cast<linalg::LinalgOp>(fusionGroup.front());
auto baseContractOp = cast<linalg::LinalgOp>(fusionGroup.front());
Location loc = baseContractOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(baseContractOp);
Expand Down
26 changes: 25 additions & 1 deletion compiler/src/iree/compiler/DispatchCreation/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,31 @@ def FoldUnitExtentDimsPass :

def FuseHorizontalContractionsPass:
InterfacePass<"iree-dispatch-creation-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> {
let summary = "Fuses horizontal contraction ops without fusions";
let summary = "Fuses horizontal contraction ops";
let description = [{
For cases where multiple contractions
- that dont have a direct dependence
- that have the same LHS operand
- all the N dimensions of the RHS operands used are the same
Such contractions can be executed as a single contraction, i.e.

A = matmul(lhs, rhs0);
B = matmul(lhs, rhs1);
C = matmul(lhs, rhs2);

can be combined into
result = matmul(lhs, concat_along_N(rhs0, rhs1, rhs2));
A = slice0(result)
B = slice1(result)
C = slice2(result)

Instead of doing an actual concat of the RHS operands,
and extracting slices of the result, the pass generates a single
operation with
- the lhs operands
- all the rhs operands
- multiple results representing the individual matmuls
}]
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::tensor::TensorDialect",
Expand Down

0 comments on commit 968c63a

Please sign in to comment.