Skip to content

Commit ffeb18f

Browse files
committed
fix bug done
1 parent 41ebfec commit ffeb18f

File tree

3 files changed

+143
-5
lines changed

3 files changed

+143
-5
lines changed

src/parallax/server/executor.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,57 @@ def recv_requests_from_peer(self) -> List[Request]:
446446

447447
return recv_reqs
448448

449+
def _compute_expected_intermediate_tokens(self, scheduler_output: Any) -> Optional[int]:
450+
"""Estimate the padded token count expected by vLLM for this batch."""
451+
if scheduler_output is None:
452+
return None
453+
454+
total_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", None)
455+
if total_tokens is None:
456+
return None
457+
458+
try:
459+
total_tokens = int(total_tokens)
460+
except (TypeError, ValueError):
461+
return None
462+
463+
model_runner = getattr(self, "model_runner", None)
464+
if model_runner is None:
465+
return None
466+
467+
get_num_input_tokens = getattr(model_runner, "_get_num_input_tokens", None)
468+
get_dp_padding = getattr(model_runner, "get_dp_padding", None)
469+
if get_num_input_tokens is None or get_dp_padding is None:
470+
return None
471+
472+
num_input_tokens = get_num_input_tokens(total_tokens)
473+
num_pad, _ = get_dp_padding(num_input_tokens)
474+
return num_input_tokens + num_pad
475+
476+
@staticmethod
477+
def _pad_or_trim_tensor(tensor: torch.Tensor, target_len: int) -> torch.Tensor:
478+
if target_len < 0:
479+
return tensor
480+
current_len = tensor.shape[0]
481+
if current_len == target_len:
482+
return tensor
483+
if current_len > target_len:
484+
return tensor[:target_len]
485+
pad_shape = (target_len - current_len,) + tensor.shape[1:]
486+
pad = tensor.new_zeros(pad_shape)
487+
return torch.cat((tensor, pad), dim=0)
488+
489+
def _resize_intermediate_tensors(self, intermediate_tensors, target_len: Optional[int]):
490+
if intermediate_tensors is None or target_len is None:
491+
return intermediate_tensors
492+
if target_len < 0:
493+
return intermediate_tensors
494+
495+
# Create a list to avoid "dictionary changed size during iteration".
496+
for key, tensor in list(intermediate_tensors.items()):
497+
intermediate_tensors[key] = self._pad_or_trim_tensor(tensor, target_len)
498+
return intermediate_tensors
499+
449500
def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]:
450501
"""
451502
Prepares inputs for CUDA backends from a batch of prefill requests.
@@ -459,6 +510,7 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s
459510

460511
# Prepare PP proxy tensors (common for both backends when not first peer)
461512
pp_proxy_tensors = None
513+
pp_proxy_initial_tokens = None
462514
if not self.is_first_peer:
463515
# Concatenate hidden states from all requests
464516
# For vLLM, we need to flatten to (total_tokens, hidden_size)
@@ -478,6 +530,7 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s
478530

479531
# Concatenate along sequence dimension to get (total_tokens, hidden_size)
480532
hidden_states = torch.cat(hidden_states_list, dim=0)
533+
pp_proxy_initial_tokens = hidden_states.shape[0]
481534

482535
# Create residual tensor with same shape
483536
residual = torch.zeros(
@@ -515,6 +568,29 @@ def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[s
515568

516569
schedule_outputs_prefill = form_vllm_batch_prefill(batched_requests, self.model_runner)
517570

571+
if not self.is_first_peer and pp_proxy_tensors is not None:
572+
target_tokens = self._compute_expected_intermediate_tokens(schedule_outputs_prefill)
573+
if target_tokens is not None:
574+
before = pp_proxy_tensors["hidden_states"].shape[0]
575+
pp_proxy_tensors = self._resize_intermediate_tensors(
576+
pp_proxy_tensors, target_tokens
577+
)
578+
after = pp_proxy_tensors["hidden_states"].shape[0]
579+
if after != before:
580+
logger.debug(
581+
"PP Proxy: resized hidden_states from %d to %d tokens (requested=%s, initial=%s)",
582+
before,
583+
after,
584+
target_tokens,
585+
pp_proxy_initial_tokens,
586+
)
587+
588+
if not self.is_first_peer and pp_proxy_tensors is not None:
589+
logger.debug(
590+
"PP Proxy: hidden_states shape after adjustment: %s",
591+
tuple(pp_proxy_tensors["hidden_states"].shape),
592+
)
593+
518594
ret = {
519595
"scheduler_output": schedule_outputs_prefill,
520596
"pp_proxy_tensors": pp_proxy_tensors,
@@ -572,6 +648,7 @@ def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[st
572648

573649
# Concatenate along sequence dimension to get (total_tokens, hidden_size)
574650
hidden_states = torch.cat(hidden_states_list, dim=0)
651+
pp_proxy_initial_tokens = hidden_states.shape[0]
575652

576653
# Create residual tensor with same shape
577654
residual = torch.zeros(
@@ -918,6 +995,10 @@ def _handle_cuda_input_requests(self, requests: List[Request]):
918995

919996
assert req.next_token_id is not None
920997
original_req.commit_new_token(req.next_token_id)
998+
logger.debug(
999+
f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, "
1000+
f"output_ids now has {len(original_req.output_ids)} tokens"
1001+
)
9211002
if len(req.routing_table) > 0:
9221003
original_req.routing_table = req.routing_table
9231004

@@ -1102,6 +1183,8 @@ def _prepare_next_single_request(self, request: Request, hidden_states: Any) ->
11021183
assert isinstance(
11031184
request, IntermediateRequest
11041185
), "Last peer must receive an IntermediateRequest."
1186+
logger.info(f"hidden_states shape: {hidden_states.shape}")
1187+
logger.info(f"hidden_states: {hidden_states}")
11051188
if self.device == "cuda":
11061189
assert hidden_states.dtype in (
11071190
torch.int64,
@@ -1143,6 +1226,7 @@ def _prepare_next_batch_requests(
11431226
for i, src_request in enumerate(requests):
11441227
if self.is_last_peer:
11451228
# Last peer gets a 1D array of token IDs
1229+
logger.info(f"hidden_states: {hidden_states}")
11461230
hidden_state_for_req = hidden_states[i : i + 1]
11471231
else:
11481232
# Other peers get a 3D array of hidden states
@@ -1217,6 +1301,7 @@ def _process_batch_cuda(
12171301
import torch
12181302

12191303
sampled_token_ids = output.sampled_token_ids
1304+
logger.info(f"sampled_token_ids: {sampled_token_ids}")
12201305
if isinstance(sampled_token_ids, list) and len(sampled_token_ids) > 0:
12211306
# Convert to tensor: pad sequences to same length
12221307
max_len = max(len(seq) for seq in sampled_token_ids)
@@ -1498,6 +1583,7 @@ def run_loop(self):
14981583
output = self.process_batch(
14991584
prepared_inputs, return_decoded_tokens=self.is_last_peer
15001585
)
1586+
logger.info(f"output: {output}")
15011587
# Update metrics with per-layer latency sample (throttled by decode steps)
15021588
if batch_type == "decode_batch":
15031589
try:

src/parallax/vllm/batch_info.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def form_vllm_batch_prefill(
159159
def form_vllm_batch_decode(
160160
batched_requests: List[Request],
161161
model_runner: Any = None,
162+
scheduler: Any = None,
162163
) -> Optional[SchedulerOutput]:
163164
if not batched_requests:
164165
return None
@@ -183,7 +184,32 @@ def form_vllm_batch_decode(
183184
for req in batched_requests:
184185
req_ids.append(req.request_id)
185186
resumed_from_preemption.append(False)
187+
188+
# For GPU workers (non-first peer), IntermediateRequest doesn't have output_ids
189+
# We need to get it from vLLM's CachedRequestState in model_runner
186190
output_ids = getattr(req, "output_ids", None) or []
191+
192+
# If this request doesn't have output_ids (IntermediateRequest case),
193+
# try to get it from model_runner's cached request state (vLLM internal state)
194+
if not output_ids and hasattr(model_runner, "requests"):
195+
cached_req_state = model_runner.requests.get(req.request_id)
196+
if cached_req_state is not None:
197+
output_ids = getattr(cached_req_state, "output_token_ids", [])
198+
logger.debug(
199+
f"[Decode] Retrieved output_token_ids from vLLM CachedRequestState for "
200+
f"{req.request_id}: len={len(output_ids)}"
201+
)
202+
203+
# Fallback: try scheduler if available
204+
if not output_ids and scheduler is not None:
205+
running_req = scheduler.get_running_request(req.request_id)
206+
if running_req is not None:
207+
output_ids = getattr(running_req, "output_ids", None) or []
208+
logger.debug(
209+
f"[Decode] Retrieved output_ids from scheduler for {req.request_id}: "
210+
f"len={len(output_ids)}"
211+
)
212+
187213
if output_ids:
188214
last_token = output_ids[-1]
189215
new_token_ids.append([last_token])
@@ -196,13 +222,23 @@ def form_vllm_batch_decode(
196222
vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=True)
197223

198224
prompt_ids = getattr(req, "input_ids", None) or []
199-
output_ids = getattr(req, "output_ids", None) or []
225+
# For decode stage, computed_token_count should be the total number of tokens
226+
# that have been processed (including all output tokens).
227+
# In pipeline parallelism, this must match what GPU worker expects.
200228
if output_ids:
201-
computed_token_count = len(prompt_ids) + len(output_ids) - 1
229+
# All tokens (prompt + all generated outputs) have been computed
230+
computed_token_count = len(prompt_ids) + len(output_ids)
202231
else:
232+
# First decode step: only prompt has been computed
203233
computed_token_count = len(prompt_ids)
204234
vllm_req.num_computed_tokens = computed_token_count
205235

236+
# Debug logging to track state synchronization
237+
logger.debug(
238+
f"[Decode] req_id={req.request_id}, prompt_len={len(prompt_ids)}, "
239+
f"output_len={len(output_ids)}, computed_tokens={computed_token_count}"
240+
)
241+
206242
new_blocks = kv_cache_manager.allocate_slots(
207243
request=vllm_req,
208244
num_new_tokens=1,

src/parallax/vllm/model_runner.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,21 @@ def custom_get_pp_indices(num_layers: int, rank: int, world_size: int):
319319
finally:
320320
vllm.distributed.utils.get_pp_indices = original_get_pp_indices
321321

322-
logger.debug("Model loaded successfully with partial layers")
322+
def execute_model(self, scheduler_output, intermediate_tensors=None):
323+
"""
324+
Execute the model with the given scheduler output and intermediate tensors.
325+
If this is not the first peer, and the intermediate_tensors buffer is not initialized,
326+
initialize it.
327+
"""
328+
if not self.is_first_peer and self.intermediate_tensors is None:
329+
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
330+
batch_size=self.max_num_tokens,
331+
dtype=self.model_config.dtype,
332+
device=self.device,
333+
)
334+
logger.debug("Successfully initialized intermediate_tensors buffer")
335+
336+
return super().execute_model(scheduler_output, intermediate_tensors)
323337

324338

325339
def initialize_vllm_model_runner(
@@ -348,15 +362,17 @@ def initialize_vllm_model_runner(
348362
config = load_config(model_path)
349363
tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
350364
dtype = config.get("torch_dtype", "bfloat16")
351-
365+
352366
num_hidden_layers = config.get("num_hidden_layers")
353367
is_first_peer = start_layer == 0
354368
is_last_peer = end_layer == num_hidden_layers
355369

356370
# Apply Parallax vLLM monkey patches for pipeline parallelism
357371
try:
358372
apply_parallax_vllm_monkey_patch(is_first_stage=is_first_peer, is_last_stage=is_last_peer)
359-
logger.debug(f"Applied Parallax vLLM monkey patches: is_first_stage={is_first_peer}, is_last_stage={is_last_peer}")
373+
logger.debug(
374+
f"Applied Parallax vLLM monkey patches: is_first_stage={is_first_peer}, is_last_stage={is_last_peer}"
375+
)
360376
except Exception as e:
361377
logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e)
362378

0 commit comments

Comments
 (0)