Skip to content

Commit 77abd29

Browse files
authored
Use TensorIndexer by default (#5828)
This PR enables use of the new indexer by default. It still falls back to the legacy indexer for fusions with unsupported ops. I'll remove the NVFUSER_ENABLE option setting for `id_model` in a follow-up PR.
1 parent ac59eba commit 77abd29

File tree

1 file changed

+1
-67
lines changed

1 file changed

+1
-67
lines changed

csrc/device_lower/lower2device.cpp

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -295,73 +295,7 @@ namespace {
295295
// given Fusion
296296
IdModelOptions getIdModelOptions(Fusion* fusion) {
297297
IdModelOptions options;
298-
299-
for (auto expr : fusion->exprs()) {
300-
if (auto ldst = dynamic_cast<LoadStoreOp*>(expr)) {
301-
if (ldst->opType() == LoadStoreOpType::CpAsyncBulk) {
302-
options.setTensorIndexer(true);
303-
continue;
304-
}
305-
} else if (
306-
expr->isOneOf<ArgsortOp, PadOp, ScanOp, ScatterOp, SliceOp, TopKOp>()) {
307-
options.setTensorIndexer(true);
308-
continue;
309-
} else if (auto reshape = dynamic_cast<ReshapeOp*>(expr)) {
310-
// The legacy indexer has an issue when an expand broadcast is
311-
// involved in reshape transformations. Enable both tensor and
312-
// predicate indexing if found
313-
314-
auto producer_tv = reshape->in();
315-
auto consumer_tv = reshape->out();
316-
317-
// Find expanded producer IDs. Note that corresponding consumer IDs do
318-
// not inherit the iteration type and are no longer expanded IDs, so the
319-
// producer domain needs to be checked to find expanded IDs.
320-
std::unordered_set<IterDomain*> expanded_ids;
321-
std::copy_if(
322-
producer_tv->getLogicalDomain().begin(),
323-
producer_tv->getLogicalDomain().end(),
324-
std::inserter(expanded_ids, expanded_ids.end()),
325-
[](IterDomain* logical_id) {
326-
return logical_id->isBroadcast() && logical_id->hasExpandedExtent();
327-
});
328-
329-
if (expanded_ids.empty()) {
330-
continue;
331-
}
332-
333-
// Find corresponding consumer root IDs
334-
auto c2p = PairwiseLogicalDomainMap(producer_tv, consumer_tv)
335-
.mapConsumerToProducer();
336-
std::unordered_set<Val*> consumer_expanded_root_ids;
337-
for (auto consumer_root_id : consumer_tv->getRootDomain()) {
338-
auto producer_logical_id = c2p.at(consumer_root_id);
339-
if (expanded_ids.count(producer_logical_id)) {
340-
consumer_expanded_root_ids.insert(consumer_root_id);
341-
}
342-
}
343-
344-
auto reshape_exprs = DependencyCheck::getAllExprsBetween(
345-
{consumer_tv->getRootDomain().begin(),
346-
consumer_tv->getRootDomain().end()},
347-
{consumer_tv->getLogicalDomain().begin(),
348-
consumer_tv->getLogicalDomain().end()});
349-
350-
if (std::any_of(
351-
reshape_exprs.begin(),
352-
reshape_exprs.end(),
353-
[&consumer_expanded_root_ids](Expr* expr) {
354-
return std::any_of(
355-
expr->inputs().begin(),
356-
expr->inputs().end(),
357-
[&](Val* input) {
358-
return consumer_expanded_root_ids.count(input);
359-
});
360-
})) {
361-
options.setTensorIndexer(true);
362-
}
363-
}
364-
}
298+
options.setTensorIndexer(true);
365299

366300
// If not supported, disable use of TensorIndexer by default. It is
367301
// still used if explicitly opted-in (see, for example,

0 commit comments

Comments
 (0)