@@ -295,73 +295,7 @@ namespace {
295295// given Fusion
296296IdModelOptions getIdModelOptions (Fusion* fusion) {
297297 IdModelOptions options;
298-
299- for (auto expr : fusion->exprs ()) {
300- if (auto ldst = dynamic_cast <LoadStoreOp*>(expr)) {
301- if (ldst->opType () == LoadStoreOpType::CpAsyncBulk) {
302- options.setTensorIndexer (true );
303- continue ;
304- }
305- } else if (
306- expr->isOneOf <ArgsortOp, PadOp, ScanOp, ScatterOp, SliceOp, TopKOp>()) {
307- options.setTensorIndexer (true );
308- continue ;
309- } else if (auto reshape = dynamic_cast <ReshapeOp*>(expr)) {
310- // The legacy indexer has an issue when an expand broadcast is
311- // involved in reshape transformations. Enable both tensor and
312- // predicate indexing if found
313-
314- auto producer_tv = reshape->in ();
315- auto consumer_tv = reshape->out ();
316-
317- // Find expanded producer IDs. Note that corresponding consumer IDs do
318- // not inherit the iteration type and are no longer expanded IDs, so the
319- // producer domain needs to be checked to find expanded IDs.
320- std::unordered_set<IterDomain*> expanded_ids;
321- std::copy_if (
322- producer_tv->getLogicalDomain ().begin (),
323- producer_tv->getLogicalDomain ().end (),
324- std::inserter (expanded_ids, expanded_ids.end ()),
325- [](IterDomain* logical_id) {
326- return logical_id->isBroadcast () && logical_id->hasExpandedExtent ();
327- });
328-
329- if (expanded_ids.empty ()) {
330- continue ;
331- }
332-
333- // Find corresponding consumer root IDs
334- auto c2p = PairwiseLogicalDomainMap (producer_tv, consumer_tv)
335- .mapConsumerToProducer ();
336- std::unordered_set<Val*> consumer_expanded_root_ids;
337- for (auto consumer_root_id : consumer_tv->getRootDomain ()) {
338- auto producer_logical_id = c2p.at (consumer_root_id);
339- if (expanded_ids.count (producer_logical_id)) {
340- consumer_expanded_root_ids.insert (consumer_root_id);
341- }
342- }
343-
344- auto reshape_exprs = DependencyCheck::getAllExprsBetween (
345- {consumer_tv->getRootDomain ().begin (),
346- consumer_tv->getRootDomain ().end ()},
347- {consumer_tv->getLogicalDomain ().begin (),
348- consumer_tv->getLogicalDomain ().end ()});
349-
350- if (std::any_of (
351- reshape_exprs.begin (),
352- reshape_exprs.end (),
353- [&consumer_expanded_root_ids](Expr* expr) {
354- return std::any_of (
355- expr->inputs ().begin (),
356- expr->inputs ().end (),
357- [&](Val* input) {
358- return consumer_expanded_root_ids.count (input);
359- });
360- })) {
361- options.setTensorIndexer (true );
362- }
363- }
364- }
298+ options.setTensorIndexer (true );
365299
366300 // If not supported, disable use of TensorIndexer by default. It is
367301 // still used if explicitly opted-in (see, for example,
0 commit comments