Skip to content

Torch Function modes x torch.compile tutorial #3320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
:link: ../recipes/amx.html
:tags: Model-Optimization

.. (beta) Utilizing Torch Function modes with torch.compile

.. customcarditem::
:header: (beta) Utilizing Torch Function modes with torch.compile
:card_description: Override torch operators with Torch Function modes and torch.compile
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/torch_compile_torch_function_modes.html
:tags: Model-Optimization

.. (beta) Compiling the Optimizer with torch.compile

.. customcarditem::
Expand Down
77 changes: 77 additions & 0 deletions recipes_source/torch_compile_torch_function_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
(beta) Utilizing Torch Function modes with torch.compile
============================================================

**Author:** `Michael Lazos <https://github.com/mlazos>`_
"""

#########################################################
# This recipe covers how to use a key torch extensibility point,
# torch function modes, in tandem with ``torch.compile`` to override
# the behavior of torch operators, also know as **ops**, at trace time, with no runtime overhead.
#
# .. note::
#
# This recipe requires PyTorch 2.7.0 or later.


#####################################################################
# Rewriting a torch op (torch.add -> torch.mul)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Rewriting a torch op (torch.add -> torch.mul)
# Rewriting a torch op (``torch.add`` -> ``torch.mul``)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# For this example, we'll use torch function modes to rewrite occurences
# of addition with multiply instead. This type of override can be common
# if a certain backend has a custom implementation that should be dispatched
# for a given op.
import torch

# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)

from torch.overrides import BaseTorchFunctionMode

# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if func == torch.Tensor.add:
func = torch.mul

return super().__torch_function__(func, types, args, kwargs)

@torch.compile()
def test_fn(x, y):
return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)

with AddToMultiplyMode():
z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

# The mode can also be used within the compiled region as well like this:

@torch.compile()
def test_fn(x, y):
with AddToMultiplyMode():
return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

######################################################################
# Conclusion
# ~~~~~~~~~~
# In this recipe we demonstrated how to override the behavior of ``torch.*`` operators
# using torch function modes from within ``torch.compile``. This enables users to utilize
# the extensibility benefits of torch function modes without the runtime overhead
# of calling torch function on every op invocation.
#
# * See `Extending Torch API with Modes <https://pytorch.org/docs/stable/notes/extending.html#extending-all-torch-api-with-modes>`__ for other examples and background on Torch Function modes.