@@ -822,13 +822,12 @@ def pair_gossip_nonblocking(tensor: torch.Tensor, target_rank: int, self_weight:
822
822
823
823
def poll (handle : int ) -> bool :
824
824
"""
825
- Polls an allreduce, allgather or broadcast handle to determine whether underlying
826
- nonblocking operation has completed. After `poll()` returns `True`, `synchronize ()`
825
+ Polls an allreduce, neighbor_allreduce, etc operation handle to determine whether underlying
826
+ nonblocking operation has completed. After `poll()` returns `True`, `wait ()`
827
827
will return without blocking.
828
828
829
829
Arguments:
830
- handle: A handle returned by an allreduce, allgather, broadcast, neighbor_allgather,
831
- and neighbro_allreduce nonblocking operation.
830
+ handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
832
831
833
832
Returns:
834
833
A flag indicating whether the operation has completed.
@@ -838,12 +837,12 @@ def poll(handle: int) -> bool:
838
837
839
838
def synchronize (handle : int ) -> torch .Tensor :
840
839
"""
841
- Synchronizes an nonblocking allreduce, allgather or broadcast operation until
840
+ Wait an allreduce, neighbor_allreduce, etc operation until
842
841
it's completed. Returns the result of the operation.
842
+ It is the same function as `wait()`.
843
843
844
844
Args:
845
- handle: A handle returned by an allreduce, allgather or broadcast nonblocking
846
- operation.
845
+ handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
847
846
848
847
Returns:
849
848
torch.Tensor: An output tensor of the operation.
@@ -855,6 +854,21 @@ def synchronize(handle: int) -> torch.Tensor:
855
854
return output
856
855
857
856
857
+ def wait (handle : int ) -> torch .Tensor :
858
+ """
859
+ Wait an allreduce, neighbor_allreduce, etc operation until
860
+ it's completed. Returns the result of the operation.
861
+ It is just alias of `synchronize()` function.
862
+
863
+ Args:
864
+ handle: A handle returned by an allreduce, neighbor_allreduce, etc. nonblocking operation.
865
+
866
+ Returns:
867
+ torch.Tensor: An output tensor of the operation.
868
+ """
869
+ return synchronize (handle )
870
+
871
+
858
872
def barrier ():
859
873
"""Barrier function to sychronize all MPI processes.
860
874
0 commit comments