@@ -776,16 +776,13 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
776776 sliceOp, affineIndexInfoMap, whileBody, whileOp);
777777
778778 if (isValidForBatchingResult (result.result )) {
779- candidateSlices.push_back (DynamicSliceInfo{
780- sliceOp, result.dimensions , false , {}, affineIndexInfo, false });
779+ candidateSlices.push_back (
780+ DynamicSliceInfo{ sliceOp, result.dimensions , false , false , {} });
781781 }
782782 }
783783 }
784784 }
785785
786- if (candidateSlices.empty ())
787- return rewriter.notifyMatchFailure (whileOp, " no candidate slices found" );
788-
789786 bool anyOpRewritten = false ;
790787
791788 // iota [idx] where iota starts at 0 and iter var also starts at 0
@@ -817,19 +814,23 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
817814 // indexing with `scale * indVar + offset`
818815 // result = scale * indVar + (iotaStart + offset)
819816
817+ auto affineIndexInfo =
818+ affineIndexInfoMap[slice.sliceOp
819+ .getStartIndices ()[slice.dimensions [0 ]]];
820+
820821 auto scalarType = RankedTensorType::get ({}, opElemType);
821822
822- if (!slice. affineIndexInfo .scale .isOne ()) {
823+ if (!affineIndexInfo.scale .isOne ()) {
823824 newOperand = stablehlo::MulOp::create (
824825 rewriter, slice.sliceOp .getLoc (),
825826 stablehlo::ConstantOp::create (
826827 rewriter, slice.sliceOp .getLoc (), scalarType,
827828 cast<ElementsAttr>(makeAttr (
828- scalarType, slice. affineIndexInfo .scale .getSExtValue ()))),
829+ scalarType, affineIndexInfo.scale .getSExtValue ()))),
829830 newOperand);
830831 }
831832
832- auto indexOffset = slice. affineIndexInfo .offset .getSExtValue ();
833+ auto indexOffset = affineIndexInfo.offset .getSExtValue ();
833834 auto iotaStart = iotaDetection.value ().start ;
834835 auto offset = indexOffset + iotaStart;
835836
@@ -877,13 +878,23 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
877878
878879 for (auto user : op->getUsers ()) {
879880 userOpToSlicesMap[user].push_back (
880- DynamicSliceInfo{ds.sliceOp , ds.dimensions , true , reshapeShape,
881- ds. affineIndexInfo , needsManualReshape });
881+ DynamicSliceInfo{ds.sliceOp , ds.dimensions , true ,
882+ needsManualReshape, reshapeShape });
882883 }
883884 }
884885 }
885886 }
886887
888+ // for certain operations on index variables it is more efficient to hoist
889+ // those out of the loop and then perform indirect indexing
890+ for (auto &[val, slices] : affineIndexInfoMap) {
891+ for (auto user : val.getUsers ()) {
892+ if (isa<stablehlo::CompareOp, stablehlo::BroadcastInDimOp>(user)) {
893+ userOpToSlicesMap[user].push_back (DynamicSliceInfo ());
894+ }
895+ }
896+ }
897+
887898 if (userOpToSlicesMap.empty ())
888899 return anyOpRewritten ? success () : failure ();
889900
@@ -907,6 +918,7 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
907918 ArrayRef<DynamicSliceInfo> sliceOps, Operation *op,
908919 WhileLoopInfo info) const {
909920 auto moduleOp = op->getParentOfType <ModuleOp>();
921+ auto affineIndexInfoMap = info.getAffineIndexInfo ();
910922
911923 SmallVector<BatchLiftingMode> batchLiftingModes (op->getNumOperands ());
912924 SmallVector<Value> batchOperands (op->getNumOperands ());
@@ -934,6 +946,12 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
934946 continue ;
935947 }
936948
949+ if (affineIndexInfoMap.contains (operand)) {
950+ batchLiftingModes[i] = BatchLiftingMode::AFFINE_INDEX;
951+ batchOperands[i] = operand;
952+ continue ;
953+ }
954+
937955 auto defOp = operand.getDefiningOp ();
938956 if (!defOp) {
939957 return false ;
@@ -1101,10 +1119,39 @@ bool GreedyWhileLoopBatchFission::liftOperationByBatching(
11011119 break ;
11021120 }
11031121 case BatchLiftingMode::CONSTANT: {
1104- continue ; // copied into the function body no need to include in operands
1122+ break ; // copied into the function body no need to include in operands
11051123 }
1106- default : {
1107- assert (false && " not implemented" );
1124+ case BatchLiftingMode::AFFINE_INDEX: {
1125+ auto hoistedTy = RankedTensorType::get ({info.getConstantNumIters ()},
1126+ operandType.getElementType ());
1127+ Value loopIndices = stablehlo::IotaOp::create (
1128+ rewriter, whileOp->getLoc (),
1129+ RankedTensorType::get ({info.getConstantNumIters ()},
1130+ operandType.getElementType ()),
1131+ 0 );
1132+
1133+ auto createConst = [&](int64_t val) {
1134+ return stablehlo::ConstantOp::create (
1135+ rewriter, whileOp->getLoc (), hoistedTy,
1136+ cast<ElementsAttr>(makeAttr (hoistedTy, val)));
1137+ };
1138+
1139+ auto startVal = createConst (info.getConstantStart ().value ());
1140+ auto stepVal = createConst (info.getConstantStep ().value ());
1141+ loopIndices = stablehlo::AddOp::create (
1142+ rewriter, whileOp->getLoc (), loopIndices,
1143+ stablehlo::MulOp::create (rewriter, whileOp->getLoc (), stepVal,
1144+ startVal));
1145+
1146+ auto affineIndexInfo = info.getAffineIndexInfo ()[baseOp];
1147+ auto scale = createConst (affineIndexInfo.scale .getSExtValue ());
1148+ auto offset = createConst (affineIndexInfo.offset .getSExtValue ());
1149+ auto res = stablehlo::AddOp::create (
1150+ rewriter, whileOp->getLoc (),
1151+ stablehlo::MulOp::create (rewriter, whileOp->getLoc (), scale,
1152+ loopIndices),
1153+ offset);
1154+ newOperands.push_back (res);
11081155 break ;
11091156 }
11101157 }
0 commit comments