Skip to content

GPUDirect Storage prototype tutorial #3317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"prototype_source/vmap_recipe",
"prototype_source/torchscript_freezing",
"prototype_source/nestedtensor",
"prototype_source/gpu_direct_storage", # requires specific filesystem + GPUDirect Storage to be set up
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't it run in compat mode with a random machine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need a specific filesystem

"recipes_source/recipes/saving_and_loading_models_for_inference",
"recipes_source/recipes/saving_multiple_models_in_one_file",
"recipes_source/recipes/tensorboard_with_pytorch",
Expand Down
132 changes: 132 additions & 0 deletions prototype_source/gpu_direct_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
(prototype) Accelerating ``torch.save`` and ``torch.load`` with GPUDirect Storage
=================================================================================

GPUDirect Storage enables a direct data path for direct memory access transfers
between GPU memory and storage, avoiding a bounce buffer through the CPU.

In version **2.7**, we introduced new prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around
the `cuFile APIs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_
that can be used with ``torch.Tensor`` to achieve improved I/O performance.

In this tutorial, we will demonstrate how to use the ``torch.cuda.gds`` APIs in conjunction with
checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem.

.. grid:: 2

.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites

* Understand how to use the ``torch.cuda.gds`` APIs in conjunction with
checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem

.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites

* PyTorch v.2.7.0 or later
* GPUDirect Storage must be installed per
`the documentation <https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/contents.html>`_
* Ensure that the filesystem that you are saving/loading to supports GPUDirect Storage.
"""

################################################################################
# Using GPUDirect Storage with ``torch.save`` and ``torch.load``
# =============================================================
# GPUDirect Storage requires a storage alignment of 4KB. You can toggle this by using
# ``torch.utils.serialization.config.save.storage_alignment``:

import torch
from torch.utils.serialization import config as serialization_config

serialization_config.save.storage_alignment = 4096

################################################################################
# The steps involved in the process are as follows:
# * Write the checkpoint file without any actual data. This reserves the space on disk.
# * Read the offsets for the storage associated with each tensor in the checkpoint using ``FakeTensor``.
# * Use ``GDSFile`` to write the appropriate data at these offsets.
#
# Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context
# manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage``
# in the state dictionary, space will be reserved within the checkpoint for the storage bytes.

import torch.nn as nn

m = nn.Linear(5, 10, device='cuda')
sd = m.state_dict()

with torch.serialization.skip_data():
torch.save(sd, "checkpoint.pt")

################################################################################
# We can get the offsets that each storage should be written to within the checkpoint by loading under
# a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (such as sizes, strides, dtype, device)
# information about the tensor but does not have any storage bytes. The following snippet will not materialize
# any data but will tag each ``FakeTensor`` with the offset within the checkpoint that
# corresponds to the tensor.
#
# If you are continuously saving the same state dictionary during training, you
# would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to
# be saved or loaded to repeatedly you can use the ``torch.cuda.gds.gds_register_buffer`` which wraps
# ``cuFileBufRegister`` to register the storages as GDS buffers.
#
# Note that ``torch.cuda.gds.GdsFile.save_storage`` binds to the synchronous ``cuFileWrite`` API,
# so no synchronization is needed afterwards.


import os
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode() as mode:
fake_sd = torch.load("checkpoint.pt")

for k, v in fake_sd.items():
print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")

f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)

for k, v in sd.items():
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# save_storage is a wrapper around `cuFileWrite`
f.save_storage(v.untyped_storage(), offset)


################################################################################
# We verify correctness of the saved checkpoint by ``torch.load`` and comparing.

sd_loaded = torch.load("checkpoint.pt")
for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])

################################################################################
# The loading flow is the inverse: you can use ``torch.load`` with the ``torch.serialization.skip_data`` context
# manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be
# created but their storages will be empty (as if the tensors were created via ``torch.empty``).

with torch.serialization.skip_data():
sd_loaded = torch.load("checkpoint.pt")

################################################################################
# We once again use the ``FakeTensorMode`` to get the checkpoint offsets and
# ascertain that the loaded checkpoint is the same as the saved checkpoint.
#
# Similar to ``torch.cuda.gds.GdsFile.save_storage``, ``torch.cuda.gds.GdsFile.load_storage``
# binds to the synchronous ``cuFileRead`` API, so no synchronization is needed afterwards.

for k, v in sd_loaded.items():
assert not torch.equal(v, sd[k])
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# load_storage is a wrapper around `cuFileRead`
f.load_storage(v.untyped_storage(), offset)

for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])

del f
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar synchronization question as above

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think synchronization is needed after the call as cuFileRead/Write are blocking operations that block until IO is complete https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufileread. You might need to synchronize before these ops (rather than after) though

Screenshot 2025-04-08 at 11 21 32 AM


# Conclusion
# ==========
#
# In this tutorial we have demonstrated how to use the prototype ``torch.cuda.gds`` APIs
# in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Please
# file an issue in the PyTorch GitHub repo if you have any feedback.
8 changes: 8 additions & 0 deletions prototype_source/prototype_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ Prototype features are not available as part of binary distributions like PyPI o
:link: ../prototype/python_extension_autoload.html
:tags: Extending-PyTorch, Frontend-APIs

.. GPUDirect Storage
.. customcarditem::
:header: (prototype) Using GPUDirect Storage
:card_description: Learn how to use GPUDirect Storage in PyTorch.
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../prototype/gpudirect_storage.html
:tags: GPUDirect-Storage

.. End of tutorial card section

.. raw:: html
Expand Down
Loading