Skip to content

Commit 670aaa3

Browse files
authored
[Bug fix] Fix pd for x1 thinking (#4433)
1 parent 8e392f0 commit 670aaa3

File tree

5 files changed

+14
-8
lines changed

5 files changed

+14
-8
lines changed

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -697,9 +697,7 @@ def _fetch_request():
697697
time.sleep(0.001)
698698
continue
699699
if self.cfg.scheduler_config.splitwise_role != "mixed":
700-
if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and (
701-
not is_fetching
702-
):
700+
if not is_fetching:
703701
get_request_pool.submit(_fetch_request)
704702

705703
else:

fastdeploy/engine/request.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
structural_tag: Optional[Any] = None,
7676
guided_json_object: Optional[bool] = None,
7777
enable_thinking: Optional[bool] = True,
78+
reasoning_max_tokens: Optional[int] = None,
7879
trace_carrier: dict = dict(),
7980
dp_rank: Optional[int] = None,
8081
chat_template: Optional[str] = None,
@@ -125,6 +126,7 @@ def __init__(
125126
self.multimodal_img_boundaries = None
126127

127128
self.enable_thinking = enable_thinking
129+
self.reasoning_max_tokens = reasoning_max_tokens
128130
self.trace_carrier = trace_carrier
129131

130132
self.chat_template = chat_template
@@ -188,7 +190,8 @@ def from_dict(cls, d: dict):
188190
guided_grammar=d.get("guided_grammar", None),
189191
structural_tag=d.get("structural_tag", None),
190192
guided_json_object=d.get("guided_json_object", None),
191-
enable_thinking=d.get("enable_thinking", True),
193+
enable_thinking=d.get("enable_thinking", False),
194+
reasoning_max_tokens=d.get("reasoning_max_tokens", None),
192195
trace_carrier=d.get("trace_carrier", {}),
193196
chat_template=d.get("chat_template", None),
194197
num_computed_tokens=d.get("num_computed_tokens", 0),
@@ -239,6 +242,7 @@ def to_dict(self) -> dict:
239242
"disaggregate_info": self.disaggregate_info,
240243
"draft_token_ids": self.draft_token_ids,
241244
"enable_thinking": self.enable_thinking,
245+
"reasoning_max_tokens": self.reasoning_max_tokens,
242246
"trace_carrier": self.trace_carrier,
243247
"chat_template": self.chat_template,
244248
"num_computed_tokens": self.num_computed_tokens,

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,8 @@ def preallocate_resource_in_d(self, request: Request):
796796
return False
797797
if self.available_batch() == 0:
798798
return False
799+
if request.reasoning_max_tokens is not None:
800+
request.reasoning_max_tokens -= 1
799801
request.need_prefill_tokens = len(request.prompt_token_ids)
800802
need_prealloc_prefill_blocks = (
801803
request.need_prefill_tokens + self.config.cache_config.block_size - 1

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,7 @@ def post_process_normal(
216216
model_output.reasoning_index,
217217
)
218218

219-
stop_wo_think = (
220-
(sampler_output.sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
221-
| (model_output.reasoning_index == 0)
222-
) & (model_output.need_think_end > 0)
219+
stop_wo_think = ((model_output.reasoning_index == 0)) & (model_output.need_think_end > 0)
223220

224221
stop_wo_think = stop_wo_think & thinking_mask
225222
sampler_output.sampled_token_ids = paddle.where(

fastdeploy/scheduler/dp_scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,18 @@ def get_requests(
174174
):
175175
break
176176
else:
177+
required_total_blocks = 0
177178
batch_ids = self.requests_not_empty.wait_for(
178179
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
179180
0.005,
180181
)
181182
if batch_ids:
182183
for request_id in batch_ids:
183184
request = self.requests[request_id]
185+
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
186+
required_total_blocks += required_input_blocks + reserved_output_blocks
187+
if required_total_blocks > available_blocks:
188+
break
184189
requests.append(request.raw)
185190
self.ids_read_cursor += 1
186191

0 commit comments

Comments
 (0)