Skip to content

Commit 62b1abd

Browse files
committed
feat: more operation hoisting
1 parent 1335d13 commit 62b1abd

File tree

2 files changed

+62
-15
lines changed

2 files changed

+62
-15
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -776,16 +776,13 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
776776
sliceOp, affineIndexInfoMap, whileBody, whileOp);
777777

778778
if (isValidForBatchingResult(result.result)) {
779-
candidateSlices.push_back(DynamicSliceInfo{
780-
sliceOp, result.dimensions, false, {}, affineIndexInfo, false});
779+
candidateSlices.push_back(
780+
DynamicSliceInfo{sliceOp, result.dimensions, false, false, {}});
781781
}
782782
}
783783
}
784784
}
785785

786-
if (candidateSlices.empty())
787-
return rewriter.notifyMatchFailure(whileOp, "no candidate slices found");
788-
789786
bool anyOpRewritten = false;
790787

791788
// iota [idx] where iota starts at 0 and iter var also starts at 0
@@ -817,19 +814,23 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
817814
// indexing with `scale * indVar + offset`
818815
// result = scale * indVar + (iotaStart + offset)
819816

817+
auto affineIndexInfo =
818+
affineIndexInfoMap[slice.sliceOp
819+
.getStartIndices()[slice.dimensions[0]]];
820+
820821
auto scalarType = RankedTensorType::get({}, opElemType);
821822

822-
if (!slice.affineIndexInfo.scale.isOne()) {
823+
if (!affineIndexInfo.scale.isOne()) {
823824
newOperand = stablehlo::MulOp::create(
824825
rewriter, slice.sliceOp.getLoc(),
825826
stablehlo::ConstantOp::create(
826827
rewriter, slice.sliceOp.getLoc(), scalarType,
827828
cast<ElementsAttr>(makeAttr(
828-
scalarType, slice.affineIndexInfo.scale.getSExtValue()))),
829+
scalarType, affineIndexInfo.scale.getSExtValue()))),
829830
newOperand);
830831
}
831832

832-
auto indexOffset = slice.affineIndexInfo.offset.getSExtValue();
833+
auto indexOffset = affineIndexInfo.offset.getSExtValue();
833834
auto iotaStart = iotaDetection.value().start;
834835
auto offset = indexOffset + iotaStart;
835836

@@ -877,13 +878,23 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
877878

878879
for (auto user : op->getUsers()) {
879880
userOpToSlicesMap[user].push_back(
880-
DynamicSliceInfo{ds.sliceOp, ds.dimensions, true, reshapeShape,
881-
ds.affineIndexInfo, needsManualReshape});
881+
DynamicSliceInfo{ds.sliceOp, ds.dimensions, true,
882+
needsManualReshape, reshapeShape});
882883
}
883884
}
884885
}
885886
}
886887

888+
// for certain operations on index variables it is more efficient to hoist
889+
// those out of the loop and then perform indirect indexing
890+
for (auto &[val, slices] : affineIndexInfoMap) {
891+
for (auto user : val.getUsers()) {
892+
if (isa<stablehlo::CompareOp, stablehlo::BroadcastInDimOp>(user)) {
893+
userOpToSlicesMap[user].push_back(DynamicSliceInfo());
894+
}
895+
}
896+
}
897+
887898
if (userOpToSlicesMap.empty())
888899
return anyOpRewritten ? success() : failure();
889900

@@ -907,6 +918,7 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
907918
ArrayRef<DynamicSliceInfo> sliceOps, Operation *op,
908919
WhileLoopInfo info) const {
909920
auto moduleOp = op->getParentOfType<ModuleOp>();
921+
auto affineIndexInfoMap = info.getAffineIndexInfo();
910922

911923
SmallVector<BatchLiftingMode> batchLiftingModes(op->getNumOperands());
912924
SmallVector<Value> batchOperands(op->getNumOperands());
@@ -934,6 +946,12 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
934946
continue;
935947
}
936948

949+
if (affineIndexInfoMap.contains(operand)) {
950+
batchLiftingModes[i] = BatchLiftingMode::AFFINE_INDEX;
951+
batchOperands[i] = operand;
952+
continue;
953+
}
954+
937955
auto defOp = operand.getDefiningOp();
938956
if (!defOp) {
939957
return false;
@@ -1101,10 +1119,39 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
11011119
break;
11021120
}
11031121
case BatchLiftingMode::CONSTANT: {
1104-
continue; // copied into the function body no need to include in operands
1122+
break; // copied into the function body no need to include in operands
11051123
}
1106-
default: {
1107-
assert(false && "not implemented");
1124+
case BatchLiftingMode::AFFINE_INDEX: {
1125+
auto hoistedTy = RankedTensorType::get({info.getConstantNumIters()},
1126+
operandType.getElementType());
1127+
Value loopIndices = stablehlo::IotaOp::create(
1128+
rewriter, whileOp->getLoc(),
1129+
RankedTensorType::get({info.getConstantNumIters()},
1130+
operandType.getElementType()),
1131+
0);
1132+
1133+
auto createConst = [&](int64_t val) {
1134+
return stablehlo::ConstantOp::create(
1135+
rewriter, whileOp->getLoc(), hoistedTy,
1136+
cast<ElementsAttr>(makeAttr(hoistedTy, val)));
1137+
};
1138+
1139+
auto startVal = createConst(info.getConstantStart().value());
1140+
auto stepVal = createConst(info.getConstantStep().value());
1141+
loopIndices = stablehlo::AddOp::create(
1142+
rewriter, whileOp->getLoc(), loopIndices,
1143+
stablehlo::MulOp::create(rewriter, whileOp->getLoc(), stepVal,
1144+
startVal));
1145+
1146+
auto affineIndexInfo = info.getAffineIndexInfo()[baseOp];
1147+
auto scale = createConst(affineIndexInfo.scale.getSExtValue());
1148+
auto offset = createConst(affineIndexInfo.offset.getSExtValue());
1149+
auto res = stablehlo::AddOp::create(
1150+
rewriter, whileOp->getLoc(),
1151+
stablehlo::MulOp::create(rewriter, whileOp->getLoc(), scale,
1152+
loopIndices),
1153+
offset);
1154+
newOperands.push_back(res);
11081155
break;
11091156
}
11101157
}

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ struct GreedyWhileLoopBatchFission
214214
DEFINED_OUTSIDE_WHILE,
215215
CONSTANT,
216216
NEEDS_HOISTING_OUTSIDE_WHILE,
217+
AFFINE_INDEX,
217218
};
218219

219220
enum class IsValidForBatchingResult {
@@ -231,9 +232,8 @@ struct GreedyWhileLoopBatchFission
231232
mlir::stablehlo::DynamicSliceOp sliceOp;
232233
llvm::SmallVector<int64_t> dimensions;
233234
bool intermediateReshape;
234-
llvm::SmallVector<int64_t> reshapeShape;
235-
mlir::enzyme::WhileLoopInfo::AffineIndexInfo affineIndexInfo;
236235
bool needsManualReshape;
236+
llvm::SmallVector<int64_t> reshapeShape;
237237
};
238238

239239
struct ValidBatchingInfo {

0 commit comments

Comments
 (0)