Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW committed Dec 18, 2024
1 parent a3b3739 commit cb19fce
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts());
}

EncodingAttr EncodingAttr::cloneWithLayouts(SmallVector<Attribute> layouts) {
EncodingAttr EncodingAttr::cloneWithLayouts(ArrayRef<Attribute> layouts) {
MLIRContext *ctx = getContext();
return get(ctx, getOperandIndex(), getOpType(), getElementTypes(),
/*user_indexing_maps=*/ArrayAttr(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def EncodingAttr :

/// Clones an encoding with a new layout list and drops other optional
/// parameters (because they are resolved).
EncodingAttr cloneWithLayouts(SmallVector<Attribute> layouts);
EncodingAttr cloneWithLayouts(ArrayRef<Attribute> layouts);
}];

let genVerifyDecl = 0;
Expand Down
18 changes: 9 additions & 9 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
Expand Down Expand Up @@ -121,19 +122,18 @@ class HALAffinityAnalysisDialectInterface
: public IREE::Stream::AffinityAnalysisDialectInterface {
public:
using AffinityAnalysisDialectInterface::AffinityAnalysisDialectInterface;
IREE::Stream::LayoutAttrSolverFn
makeLayoutAttrSolver(ModuleOp moduleOp) const {
return [=](IREE::Stream::AffinityAttr aff, Operation *op,
SetVector<Attribute> &layoutAttrs) {
// TODO: This needs to be in the lambda. Otherwise, it could crash because
// the root op (i.e., the original moduleOp) could be outdated.
IREE::Stream::ResolveLayoutAttrFn
makeLayoutAttrResolver(ModuleOp moduleOp) const {
return [=](IREE::Stream::AffinityAttr affinityAttr, Operation *op,
SetVector<Attribute> &layoutAttrs) -> LogicalResult {
// This needs to be in the lambda because the moduleOp could be modified..
IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
if (failed(deviceAnalysis.run())) {
op->emitError("failed to run DeviceAnalysis");
return failure();
return op->emitError("failed to run DeviceAnalysis");
}
SetVector<IREE::HAL::ExecutableTargetAttr> resultSet;
deviceAnalysis.gatherRequiredExecutableTargets(aff, op, resultSet);
deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
resultSet);
// TODO(hanchung): Populate the EncodingLayoutAttr when it is ready.
layoutAttrs.insert(resultSet.begin(), resultSet.end());
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@

namespace mlir::iree_compiler::IREE::Stream {

using LayoutAttrSolverFn = std::function<LogicalResult(
using ResolveLayoutAttrFn = std::function<LogicalResult(
AffinityAttr, Operation *, SetVector<Attribute> &)>;

class AffinityAnalysisDialectInterface
: public DialectInterface::Base<AffinityAnalysisDialectInterface> {
public:
AffinityAnalysisDialectInterface(Dialect *dialect) : Base(dialect) {}

virtual LayoutAttrSolverFn makeLayoutAttrSolver(ModuleOp moduleOp) const = 0;
/// The `moduleOp` must remain live and unmodified for as long as the returned
/// capture is. Otherwise, it will likely be incorrect or crash if the module
/// op is mutated, especially when module scope analysis is run.
virtual ResolveLayoutAttrFn
makeLayoutAttrResolver(ModuleOp moduleOp) const = 0;
};

} // namespace mlir::iree_compiler::IREE::Stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {

// TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType.
static RankedTensorType cloneWithEncoding(RankedTensorType type,
Attribute encoding) {
Attribute encodingAttr) {
return RankedTensorType::get(type.getShape(), type.getElementType(),
encoding);
encodingAttr);
}

static LogicalResult
addLayoutsToTensorPhaseOps(ModuleOp moduleOp, FunctionOpInterface funcOp,
LayoutAttrSolverFn makeLayoutAttrFn) {
SmallVector<AffinityOpInterface> candidates;
funcOp.walk([&](AffinityOpInterface affinityOp) {
static LogicalResult addLayoutsToTensorPhaseOps(
ModuleOp moduleOp, FunctionOpInterface funcOp,
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
SmallVector<IREE::Stream::AffinityOpInterface> candidates;
funcOp.walk([&](IREE::Stream::AffinityOpInterface affinityOp) {
// Only need to update encoding types for ops that have TensorPhaseOp trait.
if (!affinityOp->hasTrait<OpTrait::IREE::Stream::TensorPhaseOp>()) {
return;
Expand All @@ -73,8 +73,8 @@ addLayoutsToTensorPhaseOps(ModuleOp moduleOp, FunctionOpInterface funcOp,
// TODO(hanchung): We should use the default device in this case. However,
// it is not guaranteed that default device attribute will always be set in
// the IR. (Is the statement correct?)
auto affAttr = affinityOp.getAffinityAttr();
if (!affAttr) {
auto affinityAttr = affinityOp.getAffinityAttr();
if (!affinityAttr) {
return;
}
candidates.push_back(affinityOp);
Expand All @@ -86,47 +86,48 @@ addLayoutsToTensorPhaseOps(ModuleOp moduleOp, FunctionOpInterface funcOp,

IRRewriter rewriter(funcOp.getContext());
for (auto affinityOp : candidates) {
auto affAttr = affinityOp.getAffinityAttr();
auto affinityAttr = affinityOp.getAffinityAttr();
SetVector<Attribute> layouts;
if (failed(makeLayoutAttrFn(affAttr, moduleOp, layouts))) {
affinityOp.emitError("failed on making layouts");
return failure();
if (failed(resolveLayoutAttr(affinityAttr, moduleOp, layouts))) {
return affinityOp.emitError("failed on making layouts");
}

// Returns an updated encoding attribute if an encoding attribute is present
// in the type. Otherwise, returns std::nullopt.
auto getEncodingWithNewLayouts =
[=](Type type) -> std::optional<IREE::Encoding::EncodingAttr> {
auto rankedTensorType = dyn_cast<RankedTensorType>(type);
if (!rankedTensorType) {
return std::nullopt;
}
auto encoding = IREE::Encoding::getEncodingAttr(rankedTensorType);
if (!encoding) {
auto encodingAttr = IREE::Encoding::getEncodingAttr(rankedTensorType);
if (!encodingAttr) {
return std::nullopt;
}
SmallVector<Attribute> attrs(layouts.begin(), layouts.end());
return encoding.cloneWithLayouts(attrs);
return encodingAttr.cloneWithLayouts(layouts.getArrayRef());
};

// TODO(hanchung): Update other Stream operations.
LogicalResult result =
TypeSwitch<Operation *, LogicalResult>(affinityOp)
.Case<Stream::TensorSizeOfOp>([&](auto sizeOfOp) {
.Case<IREE::Stream::TensorSizeOfOp>([&](auto sizeOfOp) {
auto encodingType =
dyn_cast<RankedTensorType>(sizeOfOp.getEncoding());
if (!encodingType) {
return success();
}
std::optional<IREE::Encoding::EncodingAttr> encoding =
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(encodingType);
if (!encoding.has_value()) {
if (!encodingAttr.has_value()) {
return success();
}
rewriter.modifyOpInPlace(sizeOfOp, [&] {
sizeOfOp.setEncoding(
cloneWithEncoding(encodingType, encoding.value()));
cloneWithEncoding(encodingType, encodingAttr.value()));
});
return success();
})
.Default([](auto *op) { return success(); });
.Default([](auto *op) { return failure(); });

if (failed(result)) {
return failure();
Expand All @@ -140,25 +141,24 @@ struct SpecializeEncodingsPass
: public impl::SpecializeEncodingsPassBase<SpecializeEncodingsPass> {
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
auto usedDialects =
gatherUsedDialectInterfaces<AffinityAnalysisDialectInterface>(moduleOp);
auto usedDialects = gatherUsedDialectInterfaces<
IREE::Stream::AffinityAnalysisDialectInterface>(moduleOp);
if (usedDialects.size() != 1) {
moduleOp.emitError("expected only one dialect implementing "
"AffinityAnalysisDialectInterface");
return signalPassFailure();
}

SymbolTable symbolTable(moduleOp);
llvm::MapVector<StringRef, IREE::Stream::ExecutableOp> executableOps;
for (auto executableOp : moduleOp.getOps<IREE::Stream::ExecutableOp>()) {
executableOps[executableOp.getName()] = executableOp;
}

LayoutAttrSolverFn makeLayoutAttrFn =
usedDialects[0]->makeLayoutAttrSolver(moduleOp);
for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) {
if (failed(
addLayoutsToTensorPhaseOps(moduleOp, funcOp, makeLayoutAttrFn))) {
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr =
usedDialects[0]->makeLayoutAttrResolver(moduleOp);
for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
if (failed(addLayoutsToTensorPhaseOps(moduleOp, funcOp,
resolveLayoutAttr))) {
funcOp.emitError(
"failed on adding layouts to Stream::TensorPhaseOp with encodings");
return signalPassFailure();
Expand Down

0 comments on commit cb19fce

Please sign in to comment.