diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 5150ea27e..0e3c8b4bf 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,7 +15,8 @@ import logging import multiprocessing as mp import os -from typing import Any, Optional, Union +import uuid +from typing import Any, Callable, Optional, TypeVar, Union from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor @@ -38,6 +39,8 @@ logger = init_logger(__name__) +_R = TypeVar("_R") + class OmniStage: """Stage manager for orchestrating a single stage in the omni pipeline. @@ -320,6 +323,64 @@ def process_engine_inputs( stage_list, engine_input_source, prompt, self.requires_multimodal_data ) + def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + """Execute an RPC call on all workers via the stage engine. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" + + # Submit collective_rpc task to worker + rpc_id = str(uuid.uuid4()) + self._in_q.put({ + "type": "collective_rpc", + "rpc_id": rpc_id, + "method": method, + "timeout": timeout, + "args": args, + "kwargs": kwargs, + }) + + # Wait for result from worker + import time + start_time = time.time() + while True: + if timeout is not None and (time.time() - start_time) > timeout: + raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") + + result = self.try_collect() + if result is not None: + if result.get("type") == "collective_rpc_result": + if result.get("rpc_id") == rpc_id: + if "error" in result: + raise RuntimeError(f"collective_rpc failed: {result['error']}") + return result["result"] + + time.sleep(0.001) # Small sleep to avoid busy waiting + def _stage_worker( model: str, @@ -448,12 +509,35 @@ def filter(self, record: _logging.LogRecord) -> bool: # Batch processing loop while True: + task = in_q.get() _recv_dequeue_ts = _time.time() if task is None: _logging.getLogger(__name__).error("[Stage-%s] Received shutdown signal", stage_id) break + # Handle collective_rpc requests + if isinstance(task, dict) and task.get("type") == "collective_rpc": + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + try: + result = stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + }) + except Exception as e: + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + }) + continue + max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) print(f"[Stage-{stage_id}] Max batch size: {max_batch_size}") batch_tasks: list[dict[str, Any]] = [task] @@ -775,6 +859,28 @@ def filter(self, record: _logging.LogRecord) -> bool: _logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id) break + # Handle collective_rpc requests + if isinstance(task, dict) and task.get("type") == "collective_rpc": + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + try: + result = await stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + }) + except Exception as e: + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + }) + continue + _rx_bytes_by_rid: dict[Any, int] = {} _rx_decode_ms_by_rid: dict[Any, float] = {} _in_flight_ms_by_rid: dict[Any, float] = {}