diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index cfd75068..6e64804b 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -8,6 +8,7 @@ from itertools import cycle from time import sleep from typing import List, Optional +import pickle import numpy as np import torch @@ -1008,7 +1009,8 @@ def get_new_pg(self, group_ranks, backend): ranks=group_ranks, backend=backend ) else: - return dist.new_group(ranks=group_ranks, backend=backend) + pg = dist.new_group(ranks=group_ranks, backend=backend) + return pg def tensor_list_to_numpy(self, tensorList): if isinstance(tensorList, list): @@ -1070,9 +1072,24 @@ def initialize_backend( def initialize_groups(self, backend="gloo"): groups = {} world_size = self.get_world_size() + global_rank = self.get_global_rank() + + # sync pgs across ranks to fix hang with multiple comm groups + # because new_group() functions requires that all processes in the main group enter, + # even if they are not going to be members of the group. + # Assumption: pg_name is unique and consistent for all ranks + sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store) + sync_store.set(str(global_rank), pickle.dumps(self.commsParams.groupRanks)) + torch.distributed.barrier() + group_ranks_sync = self.commsParams.groupRanks.copy() + for i in range(self.get_world_size()): + if i == global_rank: + continue + bytes = sync_store.get(str(i)) + group_ranks_sync.update(pickle.loads(bytes)) # create additional groups - for pg_id, group_ranks in self.commsParams.groupRanks.items(): + for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items(): if ( len(group_ranks) > world_size ): # this means that --auto-shrink is enabled, only use default pg @@ -1084,11 +1101,9 @@ def initialize_groups(self, backend="gloo"): pg = self.get_default_group() else: pg = self.get_new_pg(group_ranks=group_ranks, backend=backend) - global_rank = self.get_global_rank() - if global_rank in group_ranks: - logger.info( - f"initialize_groups: Rank {global_rank} creates new group pg_id {pg_id} {pg} with {group_ranks}" - ) + logger.info( + f"initialized_group: create new group pg_id {pg_id} {pg} with {group_ranks}" + ) groups[pg_id] = pg # if additional groups are created, overwrite the default groups list