diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index f48d048826b..2bd589c9e94 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -251,6 +251,7 @@ void initBindings(pybind11::module_& m) "guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) .def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams) .def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest) + .def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) .def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished) .def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState) .def_property_readonly( diff --git a/tensorrt_llm/_torch/pyexecutor/decoder.py b/tensorrt_llm/_torch/pyexecutor/decoder.py index 795683bde87..e27d0b2fbc2 100644 --- a/tensorrt_llm/_torch/pyexecutor/decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/decoder.py @@ -281,6 +281,7 @@ def update_requests(self, scheduled_requests: ScheduledRequests, if request.state != LlmRequestState.GENERATION_COMPLETE: new_token = new_tokens_list[idx] num_tokens = request.add_new_token(new_token, beam_idx) + request.decoding_iter += 1 self._handle_stop_criteria(request, new_token, num_tokens, beam_idx) request.py_decoding_iter += 1 diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index c0c66dd1e81..ccbe1c824de 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1401,7 +1401,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): if req.is_disagg_generation_transmission_complete: req.state = LlmRequestState.GENERATION_IN_PROGRESS req.context_current_position = req.prompt_len - req.decoding_iter = 1 + req.py_decoding_iter = 1 first_gen_tokens = req.context_phase_params.first_gen_tokens req.py_draft_tokens = req.context_phase_params.draft_tokens beam_width = req.sampling_config.beam_width @@ -1821,6 +1821,15 @@ def _handle_responses(self): requests_to_terminate.append(request) continue + # Unify the behavior of overlapping and non-overlapping cases + # return the response after we have second generated token + if self.kv_cache_transceiver is not None: + if request.state == LlmRequestState.GENERATION_IN_PROGRESS and \ + request.is_generation_only_request: + if request.py_decoding_iter == 1: + new_active_requests.append(request) + continue + request.draft_tokens = request.py_draft_tokens request.decoding_iter = request.py_decoding_iter response = request.create_response(False, self.dist.rank) diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 8b5bbd804d9..0dc4378bb20 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -367,7 +367,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult: result = GenerationResult( request, background_error_handler=self._handle_background_error, - executor=self) + executor=self, + disaggregated_params=request.disaggregated_params) self._results[request.id] = result self.request_queue.put(request) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 39809f143bf..23d3e895bdc 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -179,7 +179,7 @@ def _handle_sequence(self, finish_reasons, response_tensors, # Skip output the first generated token in generation response # TODO: We should have a better way to handle this when enable # beam search with PD. - if not self.sampling_params.use_beam_search and \ + if self.disaggregated_params is not None and \ len(response_tensors.output_token_ids[src_idx]) == 2: output._last_token_ids_len = 1 @@ -352,10 +352,12 @@ class GenerationResult(GenerationResultBase): executor (GenerationExecutor, optional): The executor that created this result. Defaults to None. ''' - def __init__(self, - generation_request: "GenerationRequest", - background_error_handler: Optional[Callable] = None, - executor: Optional["GenerationExecutor"] = None) -> None: + def __init__( + self, + generation_request: "GenerationRequest", + background_error_handler: Optional[Callable] = None, + executor: Optional["GenerationExecutor"] = None, + disaggregated_params: Optional[DisaggregatedParams] = None) -> None: super().__init__( generation_request.id, generation_request.sampling_params, @@ -364,6 +366,7 @@ def __init__(self, ) self._generation_request = generation_request self._streaming = generation_request.streaming + self.disaggregated_params = disaggregated_params # for aborting the request self._executor: Optional[weakref.ReferenceType[ diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 452a0dee138..7351b8465bd 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -419,7 +419,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult: result = GenerationResult( request, background_error_handler=self._handle_background_error, - executor=self) + executor=self, + disaggregated_params=request.disaggregated_params) self._results[client_id] = result diff --git a/tests/integration/defs/disaggregated/sanity_check.sh b/tests/integration/defs/disaggregated/sanity_check.sh index d2919f528ac..6886a58f8e6 100644 --- a/tests/integration/defs/disaggregated/sanity_check.sh +++ b/tests/integration/defs/disaggregated/sanity_check.sh @@ -1,5 +1,6 @@ #!/bin/bash set -x +set -e pkill -9 -f trtllm-serve || true rm -rf output.json || true @@ -116,7 +117,20 @@ if [[ "${TEST_DESC}" != "gen_only" ]]; then fi for expected_string in "${expected_strings[@]}"; do - grep "${expected_string}" output.json - grep "${expected_string}" output_streaming.json + grep "${expected_string}" output.json + grep "${expected_string}" output_streaming.json + + # Check the double first token in streaming output for ds-v3-lite + if [[ "${TEST_DESC}" =~ "deepseek_v3_lite" ]]; then + if [ "$expected_string" != "${expected_strings[0]}" ]; then + continue + fi + + first_word=$(echo "$expected_string" | awk '{print $1}') + count=$(grep -o "$first_word" output_streaming.json | wc -l) + if [ "$count" -ne 1 ]; then + exit 1 + fi + fi done fi