Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import logging
import multiprocessing as mp
import os
from typing import Any, Optional, Union
import uuid
from typing import Any, Callable, Optional, TypeVar, Union

from vllm.inputs import TextPrompt
from vllm.inputs.preprocess import InputPreprocessor
Expand All @@ -38,6 +39,8 @@

logger = init_logger(__name__)

_R = TypeVar("_R")


class OmniStage:
"""Stage manager for orchestrating a single stage in the omni pipeline.
Expand Down Expand Up @@ -320,6 +323,64 @@ def process_engine_inputs(
stage_list, engine_input_source, prompt, self.requires_multimodal_data
)

def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
"""Execute an RPC call on all workers via the stage engine.

Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.

If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.

Returns:
A list containing the results from each worker.

Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc"

# Submit collective_rpc task to worker
rpc_id = str(uuid.uuid4())
self._in_q.put({
"type": "collective_rpc",
"rpc_id": rpc_id,
"method": method,
"timeout": timeout,
"args": args,
"kwargs": kwargs,
})

# Wait for result from worker
import time
start_time = time.time()
while True:
if timeout is not None and (time.time() - start_time) > timeout:
raise TimeoutError(f"collective_rpc timed out after {timeout} seconds")

result = self.try_collect()
if result is not None:
if result.get("type") == "collective_rpc_result":
if result.get("rpc_id") == rpc_id:
if "error" in result:
Comment on lines +374 to +378

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge RPC wait loop discards non-RPC outputs

While waiting for a matching collective_rpc_result, the new collective_rpc method pops any item from self._out_q via try_collect() and ignores everything that is not the target RPC result. That silently drops unrelated stage outputs (e.g., pending inference responses or status messages), so issuing a collective RPC while other work is in-flight will lose those messages and leave callers hanging without responses.

Useful? React with 👍 / 👎.

raise RuntimeError(f"collective_rpc failed: {result['error']}")
return result["result"]

time.sleep(0.001) # Small sleep to avoid busy waiting


def _stage_worker(
model: str,
Expand Down Expand Up @@ -448,12 +509,35 @@ def filter(self, record: _logging.LogRecord) -> bool:

# Batch processing loop
while True:

task = in_q.get()
_recv_dequeue_ts = _time.time()
if task is None:
_logging.getLogger(__name__).error("[Stage-%s] Received shutdown signal", stage_id)
break

# Handle collective_rpc requests
if isinstance(task, dict) and task.get("type") == "collective_rpc":
rpc_id = task.get("rpc_id")
method = task.get("method")
timeout = task.get("timeout")
args = task.get("args")
kwargs = task.get("kwargs")
try:
result = stage_engine.collective_rpc(method, timeout, args, kwargs)
out_q.put({
Comment on lines +526 to +528

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge collective_rpc task calls missing engine method

The worker loop dispatches collective RPCs to stage_engine.collective_rpc(...), but OmniStageLLM (and the rest of the repo—checked with rg "collective_rpc") defines no such method. Any collective_rpc task will therefore raise an AttributeError in the worker, propagate back as an error reply, and cause OmniStage.collective_rpc to raise a RuntimeError, leaving the new API unusable.

Useful? React with 👍 / 👎.

"type": "collective_rpc_result",
"rpc_id": rpc_id,
"result": result,
})
except Exception as e:
out_q.put({
"type": "collective_rpc_result",
"rpc_id": rpc_id,
"error": str(e),
})
continue

max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1)
print(f"[Stage-{stage_id}] Max batch size: {max_batch_size}")
batch_tasks: list[dict[str, Any]] = [task]
Expand Down Expand Up @@ -775,6 +859,28 @@ def filter(self, record: _logging.LogRecord) -> bool:
_logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id)
break

# Handle collective_rpc requests
if isinstance(task, dict) and task.get("type") == "collective_rpc":
rpc_id = task.get("rpc_id")
method = task.get("method")
timeout = task.get("timeout")
args = task.get("args")
kwargs = task.get("kwargs")
try:
result = await stage_engine.collective_rpc(method, timeout, args, kwargs)
out_q.put({
"type": "collective_rpc_result",
"rpc_id": rpc_id,
"result": result,
})
except Exception as e:
out_q.put({
"type": "collective_rpc_result",
"rpc_id": rpc_id,
"error": str(e),
})
continue

_rx_bytes_by_rid: dict[Any, int] = {}
_rx_decode_ms_by_rid: dict[Any, float] = {}
_in_flight_ms_by_rid: dict[Any, float] = {}
Expand Down