Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1077,4 +1077,11 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
];
}

def WhileLoopOutsideValuesAddToArgumentListPass : Pass<
"while-loop-outside-values-add-to-argument-list"> {
let dependentDialects = [
"stablehlo::StablehloDialect"
];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "stablehlo/dialect/StablehloOps.h"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_WHILELOOPOUTSIDEVALUESADDTOARGUMENTLISTPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::stablehlo;

namespace {

static bool definedOutside(Value v, Operation *op) {
return !op->isAncestor(v.getParentBlock()->getParentOp());
}

struct SHLOWhileOpUpdateArgumentListPattern final
: public OpRewritePattern<stablehlo::WhileOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::WhileOp whileOp,
PatternRewriter &rewriter) const override {
// Collect values used inside cond/body that are defined outside the WhileOp
SmallVector<Value, 4> extraValues;
SmallPtrSet<Value, 8> seen;

auto collectExternal = [&](Region &region) {
region.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
Value v = operand.get();
if (!v)
continue;
if (definedOutside(v, whileOp) && !seen.contains(v)) {
seen.insert(v);
extraValues.push_back(v);
}
}
});
};

collectExternal(whileOp.getCond());
collectExternal(whileOp.getBody());

if (extraValues.empty())
return failure();

// Build new operand list = existing operands + external values
SmallVector<Value, 8> newOperands(whileOp.getOperands().begin(),
whileOp.getOperands().end());
for (Value v : extraValues)
newOperands.push_back(v);

SmallVector<Type, 8> newResultTypes;
newResultTypes.reserve(newOperands.size());
for (Value v : newOperands)
newResultTypes.push_back(v.getType());

auto newWhile = stablehlo::WhileOp::create(rewriter, whileOp.getLoc(),
newResultTypes, newOperands);

rewriter.inlineRegionBefore(whileOp.getCond(), newWhile.getCond(),
newWhile.getCond().end());
rewriter.inlineRegionBefore(whileOp.getBody(), newWhile.getBody(),
newWhile.getBody().end());

// Append block arguments for the extra values
Block &condBlock = newWhile.getCond().front();
Block &bodyBlock = newWhile.getBody().front();

unsigned origArgCount = whileOp.getNumOperands();
SmallVector<BlockArgument, 8> addedCondArgs, addedBodyArgs;
addedCondArgs.reserve(extraValues.size());
addedBodyArgs.reserve(extraValues.size());
for (Value v : extraValues) {
addedCondArgs.push_back(condBlock.addArgument(v.getType(), v.getLoc()));
addedBodyArgs.push_back(bodyBlock.addArgument(v.getType(), v.getLoc()));
}

// Remap uses of external values inside the regions to the new block args
auto remapRegionUses = [&](Region &region, ArrayRef<Value> externals,
ArrayRef<BlockArgument> args) {
region.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
Value v = operand.get();
for (auto [ext, arg] : llvm::zip(externals, args)) {
if (v == ext) {
operand.set(arg);
break;
}
}
}
});
};

remapRegionUses(newWhile.getCond(), extraValues, addedCondArgs);
remapRegionUses(newWhile.getBody(), extraValues, addedBodyArgs);

Operation *terminator = bodyBlock.getTerminator();
if (!terminator) {
return rewriter.notifyMatchFailure(whileOp, "missing body terminator");
}

auto retOp = dyn_cast<stablehlo::ReturnOp>(terminator);
assert(retOp && "expected stablehlo::ReturnOp");

SmallVector<Value, 8> newRetVals(retOp.getOperands().begin(),
retOp.getOperands().end());
for (BlockArgument arg : addedBodyArgs)
newRetVals.push_back(arg);

rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<stablehlo::ReturnOp>(terminator, newRetVals);

for (unsigned i = 0; i < origArgCount; ++i)
rewriter.replaceAllUsesWith(whileOp.getResult(i), newWhile.getResult(i));
rewriter.eraseOp(whileOp);
return success();
}
};

struct WhileLoopOutsideValuesAddToArgumentListPass
: public enzyme::impl::WhileLoopOutsideValuesAddToArgumentListPassBase<
WhileLoopOutsideValuesAddToArgumentListPass> {
using WhileLoopOutsideValuesAddToArgumentListPassBase::
WhileLoopOutsideValuesAddToArgumentListPassBase;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<SHLOWhileOpUpdateArgumentListPattern>(patterns.getContext());
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};

} // namespace
67 changes: 67 additions & 0 deletions test/lit_tests/loop_all_values_defined.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: enzymexlamlir-opt %s --while-loop-outside-values-add-to-argument-list | FileCheck %s

