Skip to content
19 changes: 19 additions & 0 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# DeepSpeed Team

import torch
import contextlib
import functools
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator

try:
from torch.compiler import is_compiling as torch_is_compiling
Expand All @@ -16,6 +18,11 @@
# Torch does not have compiler support
torch_is_compiling = lambda: False

if required_torch_version(min_version="2.6.0a"):
from torch._dynamo.compiled_autograd import _enable as compiled_autograd_enable
else:
from torch._dynamo.compiled_autograd import enable as compiled_autograd_enable


def is_compile_supported():
return required_torch_version(min_version=2.1)
Expand Down Expand Up @@ -71,3 +78,15 @@ def wrapper(*args, **kwargs):

def is_compiling():
return torch_is_compiling()


@contextlib.contextmanager
def compiled_autograd(enabled, kwargs):
try:
if enabled:
with compiled_autograd_enable(torch.compile(backend=get_accelerator().get_compile_backend(), **kwargs)):
yield
else:
yield
finally:
pass
22 changes: 18 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@

from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import is_compile_supported
from .compiler import is_compile_supported, compiled_autograd
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
Expand Down Expand Up @@ -420,6 +420,9 @@ def __init__(self,
self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather)
self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states)

self._is_compiled_autograd_enabled = False
self._compile_kwargs = {}

def _optimized_linear_offload_setup(self):
self.optimized_linear_base_weight_sharding = False
self.optimized_linear_lora_enabled = False
Expand Down Expand Up @@ -2359,8 +2362,9 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True):

self._start_timers(self.engine_timers.backward_timers)
loss = self._backward_prologue(loss, scale_wrt_gas)
self._do_optimizer_backward(loss, retain_graph)
self._backward_epilogue()
with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs):
self._do_optimizer_backward(loss, retain_graph)
self._backward_epilogue()
self._stop_timers(self.engine_timers.backward_timers)

return loss
Expand Down Expand Up @@ -4078,7 +4082,11 @@ def empty_partition_cache(self):
gc.collect()
get_accelerator().empty_cache()

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
def compile(self,
backend=get_accelerator().get_compile_backend(),
compile_kwargs={},
schedule=None,
compiled_autograd_enabled=False) -> None:
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
Expand Down Expand Up @@ -4144,6 +4152,12 @@ def passes_name_to_fn(passes):
raise

self._is_compiled = True
if compiled_autograd_enabled:
if not self._deepcompile_active:
self._is_compiled_autograd_enabled = compiled_autograd_enabled
self._compile_kwargs = compile_kwargs
else:
logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.")

def _set_deepcompile_active(self, active: bool) -> None:
"""Toggle DeepCompile runtime state and manage forward hooks accordingly."""
Expand Down