Skip to content
Merged
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ void initBindings(pybind11::module_& m)
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
.def_property_readonly(
Expand Down
11 changes: 11 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(self,
self.enable_attention_dp = model_engine.enable_attention_dp
self.decoder = decoder
self.dist = dist
self.enable_overlap_scheduler = enable_overlap_scheduler

# Draft model for certain spec decode algorithms, e.g. EAGLE3
self.draft_model_engine = draft_model_engine
Expand Down Expand Up @@ -1441,6 +1442,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
req.state = LlmRequestState.GENERATION_IN_PROGRESS
req.context_current_position = req.prompt_len
req.decoding_iter = 1
req.py_decoding_iter = 1
first_gen_tokens = req.context_phase_params.first_gen_tokens
req.py_draft_tokens = req.context_phase_params.draft_tokens
beam_width = req.sampling_config.beam_width
Expand Down Expand Up @@ -1866,6 +1868,15 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

if request.is_generation_only_request:
# If request is in transmission, so we don't need to emit a response
# Also, for the first iteration with overlap, we should skip since first token has already been emitted by context server
if request.is_disagg_generation_transmission_in_progress or (
self.enable_overlap_scheduler
and request.py_decoding_iter <= 1):
new_active_requests.append(request)
continue

request.draft_tokens = request.py_draft_tokens
request.decoding_iter = request.py_decoding_iter
response = request.create_response(False, self.dist.rank)
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self)
executor=self,
disaggregated_params=request.disaggregated_params)
self._results[request.id] = result

self.request_queue.put(request)
Expand Down
19 changes: 9 additions & 10 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,8 @@ def _handle_sequence(self, finish_reasons, response_tensors,
else:
output.token_ids.extend(response_tensors.output_token_ids[src_idx])

# In PD, the first generation response will return 2 tokens
# Skip output the first generated token in generation response
# TODO: We should have a better way to handle this when enable
# beam search with PD.
if not self.sampling_params.use_beam_search and \
len(response_tensors.output_token_ids[src_idx]) == 2:
# In PD, the first token should be ignored in streaming mode, since it's already been returned by the context server
if self.disaggregated_params is not None and self.disaggregated_params.request_type == "generation_only" and self._streaming and self.decoding_iter == 2:
output._last_token_ids_len = 1

if response_tensors.cum_log_probs is not None:
Expand Down Expand Up @@ -352,10 +348,12 @@ class GenerationResult(GenerationResultBase):
executor (GenerationExecutor, optional): The executor that created this result. Defaults to None.
'''

def __init__(self,
generation_request: "GenerationRequest",
background_error_handler: Optional[Callable] = None,
executor: Optional["GenerationExecutor"] = None) -> None:
def __init__(
self,
generation_request: "GenerationRequest",
background_error_handler: Optional[Callable] = None,
executor: Optional["GenerationExecutor"] = None,
disaggregated_params: Optional[DisaggregatedParams] = None) -> None:
super().__init__(
generation_request.id,
generation_request.sampling_params,
Expand All @@ -364,6 +362,7 @@ def __init__(self,
)
self._generation_request = generation_request
self._streaming = generation_request.streaming
self.disaggregated_params = disaggregated_params

# for aborting the request
self._executor: Optional[weakref.ReferenceType[
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self)
executor=self,
disaggregated_params=request.disaggregated_params)

self._results[client_id] = result

Expand Down
7 changes: 7 additions & 0 deletions tests/integration/defs/disaggregated/sanity_check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ fi

if [[ "${TEST_DESC}" != "gen_only" ]]; then
expected_strings=("The capital of Germany is Berlin" "Asyncio is a Python library")
not_expected_strings=("Berlin Berlin")
if [[ "${TEST_DESC}" =~ "deepseek_v3_lite" ]]; then
expected_strings=("Berlin" "Asyncio is a")
fi
Expand All @@ -130,4 +131,10 @@ if [[ "${TEST_DESC}" != "gen_only" ]]; then
grep "${expected_string}" output_streaming_chat.json
fi
done

for not_expected_string in "${not_expected_strings[@]}"; do
grep -v "${not_expected_string}" output.json
grep -v "${not_expected_string}" output_streaming.json
done

fi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
hostname: localhost
port: 8000
model: DeepSeek-V3-Lite/fp8
free_gpu_memory_fraction: 0.25
backend: "pytorch"
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
tensor_parallel_size: 2
Expand Down