diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 15fa063e5..73c85daca 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -307,10 +307,11 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range, max_num_bat # filter rules for buckets # prompt def not_over_max_model_len(bs, query, ctx): - smaller_than_limit = (query + ctx * block_size) <= max_model_len + smaller_than_limit = (query + ctx * block_size) <= max_model_len + block_size * (max_num_prefill_seqs - 1) if not smaller_than_limit: omitted_buckets.add( - ("condition: (query + ctx * block_size) <= max_model_len", "-> bs, query, ctx: ", bs, query, ctx)) + ("condition: (query + ctx * block_size) <= max_model_len + block_size * max_num_prefill_seqs", + "-> bs, query, ctx: ", bs, query, ctx)) return smaller_than_limit def not_over_max_num_batched_tokens(bs, query, ctx):