Skip to content

Commit 322bb1c

Browse files
Fix dp sync after upstream change #24105 (#179)
- fix behavior of Lazy + `enforce_eager` in which case hpu graph is NOT used - disable device group for dp sync when hpu graph is used - enable DP CI test again --------- Signed-off-by: Wuxun Zhang <[email protected]> Co-authored-by: Chendi.Xue <[email protected]>
1 parent a3dce5c commit 322bb1c

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,12 @@ if [ $? -ne 0 ]; then
209209
fi
210210
echo "Embedding-model-support for v1 successful"
211211

212-
# Data Parallel failed with recent upstream changes
213-
# # DP2
214-
# echo "Testing data parallel size 2 with vllm-hpu plugin v1"
215-
# echo HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u vllm-gaudi/examples/data_parallel.py --dp-size 2 --tp-size 2
216-
# HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u vllm-gaudi/examples/data_parallel.py --dp-size 2 --tp-size 2
217-
# if [ $? -ne 0 ]; then
218-
# echo "Error: Test failed for data parallel size 2" >&2
219-
# exit -1
220-
# fi
221-
# echo "Test with data parallel size 2 passed"
212+
# DP2
213+
echo "Testing data parallel size 2 with vllm-hpu plugin v1"
214+
echo HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u vllm-gaudi/examples/data_parallel.py --dp-size 2 --tp-size 2
215+
HABANA_VISIBLE_DEVICES=all VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 python -u vllm-gaudi/examples/data_parallel.py --dp-size 2 --tp-size 2
216+
if [ $? -ne 0 ]; then
217+
echo "Error: Test failed for data parallel size 2" >&2
218+
exit -1
219+
fi
220+
echo "Test with data parallel size 2 passed"

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,13 @@ def __init__(
730730
self._PAD_BLOCK_ID = -1
731731
self._tokenizer = init_tokenizer_from_configs(model_config=vllm_config.model_config)
732732

733+
if self.vllm_config.parallel_config.data_parallel_size > 1 and htorch.utils.internal.is_lazy(
734+
) and not self.model_config.enforce_eager:
735+
from vllm import envs
736+
# disable device group for dp synchronization when hpu graph is
737+
# turned on since it's not captured and causes issues
738+
envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION = True
739+
733740
# TODO(madamczyk-intel): add a knob for that
734741
# TODO(madamczyk-intel): debug why increasing it lowers acc
735742
self.logits_rounding = 1
@@ -2230,7 +2237,7 @@ def _execute_model_generic(self,
22302237
num_blocks = self._num_blocks(attn_metadata)
22312238
self._check_config(batch_size, seq_len, num_blocks, attn_metadata, warmup_mode)
22322239
additional_kwargs = {}
2233-
if htorch.utils.internal.is_lazy() and not self.model_config.enforce_eager:
2240+
if htorch.utils.internal.is_lazy():
22342241
use_graphs = self._use_graphs()
22352242
additional_kwargs.update({"bypass_hpu_graphs": not use_graphs})
22362243
else:
@@ -2252,7 +2259,8 @@ def _execute_model_generic(self,
22522259
kv_caches=kv_caches,
22532260
inputs_embeds=inputs_embeds,
22542261
model_mm_kwargs=model_mm_kwargs,
2255-
lora_mask=lora_mask)
2262+
lora_mask=lora_mask,
2263+
**additional_kwargs)
22562264
# NOTE(kzawora): returning hidden_states is required in prompt logprobs
22572265
# scenarios, as they will do logit processing on their own
22582266
if self.use_aux_hidden_state_outputs:

0 commit comments

Comments
 (0)