Skip to content

Commit 231dea0

Browse files
authored
[SCFToCalyx] Arith DivFOp lowering and its emit (#8296)
1 parent b984e73 commit 231dea0

File tree

7 files changed

+190
-24
lines changed

7 files changed

+190
-24
lines changed

include/circt/Dialect/Calyx/CalyxPrimitives.td

+55-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
include "mlir/IR/BuiltinAttributeInterfaces.td"
1414

15+
def I3 : I<3>;
16+
def I5 : I<5>;
17+
1518
/// Base class for Calyx primitives.
1619
class CalyxPrimitive<string mnemonic, list<Trait> traits = []> :
1720
CalyxCell<mnemonic, traits> {
@@ -342,8 +345,8 @@ def AddFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.add", [
342345
SameTypeConstraint<"left", "out">
343346
]> {
344347
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
345-
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
346-
AnySignlessInteger:$exceptionalFlags, I1:$done);
348+
AnySignlessInteger:$left, AnySignlessInteger:$right, I3:$roundingMode, AnySignlessInteger:$out,
349+
I5:$exceptionalFlags, I1:$done);
347350

348351
let extraClassDefinition = [{
349352
SmallVector<StringRef> $cppClass::portNames() {
@@ -384,8 +387,8 @@ def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul", [
384387
SameTypeConstraint<"left", "out">
385388
]> {
386389
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control,
387-
AnySignlessInteger:$left, AnySignlessInteger:$right, AnySignlessInteger:$roundingMode, AnySignlessInteger:$out,
388-
AnySignlessInteger:$exceptionalFlags, I1:$done);
390+
AnySignlessInteger:$left, AnySignlessInteger:$right, I3:$roundingMode, AnySignlessInteger:$out,
391+
I5:$exceptionalFlags, I1:$done);
389392
let assemblyFormat = "$sym_name attr-dict `:` qualified(type(results))";
390393
let extraClassDefinition = [{
391394
SmallVector<StringRef> $cppClass::portNames() {
@@ -417,11 +420,58 @@ def MulFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.mul", [
417420
}];
418421
}
419422

423+
// This models the division and square root, distinguished by `sqrtOp`,
424+
// operation interface in Berkeley HardFloat. It computes `left`/`right`
425+
// if `sqrtOp` = 0; it ignores `right` and computes sqrt(`left`) if `sqrtOp` = 1.
426+
def DivSqrtOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.divSqrt", [
427+
SameTypeConstraint<"left", "out">
428+
]> {
429+
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$sqrtOp,
430+
AnySignlessInteger:$left, AnySignlessInteger:$right, I3:$roundingMode, AnySignlessInteger:$out,
431+
I5:$exceptionalFlags, I1:$done);
432+
let assemblyFormat = "$sym_name attr-dict `:` qualified(type(results))";
433+
let extraClassDefinition = [{
434+
SmallVector<StringRef> $cppClass::portNames() {
435+
return {clkPort, resetPort, goPort, "control", "sqrtOp", "left", "right",
436+
"roundingMode", "out", "exceptionalFlags", donePort
437+
};
438+
}
439+
SmallVector<Direction> $cppClass::portDirections() {
440+
return {Input, Input, Input, Input, Input, Input, Input, Input, Output, Output, Output};
441+
}
442+
void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
443+
getCellAsmResultNames(setNameFn, *this, this->portNames());
444+
}
445+
bool $cppClass::isCombinational() { return false; }
446+
SmallVector<DictionaryAttr> $cppClass::portAttributes() {
447+
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
448+
NamedAttrList go, clk, reset, done;
449+
go.append(goPort, isSet);
450+
clk.append(clkPort, isSet);
451+
reset.append(resetPort, isSet);
452+
done.append(donePort, isSet);
453+
return {
454+
clk.getDictionary(getContext()),
455+
reset.getDictionary(getContext()),
456+
go.getDictionary(getContext()),
457+
DictionaryAttr::get(getContext()), // control
458+
DictionaryAttr::get(getContext()), // sqrtOp
459+
DictionaryAttr::get(getContext()), // left
460+
DictionaryAttr::get(getContext()), // right
461+
DictionaryAttr::get(getContext()), // roundingMode
462+
DictionaryAttr::get(getContext()), // out
463+
DictionaryAttr::get(getContext()), // exceptionalFlags
464+
done.getDictionary(getContext())
465+
};
466+
}
467+
}];
468+
}
469+
420470
// This models the compare operation interface in Berkeley HardFloat implementation.
421471
def CompareFOpIEEE754 : ArithBinaryFloatingPointLibraryOp<"ieee754.compare", []> {
422472
let results = (outs I1:$clk, I1:$reset, I1:$go,
423473
AnySignlessInteger:$left, AnySignlessInteger:$right, I1:$signaling,
424-
I1:$lt, I1: $eq, I1: $gt, I1: $unordered, AnySignlessInteger: $exceptionalFlags, I1: $done);
474+
I1:$lt, I1: $eq, I1: $gt, I1: $unordered, I5: $exceptionalFlags, I1: $done);
425475
let assemblyFormat = "$sym_name attr-dict `:` qualified(type(results))";
426476
let extraClassDefinition = [{
427477
SmallVector<StringRef> $cppClass::portNames() {

lib/Conversion/SCFToCalyx/SCFToCalyx.cpp

+42-17
Original file line numberDiff line numberDiff line change
@@ -294,22 +294,22 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
294294
funcOp.walk([&](Operation *_op) {
295295
opBuiltSuccessfully &=
296296
TypeSwitch<mlir::Operation *, bool>(_op)
297-
.template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
298-
/// SCF
299-
scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
300-
scf::ParallelOp, scf::ReduceOp,
301-
scf::ExecuteRegionOp,
302-
/// memref
303-
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
304-
memref::StoreOp, memref::GetGlobalOp,
305-
/// standard arithmetic
306-
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
307-
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
308-
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
309-
/// floating point
310-
AddFOp, SubFOp, MulFOp, CmpFOp, FPToSIOp, SIToFPOp,
311-
/// others
312-
SelectOp, IndexCastOp, BitcastOp, CallOp>(
297+
.template Case<
298+
arith::ConstantOp, ReturnOp, BranchOpInterface,
299+
/// SCF
300+
scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
301+
scf::ParallelOp, scf::ReduceOp, scf::ExecuteRegionOp,
302+
/// memref
303+
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
304+
memref::StoreOp, memref::GetGlobalOp,
305+
/// standard arithmetic
306+
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp, AndIOp,
307+
XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp, MulIOp, DivUIOp,
308+
DivSIOp, RemUIOp, RemSIOp,
309+
/// floating point
310+
AddFOp, SubFOp, MulFOp, CmpFOp, FPToSIOp, SIToFPOp, DivFOp,
311+
/// others
312+
SelectOp, IndexCastOp, BitcastOp, CallOp>(
313313
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
314314
.template Case<FuncOp, scf::ConditionOp>([&](auto) {
315315
/// Skip: these special cases will be handled separately.
@@ -370,6 +370,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
370370
LogicalResult buildOp(PatternRewriter &rewriter, CmpFOp op) const;
371371
LogicalResult buildOp(PatternRewriter &rewriter, FPToSIOp op) const;
372372
LogicalResult buildOp(PatternRewriter &rewriter, SIToFPOp op) const;
373+
LogicalResult buildOp(PatternRewriter &rewriter, DivFOp op) const;
373374
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
374375
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
375376
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
@@ -514,6 +515,12 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
514515
/*subtract=*/1);
515516
}
516517
rewriter.create<calyx::AssignOp>(loc, opFOp.getSubOp(), subOp);
518+
} else if (auto opFOp =
519+
dyn_cast<calyx::DivSqrtOpIEEE754>(opPipe.getOperation())) {
520+
bool isSqrt = !isa<arith::DivFOp>(op);
521+
hw::ConstantOp sqrtOp =
522+
createConstant(loc, rewriter, getComponent(), /*width=*/1, isSqrt);
523+
rewriter.create<calyx::AssignOp>(loc, opFOp.getSqrtOp(), sqrtOp);
517524
}
518525

519526
// Register the values for the pipeline.
@@ -1056,6 +1063,24 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
10561063
sitofp.getOut().getType().getIntOrFloatBitWidth(), "signedIn");
10571064
}
10581065

1066+
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
1067+
DivFOp divf) const {
1068+
Location loc = divf.getLoc();
1069+
IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
1070+
five = rewriter.getIntegerType(5),
1071+
width = rewriter.getIntegerType(
1072+
divf.getType().getIntOrFloatBitWidth());
1073+
auto divFOp = getState<ComponentLoweringState>()
1074+
.getNewLibraryOpInstance<calyx::DivSqrtOpIEEE754>(
1075+
rewriter, loc,
1076+
{/*clk=*/one, /*reset=*/one, /*go=*/one,
1077+
/*control=*/one, /*sqrtOp=*/one, /*left=*/width,
1078+
/*right=*/width, /*roundingMode=*/three, /*out=*/width,
1079+
/*exceptionalFlags=*/five, /*done=*/one});
1080+
return buildLibraryBinaryPipeOp<calyx::DivSqrtOpIEEE754>(
1081+
rewriter, divf, divFOp, divFOp.getOut());
1082+
}
1083+
10591084
template <typename TAllocOp>
10601085
static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
10611086
PatternRewriter &rewriter, TAllocOp allocOp) {
@@ -2482,7 +2507,7 @@ class SCFToCalyxPass : public circt::impl::SCFToCalyxBase<SCFToCalyxPass> {
24822507
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
24832508
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp,
24842509
BitcastOp, FuncOp, ExtSIOp, CallOp, AddFOp, SubFOp,
2485-
MulFOp, CmpFOp, FPToSIOp, SIToFPOp>();
2510+
MulFOp, CmpFOp, FPToSIOp, SIToFPOp, DivFOp>();
24862511

24872512
RewritePatternSet legalizePatterns(&getContext());
24882513
legalizePatterns.add<DummyPattern>(&getContext());

lib/Dialect/Calyx/CalyxOps.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,10 @@ FloatingPointStandard IntToFpOpIEEE754::getFloatingPointStandard() {
12211221
return FloatingPointStandard::IEEE754;
12221222
}
12231223

1224+
FloatingPointStandard DivSqrtOpIEEE754::getFloatingPointStandard() {
1225+
return FloatingPointStandard::IEEE754;
1226+
}
1227+
12241228
std::string AddFOpIEEE754::getCalyxLibraryName() { return "std_addFN"; }
12251229

12261230
std::string MulFOpIEEE754::getCalyxLibraryName() { return "std_mulFN"; }
@@ -1230,6 +1234,8 @@ std::string CompareFOpIEEE754::getCalyxLibraryName() { return "std_compareFN"; }
12301234
std::string FpToIntOpIEEE754::getCalyxLibraryName() { return "std_fpToInt"; }
12311235

12321236
std::string IntToFpOpIEEE754::getCalyxLibraryName() { return "std_intToFp"; }
1237+
1238+
std::string DivSqrtOpIEEE754::getCalyxLibraryName() { return "std_divSqrtFN"; }
12331239
//===----------------------------------------------------------------------===//
12341240
// GroupInterface
12351241
//===----------------------------------------------------------------------===//

lib/Dialect/Calyx/Export/CalyxEmitter.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ struct ImportTracker {
169169
static constexpr std::string_view sFloatingPoint = "float/intToFp";
170170
return {sFloatingPoint};
171171
})
172+
.Case<DivSqrtOpIEEE754>([&](auto op) -> FailureOr<StringRef> {
173+
static constexpr std::string_view sFloatingPoint = "float/divSqrtFN";
174+
return {sFloatingPoint};
175+
})
172176
.Default([&](auto op) {
173177
auto diag = op->emitOpError() << "not supported for emission";
174178
return diag;
@@ -692,7 +696,7 @@ void Emitter::emitComponent(ComponentInterface op) {
692696
op, /*calyxLibName=*/{"std_sdiv_pipe"});
693697
})
694698
.Case<AddFOpIEEE754, MulFOpIEEE754, CompareFOpIEEE754,
695-
FpToIntOpIEEE754, IntToFpOpIEEE754>(
699+
FpToIntOpIEEE754, IntToFpOpIEEE754, DivSqrtOpIEEE754>(
696700
[&](auto op) { emitLibraryFloatingPoint(op); })
697701
.Default([&](auto op) {
698702
emitOpError(op, "not supported for emission inside component");

lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,8 @@ void InlineCombGroups::recurseInlineCombGroups(
713713
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp,
714714
calyx::ConstantOp, calyx::AddFOpIEEE754, calyx::MulFOpIEEE754,
715715
calyx::CompareFOpIEEE754, calyx::FpToIntOpIEEE754,
716-
calyx::IntToFpOpIEEE754>(src.getDefiningOp()))
716+
calyx::IntToFpOpIEEE754, calyx::DivSqrtOpIEEE754>(
717+
src.getDefiningOp()))
717718
continue;
718719

719720
auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(

test/Conversion/SCFToCalyx/convert_simple.mlir

+30
Original file line numberDiff line numberDiff line change
@@ -603,3 +603,33 @@ module {
603603
return %0 : f64
604604
}
605605
}
606+
607+
// -----
608+
609+
// Test floating point division
610+
611+
// CHECK: %cst = calyx.constant @cst_0 <4.200000e+00 : f32> : i32
612+
// CHECK-DAG: %true = hw.constant true
613+
// CHECK-DAG: %false = hw.constant false
614+
// CHECK-DAG: %divf_0_reg.in, %divf_0_reg.write_en, %divf_0_reg.clk, %divf_0_reg.reset, %divf_0_reg.out, %divf_0_reg.done = calyx.register @divf_0_reg : i32, i1, i1, i1, i32, i1
615+
// CHECK-DAG: %std_divSqrtFN_0.clk, %std_divSqrtFN_0.reset, %std_divSqrtFN_0.go, %std_divSqrtFN_0.control, %std_divSqrtFN_0.sqrtOp, %std_divSqrtFN_0.left, %std_divSqrtFN_0.right, %std_divSqrtFN_0.roundingMode, %std_divSqrtFN_0.out, %std_divSqrtFN_0.exceptionalFlags, %std_divSqrtFN_0.done = calyx.ieee754.divSqrt @std_divSqrtFN_0 : i1, i1, i1, i1, i1, i32, i32, i3, i32, i5, i1
616+
// CHECK-DAG: %ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
617+
// CHECK: calyx.group @bb0_0 {
618+
// CHECK-DAG: calyx.assign %std_divSqrtFN_0.left = %in0 : i32
619+
// CHECK-DAG: calyx.assign %std_divSqrtFN_0.right = %cst : i32
620+
// CHECK-DAG: calyx.assign %divf_0_reg.in = %std_divSqrtFN_0.out : i32
621+
// CHECK-DAG: calyx.assign %divf_0_reg.write_en = %std_divSqrtFN_0.done : i1
622+
// CHECK-DAG: %0 = comb.xor %std_divSqrtFN_0.done, %true : i1
623+
// CHECK-DAG: calyx.assign %std_divSqrtFN_0.go = %0 ? %true : i1
624+
// CHECK-DAG: calyx.assign %std_divSqrtFN_0.sqrtOp = %false : i1
625+
// CHECK-DAG: calyx.group_done %divf_0_reg.done : i1
626+
// CHECK-DAG: }
627+
628+
module {
629+
func.func @main(%arg0 : f32) -> f32 {
630+
%0 = arith.constant 4.2 : f32
631+
%1 = arith.divf %arg0, %0 : f32
632+
633+
return %1 : f32
634+
}
635+
}

test/Dialect/Calyx/emit.mlir

+50
Original file line numberDiff line numberDiff line change
@@ -535,3 +535,53 @@ module attributes {calyx.entrypoint = "main"} {
535535
}
536536
} {toplevel}
537537
}
538+
539+
// -----
540+
541+
module attributes {calyx.entrypoint = "main"} {
542+
// CHECK: import "primitives/float/divSqrtFN.futil";
543+
calyx.component @main(%in0: i32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %done: i1 {done}) {
544+
%cst = calyx.constant @cst_0 <4.200000e+00 : f32> : i32
545+
%true = hw.constant true
546+
%false = hw.constant false
547+
%divf_0_reg.in, %divf_0_reg.write_en, %divf_0_reg.clk, %divf_0_reg.reset, %divf_0_reg.out, %divf_0_reg.done = calyx.register @divf_0_reg : i32, i1, i1, i1, i32, i1
548+
// CHECK-DAG: std_divSqrtFN_0 = std_divSqrtFN(8, 24, 32);
549+
%std_divSqrtFN_0.clk, %std_divSqrtFN_0.reset, %std_divSqrtFN_0.go, %std_divSqrtFN_0.control, %std_divSqrtFN_0.sqrtOp, %std_divSqrtFN_0.left, %std_divSqrtFN_0.right, %std_divSqrtFN_0.roundingMode, %std_divSqrtFN_0.out, %std_divSqrtFN_0.exceptionalFlags, %std_divSqrtFN_0.done = calyx.ieee754.divSqrt @std_divSqrtFN_0 : i1, i1, i1, i1, i1, i32, i32, i3, i32, i5, i1
550+
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
551+
calyx.wires {
552+
calyx.assign %out0 = %ret_arg0_reg.out : i32
553+
// CHECK-LABEL: group bb0_0 {
554+
// CHECK-NEXT: std_divSqrtFN_0.left = in0;
555+
// CHECK-NEXT: std_divSqrtFN_0.right = cst_0.out;
556+
// CHECK-NEXT: divf_0_reg.in = std_divSqrtFN_0.out;
557+
// CHECK-NEXT: divf_0_reg.write_en = std_divSqrtFN_0.done;
558+
// CHECK-NEXT: std_divSqrtFN_0.go = !std_divSqrtFN_0.done ? 1'd1;
559+
// CHECK-NEXT: std_divSqrtFN_0.sqrtOp = 1'd0;
560+
// CHECK-NEXT: bb0_0[done] = divf_0_reg.done;
561+
// CHECK-NEXT: }
562+
calyx.group @bb0_0 {
563+
calyx.assign %std_divSqrtFN_0.left = %in0 : i32
564+
calyx.assign %std_divSqrtFN_0.right = %cst : i32
565+
calyx.assign %divf_0_reg.in = %std_divSqrtFN_0.out : i32
566+
calyx.assign %divf_0_reg.write_en = %std_divSqrtFN_0.done : i1
567+
%0 = comb.xor %std_divSqrtFN_0.done, %true : i1
568+
calyx.assign %std_divSqrtFN_0.go = %0 ? %true : i1
569+
calyx.assign %std_divSqrtFN_0.sqrtOp = %false : i1
570+
calyx.group_done %divf_0_reg.done : i1
571+
}
572+
calyx.group @ret_assign_0 {
573+
calyx.assign %ret_arg0_reg.in = %divf_0_reg.out : i32
574+
calyx.assign %ret_arg0_reg.write_en = %true : i1
575+
calyx.group_done %ret_arg0_reg.done : i1
576+
}
577+
}
578+
calyx.control {
579+
calyx.seq {
580+
calyx.seq {
581+
calyx.enable @bb0_0
582+
calyx.enable @ret_assign_0
583+
}
584+
}
585+
}
586+
} {toplevel}
587+
}

0 commit comments

Comments
 (0)