Skip to content

chisel: support in-place ops via MemoryEffectOpInterface#8639

Open
ndrakulicTT wants to merge 10 commits into
mainfrom
ndrakulic/python-inplace-ops
Open

chisel: support in-place ops via MemoryEffectOpInterface#8639
ndrakulicTT wants to merge 10 commits into
mainfrom
ndrakulic/python-inplace-ops

Conversation

@ndrakulicTT
Copy link
Copy Markdown
Contributor

@ndrakulicTT ndrakulicTT commented May 28, 2026

Summary

Closes #8385

Replaces the hand-maintained _CHISEL_INPLACE_OPS table with an IR-derived view of in-place mutation, and teaches chisel's per-op flow to validate mutated operands alongside SSA results.

What changes

  • C++ binding (python/Util.cpp): new get_write_effect_operand_indices that returns flat operand indices on which an op declares MemoryEffects::Write, or None when the op doesn't implement MemoryEffectOpInterface.
  • chisel/ops.py: get_inplace_vals(op) now derives mutated tensor operands from the binding above instead of the static dict.
    The _CHISEL_INPLACE_OPS table and get_inplace_operands helper are removed.
  • chisel/callbacks.py_default_post_op: SSA results and in-place mutated operands are validated in a single loop.
    Each enabled numerics mode (isolated / accumulated) produces a numerics record per result and per mutated operand.
    The pool is refreshed for in-place operands when a golden is registered.
  • chisel/callbacks.py_evict_inplace_no_golden: for no-golden ops that declare MemWrite, the pooled golden for each mutated operand is dropped (with a golden_evicted record), preventing false PCC failures on downstream ops.
  • chisel/executor.py: drops execute_golden_from_pool — the unified post-op loop in callbacks.py handles pool refresh inline.
  • chisel/op_configs.py: UpdateCacheOp, PagedUpdateCacheOp, FillCacheOp, PagedFillCacheOp, and BatchNormTrainingOp no longer need to be marked no_golden=True — they have working chisel goldens and run through the in-place validation path.
  • Tests: new test_chisel_records_update_cache_inplace and test_chisel_records_batch_norm_training_inplace exercising the new path end-to-end (numerics records under both isolation and accumulation, no spurious golden_evicted records).

Follow-ups

Two follow-up PRs are planned on top of this one:

  1. Device tensor pool — reintroduce the program-scoped device-tensor cache that was prototyped on this branch and then removed before merge.
    Goal: stop redundantly pulling the same SSA to host across consumer ops and across iso/accum modes.
  2. Test refactor — the two new in-place tests in this PR share most of their structure with the existing chisel integration tests (session setup, record filtering, PCC assertions).
    A follow-up will factor the common scaffolding into shared helpers / parametrize across ops.

Copilot AI review requested due to automatic review settings May 28, 2026 09:58
@ndrakulicTT ndrakulicTT requested a review from a team as a code owner May 28, 2026 09:58
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates chisel’s notion of in-place mutation by deriving mutated operands from MLIR’s MemoryEffectOpInterface (instead of a hand-maintained Python table), and extends the post-op numerics validation flow to validate both SSA results and mutated operands uniformly (including pool refresh/eviction behavior).

Changes:

  • Add a new Python binding (ttmlir.util.get_write_effect_operand_indices) to query operand indices with MemoryEffects::Write.
  • Replace the static in-place-op table in chisel with an interface-driven get_inplace_vals(op) and unify post-op numerics validation for SSA + in-place operands.
  • Add/adjust goldens and integration tests for in-place validation paths (e.g., cache update ops, batch_norm_training).

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
python/Util.cpp Adds C++/nanobind helper to expose write-effect operand indices to Python.
tools/chisel/chisel/ops.py Replaces static in-place operand table with MemoryEffectOpInterface-driven detection.
tools/chisel/chisel/callbacks.py Unifies post-op validation loop for SSA outputs + in-place mutated operands; evicts stale pool entries for no-golden in-place ops.
tools/chisel/chisel/executor.py Removes execute_golden_from_pool and updates golden execution to use new in-place operand detection.
tools/chisel/chisel/op_configs.py Removes no-golden flags for ops that now have working in-place goldens/validation path.
tools/golden/mapping.py Adds chisel golden for ttnn.batch_norm_training; fixes paged cache goldens to use operand dtype for in-place ops with no SSA results.
test/python/chisel/test_golden_execution.py Updates test to use new get_inplace_vals API and adjusts messages accordingly.
test/python/chisel/test_builder_chisel_integration.py Adds end-to-end in-place numerics record tests for update_cache and batch_norm_training.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/Util.cpp
Comment on lines +5 to 35
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include <nanobind/stl/vector.h>
#include <optional>
#include <variant>

