Skip to content

Commit 93c1ab0

Browse files
authored
Dedicated docs page for distributed checkpoints (Trainer) (Lightning-AI#19299)
1 parent 6655c4d commit 93c1ab0

File tree

6 files changed

+112
-89
lines changed

6 files changed

+112
-89
lines changed

docs/source-fabric/advanced/model_parallel/fsdp.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ Even a single H100 GPU with 80 GB of VRAM (the biggest today) is not enough to t
1212
The memory consumption for training is generally made up of
1313

1414
1. the model parameters,
15-
2. the layer activations (forward) and
16-
3. the gradients (backward).
17-
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter),
15+
2. the layer activations (forward),
16+
3. the gradients (backward) and
17+
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter).
1818

1919
|
2020

docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ The distributed checkpoint format is the default when you train with the :doc:`F
4747
4848
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
4949
This reduces memory peaks and speeds up the saving to disk.
50-
The resulting checkpoint folder will have this structure:
5150

5251
.. collapse:: Full example
5352

@@ -103,6 +102,7 @@ The resulting checkpoint folder will have this structure:
103102
├── __1_0.distcp
104103
├── __2_0.distcp
105104
├── __3_0.distcp
105+
├── .metadata
106106
└── meta.pt
107107
108108
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files

docs/source-pytorch/advanced/model_parallel/fsdp.rst

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ Even a single H100 GPU with 80 GB of VRAM (the biggest today) is not enough to t
1616
The memory consumption for training is generally made up of
1717

1818
1. the model parameters,
19-
2. the layer activations (forward) and
20-
3. the gradients (backward).
21-
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter),
19+
2. the layer activations (forward),
20+
3. the gradients (backward) and
21+
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter).
2222

2323
|
2424
@@ -200,7 +200,8 @@ Before:
200200
class LanguageModel(L.LightningModule):
201201
def __init__(self, vocab_size):
202202
super().__init__()
203-
self.model = Transformer(vocab_size=vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64) # 1B parameters
203+
# 1B parameters
204+
self.model = Transformer(vocab_size=vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64)
204205
205206
After:
206207

@@ -397,13 +398,14 @@ The resulting checkpoint folder will have this structure:
397398
├── .metadata
398399
├── __0_0.distcp
399400
├── __1_0.distcp
401+
...
400402
└── meta.pt
401403
402404
The “sharded” checkpoint format is the most efficient to save and load in Lightning.
403405

404406
**Which checkpoint format should I use?**
405407

406-
- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to convert the sharded checkpoint into a regular checkpoint file.
408+
- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to :doc:`convert the sharded checkpoint into a regular checkpoint file <../../common/checkpointing_expert>`.
407409
- ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required.
408410

409411

docs/source-pytorch/common/checkpointing.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Checkpointing
4545

4646
.. displayitem::
4747
:header: Distributed checkpoints
48-
:description: Customize checkpointing for custom distributed strategies and accelerators.
48+
:description: Save and load very large models efficiently with distributed checkpoints
4949
:col_css: col-md-4
5050
:button_link: checkpointing_expert.html
5151
:height: 150

docs/source-pytorch/common/checkpointing_expert.rst

Lines changed: 92 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -6,121 +6,134 @@
66
Distributed checkpoints (expert)
77
################################
88

9-
*********************************
10-
Writing your own Checkpoint class
11-
*********************************
9+
Generally, the bigger your model is, the longer it takes to save a checkpoint to disk.
10+
With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues.
1211

13-
We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback.
1412

13+
----
1514

16-
***********************
17-
Customize Checkpointing
18-
***********************
1915

20-
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
16+
*****************************
17+
Save a distributed checkpoint
18+
*****************************
2119

22-
Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
23-
that is managed by the ``Strategy``. ``CheckpointIO`` is different from :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_save_checkpoint`
24-
and :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_load_checkpoint` methods as it determines how the checkpoint is saved/loaded to storage rather than
25-
what's saved in the checkpoint.
20+
The distributed checkpoint format can be enabled when you train with the :doc:`FSDP strategy <../advanced/model_parallel/fsdp>`.
2621

22+
.. code-block:: python
2723
28-
.. TODO:: I don't understand this...
24+
import lightning as L
25+
from lightning.pytorch.strategies import FSDPStrategy
2926
30-
******************************
31-
Built-in Checkpoint IO Plugins
32-
******************************
27+
# 1. Select the FSDP strategy and set the sharded/distributed checkpoint format
28+
strategy = FSDPStrategy(state_dict_type="sharded")
3329
34-
.. list-table:: Built-in Checkpoint IO Plugins
35-
:widths: 25 75
36-
:header-rows: 1
30+
# 2. Pass the strategy to the Trainer
31+
trainer = L.Trainer(devices=2, strategy=strategy, ...)
3732
38-
* - Plugin
39-
- Description
40-
* - :class:`~lightning.pytorch.plugins.io.TorchCheckpointIO`
41-
- CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
42-
respectively, common for most use cases.
43-
* - :class:`~lightning.pytorch.plugins.io.XLACheckpointIO`
44-
- CheckpointIO that utilizes ``xm.save`` to save checkpoints for TPU training strategies.
45-
* - :class:`~lightning.pytorch.plugins.io.AsyncCheckpointIO`
46-
- ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
33+
# 3. Run the trainer
34+
trainer.fit(model)
4735
4836
49-
***************************
50-
Custom Checkpoint IO Plugin
51-
***************************
37+
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
38+
This reduces memory peaks and speeds up the saving to disk.
5239

53-
``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``Strategy`` as shown below:
40+
.. collapse:: Full example
5441

55-
.. code-block:: python
42+
.. code-block:: python
5643
57-
from lightning.pytorch import Trainer
58-
from lightning.pytorch.callbacks import ModelCheckpoint
59-
from lightning.pytorch.plugins import CheckpointIO
60-
from lightning.pytorch.strategies import SingleDeviceStrategy
44+
import lightning as L
45+
from lightning.pytorch.strategies import FSDPStrategy
46+
from lightning.pytorch.demos import LightningTransformer
6147
48+
model = LightningTransformer()
6249
63-
class CustomCheckpointIO(CheckpointIO):
64-
def save_checkpoint(self, checkpoint, path, storage_options=None):
65-
...
50+
strategy = FSDPStrategy(state_dict_type="sharded")
51+
trainer = L.Trainer(
52+
accelerator="cuda",
53+
devices=4,
54+
strategy=strategy,
55+
max_steps=3,
56+
)
57+
trainer.fit(model)
6658
67-
def load_checkpoint(self, path, storage_options=None):
68-
...
6959
70-
def remove_checkpoint(self, path):
71-
...
60+
Check the contents of the checkpoint folder:
7261

62+
.. code-block:: bash
7363
74-
custom_checkpoint_io = CustomCheckpointIO()
64+
ls -a lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt/
7565
76-
# Either pass into the Trainer object
77-
model = MyModel()
78-
trainer = Trainer(
79-
plugins=[custom_checkpoint_io],
80-
callbacks=ModelCheckpoint(save_last=True),
81-
)
82-
trainer.fit(model)
66+
.. code-block::
8367
84-
# or pass into Strategy
85-
model = MyModel()
86-
device = torch.device("cpu")
87-
trainer = Trainer(
88-
strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
89-
callbacks=ModelCheckpoint(save_last=True),
90-
)
91-
trainer.fit(model)
68+
epoch=0-step=3.ckpt/
69+
├── __0_0.distcp
70+
├── __1_0.distcp
71+
├── __2_0.distcp
72+
├── __3_0.distcp
73+
├── .metadata
74+
└── meta.pt
9275
93-
.. note::
76+
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files
77+
is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs.
9478

95-
Some ``Strategy``s like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.
9679

80+
----
9781

98-
**************************
99-
Asynchronous Checkpointing
100-
**************************
10182

102-
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
83+
*****************************
84+
Load a distributed checkpoint
85+
*****************************
10386

104-
To enable saving the checkpoints asynchronously without blocking your training, you can configure
105-
:class:`~lightning.pytorch.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``.
87+
You can easily load a distributed checkpoint in Trainer if your script uses :doc:`FSDP <../advanced/model_parallel/fsdp>`.
10688

