Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import multiprocessing as mp
import time
import weakref
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

from vllm.logger import init_logger

Expand Down Expand Up @@ -195,6 +197,77 @@ def _launch_workers(self, broadcast_handle):
def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
return scheduler.add_req(requests)

def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
unique_reply_rank: int | None = None,
) -> Any:
"""Call a method on worker processes and get results immediately.

Args:
method: The method name (str) or callable to execute on workers
timeout: Optional timeout in seconds
args: Positional arguments for the method
kwargs: Keyword arguments for the method
unique_reply_rank: If set, only get reply from this rank

Returns:
Single result if unique_reply_rank is provided, otherwise list of results
"""
if self._closed:
raise RuntimeError("DiffusionEngine is closed.")

deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}

assert isinstance(method, str)
send_method = method

# Prepare RPC request message
rpc_request = {
"type": "rpc",
"method": send_method,
"args": args,
"kwargs": kwargs,
"output_rank": unique_reply_rank,
}

try:
# Broadcast RPC request to all workers via unified message queue
scheduler.mq.enqueue(rpc_request)

# Determine which workers we expect responses from
num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus

responses = []
for _ in range(num_responses):
dequeue_timeout = None if deadline is None else (deadline - time.monotonic())
try:
if scheduler.result_mq is None:
raise RuntimeError("Result queue not initialized")

response = scheduler.result_mq.dequeue(timeout=dequeue_timeout)

# Check if response indicates an error
if isinstance(response, dict) and response.get("status") == "error":
raise RuntimeError(
f"Worker failed with error '{response.get('error')}', "
"please check the stack trace above for the root cause"
)

responses.append(response)
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e

return responses[0] if unique_reply_rank is not None else responses

except Exception as e:
logger.error(f"RPC call failed: {e}")
raise

def _dummy_run(self):
"""A dummy run to warm up the model."""
prompt = "dummy run"
Expand Down
17 changes: 14 additions & 3 deletions vllm_omni/diffusion/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def initialize(self, od_config: OmniDiffusionConfig):
self.od_config = od_config
self.context = zmq.Context() # Standard synchronous context