namespace mlir::ttmlir::python {

static std::optional<std::vector<int64_t>>
collectWriteEffectOperandIndices(mlir::Operation *op) {
auto iface = mlir::dyn_cast<mlir::MemoryEffectOpInterface>(op);
if (!iface) {
return std::nullopt;
}

llvm::SmallVector<mlir::MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);

std::vector<int64_t> indices;
for (const auto &eff : effects) {
if (!mlir::isa<mlir::MemoryEffects::Write>(eff.getEffect())) {
continue;
}
if (mlir::OpOperand *operand = eff.getEffectValue<mlir::OpOperand *>()) {
indices.push_back(static_cast<int64_t>(operand->getOperandNumber()));
}
}
return indices;
}

void populateUtilModule(nb::module_ &m) {
Comment on lines +592 to +597
ssas = {r.ssa for r in for_mode}
assert len(ssas) >= 2, (
f"expected numerics records for the SSA result plus at least one "
f"in-place operand on ttnn.batch_norm_training in {mode} mode, "
f"got ssas={ssas}"
)
Copy link
Copy Markdown
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python/Util.cpp LGTM, one comment inline.

Comment thread python/Util.cpp

namespace mlir::ttmlir::python {

static std::optional<std::vector<int64_t>>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type is a bit inconsistent. Function returns nullopt if op doesn't implement MemoryEffectOpInterface, but returns an empty vector if it does but there isn't any operand with a MemoryEffects::Write. I'm not sure if it makes sense to discriminate between these two cases.

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 9.52381% with 19 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.95%. Comparing base (3bae50d) to head (2ba2b74).
⚠️ Report is 13 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
python/Util.cpp 9.52% 19 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #8639      +/-   ##
==========================================
+ Coverage   70.85%   70.95%   +0.09%     
==========================================
  Files         497      503       +6     
  Lines       93246    93594     +348     
==========================================
+ Hits        66069    66405     +336     
- Misses      27177    27189      +12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Copy Markdown
Contributor

@mmilosevicTT mmilosevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, apart from changes that other guys requested

Comment on lines +209 to +216
modes: list[NumericsMode] = []
if ctx.checks_config.isolation:
iso_outs = execute_golden_with_ssa_inputs(op, ctx.stashed_inputs, asm_state)
else:
iso_outs = [None] * len(mlir_op_outputs)

modes.append(NumericsMode.ISOLATED)
if ctx.checks_config.accumulation:
accum_outs = execute_golden_from_pool(op, ctx.golden_tensor_pool, asm_state)
else:
accum_outs = [None] * len(mlir_op_outputs)
modes.append(NumericsMode.ACCUMULATED)

for mlir_output, output_ref, iso_out, accum_out in zip(
mlir_op_outputs, ctx.output_refs, iso_outs, accum_outs, strict=True
):
device_tensor = _validate_and_retrieve_tensor(ctx, mlir_output, output_ref)
ssa = mlir_output.get_name(asm_state)
if len(modes) == 0:
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
modes: list[NumericsMode] = []
if ctx.checks_config.isolation:
iso_outs = execute_golden_with_ssa_inputs(op, ctx.stashed_inputs, asm_state)
else:
iso_outs = [None] * len(mlir_op_outputs)
modes.append(NumericsMode.ISOLATED)
if ctx.checks_config.accumulation:
accum_outs = execute_golden_from_pool(op, ctx.golden_tensor_pool, asm_state)
else:
accum_outs = [None] * len(mlir_op_outputs)
modes.append(NumericsMode.ACCUMULATED)
for mlir_output, output_ref, iso_out, accum_out in zip(
mlir_op_outputs, ctx.output_refs, iso_outs, accum_outs, strict=True
):
device_tensor = _validate_and_retrieve_tensor(ctx, mlir_output, output_ref)
ssa = mlir_output.get_name(asm_state)
if len(modes) == 0:
return
if not (ctx.checks_config.isolation or ctx.checks_config.accumulated):
return
modes: list[NumericsMode] = []
if ctx.checks_config.isolation:
modes.append(NumericsMode.ISOLATED)
if ctx.checks_config.accumulation:
modes.append(NumericsMode.ACCUMULATED)

Or I would even fail if chisel context is run without any mode set.

Comment on lines +249 to +252
if mode is NumericsMode.ISOLATED:
continue

ctx.golden_tensor_pool[ssa] = golden_out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if mode is NumericsMode.ISOLATED:
continue
ctx.golden_tensor_pool[ssa] = golden_out
if mode is NumericsMode.ACCUMULATED:
ctx.golden_tensor_pool[ssa] = golden_out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[python] Derive in-place / MemWrite operand list from ODS

5 participants