Skip to content

Commit

Permalink
Add support to multiple process groups by syncing across ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev authored and TaekyungHeo committed Aug 5, 2024
1 parent fba0236 commit 8e3bdb7
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from itertools import cycle
from time import sleep
from typing import List, Optional
import pickle

import numpy as np
import torch
Expand Down Expand Up @@ -1008,7 +1009,8 @@ def get_new_pg(self, group_ranks, backend):
ranks=group_ranks, backend=backend
)
else:
return dist.new_group(ranks=group_ranks, backend=backend)
pg = dist.new_group(ranks=group_ranks, backend=backend)
return pg

def tensor_list_to_numpy(self, tensorList):
if isinstance(tensorList, list):
Expand Down Expand Up @@ -1070,9 +1072,24 @@ def initialize_backend(
def initialize_groups(self, backend="gloo"):
groups = {}
world_size = self.get_world_size()
global_rank = self.get_global_rank()

# sync pgs across ranks to fix hang with multiple comm groups
# because new_group() functions requires that all processes in the main group enter,
# even if they are not going to be members of the group.
# Assumption: pg_name is unique and consistent for all ranks
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), pickle.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()
group_ranks_sync = self.commsParams.groupRanks.copy()
for i in range(self.get_world_size()):
if i == global_rank:
continue
bytes = sync_store.get(str(i))
group_ranks_sync.update(pickle.loads(bytes))

# create additional groups
for pg_id, group_ranks in self.commsParams.groupRanks.items():
for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items():
if (
len(group_ranks) > world_size
): # this means that --auto-shrink is enabled, only use default pg
Expand All @@ -1084,11 +1101,9 @@ def initialize_groups(self, backend="gloo"):
pg = self.get_default_group()
else:
pg = self.get_new_pg(group_ranks=group_ranks, backend=backend)
global_rank = self.get_global_rank()
if global_rank in group_ranks:
logger.info(
f"initialize_groups: Rank {global_rank} creates new group pg_id {pg_id} {pg} with {group_ranks}"
)
logger.info(
f"initialized_group: create new group pg_id {pg_id} {pg} with {group_ranks}"
)
groups[pg_id] = pg

# if additional groups are created, overwrite the default groups list
Expand Down

0 comments on commit 8e3bdb7

Please sign in to comment.