@@ -446,6 +446,57 @@ def recv_requests_from_peer(self) -> List[Request]:
446446
447447 return recv_reqs
448448
449+ def _compute_expected_intermediate_tokens (self , scheduler_output : Any ) -> Optional [int ]:
450+ """Estimate the padded token count expected by vLLM for this batch."""
451+ if scheduler_output is None :
452+ return None
453+
454+ total_tokens = getattr (scheduler_output , "total_num_scheduled_tokens" , None )
455+ if total_tokens is None :
456+ return None
457+
458+ try :
459+ total_tokens = int (total_tokens )
460+ except (TypeError , ValueError ):
461+ return None
462+
463+ model_runner = getattr (self , "model_runner" , None )
464+ if model_runner is None :
465+ return None
466+
467+ get_num_input_tokens = getattr (model_runner , "_get_num_input_tokens" , None )
468+ get_dp_padding = getattr (model_runner , "get_dp_padding" , None )
469+ if get_num_input_tokens is None or get_dp_padding is None :
470+ return None
471+
472+ num_input_tokens = get_num_input_tokens (total_tokens )
473+ num_pad , _ = get_dp_padding (num_input_tokens )
474+ return num_input_tokens + num_pad
475+
476+ @staticmethod
477+ def _pad_or_trim_tensor (tensor : torch .Tensor , target_len : int ) -> torch .Tensor :
478+ if target_len < 0 :
479+ return tensor
480+ current_len = tensor .shape [0 ]
481+ if current_len == target_len :
482+ return tensor
483+ if current_len > target_len :
484+ return tensor [:target_len ]
485+ pad_shape = (target_len - current_len ,) + tensor .shape [1 :]
486+ pad = tensor .new_zeros (pad_shape )
487+ return torch .cat ((tensor , pad ), dim = 0 )
488+
489+ def _resize_intermediate_tensors (self , intermediate_tensors , target_len : Optional [int ]):
490+ if intermediate_tensors is None or target_len is None :
491+ return intermediate_tensors
492+ if target_len < 0 :
493+ return intermediate_tensors
494+
495+ # Create a list to avoid "dictionary changed size during iteration".
496+ for key , tensor in list (intermediate_tensors .items ()):
497+ intermediate_tensors [key ] = self ._pad_or_trim_tensor (tensor , target_len )
498+ return intermediate_tensors
499+
449500 def _prepare_cuda_prefill_batch (self , batched_requests : List [Request ]) -> Dict [str , Any ]:
450501 """
451502 Prepares inputs for CUDA backends from a batch of prefill requests.
@@ -459,6 +510,7 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s
459510
460511 # Prepare PP proxy tensors (common for both backends when not first peer)
461512 pp_proxy_tensors = None
513+ pp_proxy_initial_tokens = None
462514 if not self .is_first_peer :
463515 # Concatenate hidden states from all requests
464516 # For vLLM, we need to flatten to (total_tokens, hidden_size)
@@ -478,6 +530,7 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s
478530
479531 # Concatenate along sequence dimension to get (total_tokens, hidden_size)
480532 hidden_states = torch .cat (hidden_states_list , dim = 0 )
533+ pp_proxy_initial_tokens = hidden_states .shape [0 ]
481534
482535 # Create residual tensor with same shape
483536 residual = torch .zeros (
@@ -515,6 +568,29 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s
515568
516569 schedule_outputs_prefill = form_vllm_batch_prefill (batched_requests , self .model_runner )
517570
571+ if not self .is_first_peer and pp_proxy_tensors is not None :
572+ target_tokens = self ._compute_expected_intermediate_tokens (schedule_outputs_prefill )
573+ if target_tokens is not None :
574+ before = pp_proxy_tensors ["hidden_states" ].shape [0 ]
575+ pp_proxy_tensors = self ._resize_intermediate_tensors (
576+ pp_proxy_tensors , target_tokens
577+ )
578+ after = pp_proxy_tensors ["hidden_states" ].shape [0 ]
579+ if after != before :
580+ logger .debug (
581+ "PP Proxy: resized hidden_states from %d to %d tokens (requested=%s, initial=%s)" ,
582+ before ,
583+ after ,
584+ target_tokens ,
585+ pp_proxy_initial_tokens ,
586+ )
587+
588+ if not self .is_first_peer and pp_proxy_tensors is not None :
589+ logger .debug (
590+ "PP Proxy: hidden_states shape after adjustment: %s" ,
591+ tuple (pp_proxy_tensors ["hidden_states" ].shape ),
592+ )
593+
518594 ret = {
519595 "scheduler_output" : schedule_outputs_prefill ,
520596 "pp_proxy_tensors" : pp_proxy_tensors ,
@@ -572,6 +648,7 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st
572648
573649 # Concatenate along sequence dimension to get (total_tokens, hidden_size)
574650 hidden_states = torch .cat (hidden_states_list , dim = 0 )
651+ pp_proxy_initial_tokens = hidden_states .shape [0 ]
575652
576653 # Create residual tensor with same shape
577654 residual = torch .zeros (
@@ -918,6 +995,10 @@ def _handle_cuda_input_requests(self, requests: List[Request]):
918995
919996 assert req .next_token_id is not None
920997 original_req .commit_new_token (req .next_token_id )
998+ logger .debug (
999+ f"[FirstPeer-CUDA] Committed token { req .next_token_id } for { req .request_id } , "
1000+ f"output_ids now has { len (original_req .output_ids )} tokens"
1001+ )
9211002 if len (req .routing_table ) > 0 :
9221003 original_req .routing_table = req .routing_table
9231004
@@ -1102,6 +1183,8 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
11021183 assert isinstance (
11031184 request , IntermediateRequest
11041185 ), "Last peer must receive an IntermediateRequest."
1186+ logger .info (f"hidden_states shape: { hidden_states .shape } " )
1187+ logger .info (f"hidden_states: { hidden_states } " )
11051188 if self .device == "cuda" :
11061189 assert hidden_states .dtype in (
11071190 torch .int64 ,
@@ -1143,6 +1226,7 @@ def _prepare_next_batch_requests(
11431226 for i , src_request in enumerate (requests ):
11441227 if self .is_last_peer :
11451228 # Last peer gets a 1D array of token IDs
1229+ logger .info (f"hidden_states: { hidden_states } " )
11461230 hidden_state_for_req = hidden_states [i : i + 1 ]
11471231 else :
11481232 # Other peers get a 3D array of hidden states
@@ -1217,6 +1301,7 @@ def _process_batch_cuda(
12171301 import torch
12181302
12191303 sampled_token_ids = output .sampled_token_ids
1304+ logger .info (f"sampled_token_ids: { sampled_token_ids } " )
12201305 if isinstance (sampled_token_ids , list ) and len (sampled_token_ids ) > 0 :
12211306 # Convert to tensor: pad sequences to same length
12221307 max_len = max (len (seq ) for seq in sampled_token_ids )
@@ -1498,6 +1583,7 @@ def run_loop(self):
14981583 output = self .process_batch (
14991584 prepared_inputs , return_decoded_tokens = self .is_last_peer
15001585 )
1586+ logger .info (f"output: { output } " )
15011587 # Update metrics with per-layer latency sample (throttled by decode steps)
15021588 if batch_type == "decode_batch" :
15031589 try :
0 commit comments