Skip to content

Commit

Permalink
[Encoding][NFC] Move non attribute implementation to EncodingTypes.cpp (
Browse files Browse the repository at this point in the history
#20045)

It follows the convention, which only put the attribute implementation
and local functions to EncodingAttr.cpp. The implementation of other
utilities is moved to EncodingTypes.cpp.

Additionally, it spells out the types for `auto`.

---------

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Feb 20, 2025
1 parent 308d176 commit fb3523b
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 110 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ iree_compiler_cc_library(
"EncodingOps.cpp",
"EncodingOps.cpp.inc",
"EncodingTypeInterfaces.cpp.inc",
"EncodingTypes.cpp",
"EncodingTypes.cpp.inc",
],
hdrs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ iree_cc_library(
"EncodingOps.cpp"
"EncodingOps.cpp.inc"
"EncodingTypeInterfaces.cpp.inc"
"EncodingTypes.cpp"
"EncodingTypes.cpp.inc"
DEPS
::EncodingEnumsGen
Expand Down
113 changes: 3 additions & 110 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,20 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"

#include <cassert>

namespace mlir::iree_compiler::IREE::Encoding {

Expand Down Expand Up @@ -72,33 +65,6 @@ EncodingAttr::mapDimToOperandIndex(int64_t dimPos) const {
getAffineDimExpr(dimPos, getContext()));
}

MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp,
int narrowThreshold) {
linalg::ContractionDimensions cDims =
linalg::inferContractionDims(linalgOp).value();
auto map = linalgOp.getIndexingMapsArray().back();
auto outType = llvm::cast<ShapedType>(linalgOp.getDpsInits()[0].getType());
auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t {
return outType.getDimSize(
map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext()))
.value());
};
// M or N can be empty instead of having an explicit dim size of 1 for matvec
// and vecmat, so set to 1 if empty.
int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]);
int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]);

MatmulNarrowDim narrowM, narrowN;
if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) {
narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize};
}
if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) {
narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize};
}

return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN;
}

ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() const {
auto roundDimsTo = getRoundDimsTo();
if (!roundDimsTo) {
Expand Down Expand Up @@ -252,75 +218,6 @@ Value EncodingAttr::calculateStorageSizeInBytes(Location loc,
return result;
}

MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
if (encoding.getOpType().getValue() != EncodingOpType::matmul) {
return {};
}
ArrayRef<int64_t> roundDimsTo = encoding.getRoundDimsToArray();
if (roundDimsTo.empty()) {
return {};
}
int m = roundDimsTo[0];
int n = roundDimsTo[1];
if (m < n) {
return {MatmulNarrowDim::Dim::M, m};
}
if (n < m) {
return {MatmulNarrowDim::Dim::N, n};
}
return {};
}

bool isNarrowNResult(EncodingAttr encoding) {
if (encoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RESULT) {
return false;
}

return IREE::Encoding::getMatmulNarrowDim(encoding).isN();
}

SerializableEncodingAttrInterface
getSerializableEncodingAttrInterface(RankedTensorType type) {
return dyn_cast_or_null<SerializableEncodingAttrInterface>(
type.getEncoding());
}

EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

bool hasPackedStorageAttr(RankedTensorType type) {
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
}

FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding) {
auto indexingMapsAttr = encoding.getUserIndexingMaps();
if (!indexingMapsAttr) {
return failure();
}
SmallVector<AffineMap> indexingMaps = llvm::map_to_vector(
indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap {
return cast<AffineMapAttr>(m).getAffineMap();
});
return linalg::inferContractionDims(indexingMaps);
}

std::string stringifyOperandIndex(IntegerAttr valueAttr) {
auto value = valueAttr.getValue().getZExtValue();
switch (value) {
case MATMUL_LHS:
return "LHS";
case MATMUL_RHS:
return "RHS";
case MATMUL_RESULT:
return "RESULT";
default:
assert(false && "invalid index");
return "";
}
}

Value PadEncodingLayoutAttr::calculateStorageSizeInBytes(
Location loc, OpBuilder &builder, RankedTensorType type,
ValueRange dynamicDims) const {
Expand Down Expand Up @@ -467,10 +364,6 @@ void SpecializedEncodingAttr::print(AsmPrinter &p) const {
os << ">";
}

RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}

Attribute SpecializedEncodingAttr::getLayout(RankedTensorType type) const {
MLIRContext *ctx = getContext();
return get(ctx, getSeed(), TypeAttr::get(dropEncoding(type)));
Expand Down
120 changes: 120 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"

#include <cassert>

namespace mlir::iree_compiler::IREE::Encoding {

SerializableEncodingAttrInterface
getSerializableEncodingAttrInterface(RankedTensorType type) {
return dyn_cast_or_null<SerializableEncodingAttrInterface>(
type.getEncoding());
}

EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

bool hasPackedStorageAttr(RankedTensorType type) {
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
}

FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding) {
ArrayAttr indexingMapsAttr = encoding.getUserIndexingMaps();
if (!indexingMapsAttr) {
return failure();
}
SmallVector<AffineMap> indexingMaps = llvm::map_to_vector(
indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap {
return cast<AffineMapAttr>(m).getAffineMap();
});
return linalg::inferContractionDims(indexingMaps);
}

std::string stringifyOperandIndex(IntegerAttr valueAttr) {
uint64_t value = valueAttr.getValue().getZExtValue();
switch (value) {
case MATMUL_LHS:
return "LHS";
case MATMUL_RHS:
return "RHS";
case MATMUL_RESULT:
return "RESULT";
default:
assert(false && "invalid index");
return "";
}
}

MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp,
int narrowThreshold) {
linalg::ContractionDimensions cDims =
linalg::inferContractionDims(linalgOp).value();
AffineMap map = linalgOp.getIndexingMapsArray().back();
auto outType = llvm::cast<ShapedType>(linalgOp.getDpsInits()[0].getType());
auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t {
return outType.getDimSize(
map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext()))
.value());
};
// M or N can be empty instead of having an explicit dim size of 1 for matvec
// and vecmat, so set to 1 if empty.
int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]);
int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]);

MatmulNarrowDim narrowM, narrowN;
if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) {
narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize};
}
if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) {
narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize};
}

return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN;
}

MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
if (encoding.getOpType().getValue() != EncodingOpType::matmul) {
return {};
}
ArrayRef<int64_t> roundDimsTo = encoding.getRoundDimsToArray();
if (roundDimsTo.empty()) {
return {};
}
int m = roundDimsTo[0];
int n = roundDimsTo[1];
if (m < n) {
return {MatmulNarrowDim::Dim::M, m};
}
if (n < m) {
return {MatmulNarrowDim::Dim::N, n};
}
return {};
}

bool isNarrowNResult(EncodingAttr encoding) {
if (encoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RESULT) {
return false;
}

return IREE::Encoding::getMatmulNarrowDim(encoding).isN();
}

RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}

} // namespace mlir::iree_compiler::IREE::Encoding

0 comments on commit fb3523b

Please sign in to comment.