Skip to content

Commit 48a5f0a

Browse files
agunapalsvekarsangelayi
authored
Add a recipe for showcasing torch.export flow for 4 models, with unique challenges and solutions (#3180)
* Added a recipe for showcasing torch.export flow for 4 models. --------- Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: Angela Yi <[email protected]>
1 parent f7d06b6 commit 48a5f0a

File tree

2 files changed

+339
-0
lines changed

2 files changed

+339
-0
lines changed

recipes_source/recipes_index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
157157
:link: ../recipes/torch_export_aoti_python.html
158158
:tags: Basics
159159

160+
.. customcarditem::
161+
:header: Demonstration of torch.export flow, common challenges and the solutions to address them
162+
:card_description: Learn how to export models for popular usecases
163+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
164+
:link: ../recipes/torch_export_challenges_solutions.html
165+
:tags: Compiler,TorchCompile
166+
160167
.. Interpretability
161168
162169
.. customcarditem::
@@ -472,3 +479,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
472479
/recipes/distributed_optim_torchscript
473480
/recipes/mobile_interpreter
474481
/recipes/distributed_comm_debug_mode
482+
/recipes/torch_export_challenges_solutions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
Demonstration of torch.export flow, common challenges and the solutions to address them
2+
=======================================================================================
3+
**Authors:** `Ankith Gunapal <https://github.com/agunapal>`__, `Jordi Ramon <https://github.com/JordiFB>`__, `Marcos Carranza <https://github.com/macarran>`__
4+
5+
In the `Introduction to torch.export Tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__ , we learned how to use `torch.export <https://pytorch.org/docs/stable/export.html>`__.
6+
This tutorial expands on the previous one and explores the process of exporting popular models with code, as well as addresses common challenges that may arise with ``torch.export``.
7+
8+
In this tutorial, you will learn how to export models for these use cases:
9+
10+
* Video classifier (`MViT <https://pytorch.org/vision/main/models/video_mvit.html>`__)
11+
* Automatic Speech Recognition (`OpenAI Whisper-Tiny <https://huggingface.co/openai/whisper-tiny>`__)
12+
* Image Captioning (`BLIP <https://github.com/salesforce/BLIP>`__)
13+
* Promptable Image Segmentation (`SAM2 <https://ai.meta.com/sam2/>`__)
14+
15+
Each of the four models were chosen to demonstrate unique features of `torch.export`, as well as some practical considerations
16+
and issues faced in the implementation.
17+
18+
Prerequisites
19+
-------------
20+
21+
* PyTorch 2.4 or later
22+
* Basic understanding of ``torch.export`` and PyTorch Eager inference.
23+
24+
25+
Key requirement for ``torch.export``: No graph break
26+
----------------------------------------------------
27+
28+
`torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ speeds up PyTorch code by using JIT to compile PyTorch code into optimized kernels. It optimizes the given model
29+
using ``TorchDynamo`` and creates an optimized graph , which is then lowered into the hardware using the backend specified in the API.
30+
When TorchDynamo encounters unsupported Python features, it breaks the computation graph, lets the default Python interpreter
31+
handle the unsupported code, and then resumes capturing the graph. This break in the computation graph is called a `graph break <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torchdynamo-and-fx-graphs>`__.
32+
33+
One of the key differences between ``torch.export`` and ``torch.compile`` is that ``torch.export`` doesn’t support graph breaks
34+
which means that the entire model or part of the model that you are exporting needs to be a single graph. This is because handling graph breaks
35+
involves interpreting the unsupported operation with default Python evaluation, which is incompatible with what ``torch.export`` is
36+
designed for. You can read details about the differences between the various PyTorch frameworks in this `link <https://pytorch.org/docs/main/export.html#existing-frameworks>`__
37+
38+
You can identify graph breaks in your program by using the following command:
39+
40+
.. code:: sh
41+
42+
TORCH_LOGS="graph_breaks" python <file_name>.py
43+
44+
You will need to modify your program to get rid of graph breaks. Once resolved, you are ready to export the model.
45+
PyTorch runs `nightly benchmarks <https://hud.pytorch.org/benchmark/compilers>`__ for `torch.compile` on popular HuggingFace and TIMM models.
46+
Most of these models have no graph breaks.
47+
48+
The models in this recipe have no graph breaks, but fail with `torch.export`.
49+
50+
Video Classification
51+
--------------------
52+
53+
MViT is a class of models based on `MultiScale Vision Transformers <https://arxiv.org/abs/2104.11227>`__. This model has been trained for video classification using the `Kinetics-400 Dataset <https://arxiv.org/abs/1705.06950>`__.
54+
This model with a relevant dataset can be used for action recognition in the context of gaming.
55+
56+
57+
The code below exports MViT by tracing with ``batch_size=2`` and then checks if the ExportedProgram can run with ``batch_size=4``.
58+
59+
.. code:: python
60+
61+
import numpy as np
62+
import torch
63+
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
64+
import traceback as tb
65+
66+
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
67+
68+
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
69+
input_frames = torch.randn(2,16, 224, 224, 3)
70+
# Transpose to get [1, 3, num_clips, height, width].
71+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
72+
73+
# Export the model.
74+
exported_program = torch.export.export(
75+
model,
76+
(input_frames,),
77+
)
78+
79+
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
80+
input_frames = torch.randn(4,16, 224, 224, 3)
81+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
82+
try:
83+
exported_program.module()(input_frames)
84+
except Exception:
85+
tb.print_exc()
86+
87+
88+
Error: Static batch size
89+
~~~~~~~~~~~~~~~~~~~~~~~~
90+
91+
.. code-block:: sh
92+
93+
raise RuntimeError(
94+
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4
95+
96+
97+
By default, the exporting flow will trace the program assuming that all input shapes are static, so if you run the program with
98+
input shapes that are different than the ones you used while tracing, you will run into an error.
99+
100+
Solution
101+
~~~~~~~~
102+
103+
To address the error, we specify the first dimension of the input (``batch_size``) to be dynamic , specifying the expected range of ``batch_size``.
104+
In the corrected example shown below, we specify that the expected ``batch_size`` can range from 1 to 16.
105+
One detail to notice that ``min=2`` is not a bug and is explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__. A detailed description of dynamic shapes
106+
for ``torch.export`` can be found in the export tutorial. The code shown below demonstrates how to export mViT with dynamic batch sizes:
107+
108+
.. code:: python
109+
110+
import numpy as np
111+
import torch
112+
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
113+
import traceback as tb
114+
115+
116+
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
117+
118+
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
119+
input_frames = torch.randn(2,16, 224, 224, 3)
120+
121+
# Transpose to get [1, 3, num_clips, height, width].
122+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
123+
124+
# Export the model.
125+
batch_dim = torch.export.Dim("batch", min=2, max=16)
126+
exported_program = torch.export.export(
127+
model,
128+
(input_frames,),
129+
# Specify the first dimension of the input x as dynamic
130+
dynamic_shapes={"x": {0: batch_dim}},
131+
)
132+
133+
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
134+
input_frames = torch.randn(4,16, 224, 224, 3)
135+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
136+
try:
137+
exported_program.module()(input_frames)
138+
except Exception:
139+
tb.print_exc()
140+
141+
142+
Automatic Speech Recognition
143+
---------------
144+
145+
**Automatic Speech Recognition** (ASR) is the use of machine learning to transcribe spoken language into text.
146+
`Whisper <https://arxiv.org/abs/2212.04356>`__ is a Transformer based encoder-decoder model from OpenAI, which was trained on 680k hours of labelled data for ASR and speech translation.
147+
The code below tries to export ``whisper-tiny`` model for ASR.
148+
149+
150+
.. code:: python
151+
152+
import torch
153+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
154+
from datasets import load_dataset
155+
156+
# load model
157+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
158+
159+
# dummy inputs for exporting the model
160+
input_features = torch.randn(1,80, 3000)
161+
attention_mask = torch.ones(1, 3000)
162+
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
163+
164+
model.eval()
165+
166+
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))
167+
168+
169+
170+
Error: strict tracing with TorchDynamo
171+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
172+
173+
.. code:: console
174+
175+
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'
176+
177+
178+
By default ``torch.export`` traces your code using `TorchDynamo <https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html>`__, a byte-code analysis engine, which symbolically analyzes your code and builds a graph.
179+
This analysis provides a stronger guarantee about safety but not all Python code is supported. When we export the ``whisper-tiny`` model using the
180+
default strict mode, it typically returns an error in Dynamo due to an unsupported feature. To understand why this errors in Dynamo, you can refer to this `GitHub issue <https://github.com/pytorch/pytorch/issues/144906>`__.
181+
182+
Solution
183+
~~~~~~~~
184+
185+
To address the above error , ``torch.export`` supports the ``non_strict`` mode where the program is traced using the Python interpreter, which works similar to
186+
PyTorch eager execution. The only difference is that all ``Tensor`` objects will be replaced by ``ProxyTensors``, which will record all their operations into
187+
a graph. By using ``strict=False``, we are able to export the program.
188+
189+
.. code:: python
190+
191+
import torch
192+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
193+
from datasets import load_dataset
194+
195+
# load model
196+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
197+
198+
# dummy inputs for exporting the model
199+
input_features = torch.randn(1,80, 3000)
200+
attention_mask = torch.ones(1, 3000)
201+
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
202+
203+
model.eval()
204+
205+
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)
206+
207+
Image Captioning
208+
----------------
209+
210+
**Image Captioning** is the task of defining the contents of an image in words. In the context of gaming, Image Captioning can be used to enhance the
211+
gameplay experience by dynamically generating text description of the various game objects in the scene, thereby providing the gamer with additional
212+
details. `BLIP <https://arxiv.org/pdf/2201.12086>`__ is a popular model for Image Captioning `released by SalesForce Research <https://github.com/salesforce/BLIP>`__. The code below tries to export BLIP with ``batch_size=1``.
213+
214+
215+
.. code:: python
216+
217+
import torch
218+
from models.blip import blip_decoder
219+
220+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
221+
image_size = 384
222+
image = torch.randn(1, 3,384,384).to(device)
223+
caption_input = ""
224+
225+
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
226+
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
227+
model.eval()
228+
model = model.to(device)
229+
230+
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)
231+
232+
233+
234+
Error: Cannot mutate tensors with frozen storage
235+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
236+
237+
While exporting a model, it might fail because the model implementation might contain certain Python operations which are not yet supported by ``torch.export``.
238+
Some of these failures may have a workaround. BLIP is an example where the original model errors, which can be resolved by making a small change in the code.
239+
``torch.export`` lists the common cases of supported and unsupported operations in `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ and shows how you can modify your code to make it export compatible.
240+
241+
.. code:: console
242+
243+
File "/BLIP/models/blip.py", line 112, in forward
244+
text.input_ids[:,0] = self.tokenizer.bos_token_id
245+
File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
246+
outs_unwrapped = func._op_dk(
247+
RuntimeError: cannot mutate tensors with frozen storage
248+
249+
250+
251+
Solution
252+
~~~~~~~~
253+
254+
Clone the `tensor <https://github.com/salesforce/BLIP/blob/main/models/blip.py#L112>`__ where export fails.
255+
256+
.. code:: python
257+
258+
text.input_ids = text.input_ids.clone() # clone the tensor
259+
text.input_ids[:,0] = self.tokenizer.bos_token_id
260+
261+
.. note::
262+
This constraint has been relaxed in PyTorch 2.7 nightlies. This should work out-of-the-box in PyTorch 2.7
263+
264+
Promptable Image Segmentation
265+
-----------------------------
266+
267+
**Image segmentation** is a computer vision technique that divides a digital image into distinct groups of pixels, or segments, based on their characteristics.
268+
`Segment Anything Model (SAM) <https://ai.meta.com/blog/segment-anything-foundation-model-image-segmentation/>`__) introduced promptable image segmentation, which predicts object masks given prompts that indicate the desired object. `SAM 2 <https://ai.meta.com/sam2/>`__ is
269+
the first unified model for segmenting objects across images and videos. The `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ class provides an easy interface to the model for prompting
270+
the model. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction. Since SAM2 provides strong
271+
zero-shot performance for object tracking, it can be used for tracking game objects in a scene.
272+
273+
274+
The tensor operations in the predict method of `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ are happening in the `_predict <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L291>`__ method. So, we try to export like this.
275+
276+
.. code:: python
277+
278+
ep = torch.export.export(
279+
self._predict,
280+
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
281+
kwargs={"return_logits": return_logits},
282+
strict=False,
283+
)
284+
285+
286+
Error: Model is not of type ``torch.nn.Module``
287+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
288+
289+
``torch.export`` expects the module to be of type ``torch.nn.Module``. However, the module we are trying to export is a class method. Hence it errors.
290+
291+
.. code:: console
292+
293+
Traceback (most recent call last):
294+
File "/sam2/image_predict.py", line 20, in <module>
295+
masks, scores, _ = predictor.predict(
296+
File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
297+
ep = torch.export.export(
298+
File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
299+
raise ValueError(
300+
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.
301+
302+
303+
Solution
304+
~~~~~~~~
305+
306+
We write a helper class, which inherits from ``torch.nn.Module`` and call the ``_predict method`` in the ``forward`` method of the class. The complete code can be found `here <https://github.com/anijain2305/sam2/blob/ued/sam2/sam2_image_predictor.py#L293-L311>`__.
307+
308+
.. code:: python
309+
310+
class ExportHelper(torch.nn.Module):
311+
def __init__(self):
312+
super().__init__()
313+
314+
def forward(_, *args, **kwargs):
315+
return self._predict(*args, **kwargs)
316+
317+
model_to_export = ExportHelper()
318+
ep = torch.export.export(
319+
model_to_export,
320+
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
321+
kwargs={"return_logits": return_logits},
322+
strict=False,
323+
)
324+
325+
Conclusion
326+
----------
327+
328+
In this tutorial, we have learned how to use ``torch.export`` to export models for popular use cases by addressing challenges through correct configuration and simple code modifications.
329+
Once you are able to export a model, you can lower the ``ExportedProgram`` into your hardware using `AOTInductor <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html>`__ in case of servers and `ExecuTorch <https://pytorch.org/executorch/stable/index.html>`__ in case of edge device.
330+
To learn more about ``AOTInductor`` (AOTI), please refer to the `AOTI tutorial <https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`__.
331+
To learn more about ``ExecuTorch`` , please refer to the `ExecuTorch tutorial <https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html>`__.

0 commit comments

Comments
 (0)