func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> {
%cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
%c = stablehlo.constant dense<1> : tensor<i32>
%c_1 = stablehlo.constant dense<5> : tensor<i64>
%c_2 = stablehlo.constant dense<2> : tensor<i64>
%c_3 = stablehlo.constant dense<0> : tensor<i64>
%c_4 = stablehlo.constant dense<10> : tensor<i64>
%c_5 = stablehlo.constant dense<1> : tensor<i64>
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32>
%0:2 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6) : tensor<i64>, tensor<13xf32>
cond {
%1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
} do {
%1 = stablehlo.add %c_5, %iterArg : tensor<i64>
%2 = stablehlo.multiply %c_2, %1 : tensor<i64>
%3 = stablehlo.add %2, %c_1 : tensor<i64>
%4 = stablehlo.convert %3 : (tensor<i64>) -> tensor<i32>
%5 = stablehlo.subtract %4, %c : tensor<i32>
%6 = stablehlo.dynamic_slice %arg0, %5, sizes = [1] : (tensor<25xf32>, tensor<i32>) -> tensor<1xf32>
%7 = stablehlo.multiply %6, %cst : tensor<1xf32>
%8 = stablehlo.subtract %7, %cst_0 : tensor<1xf32>
%9 = stablehlo.sine %8 : tensor<1xf32>
%10 = stablehlo.add %1, %c_2 : tensor<i64>
%11 = stablehlo.convert %10 : (tensor<i64>) -> tensor<i32>
%12 = stablehlo.subtract %11, %c : tensor<i32>
%13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor<i32>) -> tensor<13xf32>
stablehlo.return %1, %13 : tensor<i64>, tensor<13xf32>
}
return %0#1 : tensor<13xf32>
}

// CHECK: func.func @main(%arg0: tensor<25xf32>) -> tensor<13xf32> {
// CHECK-NEXT: %cst = stablehlo.constant dense<3.000000e+00> : tensor<1xf32>
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<i32>
// CHECK-NEXT: %c_1 = stablehlo.constant dense<5> : tensor<i64>
// CHECK-NEXT: %c_2 = stablehlo.constant dense<2> : tensor<i64>
// CHECK-NEXT: %c_3 = stablehlo.constant dense<0> : tensor<i64>
// CHECK-NEXT: %c_4 = stablehlo.constant dense<10> : tensor<i64>
// CHECK-NEXT: %c_5 = stablehlo.constant dense<1> : tensor<i64>
// CHECK-NEXT: %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<13xf32>
// CHECK-NEXT: %0:10 = stablehlo.while(%iterArg = %c_3, %iterArg_7 = %cst_6, %iterArg_8 = %c_4, %iterArg_9 = %c_5, %iterArg_10 = %c_2, %iterArg_11 = %c_1, %iterArg_12 = %c, %iterArg_13 = %arg0, %iterArg_14 = %cst, %iterArg_15 = %cst_0) : tensor<i64>, tensor<13xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i32>, tensor<25xf32>, tensor<1xf32>, tensor<1xf32>
// CHECK-NEXT: cond {
// CHECK-NEXT: %1 = stablehlo.compare LT, %iterArg, %iterArg_8 : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: stablehlo.return %1 : tensor<i1>
// CHECK-NEXT: } do {
// CHECK-NEXT: %1 = stablehlo.add %iterArg_9, %iterArg : tensor<i64>
// CHECK-NEXT: %2 = stablehlo.multiply %iterArg_10, %1 : tensor<i64>
// CHECK-NEXT: %3 = stablehlo.add %2, %iterArg_11 : tensor<i64>
// CHECK-NEXT: %4 = stablehlo.convert %3 : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: %5 = stablehlo.subtract %4, %iterArg_12 : tensor<i32>
// CHECK-NEXT: %6 = stablehlo.dynamic_slice %iterArg_13, %5, sizes = [1] : (tensor<25xf32>, tensor<i32>) -> tensor<1xf32>
// CHECK-NEXT: %7 = stablehlo.multiply %6, %iterArg_14 : tensor<1xf32>
// CHECK-NEXT: %8 = stablehlo.subtract %7, %iterArg_15 : tensor<1xf32>
// CHECK-NEXT: %9 = stablehlo.sine %8 : tensor<1xf32>
// CHECK-NEXT: %10 = stablehlo.add %1, %iterArg_10 : tensor<i64>
// CHECK-NEXT: %11 = stablehlo.convert %10 : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: %12 = stablehlo.subtract %11, %iterArg_12 : tensor<i32>
// CHECK-NEXT: %13 = stablehlo.dynamic_update_slice %iterArg_7, %9, %12 : (tensor<13xf32>, tensor<1xf32>, tensor<i32>) -> tensor<13xf32>
// CHECK-NEXT: stablehlo.return %1, %13, %iterArg_8, %iterArg_9, %iterArg_10, %iterArg_11, %iterArg_12, %iterArg_13, %iterArg_14, %iterArg_15 : tensor<i64>, tensor<13xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i32>, tensor<25xf32>, tensor<1xf32>, tensor<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %0#1 : tensor<13xf32>
// CHECK-NEXT: }
Loading