10789
.. code-block:: python
10890
109-
from lightning.pytorch.plugins.io import AsyncCheckpointIO
91+
import lightning as L
92+
from lightning.pytorch.strategies import FSDPStrategy
11093
94+
# 1. Select the FSDP strategy and set the sharded/distributed checkpoint format
95+
strategy = FSDPStrategy(state_dict_type="sharded")
11196
112-
async_ckpt_io = AsyncCheckpointIO()
113-
trainer = Trainer(plugins=[async_ckpt_io])
97+
# 2. Pass the strategy to the Trainer
98+
trainer = L.Trainer(devices=2, strategy=strategy, ...)
11499
100+
# 3. Set the checkpoint path to load
101+
trainer.fit(model, ckpt_path="path/to/checkpoint")
115102
116-
It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously.
117-
By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``.
118-
But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``.
103+
Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint.
119104

120-
.. code-block:: python
105+
.. collapse:: Full example
106+
107+
.. code-block:: python
108+
109+
import lightning as L
110+
from lightning.pytorch.strategies import FSDPStrategy
111+
from lightning.pytorch.demos import LightningTransformer
112+
113+
model = LightningTransformer()
114+
115+
strategy = FSDPStrategy(state_dict_type="sharded")
116+
trainer = L.Trainer(
117+
accelerator="cuda",
118+
devices=2,
119+
strategy=strategy,
120+
max_steps=5,
121+
)
122+
trainer.fit(model, ckpt_path="lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt")
123+
124+
125+
.. important::
126+
127+
If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Trainer at all), then you will have to :ref:`convert it to a single-file checkpoint first <Convert dist-checkpoint>`.
128+
129+
130+
----
131+
132+
133+
.. _Convert dist-checkpoint:
121134

122-
from lightning.pytorch.plugins.io import AsyncCheckpointIO
135+
********************************
136+
Convert a distributed checkpoint
137+
********************************
123138

124-
base_ckpt_io = MyCustomCheckpointIO()
125-
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
126-
trainer = Trainer(plugins=[async_ckpt_io])
139+
Coming soon.

docs/source-pytorch/glossary/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Console Logging <../common/console_logs>
1212
Debugging <../debug/debugging>
1313
DeepSpeed <../advanced/model_parallel/deepspeed>
14+
Distributed Checkpoints <../common/checkpointing_expert>
1415
Early stopping <../common/early_stopping>
1516
Experiment manager (Logger) <../visualize/experiment_managers>
1617
Finetuning <../advanced/finetuning>
@@ -113,6 +114,13 @@ Glossary
113114
:button_link: ../advanced/model_parallel/deepspeed.html
114115
:height: 100
115116

117+
.. displayitem::
118+
:header: Distributed Checkpoints
119+
:description: Save and load very large models efficiently with distributed checkpoints
120+
:col_css: col-md-12
121+
:button_link: ../common/checkpointing_expert.html
122+
:height: 100
123+
116124
.. displayitem::
117125
:header: Early stopping
118126
:description: Stop the training when no improvement is observed

0 commit comments

Comments
 (0)