Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/Dialect/ModArith/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ cc_library(
":ops_inc_gen",
":type_interfaces_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@heir//lib/Dialect/RNS/IR:TypeInterfaces",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down Expand Up @@ -66,6 +67,7 @@ td_library(
includes = ["../../../.."],
deps = [
":type_interfaces_td",
"@heir//lib/Dialect:td_files",
"@heir//lib/Dialect/RNS/IR:type_interfaces_td",
"@heir//lib/Utils/DRR",
"@llvm-project//mlir:ArithOpsTdFiles",
Expand Down
166 changes: 104 additions & 62 deletions lib/Dialect/ModArith/IR/ModArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

// NOLINTBEGIN(misc-include-cleaner): Required to define
// ModArithDialect, ModArithTypes, ModArithOps,
#include "lib/Dialect/HEIRInterfaces.h"
#include "lib/Dialect/ModArith/IR/ModArithAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
// NOLINTEND(misc-include-cleaner)

#define DEBUG_TYPE "mod-arith"
Expand Down Expand Up @@ -381,101 +382,142 @@ OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
llvm::dbgs() << " Folded : " << foldedVal << "\n";
llvm::dbgs() << "========================================\n";
});

// Create the result
return IntegerAttr::get(storageType, foldedVal);
}

/// Helper function to handle common folding logic for binary arithmetic
/// operations.
/// - `opName` is used for debug output.
/// - `foldBinFn` defines how the actual binary operation (+, -, *) should be
/// performed.
template <typename FoldAdaptor, typename FoldBinFn>
static OpFoldResult foldBinModOp(Operation* op, FoldAdaptor adaptor,
FoldBinFn&& foldBinFn,
llvm::StringRef opName) {
// TODO(#1759): support dense attributes
namespace {
enum class ModOp { Add, Sub, Mul, Mac };

std::optional<APInt> getRawAPInt(Attribute attr) {
if (!attr) return std::nullopt;
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
return intAttr.getValue();
}
if (auto modAttr = dyn_cast<ModArithAttr>(attr)) {
return modAttr.getValue().getValue();
}
return std::nullopt;
}

// Check if lhs and rhs are IntegerAttrs
auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
Attribute foldScalarModOp(ArrayRef<Attribute> operands,
ModArithType residueType, ModOp op,
StringRef opName) {
auto lhs = getRawAPInt(operands[0]);
auto rhs = getRawAPInt(operands[1]);
if (!lhs || !rhs) return {};

auto modType = dyn_cast<ModArithType>(op->getResultTypes().front());
if (!modType) return {};
std::optional<APInt> acc = std::nullopt;
if (op == ModOp::Mac) {
if (operands.size() < 3) return {};
acc = getRawAPInt(operands[2]);
if (!acc) return {};
}

// Retrieve the modulus value and its bit width
APInt modulus = modType.getModulus().getValue();
APInt modulus = residueType.getModulus().getValue();
unsigned modBitWidth = modulus.getBitWidth();

// Extract the actual integer values
APInt lhsVal = lhs.getValue();
APInt rhsVal = rhs.getValue();

// Adjust lhsVal and rhsVal bit widths to match modulus if necessary
lhsVal = lhsVal.zextOrTrunc(modBitWidth);
rhsVal = rhsVal.zextOrTrunc(modBitWidth);
// Strict accumulation safety (1 extra bit or more)
unsigned workWidth = 2 * modBitWidth + 1;

APInt lw = lhs->zextOrTrunc(modBitWidth).zext(workWidth);
APInt rw = rhs->zextOrTrunc(modBitWidth).zext(workWidth);
APInt mw = modulus.zext(workWidth);

APInt res;
switch (op) {
case ModOp::Add:
res = lw + rw;
break;
case ModOp::Sub:
res = lw + mw - rw.urem(mw);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add the modulus here? Is there an assumption somewhere that values should not be negative? c.f. my confusion here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the urem below... any other reason?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring the urem at the end, I think until ModArithToArith is updated to support anything except standard representatives in [0, q), we will have to maintain that invariant in the folders or else the lowerings may break.

break;
case ModOp::Mul:
res = lw * rw;
break;
case ModOp::Mac: {
APInt aw = acc->zextOrTrunc(modBitWidth).zext(workWidth);
res = aw + (lw * rw);
break;
}
}

// Perform the operation using the provided foldBinFn
APInt foldedVal = foldBinFn(lhsVal, rhsVal, modulus);
APInt foldedVal = res.urem(mw).trunc(modBitWidth);

LLVM_DEBUG({
llvm::dbgs() << "\n";
llvm::dbgs() << "========================================\n";
llvm::dbgs() << " Folding Operation: " << opName << "\n";
llvm::dbgs() << " Folding Operation: " << opName << " (Limbwise)\n";
llvm::dbgs() << "----------------------------------------\n";
llvm::dbgs() << " LHS : " << lhsVal << "\n";
llvm::dbgs() << " RHS : " << rhsVal << "\n";
llvm::dbgs() << " LHS : " << *lhs << "\n";
llvm::dbgs() << " RHS : " << *rhs << "\n";
if (acc) {
llvm::dbgs() << " ACC : " << *acc << "\n";
}
llvm::dbgs() << " Modulus : " << modulus << "\n";
llvm::dbgs() << " Folded : " << foldedVal << "\n";
llvm::dbgs() << "========================================\n";
});

// Create the result
auto elementType = modType.getModulus().getType();
return IntegerAttr::get(elementType, foldedVal);
auto intAttr =
IntegerAttr::get(residueType.getModulus().getType(), foldedVal);
return ModArithAttr::get(residueType.getContext(), residueType, intAttr);
}

// add(c0, c1) -> (c0 + c1) mod q
} // namespace

