Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
namespace nvfuser {

TensorIndexer::TensorIndexer(IdModel& id_model) : id_model_(id_model) {
NVF_ERROR(isSupported(id_model.fusion()));

buildLoopIndexMap();

if (isDebugDumpEnabled(DebugDumpOption::IndexingVerbose)) {
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/expr_eval_sched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) {
// TODO: remove IndexPutAccumulateOp
if (exprs.front()
->isOneOf<
GatherOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding GatherOp here routes ALL gather operations (including exact-sized takeAlongAxis) to ExprEval/ATen evaluation. The PR title says "but not takeAlongAxis", suggesting exact gather should still be compiled. Consider filtering to only accept non-exact gather:

Suggested change
GatherOp,
!exprs.front()->isa<GatherOp>() || !exprs.front()->as<GatherOp>()->exactSizes() ? GatherOp : void,

Or clarify if the performance regression for takeAlongAxis is intentional.

ScatterOp,
SdpaFwdOp,
SdpaBwdOp,
Expand Down
10 changes: 10 additions & 0 deletions csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) {
return false;
}

// Support of non-exact gather was dropped when the legacy indexer was
// deprecated
if (std::ranges::any_of(
ir_utils::getOpsOfType<GatherOp>(fusion),
[](GatherOp* gather) { return !gather->exactSizes(); })) {
scheduler_debug_utils::canScheduleRejectReason(
scheduler_type, "Non-exact gather ops");
return false;
}

// Fusions with `MatmulOp, LinearOp, MmaOp` can only be accepted by Matmul
// scheduler.
if (scheduler_type != SchedulerType::Matmul &&
Expand Down
8 changes: 5 additions & 3 deletions tests/cpp/test_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ TEST_F(GatherTest, TakeAlongAxisIntermediateTensorReduction1) {

validateSegmentation(
executor_cache.getMostRecentKernelRuntime(),
{SchedulerType::Reduction, SchedulerType::PointWise});
{SchedulerType::Reduction, SchedulerType::ExprEval});

testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);
}
Expand Down Expand Up @@ -1127,7 +1127,8 @@ TEST_F(GatherTest, TakeAlongAxisCrossEntropyLoss) {
}

// Test grouped reduction on IterType::GatherScatter
TEST_F(GatherTest, GatherIterGoupedReduction) {
// Codegen support of non-exact gather dropped
TEST_F(GatherTest, DISABLED_GatherIterGoupedReduction) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still plan to support this later? wondering if we should remove the tests instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me remove it

const int max_dim_size = 128;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
Expand Down Expand Up @@ -1211,7 +1212,8 @@ TEST_F(GatherTest, GatherIterGoupedReduction) {
lparams);
}

TEST_F(GatherTest, SameTvUsedAsLookupAndIndex) {
// Codegen support of non-exact gather dropped
TEST_F(GatherTest, DISABLED_SameTvUsedAsLookupAndIndex) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);
Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/test_persistent_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,9 @@ TEST_F(PersistentBufferTest, BufferGatherLookupTv) {
auto tv2 = sum(tv1, {1});
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = broadcast(index_tv, {false, true});
auto tv5 = gather(tv0, 1, tv4);
// Use takeAlongAxis rather than gather as codegen does not support
// the latter
auto tv5 = takeAlongAxis(tv0, tv4, 1);
auto tv6 = maybeCastOp(DataType::BFloat16, tv5);
auto tv7 = add(tv3, tv6);
auto tv8 = add(tv1, tv7);
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/test_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2563,7 +2563,7 @@ TEST_F(ReductionTest, CrossEntropyGatherPattern) {
fusion.addInput(labels);

auto tv2 = broadcast(labels, {false, true});
auto tv3 = gather(log_probs, 1, tv2);
auto tv3 = takeAlongAxis(log_probs, tv2, 1);
auto tv4 = squeeze(tv3, std::vector<bool>({false, true}));

fusion.addOutput(tv4);
Expand Down