|
6 | 6 | Distributed checkpoints (expert)
|
7 | 7 | ################################
|
8 | 8 |
|
9 |
| -********************************* |
10 |
| -Writing your own Checkpoint class |
11 |
| -********************************* |
| 9 | +Generally, the bigger your model is, the longer it takes to save a checkpoint to disk. |
| 10 | +With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues. |
12 | 11 |
|
13 |
| -We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. |
14 | 12 |
|
| 13 | +---- |
15 | 14 |
|
16 |
| -*********************** |
17 |
| -Customize Checkpointing |
18 |
| -*********************** |
19 | 15 |
|
20 |
| -.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. |
| 16 | +***************************** |
| 17 | +Save a distributed checkpoint |
| 18 | +***************************** |
21 | 19 |
|
22 |
| -Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic |
23 |
| -that is managed by the ``Strategy``. ``CheckpointIO`` is different from :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_save_checkpoint` |
24 |
| -and :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_load_checkpoint` methods as it determines how the checkpoint is saved/loaded to storage rather than |
25 |
| -what's saved in the checkpoint. |
| 20 | +The distributed checkpoint format can be enabled when you train with the :doc:`FSDP strategy <../advanced/model_parallel/fsdp>`. |
26 | 21 |
|
| 22 | +.. code-block:: python |
27 | 23 |
|
28 |
| -.. TODO:: I don't understand this... |
| 24 | + import lightning as L |
| 25 | + from lightning.pytorch.strategies import FSDPStrategy |
29 | 26 |
|
30 |
| -****************************** |
31 |
| -Built-in Checkpoint IO Plugins |
32 |
| -****************************** |
| 27 | + # 1. Select the FSDP strategy and set the sharded/distributed checkpoint format |
| 28 | + strategy = FSDPStrategy(state_dict_type="sharded") |
33 | 29 |
|
34 |
| -.. list-table:: Built-in Checkpoint IO Plugins |
35 |
| - :widths: 25 75 |
36 |
| - :header-rows: 1 |
| 30 | + # 2. Pass the strategy to the Trainer |
| 31 | + trainer = L.Trainer(devices=2, strategy=strategy, ...) |
37 | 32 |
|
38 |
| - * - Plugin |
39 |
| - - Description |
40 |
| - * - :class:`~lightning.pytorch.plugins.io.TorchCheckpointIO` |
41 |
| - - CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints |
42 |
| - respectively, common for most use cases. |
43 |
| - * - :class:`~lightning.pytorch.plugins.io.XLACheckpointIO` |
44 |
| - - CheckpointIO that utilizes ``xm.save`` to save checkpoints for TPU training strategies. |
45 |
| - * - :class:`~lightning.pytorch.plugins.io.AsyncCheckpointIO` |
46 |
| - - ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread. |
| 33 | + # 3. Run the trainer |
| 34 | + trainer.fit(model) |
47 | 35 |
|
48 | 36 |
|
49 |
| -*************************** |
50 |
| -Custom Checkpoint IO Plugin |
51 |
| -*************************** |
| 37 | +With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path. |
| 38 | +This reduces memory peaks and speeds up the saving to disk. |
52 | 39 |
|
53 |
| -``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``Strategy`` as shown below: |
| 40 | +.. collapse:: Full example |
54 | 41 |
|
55 |
| -.. code-block:: python |
| 42 | + .. code-block:: python |
56 | 43 |
|
57 |
| - from lightning.pytorch import Trainer |
58 |
| - from lightning.pytorch.callbacks import ModelCheckpoint |
59 |
| - from lightning.pytorch.plugins import CheckpointIO |
60 |
| - from lightning.pytorch.strategies import SingleDeviceStrategy |
| 44 | + import lightning as L |
| 45 | + from lightning.pytorch.strategies import FSDPStrategy |
| 46 | + from lightning.pytorch.demos import LightningTransformer |
61 | 47 |
|
| 48 | + model = LightningTransformer() |
62 | 49 |
|
63 |
| - class CustomCheckpointIO(CheckpointIO): |
64 |
| - def save_checkpoint(self, checkpoint, path, storage_options=None): |
65 |
| - ... |
| 50 | + strategy = FSDPStrategy(state_dict_type="sharded") |
| 51 | + trainer = L.Trainer( |
| 52 | + accelerator="cuda", |
| 53 | + devices=4, |
| 54 | + strategy=strategy, |
| 55 | + max_steps=3, |
| 56 | + ) |
| 57 | + trainer.fit(model) |
66 | 58 |
|
67 |
| - def load_checkpoint(self, path, storage_options=None): |
68 |
| - ... |
69 | 59 |
|
70 |
| - def remove_checkpoint(self, path): |
71 |
| - ... |
| 60 | + Check the contents of the checkpoint folder: |
72 | 61 |
|
| 62 | + .. code-block:: bash |
73 | 63 |
|
74 |
| - custom_checkpoint_io = CustomCheckpointIO() |
| 64 | + ls -a lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt/ |
75 | 65 |
|
76 |
| - # Either pass into the Trainer object |
77 |
| - model = MyModel() |
78 |
| - trainer = Trainer( |
79 |
| - plugins=[custom_checkpoint_io], |
80 |
| - callbacks=ModelCheckpoint(save_last=True), |
81 |
| - ) |
82 |
| - trainer.fit(model) |
| 66 | + .. code-block:: |
83 | 67 |
|
84 |
| - # or pass into Strategy |
85 |
| - model = MyModel() |
86 |
| - device = torch.device("cpu") |
87 |
| - trainer = Trainer( |
88 |
| - strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io), |
89 |
| - callbacks=ModelCheckpoint(save_last=True), |
90 |
| - ) |
91 |
| - trainer.fit(model) |
| 68 | + epoch=0-step=3.ckpt/ |
| 69 | + ├── __0_0.distcp |
| 70 | + ├── __1_0.distcp |
| 71 | + ├── __2_0.distcp |
| 72 | + ├── __3_0.distcp |
| 73 | + ├── .metadata |
| 74 | + └── meta.pt |
92 | 75 |
|
93 |
| -.. note:: |
| 76 | + The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files |
| 77 | + is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs. |
94 | 78 |
|
95 |
| - Some ``Strategy``s like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable. |
96 | 79 |
|
| 80 | +---- |
97 | 81 |
|
98 |
| -************************** |
99 |
| -Asynchronous Checkpointing |
100 |
| -************************** |
101 | 82 |
|
102 |
| -.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. |
| 83 | +***************************** |
| 84 | +Load a distributed checkpoint |
| 85 | +***************************** |
103 | 86 |
|
104 |
| -To enable saving the checkpoints asynchronously without blocking your training, you can configure |
105 |
| -:class:`~lightning.pytorch.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``. |
| 87 | +You can easily load a distributed checkpoint in Trainer if your script uses :doc:`FSDP <../advanced/model_parallel/fsdp>`. |
106 | 88 |
|
107 | 89 | .. code-block:: python
|
108 | 90 |
|
109 |
| - from lightning.pytorch.plugins.io import AsyncCheckpointIO |
| 91 | + import lightning as L |
| 92 | + from lightning.pytorch.strategies import FSDPStrategy |
110 | 93 |
|
| 94 | + # 1. Select the FSDP strategy and set the sharded/distributed checkpoint format |
| 95 | + strategy = FSDPStrategy(state_dict_type="sharded") |
111 | 96 |
|
112 |
| - async_ckpt_io = AsyncCheckpointIO() |
113 |
| - trainer = Trainer(plugins=[async_ckpt_io]) |
| 97 | + # 2. Pass the strategy to the Trainer |
| 98 | + trainer = L.Trainer(devices=2, strategy=strategy, ...) |
114 | 99 |
|
| 100 | + # 3. Set the checkpoint path to load |
| 101 | + trainer.fit(model, ckpt_path="path/to/checkpoint") |
115 | 102 |
|
116 |
| -It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously. |
117 |
| -By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``. |
118 |
| -But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``. |
| 103 | +Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint. |
119 | 104 |
|
120 |
| -.. code-block:: python |
| 105 | +.. collapse:: Full example |
| 106 | + |
| 107 | + .. code-block:: python |
| 108 | +
|
| 109 | + import lightning as L |
| 110 | + from lightning.pytorch.strategies import FSDPStrategy |
| 111 | + from lightning.pytorch.demos import LightningTransformer |
| 112 | +
|
| 113 | + model = LightningTransformer() |
| 114 | +
|
| 115 | + strategy = FSDPStrategy(state_dict_type="sharded") |
| 116 | + trainer = L.Trainer( |
| 117 | + accelerator="cuda", |
| 118 | + devices=2, |
| 119 | + strategy=strategy, |
| 120 | + max_steps=5, |
| 121 | + ) |
| 122 | + trainer.fit(model, ckpt_path="lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt") |
| 123 | +
|
| 124 | +
|
| 125 | +.. important:: |
| 126 | + |
| 127 | + If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Trainer at all), then you will have to :ref:`convert it to a single-file checkpoint first <Convert dist-checkpoint>`. |
| 128 | + |
| 129 | + |
| 130 | +---- |
| 131 | + |
| 132 | + |
| 133 | +.. _Convert dist-checkpoint: |
121 | 134 |
|
122 |
| - from lightning.pytorch.plugins.io import AsyncCheckpointIO |
| 135 | +******************************** |
| 136 | +Convert a distributed checkpoint |
| 137 | +******************************** |
123 | 138 |
|
124 |
| - base_ckpt_io = MyCustomCheckpointIO() |
125 |
| - async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io) |
126 |
| - trainer = Trainer(plugins=[async_ckpt_io]) |
| 139 | +Coming soon. |
0 commit comments