diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index 84f2067a0dce..6c605658e2a3 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -4,8 +4,10 @@ # DeepSpeed Team import torch +import contextlib import functools from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator try: from torch.compiler import is_compiling as torch_is_compiling @@ -16,6 +18,16 @@ # Torch does not have compiler support torch_is_compiling = lambda: False +try: + if required_torch_version(min_version="2.6.0a"): + from torch._dynamo.compiled_autograd import _enable as compiled_autograd_enable + else: + from torch._dynamo.compiled_autograd import enable as compiled_autograd_enable + + _COMPILED_AUTOGRAD_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + _COMPILED_AUTOGRAD_AVAILABLE = False + def is_compile_supported(): return required_torch_version(min_version=2.1) @@ -73,6 +85,22 @@ def is_compiling(): return torch_is_compiling() +@contextlib.contextmanager +def compiled_autograd(enabled: bool, kwargs: dict): + if not enabled or not _COMPILED_AUTOGRAD_AVAILABLE: + yield + return + + if torch_is_compiling(): + yield + return + + compiler_fn = torch.compile(backend=get_accelerator().get_compile_backend(), **kwargs) + + with compiled_autograd_enable(compiler_fn): + yield + + def dummy_decorator(func): return func diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 760d92f91146..5ec61ca79dc7 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -106,7 +106,7 @@ from .pipe.module import PipelineModule from .utils import get_ma_status -from .compiler import is_compile_supported +from .compiler import is_compile_supported, compiled_autograd from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE @@ -446,6 +446,9 @@ def __init__(self, # See also: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook self.optimizer.register_grad_acc_post_hook(self._backward_post_hook) + self._is_compiled_autograd_enabled = False + self._compile_kwargs = {} + def _optimized_linear_offload_setup(self): self.optimized_linear_base_weight_sharding = False self.optimized_linear_lora_enabled = False @@ -2476,17 +2479,18 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): elif self.torch_autocast_z0_gradscaler: loss = self.torch_autocast_z0_gradscaler.scale(loss) - if self.zero_optimization() or not self.amp_enabled(): - loss.backward(**backward_kwargs) - elif self.amp_enabled(): - # AMP requires delaying unscale when inside gradient accumulation boundaries - # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations - delay_unscale = not self.is_gradient_accumulation_boundary() - with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: - scaled_loss.backward(**backward_kwargs) + with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): + if self.zero_optimization() or not self.amp_enabled(): + loss.backward(**backward_kwargs) + elif self.amp_enabled(): + # AMP requires delaying unscale when inside gradient accumulation boundaries + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = not self.is_gradient_accumulation_boundary() + with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: + scaled_loss.backward(**backward_kwargs) - # backward_epilogue is not called in a hook when self._support_torch_style_backward is False - self._backward_epilogue() + # backward_epilogue is not called in a hook when self._support_torch_style_backward is False + self._backward_epilogue() self._running_engine_backward = False @@ -4205,7 +4209,11 @@ def empty_partition_cache(self): gc.collect() get_accelerator().empty_cache() - def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None: + def compile(self, + backend=get_accelerator().get_compile_backend(), + compile_kwargs={}, + schedule=None, + compiled_autograd_enabled=False) -> None: """Compile the module using the specified backend and kwargs. If a compiler_fn is set, it will be used instead of torch.compile(). """ @@ -4271,6 +4279,13 @@ def passes_name_to_fn(passes): raise self._is_compiled = True + self._compile_kwargs = compile_kwargs + if compiled_autograd_enabled: + if not self._deepcompile_active: + self._is_compiled_autograd_enabled = compiled_autograd_enabled + else: + logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.") + self._is_compiled_autograd_enabled = False def _set_deepcompile_active(self, active: bool) -> None: """Toggle DeepCompile runtime state and manage forward hooks accordingly."""