Skip to content

Commit f146d77

Browse files
committed
fix: Fixing issue with first gen token being returned twice with streaming
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
1 parent 5616c0d commit f146d77

File tree

8 files changed

+32
-16
lines changed

8 files changed

+32
-16
lines changed

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ void initBindings(pybind11::module_& m)
251251
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
252252
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
253253
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
254+
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
254255
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
255256
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
256257
.def_property_readonly(

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(self,
161161
self.enable_attention_dp = model_engine.enable_attention_dp
162162
self.decoder = decoder
163163
self.dist = dist
164+
self.enable_overlap_scheduler = enable_overlap_scheduler
164165

165166
# Draft model for certain spec decode algorithms, e.g. EAGLE3
166167
self.draft_model_engine = draft_model_engine
@@ -1424,6 +1425,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
14241425
req.state = LlmRequestState.GENERATION_IN_PROGRESS
14251426
req.context_current_position = req.prompt_len
14261427
req.decoding_iter = 1
1428+
req.py_decoding_iter = 1
14271429
first_gen_tokens = req.context_phase_params.first_gen_tokens
14281430
req.py_draft_tokens = req.context_phase_params.draft_tokens
14291431
beam_width = req.sampling_config.beam_width
@@ -1849,6 +1851,15 @@ def _handle_responses(self):
18491851
requests_to_terminate.append(request)
18501852
continue
18511853

1854+
if request.is_generation_only_request:
1855+
# If request is in transmission, so we don't need to emit a response
1856+
# Also, for the first iteration with overlap, we should skip since first token has already been emitted by context server
1857+
if request.is_disagg_generation_transmission_in_progress or (
1858+
self.enable_overlap_scheduler
1859+
and request.py_decoding_iter <= 1):
1860+
new_active_requests.append(request)
1861+
continue
1862+
18521863
request.draft_tokens = request.py_draft_tokens
18531864
request.decoding_iter = request.py_decoding_iter
18541865
response = request.create_response(False, self.dist.rank)

tensorrt_llm/executor/proxy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
367367
result = GenerationResult(
368368
request,
369369
background_error_handler=self._handle_background_error,
370-
executor=self)
370+
executor=self,
371+
disaggregated_params=request.disaggregated_params)
371372
self._results[request.id] = result
372373

373374
self.request_queue.put(request)

tensorrt_llm/executor/result.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,8 @@ def _handle_sequence(self, finish_reasons, response_tensors,
175175
else:
176176
output.token_ids.extend(response_tensors.output_token_ids[src_idx])
177177

178-
# In PD, the first generation response will return 2 tokens
179-
# Skip output the first generated token in generation response
180-
# TODO: We should have a better way to handle this when enable
181-
# beam search with PD.
182-
if not self.sampling_params.use_beam_search and \
183-
len(response_tensors.output_token_ids[src_idx]) == 2:
178+
# In PD, the first token should be ignored in streaming mode, since it's already been returned by the context server
179+
if self.disaggregated_params is not None and self.disaggregated_params.request_type == "generation_only" and self._streaming and self.decoding_iter == 2:
184180
output._last_token_ids_len = 1
185181

186182
if response_tensors.cum_log_probs is not None:
@@ -352,10 +348,12 @@ class GenerationResult(GenerationResultBase):
352348
executor (GenerationExecutor, optional): The executor that created this result. Defaults to None.
353349
'''
354350

355-
def __init__(self,
356-
generation_request: "GenerationRequest",
357-
background_error_handler: Optional[Callable] = None,
358-
executor: Optional["GenerationExecutor"] = None) -> None:
351+
def __init__(
352+
self,
353+
generation_request: "GenerationRequest",
354+
background_error_handler: Optional[Callable] = None,
355+
executor: Optional["GenerationExecutor"] = None,
356+
disaggregated_params: Optional[DisaggregatedParams] = None) -> None:
359357
super().__init__(
360358
generation_request.id,
361359
generation_request.sampling_params,
@@ -364,6 +362,7 @@ def __init__(self,
364362
)
365363
self._generation_request = generation_request
366364
self._streaming = generation_request.streaming
365+
self.disaggregated_params = disaggregated_params
367366

368367
# for aborting the request
369368
self._executor: Optional[weakref.ReferenceType[

tensorrt_llm/executor/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
419419
result = GenerationResult(
420420
request,
421421
background_error_handler=self._handle_background_error,
422-
executor=self)
422+
executor=self,
423+
disaggregated_params=request.disaggregated_params)
423424

424425
self._results[client_id] = result
425426

tests/integration/defs/disaggregated/sanity_check.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ fi
111111

112112
if [[ "${TEST_DESC}" != "gen_only" ]]; then
113113
expected_strings=("The capital of Germany is Berlin" "Asyncio is a Python library")
114+
no_expected_strings=("Berlin Berlin")
114115
if [[ "${TEST_DESC}" =~ "deepseek_v3_lite" ]]; then
115116
expected_strings=("Berlin" "Asyncio is a powerful tool")
116117
fi
@@ -119,4 +120,10 @@ if [[ "${TEST_DESC}" != "gen_only" ]]; then
119120
grep "${expected_string}" output.json
120121
grep "${expected_string}" output_streaming.json
121122
done
123+
124+
for not_expected_string in "${not_expected_strings[@]}"; do
125+
grep -v "${not_expected_string}" output.json
126+
grep -v "${not_expected_string}" output_streaming.json
127+
done
128+
122129
fi

tests/integration/defs/disaggregated/test_configs/disagg_config_overlap_dp.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ hostname: localhost
33
port: 8000
44
backend: "pytorch"
55
free_gpu_memory_fraction: 0.2
6-
speculative_config:
7-
decoding_type: MTP
8-
num_nextn_predict_layers: 1
96
context_servers:
107
num_instances: 1
118
max_batch_size: 1

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,6 @@ full:B40/perf/test_perf.py::test_perf[t5_large] SKIP (bert_attention_plugin does
419419
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-disable_quant-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5174573)
420420
examples/test_mistral.py::test_llm_mistral_nemo_fp8_quantization_1gpu[Mistral-Nemo-12b-Base-summarization] SKIP (https://nvbugspro.nvidia.com/bug/5181262)
421421
examples/test_qwen.py::test_llm_qwen_moe_single_gpu_summary[qwen1.5_moe_a2.7b_chat-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha] SKIP (https://nvbugs/5180961)
422-
disaggregated/test_disaggregated.py::test_disaggregated_overlap_dp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5166600)
423422
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5201168)
424423
unittest/_torch/multi_gpu_modeling -k "deepseek and tp2_pp2_ep1_nextn0_enable_dp" SKIP (https://nvbugspro.nvidia.com/bug/5206873)
425424
examples/test_multimodal.py::test_llm_multimodal_general[neva-22b-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5214245)

0 commit comments

Comments
 (0)