Skip to content
53 changes: 50 additions & 3 deletions docs/api/ir_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,60 @@

## Use built-in passes

Common, reusable passes are implemented in `onnx_ir.passes.common`. You can use {py:class}`onnx_ir.passes.Sequential <onnx_ir.passes.Sequential>` to chain passes or use {py:class}`onnx_ir.passes.PassManager <onnx_ir.passes.PassManager>` which supports early stopping if no changes are made.
Common, reusable passes are implemented in {py:module}`onnx_ir.passes.common`. You can use {py:class}`onnx_ir.passes.Sequential <onnx_ir.passes.Sequential>` to chain passes or use {py:class}`onnx_ir.passes.PassManager <onnx_ir.passes.PassManager>` which supports early stopping if no changes are made.

### Example

```py
import onnx_ir as ir
import onnx_ir.passes.common as common_passes

model = ir.load("model.onnx")

# You can chain passes with ir.passes.Sequential
passes = ir.passes.Sequential(
common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
common_passes.CommonSubexpressionEliminationPass(),
)
result = passes(model)

# Or you can run passes individually. Passing in result or result.model has the same effect
result = common_passes.ClearMetadataAndDocStringPass()(result)

print("The model was modified:", result.modified)
ir.save(result.model, "model.onnx")
```

For more advanced use cases, you can use {py:class}`onnx_ir.passes.PassManager <onnx_ir.passes.PassManager>` to orchestrate passes with automatic iteration until convergence:

```py
model = ir.load("model.onnx")

passes = ir.passes.PassManager(
[
# Pass managers can be nested
ir.passes.PassManager(
[
common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
common_passes.CommonSubexpressionEliminationPass(),
]
steps=2,
early_stop=True,
),
common_passes.ClearMetadataAndDocStringPass(),
],
steps=2,
early_stop=False,
)

result = passes(model)
```

## Pass infrastructure

Inherent {py:class}`onnx_ir.passes.InPlacePass <onnx_ir.passes.InPlacePass>` or {py:class}`onnx_ir.passes.FunctionalPass <onnx_ir.passes.FunctionalPass>` to define a pass. You will need to implement the `call` method which returns a {py:class}`onnx_ir.passes.PassResult <onnx_ir.passes.PassResult>`.
Inherit from {py:class}`onnx_ir.passes.InPlacePass <onnx_ir.passes.InPlacePass>` or {py:class}`onnx_ir.passes.FunctionalPass <onnx_ir.passes.FunctionalPass>` to define a pass. You will need to implement the `call` method which returns a {py:class}`onnx_ir.passes.PassResult <onnx_ir.passes.PassResult>`.

Alternatively, inherent the base class `onnx_ir.passes.PassBase <onnx_ir.passes.PassBase>` and override the two properties `changes_input` and `in_place` to set properties of the pass.
Alternatively, inherit from the base class {py:class}`onnx_ir.passes.PassBase <onnx_ir.passes.PassBase>` and override the two properties `changes_input` and `in_place` to set properties of the pass.

```{eval-rst}
.. autosummary::
Expand Down
14 changes: 13 additions & 1 deletion src/onnx_ir/passes/_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,19 @@ def changes_input(self) -> Literal[False]:


class Sequential(PassBase):
"""Run a sequence of passes in order."""
"""Run a sequence of passes in order.

Example::
import onnx_ir as ir
import onnx_ir.passes.common as common_passes

passes = ir.passes.Sequential(
common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
common_passes.CommonSubexpressionEliminationPass(),
common_passes.ClearMetadataAndDocStringPass(),
)
result = passes(model)
"""

def __init__(self, *passes: PassBase):
if not passes:
Expand Down
Loading