-
Notifications
You must be signed in to change notification settings - Fork 207
RPC support for OmniDiffusion #371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
0835c74
854620c
9eb3db7
86e3319
971add7
98c6344
23e6c3e
e5ae766
2b169ed
7c6d3d4
2a42041
b2f3dde
7550475
c3dff61
b9d74f3
b99dca4
a86979f
add09c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm_omni.diffusion.data import DiffusionOutput | ||
| from vllm_omni.diffusion.request import OmniDiffusionRequest | ||
|
|
||
| # ruff: noqa: E402 | ||
| REPO_ROOT = Path(__file__).resolve().parents[2] | ||
| if str(REPO_ROOT) not in sys.path: | ||
| sys.path.insert(0, str(REPO_ROOT)) | ||
|
|
||
| from vllm_omni import Omni | ||
|
|
||
| os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" | ||
|
|
||
| models = ["Tongyi-MAI/Z-Image-Turbo"] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_name", models) | ||
| def test_diffusion_model(model_name: str): | ||
| m = Omni(model=model_name) | ||
| # high resolution may cause OOM on L4 | ||
| height = 256 | ||
| width = 256 | ||
| request = OmniDiffusionRequest( | ||
| prompt="a photo of a cat sitting on a laptop keyboard", | ||
| height=height, | ||
| width=width, | ||
| num_inference_steps=2, | ||
| guidance_scale=0.0, | ||
| generator=torch.Generator("cuda").manual_seed(42), | ||
| num_outputs_per_prompt=2, | ||
| ) | ||
| results = m.instance.engine.collective_rpc( | ||
| method="generate", | ||
| args=([request],), | ||
| kwargs={}, | ||
| unique_reply_rank=0, | ||
| ) | ||
| assert isinstance(results, DiffusionOutput) | ||
| assert results.output.shape[2] == width | ||
| assert results.output.shape[3] == height | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| import os | ||
| import time | ||
|
|
||
| import cloudpickle | ||
| import torch | ||
| import zmq | ||
| from vllm.config import LoadConfig, VllmConfig, set_current_vllm_config | ||
|
|
@@ -17,7 +18,6 @@ | |
|
|
||
| from vllm_omni.diffusion.cache.selector import get_cache_backend | ||
| from vllm_omni.diffusion.data import ( | ||
| SHUTDOWN_MESSAGE, | ||
| DiffusionOutput, | ||
| OmniDiffusionConfig, | ||
| ) | ||
|
|
@@ -91,6 +91,18 @@ def init_device_and_model(self) -> None: | |
| if self.cache_backend is not None: | ||
| self.cache_backend.enable(self.pipeline) | ||
|
|
||
| def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: | ||
| """ | ||
| Generate output for the given requests. | ||
|
|
||
| Args: | ||
| requests: List of diffusion requests | ||
|
|
||
| Returns: | ||
| DiffusionOutput with generated results | ||
| """ | ||
| return self.execute_model(requests, self.od_config) | ||
|
|
||
| @torch.inference_mode() | ||
| def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: | ||
| """ | ||
|
|
@@ -130,7 +142,7 @@ def __init__( | |
| # Inter-process Communication | ||
| self.context = zmq.Context(io_threads=2) | ||
|
|
||
| # Initialize MessageQueue reader from handle | ||
| # Initialize MessageQueue reader from handle (unified for generation & RPC) | ||
| self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) | ||
|
|
||
| self.result_mq = None | ||
|
|
@@ -162,55 +174,101 @@ def return_result(self, output: DiffusionOutput): | |
| if self.result_mq is not None: | ||
| self.result_mq.enqueue(output) | ||
|
|
||
| def recv_reqs(self): | ||
| def recv_message(self): | ||
| """ | ||
| Receive requests from broadcast queue | ||
| Receive unified messages (RPC requests, shutdown) from broadcast queue. | ||
| Uses indefinite=True to block until a message arrives. | ||
| """ | ||
| return self.mq.dequeue(indefinite=True) | ||
|
|
||
| def execute_rpc(self, rpc_request: dict): | ||
| """Execute an RPC request and return the result.""" | ||
| try: | ||
| method = rpc_request["method"] | ||
| args = rpc_request.get("args", ()) | ||
| kwargs = rpc_request.get("kwargs", {}) | ||
| output_rank = rpc_request.get("output_rank") | ||
|
|
||
| # Only execute if we should reply (either output_rank is None or matches our rank) | ||
| if output_rank is not None and output_rank != self.gpu_id: | ||
| return None | ||
|
|
||
| # Deserialize method if it's a callable | ||
| if isinstance(method, bytes): | ||
| method = cloudpickle.loads(method) | ||
|
|
||
| # Execute the method | ||
| if isinstance(method, str): | ||
| # Method is a string, call it on the worker | ||
| func = getattr(self.worker, method) | ||
| result = func(*args, **kwargs) | ||
| else: | ||
| # Method is a callable | ||
| result = method(self.worker, *args, **kwargs) | ||
|
|
||
| return result | ||
| except Exception as e: | ||
| logger.error(f"Error executing RPC: {e}", exc_info=True) | ||
| return {"status": "error", "error": str(e)} | ||
|
|
||
| # TODO: queueing, cancellation | ||
| def worker_busy_loop(self) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should redesign this, because it has become more complex. we can address the redesign in a separate, follow-up task if you don't have time. here is a good example of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with it, the redesign is WIP and will need a more structure RFC.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just confirming — is this already WIP?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is not WIP. But the redesign as you said above is on working
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ZJY0516 It's ready now. Could you take a look again thanks ? |
||
| """Main busy loop for Multiprocessing Workers""" | ||
|
|
||
| logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") | ||
|
|
||
| while self._running: | ||
| reqs = None | ||
| # 1: receive requests | ||
| # Receive unified message (generation request, RPC request, or shutdown) | ||
| msg = None | ||
| try: | ||
| reqs = self.recv_reqs() | ||
| msg = self.recv_message() | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Error receiving requests in scheduler event loop: {e}", | ||
| f"Error receiving message in worker loop: {e}", | ||
| exc_info=True, | ||
| ) | ||
| continue | ||
|
|
||
| if reqs == SHUTDOWN_MESSAGE: | ||
| logger.info("Worker %s: Received shutdown message", self.gpu_id) | ||
| self._running = False | ||
| continue | ||
| if reqs is None: | ||
| if msg is None: | ||
| logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) | ||
| continue | ||
|
|
||
| # 2: execute, make sure a reply is always sent | ||
| try: | ||
| output = self.worker.execute_model(reqs, self.od_config) | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Error executing forward in event loop: {e}", | ||
| exc_info=True, | ||
| ) | ||
| output = DiffusionOutput(error=str(e)) | ||
|
|
||
| try: | ||
| self.return_result(output) | ||
| except zmq.ZMQError as e: | ||
| # Reply failed; log and keep loop alive to accept future requests | ||
| logger.error(f"ZMQ error sending reply: {e}") | ||
| # Route message based on type | ||
| if isinstance(msg, dict) and msg.get("type") == "rpc": | ||
| # Handle RPC request | ||
| try: | ||
| result = self.execute_rpc(msg) | ||
| if result is not None and self.gpu_id == 0: | ||
| self.return_result(result) | ||
|
Comment on lines
249
to
252
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| except Exception as e: | ||
| logger.error(f"Error processing RPC: {e}", exc_info=True) | ||
| if self.gpu_id == 0: | ||
| self.return_result({"status": "error", "error": str(e)}) | ||
|
|
||
| elif isinstance(msg, dict) and msg.get("type") == "shutdown": | ||
| # Handle shutdown message | ||
| logger.info("Worker %s: Received shutdown message", self.gpu_id) | ||
| self._running = False | ||
| continue | ||
|
|
||
| else: | ||
| # Handle generation request (OmniDiffusionRequest list) | ||
| try: | ||
| output = self.worker.execute_model(msg, self.od_config) | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Error executing forward in event loop: {e}", | ||
| exc_info=True, | ||
| ) | ||
| output = DiffusionOutput(error=str(e)) | ||
|
|
||
| try: | ||
| self.return_result(output) | ||
| except zmq.ZMQError as e: | ||
| # Reply failed; log and keep loop alive to accept future requests | ||
| logger.error(f"ZMQ error sending reply: {e}") | ||
| continue | ||
|
|
||
| logger.info("event loop terminated.") | ||
| try: | ||
| self.worker.shutdown() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is a little stange here. I don't think we need an e2e test here. We can test it after #376 lands. cc @SamitHuang
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, tests for this PR can be covered by the tests in #376
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@knlnguyen1802 Could you please remove this file? Once that‘s done, we can merge this PR.