-
Notifications
You must be signed in to change notification settings - Fork 209
RPC support for OmniLLM #355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,8 @@ | |
| import logging | ||
| import multiprocessing as mp | ||
| import os | ||
| from typing import Any, Optional, Union | ||
| import uuid | ||
| from typing import Any, Callable, Optional, TypeVar, Union | ||
|
|
||
| from vllm.inputs import TextPrompt | ||
| from vllm.inputs.preprocess import InputPreprocessor | ||
|
|
@@ -38,6 +39,8 @@ | |
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| _R = TypeVar("_R") | ||
|
|
||
|
|
||
| class OmniStage: | ||
| """Stage manager for orchestrating a single stage in the omni pipeline. | ||
|
|
@@ -320,6 +323,64 @@ def process_engine_inputs( | |
| stage_list, engine_input_source, prompt, self.requires_multimodal_data | ||
| ) | ||
|
|
||
| def collective_rpc( | ||
| self, | ||
| method: str | Callable[..., _R], | ||
| timeout: float | None = None, | ||
| args: tuple = (), | ||
| kwargs: dict[str, Any] | None = None, | ||
| ) -> list[_R]: | ||
| """Execute an RPC call on all workers via the stage engine. | ||
|
|
||
| Args: | ||
| method: Name of the worker method to execute, or a callable that | ||
| is serialized and sent to all workers to execute. | ||
|
|
||
| If the method is a callable, it should accept an additional | ||
| `self` argument, in addition to the arguments passed in `args` | ||
| and `kwargs`. The `self` argument will be the worker object. | ||
| timeout: Maximum time in seconds to wait for execution. Raises a | ||
| [`TimeoutError`][] on timeout. `None` means wait indefinitely. | ||
| args: Positional arguments to pass to the worker method. | ||
| kwargs: Keyword arguments to pass to the worker method. | ||
|
|
||
| Returns: | ||
| A list containing the results from each worker. | ||
|
|
||
| Note: | ||
| It is recommended to use this API to only pass control messages, | ||
| and set up data-plane communication to pass data. | ||
| """ | ||
| assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" | ||
|
|
||
| # Submit collective_rpc task to worker | ||
| rpc_id = str(uuid.uuid4()) | ||
| self._in_q.put({ | ||
| "type": "collective_rpc", | ||
| "rpc_id": rpc_id, | ||
| "method": method, | ||
| "timeout": timeout, | ||
| "args": args, | ||
| "kwargs": kwargs, | ||
| }) | ||
|
|
||
| # Wait for result from worker | ||
| import time | ||
| start_time = time.time() | ||
| while True: | ||
| if timeout is not None and (time.time() - start_time) > timeout: | ||
| raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") | ||
|
|
||
| result = self.try_collect() | ||
| if result is not None: | ||
| if result.get("type") == "collective_rpc_result": | ||
| if result.get("rpc_id") == rpc_id: | ||
| if "error" in result: | ||
| raise RuntimeError(f"collective_rpc failed: {result['error']}") | ||
| return result["result"] | ||
|
|
||
| time.sleep(0.001) # Small sleep to avoid busy waiting | ||
|
|
||
|
|
||
| def _stage_worker( | ||
| model: str, | ||
|
|
@@ -448,12 +509,35 @@ def filter(self, record: _logging.LogRecord) -> bool: | |
|
|
||
| # Batch processing loop | ||
| while True: | ||
|
|
||
| task = in_q.get() | ||
| _recv_dequeue_ts = _time.time() | ||
| if task is None: | ||
| _logging.getLogger(__name__).error("[Stage-%s] Received shutdown signal", stage_id) | ||
| break | ||
|
|
||
| # Handle collective_rpc requests | ||
| if isinstance(task, dict) and task.get("type") == "collective_rpc": | ||
| rpc_id = task.get("rpc_id") | ||
| method = task.get("method") | ||
| timeout = task.get("timeout") | ||
| args = task.get("args") | ||
| kwargs = task.get("kwargs") | ||
| try: | ||
| result = stage_engine.collective_rpc(method, timeout, args, kwargs) | ||
| out_q.put({ | ||
|
Comment on lines
+526
to
+528
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The worker loop dispatches collective RPCs to Useful? React with 👍 / 👎. |
||
| "type": "collective_rpc_result", | ||
| "rpc_id": rpc_id, | ||
| "result": result, | ||
| }) | ||
| except Exception as e: | ||
| out_q.put({ | ||
| "type": "collective_rpc_result", | ||
| "rpc_id": rpc_id, | ||
| "error": str(e), | ||
| }) | ||
| continue | ||
|
|
||
| max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) | ||
| print(f"[Stage-{stage_id}] Max batch size: {max_batch_size}") | ||
| batch_tasks: list[dict[str, Any]] = [task] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While waiting for a matching
collective_rpc_result, the newcollective_rpcmethod pops any item fromself._out_qviatry_collect()and ignores everything that is not the target RPC result. That silently drops unrelated stage outputs (e.g., pending inference responses or status messages), so issuing a collective RPC while other work is in-flight will lose those messages and leave callers hanging without responses.Useful? React with 👍 / 👎.