Attribute AddOp::foldScalarResidue(ArrayRef<Attribute> operands,
Type residueType) {
auto modType = dyn_cast<ModArithType>(residueType);
if (!modType) return {};
return foldScalarModOp(operands, modType, ModOp::Add, "Add");
}
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return foldBinModOp(
getOperation(), adaptor,
[](APInt lhs, APInt rhs, APInt modulus) {
APInt sum = lhs + rhs;
return sum.urem(modulus);
},
"Add");
return heir::foldLimbwise(getOperation(), adaptor.getOperands(),
getResult().getType());
}

// sub(c0, c1) -> (c0 - c1) mod q
Attribute SubOp::foldScalarResidue(ArrayRef<Attribute> operands,
Type residueType) {
auto modType = dyn_cast<ModArithType>(residueType);
if (!modType) return {};
return foldScalarModOp(operands, modType, ModOp::Sub, "Sub");
}
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
return foldBinModOp(
getOperation(), adaptor,
[](APInt lhs, APInt rhs, APInt modulus) {
APInt diff = lhs - rhs;
if (diff.isNegative()) {
diff += modulus;
}
return diff.urem(modulus);
},
"Sub");
return heir::foldLimbwise(getOperation(), adaptor.getOperands(),
getResult().getType());
}

// mul(c0, c1) -> (c0 * c1) mod q
Attribute MulOp::foldScalarResidue(ArrayRef<Attribute> operands,
Type residueType) {
auto modType = dyn_cast<ModArithType>(residueType);
if (!modType) return {};
return foldScalarModOp(operands, modType, ModOp::Mul, "Mul");
}
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return foldBinModOp(
getOperation(), adaptor,
[](APInt lhs, APInt rhs, APInt modulus) {
APInt product = lhs * rhs;
return product.urem(modulus);
},
"Mul");
return heir::foldLimbwise(getOperation(), adaptor.getOperands(),
getResult().getType());
}