# Initialize MessageQueue for broadcasting requests
# Initialize single MessageQueue for all message types (generation & RPC)
# Assuming all readers are local for now as per current launch_engine implementation
self.mq = MessageQueue(
n_reader=self.num_workers,
Expand All @@ -51,9 +51,20 @@ def get_broadcast_handle(self):
def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
"""Sends a request to the scheduler and waits for the response."""
try:
# Broadcast request to all workers
self.mq.enqueue(requests)
# Prepare RPC request for generation
rpc_request = {
"type": "rpc",
"method": "generate",
"args": (requests,),
"kwargs": {},
"output_rank": 0,
"exec_all_ranks": True,
}

# Broadcast RPC request to all workers
self.mq.enqueue(rpc_request)
# Wait for result from Rank 0 (or whoever sends it)

if self.result_mq is None:
raise RuntimeError("Result queue not initialized")

Expand Down
109 changes: 81 additions & 28 deletions vllm_omni/diffusion/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from vllm_omni.diffusion.cache.selector import get_cache_backend
from vllm_omni.diffusion.data import (
SHUTDOWN_MESSAGE,
DiffusionOutput,
OmniDiffusionConfig,
set_current_omni_diffusion_config,
Expand Down Expand Up @@ -107,6 +106,18 @@ def init_device_and_model(self) -> None:
if self.cache_backend is not None:
self.cache_backend.enable(self.pipeline)

def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
"""
Generate output for the given requests.

Args:
requests: List of diffusion requests

Returns:
DiffusionOutput with generated results
"""
return self.execute_model(requests, self.od_config)

@torch.inference_mode()
def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
"""
Expand Down Expand Up @@ -141,7 +152,7 @@ def __init__(
# Inter-process Communication
self.context = zmq.Context(io_threads=2)

# Initialize MessageQueue reader from handle
# Initialize MessageQueue reader from handle (unified for generation & RPC)
self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id)

self.result_mq = None
Expand Down Expand Up @@ -173,55 +184,97 @@ def return_result(self, output: DiffusionOutput):
if self.result_mq is not None:
self.result_mq.enqueue(output)

def recv_reqs(self):
def recv_message(self):
"""
Receive requests from broadcast queue
Receive unified messages (RPC requests, shutdown) from broadcast queue.
Uses indefinite=True to block until a message arrives.
"""
return self.mq.dequeue(indefinite=True)

def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]:
"""Execute an RPC request and indicate whether to reply."""

method = rpc_request["method"]
args = rpc_request.get("args", ())
kwargs = rpc_request.get("kwargs", {})
output_rank = rpc_request.get("output_rank")
exec_all_ranks = rpc_request.get("exec_all_ranks", False)

should_execute = exec_all_ranks or output_rank is None or output_rank == self.gpu_id
should_reply = (output_rank is None or output_rank == self.gpu_id) and self.result_mq is not None

if not should_execute:
return None, False

try:
if isinstance(method, str):
func = getattr(self.worker, method)
result = func(*args, **kwargs)
else:
result = method(self.worker, *args, **kwargs)
return result, should_reply
except Exception as e:
logger.error(f"Error executing RPC: {e}", exc_info=True)
return {"status": "error", "error": str(e)}, should_reply

# TODO: queueing, cancellation
def worker_busy_loop(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should redesign this, because it has become more complex.

we can address the redesign in a separate, follow-up task if you don't have time.

here is a good example of worker_busy_loop: https://github.com/vllm-project/vllm/blob/c02a2705f9ceeb00b5d32453621f997b2ceafbea/vllm/v1/executor/multiproc_executor.py#L806

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with it, the redesign is WIP and will need a more structure RFC.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming — is this already WIP?

Copy link
Contributor Author

@knlnguyen1802 knlnguyen1802 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is not WIP. But the redesign as you said above is on working

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZJY0516 It's ready now. Could you take a look again thanks ?

"""Main busy loop for Multiprocessing Workers"""

logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory")

while self._running:
reqs = None
# 1: receive requests
# Receive unified message (generation request, RPC request, or shutdown)
msg = None
try:
reqs = self.recv_reqs()
msg = self.recv_message()
except Exception as e:
logger.error(
f"Error receiving requests in scheduler event loop: {e}",
f"Error receiving message in worker loop: {e}",
exc_info=True,
)
continue

if reqs == SHUTDOWN_MESSAGE:
logger.info("Worker %s: Received shutdown message", self.gpu_id)
self._running = False
continue
if reqs is None:
if msg is None:
logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id)
continue

# 2: execute, make sure a reply is always sent
try:
output = self.worker.execute_model(reqs, self.od_config)
except Exception as e:
logger.error(
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))

try:
self.return_result(output)
except zmq.ZMQError as e:
# Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
# Route message based on type
if isinstance(msg, dict) and msg.get("type") == "rpc":
# Handle RPC request
try:
result, should_reply = self.execute_rpc(msg)
if should_reply:
self.return_result(result)
except Exception as e:
logger.error(f"Error processing RPC: {e}", exc_info=True)
if self.result_mq is not None:
self.return_result({"status": "error", "error": str(e)})

elif isinstance(msg, dict) and msg.get("type") == "shutdown":
# Handle shutdown message
logger.info("Worker %s: Received shutdown message", self.gpu_id)
self._running = False
continue

else:
# Handle generation request (OmniDiffusionRequest list)
try:
output = self.worker.execute_model(msg, self.od_config)
except Exception as e:
logger.error(
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))

try:
self.return_result(output)
except zmq.ZMQError as e:
# Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
continue

logger.info("event loop terminated.")
try:
self.worker.shutdown()
Expand Down