diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py index 80cfc2fa3..b82214b0c 100644 --- a/tests/unit_tests/test_bucketing.py +++ b/tests/unit_tests/test_bucketing.py @@ -39,6 +39,12 @@ def test_warmup_range(): assert result == [2, 4, 8, 16, 32, 64, 128] +def test_warmup_range_with_one(): + config = (1, 64, 128) + result = linear.warmup_range(config) + assert result == [1, 2, 4, 8, 16, 32, 64, 128] + + def test_generate_prompt_buckets(): max_num_batched_tokens = 2048 block_size = 64 diff --git a/vllm_gaudi/extension/bucketing/exponential.py b/vllm_gaudi/extension/bucketing/exponential.py index a1779a43c..87a4d1aee 100644 --- a/vllm_gaudi/extension/bucketing/exponential.py +++ b/vllm_gaudi/extension/bucketing/exponential.py @@ -30,13 +30,20 @@ def check_for_user_flags(self, phase): def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len): self.check_for_user_flags('prompt') + if getattr(get_config(), 'VLLM_PROMPT_QUERY_BUCKET_MIN') == 1: + query_min = 1 + logger().warning( + f"It's only recommended to use VLLM_PROMPT_QUERY_BUCKET_MIN=1 on the decode instance under P/D disaggregation scenario." + ) + else: + query_min = block_size use_merged_prefill = get_config().merged_prefill # cfgs shape: [min, step, max, limit] prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1 prompt_bs_bucket_cfg = [1, 2, max_num_prefill_seqs, prompt_bs_limit] max_prompt_seq_limit = math.ceil(math.log2(max_num_batched_tokens)) - prompt_query_bucket_cfg = [block_size, block_size, max_num_batched_tokens, max_prompt_seq_limit] + prompt_query_bucket_cfg = [query_min, block_size, max_num_batched_tokens, max_prompt_seq_limit] max_ctx = max(1, math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size)) max_prompt_ctx_limit = 2 if max_ctx == 1 else math.ceil(math.log2(max_ctx)) + 1 prompt_ctx_bucket_cfg = [0, 1, max_ctx, max_prompt_ctx_limit] @@ -124,8 +131,9 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals """ # noqa: E501 bmin, bstep, bmax, num_buckets = config - add_zero_bucket = bmin == 0 - if add_zero_bucket: + add_zero_or_one_bucket = bmin in [0, 1] + if add_zero_or_one_bucket: + bmin_origin = bmin bmin = bstep linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep)) assert num_buckets > 0, "num_buckets must be a positive integer" @@ -174,6 +182,6 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals ''' if bucket not in buckets: buckets.add(bucket) - if add_zero_bucket: - buckets.add(0) + if add_zero_or_one_bucket: + buckets.add(bmin_origin) return list(sorted(buckets)) diff --git a/vllm_gaudi/v1/worker/hpu_dp_utils.py b/vllm_gaudi/v1/worker/hpu_dp_utils.py index ed3dcc1d1..1002f364d 100644 --- a/vllm_gaudi/v1/worker/hpu_dp_utils.py +++ b/vllm_gaudi/v1/worker/hpu_dp_utils.py @@ -22,6 +22,11 @@ def make( dp_size = vllm_config.parallel_config.data_parallel_size tp_size = vllm_config.parallel_config.tensor_parallel_size + if vllm_config.parallel_config.use_sequence_parallel_moe and (num_tokens % tp_size != 0): + # make sure num_tokens is enough to be divided by tp_size for + # sequence parallel MOE + num_tokens = (num_tokens // tp_size + 1) * tp_size + num_tokens_across_dp = num_tokens * dp_size dtype = vllm_config.model_config.dtype diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4546d41fd..2f7a50dfb 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -1974,10 +1974,10 @@ def _form_unified_prefill_batch(self, contents): def _create_dummy_prefill_batch_contents(self, num_prefills: int) -> list[PrefillInputData]: req_id = str(-1) - context_len = 0 - query_len = 128 + context_len = 127 if has_kv_transfer_group() else 0 + query_len = 1 if has_kv_transfer_group() else 128 prompt_tokens = 128 - token_ids = list(int(i) for i in range(prompt_tokens)) + token_ids = list(int(i) for i in range(query_len)) num_blocks = round_up(context_len + query_len, self.block_size) // self.block_size blocks = [0] * num_blocks num_output_logits = context_len + query_len - prompt_tokens + 1