Skip to content

Commit 70a7f58

Browse files
committed
feat: hoist chain of ops
1 parent 62b1abd commit 70a7f58

File tree

6 files changed

+194
-39
lines changed

6 files changed

+194
-39
lines changed

src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ bool WhileLoopInfo::isConstantValue(Value v, llvm::APInt &constVal) {
214214
return true;
215215

216216
Value outerValue;
217-
if (isConstantAcrossIterations(v, outerValue) &&
217+
SmallVector<Operation *> canBeHoisted;
218+
if (isConstantAcrossIterations(v, outerValue, canBeHoisted, false) &&
218219
matchPattern(outerValue, m_ConstantInt(&constVal)))
219220
return true;
220221
return false;
@@ -295,11 +296,13 @@ void WhileLoopInfo::propagateAffineIndexInfo() {
295296

296297
bool WhileLoopInfo::isConstantAcrossIterations(Value v, bool checkOperands) {
297298
Value outerValue;
298-
return isConstantAcrossIterations(v, outerValue, checkOperands);
299+
SmallVector<Operation *> canBeHoisted;
300+
return isConstantAcrossIterations(v, outerValue, canBeHoisted, checkOperands);
299301
}
300302

301-
bool WhileLoopInfo::isConstantAcrossIterations(Value v, Value &outerValue,
302-
bool checkOperands) {
303+
bool WhileLoopInfo::isConstantAcrossIterations(
304+
Value v, Value &outerValue, SmallVector<Operation *> &canBeHoisted,
305+
bool checkOperands) {
303306
if (definedOutside(v, op)) {
304307
outerValue = v;
305308
return true;
@@ -326,12 +329,15 @@ bool WhileLoopInfo::isConstantAcrossIterations(Value v, Value &outerValue,
326329

327330
// all operands of the defining op are constant across iterations
328331
// don't populate the outerValue in this case
329-
return llvm::all_of(defOp->getOperands(), [&](Value operand) {
330-
// TODO: we should do `isConstantAcrossIterations` but for now we do a more
331-
// conservative check
332-
// return isConstantAcrossIterations(operand);
333-
return definedOutside(operand, op);
334-
});
332+
if (llvm::all_of(defOp->getOperands(), [&](Value operand) {
333+
return isConstantAcrossIterations(operand, outerValue, canBeHoisted,
334+
true);
335+
})) {
336+
outerValue = nullptr;
337+
canBeHoisted.push_back(defOp);
338+
return true;
339+
}
340+
return false;
335341
}
336342

337343
template <typename OpTy>
@@ -423,8 +429,6 @@ bool WhileLoopInfo::hoistOperationFromLoop(
423429
if (!canHoistOperationFromLoop(sliceOp, dimensions))
424430
return false;
425431

426-
auto totalIterCount = getConstantNumIters();
427-
428432
auto depIndex = sliceOp.getStartIndices()[sliceIndex];
429433
auto indexTy = depIndex.getType();
430434

@@ -439,7 +443,6 @@ bool WhileLoopInfo::hoistOperationFromLoop(
439443
auto step = getConstantStep().value();
440444
auto lb = getConstantStart().value();
441445
auto ub = getConstantLimit().value();
442-
int64_t N = ub - lb;
443446

444447
auto rawMin = scale * lb + offset;
445448
auto rawMax = scale * (ub - 1) + offset;

src/enzyme_ad/jax/Implementations/WhileLoopInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ struct WhileLoopInfo {
6565

6666
bool isConstantAcrossIterations(Value v, bool checkOperands = true);
6767
bool isConstantAcrossIterations(Value v, Value &outerValue,
68+
SmallVector<Operation *> &canBeHoisted,
6869
bool checkOperands = true);
6970

7071
bool canHoistOperationFromLoop(mlir::stablehlo::DynamicSliceOp sliceOp,

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "src/enzyme_ad/jax/Passes/AutoBatching.h"
22

33
#include "Enzyme/MLIR/Passes/EnzymeBatchPass.h"
4+
#include "mlir/Analysis/TopologicalSortUtils.h"
45
#include "mlir/Dialect/Func/IR/FuncOps.h"
56
#include "mlir/IR/Builders.h"
67
#include "mlir/IR/Matchers.h"
@@ -925,11 +926,14 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
925926
SmallVector<SmallVector<int64_t>> sliceDims(op->getNumOperands());
926927
SmallVector<int64_t> hoistedDims(op->getNumOperands());
927928
SmallVector<DynamicSliceInfo> mappedSliceInfos(op->getNumOperands());
929+
DenseMap<Value, SmallVector<Operation *>> hoistMap;
930+
928931
for (int i = 0; i < op->getNumOperands(); i++) {
929932
auto operand = op->getOperand(i);
930933

931934
Value outerValue;
932-
if (info.isConstantAcrossIterations(operand, outerValue)) {
935+
SmallVector<Operation *> canBeHoisted;
936+
if (info.isConstantAcrossIterations(operand, outerValue, canBeHoisted)) {
933937
if (outerValue) {
934938
SplatElementsAttr splat;
935939
if (matchPattern(operand, m_Constant(&splat))) {
@@ -939,6 +943,7 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
939943
}
940944
batchOperands[i] = outerValue;
941945
} else {
946+
hoistMap[operand] = canBeHoisted;
942947
hoistedDims[i] = cast<mlir::OpResult>(operand).getResultNumber();
943948
batchLiftingModes[i] = BatchLiftingMode::NEEDS_HOISTING_OUTSIDE_WHILE;
944949
batchOperands[i] = operand;
@@ -1040,6 +1045,34 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
10401045

10411046
rewriter.setInsertionPoint(whileOp);
10421047

1048+
// hoist any operations that can be hoisted
1049+
DenseMap<Value, Value> hoistedValues;
1050+
for (auto &[val, ops] : hoistMap) {
1051+
llvm::SetVector<Operation *> toHoist(ops.begin(), ops.end());
1052+
auto sorted = mlir::topologicalSort(toHoist);
1053+
IRMapping mapper;
1054+
1055+
for (auto &op : sorted) {
1056+
for (auto operand : op->getOperands()) {
1057+
if (!definedOutside(operand, whileOp)) {
1058+
Value outerValue;
1059+
SmallVector<Operation *> canBeHoisted;
1060+
if (info.isConstantAcrossIterations(operand, outerValue, canBeHoisted,
1061+
false)) {
1062+
mapper.map(operand, outerValue);
1063+
}
1064+
}
1065+
}
1066+
auto hoisted = rewriter.clone(*op, mapper);
1067+
for (auto [origRes, newRes] :
1068+
llvm::zip(op->getResults(), hoisted->getResults())) {
1069+
mapper.map(origRes, newRes);
1070+
}
1071+
}
1072+
1073+
hoistedValues[val] = mapper.lookup(val);
1074+
}
1075+
10431076
SmallVector<Value> newOperands;
10441077
for (auto [consType, baseOp, sliceDim, sliceInfo, hoistDim] :
10451078
llvm::zip(batchLiftingModes, batchOperands, sliceDims, mappedSliceInfos,
@@ -1097,9 +1130,8 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
10971130
break;
10981131
}
10991132
case BatchLiftingMode::NEEDS_HOISTING_OUTSIDE_WHILE: {
1100-
auto hoisted = rewriter.clone(*baseOp.getDefiningOp());
1101-
baseOp = hoisted->getResult(hoistDim);
1102-
// intentionally fallthrough
1133+
baseOp = hoistedValues[baseOp];
1134+
LLVM_FALLTHROUGH;
11031135
}
11041136
case BatchLiftingMode::DEFINED_OUTSIDE_WHILE: {
11051137
auto operandShape = operandType.getShape();

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2210,7 +2210,6 @@ struct DUSPad final
22102210

22112211
SmallVector<int64_t> newDusStartIndexValues;
22122212
SmallVector<Value> newDusStartIndices;
2213-
SmallVector<int64_t> newDusStartIndexValues;
22142213
Location loc = dus.getLoc();
22152214
auto indexElementType =
22162215
cast<ShapedType>(startIndices[0].getType()).getElementType();

test/lit_tests/autobatching/higher_order_post_diff.mlir

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=dot_general_licm(0);elementwise_licm(0);greedy_while_loop_batch_fission;while_is_copy_simplify;remove_no_ops_from_while_loop;dynamic_slice_reshape_dynamic_slice" --transform-interpreter --enzyme-hlo-remove-transform --inline --enzyme-hlo-opt %s | FileCheck %s
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=greedy_while_loop_batch_fission;while_is_copy_simplify;remove_no_ops_from_while_loop;dynamic_slice_reshape_dynamic_slice" --transform-interpreter --enzyme-hlo-remove-transform --enzyme-hlo-opt %s | FileCheck %s
22

33
func.func @main(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>, %arg2: tensor<3x5xf32>) -> tensor<3x5xf32> {
44
%cst = stablehlo.constant dense<"0x000000400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040"> : tensor<15x3x5xf32>
@@ -83,29 +83,29 @@ func.func @main(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>, %arg2: tensor<3x5x
8383
// CHECK-NEXT: %4 = stablehlo.dot_general %0, %cst_0, batching_dims = [0] x [0], contracting_dims = [1] x [2] : (tensor<15x5x5xf32>, tensor<15x3x5xf32>) -> tensor<15x5x3xf32>
8484
// CHECK-NEXT: %5 = stablehlo.dot_general %arg0, %arg2, contracting_dims = [0] x [1] : (tensor<5x5xf32>, tensor<3x5xf32>) -> tensor<5x3xf32>
8585
// CHECK-NEXT: %6 = stablehlo.add %5, %3 : tensor<5x3xf32>
86-
// CHECK-NEXT: %7 = stablehlo.broadcast_in_dim %6, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
87-
// CHECK-NEXT: %8 = stablehlo.multiply %7, %4 : tensor<15x5x3xf32>
88-
// CHECK-NEXT: %9 = stablehlo.multiply %8, %cst : tensor<15x5x3xf32>
89-
// CHECK-NEXT: %10 = stablehlo.multiply %cst_4, %6 : tensor<5x3xf32>
90-
// CHECK-NEXT: %11 = stablehlo.broadcast_in_dim %10, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
91-
// CHECK-NEXT: %12 = stablehlo.multiply %9, %11 : tensor<15x5x3xf32>
92-
// CHECK-NEXT: %13 = stablehlo.multiply %6, %6 : tensor<5x3xf32>
93-
// CHECK-NEXT: %14 = stablehlo.multiply %13, %cst_3 : tensor<5x3xf32>
94-
// CHECK-NEXT: %15 = stablehlo.add %14, %cst_2 : tensor<5x3xf32>
95-
// CHECK-NEXT: %16 = stablehlo.broadcast_in_dim %15, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
96-
// CHECK-NEXT: %17 = stablehlo.multiply %2, %16 : tensor<15x5x3xf32>
97-
// CHECK-NEXT: %18 = stablehlo.add %17, %12 : tensor<15x5x3xf32>
98-
// CHECK-NEXT: %19 = stablehlo.multiply %10, %15 : tensor<5x3xf32>
99-
// CHECK-NEXT: %20 = stablehlo.logistic %19 : tensor<5x3xf32>
100-
// CHECK-NEXT: %21 = stablehlo.broadcast_in_dim %20, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
101-
// CHECK-NEXT: %22 = stablehlo.multiply %1, %21 : tensor<15x5x3xf32>
102-
// CHECK-NEXT: %23 = stablehlo.logistic %19 : tensor<5x3xf32>
86+
// CHECK-NEXT: %7 = stablehlo.multiply %6, %6 : tensor<5x3xf32>
87+
// CHECK-NEXT: %8 = stablehlo.multiply %7, %cst_3 : tensor<5x3xf32>
88+
// CHECK-NEXT: %9 = stablehlo.add %8, %cst_2 : tensor<5x3xf32>
89+
// CHECK-NEXT: %10 = stablehlo.broadcast_in_dim %9, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
90+
// CHECK-NEXT: %11 = stablehlo.multiply %2, %10 : tensor<15x5x3xf32>
91+
// CHECK-NEXT: %12 = stablehlo.broadcast_in_dim %6, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
92+
// CHECK-NEXT: %13 = stablehlo.multiply %12, %4 : tensor<15x5x3xf32>
93+
// CHECK-NEXT: %14 = stablehlo.multiply %cst_4, %6 : tensor<5x3xf32>
94+
// CHECK-NEXT: %15 = stablehlo.multiply %14, %9 : tensor<5x3xf32>
95+
// CHECK-NEXT: %16 = stablehlo.logistic %15 : tensor<5x3xf32>
96+
// CHECK-NEXT: %17 = stablehlo.broadcast_in_dim %16, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
97+
// CHECK-NEXT: %18 = stablehlo.multiply %1, %17 : tensor<15x5x3xf32>
98+
// CHECK-NEXT: %19 = stablehlo.multiply %13, %cst : tensor<15x5x3xf32>
99+
// CHECK-NEXT: %20 = stablehlo.broadcast_in_dim %14, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
100+
// CHECK-NEXT: %21 = stablehlo.multiply %19, %20 : tensor<15x5x3xf32>
101+
// CHECK-NEXT: %22 = stablehlo.add %11, %21 : tensor<15x5x3xf32>
102+
// CHECK-NEXT: %23 = stablehlo.logistic %15 : tensor<5x3xf32>
103103
// CHECK-NEXT: %24 = stablehlo.subtract %cst_2, %23 : tensor<5x3xf32>
104104
// CHECK-NEXT: %25 = stablehlo.multiply %23, %24 : tensor<5x3xf32>
105105
// CHECK-NEXT: %26 = stablehlo.broadcast_in_dim %25, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
106-
// CHECK-NEXT: %27 = stablehlo.multiply %18, %26 : tensor<15x5x3xf32>
107-
// CHECK-NEXT: %28 = stablehlo.multiply %27, %7 : tensor<15x5x3xf32>
108-
// CHECK-NEXT: %29 = stablehlo.add %22, %28 : tensor<15x5x3xf32>
106+
// CHECK-NEXT: %27 = stablehlo.multiply %22, %26 : tensor<15x5x3xf32>
107+
// CHECK-NEXT: %28 = stablehlo.multiply %27, %12 : tensor<15x5x3xf32>
108+
// CHECK-NEXT: %29 = stablehlo.add %18, %28 : tensor<15x5x3xf32>
109109
// CHECK-NEXT: %30:2 = stablehlo.while(%iterArg = %c_7, %iterArg_12 = %cst_9) : tensor<i64>, tensor<3x5xf32>
110110
// CHECK-NEXT: cond {
111111
// CHECK-NEXT: %31 = stablehlo.compare LT, %iterArg, %c_11 : (tensor<i64>, tensor<i64>) -> tensor<i1>

0 commit comments

Comments
 (0)