Skip to content

Commit

Permalink
Stage3: Use new torch grad accumulation hooks API (#6773)
Browse files Browse the repository at this point in the history
      * This commit addresses a Deepspeed issue
[#6718](#6718)
* The existing code has been using the grad_acc node hook to reduce
params grads.
The constructs such as `param.data = replicated_tensor.data` used in
`allgather_params(..)`
are compiled into `param.set()` causing the hook assigned to the
grad_acc node not being called.
* Starting from PyTorch 2.1 there is a new and robust hook API on a
param itself: `param.register_post_accumulate_grad_hook(..)`
* This commit will make use of the proper API depending on the PyTorch
version
* It will also disable compile for PyTorch versions < 2.1

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
  • Loading branch information
3 people authored Jan 3, 2025
1 parent 3573858 commit 456c9ac
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
from deepspeed.utils.torch import required_torch_version

try:
from torch.compiler import is_compiling as torch_is_compiling
Expand All @@ -16,7 +17,7 @@


def is_compile_supported():
return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
return required_torch_version(min_version=2.1)


def disable(func):
Expand Down
7 changes: 2 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.utils.torch import register_grad_hook
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
Expand Down Expand Up @@ -1159,7 +1160,6 @@ def overlapping_partition_gradients_reduce_epilogue(self):

def create_reduce_and_remove_grad_hooks(self):
print_rank_0(f'[Begin] Create gradient reduction hooks')
self.grad_accs = []
self.leaf_parameters = defaultdict(list)
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
Expand All @@ -1172,15 +1172,12 @@ def create_reduce_and_remove_grad_hooks(self):

#print(f"After all gather {param.device}, {param.shape}")
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param)

self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)
self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads))

#print(f"param grad fn {param.expand_as(param).grad_fn}")
if z3_leaf_parameter(param):
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None):
return False

return True


def register_grad_hook(param, hook):
if required_torch_version(min_version=2.1):
return param.register_post_accumulate_grad_hook(hook)
else:
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
return grad_acc.register_hook(hook)

0 comments on commit 456c9ac

Please sign in to comment.