diff --git a/docs/api/ir_passes.md b/docs/api/ir_passes.md index b88fe2bc..e6cf6423 100644 --- a/docs/api/ir_passes.md +++ b/docs/api/ir_passes.md @@ -6,13 +6,60 @@ ## Use built-in passes -Common, reusable passes are implemented in `onnx_ir.passes.common`. You can use {py:class}`onnx_ir.passes.Sequential ` to chain passes or use {py:class}`onnx_ir.passes.PassManager ` which supports early stopping if no changes are made. +Common, reusable passes are implemented in {py:module}`onnx_ir.passes.common`. You can use {py:class}`onnx_ir.passes.Sequential ` to chain passes or use {py:class}`onnx_ir.passes.PassManager ` which supports early stopping if no changes are made. + +### Example + +```py +import onnx_ir as ir +import onnx_ir.passes.common as common_passes + +model = ir.load("model.onnx") + +# You can chain passes with ir.passes.Sequential +passes = ir.passes.Sequential( + common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024), + common_passes.CommonSubexpressionEliminationPass(), +) +result = passes(model) + +# Or you can run passes individually. Passing in result or result.model has the same effect +result = common_passes.ClearMetadataAndDocStringPass()(result) + +print("The model was modified:", result.modified) +ir.save(result.model, "model.onnx") +``` + +For more advanced use cases, you can use {py:class}`onnx_ir.passes.PassManager ` to orchestrate passes with automatic iteration until convergence: + +```py +model = ir.load("model.onnx") + +passes = ir.passes.PassManager( + [ + # Pass managers can be nested + ir.passes.PassManager( + [ + common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024), + common_passes.CommonSubexpressionEliminationPass(), + ], + steps=2, + early_stop=True, + ), + common_passes.ClearMetadataAndDocStringPass(), + ], + steps=2, + early_stop=False, +) + +result = passes(model) +``` ## Pass infrastructure -Inherent {py:class}`onnx_ir.passes.InPlacePass ` or {py:class}`onnx_ir.passes.FunctionalPass ` to define a pass. You will need to implement the `call` method which returns a {py:class}`onnx_ir.passes.PassResult `. +Inherit from {py:class}`onnx_ir.passes.InPlacePass ` or {py:class}`onnx_ir.passes.FunctionalPass ` to define a pass. You will need to implement the `call` method which returns a {py:class}`onnx_ir.passes.PassResult `. -Alternatively, inherent the base class `onnx_ir.passes.PassBase ` and override the two properties `changes_input` and `in_place` to set properties of the pass. +Alternatively, inherit from the base class {py:class}`onnx_ir.passes.PassBase ` and override the two properties `changes_input` and `in_place` to set properties of the pass. ```{eval-rst} .. autosummary:: diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 9b323e6b..243e8996 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -2207,19 +2207,9 @@ def Neg(self, operand) -> Value: ... # noqa: N802 def set_value_magic_handler(handler: _OpHandlerProtocol | None) -> _OpHandlerProtocol | None: """Set the magic handler for Value arithmetic methods. - This context manager sets the magic handler for Value arithmetic methods - within the context. After exiting the context, the magic handler is reset - to None. - Framework authors can implement custom context managers that set the magic handler to enable arithmetic operations on Values. - Args: - handler: The magic handler to set. - - Returns: - The previous magic handler. - Example:: class MyOpHandler: def Add(self, lhs, rhs): @@ -2234,6 +2224,12 @@ def graph_context(graph): yield finally: onnx_ir.set_value_magic_handler(old_handler) + + Args: + handler: The magic handler to set. + + Returns: + The previous magic handler. """ old_handler = WithArithmeticMethods._magic_handler WithArithmeticMethods._magic_handler = handler @@ -2315,7 +2311,7 @@ class Value(WithArithmeticMethods, _protocols.ValueProtocol, _display.PrettyPrin For consistency, none of the other comparison operators are included. .. versionadded:: 0.1.14 - Value now supports arithmetic magic methods within the context manager + Value now supports arithmetic magic methods when a handler is set via :func:`onnx_ir.set_value_magic_handler`. """ @@ -2710,10 +2706,11 @@ def merge_shapes(self, other: Shape | None, /) -> None: """Merge the shape of this value with another shape to update the existing shape, with the current shape's dimensions taking precedence. Two dimensions are merged as follows: - - If both dimensions are equal, the merged dimension is the same. - - If one dimension is SymbolicDim and the other is concrete, the merged dimension is the concrete one. - - If both dimensions are SymbolicDim, a named symbolic dimension (non-None value) is preferred over an unnamed one (None value). - - In all other cases where the dimensions differ, the current shape's dimension is taken (a warning is emitted when both are concrete integers). + + * If both dimensions are equal, the merged dimension is the same. + * If one dimension is SymbolicDim and the other is concrete, the merged dimension is the concrete one. + * If both dimensions are SymbolicDim, a named symbolic dimension (non-None value) is preferred over an unnamed one (None value). + * In all other cases where the dimensions differ, the current shape's dimension is taken (a warning is emitted when both are concrete integers). .. versionadded:: 0.1.14 diff --git a/src/onnx_ir/passes/_pass_infra.py b/src/onnx_ir/passes/_pass_infra.py index ca1dd51a..60daa438 100644 --- a/src/onnx_ir/passes/_pass_infra.py +++ b/src/onnx_ir/passes/_pass_infra.py @@ -210,7 +210,19 @@ def changes_input(self) -> Literal[False]: class Sequential(PassBase): - """Run a sequence of passes in order.""" + """Run a sequence of passes in order. + + Example:: + import onnx_ir as ir + import onnx_ir.passes.common as common_passes + + passes = ir.passes.Sequential( + common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024), + common_passes.CommonSubexpressionEliminationPass(), + common_passes.ClearMetadataAndDocStringPass(), + ) + result = passes(model) + """ def __init__(self, *passes: PassBase): if not passes: @@ -254,6 +266,31 @@ class PassManager(Sequential): The PassManager is a Pass that runs a sequence of passes on a model. + Example:: + import onnx_ir as ir + import onnx_ir.passes.common as common_passes + + model = ir.load("model.onnx") + passes = ir.passes.PassManager( + [ + # Pass managers can be nested + ir.passes.PassManager( + [ + common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024), + common_passes.CommonSubexpressionEliminationPass(), + ], + steps=2, + early_stop=True, + ), + common_passes.ClearMetadataAndDocStringPass(), + ], + steps=2, + early_stop=False, + ) + + # Apply the passes to the model + result = passes(model) + Attributes: passes: The passes to run. steps: The number of times to run the passes.