Skip to content

Commit

Permalink
[Arc][Sim] Lower Sim DPI func to func.func and support dpi call in Arc (
Browse files Browse the repository at this point in the history
#7386)

This PR implements initial support for lowering Sim DPI operations to Arc. 

* sim::LowerDPIFuncPass implements lowering from `sim.dpi.func` to `func.func` that respects C-level ABI. 
* arc::LowerStatePass is modified to allocate states and call functions for `sim.dpi.call` op. 

Currently unclocked call is not supported yet.
  • Loading branch information
uenoku committed Aug 7, 2024
1 parent 1a8f82e commit 9828707
Show file tree
Hide file tree
Showing 16 changed files with 378 additions and 41 deletions.
5 changes: 5 additions & 0 deletions include/circt/Dialect/Sim/SimPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@ def ProceduralizeSim : Pass<"sim-proceduralize", "hw::HWModuleOp"> {
let dependentDialects = ["circt::hw::HWDialect, circt::seq::SeqDialect, mlir::scf::SCFDialect"];
}

def LowerDPIFunc : Pass<"sim-lower-dpi-func", "mlir::ModuleOp"> {
let summary = "Lower sim.dpi.func into func.func for the simulation flow";
let dependentDialects = ["mlir::func::FuncDialect", "mlir::LLVM::LLVMDialect"];
}

#endif // CIRCT_DIALECT_SIM_SEQPASSES
39 changes: 39 additions & 0 deletions integration_test/arcilator/JIT/dpi.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit

// CHECK: c = 0
// CHECK-NEXT: c = 5
sim.func.dpi @dpi(in %a : i32, in %b : i32, out c : i32) attributes {verilogName = "adder_func"}
func.func @adder_func(%arg0: i32, %arg1: i32, %arg2: !llvm.ptr) {
%0 = arith.addi %arg0, %arg1 : i32
llvm.store %0, %arg2 : i32, !llvm.ptr
return
}
hw.module @adder(in %clock : i1, in %a : i32, in %b : i32, out c : i32) {
%seq_clk = seq.to_clock %clock

%0 = sim.func.dpi.call @dpi(%a, %b) clock %seq_clk : (i32, i32) -> i32
hw.output %0 : i32
}
func.func @main() {
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%one = arith.constant 1 : i1
%zero = arith.constant 0 : i1
arc.sim.instantiate @adder as %arg0 {
arc.sim.set_input %arg0, "a" = %c2_i32 : i32, !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "b" = %c3_i32 : i32, !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@adder>

arc.sim.step %arg0 : !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %zero : i1, !arc.sim.instance<@adder>
%0 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@adder>
arc.sim.emit "c", %0 : i32

arc.sim.step %arg0 : !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@adder>
%2 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@adder>
arc.sim.emit "c", %2 : i32
}
return
}
1 change: 1 addition & 0 deletions lib/Conversion/ConvertToArcs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ add_circt_conversion_library(CIRCTConvertToArcs
CIRCTArc
CIRCTHW
CIRCTSeq
CIRCTSim
MLIRTransforms
)
3 changes: 2 additions & 1 deletion lib/Conversion/ConvertToArcs/ConvertToArcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "circt/Dialect/Arc/ArcOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Support/Namespace.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -25,7 +26,7 @@ using llvm::MapVector;
static bool isArcBreakingOp(Operation *op) {
return op->hasTrait<OpTrait::ConstantLike>() ||
isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, ClockedOpInterface,
seq::ClockGateOp>(op) ||
seq::ClockGateOp, sim::DPICallOp>(op) ||
op->getNumResults() > 1;
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Arc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_circt_dialect_library(CIRCTArcTransforms
CIRCTOM
CIRCTSV
CIRCTSeq
CIRCTSim
CIRCTSupport
MLIRFuncDialect
MLIRLLVMDialect
Expand Down
93 changes: 60 additions & 33 deletions lib/Dialect/Arc/Transforms/LowerState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Support/BackedgeBuilder.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -117,7 +118,12 @@ struct ModuleLowering {
LogicalResult lowerPrimaryInputs();
LogicalResult lowerPrimaryOutputs();
LogicalResult lowerStates();
template <typename CallTy>
LogicalResult lowerStateLike(Operation *op, Value clock, Value enable,
Value reset, ArrayRef<Value> inputs,
FlatSymbolRefAttr callee);
LogicalResult lowerState(StateOp stateOp);
LogicalResult lowerState(sim::DPICallOp dpiCallOp);
LogicalResult lowerState(MemoryOp memOp);
LogicalResult lowerState(MemoryWritePortOp memWriteOp);
LogicalResult lowerState(TapOp tapOp);
Expand All @@ -139,7 +145,7 @@ static bool shouldMaterialize(Operation *op) {
return !isa<MemoryOp, AllocStateOp, AllocMemoryOp, AllocStorageOp,
ClockTreeOp, PassThroughOp, RootInputOp, RootOutputOp,
StateWriteOp, MemoryWritePortOp, igraph::InstanceOpInterface,
StateOp>(op);
StateOp, sim::DPICallOp>(op);
}

static bool shouldMaterialize(Value value) {
Expand Down Expand Up @@ -390,53 +396,48 @@ LogicalResult ModuleLowering::lowerPrimaryOutputs() {
LogicalResult ModuleLowering::lowerStates() {
SmallVector<Operation *> opsToLower;
for (auto &op : *moduleOp.getBodyBlock())
if (isa<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(&op))
if (isa<StateOp, MemoryOp, MemoryWritePortOp, TapOp, sim::DPICallOp>(&op))
opsToLower.push_back(&op);

for (auto *op : opsToLower) {
LLVM_DEBUG(llvm::dbgs() << "- Lowering " << *op << "\n");
auto result = TypeSwitch<Operation *, LogicalResult>(op)
.Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(
[&](auto op) { return lowerState(op); })
.Default(success());
auto result =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp, sim::DPICallOp>(
[&](auto op) { return lowerState(op); })
.Default(success());
if (failed(result))
return failure();
}
return success();
}

LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
// We don't support arcs beyond latency 1 yet. These should be easy to add in
// the future though.
if (stateOp.getLatency() > 1)
return stateOp.emitError("state with latency > 1 not supported");

// Grab all operands from the state op and make it drop all its references.
// This allows `materializeValue` to move an operation if this state was the
// last user.
auto stateClock = stateOp.getClock();
auto stateEnable = stateOp.getEnable();
auto stateReset = stateOp.getReset();
auto stateInputs = SmallVector<Value>(stateOp.getInputs());
template <typename CallOpTy>
LogicalResult ModuleLowering::lowerStateLike(
Operation *stateOp, Value stateClock, Value stateEnable, Value stateReset,
ArrayRef<Value> stateInputs, FlatSymbolRefAttr callee) {
// Grab all operands from the state op at the callsite and make it drop all
// its references. This allows `materializeValue` to move an operation if this
// state was the last user.

// Get the clock tree and enable condition for this state's clock. If this arc
// carries an explicit enable condition, fold that into the enable provided by
// the clock gates in the arc's clock tree.
auto info = getOrCreateClockLowering(stateClock);
info.enable = info.clock.getOrCreateAnd(
info.enable, info.clock.materializeValue(stateEnable), stateOp.getLoc());
info.enable, info.clock.materializeValue(stateEnable), stateOp->getLoc());

// Allocate the necessary state within the model.
SmallVector<Value> allocatedStates;
for (unsigned stateIdx = 0; stateIdx < stateOp.getNumResults(); ++stateIdx) {
auto type = stateOp.getResult(stateIdx).getType();
for (unsigned stateIdx = 0; stateIdx < stateOp->getNumResults(); ++stateIdx) {
auto type = stateOp->getResult(stateIdx).getType();
auto intType = dyn_cast<IntegerType>(type);
if (!intType)
return stateOp.emitOpError("result ")
return stateOp->emitOpError("result ")
<< stateIdx << " has non-integer type " << type
<< "; only integer types are supported";
auto stateType = StateType::get(intType);
auto state = stateBuilder.create<AllocStateOp>(stateOp.getLoc(), stateType,
auto state = stateBuilder.create<AllocStateOp>(stateOp->getLoc(), stateType,
storageArg);
if (auto names = stateOp->getAttrOfType<ArrayAttr>("names"))
state->setAttr("name", names[stateIdx]);
Expand All @@ -455,18 +456,18 @@ LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
OpBuilder nonResetBuilder = info.clock.builder;
if (stateReset) {
auto materializedReset = info.clock.materializeValue(stateReset);
auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp.getLoc(),
auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp->getLoc(),
materializedReset, true);

for (auto [alloc, resTy] :
llvm::zip(allocatedStates, stateOp.getResultTypes())) {
llvm::zip(allocatedStates, stateOp->getResultTypes())) {
if (!isa<IntegerType>(resTy))
stateOp->emitOpError("Non-integer result not supported yet!");

auto thenBuilder = ifOp.getThenBodyBuilder();
Value constZero =
thenBuilder.create<hw::ConstantOp>(stateOp.getLoc(), resTy, 0);
thenBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, constZero,
thenBuilder.create<hw::ConstantOp>(stateOp->getLoc(), resTy, 0);
thenBuilder.create<StateWriteOp>(stateOp->getLoc(), alloc, constZero,
Value());
}

Expand All @@ -475,24 +476,50 @@ LogicalResult ModuleLowering::lowerState(StateOp stateOp) {

stateOp->dropAllReferences();

auto newStateOp = nonResetBuilder.create<CallOp>(
stateOp.getLoc(), stateOp.getResultTypes(), stateOp.getArcAttr(),
auto newStateOp = nonResetBuilder.create<CallOpTy>(
stateOp->getLoc(), stateOp->getResultTypes(), callee,
materializedOperands);

// Create the write ops that write the result of the transfer function to the
// allocated state storage.
for (auto [alloc, result] :
llvm::zip(allocatedStates, newStateOp.getResults()))
nonResetBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, result,
nonResetBuilder.create<StateWriteOp>(stateOp->getLoc(), alloc, result,
info.enable);

// Replace all uses of the arc with reads from the allocated state.
for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp.getResults()))
for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp->getResults()))
replaceValueWithStateRead(result, alloc);
stateOp.erase();
stateOp->erase();
return success();
}

LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
// We don't support arcs beyond latency 1 yet. These should be easy to add in
// the future though.
if (stateOp.getLatency() > 1)
return stateOp.emitError("state with latency > 1 not supported");

auto stateInputs = SmallVector<Value>(stateOp.getInputs());

return lowerStateLike<arc::CallOp>(stateOp, stateOp.getClock(),
stateOp.getEnable(), stateOp.getReset(),
stateInputs, stateOp.getArcAttr());
}

LogicalResult ModuleLowering::lowerState(sim::DPICallOp callOp) {
// Clocked call op can be considered as arc state with single latency.
auto stateClock = callOp.getClock();
if (!stateClock)
return callOp.emitError("unclocked DPI call not implemented yet");

auto stateInputs = SmallVector<Value>(callOp.getInputs());

return lowerStateLike<func::CallOp>(callOp, stateClock, callOp.getEnable(),
Value(), stateInputs,
callOp.getCalleeAttr());
}

LogicalResult ModuleLowering::lowerState(MemoryOp memOp) {
auto allocMemOp = stateBuilder.create<AllocMemoryOp>(
memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs());
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Sim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_circt_dialect_library(CIRCTSim
CIRCTHW
CIRCTSeq
CIRCTSV
MLIRFuncDialect
MLIRIR
MLIRPass
MLIRTransforms
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/Sim/SimOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Dialect/HW/ModuleImplementation.h"
#include "circt/Dialect/SV/SVOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionImplementation.h"

Expand Down Expand Up @@ -69,12 +70,15 @@ ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {

LogicalResult
sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto referencedOp = dyn_cast_or_null<sim::DPIFuncOp>(
symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr()));
auto referencedOp =
symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr());
if (!referencedOp)
return emitError("cannot find function declaration '")
<< getCallee() << "'";
return success();
if (isa<func::FuncOp, sim::DPIFuncOp>(referencedOp))
return success();
return emitError("callee must be 'sim.dpi.func' or 'func.func' but got '")
<< referencedOp->getName() << "'";
}

void DPIFuncOp::print(OpAsmPrinter &p) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Sim/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_circt_dialect_library(CIRCTSimTransforms
LowerDPIFunc.cpp
ProceduralizeSim.cpp


Expand All @@ -12,8 +13,10 @@ add_circt_dialect_library(CIRCTSimTransforms
CIRCTSV
CIRCTComb
CIRCTSupport
MLIRFuncDialect
MLIRIR
MLIRPass
MLIRLLVMDialect
MLIRSCFDialect
MLIRTransformUtils
)
Loading

0 comments on commit 9828707

Please sign in to comment.