diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index c119645677..7341f6122a 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -251,6 +251,7 @@ void initBindings(pybind11::module_& m) "guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) .def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams) .def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest) + .def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) .def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished) .def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState) .def_property_readonly( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8df2aad1d7..2b1578d28b 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -161,6 +161,7 @@ def __init__(self, self.enable_attention_dp = model_engine.enable_attention_dp self.decoder = decoder self.dist = dist + self.enable_overlap_scheduler = enable_overlap_scheduler # Draft model for certain spec decode algorithms, e.g. EAGLE3 self.draft_model_engine = draft_model_engine @@ -1441,6 +1442,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): req.state = LlmRequestState.GENERATION_IN_PROGRESS req.context_current_position = req.prompt_len req.decoding_iter = 1 + req.py_decoding_iter = 1 first_gen_tokens = req.context_phase_params.first_gen_tokens req.py_draft_tokens = req.context_phase_params.draft_tokens beam_width = req.sampling_config.beam_width @@ -1866,6 +1868,15 @@ def _handle_responses(self): requests_to_terminate.append(request) continue + if request.is_generation_only_request: + # If request is in transmission, so we don't need to emit a response + # Also, for the first iteration with overlap, we should skip since first token has already been emitted by context server + if request.is_disagg_generation_transmission_in_progress or ( + self.enable_overlap_scheduler + and request.py_decoding_iter <= 1): + new_active_requests.append(request) + continue + request.draft_tokens = request.py_draft_tokens request.decoding_iter = request.py_decoding_iter response = request.create_response(False, self.dist.rank) diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 5acb5b71fb..bf30c6c754 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -369,7 +369,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult: result = GenerationResult( request, background_error_handler=self._handle_background_error, - executor=self) + executor=self, + disaggregated_params=request.disaggregated_params) self._results[request.id] = result self.request_queue.put(request) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 39809f143b..a3a337863d 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -175,12 +175,8 @@ def _handle_sequence(self, finish_reasons, response_tensors, else: output.token_ids.extend(response_tensors.output_token_ids[src_idx]) - # In PD, the first generation response will return 2 tokens - # Skip output the first generated token in generation response - # TODO: We should have a better way to handle this when enable - # beam search with PD. - if not self.sampling_params.use_beam_search and \ - len(response_tensors.output_token_ids[src_idx]) == 2: + # In PD, the first token should be ignored in streaming mode, since it's already been returned by the context server + if self.disaggregated_params is not None and self.disaggregated_params.request_type == "generation_only" and self._streaming and self.decoding_iter == 2: output._last_token_ids_len = 1 if response_tensors.cum_log_probs is not None: @@ -352,10 +348,12 @@ class GenerationResult(GenerationResultBase): executor (GenerationExecutor, optional): The executor that created this result. Defaults to None. ''' - def __init__(self, - generation_request: "GenerationRequest", - background_error_handler: Optional[Callable] = None, - executor: Optional["GenerationExecutor"] = None) -> None: + def __init__( + self, + generation_request: "GenerationRequest", + background_error_handler: Optional[Callable] = None, + executor: Optional["GenerationExecutor"] = None, + disaggregated_params: Optional[DisaggregatedParams] = None) -> None: super().__init__( generation_request.id, generation_request.sampling_params, @@ -364,6 +362,7 @@ def __init__(self, ) self._generation_request = generation_request self._streaming = generation_request.streaming + self.disaggregated_params = disaggregated_params # for aborting the request self._executor: Optional[weakref.ReferenceType[ diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index d045f2dfce..09759fede5 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -417,7 +417,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult: result = GenerationResult( request, background_error_handler=self._handle_background_error, - executor=self) + executor=self, + disaggregated_params=request.disaggregated_params) self._results[client_id] = result diff --git a/tests/integration/defs/disaggregated/sanity_check.sh b/tests/integration/defs/disaggregated/sanity_check.sh index b419434000..566be3123c 100644 --- a/tests/integration/defs/disaggregated/sanity_check.sh +++ b/tests/integration/defs/disaggregated/sanity_check.sh @@ -117,6 +117,7 @@ fi if [[ "${TEST_DESC}" != "gen_only" ]]; then expected_strings=("The capital of Germany is Berlin" "Asyncio is a Python library") + not_expected_strings=("Berlin Berlin") if [[ "${TEST_DESC}" =~ "deepseek_v3_lite" ]]; then expected_strings=("Berlin" "Asyncio is a") fi @@ -130,4 +131,10 @@ if [[ "${TEST_DESC}" != "gen_only" ]]; then grep "${expected_string}" output_streaming_chat.json fi done + + for not_expected_string in "${not_expected_strings[@]}"; do + grep -v "${not_expected_string}" output.json + grep -v "${not_expected_string}" output_streaming.json + done + fi diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml index 565764ec69..dcd6db9f9d 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml @@ -1,8 +1,8 @@ hostname: localhost port: 8000 model: DeepSeek-V3-Lite/fp8 -free_gpu_memory_fraction: 0.25 backend: "pytorch" +free_gpu_memory_fraction: 0.2 context_servers: num_instances: 1 tensor_parallel_size: 2