|
| 1 | +Demonstration of torch.export flow, common challenges and the solutions to address them |
| 2 | +======================================================================================= |
| 3 | +**Authors:** `Ankith Gunapal <https://github.com/agunapal>`__, `Jordi Ramon <https://github.com/JordiFB>`__, `Marcos Carranza <https://github.com/macarran>`__ |
| 4 | + |
| 5 | +In the `Introduction to torch.export Tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__ , we learned how to use `torch.export <https://pytorch.org/docs/stable/export.html>`__. |
| 6 | +This tutorial expands on the previous one and explores the process of exporting popular models with code, as well as addresses common challenges that may arise with ``torch.export``. |
| 7 | + |
| 8 | +In this tutorial, you will learn how to export models for these use cases: |
| 9 | + |
| 10 | +* Video classifier (`MViT <https://pytorch.org/vision/main/models/video_mvit.html>`__) |
| 11 | +* Automatic Speech Recognition (`OpenAI Whisper-Tiny <https://huggingface.co/openai/whisper-tiny>`__) |
| 12 | +* Image Captioning (`BLIP <https://github.com/salesforce/BLIP>`__) |
| 13 | +* Promptable Image Segmentation (`SAM2 <https://ai.meta.com/sam2/>`__) |
| 14 | + |
| 15 | +Each of the four models were chosen to demonstrate unique features of `torch.export`, as well as some practical considerations |
| 16 | +and issues faced in the implementation. |
| 17 | + |
| 18 | +Prerequisites |
| 19 | +------------- |
| 20 | + |
| 21 | +* PyTorch 2.4 or later |
| 22 | +* Basic understanding of ``torch.export`` and PyTorch Eager inference. |
| 23 | + |
| 24 | + |
| 25 | +Key requirement for ``torch.export``: No graph break |
| 26 | +---------------------------------------------------- |
| 27 | + |
| 28 | +`torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ speeds up PyTorch code by using JIT to compile PyTorch code into optimized kernels. It optimizes the given model |
| 29 | +using ``TorchDynamo`` and creates an optimized graph , which is then lowered into the hardware using the backend specified in the API. |
| 30 | +When TorchDynamo encounters unsupported Python features, it breaks the computation graph, lets the default Python interpreter |
| 31 | +handle the unsupported code, and then resumes capturing the graph. This break in the computation graph is called a `graph break <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torchdynamo-and-fx-graphs>`__. |
| 32 | + |
| 33 | +One of the key differences between ``torch.export`` and ``torch.compile`` is that ``torch.export`` doesn’t support graph breaks |
| 34 | +which means that the entire model or part of the model that you are exporting needs to be a single graph. This is because handling graph breaks |
| 35 | +involves interpreting the unsupported operation with default Python evaluation, which is incompatible with what ``torch.export`` is |
| 36 | +designed for. You can read details about the differences between the various PyTorch frameworks in this `link <https://pytorch.org/docs/main/export.html#existing-frameworks>`__ |
| 37 | + |
| 38 | +You can identify graph breaks in your program by using the following command: |
| 39 | + |
| 40 | +.. code:: sh |
| 41 | +
|
| 42 | + TORCH_LOGS="graph_breaks" python <file_name>.py |
| 43 | +
|
| 44 | +You will need to modify your program to get rid of graph breaks. Once resolved, you are ready to export the model. |
| 45 | +PyTorch runs `nightly benchmarks <https://hud.pytorch.org/benchmark/compilers>`__ for `torch.compile` on popular HuggingFace and TIMM models. |
| 46 | +Most of these models have no graph breaks. |
| 47 | + |
| 48 | +The models in this recipe have no graph breaks, but fail with `torch.export`. |
| 49 | + |
| 50 | +Video Classification |
| 51 | +-------------------- |
| 52 | + |
| 53 | +MViT is a class of models based on `MultiScale Vision Transformers <https://arxiv.org/abs/2104.11227>`__. This model has been trained for video classification using the `Kinetics-400 Dataset <https://arxiv.org/abs/1705.06950>`__. |
| 54 | +This model with a relevant dataset can be used for action recognition in the context of gaming. |
| 55 | + |
| 56 | + |
| 57 | +The code below exports MViT by tracing with ``batch_size=2`` and then checks if the ExportedProgram can run with ``batch_size=4``. |
| 58 | + |
| 59 | +.. code:: python |
| 60 | +
|
| 61 | + import numpy as np |
| 62 | + import torch |
| 63 | + from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b |
| 64 | + import traceback as tb |
| 65 | +
|
| 66 | + model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT) |
| 67 | +
|
| 68 | + # Create a batch of 2 videos, each with 16 frames of shape 224x224x3. |
| 69 | + input_frames = torch.randn(2,16, 224, 224, 3) |
| 70 | + # Transpose to get [1, 3, num_clips, height, width]. |
| 71 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 72 | +
|
| 73 | + # Export the model. |
| 74 | + exported_program = torch.export.export( |
| 75 | + model, |
| 76 | + (input_frames,), |
| 77 | + ) |
| 78 | +
|
| 79 | + # Create a batch of 4 videos, each with 16 frames of shape 224x224x3. |
| 80 | + input_frames = torch.randn(4,16, 224, 224, 3) |
| 81 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 82 | + try: |
| 83 | + exported_program.module()(input_frames) |
| 84 | + except Exception: |
| 85 | + tb.print_exc() |
| 86 | +
|
| 87 | +
|
| 88 | +Error: Static batch size |
| 89 | +~~~~~~~~~~~~~~~~~~~~~~~~ |
| 90 | + |
| 91 | +.. code-block:: sh |
| 92 | +
|
| 93 | + raise RuntimeError( |
| 94 | + RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4 |
| 95 | +
|
| 96 | +
|
| 97 | +By default, the exporting flow will trace the program assuming that all input shapes are static, so if you run the program with |
| 98 | +input shapes that are different than the ones you used while tracing, you will run into an error. |
| 99 | +
|
| 100 | +Solution |
| 101 | +~~~~~~~~ |
| 102 | +
|
| 103 | +To address the error, we specify the first dimension of the input (``batch_size``) to be dynamic , specifying the expected range of ``batch_size``. |
| 104 | +In the corrected example shown below, we specify that the expected ``batch_size`` can range from 1 to 16. |
| 105 | +One detail to notice that ``min=2`` is not a bug and is explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__. A detailed description of dynamic shapes |
| 106 | +for ``torch.export`` can be found in the export tutorial. The code shown below demonstrates how to export mViT with dynamic batch sizes: |
| 107 | +
|
| 108 | +.. code:: python |
| 109 | +
|
| 110 | + import numpy as np |
| 111 | + import torch |
| 112 | + from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b |
| 113 | + import traceback as tb |
| 114 | +
|
| 115 | +
|
| 116 | + model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT) |
| 117 | +
|
| 118 | + # Create a batch of 2 videos, each with 16 frames of shape 224x224x3. |
| 119 | + input_frames = torch.randn(2,16, 224, 224, 3) |
| 120 | +
|
| 121 | + # Transpose to get [1, 3, num_clips, height, width]. |
| 122 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 123 | +
|
| 124 | + # Export the model. |
| 125 | + batch_dim = torch.export.Dim("batch", min=2, max=16) |
| 126 | + exported_program = torch.export.export( |
| 127 | + model, |
| 128 | + (input_frames,), |
| 129 | + # Specify the first dimension of the input x as dynamic |
| 130 | + dynamic_shapes={"x": {0: batch_dim}}, |
| 131 | + ) |
| 132 | +
|
| 133 | + # Create a batch of 4 videos, each with 16 frames of shape 224x224x3. |
| 134 | + input_frames = torch.randn(4,16, 224, 224, 3) |
| 135 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 136 | + try: |
| 137 | + exported_program.module()(input_frames) |
| 138 | + except Exception: |
| 139 | + tb.print_exc() |
| 140 | +
|
| 141 | +
|
| 142 | +Automatic Speech Recognition |
| 143 | +--------------- |
| 144 | +
|
| 145 | +**Automatic Speech Recognition** (ASR) is the use of machine learning to transcribe spoken language into text. |
| 146 | +`Whisper <https://arxiv.org/abs/2212.04356>`__ is a Transformer based encoder-decoder model from OpenAI, which was trained on 680k hours of labelled data for ASR and speech translation. |
| 147 | +The code below tries to export ``whisper-tiny`` model for ASR. |
| 148 | +
|
| 149 | +
|
| 150 | +.. code:: python |
| 151 | +
|
| 152 | + import torch |
| 153 | + from transformers import WhisperProcessor, WhisperForConditionalGeneration |
| 154 | + from datasets import load_dataset |
| 155 | +
|
| 156 | + # load model |
| 157 | + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") |
| 158 | +
|
| 159 | + # dummy inputs for exporting the model |
| 160 | + input_features = torch.randn(1,80, 3000) |
| 161 | + attention_mask = torch.ones(1, 3000) |
| 162 | + decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id |
| 163 | +
|
| 164 | + model.eval() |
| 165 | +
|
| 166 | + exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,)) |
| 167 | +
|
| 168 | +
|
| 169 | +
|
| 170 | +Error: strict tracing with TorchDynamo |
| 171 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 172 | +
|
| 173 | +.. code:: console |
| 174 | +
|
| 175 | + torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache' |
| 176 | +
|
| 177 | +
|
| 178 | +By default ``torch.export`` traces your code using `TorchDynamo <https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html>`__, a byte-code analysis engine, which symbolically analyzes your code and builds a graph. |
| 179 | +This analysis provides a stronger guarantee about safety but not all Python code is supported. When we export the ``whisper-tiny`` model using the |
| 180 | +default strict mode, it typically returns an error in Dynamo due to an unsupported feature. To understand why this errors in Dynamo, you can refer to this `GitHub issue <https://github.com/pytorch/pytorch/issues/144906>`__. |
| 181 | +
|
| 182 | +Solution |
| 183 | +~~~~~~~~ |
| 184 | +
|
| 185 | +To address the above error , ``torch.export`` supports the ``non_strict`` mode where the program is traced using the Python interpreter, which works similar to |
| 186 | +PyTorch eager execution. The only difference is that all ``Tensor`` objects will be replaced by ``ProxyTensors``, which will record all their operations into |
| 187 | +a graph. By using ``strict=False``, we are able to export the program. |
| 188 | +
|
| 189 | +.. code:: python |
| 190 | +
|
| 191 | + import torch |
| 192 | + from transformers import WhisperProcessor, WhisperForConditionalGeneration |
| 193 | + from datasets import load_dataset |
| 194 | +
|
| 195 | + # load model |
| 196 | + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") |
| 197 | +
|
| 198 | + # dummy inputs for exporting the model |
| 199 | + input_features = torch.randn(1,80, 3000) |
| 200 | + attention_mask = torch.ones(1, 3000) |
| 201 | + decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id |
| 202 | +
|
| 203 | + model.eval() |
| 204 | +
|
| 205 | + exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False) |
| 206 | +
|
| 207 | +Image Captioning |
| 208 | +---------------- |
| 209 | +
|
| 210 | +**Image Captioning** is the task of defining the contents of an image in words. In the context of gaming, Image Captioning can be used to enhance the |
| 211 | +gameplay experience by dynamically generating text description of the various game objects in the scene, thereby providing the gamer with additional |
| 212 | +details. `BLIP <https://arxiv.org/pdf/2201.12086>`__ is a popular model for Image Captioning `released by SalesForce Research <https://github.com/salesforce/BLIP>`__. The code below tries to export BLIP with ``batch_size=1``. |
| 213 | +
|
| 214 | +
|
| 215 | +.. code:: python |
| 216 | +
|
| 217 | + import torch |
| 218 | + from models.blip import blip_decoder |
| 219 | +
|
| 220 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 221 | + image_size = 384 |
| 222 | + image = torch.randn(1, 3,384,384).to(device) |
| 223 | + caption_input = "" |
| 224 | +
|
| 225 | + model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' |
| 226 | + model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base') |
| 227 | + model.eval() |
| 228 | + model = model.to(device) |
| 229 | +
|
| 230 | + exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False) |
| 231 | +
|
| 232 | +
|
| 233 | +
|
| 234 | +Error: Cannot mutate tensors with frozen storage |
| 235 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 236 | +
|
| 237 | +While exporting a model, it might fail because the model implementation might contain certain Python operations which are not yet supported by ``torch.export``. |
| 238 | +Some of these failures may have a workaround. BLIP is an example where the original model errors, which can be resolved by making a small change in the code. |
| 239 | +``torch.export`` lists the common cases of supported and unsupported operations in `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ and shows how you can modify your code to make it export compatible. |
| 240 | +
|
| 241 | +.. code:: console |
| 242 | +
|
| 243 | + File "/BLIP/models/blip.py", line 112, in forward |
| 244 | + text.input_ids[:,0] = self.tokenizer.bos_token_id |
| 245 | + File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__ |
| 246 | + outs_unwrapped = func._op_dk( |
| 247 | + RuntimeError: cannot mutate tensors with frozen storage |
| 248 | +
|
| 249 | +
|
| 250 | +
|
| 251 | +Solution |
| 252 | +~~~~~~~~ |
| 253 | +
|
| 254 | +Clone the `tensor <https://github.com/salesforce/BLIP/blob/main/models/blip.py#L112>`__ where export fails. |
| 255 | +
|
| 256 | +.. code:: python |
| 257 | +
|
| 258 | + text.input_ids = text.input_ids.clone() # clone the tensor |
| 259 | + text.input_ids[:,0] = self.tokenizer.bos_token_id |
| 260 | +
|
| 261 | +.. note:: |
| 262 | + This constraint has been relaxed in PyTorch 2.7 nightlies. This should work out-of-the-box in PyTorch 2.7 |
| 263 | +
|
| 264 | +Promptable Image Segmentation |
| 265 | +----------------------------- |
| 266 | +
|
| 267 | +**Image segmentation** is a computer vision technique that divides a digital image into distinct groups of pixels, or segments, based on their characteristics. |
| 268 | +`Segment Anything Model (SAM) <https://ai.meta.com/blog/segment-anything-foundation-model-image-segmentation/>`__) introduced promptable image segmentation, which predicts object masks given prompts that indicate the desired object. `SAM 2 <https://ai.meta.com/sam2/>`__ is |
| 269 | +the first unified model for segmenting objects across images and videos. The `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ class provides an easy interface to the model for prompting |
| 270 | +the model. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction. Since SAM2 provides strong |
| 271 | +zero-shot performance for object tracking, it can be used for tracking game objects in a scene. |
| 272 | +
|
| 273 | +
|
| 274 | +The tensor operations in the predict method of `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ are happening in the `_predict <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L291>`__ method. So, we try to export like this. |
| 275 | +
|
| 276 | +.. code:: python |
| 277 | +
|
| 278 | + ep = torch.export.export( |
| 279 | + self._predict, |
| 280 | + args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output), |
| 281 | + kwargs={"return_logits": return_logits}, |
| 282 | + strict=False, |
| 283 | + ) |
| 284 | +
|
| 285 | +
|
| 286 | +Error: Model is not of type ``torch.nn.Module`` |
| 287 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 288 | +
|
| 289 | +``torch.export`` expects the module to be of type ``torch.nn.Module``. However, the module we are trying to export is a class method. Hence it errors. |
| 290 | +
|
| 291 | +.. code:: console |
| 292 | +
|
| 293 | + Traceback (most recent call last): |
| 294 | + File "/sam2/image_predict.py", line 20, in <module> |
| 295 | + masks, scores, _ = predictor.predict( |
| 296 | + File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict |
| 297 | + ep = torch.export.export( |
| 298 | + File "python3.10/site-packages/torch/export/__init__.py", line 359, in export |
| 299 | + raise ValueError( |
| 300 | + ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>. |
| 301 | +
|
| 302 | +
|
| 303 | +Solution |
| 304 | +~~~~~~~~ |
| 305 | +
|
| 306 | +We write a helper class, which inherits from ``torch.nn.Module`` and call the ``_predict method`` in the ``forward`` method of the class. The complete code can be found `here <https://github.com/anijain2305/sam2/blob/ued/sam2/sam2_image_predictor.py#L293-L311>`__. |
| 307 | +
|
| 308 | +.. code:: python |
| 309 | +
|
| 310 | + class ExportHelper(torch.nn.Module): |
| 311 | + def __init__(self): |
| 312 | + super().__init__() |
| 313 | +
|
| 314 | + def forward(_, *args, **kwargs): |
| 315 | + return self._predict(*args, **kwargs) |
| 316 | +
|
| 317 | + model_to_export = ExportHelper() |
| 318 | + ep = torch.export.export( |
| 319 | + model_to_export, |
| 320 | + args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output), |
| 321 | + kwargs={"return_logits": return_logits}, |
| 322 | + strict=False, |
| 323 | + ) |
| 324 | +
|
| 325 | +Conclusion |
| 326 | +---------- |
| 327 | +
|
| 328 | +In this tutorial, we have learned how to use ``torch.export`` to export models for popular use cases by addressing challenges through correct configuration and simple code modifications. |
| 329 | +Once you are able to export a model, you can lower the ``ExportedProgram`` into your hardware using `AOTInductor <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html>`__ in case of servers and `ExecuTorch <https://pytorch.org/executorch/stable/index.html>`__ in case of edge device. |
| 330 | +To learn more about ``AOTInductor`` (AOTI), please refer to the `AOTI tutorial <https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`__. |
| 331 | +To learn more about ``ExecuTorch`` , please refer to the `ExecuTorch tutorial <https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html>`__. |
0 commit comments