Skip to content

Commit 02f2244

Browse files
committed
use enzyme utils
1 parent ae9f672 commit 02f2244

File tree

2 files changed

+3
-68
lines changed

2 files changed

+3
-68
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -240,67 +240,6 @@ bool getEffectsAfter(Operation *op,
240240
return !conservative;
241241
}
242242

243-
bool isReadOnly(Operation *op) {
244-
bool hasRecursiveEffects = op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
245-
if (hasRecursiveEffects) {
246-
for (Region &region : op->getRegions()) {
247-
for (auto &block : region) {
248-
for (auto &nestedOp : block)
249-
if (!isReadOnly(&nestedOp))
250-
return false;
251-
}
252-
}
253-
return true;
254-
}
255-
256-
// If the op has memory effects, try to characterize them to see if the op
257-
// is trivially dead here.
258-
if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
259-
// Check to see if this op either has no effects, or only allocates/reads
260-
// memory.
261-
SmallVector<MemoryEffects::EffectInstance, 1> effects;
262-
effectInterface.getEffects(effects);
263-
if (!llvm::all_of(effects, [](const MemoryEffects::EffectInstance &it) {
264-
return isa<MemoryEffects::Read>(it.getEffect());
265-
})) {
266-
return false;
267-
}
268-
return true;
269-
}
270-
return false;
271-
}
272-
273-
bool isReadNone(Operation *op) {
274-
bool hasRecursiveEffects = op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
275-
if (hasRecursiveEffects) {
276-
for (Region &region : op->getRegions()) {
277-
for (auto &block : region) {
278-
for (auto &nestedOp : block)
279-
if (!isReadNone(&nestedOp))
280-
return false;
281-
}
282-
}
283-
return true;
284-
}
285-
286-
// If the op has memory effects, try to characterize them to see if the op
287-
// is trivially dead here.
288-
if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
289-
// Check to see if this op either has no effects, or only allocates/reads
290-
// memory.
291-
SmallVector<MemoryEffects::EffectInstance, 1> effects;
292-
effectInterface.getEffects(effects);
293-
if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &it) {
294-
return isa<MemoryEffects::Read>(it.getEffect()) ||
295-
isa<MemoryEffects::Write>(it.getEffect());
296-
})) {
297-
return false;
298-
}
299-
return true;
300-
}
301-
return false;
302-
}
303-
304243
const std::set<std::string> &getNonCapturingFunctions() {
305244
static std::set<std::string> NonCapturingFunctions = {
306245
"free", "printf", "fprintf", "scanf",

src/enzyme_ad/jax/Utils.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
namespace mlir {
4040
namespace enzyme {
4141

42+
using namespace ::mlir::enzyme::oputils;
43+
using isReadNone = isReadOnly;
44+
4245
template <typename T> inline Attribute makeAttr(mlir::Type elemType, T val) {
4346
if (auto TT = dyn_cast<RankedTensorType>(elemType))
4447
return SplatElementsAttr::get(
@@ -294,8 +297,6 @@ static inline bool hasElse(mlir::affine::AffineIfOp op) {
294297
return op.getElseRegion().getBlocks().size() > 0;
295298
}
296299

297-
const std::set<std::string> &getNonCapturingFunctions();
298-
299300
bool collectEffects(
300301
mlir::Operation *op,
301302
llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects,
@@ -311,14 +312,9 @@ bool getEffectsAfter(
311312
llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects,
312313
bool stopAtBarrier);
313314

314-
bool isReadOnly(mlir::Operation *);
315-
bool isReadNone(mlir::Operation *);
316-
317315
bool mayReadFrom(mlir::Operation *, mlir::Value);
318316
bool mayWriteTo(mlir::Operation *, mlir::Value, bool ignoreBarrier = false);
319317

320-
using ::mlir::enzyme::oputils::mayAlias;
321-
322318
template <typename AttrTy, typename T>
323319
SmallVector<Attribute> getUpdatedAttrList(Value val, StringRef attrName,
324320
T unknownValue, T newValue) {

0 commit comments

Comments
 (0)