Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/e2e/test_rpc_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
Copy link
Collaborator

@ZJY0516 ZJY0516 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is a little stange here. I don't think we need an e2e test here. We can test it after #376 lands. cc @SamitHuang

Copy link
Collaborator

@SamitHuang SamitHuang Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, tests for this PR can be covered by the tests in #376

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knlnguyen1802 Could you please remove this file? Once that‘s done, we can merge this PR.

import sys
from pathlib import Path

import pytest
import torch

from vllm_omni.diffusion.data import DiffusionOutput
from vllm_omni.diffusion.request import OmniDiffusionRequest

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from vllm_omni import Omni

os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"

models = ["Tongyi-MAI/Z-Image-Turbo"]


@pytest.mark.parametrize("model_name", models)
def test_diffusion_model(model_name: str):
m = Omni(model=model_name)
# high resolution may cause OOM on L4
height = 256
width = 256
request = OmniDiffusionRequest(
prompt="a photo of a cat sitting on a laptop keyboard",
height=height,
width=width,
num_inference_steps=2,
guidance_scale=0.0,
generator=torch.Generator("cuda").manual_seed(42),
num_outputs_per_prompt=2,
)
results = m.instance.engine.collective_rpc(
method="generate",
args=([request],),
kwargs={},
unique_reply_rank=0,
)
assert isinstance(results, DiffusionOutput)
assert results.output.shape[2] == width
assert results.output.shape[3] == height
84 changes: 81 additions & 3 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import multiprocessing as mp
import pickle
import time
from collections.abc import Callable
from typing import Any

import cloudpickle
from vllm.logger import init_logger

from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, OmniDiffusionConfig
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.registry import get_diffusion_post_process_func, get_diffusion_pre_process_func
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.scheduler import scheduler
Expand Down Expand Up @@ -151,16 +155,90 @@ def _launch_workers(self, broadcast_handle):
def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
return scheduler.add_req(requests)

def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
unique_reply_rank: int | None = None,
) -> Any:
"""Call a method on worker processes and get results immediately.

Args:
method: The method name (str) or callable to execute on workers
timeout: Optional timeout in seconds
args: Positional arguments for the method
kwargs: Keyword arguments for the method
unique_reply_rank: If set, only get reply from this rank

Returns:
Single result if unique_reply_rank is provided, otherwise list of results
"""
if self._closed:
raise RuntimeError("DiffusionEngine is closed.")

deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}

# Prepare the method to send
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)

# Prepare RPC request message
rpc_request = {
"type": "rpc",
"method": send_method,
"args": args,
"kwargs": kwargs,
"output_rank": unique_reply_rank,
}

try:
# Broadcast RPC request to all workers via unified message queue
scheduler.mq.enqueue(rpc_request)

# Determine which workers we expect responses from
num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus

responses = []
for _ in range(num_responses):
dequeue_timeout = None if deadline is None else (deadline - time.monotonic())
try:
if scheduler.result_mq is None:
raise RuntimeError("Result queue not initialized")

response = scheduler.result_mq.dequeue(timeout=dequeue_timeout)

# Check if response indicates an error
if isinstance(response, dict) and response.get("status") == "error":
raise RuntimeError(
f"Worker failed with error '{response.get('error')}', "
"please check the stack trace above for the root cause"
)

responses.append(response)
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e

return responses[0] if unique_reply_rank is not None else responses

except Exception as e:
logger.error(f"RPC call failed: {e}")
raise

def close(self, *, timeout_s: float = 30.0) -> None:
if self._closed:
return
self._closed = True

# Send shutdown signal to worker processes via broadcast queue
# Send shutdown signal to worker processes via unified broadcast queue
try:
if getattr(scheduler, "mq", None) is not None:
for _ in range(self.od_config.num_gpus or 1):
scheduler.mq.enqueue(SHUTDOWN_MESSAGE)
scheduler.mq.enqueue({"type": "shutdown"})
except Exception as exc: # pragma: no cover - best effort cleanup
logger.warning("Failed to send shutdown signal: %s", exc)

Expand Down
16 changes: 13 additions & 3 deletions vllm_omni/diffusion/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def initialize(self, od_config: OmniDiffusionConfig):
self.od_config = od_config
self.context = zmq.Context() # Standard synchronous context

