chisel: support in-place ops via MemoryEffectOpInterface#8639
chisel: support in-place ops via MemoryEffectOpInterface#8639ndrakulicTT wants to merge 10 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates chisel’s notion of in-place mutation by deriving mutated operands from MLIR’s MemoryEffectOpInterface (instead of a hand-maintained Python table), and extends the post-op numerics validation flow to validate both SSA results and mutated operands uniformly (including pool refresh/eviction behavior).
Changes:
- Add a new Python binding (
ttmlir.util.get_write_effect_operand_indices) to query operand indices withMemoryEffects::Write. - Replace the static in-place-op table in chisel with an interface-driven
get_inplace_vals(op)and unify post-op numerics validation for SSA + in-place operands. - Add/adjust goldens and integration tests for in-place validation paths (e.g., cache update ops, batch_norm_training).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
python/Util.cpp |
Adds C++/nanobind helper to expose write-effect operand indices to Python. |
tools/chisel/chisel/ops.py |
Replaces static in-place operand table with MemoryEffectOpInterface-driven detection. |
tools/chisel/chisel/callbacks.py |
Unifies post-op validation loop for SSA outputs + in-place mutated operands; evicts stale pool entries for no-golden in-place ops. |
tools/chisel/chisel/executor.py |
Removes execute_golden_from_pool and updates golden execution to use new in-place operand detection. |
tools/chisel/chisel/op_configs.py |
Removes no-golden flags for ops that now have working in-place goldens/validation path. |
tools/golden/mapping.py |
Adds chisel golden for ttnn.batch_norm_training; fixes paged cache goldens to use operand dtype for in-place ops with no SSA results. |
test/python/chisel/test_golden_execution.py |
Updates test to use new get_inplace_vals API and adjusts messages accordingly. |
test/python/chisel/test_builder_chisel_integration.py |
Adds end-to-end in-place numerics record tests for update_cache and batch_norm_training. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #include "mlir/Interfaces/SideEffectInterfaces.h" | ||
| #include "ttmlir/Bindings/Python/TTMLIRModule.h" | ||
| #include <nanobind/stl/vector.h> | ||
| #include <optional> | ||
| #include <variant> | ||
|
|
||
| namespace mlir::ttmlir::python { | ||
|
|
||
| static std::optional<std::vector<int64_t>> | ||
| collectWriteEffectOperandIndices(mlir::Operation *op) { | ||
| auto iface = mlir::dyn_cast<mlir::MemoryEffectOpInterface>(op); | ||
| if (!iface) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| llvm::SmallVector<mlir::MemoryEffects::EffectInstance> effects; | ||
| iface.getEffects(effects); | ||
|
|
||
| std::vector<int64_t> indices; | ||
| for (const auto &eff : effects) { | ||
| if (!mlir::isa<mlir::MemoryEffects::Write>(eff.getEffect())) { | ||
| continue; | ||
| } | ||
| if (mlir::OpOperand *operand = eff.getEffectValue<mlir::OpOperand *>()) { | ||
| indices.push_back(static_cast<int64_t>(operand->getOperandNumber())); | ||
| } | ||
| } | ||
| return indices; | ||
| } | ||
|
|
||
| void populateUtilModule(nb::module_ &m) { |
| ssas = {r.ssa for r in for_mode} | ||
| assert len(ssas) >= 2, ( | ||
| f"expected numerics records for the SSA result plus at least one " | ||
| f"in-place operand on ttnn.batch_norm_training in {mode} mode, " | ||
| f"got ssas={ssas}" | ||
| ) |
azecevicTT
left a comment
There was a problem hiding this comment.
python/Util.cpp LGTM, one comment inline.
|
|
||
| namespace mlir::ttmlir::python { | ||
|
|
||
| static std::optional<std::vector<int64_t>> |
There was a problem hiding this comment.
Return type is a bit inconsistent. Function returns nullopt if op doesn't implement MemoryEffectOpInterface, but returns an empty vector if it does but there isn't any operand with a MemoryEffects::Write. I'm not sure if it makes sense to discriminate between these two cases.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8639 +/- ##
==========================================
+ Coverage 70.85% 70.95% +0.09%
==========================================
Files 497 503 +6
Lines 93246 93594 +348
==========================================
+ Hits 66069 66405 +336
- Misses 27177 27189 +12 ☔ View full report in Codecov by Sentry. |
mmilosevicTT
left a comment
There was a problem hiding this comment.
LGTM, apart from changes that other guys requested
| modes: list[NumericsMode] = [] | ||
| if ctx.checks_config.isolation: | ||
| iso_outs = execute_golden_with_ssa_inputs(op, ctx.stashed_inputs, asm_state) | ||
| else: | ||
| iso_outs = [None] * len(mlir_op_outputs) | ||
|
|
||
| modes.append(NumericsMode.ISOLATED) | ||
| if ctx.checks_config.accumulation: | ||
| accum_outs = execute_golden_from_pool(op, ctx.golden_tensor_pool, asm_state) | ||
| else: | ||
| accum_outs = [None] * len(mlir_op_outputs) | ||
| modes.append(NumericsMode.ACCUMULATED) | ||
|
|
||
| for mlir_output, output_ref, iso_out, accum_out in zip( | ||
| mlir_op_outputs, ctx.output_refs, iso_outs, accum_outs, strict=True | ||
| ): | ||
| device_tensor = _validate_and_retrieve_tensor(ctx, mlir_output, output_ref) | ||
| ssa = mlir_output.get_name(asm_state) | ||
| if len(modes) == 0: | ||
| return |
There was a problem hiding this comment.
| modes: list[NumericsMode] = [] | |
| if ctx.checks_config.isolation: | |
| iso_outs = execute_golden_with_ssa_inputs(op, ctx.stashed_inputs, asm_state) | |
| else: | |
| iso_outs = [None] * len(mlir_op_outputs) | |
| modes.append(NumericsMode.ISOLATED) | |
| if ctx.checks_config.accumulation: | |
| accum_outs = execute_golden_from_pool(op, ctx.golden_tensor_pool, asm_state) | |
| else: | |
| accum_outs = [None] * len(mlir_op_outputs) | |
| modes.append(NumericsMode.ACCUMULATED) | |
| for mlir_output, output_ref, iso_out, accum_out in zip( | |
| mlir_op_outputs, ctx.output_refs, iso_outs, accum_outs, strict=True | |
| ): | |
| device_tensor = _validate_and_retrieve_tensor(ctx, mlir_output, output_ref) | |
| ssa = mlir_output.get_name(asm_state) | |
| if len(modes) == 0: | |
| return | |
| if not (ctx.checks_config.isolation or ctx.checks_config.accumulated): | |
| return | |
| modes: list[NumericsMode] = [] | |
| if ctx.checks_config.isolation: | |
| modes.append(NumericsMode.ISOLATED) | |
| if ctx.checks_config.accumulation: | |
| modes.append(NumericsMode.ACCUMULATED) |
Or I would even fail if chisel context is run without any mode set.
| if mode is NumericsMode.ISOLATED: | ||
| continue | ||
|
|
||
| ctx.golden_tensor_pool[ssa] = golden_out |
There was a problem hiding this comment.
| if mode is NumericsMode.ISOLATED: | |
| continue | |
| ctx.golden_tensor_pool[ssa] = golden_out | |
| if mode is NumericsMode.ACCUMULATED: | |
| ctx.golden_tensor_pool[ssa] = golden_out |
Summary
Closes #8385
Replaces the hand-maintained
_CHISEL_INPLACE_OPStable with an IR-derived view of in-place mutation, and teaches chisel's per-op flow to validate mutated operands alongside SSA results.What changes
python/Util.cpp): newget_write_effect_operand_indicesthat returns flat operand indices on which an op declaresMemoryEffects::Write, orNonewhen the op doesn't implementMemoryEffectOpInterface.chisel/ops.py:get_inplace_vals(op)now derives mutated tensor operands from the binding above instead of the static dict.The
_CHISEL_INPLACE_OPStable andget_inplace_operandshelper are removed.chisel/callbacks.py—_default_post_op: SSA results and in-place mutated operands are validated in a single loop.Each enabled numerics mode (isolated / accumulated) produces a numerics record per result and per mutated operand.
The pool is refreshed for in-place operands when a golden is registered.
chisel/callbacks.py—_evict_inplace_no_golden: for no-golden ops that declare MemWrite, the pooled golden for each mutated operand is dropped (with agolden_evictedrecord), preventing false PCC failures on downstream ops.chisel/executor.py: dropsexecute_golden_from_pool— the unified post-op loop incallbacks.pyhandles pool refresh inline.chisel/op_configs.py:UpdateCacheOp,PagedUpdateCacheOp,FillCacheOp,PagedFillCacheOp, andBatchNormTrainingOpno longer need to be markedno_golden=True— they have working chisel goldens and run through the in-place validation path.test_chisel_records_update_cache_inplaceandtest_chisel_records_batch_norm_training_inplaceexercising the new path end-to-end (numerics records under both isolation and accumulation, no spurious golden_evicted records).Follow-ups
Two follow-up PRs are planned on top of this one:
Goal: stop redundantly pulling the same SSA to host across consumer ops and across iso/accum modes.
A follow-up will factor the common scaffolding into shared helpers / parametrize across ops.