Skip to content

Commit

Permalink
Add pattern to convert generic conv ops to IGEMM (#19798)
Browse files Browse the repository at this point in the history
This PR removes the named op patterns to convert convs to IGEMM and
replaces them with a generic pattern that works for all supported convs.
A new utility function that populates the shared details required for
setting lowering config and doing the IGEMM computation is added.
The PR is currently using a default true flag
`iree-gpu-use-tile-and-fuse-generic-convolution` . The idea is that
since a lot more convolutions will go down the IGEMM path with this PR
if any of them run into issues we can turn the flag off by default
rather then needing to revert the whole PR. If after some time we find
that there are no issues then we can drop the flag and have this
happening always.

---------

Signed-off-by: Nirvedh Meshram <[email protected]>
  • Loading branch information
nirvedhmeshram authored Jan 29, 2025
1 parent 3f713f5 commit 50a7087
Show file tree
Hide file tree
Showing 7 changed files with 493 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@

namespace mlir::iree_compiler::IREE::GPU {

// TODO (nirvedhmeshram) : This flag allows a lot more convolutions to use IGEMM
// so drop this flag after sufficient use with no issues.
llvm::cl::opt<bool> clGPUUseTileAndFuseGenericConvolution(
"iree-gpu-use-tile-and-fuse-generic-convolution",
llvm::cl::desc(
"enable the tile and fuse pipeline for generic convolutions"),
llvm::cl::init(true));

constexpr int64_t kCacheLineSizeBits = 128 * 8;
constexpr int64_t kPreferredCopyNumBits = 128;

Expand Down Expand Up @@ -371,12 +379,25 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
return failure();

LDBG("IGEMM TileAndFuse Config");
FailureOr<SmallVector<AffineMap>> igemmContractionMaps =
LinalgExt::getIGEMMContractionIndexingMaps(linalgOp);
FailureOr<SmallVector<int64_t>> igemmLoopBounds =
LinalgExt::getIGEMMLoopBounds(linalgOp);
FailureOr<SmallVector<Value>> igemmOperands =
LinalgExt::getIGEMMOperands(linalgOp);
FailureOr<SmallVector<AffineMap>> igemmContractionMaps;
FailureOr<SmallVector<int64_t>> igemmLoopBounds;
FailureOr<SmallVector<Value>> igemmOperands;
if (!clGPUUseTileAndFuseGenericConvolution) {
igemmContractionMaps = LinalgExt::getIGEMMContractionIndexingMaps(linalgOp);
igemmLoopBounds = LinalgExt::getIGEMMLoopBounds(linalgOp);
igemmOperands = LinalgExt::getIGEMMOperands(linalgOp);
} else {
FailureOr<LinalgExt::IGEMMGenericConvDetails> igemmGenericConvDetails =
LinalgExt::getIGEMMGenericConvDetails(linalgOp);
if (failed(igemmGenericConvDetails)) {
LDBG("Unsupported generic convolution type");
return failure();
}
igemmContractionMaps = igemmGenericConvDetails->igemmContractionMaps;
igemmLoopBounds = igemmGenericConvDetails->igemmLoopBounds;
igemmOperands = igemmGenericConvDetails->igemmOperands;
}

if (failed(igemmContractionMaps) || failed(igemmLoopBounds) ||
failed(igemmOperands)) {
LDBG("Unsupported convolution type");
Expand Down
Loading

0 comments on commit 50a7087

Please sign in to comment.