Attribute MacOp::foldScalarResidue(ArrayRef<Attribute> operands,
Type residueType) {
auto modType = dyn_cast<ModArithType>(residueType);
if (!modType) return {};
return foldScalarModOp(operands, modType, ModOp::Mac, "Mac");
}
OpFoldResult MacOp::fold(FoldAdaptor adaptor) {
return heir::foldLimbwise(getOperation(), adaptor.getOperands(),
getResult().getType());
}

Operation* ModArithDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
Location loc) {
if (auto limbwiseAttr = dyn_cast_if_present<LimbwiseAttrInterface>(value)) {
return limbwiseAttr.materializeConstant(builder, loc, type);
}
if (auto modArithAttr = dyn_cast<ModArithAttr>(value)) {
value = modArithAttr.getValue();
}
// TODO(#1759): support dense attributes
auto intAttr = dyn_cast_if_present<IntegerAttr>(value);
if (!intAttr) return nullptr;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/IR/ModArithOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_H_

// IWYU pragma: begin_keep
#include "lib/Dialect/HEIRInterfaces.h"
#include "lib/Dialect/ModArith/IR/ModArithDialect.h"
#include "lib/Dialect/ModArith/IR/ModArithTypeInterfaces.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/ModArith/IR/ModArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
include "lib/Dialect/ModArith/IR/ModArithDialect.td"
include "lib/Dialect/ModArith/IR/ModArithEnums.td"
include "lib/Dialect/ModArith/IR/ModArithTypes.td"
include "lib/Dialect/HEIRInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -167,7 +168,7 @@ def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOp
}

