Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/unit_tests/test_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def test_warmup_range():
assert result == [2, 4, 8, 16, 32, 64, 128]


def test_warmup_range_with_one():
config = (1, 64, 128)
result = linear.warmup_range(config)
assert result == [1, 2, 4, 8, 16, 32, 64, 128]


def test_generate_prompt_buckets():
max_num_batched_tokens = 2048
block_size = 64
Expand Down
18 changes: 13 additions & 5 deletions vllm_gaudi/extension/bucketing/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,20 @@ def check_for_user_flags(self, phase):

def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len):
self.check_for_user_flags('prompt')
if getattr(get_config(), 'VLLM_PROMPT_QUERY_BUCKET_MIN') == 1:
query_min = 1
logger().warning(
f"It's only recommended to use VLLM_PROMPT_QUERY_BUCKET_MIN=1 on the decode instance under P/D disaggregation scenario."
)
else:
query_min = block_size
use_merged_prefill = get_config().merged_prefill

# cfgs shape: [min, step, max, limit]
prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1
prompt_bs_bucket_cfg = [1, 2, max_num_prefill_seqs, prompt_bs_limit]
max_prompt_seq_limit = math.ceil(math.log2(max_num_batched_tokens))
prompt_query_bucket_cfg = [block_size, block_size, max_num_batched_tokens, max_prompt_seq_limit]
prompt_query_bucket_cfg = [query_min, block_size, max_num_batched_tokens, max_prompt_seq_limit]
max_ctx = max(1, math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size))
max_prompt_ctx_limit = 2 if max_ctx == 1 else math.ceil(math.log2(max_ctx)) + 1
prompt_ctx_bucket_cfg = [0, 1, max_ctx, max_prompt_ctx_limit]
Expand Down Expand Up @@ -124,8 +131,9 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
""" # noqa: E501

bmin, bstep, bmax, num_buckets = config
add_zero_bucket = bmin == 0
if add_zero_bucket:
add_zero_or_one_bucket = bmin in [0, 1]
if add_zero_or_one_bucket:
bmin_origin = bmin
bmin = bstep
linear_buckets = set(np.arange(bmin, bmax + 1, step=bstep))
assert num_buckets > 0, "num_buckets must be a positive integer"
Expand Down Expand Up @@ -174,6 +182,6 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
'''
if bucket not in buckets:
buckets.add(bucket)
if add_zero_bucket:
buckets.add(0)
if add_zero_or_one_bucket:
buckets.add(bmin_origin)
return list(sorted(buckets))
5 changes: 5 additions & 0 deletions vllm_gaudi/v1/worker/hpu_dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def make(
dp_size = vllm_config.parallel_config.data_parallel_size
tp_size = vllm_config.parallel_config.tensor_parallel_size

if num_tokens % tp_size != 0:
# make sure num_tokens is enough to be divided by tp_size for
# sequence parallel MOE
num_tokens = (num_tokens // tp_size + 1) * tp_size

num_tokens_across_dp = num_tokens * dp_size

dtype = vllm_config.model_config.dtype
Expand Down
6 changes: 3 additions & 3 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,10 +1974,10 @@ def _form_unified_prefill_batch(self, contents):

def _create_dummy_prefill_batch_contents(self, num_prefills: int) -> list[PrefillInputData]:
req_id = str(-1)
context_len = 0
query_len = 128
context_len = 127 if has_kv_transfer_group() else 0
query_len = 1 if has_kv_transfer_group() else 128
prompt_tokens = 128
token_ids = list(int(i) for i in range(prompt_tokens))
token_ids = list(int(i) for i in range(query_len))
num_blocks = round_up(context_len + query_len, self.block_size) // self.block_size
blocks = [0] * num_blocks
num_output_logits = context_len + query_len - prompt_tokens + 1
Expand Down
Loading