Skip to content

Commit 1335d13

Browse files
committed
feat: support licm inside batching
1 parent 0af2908 commit 1335d13

File tree

7 files changed

+82
-49
lines changed

7 files changed

+82
-49
lines changed

src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,13 @@ void WhileLoopInfo::propagateAffineIndexInfo() {
293293
}
294294
}
295295

296-
bool WhileLoopInfo::isConstantAcrossIterations(Value v) {
296+
bool WhileLoopInfo::isConstantAcrossIterations(Value v, bool checkOperands) {
297297
Value outerValue;
298-
return isConstantAcrossIterations(v, outerValue);
298+
return isConstantAcrossIterations(v, outerValue, checkOperands);
299299
}
300300

301-
bool WhileLoopInfo::isConstantAcrossIterations(Value v, Value &outerValue) {
301+
bool WhileLoopInfo::isConstantAcrossIterations(Value v, Value &outerValue,
302+
bool checkOperands) {
302303
if (definedOutside(v, op)) {
303304
outerValue = v;
304305
return true;
@@ -316,7 +317,21 @@ bool WhileLoopInfo::isConstantAcrossIterations(Value v, Value &outerValue) {
316317
}
317318
}
318319

319-
return false;
320+
if (!checkOperands)
321+
return false;
322+
323+
auto defOp = v.getDefiningOp();
324+
if (!defOp)
325+
return false;
326+
327+
// all operands of the defining op are constant across iterations
328+
// don't populate the outerValue in this case
329+
return llvm::all_of(defOp->getOperands(), [&](Value operand) {
330+
// TODO: we should do `isConstantAcrossIterations` but for now we do a more
331+
// conservative check
332+
// return isConstantAcrossIterations(operand);
333+
return definedOutside(operand, op);
334+
});
320335
}
321336

322337
template <typename OpTy>

src/enzyme_ad/jax/Implementations/WhileLoopInfo.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ struct WhileLoopInfo {
6363
return affineIndexInfo;
6464
}
6565

66-
bool isConstantAcrossIterations(Value v);
67-
bool isConstantAcrossIterations(Value v, Value &outerValue);
66+
bool isConstantAcrossIterations(Value v, bool checkOperands = true);
67+
bool isConstantAcrossIterations(Value v, Value &outerValue,
68+
bool checkOperands = true);
6869

6970
bool canHoistOperationFromLoop(mlir::stablehlo::DynamicSliceOp sliceOp,
7071
SmallVectorImpl<int64_t> &dimensions);

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -911,26 +911,33 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
911911
SmallVector<BatchLiftingMode> batchLiftingModes(op->getNumOperands());
912912
SmallVector<Value> batchOperands(op->getNumOperands());
913913
SmallVector<SmallVector<int64_t>> sliceDims(op->getNumOperands());
914+
SmallVector<int64_t> hoistedDims(op->getNumOperands());
914915
SmallVector<DynamicSliceInfo> mappedSliceInfos(op->getNumOperands());
915916
for (int i = 0; i < op->getNumOperands(); i++) {
916917
auto operand = op->getOperand(i);
917918

918-
Value outerValue = operand;
919-
if (operand.getParentBlock() != &whileOp.getBody().front() ||
920-
info.isConstantAcrossIterations(operand, outerValue)) {
921-
SplatElementsAttr splat;
922-
if (matchPattern(operand, m_Constant(&splat))) {
923-
batchLiftingModes[i] = BatchLiftingMode::CONSTANT;
919+
Value outerValue;
920+
if (info.isConstantAcrossIterations(operand, outerValue)) {
921+
if (outerValue) {
922+
SplatElementsAttr splat;
923+
if (matchPattern(operand, m_Constant(&splat))) {
924+
batchLiftingModes[i] = BatchLiftingMode::CONSTANT;
925+
} else {
926+
batchLiftingModes[i] = BatchLiftingMode::DEFINED_OUTSIDE_WHILE;
927+
}
928+
batchOperands[i] = outerValue;
924929
} else {
925-
batchLiftingModes[i] = BatchLiftingMode::DEFINED_OUTSIDE_WHILE;
930+
hoistedDims[i] = cast<mlir::OpResult>(operand).getResultNumber();
931+
batchLiftingModes[i] = BatchLiftingMode::NEEDS_HOISTING_OUTSIDE_WHILE;
932+
batchOperands[i] = operand;
926933
}
927-
batchOperands[i] = outerValue;
928934
continue;
929935
}
930936

931937
auto defOp = operand.getDefiningOp();
932-
if (!defOp)
938+
if (!defOp) {
933939
return false;
940+
}
934941

935942
Operation *dsOp;
936943
bool mustBeIntermediateReshape = false;
@@ -998,14 +1005,15 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
9981005
rewriter.setInsertionPointToStart(&entryBlock);
9991006

10001007
IRMapping mapper;
1001-
for (int i = 0; i < op->getNumOperands(); i++) {
1002-
auto operand = op->getOperand(i);
1003-
if (batchLiftingModes[i] == BatchLiftingMode::CONSTANT) {
1008+
size_t argIdx = 0;
1009+
for (auto [batchLiftMode, operand] :
1010+
llvm::zip(batchLiftingModes, op->getOperands())) {
1011+
if (batchLiftMode == BatchLiftingMode::CONSTANT) {
10041012
auto clonedConst = rewriter.clone(*operand.getDefiningOp());
10051013
mapper.map(operand, clonedConst->getResult(0));
10061014
continue;
10071015
}
1008-
mapper.map(operand, entryBlock.getArguments()[i]);
1016+
mapper.map(operand, entryBlock.getArguments()[argIdx++]);
10091017
}
10101018

10111019
auto unbatchedOp = rewriter.clone(*op, mapper);
@@ -1015,8 +1023,9 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
10151023
rewriter.setInsertionPoint(whileOp);
10161024

10171025
SmallVector<Value> newOperands;
1018-
for (auto [consType, baseOp, sliceDim, sliceInfo] : llvm::zip(
1019-
batchLiftingModes, batchOperands, sliceDims, mappedSliceInfos)) {
1026+
for (auto [consType, baseOp, sliceDim, sliceInfo, hoistDim] :
1027+
llvm::zip(batchLiftingModes, batchOperands, sliceDims, mappedSliceInfos,
1028+
hoistedDims)) {
10201029
auto operandType = cast<RankedTensorType>(baseOp.getType());
10211030
int operandRank = cast<RankedTensorType>(baseOp.getType()).getRank();
10221031

@@ -1069,6 +1078,11 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
10691078
newOperands.push_back(newOperand);
10701079
break;
10711080
}
1081+
case BatchLiftingMode::NEEDS_HOISTING_OUTSIDE_WHILE: {
1082+
auto hoisted = rewriter.clone(*baseOp.getDefiningOp());
1083+
baseOp = hoisted->getResult(hoistDim);
1084+
// intentionally fallthrough
1085+
}
10721086
case BatchLiftingMode::DEFINED_OUTSIDE_WHILE: {
10731087
auto operandShape = operandType.getShape();
10741088
SmallVector<int64_t> newOperandShape(operandRank + 1);

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ struct GreedyWhileLoopBatchFission
213213
DYNAMIC_SLICE,
214214
DEFINED_OUTSIDE_WHILE,
215215
CONSTANT,
216+
NEEDS_HOISTING_OUTSIDE_WHILE,
216217
};
217218

218219
enum class IsValidForBatchingResult {

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25203,7 +25203,8 @@ struct WhileIsCopySimplify
2520325203
SmallVector<int64_t> inductionVarDimensions;
2520425204

2520525205
for (auto [i, startIndex] : llvm::enumerate(startIndices)) {
25206-
if (info.isConstantAcrossIterations(startIndex))
25206+
// we could hoist the other dimensions but licm should fix this
25207+
if (info.isConstantAcrossIterations(startIndex, false))
2520725208
continue;
2520825209

2520925210
if (!affineIndexInfo.contains(startIndex))

test/lit_tests/autobatching/dot_general_loop.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt --enzyme-hlo-opt --auto-batching --inline --enzyme-hlo-generate-td="patterns=reshape_dynamic_slice(1);reshape_licm(1);transpose_dynamic_slice;transpose_licm(1);while_is_copy_simplify;reshape_elementwise(1);elementwise_licm(1)" --transform-interpreter --enzyme-hlo-remove-transform --enzyme-hlo-opt %s | FileCheck %s
1+
// RUN: enzymexlamlir-opt --auto-batching --enzyme-hlo-opt %s | FileCheck %s
22

33
module {
44
func.func @main(%arg0: tensor<3x5x10xf32> {enzymexla.memory_effects = []}, %arg1: tensor<4x3xf32> {enzymexla.memory_effects = []}) -> tensor<4x5x10xf32> attributes {enzymexla.memory_effects = []} {

test/lit_tests/autobatching/higher_order_post_diff.mlir

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,39 +87,40 @@ func.func @main(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>, %arg2: tensor<3x5x
8787
// CHECK-NEXT: %8 = stablehlo.multiply %7, %4 : tensor<15x5x3xf32>
8888
// CHECK-NEXT: %9 = stablehlo.multiply %8, %cst : tensor<15x5x3xf32>
8989
// CHECK-NEXT: %10 = stablehlo.multiply %cst_4, %6 : tensor<5x3xf32>
90-
// CHECK-NEXT: %11 = stablehlo.multiply %6, %6 : tensor<5x3xf32>
91-
// CHECK-NEXT: %12 = stablehlo.broadcast_in_dim %10, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
92-
// CHECK-NEXT: %13 = stablehlo.multiply %9, %12 : tensor<15x5x3xf32>
93-
// CHECK-NEXT: %14 = stablehlo.multiply %11, %cst_3 : tensor<5x3xf32>
90+
// CHECK-NEXT: %11 = stablehlo.broadcast_in_dim %10, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
91+
// CHECK-NEXT: %12 = stablehlo.multiply %9, %11 : tensor<15x5x3xf32>
92+
// CHECK-NEXT: %13 = stablehlo.multiply %6, %6 : tensor<5x3xf32>
93+
// CHECK-NEXT: %14 = stablehlo.multiply %13, %cst_3 : tensor<5x3xf32>
9494
// CHECK-NEXT: %15 = stablehlo.add %14, %cst_2 : tensor<5x3xf32>
9595
// CHECK-NEXT: %16 = stablehlo.broadcast_in_dim %15, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
9696
// CHECK-NEXT: %17 = stablehlo.multiply %2, %16 : tensor<15x5x3xf32>
97-
// CHECK-NEXT: %18 = stablehlo.add %17, %13 : tensor<15x5x3xf32>
97+
// CHECK-NEXT: %18 = stablehlo.add %17, %12 : tensor<15x5x3xf32>
9898
// CHECK-NEXT: %19 = stablehlo.multiply %10, %15 : tensor<5x3xf32>
9999
// CHECK-NEXT: %20 = stablehlo.logistic %19 : tensor<5x3xf32>
100100
// CHECK-NEXT: %21 = stablehlo.broadcast_in_dim %20, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
101101
// CHECK-NEXT: %22 = stablehlo.multiply %1, %21 : tensor<15x5x3xf32>
102-
// CHECK-NEXT: %23 = stablehlo.subtract %cst_2, %20 : tensor<5x3xf32>
103-
// CHECK-NEXT: %24 = stablehlo.multiply %20, %23 : tensor<5x3xf32>
104-
// CHECK-NEXT: %25 = stablehlo.broadcast_in_dim %24, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
105-
// CHECK-NEXT: %26 = stablehlo.multiply %18, %25 : tensor<15x5x3xf32>
106-
// CHECK-NEXT: %27 = stablehlo.multiply %26, %7 : tensor<15x5x3xf32>
107-
// CHECK-NEXT: %28 = stablehlo.add %22, %27 : tensor<15x5x3xf32>
108-
// CHECK-NEXT: %29:2 = stablehlo.while(%iterArg = %c_7, %iterArg_12 = %cst_9) : tensor<i64>, tensor<3x5xf32>
102+
// CHECK-NEXT: %23 = stablehlo.logistic %19 : tensor<5x3xf32>
103+
// CHECK-NEXT: %24 = stablehlo.subtract %cst_2, %23 : tensor<5x3xf32>
104+
// CHECK-NEXT: %25 = stablehlo.multiply %23, %24 : tensor<5x3xf32>
105+
// CHECK-NEXT: %26 = stablehlo.broadcast_in_dim %25, dims = [1, 2] : (tensor<5x3xf32>) -> tensor<15x5x3xf32>
106+
// CHECK-NEXT: %27 = stablehlo.multiply %18, %26 : tensor<15x5x3xf32>
107+
// CHECK-NEXT: %28 = stablehlo.multiply %27, %7 : tensor<15x5x3xf32>
108+
// CHECK-NEXT: %29 = stablehlo.add %22, %28 : tensor<15x5x3xf32>
109+
// CHECK-NEXT: %30:2 = stablehlo.while(%iterArg = %c_7, %iterArg_12 = %cst_9) : tensor<i64>, tensor<3x5xf32>
109110
// CHECK-NEXT: cond {
110-
// CHECK-NEXT: %30 = stablehlo.compare LT, %iterArg, %c_11 : (tensor<i64>, tensor<i64>) -> tensor<i1>
111-
// CHECK-NEXT: stablehlo.return %30 : tensor<i1>
111+
// CHECK-NEXT: %31 = stablehlo.compare LT, %iterArg, %c_11 : (tensor<i64>, tensor<i64>) -> tensor<i1>
112+
// CHECK-NEXT: stablehlo.return %31 : tensor<i1>
112113
// CHECK-NEXT: } do {
113-
// CHECK-NEXT: %30 = stablehlo.add %c_8, %iterArg : tensor<i64>
114-
// CHECK-NEXT: %31 = stablehlo.remainder %iterArg, %c_5 : tensor<i64>
115-
// CHECK-NEXT: %32 = stablehlo.add %31, %c_8 : tensor<i64>
116-
// CHECK-NEXT: %33 = stablehlo.convert %32 : (tensor<i64>) -> tensor<i32>
117-
// CHECK-NEXT: %34 = stablehlo.subtract %33, %c_6 : tensor<i32>
118-
// CHECK-NEXT: %35 = stablehlo.convert %34 : (tensor<i32>) -> tensor<i64>
119-
// CHECK-NEXT: %36 = stablehlo.dynamic_slice %28, %iterArg, %35, %c_7, sizes = [1, 1, 1] : (tensor<15x5x3xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x1xf32>
120-
// CHECK-NEXT: %37 = stablehlo.reshape %36 : (tensor<1x1x1xf32>) -> tensor<1x1xf32>
121-
// CHECK-NEXT: %38 = stablehlo.dynamic_update_slice %iterArg_12, %37, %c, %34 : (tensor<3x5xf32>, tensor<1x1xf32>, tensor<i32>, tensor<i32>) -> tensor<3x5xf32>
122-
// CHECK-NEXT: stablehlo.return %30, %38 : tensor<i64>, tensor<3x5xf32>
114+
// CHECK-NEXT: %31 = stablehlo.add %c_8, %iterArg : tensor<i64>
115+
// CHECK-NEXT: %32 = stablehlo.remainder %iterArg, %c_5 : tensor<i64>
116+
// CHECK-NEXT: %33 = stablehlo.add %32, %c_8 : tensor<i64>
117+
// CHECK-NEXT: %34 = stablehlo.convert %33 : (tensor<i64>) -> tensor<i32>
118+
// CHECK-NEXT: %35 = stablehlo.subtract %34, %c_6 : tensor<i32>
119+
// CHECK-NEXT: %36 = stablehlo.convert %35 : (tensor<i32>) -> tensor<i64>
120+
// CHECK-NEXT: %37 = stablehlo.dynamic_slice %29, %iterArg, %36, %c_7, sizes = [1, 1, 1] : (tensor<15x5x3xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x1xf32>
121+
// CHECK-NEXT: %38 = stablehlo.reshape %37 : (tensor<1x1x1xf32>) -> tensor<1x1xf32>
122+
// CHECK-NEXT: %39 = stablehlo.dynamic_update_slice %iterArg_12, %38, %c, %35 : (tensor<3x5xf32>, tensor<1x1xf32>, tensor<i32>, tensor<i32>) -> tensor<3x5xf32>
123+
// CHECK-NEXT: stablehlo.return %31, %39 : tensor<i64>, tensor<3x5xf32>
123124
// CHECK-NEXT: }
124-
// CHECK-NEXT: return %29#1 : tensor<3x5xf32>
125-
// CHECK-NEXT: }
125+
// CHECK-NEXT: return %30#1 : tensor<3x5xf32>
126+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)