From e037ad90ed6ee06a1c32dabd962a615b7515e3fb Mon Sep 17 00:00:00 2001 From: Felicity Liao Date: Tue, 11 Mar 2025 14:20:13 -0700 Subject: [PATCH] DO NOT LAND Differential Revision: D70993764 --- torchrec/distributed/model_parallel.py | 49 +++++++++----------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 1358def72..16b58c196 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -10,7 +10,7 @@ import abc import copy import logging as logger -from collections import defaultdict, OrderedDict +from collections import OrderedDict from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Set, Tuple, Type import torch @@ -754,57 +754,42 @@ def sync(self, include_optimizer_state: bool = True) -> None: include_optimizer_state (bool): Flag to include optimizer state syncing upon call """ assert self._replica_pg is not None, "replica_pg is not initialized!" - all_weights_by_dtype: dict[torch.dtype, List[torch.Tensor]] = defaultdict(list) - - for emb_kernel in self._modules_to_sync: + all_weights: List[torch.Tensor] = [ + w + for emb_kernel in self._modules_to_sync # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. - for w in emb_kernel.split_embedding_weights(): - all_weights_by_dtype[w.dtype].append(w) + for w in emb_kernel.split_embedding_weights() + ] opts = None if self._custom_all_reduce is None: opts = dist.AllreduceCoalescedOptions() opts.reduceOp = dist.ReduceOp.AVG - self._allreduce_tensors(all_weights_by_dtype, opts) + self._allreduce_tensors(all_weights, opts) if include_optimizer_state: - optimizer_tensors_by_dtype: Dict[torch.dtype, List[torch.Tensor]] = ( - defaultdict(list) - ) + optimizer_tensors = [] for emb_kernel in self._modules_to_sync: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. optimizer_states = emb_kernel.get_optimizer_state() - for state in optimizer_states: - opt_tensor = state["sum"] - optimizer_tensors_by_dtype[opt_tensor.dtype].append(opt_tensor) - if optimizer_tensors_by_dtype: - self._allreduce_tensors(optimizer_tensors_by_dtype, opts) + optimizer_tensors.extend([state["sum"] for state in optimizer_states]) + if optimizer_tensors: + self._allreduce_tensors(optimizer_tensors, opts) def _allreduce_tensors( self, - tensors_dict: Dict[torch.dtype, List[torch.Tensor]], + tensors: List[torch.Tensor], opts: Optional[dist.AllreduceCoalescedOptions] = None, ) -> None: """ Helper to perform all reduce on given tensors, uses custom all reduce function if provided - We perform all reduce per tensor dtype per collective constraints. """ - - def custom_all_reduce(tensors: List[torch.Tensor]) -> None: - # pyre-ignore[29] + if self._custom_all_reduce is not None: + # pyre-ignore[6] self._custom_all_reduce(tensors) - - def default_allreduce(tensor_list: List[torch.Tensor]) -> None: - self._replica_pg.allreduce_coalesced(tensor_list, opts=opts).wait() - - allreduce = ( - custom_all_reduce - if self._custom_all_reduce is not None - else default_allreduce - ) - - for tensor_list in tensors_dict.values(): - allreduce(tensor_list) + else: + handle = self._replica_pg.allreduce_coalesced(tensors, opts=opts) + handle.wait() def set_all_reduce_hook( self,