-
Notifications
You must be signed in to change notification settings - Fork 238
RPC support for OmniLLM #355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| import logging | ||
| import multiprocessing as mp | ||
| import os | ||
| from typing import Any, Optional, Union | ||
| from typing import Any, Callable, Optional, TypeVar, Union | ||
|
|
||
| from vllm.inputs import TextPrompt | ||
| from vllm.inputs.preprocess import InputPreprocessor | ||
|
|
@@ -38,6 +38,8 @@ | |
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| _R = TypeVar("_R") | ||
|
|
||
|
|
||
| class OmniStage: | ||
| """Stage manager for orchestrating a single stage in the omni pipeline. | ||
|
|
@@ -320,6 +322,63 @@ def process_engine_inputs( | |
| stage_list, engine_input_source, prompt, self.requires_multimodal_data | ||
| ) | ||
|
|
||
| def collective_rpc( | ||
| self, | ||
| method: str | Callable[..., _R], | ||
| timeout: float | None = None, | ||
| args: tuple = (), | ||
| kwargs: dict[str, Any] | None = None, | ||
| ) -> list[_R]: | ||
| """Execute an RPC call on all workers via the stage engine. | ||
|
|
||
| Args: | ||
| method: Name of the worker method to execute, or a callable that | ||
| is serialized and sent to all workers to execute. | ||
|
|
||
| If the method is a callable, it should accept an additional | ||
| `self` argument, in addition to the arguments passed in `args` | ||
| and `kwargs`. The `self` argument will be the worker object. | ||
| timeout: Maximum time in seconds to wait for execution. Raises a | ||
| [`TimeoutError`][] on timeout. `None` means wait indefinitely. | ||
| args: Positional arguments to pass to the worker method. | ||
| kwargs: Keyword arguments to pass to the worker method. | ||
|
|
||
| Returns: | ||
| A list containing the results from each worker. | ||
|
|
||
| Note: | ||
| It is recommended to use this API to only pass control messages, | ||
| and set up data-plane communication to pass data. | ||
| """ | ||
| assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" | ||
|
|
||
| # Submit collective_rpc task to worker | ||
| import uuid | ||
| rpc_id = str(uuid.uuid4()) | ||
| self._in_q.put({ | ||
| "type": "collective_rpc", | ||
| "rpc_id": rpc_id, | ||
| "method": method, | ||
| "timeout": timeout, | ||
| "args": args, | ||
| "kwargs": kwargs, | ||
| }) | ||
|
|
||
| # Wait for result from worker | ||
| import time | ||
| start_time = time.time() | ||
| while True: | ||
| if timeout is not None and (time.time() - start_time) > timeout: | ||
| raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") | ||
|
|
||
| result = self.try_collect() | ||
| if result is not None and result.get("type") == "collective_rpc_result" and result.get("rpc_id") == rpc_id: | ||
| if "error" in result: | ||
| raise RuntimeError(f"collective_rpc failed: {result['error']}") | ||
| return result["result"] | ||
|
||
|
|
||
| time.sleep(0.001) # Small sleep to avoid busy waiting | ||
|
|
||
|
|
||
| def _stage_worker( | ||
| model: str, | ||
|
|
@@ -448,12 +507,41 @@ def filter(self, record: _logging.LogRecord) -> bool: | |
|
|
||
| # Batch processing loop | ||
| while True: | ||
|
|
||
| task = in_q.get() | ||
| _recv_dequeue_ts = _time.time() | ||
| if task is None: | ||
| _logging.getLogger(__name__).error("[Stage-%s] Received shutdown signal", stage_id) | ||
| break | ||
|
|
||
| # Handle collective_rpc requests | ||
| if isinstance(task, dict) and task.get("type") == "collective_rpc": | ||
| rpc_id = task.get("rpc_id") | ||
| method = task.get("method") | ||
| timeout = task.get("timeout") | ||
| args = task.get("args", ()) | ||
| kwargs = task.get("kwargs") | ||
| try: | ||
| _logging.getLogger(__name__).debug( | ||
| "[Stage-%s] Executing collective_rpc: method=%s", stage_id, method | ||
| ) | ||
| result = stage_engine.collective_rpc(method, timeout, args, kwargs) | ||
| out_q.put({ | ||
|
||
| "type": "collective_rpc_result", | ||
| "rpc_id": rpc_id, | ||
| "result": result, | ||
| }) | ||
| except Exception as e: | ||
| _logging.getLogger(__name__).exception( | ||
| "[Stage-%s] collective_rpc failed: %s", stage_id, e | ||
| ) | ||
| out_q.put({ | ||
| "type": "collective_rpc_result", | ||
| "rpc_id": rpc_id, | ||
| "error": str(e), | ||
| }) | ||
| continue | ||
|
|
||
| max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) | ||
| print(f"[Stage-{stage_id}] Max batch size: {max_batch_size}") | ||
| batch_tasks: list[dict[str, Any]] = [task] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we use lazy import here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it should be a draft [WIP] but thanks for the review. Will fix it