From 5f17c7c8bd3a238914dbe949ba5728e2ee3e3833 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Thu, 18 Dec 2025 11:05:41 +0800 Subject: [PATCH 1/3] Try to do rpc Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/omni_stage.py | 90 ++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 5150ea27e..2d985717b 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,7 +15,7 @@ import logging import multiprocessing as mp import os -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, TypeVar, Union from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor @@ -38,6 +38,8 @@ logger = init_logger(__name__) +_R = TypeVar("_R") + class OmniStage: """Stage manager for orchestrating a single stage in the omni pipeline. @@ -320,6 +322,63 @@ def process_engine_inputs( stage_list, engine_input_source, prompt, self.requires_multimodal_data ) + def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + """Execute an RPC call on all workers via the stage engine. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" + + # Submit collective_rpc task to worker + import uuid + rpc_id = str(uuid.uuid4()) + self._in_q.put({ + "type": "collective_rpc", + "rpc_id": rpc_id, + "method": method, + "timeout": timeout, + "args": args, + "kwargs": kwargs, + }) + + # Wait for result from worker + import time + start_time = time.time() + while True: + if timeout is not None and (time.time() - start_time) > timeout: + raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") + + result = self.try_collect() + if result is not None and result.get("type") == "collective_rpc_result" and result.get("rpc_id") == rpc_id: + if "error" in result: + raise RuntimeError(f"collective_rpc failed: {result['error']}") + return result["result"] + + time.sleep(0.001) # Small sleep to avoid busy waiting + def _stage_worker( model: str, @@ -448,12 +507,41 @@ def filter(self, record: _logging.LogRecord) -> bool: # Batch processing loop while True: + task = in_q.get() _recv_dequeue_ts = _time.time() if task is None: _logging.getLogger(__name__).error("[Stage-%s] Received shutdown signal", stage_id) break + # Handle collective_rpc requests + if isinstance(task, dict) and task.get("type") == "collective_rpc": + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args", ()) + kwargs = task.get("kwargs") + try: + _logging.getLogger(__name__).debug( + "[Stage-%s] Executing collective_rpc: method=%s", stage_id, method + ) + result = stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + }) + except Exception as e: + _logging.getLogger(__name__).exception( + "[Stage-%s] collective_rpc failed: %s", stage_id, e + ) + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + }) + continue + max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) print(f"[Stage-{stage_id}] Max batch size: {max_batch_size}") batch_tasks: list[dict[str, Any]] = [task] From 1ec072c56abcfcccd85c399c1bd520f14a04c99b Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Thu, 18 Dec 2025 15:25:04 +0800 Subject: [PATCH 2/3] Implement rpc Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/omni_stage.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 2d985717b..3098f7c56 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,6 +15,7 @@ import logging import multiprocessing as mp import os +import uuid from typing import Any, Callable, Optional, TypeVar, Union from vllm.inputs import TextPrompt @@ -353,7 +354,6 @@ def collective_rpc( assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" # Submit collective_rpc task to worker - import uuid rpc_id = str(uuid.uuid4()) self._in_q.put({ "type": "collective_rpc", @@ -372,11 +372,13 @@ def collective_rpc( raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") result = self.try_collect() - if result is not None and result.get("type") == "collective_rpc_result" and result.get("rpc_id") == rpc_id: - if "error" in result: - raise RuntimeError(f"collective_rpc failed: {result['error']}") - return result["result"] - + if result is not None: + if result.get("type") == "collective_rpc_result": + if result.get("rpc_id") == rpc_id: + if "error" in result: + raise RuntimeError(f"collective_rpc failed: {result['error']}") + return result["result"] + time.sleep(0.001) # Small sleep to avoid busy waiting @@ -519,12 +521,9 @@ def filter(self, record: _logging.LogRecord) -> bool: rpc_id = task.get("rpc_id") method = task.get("method") timeout = task.get("timeout") - args = task.get("args", ()) + args = task.get("args") kwargs = task.get("kwargs") try: - _logging.getLogger(__name__).debug( - "[Stage-%s] Executing collective_rpc: method=%s", stage_id, method - ) result = stage_engine.collective_rpc(method, timeout, args, kwargs) out_q.put({ "type": "collective_rpc_result", @@ -532,9 +531,6 @@ def filter(self, record: _logging.LogRecord) -> bool: "result": result, }) except Exception as e: - _logging.getLogger(__name__).exception( - "[Stage-%s] collective_rpc failed: %s", stage_id, e - ) out_q.put({ "type": "collective_rpc_result", "rpc_id": rpc_id, From a6ea2116a982e8944894c622a794ad134ffbc8cf Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Thu, 18 Dec 2025 15:31:28 +0800 Subject: [PATCH 3/3] Support for async engine Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/omni_stage.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 3098f7c56..0e3c8b4bf 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -859,6 +859,28 @@ def filter(self, record: _logging.LogRecord) -> bool: _logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id) break + # Handle collective_rpc requests + if isinstance(task, dict) and task.get("type") == "collective_rpc": + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + try: + result = await stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + }) + except Exception as e: + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + }) + continue + _rx_bytes_by_rid: dict[Any, int] = {} _rx_decode_ms_by_rid: dict[Any, float] = {} _in_flight_ms_by_rid: dict[Any, float] = {}