Skip to content

Commit 115b909

Browse files
committed
Add bf.wait alias for synchronize
1 parent 52199e0 commit 115b909

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

bluefog/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from bluefog.torch.mpi_ops import neighbor_allreduce, neighbor_allreduce_nonblocking
4949
from bluefog.torch.mpi_ops import hierarchical_neighbor_allreduce
5050
from bluefog.torch.mpi_ops import hierarchical_neighbor_allreduce_nonblocking
51-
from bluefog.torch.mpi_ops import poll, synchronize, barrier
51+
from bluefog.torch.mpi_ops import poll, synchronize, wait, barrier
5252

5353
from bluefog.torch.mpi_ops import win_create, win_free
5454
from bluefog.torch.mpi_ops import win_update, win_update_then_collect

bluefog/torch/mpi_ops.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -822,13 +822,12 @@ def pair_gossip_nonblocking(tensor: torch.Tensor, target_rank: int, self_weight:
822822

823823
def poll(handle: int) -> bool:
824824
"""
825-
Polls an allreduce, allgather or broadcast handle to determine whether underlying
826-
nonblocking operation has completed. After `poll()` returns `True`, `synchronize()`
825+
Polls an allreduce, neighbor_allreduce, etc operation handle to determine whether underlying
826+
nonblocking operation has completed. After `poll()` returns `True`, `wait()`
827827
will return without blocking.
828828
829829
Arguments:
830-
handle: A handle returned by an allreduce, allgather, broadcast, neighbor_allgather,
831-
and neighbro_allreduce nonblocking operation.
830+
handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
832831
833832
Returns:
834833
A flag indicating whether the operation has completed.
@@ -838,12 +837,12 @@ def poll(handle: int) -> bool:
838837

839838
def synchronize(handle: int) -> torch.Tensor:
840839
"""
841-
Synchronizes an nonblocking allreduce, allgather or broadcast operation until
840+
Wait an allreduce, neighbor_allreduce, etc operation until
842841
it's completed. Returns the result of the operation.
842+
It is the same function as `wait()`.
843843
844844
Args:
845-
handle: A handle returned by an allreduce, allgather or broadcast nonblocking
846-
operation.
845+
handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
847846
848847
Returns:
849848
torch.Tensor: An output tensor of the operation.
@@ -855,6 +854,21 @@ def synchronize(handle: int) -> torch.Tensor:
855854
return output
856855

857856

857+
def wait(handle: int) -> torch.Tensor:
858+
"""
859+
Wait an allreduce, neighbor_allreduce, etc operation until
860+
it's completed. Returns the result of the operation.
861+
It is just alias of `synchronize()` function.
862+
863+
Args:
864+
handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
865+
866+
Returns:
867+
torch.Tensor: An output tensor of the operation.
868+
"""
869+
return synchronize(handle)
870+
871+
858872
def barrier():
859873
"""Barrier function to sychronize all MPI processes.
860874

0 commit comments

Comments
 (0)