@@ -175,12 +175,8 @@ def _handle_sequence(self, finish_reasons, response_tensors,
175175 else :
176176 output .token_ids .extend (response_tensors .output_token_ids [src_idx ])
177177
178- # In PD, the first generation response will return 2 tokens
179- # Skip output the first generated token in generation response
180- # TODO: We should have a better way to handle this when enable
181- # beam search with PD.
182- if not self .sampling_params .use_beam_search and \
183- len (response_tensors .output_token_ids [src_idx ]) == 2 :
178+ # In PD, the first token should be ignored in streaming mode, since it's already been returned by the context server
179+ if self .disaggregated_params is not None and self .disaggregated_params .request_type == "generation_only" and self ._streaming and self .decoding_iter == 2 :
184180 output ._last_token_ids_len = 1
185181
186182 if response_tensors .cum_log_probs is not None :
@@ -352,10 +348,12 @@ class GenerationResult(GenerationResultBase):
352348 executor (GenerationExecutor, optional): The executor that created this result. Defaults to None.
353349 '''
354350
355- def __init__ (self ,
356- generation_request : "GenerationRequest" ,
357- background_error_handler : Optional [Callable ] = None ,
358- executor : Optional ["GenerationExecutor" ] = None ) -> None :
351+ def __init__ (
352+ self ,
353+ generation_request : "GenerationRequest" ,
354+ background_error_handler : Optional [Callable ] = None ,
355+ executor : Optional ["GenerationExecutor" ] = None ,
356+ disaggregated_params : Optional [DisaggregatedParams ] = None ) -> None :
359357 super ().__init__ (
360358 generation_request .id ,
361359 generation_request .sampling_params ,
@@ -364,6 +362,7 @@ def __init__(self,
364362 )
365363 self ._generation_request = generation_request
366364 self ._streaming = generation_request .streaming
365+ self .disaggregated_params = disaggregated_params
367366
368367 # for aborting the request
369368 self ._executor : Optional [weakref .ReferenceType [
0 commit comments