class ModArith_BinaryOp<string mnemonic, list<Trait> traits = []> :
ModArith_Op<mnemonic, traits # [SameOperandsAndResultType, Pure, ElementwiseMappable]>,
ModArith_Op<mnemonic, traits # [SameOperandsAndResultType, Pure, ElementwiseMappable, DeclareOpInterfaceMethods<LimbwiseMappableOpInterface, ["foldScalarResidue"]>]>,
Arguments<(ins ModQLike:$lhs, ModQLike:$rhs)>,
Results<(outs ModQLike:$output)> {
let assemblyFormat ="operands attr-dict `:` type($output)";
Expand Down Expand Up @@ -209,7 +210,7 @@ def ModArith_MulOp : ModArith_BinaryOp<"mul", [Commutative]> {
let hasFolder = 1;
}

def ModArith_MacOp : ModArith_Op<"mac", [SameOperandsAndResultType, Pure, ElementwiseMappable]> {
def ModArith_MacOp : ModArith_Op<"mac", [SameOperandsAndResultType, Pure, ElementwiseMappable, DeclareOpInterfaceMethods<LimbwiseMappableOpInterface, ["foldScalarResidue"]>]> {
let summary = "modular multiplication-and-accumulation operation";

let description = [{
Expand All @@ -218,9 +219,10 @@ def ModArith_MacOp : ModArith_Op<"mac", [SameOperandsAndResultType, Pure, Elemen
Unless otherwise specified, the operation assumes all inputs are canonical
representatives and guarantees the output being canonical representative.
}];
let arguments = (ins ModArithLike:$lhs, ModArithLike:$rhs, ModArithLike:$acc);
let results = (outs ModArithLike:$output);
let arguments = (ins ModQLike:$lhs, ModQLike:$rhs, ModQLike:$acc);
let results = (outs ModQLike:$output);
let assemblyFormat = "operands attr-dict `:` type($output)";
let hasFolder = 1;
}

// TODO(#1084): migrate barrett/subifge to mod arith type
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/RNS/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
":ops_inc_gen",
":type_interfaces_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/ModArith/IR:TypeInterfaces",
"@heir//lib/Utils:APIntUtils",
Expand Down Expand Up @@ -65,6 +66,7 @@ td_library(
includes = ["../../../.."],
deps = [
":type_interfaces_td",
"@heir//lib/Dialect:td_files",
"@heir//lib/Dialect/ModArith/IR:type_interfaces_td",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
Expand Down
37 changes: 34 additions & 3 deletions lib/Dialect/RNS/IR/RNSAttributes.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#include "lib/Dialect/RNS/IR/RNSAttributes.h"

#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "lib/Dialect/RNS/IR/RNSOps.h"
#include "lib/Dialect/RNS/IR/RNSTypes.h"
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -55,6 +59,33 @@ Attribute RNSAttr::parse(AsmParser& parser, Type type) {
parser.getContext(), ArrayRef<Attribute>(attrValues), rnsType);
}

// LimbwiseAttrInterface
::mlir::Attribute RNSAttr::assembleFromLimbs(
::mlir::Type resultAttrType, ::llvm::ArrayRef< ::mlir::Attribute> limbs) {
auto rnsType = dyn_cast_if_present<RNSType>(resultAttrType);
if (!rnsType) return {};
return get(rnsType, limbs);
}

::mlir::Attribute RNSAttr::extractLimb(unsigned index) const {
return getValues()[index];
}

::mlir::Type RNSAttr::getLimbType(unsigned index) const {
return cast<RNSType>(getType()).getBasisTypes()[index];
}

unsigned RNSAttr::getNumLimbs() const { return getValues().size(); }

::mlir::Operation* RNSAttr::materializeConstant(::mlir::OpBuilder& builder,
::mlir::Location loc,
::mlir::Type type) const {
auto rnsType = dyn_cast_if_present<RNSType>(type);
if (!rnsType) return nullptr;
auto op = rns::ConstantOp::create(builder, loc, rnsType, *this);
return op.getOperation();

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is that different than return op?

@j2kun j2kun Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you wrote indeed compiles, but that's because, technically speaking, the tablegen-generated classes like ConstantOp inherit from ::mlir::Op (cf here) which wrap an Operation *, and these classes have an implicit typecaster from ::mlir::Op to Operation *. Usually when I explicitly want an Operation *, however, I'll try to make it more explicit and ask for the Operation * explicitly.

}

} // namespace rns
} // namespace heir
} // namespace mlir
3 changes: 3 additions & 0 deletions lib/Dialect/RNS/IR/RNSAttributes.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#ifndef LIB_DIALECT_RNS_IR_RNSATTRIBUTES_H_
#define LIB_DIALECT_RNS_IR_RNSATTRIBUTES_H_

// IWYU pragma: begin_keep
#include "lib/Dialect/HEIRInterfaces.h"
#include "lib/Dialect/RNS/IR/RNSDialect.h"
#include "lib/Dialect/RNS/IR/RNSTypes.h"
// IWYU pragma: end_keep

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/RNS/IR/RNSAttributes.h.inc"
Expand Down
9 changes: 8 additions & 1 deletion lib/Dialect/RNS/IR/RNSAttributes.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LIB_DIALECT_RNS_IR_RNSATTRIBUTES_TD_
#define LIB_DIALECT_RNS_IR_RNSATTRIBUTES_TD_

include "lib/Dialect/HEIRInterfaces.td"
include "lib/Dialect/RNS/IR/RNSDialect.td"
include "lib/Dialect/RNS/IR/RNSTypes.td"

Expand All @@ -14,7 +15,7 @@ class RNS_Attr<string name, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}

def RNS_RNSAttr : RNS_Attr<"RNS", "value"> {
def RNS_RNSAttr : RNS_Attr<"RNS", "value", [DeclareAttrInterfaceMethods<LimbwiseAttrInterface>]> {
let summary = "a typed RNS value";
let description = [{
A typed RNS value with one integer per basis limb.
Expand All @@ -36,6 +37,12 @@ def RNS_RNSAttr : RNS_Attr<"RNS", "value"> {
);
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
static RNSAttr get(::mlir::heir::rns::RNSType type, ::llvm::ArrayRef<::mlir::Attribute> values) {
return get(type.getContext(), values, type);
}

}];
}

#endif // LIB_DIALECT_RNS_IR_RNSATTRIBUTES_TD_
Loading
Loading