Skip to content
53 changes: 50 additions & 3 deletions docs/api/ir_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,60 @@

## Use built-in passes

Common, reusable passes are implemented in `onnx_ir.passes.common`. You can use {py:class}`onnx_ir.passes.Sequential <onnx_ir.passes.Sequential>` to chain passes or use {py:class}`onnx_ir.passes.PassManager <onnx_ir.passes.PassManager>` which supports early stopping if no changes are made.
Common, reusable passes are implemented in {py:module}`onnx_ir.passes.common`. You can use {py:class}`onnx_ir.passes.Sequential <onnx_ir.passes.Sequential>` to chain passes or use {py:class}`onnx_ir.passes.PassManager <onnx_ir.passes.PassManager>` which supports early stopping if no changes are made.

### Example

```py
import onnx_ir as ir
import onnx_ir.passes.common as common_passes

model = ir.load("model.onnx")

# You can chain passes with ir.passes.Sequential
passes = ir.passes.Sequential(
common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
common_passes.CommonSubexpressionEliminationPass(),
)
result = passes(model)

# Or you can run passes individually. Passing in result or result.model has the same effect
result = common_passes.ClearMetadataAndDocStringPass()(result)

print("The model was modified:", result.modified)
ir.save(result.model, "model.onnx")
```

For more advanced use cases, you can use {py:class}`onnx_ir.passes.PassManager <onnx_ir.passes.PassManager>` to orchestrate passes with automatic iteration until convergence:

```py
model = ir.load("model.onnx")

passes = ir.passes.PassManager(
[
# Pass managers can be nested
ir.passes.PassManager(
[
common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
common_passes.CommonSubexpressionEliminationPass(),
],
steps=2,
early_stop=True,
),
common_passes.ClearMetadataAndDocStringPass(),
],
steps=2,
early_stop=False,
)

result = passes(model)
```

## Pass infrastructure

Inherent {py:class}`onnx_ir.passes.InPlacePass <onnx_ir.passes.InPlacePass>` or {py:class}`onnx_ir.passes.FunctionalPass <onnx_ir.passes.FunctionalPass>` to define a pass. You will need to implement the `call` method which returns a {py:class}`onnx_ir.passes.PassResult <onnx_ir.passes.PassResult>`.
Inherit from {py:class}`onnx_ir.passes.InPlacePass <onnx_ir.passes.InPlacePass>` or {py:class}`onnx_ir.passes.FunctionalPass <onnx_ir.passes.FunctionalPass>` to define a pass. You will need to implement the `call` method which returns a {py:class}`onnx_ir.passes.PassResult <onnx_ir.passes.PassResult>`.

Alternatively, inherent the base class `onnx_ir.passes.PassBase <onnx_ir.passes.PassBase>` and override the two properties `changes_input` and `in_place` to set properties of the pass.
Alternatively, inherit from the base class {py:class}`onnx_ir.passes.PassBase <onnx_ir.passes.PassBase>` and override the two properties `changes_input` and `in_place` to set properties of the pass.

```{eval-rst}
.. autosummary::
Expand Down
27 changes: 12 additions & 15 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,19 +2207,9 @@ def Neg(self, operand) -> Value: ... # noqa: N802
def set_value_magic_handler(handler: _OpHandlerProtocol | None) -> _OpHandlerProtocol | None:
"""Set the magic handler for Value arithmetic methods.

This context manager sets the magic handler for Value arithmetic methods
within the context. After exiting the context, the magic handler is reset
to None.

Framework authors can implement custom context managers that set
the magic handler to enable arithmetic operations on Values.

Args:
handler: The magic handler to set.

Returns:
The previous magic handler.

Example::
class MyOpHandler:
def Add(self, lhs, rhs):
Expand All @@ -2234,6 +2224,12 @@ def graph_context(graph):
yield
finally:
onnx_ir.set_value_magic_handler(old_handler)

Args:
handler: The magic handler to set.

Returns:
The previous magic handler.
"""
old_handler = WithArithmeticMethods._magic_handler
WithArithmeticMethods._magic_handler = handler
Expand Down Expand Up @@ -2315,7 +2311,7 @@ class Value(WithArithmeticMethods, _protocols.ValueProtocol, _display.PrettyPrin
For consistency, none of the other comparison operators are included.

.. versionadded:: 0.1.14
Value now supports arithmetic magic methods within the context manager
Value now supports arithmetic magic methods when a handler is set via
:func:`onnx_ir.set_value_magic_handler`.
"""

Expand Down Expand Up @@ -2710,10 +2706,11 @@ def merge_shapes(self, other: Shape | None, /) -> None:
"""Merge the shape of this value with another shape to update the existing shape, with the current shape's dimensions taking precedence.

Two dimensions are merged as follows:
- If both dimensions are equal, the merged dimension is the same.
- If one dimension is SymbolicDim and the other is concrete, the merged dimension is the concrete one.
- If both dimensions are SymbolicDim, a named symbolic dimension (non-None value) is preferred over an unnamed one (None value).
- In all other cases where the dimensions differ, the current shape's dimension is taken (a warning is emitted when both are concrete integers).

* If both dimensions are equal, the merged dimension is the same.
* If one dimension is SymbolicDim and the other is concrete, the merged dimension is the concrete one.
* If both dimensions are SymbolicDim, a named symbolic dimension (non-None value) is preferred over an unnamed one (None value).
* In all other cases where the dimensions differ, the current shape's dimension is taken (a warning is emitted when both are concrete integers).

.. versionadded:: 0.1.14

Expand Down
39 changes: 38 additions & 1 deletion src/onnx_ir/passes/_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,19 @@ def changes_input(self) -> Literal[False]:


class Sequential(PassBase):
"""Run a sequence of passes in order."""
"""Run a sequence of passes in order.

Example::
import onnx_ir as ir
import onnx_ir.passes.common as common_passes

passes = ir.passes.Sequential(
common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
common_passes.CommonSubexpressionEliminationPass(),
common_passes.ClearMetadataAndDocStringPass(),
)
result = passes(model)
"""

def __init__(self, *passes: PassBase):
if not passes:
Expand Down Expand Up @@ -254,6 +266,31 @@ class PassManager(Sequential):

The PassManager is a Pass that runs a sequence of passes on a model.

Example::
import onnx_ir as ir
import onnx_ir.passes.common as common_passes

model = ir.load("model.onnx")
passes = ir.passes.PassManager(
[
# Pass managers can be nested
ir.passes.PassManager(
[
common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
common_passes.CommonSubexpressionEliminationPass(),
],
steps=2,
early_stop=True,
),
common_passes.ClearMetadataAndDocStringPass(),
],
steps=2,
early_stop=False,
)

# Apply the passes to the model
result = passes(model)

Attributes:
passes: The passes to run.
steps: The number of times to run the passes.
Expand Down
Loading