# Initialize MessageQueue for broadcasting requests
# Initialize single MessageQueue for all message types (generation & RPC)
# Assuming all readers are local for now as per current launch_engine implementation
self.mq = MessageQueue(
n_reader=od_config.num_gpus,
Expand All @@ -50,9 +50,19 @@ def get_broadcast_handle(self):
def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
"""Sends a request to the scheduler and waits for the response."""
try:
# Broadcast request to all workers
self.mq.enqueue(requests)
# Prepare RPC request for generation
rpc_request = {
"type": "rpc",
"method": "generate",
"args": (requests,),
"kwargs": {},
"output_rank": 0,
}

# Broadcast RPC request to all workers
self.mq.enqueue(rpc_request)
# Wait for result from Rank 0 (or whoever sends it)

if self.result_mq is None:
raise RuntimeError("Result queue not initialized")

Expand Down
114 changes: 86 additions & 28 deletions vllm_omni/diffusion/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time

import cloudpickle
import torch
import zmq
from vllm.config import LoadConfig, VllmConfig, set_current_vllm_config
Expand All @@ -17,7 +18,6 @@

from vllm_omni.diffusion.cache.selector import get_cache_backend
from vllm_omni.diffusion.data import (
SHUTDOWN_MESSAGE,
DiffusionOutput,
OmniDiffusionConfig,
)
Expand Down Expand Up @@ -91,6 +91,18 @@ def init_device_and_model(self) -> None:
if self.cache_backend is not None:
self.cache_backend.enable(self.pipeline)

def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
"""
Generate output for the given requests.

Args:
requests: List of diffusion requests

Returns:
DiffusionOutput with generated results
"""
return self.execute_model(requests, self.od_config)

@torch.inference_mode()
def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
"""
Expand Down Expand Up @@ -130,7 +142,7 @@ def __init__(
# Inter-process Communication
self.context = zmq.Context(io_threads=2)

# Initialize MessageQueue reader from handle
# Initialize MessageQueue reader from handle (unified for generation & RPC)
self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id)

self.result_mq = None
Expand Down Expand Up @@ -162,55 +174,101 @@ def return_result(self, output: DiffusionOutput):
if self.result_mq is not None:
self.result_mq.enqueue(output)

def recv_reqs(self):
def recv_message(self):
"""
Receive requests from broadcast queue
Receive unified messages (RPC requests, shutdown) from broadcast queue.
Uses indefinite=True to block until a message arrives.
"""
return self.mq.dequeue(indefinite=True)

def execute_rpc(self, rpc_request: dict):
"""Execute an RPC request and return the result."""
try:
method = rpc_request["method"]
args = rpc_request.get("args", ())
kwargs = rpc_request.get("kwargs", {})
output_rank = rpc_request.get("output_rank")

# Only execute if we should reply (either output_rank is None or matches our rank)
if output_rank is not None and output_rank != self.gpu_id:
return None

# Deserialize method if it's a callable
if isinstance(method, bytes):
method = cloudpickle.loads(method)

# Execute the method
if isinstance(method, str):
# Method is a string, call it on the worker
func = getattr(self.worker, method)
result = func(*args, **kwargs)
else:
# Method is a callable
result = method(self.worker, *args, **kwargs)

return result
except Exception as e:
logger.error(f"Error executing RPC: {e}", exc_info=True)
return {"status": "error", "error": str(e)}

# TODO: queueing, cancellation
def worker_busy_loop(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should redesign this, because it has become more complex.

we can address the redesign in a separate, follow-up task if you don't have time.

here is a good example of worker_busy_loop: https://github.com/vllm-project/vllm/blob/c02a2705f9ceeb00b5d32453621f997b2ceafbea/vllm/v1/executor/multiproc_executor.py#L806

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with it, the redesign is WIP and will need a more structure RFC.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming — is this already WIP?

Copy link
Contributor Author

@knlnguyen1802 knlnguyen1802 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is not WIP. But the redesign as you said above is on working

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZJY0516 It's ready now. Could you take a look again thanks ?

"""Main busy loop for Multiprocessing Workers"""

logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory")

while self._running:
reqs = None
# 1: receive requests
# Receive unified message (generation request, RPC request, or shutdown)
msg = None
try:
reqs = self.recv_reqs()
msg = self.recv_message()
except Exception as e:
logger.error(
f"Error receiving requests in scheduler event loop: {e}",
f"Error receiving message in worker loop: {e}",
exc_info=True,
)
continue

if reqs == SHUTDOWN_MESSAGE:
logger.info("Worker %s: Received shutdown message", self.gpu_id)
self._running = False
continue
if reqs is None:
if msg is None:
logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id)
continue

# 2: execute, make sure a reply is always sent
try:
output = self.worker.execute_model(reqs, self.od_config)
except Exception as e:
logger.error(
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))

try:
self.return_result(output)
except zmq.ZMQError as e:
# Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
# Route message based on type
if isinstance(msg, dict) and msg.get("type") == "rpc":
# Handle RPC request
try:
result = self.execute_rpc(msg)
if result is not None and self.gpu_id == 0:
self.return_result(result)
Comment on lines 249 to 252

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge collective_rpc waits for replies other ranks never send

collective_rpc expects a reply from each worker unless unique_reply_rank is set, but in worker_busy_loop only rank 0 enqueues RPC responses (self.gpu_id == 0 gate) and other ranks drop their results because they lack a result queue. On multi-GPU runs any RPC targeting a non-zero rank or broadcast calls with unique_reply_rank=None will block/time out waiting for responses that are never sent.

Useful? React with 👍 / 👎.

except Exception as e:
logger.error(f"Error processing RPC: {e}", exc_info=True)
if self.gpu_id == 0:
self.return_result({"status": "error", "error": str(e)})

elif isinstance(msg, dict) and msg.get("type") == "shutdown":
# Handle shutdown message
logger.info("Worker %s: Received shutdown message", self.gpu_id)
self._running = False
continue

else:
# Handle generation request (OmniDiffusionRequest list)
try:
output = self.worker.execute_model(msg, self.od_config)
except Exception as e:
logger.error(
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))

try:
self.return_result(output)
except zmq.ZMQError as e:
# Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
continue

logger.info("event loop terminated.")
try:
self.worker.shutdown()
Expand Down