Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import multiprocessing as mp
import os
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, TypeVar, Union

from vllm.inputs import TextPrompt
from vllm.inputs.preprocess import InputPreprocessor
Expand All @@ -38,6 +38,8 @@

logger = init_logger(__name__)

_R = TypeVar("_R")


class OmniStage:
"""Stage manager for orchestrating a single stage in the omni pipeline.
Expand Down Expand Up @@ -320,6 +322,63 @@ def process_engine_inputs(
stage_list, engine_input_source, prompt, self.requires_multimodal_data
)

def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
"""Execute an RPC call on all workers via the stage engine.

Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.

If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.

Returns:
A list containing the results from each worker.

Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc"

# Submit collective_rpc task to worker
import uuid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we use lazy import here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it should be a draft [WIP] but thanks for the review. Will fix it

rpc_id = str(uuid.uuid4())
self._in_q.put({
"type": "collective_rpc",
"rpc_id": rpc_id,
"method": method,
"timeout": timeout,
"args": args,
"kwargs": kwargs,
})

# Wait for result from worker
import time
start_time = time.time()
while True:
if timeout is not None and (time.time() - start_time) > timeout:
raise TimeoutError(f"collective_rpc timed out after {timeout} seconds")

result = self.try_collect()
if result is not None and result.get("type") == "collective_rpc_result" and result.get("rpc_id") == rpc_id:
if "error" in result:
raise RuntimeError(f"collective_rpc failed: {result['error']}")
return result["result"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge collective_rpc drops unrelated out-queue messages

While waiting for a matching collective_rpc_result, this loop consumes any item from self._out_q via try_collect() and discards it if the type/rpc_id do not match. That permanently removes normal stage outputs that arrive while the RPC is pending, causing those requests to hang because their responses are lost. The code should buffer or requeue non-RPC messages instead of dropping them.

Useful? React with 👍 / 👎.


time.sleep(0.001) # Small sleep to avoid busy waiting


def _stage_worker(
model: str,
Expand Down Expand Up @@ -448,12 +507,41 @@ def filter(self, record: _logging.LogRecord) -> bool:

# Batch processing loop
while True:

task = in_q.get()
_recv_dequeue_ts = _time.time()
if task is None:
_logging.getLogger(__name__).error("[Stage-%s] Received shutdown signal", stage_id)
break

# Handle collective_rpc requests
if isinstance(task, dict) and task.get("type") == "collective_rpc":
rpc_id = task.get("rpc_id")
method = task.get("method")
timeout = task.get("timeout")
args = task.get("args", ())
kwargs = task.get("kwargs")
try:
_logging.getLogger(__name__).debug(
"[Stage-%s] Executing collective_rpc: method=%s", stage_id, method
)
result = stage_engine.collective_rpc(method, timeout, args, kwargs)
out_q.put({

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Worker RPC handler calls missing method

The collective_rpc path in the worker invokes stage_engine.collective_rpc(...), but the stage engine is an OmniStageLLM (created at line 500) and there is no collective_rpc implementation anywhere in that class or the rest of the repo. Every collective_rpc task will therefore raise an AttributeError and be returned as an error response, so the new API never actually executes on the worker.

Useful? React with 👍 / 👎.

"type": "collective_rpc_result",
"rpc_id": rpc_id,
"result": result,
})
except Exception as e:
_logging.getLogger(__name__).exception(
"[Stage-%s] collective_rpc failed: %s", stage_id, e
)
out_q.put({
"type": "collective_rpc_result",
"rpc_id": rpc_id,
"error": str(e),
})
continue

max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1)
print(f"[Stage-{stage_id}] Max batch size: {max_batch_size}")
batch_tasks: list[dict[str, Any]] = [task]
Expand Down