Skip to content

fix: Fixing issue with first gen token being returned twice in streaming #3427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 14, 2025
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ void initBindings(pybind11::module_& m)
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
.def_property_readonly(
Expand Down
11 changes: 11 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(self,
self.enable_attention_dp = model_engine.enable_attention_dp
self.decoder = decoder
self.dist = dist
self.enable_overlap_scheduler = enable_overlap_scheduler

# Draft model for certain spec decode algorithms, e.g. EAGLE3
self.draft_model_engine = draft_model_engine
Expand Down Expand Up @@ -1441,6 +1442,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
req.state = LlmRequestState.GENERATION_IN_PROGRESS
req.context_current_position = req.prompt_len
req.decoding_iter = 1
req.py_decoding_iter = 1
first_gen_tokens = req.context_phase_params.first_gen_tokens
req.py_draft_tokens = req.context_phase_params.draft_tokens
beam_width = req.sampling_config.beam_width
Expand Down Expand Up @@ -1866,6 +1868,15 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

if request.is_generation_only_request:
# If request is in transmission, so we don't need to emit a response
# Also, for the first iteration with overlap, we should skip since first token has already been emitted by context server
if request.is_disagg_generation_transmission_in_progress or (
self.enable_overlap_scheduler
and request.py_decoding_iter <= 1):
new_active_requests.append(request)
continue

request.draft_tokens = request.py_draft_tokens
request.decoding_iter = request.py_decoding_iter
response = request.create_response(False, self.dist.rank)
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self)
executor=self,
disaggregated_params=request.disaggregated_params)
self._results[request.id] = result

self.request_queue.put(request)
Expand Down
19 changes: 9 additions & 10 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,8 @@ def _handle_sequence(self, finish_reasons, response_tensors,
else:
output.token_ids.extend(response_tensors.output_token_ids[src_idx])

# In PD, the first generation response will return 2 tokens
# Skip output the first generated token in generation response
# TODO: We should have a better way to handle this when enable
# beam search with PD.
if not self.sampling_params.use_beam_search and \
len(response_tensors.output_token_ids[src_idx]) == 2:
# In PD, the first token should be ignored in streaming mode, since it's already been returned by the context server
if self.disaggregated_params is not None and self.disaggregated_params.request_type == "generation_only" and self._streaming and self.decoding_iter == 2:
output._last_token_ids_len = 1

if response_tensors.cum_log_probs is not None:
Expand Down Expand Up @@ -352,10 +348,12 @@ class GenerationResult(GenerationResultBase):
executor (GenerationExecutor, optional): The executor that created this result. Defaults to None.
'''

def __init__(self,
generation_request: "GenerationRequest",
background_error_handler: Optional[Callable] = None,
executor: Optional["GenerationExecutor"] = None) -> None:
def __init__(
self,
generation_request: "GenerationRequest",
background_error_handler: Optional[Callable] = None,
executor: Optional["GenerationExecutor"] = None,
disaggregated_params: Optional[DisaggregatedParams] = None) -> None:
super().__init__(
generation_request.id,
generation_request.sampling_params,
Expand All @@ -364,6 +362,7 @@ def __init__(self,
)
self._generation_request = generation_request
self._streaming = generation_request.streaming
self.disaggregated_params = disaggregated_params

# for aborting the request
self._executor: Optional[weakref.ReferenceType[
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self)
executor=self,
disaggregated_params=request.disaggregated_params)

self._results[client_id] = result

Expand Down
7 changes: 7 additions & 0 deletions tests/integration/defs/disaggregated/sanity_check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ fi

if [[ "${TEST_DESC}" != "gen_only" ]]; then
expected_strings=("The capital of Germany is Berlin" "Asyncio is a Python library")
not_expected_strings=("Berlin Berlin")
if [[ "${TEST_DESC}" =~ "deepseek_v3_lite" ]]; then
expected_strings=("Berlin" "Asyncio is a")
fi
Expand All @@ -130,4 +131,10 @@ if [[ "${TEST_DESC}" != "gen_only" ]]; then
grep "${expected_string}" output_streaming_chat.json
fi
done

for not_expected_string in "${not_expected_strings[@]}"; do
grep -v "${not_expected_string}" output.json
grep -v "${not_expected_string}" output_streaming.json
done

fi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
hostname: localhost
port: 8000
model: DeepSeek-V3-Lite/fp8
free_gpu_memory_fraction: 0.25
backend: "pytorch"
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
tensor_parallel_size: 2
Expand Down