2020import torch
2121from torch import nn
2222
23- from ..import_utils import is_requests_available , is_vllm_available
23+ from ..import_utils import is_requests_available , is_vllm_ascend_available , is_vllm_available
2424
2525
2626if is_requests_available ():
3232 from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
3333 from vllm .distributed .utils import StatelessProcessGroup
3434
35+ if is_vllm_ascend_available ():
36+ from vllm_ascend .distributed .device_communicators .pyhccl import PyHcclCommunicator as PyNcclCommunicator
37+
3538
3639logger = logging .getLogger (__name__ )
3740
@@ -212,7 +215,7 @@ def init_communicator(self):
212215
213216 # Set up the communication group for weight broadcasting
214217 pg = StatelessProcessGroup .create (host = self .host , port = self .group_port , rank = self .rank , world_size = world_size )
215- self .pynccl_comm = PyNcclCommunicator (pg , device = "cuda:0" )
218+ self .pynccl_comm = PyNcclCommunicator (pg , device = 0 )
216219
217220 def update_named_param (self , name : str , weights : torch .Tensor ):
218221 """
@@ -231,7 +234,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
231234 raise Exception (f"Request failed: { response .status_code } , { response .text } " )
232235
233236 # Broadcast the weights to the other processes
234- self .pynccl_comm .broadcast (weights , src = self .rank , stream = torch . cuda . current_stream () )
237+ self .pynccl_comm .broadcast (weights , src = self .rank )
235238 self .pynccl_comm .group .barrier ()
236239
237240 def update_model_params (self , model : nn .Module ):
0 commit comments