Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ void initBindings(pybind11::module_& m)
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
.def_property_readonly(
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def update_requests(self, scheduled_requests: ScheduledRequests,
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[idx]
num_tokens = request.add_new_token(new_token, beam_idx)
request.decoding_iter += 1
self._handle_stop_criteria(request, new_token, num_tokens,
beam_idx)
request.py_decoding_iter += 1
Expand Down
11 changes: 10 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
if req.is_disagg_generation_transmission_complete:
req.state = LlmRequestState.GENERATION_IN_PROGRESS
req.context_current_position = req.prompt_len
req.decoding_iter = 1
req.py_decoding_iter = 1
first_gen_tokens = req.context_phase_params.first_gen_tokens
req.py_draft_tokens = req.context_phase_params.draft_tokens
beam_width = req.sampling_config.beam_width
Expand Down Expand Up @@ -1821,6 +1821,15 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

# Unify the behavior of overlapping and non-overlapping cases
# return the response after we have second generated token
if self.kv_cache_transceiver is not None:
if request.state == LlmRequestState.GENERATION_IN_PROGRESS and \
request.is_generation_only_request:
if request.py_decoding_iter == 1:
new_active_requests.append(request)
continue

request.draft_tokens = request.py_draft_tokens
request.decoding_iter = request.py_decoding_iter
response = request.create_response(False, self.dist.rank)
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self)
executor=self,
disaggregated_params=request.disaggregated_params)
self._results[request.id] = result

self.request_queue.put(request)
Expand Down
13 changes: 8 additions & 5 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _handle_sequence(self, finish_reasons, response_tensors,
# Skip output the first generated token in generation response
# TODO: We should have a better way to handle this when enable
# beam search with PD.
if not self.sampling_params.use_beam_search and \
if self.disaggregated_params is not None and \
len(response_tensors.output_token_ids[src_idx]) == 2:
output._last_token_ids_len = 1
Comment on lines +182 to 184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should rely on the len of output_token_ids since for spec decoding, we could have 2 tokens even after the first gen token. Can you have a look at: #3427. I think it's a more general fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I also think that should be a good solution. I will close this PR.


Expand Down Expand Up @@ -352,10 +352,12 @@ class GenerationResult(GenerationResultBase):
executor (GenerationExecutor, optional): The executor that created this result. Defaults to None.
'''

def __init__(self,
generation_request: "GenerationRequest",
background_error_handler: Optional[Callable] = None,
executor: Optional["GenerationExecutor"] = None) -> None:
def __init__(
self,
generation_request: "GenerationRequest",
background_error_handler: Optional[Callable] = None,
executor: Optional["GenerationExecutor"] = None,
disaggregated_params: Optional[DisaggregatedParams] = None) -> None:
super().__init__(
generation_request.id,
generation_request.sampling_params,
Expand All @@ -364,6 +366,7 @@ def __init__(self,
)
self._generation_request = generation_request
self._streaming = generation_request.streaming
self.disaggregated_params = disaggregated_params

# for aborting the request
self._executor: Optional[weakref.ReferenceType[
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self)
executor=self,
disaggregated_params=request.disaggregated_params)

self._results[client_id] = result

Expand Down
18 changes: 16 additions & 2 deletions tests/integration/defs/disaggregated/sanity_check.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash
set -x
set -e
pkill -9 -f trtllm-serve || true

rm -rf output.json || true
Expand Down Expand Up @@ -116,7 +117,20 @@ if [[ "${TEST_DESC}" != "gen_only" ]]; then
fi

for expected_string in "${expected_strings[@]}"; do
grep "${expected_string}" output.json
grep "${expected_string}" output_streaming.json
grep "${expected_string}" output.json
grep "${expected_string}" output_streaming.json

# Check the double first token in streaming output for ds-v3-lite
if [[ "${TEST_DESC}" =~ "deepseek_v3_lite" ]]; then
if [ "$expected_string" != "${expected_strings[0]}" ]; then
continue
fi

first_word=$(echo "$expected_string" | awk '{print $1}')
count=$(grep -o "$first_word" output_streaming.json | wc -l)
if [ "$count" -ne 1 ]; then
exit 1
fi
fi
done
fi