diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 41dda9e6f..30be322a3 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -4,7 +4,9 @@ import multiprocessing as mp import time import weakref +from collections.abc import Callable from dataclasses import dataclass +from typing import Any from vllm.logger import init_logger @@ -195,6 +197,77 @@ def _launch_workers(self, broadcast_handle): def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]): return scheduler.add_req(requests) + def collective_rpc( + self, + method: str | Callable, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + unique_reply_rank: int | None = None, + ) -> Any: + """Call a method on worker processes and get results immediately. + + Args: + method: The method name (str) or callable to execute on workers + timeout: Optional timeout in seconds + args: Positional arguments for the method + kwargs: Keyword arguments for the method + unique_reply_rank: If set, only get reply from this rank + + Returns: + Single result if unique_reply_rank is provided, otherwise list of results + """ + if self._closed: + raise RuntimeError("DiffusionEngine is closed.") + + deadline = None if timeout is None else time.monotonic() + timeout + kwargs = kwargs or {} + + assert isinstance(method, str) + send_method = method + + # Prepare RPC request message + rpc_request = { + "type": "rpc", + "method": send_method, + "args": args, + "kwargs": kwargs, + "output_rank": unique_reply_rank, + } + + try: + # Broadcast RPC request to all workers via unified message queue + scheduler.mq.enqueue(rpc_request) + + # Determine which workers we expect responses from + num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus + + responses = [] + for _ in range(num_responses): + dequeue_timeout = None if deadline is None else (deadline - time.monotonic()) + try: + if scheduler.result_mq is None: + raise RuntimeError("Result queue not initialized") + + response = scheduler.result_mq.dequeue(timeout=dequeue_timeout) + + # Check if response indicates an error + if isinstance(response, dict) and response.get("status") == "error": + raise RuntimeError( + f"Worker failed with error '{response.get('error')}', " + "please check the stack trace above for the root cause" + ) + + responses.append(response) + except TimeoutError as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e + + return responses[0] if unique_reply_rank is not None else responses + + except Exception as e: + logger.error(f"RPC call failed: {e}") + raise + def _dummy_run(self): """A dummy run to warm up the model.""" prompt = "dummy run" diff --git a/vllm_omni/diffusion/scheduler.py b/vllm_omni/diffusion/scheduler.py index 4d45812b5..7f7902b80 100644 --- a/vllm_omni/diffusion/scheduler.py +++ b/vllm_omni/diffusion/scheduler.py @@ -29,7 +29,7 @@ def initialize(self, od_config: OmniDiffusionConfig): self.od_config = od_config self.context = zmq.Context() # Standard synchronous context - # Initialize MessageQueue for broadcasting requests + # Initialize single MessageQueue for all message types (generation & RPC) # Assuming all readers are local for now as per current launch_engine implementation self.mq = MessageQueue( n_reader=self.num_workers, @@ -51,9 +51,20 @@ def get_broadcast_handle(self): def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: """Sends a request to the scheduler and waits for the response.""" try: - # Broadcast request to all workers - self.mq.enqueue(requests) + # Prepare RPC request for generation + rpc_request = { + "type": "rpc", + "method": "generate", + "args": (requests,), + "kwargs": {}, + "output_rank": 0, + "exec_all_ranks": True, + } + + # Broadcast RPC request to all workers + self.mq.enqueue(rpc_request) # Wait for result from Rank 0 (or whoever sends it) + if self.result_mq is None: raise RuntimeError("Result queue not initialized") diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index dfaafa5de..d05f96358 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -13,7 +13,6 @@ from vllm_omni.diffusion.cache.selector import get_cache_backend from vllm_omni.diffusion.data import ( - SHUTDOWN_MESSAGE, DiffusionOutput, OmniDiffusionConfig, set_current_omni_diffusion_config, @@ -107,6 +106,18 @@ def init_device_and_model(self) -> None: if self.cache_backend is not None: self.cache_backend.enable(self.pipeline) + def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: + """ + Generate output for the given requests. + + Args: + requests: List of diffusion requests + + Returns: + DiffusionOutput with generated results + """ + return self.execute_model(requests, self.od_config) + @torch.inference_mode() def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: """ @@ -141,7 +152,7 @@ def __init__( # Inter-process Communication self.context = zmq.Context(io_threads=2) - # Initialize MessageQueue reader from handle + # Initialize MessageQueue reader from handle (unified for generation & RPC) self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) self.result_mq = None @@ -173,12 +184,39 @@ def return_result(self, output: DiffusionOutput): if self.result_mq is not None: self.result_mq.enqueue(output) - def recv_reqs(self): + def recv_message(self): """ - Receive requests from broadcast queue + Receive unified messages (RPC requests, shutdown) from broadcast queue. + Uses indefinite=True to block until a message arrives. """ return self.mq.dequeue(indefinite=True) + def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]: + """Execute an RPC request and indicate whether to reply.""" + + method = rpc_request["method"] + args = rpc_request.get("args", ()) + kwargs = rpc_request.get("kwargs", {}) + output_rank = rpc_request.get("output_rank") + exec_all_ranks = rpc_request.get("exec_all_ranks", False) + + should_execute = exec_all_ranks or output_rank is None or output_rank == self.gpu_id + should_reply = (output_rank is None or output_rank == self.gpu_id) and self.result_mq is not None + + if not should_execute: + return None, False + + try: + if isinstance(method, str): + func = getattr(self.worker, method) + result = func(*args, **kwargs) + else: + result = method(self.worker, *args, **kwargs) + return result, should_reply + except Exception as e: + logger.error(f"Error executing RPC: {e}", exc_info=True) + return {"status": "error", "error": str(e)}, should_reply + # TODO: queueing, cancellation def worker_busy_loop(self) -> None: """Main busy loop for Multiprocessing Workers""" @@ -186,42 +224,57 @@ def worker_busy_loop(self) -> None: logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") while self._running: - reqs = None - # 1: receive requests + # Receive unified message (generation request, RPC request, or shutdown) + msg = None try: - reqs = self.recv_reqs() + msg = self.recv_message() except Exception as e: logger.error( - f"Error receiving requests in scheduler event loop: {e}", + f"Error receiving message in worker loop: {e}", exc_info=True, ) continue - if reqs == SHUTDOWN_MESSAGE: - logger.info("Worker %s: Received shutdown message", self.gpu_id) - self._running = False - continue - if reqs is None: + if msg is None: logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) continue - # 2: execute, make sure a reply is always sent - try: - output = self.worker.execute_model(reqs, self.od_config) - except Exception as e: - logger.error( - f"Error executing forward in event loop: {e}", - exc_info=True, - ) - output = DiffusionOutput(error=str(e)) - - try: - self.return_result(output) - except zmq.ZMQError as e: - # Reply failed; log and keep loop alive to accept future requests - logger.error(f"ZMQ error sending reply: {e}") + # Route message based on type + if isinstance(msg, dict) and msg.get("type") == "rpc": + # Handle RPC request + try: + result, should_reply = self.execute_rpc(msg) + if should_reply: + self.return_result(result) + except Exception as e: + logger.error(f"Error processing RPC: {e}", exc_info=True) + if self.result_mq is not None: + self.return_result({"status": "error", "error": str(e)}) + + elif isinstance(msg, dict) and msg.get("type") == "shutdown": + # Handle shutdown message + logger.info("Worker %s: Received shutdown message", self.gpu_id) + self._running = False continue + else: + # Handle generation request (OmniDiffusionRequest list) + try: + output = self.worker.execute_model(msg, self.od_config) + except Exception as e: + logger.error( + f"Error executing forward in event loop: {e}", + exc_info=True, + ) + output = DiffusionOutput(error=str(e)) + + try: + self.return_result(output) + except zmq.ZMQError as e: + # Reply failed; log and keep loop alive to accept future requests + logger.error(f"ZMQ error sending reply: {e}") + continue + logger.info("event loop terminated.") try: self.worker.shutdown()