-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Torch Function modes x torch.compile tutorial #3320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
74a5076
Initial commit
mlazos c51e070
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 52be71f
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 7014f3a
Update recipes_source/torch_compile_torch_function_modes.py
mlazos ede8d60
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 2e7943a
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 3d3a691
Update recipes_source/torch_compile_torch_function_modes.py
mlazos f9ab2eb
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 2a64b21
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 50ab48e
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 3c2efb6
Update recipes_source/torch_compile_torch_function_modes.py
mlazos 1c6eb64
Merge branch 'main' into mlazos/tf-modes-tutorial
mlazos 56dee0c
Fix metadata
mlazos 50eda9b
Merge branch 'main' into mlazos/tf-modes-tutorial
AlannaBurke 6fe3b64
Apply suggestions from code review
svekars 7c0928b
Merge branch 'main' into mlazos/tf-modes-tutorial
svekars 4e854d5
Merge branch 'main' into mlazos/tf-modes-tutorial
svekars f737c55
Merge branch 'main' into mlazos/tf-modes-tutorial
svekars File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
(beta) Utilizing Torch Function modes with torch.compile | ||
============================================================ | ||
|
||
**Author:** `Michael Lazos <https://github.com/mlazos>`_ | ||
""" | ||
|
||
######################################################### | ||
# This recipe covers how to use a key torch extensibility point, | ||
# torch function modes, in tandem with ``torch.compile`` to override | ||
# the behavior of torch operators, also know as **ops**, at trace time, with no runtime overhead. | ||
# | ||
# .. note:: | ||
# | ||
# This recipe requires PyTorch 2.7.0 or later. | ||
|
||
|
||
##################################################################### | ||
# Rewriting a torch op (torch.add -> torch.mul) | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# For this example, we'll use torch function modes to rewrite occurences | ||
# of addition with multiply instead. This type of override can be common | ||
# if a certain backend has a custom implementation that should be dispatched | ||
# for a given op. | ||
import torch | ||
|
||
# exit cleanly if we are on a device that doesn't support ``torch.compile`` | ||
if torch.cuda.get_device_capability() < (7, 0): | ||
print("Exiting because torch.compile is not supported on this device.") | ||
import sys | ||
sys.exit(0) | ||
|
||
from torch.overrides import BaseTorchFunctionMode | ||
|
||
# Define our mode, Note: ``BaseTorchFunctionMode`` | ||
# implements the actual invocation of func(..) | ||
class AddToMultiplyMode(BaseTorchFunctionMode): | ||
def __torch_function__(self, func, types, args=(), kwargs=None): | ||
if func == torch.Tensor.add: | ||
func = torch.mul | ||
|
||
return super().__torch_function__(func, types, args, kwargs) | ||
|
||
@torch.compile() | ||
def test_fn(x, y): | ||
return x + y * x # Note: infix operators map to torch.Tensor.* methods | ||
|
||
x = torch.rand(2, 2) | ||
y = torch.rand_like(x) | ||
|
||
with AddToMultiplyMode(): | ||
z = test_fn(x, y) | ||
|
||
assert torch.allclose(z, x * y * x) | ||
|
||
# The mode can also be used within the compiled region as well like this: | ||
|
||
@torch.compile() | ||
def test_fn(x, y): | ||
with AddToMultiplyMode(): | ||
return x + y * x # Note: infix operators map to torch.Tensor.* methods | ||
|
||
x = torch.rand(2, 2) | ||
y = torch.rand_like(x) | ||
z = test_fn(x, y) | ||
|
||
assert torch.allclose(z, x * y * x) | ||
|
||
###################################################################### | ||
# Conclusion | ||
# ~~~~~~~~~~ | ||
# In this recipe we demonstrated how to override the behavior of ``torch.*`` operators | ||
# using torch function modes from within ``torch.compile``. This enables users to utilize | ||
# the extensibility benefits of torch function modes without the runtime overhead | ||
# of calling torch function on every op invocation. | ||
# | ||
# * See `Extending Torch API with Modes <https://pytorch.org/docs/stable/notes/extending.html#extending-all-torch-api-with-modes>`__ for other examples and background on Torch Function modes. |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.