Skip to content

Commit d42a6a4

Browse files
ji-huazhongqgallouedec
authored andcommitted
🧗 Add Ascend NPU support for vLLM server (#3286)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 912d005 commit d42a6a4

3 files changed

Lines changed: 23 additions & 5 deletions

File tree

trl/extras/vllm_client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
from torch import nn
2222

23-
from ..import_utils import is_requests_available, is_vllm_available
23+
from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available
2424

2525

2626
if is_requests_available():
@@ -32,6 +32,9 @@
3232
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
3333
from vllm.distributed.utils import StatelessProcessGroup
3434

35+
if is_vllm_ascend_available():
36+
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
37+
3538

3639
logger = logging.getLogger(__name__)
3740

@@ -212,7 +215,7 @@ def init_communicator(self):
212215

213216
# Set up the communication group for weight broadcasting
214217
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
215-
self.pynccl_comm = PyNcclCommunicator(pg, device="cuda:0")
218+
self.pynccl_comm = PyNcclCommunicator(pg, device=0)
216219

217220
def update_named_param(self, name: str, weights: torch.Tensor):
218221
"""
@@ -231,7 +234,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
231234
raise Exception(f"Request failed: {response.status_code}, {response.text}")
232235

233236
# Broadcast the weights to the other processes
234-
self.pynccl_comm.broadcast(weights, src=self.rank, stream=torch.cuda.current_stream())
237+
self.pynccl_comm.broadcast(weights, src=self.rank)
235238
self.pynccl_comm.group.barrier()
236239

237240
def update_model_params(self, model: nn.Module):

trl/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_unsloth_available = _is_package_available("unsloth")
3838
_uvicorn_available = _is_package_available("uvicorn")
3939
_vllm_available = _is_package_available("vllm")
40+
_vllm_ascend_available = _is_package_available("vllm_ascend")
4041
_joblib_available = _is_package_available("joblib")
4142

4243

@@ -88,6 +89,10 @@ def is_vllm_available() -> bool:
8889
return _vllm_available
8990

9091

92+
def is_vllm_ascend_available() -> bool:
93+
return _vllm_ascend_available
94+
95+
9196
def is_joblib_available() -> bool:
9297
return _joblib_available
9398

trl/scripts/vllm_serve.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
import torch
2323

2424
from trl import TrlParser
25-
from trl.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available, is_vllm_available
25+
from trl.import_utils import (
26+
is_fastapi_available,
27+
is_pydantic_available,
28+
is_uvicorn_available,
29+
is_vllm_ascend_available,
30+
is_vllm_available,
31+
)
2632

2733

2834
if is_fastapi_available():
@@ -44,6 +50,10 @@
4450
from vllm.distributed.utils import StatelessProcessGroup
4551
from vllm.sampling_params import GuidedDecodingParams
4652

53+
if is_vllm_ascend_available():
54+
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
55+
56+
4757
logger = logging.getLogger(__name__)
4858

4959
# We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following
@@ -114,7 +124,7 @@ def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]
114124
weight = torch.empty(shape, dtype=dtype, device=self.device)
115125

116126
# Use NCCL to broadcast the updated weights from the client (src) to all workers.
117-
self.pynccl_comm.broadcast(weight, src=self.client_rank, stream=torch.cuda.current_stream())
127+
self.pynccl_comm.broadcast(weight, src=self.client_rank)
118128
self.pynccl_comm.group.barrier()
119129

120130
# Load the received weights into the model.

0 commit comments

Comments
 (0)