diff --git a/docs/source-fabric/api/fabric_methods.rst b/docs/source-fabric/api/fabric_methods.rst index 87b22578c1202..fb03e0f61f6bd 100644 --- a/docs/source-fabric/api/fabric_methods.rst +++ b/docs/source-fabric/api/fabric_methods.rst @@ -40,6 +40,7 @@ Moves the model and optimizer to the correct device automatically. model = nn.Linear(32, 64) optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=10) # Set up model and optimizer for accelerated training model, optimizer = fabric.setup(model, optimizer) @@ -47,6 +48,9 @@ Moves the model and optimizer to the correct device automatically. # If you don't want Fabric to set the device model, optimizer = fabric.setup(model, optimizer, move_to_device=False) + # If you want to additionally register a learning rate scheduler with compatible strategies such as DeepSpeed + model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler) + The setup method also prepares the model for the selected precision choice so that operations during ``forward()`` get cast automatically. Advanced users should read :doc:`the notes on models wrapped by Fabric <../api/wrappers>`. diff --git a/docs/source-fabric/api/wrappers.rst b/docs/source-fabric/api/wrappers.rst index e87874eb08666..8b20e1906072e 100644 --- a/docs/source-fabric/api/wrappers.rst +++ b/docs/source-fabric/api/wrappers.rst @@ -124,7 +124,7 @@ If you were to run this model in Fabric with multiple devices (DDP or FSDP), you # OK: Calling the model directly output = model(torch.randn(10)) - # OK: Calling the model's forward (equivalent to the abvoe) + # OK: Calling the model's forward (equivalent to the above) output = model.forward(torch.randn(10)) # ERROR: Calling another method that calls forward indirectly diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 058e5e7c40751..96e73ea6ffb8e 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -32,6 +32,7 @@ from lightning_utilities.core.overrides import is_overridden from torch import Tensor from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler import lightning.fabric @@ -206,6 +207,7 @@ def setup( self, module: nn.Module, *optimizers: Optimizer, + scheduler: Optional[_LRScheduler] = None, move_to_device: bool = True, _reapply_compile: bool = True, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy @@ -214,6 +216,7 @@ def setup( Args: module: A :class:`torch.nn.Module` to set up *optimizers: The optimizer(s) to set up (no optimizers is also possible) + scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible) move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. _reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the @@ -222,7 +225,8 @@ def setup( FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues. Returns: - The tuple containing wrapped module and the optimizers, in the same order they were passed in. + The tuple containing wrapped module, optimizers, and an optional learning rate scheduler, + in the same order they were passed in. """ self._validate_setup(module, optimizers) @@ -236,8 +240,8 @@ def setup( # Let accelerator/plugin wrap and connect the models and optimizers if optimizers: - module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] - module, list(optimizers) + module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] + module, list(optimizers), scheduler ) else: module = self._strategy.setup_module(module) @@ -266,7 +270,7 @@ def setup( if optimizers: # join both types in a tuple for API convenience - return (module, *optimizers) + return (module, *optimizers, scheduler) if scheduler is not None else (module, *optimizers) return module def setup_module( diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 41820c1cc433f..4b31c7a42fd64 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -27,6 +27,7 @@ from lightning_utilities.core.imports import RequirementCache from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from typing_extensions import override from lightning.fabric.accelerators import Accelerator, CUDAAccelerator @@ -316,15 +317,14 @@ def model(self) -> "DeepSpeedEngine": @override def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple["DeepSpeedEngine", list[Optimizer]]: - """Set up a model and multiple optimizers together. - - Currently, only a single optimizer is supported. + self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> tuple["DeepSpeedEngine", list[Optimizer], Any]: + """Set up a model and multiple optimizers together, along with an optional learning rate scheduler. Currently, + only a single optimizer is supported. Return: - The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single - deepspeed optimizer. + The model wrapped into a :class:`deepspeed.DeepSpeedEngine`, a list with a single + deepspeed optimizer, and an optional learning rate scheduler. """ if len(optimizers) != 1: @@ -332,9 +332,9 @@ def setup_module_and_optimizers( f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." ) - self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) + self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler) self._set_deepspeed_activation_checkpointing() - return self._deepspeed_engine, [optimizer] + return self._deepspeed_engine, [optimizer], scheduler @override def setup_module(self, module: Module) -> "DeepSpeedEngine": @@ -343,7 +343,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine": For training, see :meth:`setup_module_and_optimizers`. """ - self._deepspeed_engine, _ = self._initialize_engine(module) + self._deepspeed_engine, _, _ = self._initialize_engine(module) return self._deepspeed_engine @override @@ -596,10 +596,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: ) def _initialize_engine( - self, - model: Module, - optimizer: Optional[Optimizer] = None, - ) -> tuple["DeepSpeedEngine", Optimizer]: + self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None + ) -> tuple["DeepSpeedEngine", Optimizer, Any]: """Initialize one model and one optimizer with an optional learning rate scheduler. This calls ``deepspeed.initialize`` internally. @@ -608,15 +606,16 @@ def _initialize_engine( import deepspeed model_parameters = filter(lambda p: p.requires_grad, model.parameters()) - deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( + deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( args=argparse.Namespace(device_rank=self.root_device.index), config=self.config, model=model, model_parameters=model_parameters, optimizer=optimizer, + lr_scheduler=scheduler, dist_init_required=False, ) - return deepspeed_engine, deepspeed_optimizer + return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler @override def setup_environment(self) -> None: diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 9dd5b2c62d4c9..b2f548c49056d 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -33,6 +33,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from typing_extensions import TypeGuard, override from lightning.fabric.accelerators import Accelerator @@ -261,8 +262,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple[Module, list[Optimizer]]: + self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer.""" use_orig_params = self._fsdp_kwargs.get("use_orig_params") @@ -274,7 +275,7 @@ def setup_module_and_optimizers( " call `setup_optimizer`." ) module = self.setup_module(module) - return module, optimizers + return module, optimizers, scheduler @override def setup_module(self, module: Module) -> Module: diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 4daad9b954b2f..1788ffc757809 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -21,6 +21,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from lightning.fabric.accelerators import Accelerator @@ -145,8 +146,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont return stack def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple[Module, list[Optimizer]]: + self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]: """Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -155,7 +156,7 @@ def setup_module_and_optimizers( """ module = self.setup_module(module) optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] - return module, optimizers + return module, optimizers, scheduler def setup_module(self, module: Module) -> Module: """Performs setup for the model, e.g., by wrapping it by another class.""" diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 935ef72713bcc..41c63dd01f620 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -21,6 +21,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from typing_extensions import override @@ -196,8 +197,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple[Module, list[Optimizer]]: + self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None + ) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]: """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.""" raise NotImplementedError( f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 032ee63cd4721..d24021fb27b31 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -137,6 +137,7 @@ def test_deepspeed_setup_module(init_mock): model=model, model_parameters=ANY, optimizer=None, + lr_scheduler=None, dist_init_required=False, ) diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index 78622adf66fa6..d044626bf8389 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -102,7 +102,7 @@ def test_parallelize_fn_call(): strategy = ModelParallelStrategy(parallelize_fn=parallelize_fn) strategy._device_mesh = Mock() strategy.parallel_devices = [torch.device("cpu")] - model_setup, [optimizer_setup] = strategy.setup_module_and_optimizers(model, [optimizer]) + model_setup, [optimizer_setup], _ = strategy.setup_module_and_optimizers(model, [optimizer]) assert model_setup is parallel_model_mock assert optimizer_setup is optimizer parallelize_fn.assert_called_with(model, strategy.device_mesh)