diff --git a/docs/api/README.md b/docs/api/README.md index 0332dceff..cf40a9577 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -4,8 +4,10 @@ Main entry points for vLLM-Omni inference and serving. + +- [vllm_omni.entrypoints.async_omni_diffusion.AsyncOmniDiffusion][] - [vllm_omni.entrypoints.async_omni.AsyncOmni][] -- [vllm_omni.entrypoints.async_omni.AsyncOmniStageLLM][] +- [vllm_omni.entrypoints.async_omni_llm.AsyncOmniLLM][] - [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalContentParser][] - [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalItemTracker][] - [vllm_omni.entrypoints.chat_utils.parse_chat_messages_futures][] diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index deff854e6..20921a8d0 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -61,7 +61,7 @@ import torch from PIL import Image -from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig, logger from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -317,45 +317,79 @@ def main(): print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}") print(f"{'=' * 60}\n") - generation_start = time.perf_counter() - # Generate edited image - images = omni.generate( - prompt=args.prompt, - pil_image=input_image, - negative_prompt=args.negative_prompt, - generator=generator, - true_cfg_scale=args.cfg_scale, - guidance_scale=args.guidance_scale, - num_inference_steps=args.num_inference_steps, - num_outputs_per_prompt=args.num_outputs_per_prompt, - layers=args.layers, - ) - generation_end = time.perf_counter() - generation_time = generation_end - generation_start - - # Print profiling results - print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") - - # Save output image(s) - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - suffix = output_path.suffix or ".png" - stem = output_path.stem or "output_image_edit" - - if args.num_outputs_per_prompt <= 1: - img = images[0] - img = img if isinstance(img, list) else [img] - for sub_idx, sub_img in enumerate(img): - save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}" - sub_img.save(save_path) - print(f"Saved edited image to {os.path.abspath(save_path)}") - else: - for idx, img in enumerate(images): - img = img if isinstance(img, list) else [img] - for sub_idx, sub_img in enumerate(img): - save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}" - sub_img.save(save_path) - print(f"Saved edited image to {os.path.abspath(save_path)}") + try: + generation_start = time.perf_counter() + # Generate edited image + generate_kwargs = { + "prompt": args.prompt, + "pil_image": input_image, + "negative_prompt": args.negative_prompt, + "generator": generator, + "true_cfg_scale": args.cfg_scale, + "guidance_scale": args.guidance_scale, + "num_inference_steps": args.num_inference_steps, + "num_outputs_per_prompt": args.num_outputs_per_prompt, + "layers": args.layers, + "resolution": args.resolution, + } + + outputs = omni.generate(**generate_kwargs) + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + # Print profiling results + print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + + if not outputs: + raise ValueError("No output generated from omni.generate()") + logger.info("Outputs: %s", outputs) + + # Extract images from OmniRequestOutput + # omni.generate() returns list[OmniRequestOutput], extract images from request_output[0]['images'] + first_output = outputs[0] + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, dict) or "images" not in req_out: + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out["images"] + if not images: + raise ValueError("No images found in request_output") + + # Save output image(s) + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".png" + stem = output_path.stem or "output_image_edit" + + # Handle layered output (each image may be a list of layers) + if args.num_outputs_per_prompt <= 1: + img = images[0] + # Check if this is a layered output (list of images) + if isinstance(img, list): + for sub_idx, sub_img in enumerate(img): + save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}" + sub_img.save(save_path) + print(f"Saved edited image to {os.path.abspath(save_path)}") + else: + img.save(output_path) + print(f"Saved edited image to {os.path.abspath(output_path)}") + else: + for idx, img in enumerate(images): + # Check if this is a layered output (list of images) + if isinstance(img, list): + for sub_idx, sub_img in enumerate(img): + save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}" + sub_img.save(save_path) + print(f"Saved edited image to {os.path.abspath(save_path)}") + else: + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + img.save(save_path) + print(f"Saved edited image to {os.path.abspath(save_path)}") + finally: + omni.close() if __name__ == "__main__": diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 753f7cc36..7b13e10e8 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -273,6 +273,7 @@ def main(args): # Save audio file with explicit WAV format sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV") print(f"Request ID: {request_id}, Saved audio to {output_wav}") + omni_llm.close() def parse_args(): diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 21e752254..1b5e1c069 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -7,7 +7,7 @@ import torch -from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig, logger from vllm_omni.entrypoints.omni import Omni from vllm_omni.utils.platform_utils import detect_device_type, is_npu @@ -20,7 +20,7 @@ def parse_args() -> argparse.Namespace: help="Diffusion model name or local path. Supported models: Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo", ) parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") - parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.") + parser.add_argument("--seed", type=int, default=142, help="Random seed for deterministic results.") parser.add_argument( "--cfg_scale", type=float, @@ -127,7 +127,7 @@ def main(): print(f"{'=' * 60}\n") generation_start = time.perf_counter() - images = omni.generate( + outputs = omni.generate( args.prompt, height=args.height, width=args.width, @@ -142,11 +142,30 @@ def main(): # Print profiling results print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + # Extract images from OmniRequestOutput + # omni.generate() returns list[OmniRequestOutput], extract images from the first output + if not outputs or len(outputs) == 0: + raise ValueError("No output generated from omni.generate()") + logger.info(f"Outputs: {outputs}") + + # Extract images from request_output[0]['images'] + first_output = outputs[0] + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, dict) or "images" not in req_out: + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out["images"] + if not images: + raise ValueError("No images found in request_output") + output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) suffix = output_path.suffix or ".png" stem = output_path.stem or "qwen_image_output" - if args.num_images_per_prompt <= 1: + if len(images) <= 1: images[0].save(output_path) print(f"Saved generated image to {output_path}") else: @@ -155,6 +174,8 @@ def main(): img.save(save_path) print(f"Saved generated image to {save_path}") + omni.close() + if __name__ == "__main__": main() diff --git a/examples/online_serving/text_to_image/gradio_demo.py b/examples/online_serving/text_to_image/gradio_demo.py index 20d59fea7..608db9a23 100644 --- a/examples/online_serving/text_to_image/gradio_demo.py +++ b/examples/online_serving/text_to_image/gradio_demo.py @@ -24,6 +24,7 @@ def generate_image( seed: int | None, negative_prompt: str, server_url: str, + num_outputs_per_prompt: int = 1, ) -> Image.Image | None: """Generate an image using the chat completions API.""" messages = [{"role": "user", "content": prompt}] @@ -39,6 +40,8 @@ def generate_image( extra_body["seed"] = seed if negative_prompt: extra_body["negative_prompt"] = negative_prompt + # Keep consistent with run_curl_text_to_image.sh, always send num_outputs_per_prompt + extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt # Build request payload payload = {"messages": messages, "extra_body": extra_body} @@ -109,7 +112,8 @@ def create_demo(server_url: str): label="Inference Steps", minimum=10, maximum=100, - value=50, + # Default steps aligned with run_curl_text_to_image.sh to 100 + value=100, step=5, ) cfg_scale = gr.Slider( @@ -138,16 +142,26 @@ def create_demo(server_url: str): # Examples gr.Examples( examples=[ - ["A beautiful landscape painting with misty mountains", "", 1024, 1024, 50, 4.0, 42], - ["A cute cat sitting on a windowsill with sunlight", "", 1024, 1024, 50, 4.0, 123], - ["Cyberpunk style futuristic city with neon lights", "blurry, low quality", 1024, 768, 50, 4.0, 456], - ["Chinese ink painting of bamboo forest with a house", "", 768, 1024, 50, 4.0, 789], + ["A beautiful landscape painting with misty mountains", "", 1024, 1024, 100, 4.0, 42], + ["A cute cat sitting on a windowsill with sunlight", "", 1024, 1024, 100, 4.0, 123], + ["Cyberpunk style futuristic city with neon lights", "blurry, low quality", 1024, 768, 100, 4.0, 456], + ["Chinese ink painting of bamboo forest with a house", "", 768, 1024, 100, 4.0, 789], ], inputs=[prompt, negative_prompt, height, width, steps, cfg_scale, seed], ) generate_btn.click( - fn=lambda p, h, w, st, c, se, n: generate_image(p, h, w, st, c, se if se >= 0 else None, n, server_url), + fn=lambda p, h, w, st, c, se, n: generate_image( + p, + h, + w, + st, + c, + se if se >= 0 else None, + n, + server_url, + 1, + ), inputs=[prompt, height, width, steps, cfg_scale, seed, negative_prompt], outputs=[output_image], ) diff --git a/examples/online_serving/text_to_image/openai_chat_client.py b/examples/online_serving/text_to_image/openai_chat_client.py index cc5e00139..c529bf203 100644 --- a/examples/online_serving/text_to_image/openai_chat_client.py +++ b/examples/online_serving/text_to_image/openai_chat_client.py @@ -57,8 +57,7 @@ def generate_image( extra_body["seed"] = seed if negative_prompt: extra_body["negative_prompt"] = negative_prompt - if num_outputs_per_prompt > 1: - extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt + extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt # Build request payload payload = {"messages": messages} diff --git a/examples/online_serving/text_to_image/run_curl_text_to_image.sh b/examples/online_serving/text_to_image/run_curl_text_to_image.sh index d90d11c5a..2f559d563 100755 --- a/examples/online_serving/text_to_image/run_curl_text_to_image.sh +++ b/examples/online_serving/text_to_image/run_curl_text_to_image.sh @@ -2,8 +2,9 @@ # Qwen-Image text-to-image curl example SERVER="${SERVER:-http://localhost:8091}" -PROMPT="${PROMPT:-a cup of coffee on the table}" -OUTPUT="${OUTPUT:-qwen_image_output.png}" +PROMPT="${PROMPT:-a good boy in the ocean}" +CURRENT_TIME=$(date +%Y%m%d%H%M%S) +OUTPUT="${OUTPUT:-qwen_image_output_${CURRENT_TIME}.png}" echo "Generating image..." echo "Prompt: $PROMPT" @@ -18,12 +19,12 @@ curl -s "$SERVER/v1/chat/completions" \ \"extra_body\": { \"height\": 1024, \"width\": 1024, - \"num_inference_steps\": 50, + \"num_inference_steps\": 100, \"true_cfg_scale\": 4.0, \"seed\": 42, \"num_outputs_per_prompt\": 1 } - }" | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2 | base64 -d > "$OUTPUT" + }" | jq -r '.choices[0].message.content[0].image_url.url' | sed 's/^data:image[^,]*,\s*//' | base64 -d > "$OUTPUT" if [ -f "$OUTPUT" ]; then echo "Image saved to: $OUTPUT" diff --git a/tests/e2e/offline_inference/conftest.py b/tests/e2e/offline_inference/conftest.py index a24c63bff..6322ce766 100644 --- a/tests/e2e/offline_inference/conftest.py +++ b/tests/e2e/offline_inference/conftest.py @@ -69,7 +69,7 @@ def get_default_sampling_params_list(self) -> list[SamplingParams]: Returns: List of SamplingParams with default decoding for each stage """ - return [st.default_sampling_params for st in self.omni.instance.stage_list] + return [st.default_sampling_params for st in self.omni.stage_list] def get_omni_inputs( self, @@ -337,8 +337,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): def close(self): """Close and cleanup the Omni instance.""" - if hasattr(self.omni.instance, "close"): - self.omni.instance.close() + if hasattr(self.omni, "close"): + self.omni.close() @pytest.fixture(scope="session") diff --git a/tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml index fec6169d8..3033426aa 100644 --- a/tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml +++ b/tests/e2e/offline_inference/stage_configs/qwen3_omni_ci.yaml @@ -81,6 +81,7 @@ stage_args: engine_output_type: audio # Final output: audio waveform gpu_memory_utilization: 0.1 distributed_executor_backend: "mp" + max_num_seqs: 1 max_num_batched_tokens: 1000000 hf_config_name: thinker_config load_format: dummy diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py index 0113b20e2..9f1a6ee07 100644 --- a/tests/e2e/offline_inference/test_cache_dit.py +++ b/tests/e2e/offline_inference/test_cache_dit.py @@ -51,7 +51,7 @@ def test_cache_dit(model_name: str): width = 256 num_inference_steps = 4 # Minimal steps for fast test - images = m.generate( + outputs = m.generate( "a photo of a cat sitting on a laptop keyboard", height=height, width=width, @@ -60,6 +60,17 @@ def test_cache_dit(model_name: str): generator=torch.Generator("cuda").manual_seed(42), num_outputs_per_prompt=1, # Single output for speed ) + # Extract images from request_output[0]['images'] + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, dict) or "images" not in req_out: + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out["images"] # Verify generation succeeded assert images is not None @@ -67,3 +78,5 @@ def test_cache_dit(model_name: str): # Check image size assert images[0].width == width assert images[0].height == height + # manually close the Omni instance + m.close() diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py index e2d732b65..93a802394 100644 --- a/tests/e2e/offline_inference/test_sequence_parallel.py +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -78,7 +78,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in dtype=dtype, ) try: - baseline_images = baseline.generate( + outputs = baseline.generate( PROMPT, height=height, width=width, @@ -87,6 +87,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in generator=torch.Generator(get_device_name()).manual_seed(seed), num_outputs_per_prompt=1, ) + baseline_images = outputs[0].request_output[0]["images"] finally: baseline.close() @@ -103,7 +104,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in dtype=dtype, ) try: - sp_images = sp.generate( + outputs = sp.generate( PROMPT, height=height, width=width, @@ -112,6 +113,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in generator=torch.Generator(get_device_name()).manual_seed(seed), num_outputs_per_prompt=1, ) + sp_images = outputs[0].request_output[0]["images"] finally: sp.close() diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py index 61f0c7b12..d448b920c 100644 --- a/tests/e2e/offline_inference/test_t2i_model.py +++ b/tests/e2e/offline_inference/test_t2i_model.py @@ -37,7 +37,7 @@ def test_diffusion_model(model_name: str): # high resolution may cause OOM on L4 height = 256 width = 256 - images = m.generate( + outputs = m.generate( "a photo of a cat sitting on a laptop keyboard", height=height, width=width, @@ -46,8 +46,22 @@ def test_diffusion_model(model_name: str): generator=torch.Generator("cuda").manual_seed(42), num_outputs_per_prompt=2, ) + # Extract images from request_output[0]['images'] + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, dict) or "images" not in req_out: + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out["images"] + assert len(images) == 2 # check image size assert images[0].width == width assert images[0].height == height images[0].save("image_output.png") + # manually close the Omni instance + m.close() diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py index b7d5f7229..2e81c87fa 100644 --- a/tests/e2e/offline_inference/test_t2v_model.py +++ b/tests/e2e/offline_inference/test_t2v_model.py @@ -30,7 +30,7 @@ def test_video_diffusion_model(model_name: str): height = 480 width = 640 num_frames = 5 - frames = m.generate( + outputs = m.generate( "A cat sitting on a table", height=height, width=width, @@ -39,6 +39,16 @@ def test_video_diffusion_model(model_name: str): guidance_scale=1.0, generator=torch.Generator("cuda").manual_seed(42), ) + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, dict) or "images" not in req_out: + raise ValueError("Invalid request_output structure or missing 'images' key") + + frames = req_out["images"][0] assert frames is not None assert hasattr(frames, "shape") @@ -46,3 +56,5 @@ def test_video_diffusion_model(model_name: str): assert frames.shape[1] == num_frames assert frames.shape[2] == height assert frames.shape[3] == width + # manually close the Omni instance + m.close() diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py index 9e925f5e4..55ed072a2 100644 --- a/tests/e2e/offline_inference/test_teacache.py +++ b/tests/e2e/offline_inference/test_teacache.py @@ -47,7 +47,7 @@ def test_teacache(model_name: str): width = 256 num_inference_steps = 4 # Minimal steps for fast test - images = m.generate( + outputs = m.generate( "a photo of a cat sitting on a laptop keyboard", height=height, width=width, @@ -56,6 +56,17 @@ def test_teacache(model_name: str): generator=torch.Generator("cuda").manual_seed(42), num_outputs_per_prompt=1, # Single output for speed ) + # Extract images from request_output[0]['images'] + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, dict) or "images" not in req_out: + raise ValueError("Invalid request_output structure or missing 'images' key") + + images = req_out["images"] # Verify generation succeeded assert images is not None @@ -63,3 +74,5 @@ def test_teacache(model_name: str): # Check image size assert images[0].width == width assert images[0].height == height + # manually close the Omni instance + m.close() diff --git a/tests/e2e/online_serving/test_i2i_multi_image_input.py b/tests/e2e/online_serving/test_i2i_multi_image_input.py index 0d5c9a5c5..dbaf9f53f 100644 --- a/tests/e2e/online_serving/test_i2i_multi_image_input.py +++ b/tests/e2e/online_serving/test_i2i_multi_image_input.py @@ -6,6 +6,7 @@ import base64 import os +import signal import socket import subprocess import sys @@ -66,6 +67,7 @@ def _start_server(self) -> None: cmd, env=env, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root + start_new_session=True, ) # Wait for server to be ready @@ -91,11 +93,18 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if self.proc: - self.proc.terminate() + try: + os.killpg(self.proc.pid, signal.SIGTERM) + except ProcessLookupError: + pass + try: self.proc.wait(timeout=30) except subprocess.TimeoutExpired: - self.proc.kill() + try: + os.killpg(self.proc.pid, signal.SIGKILL) + except ProcessLookupError: + pass self.proc.wait() diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index 63ff0e050..b90e03555 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -5,7 +5,9 @@ """ import concurrent.futures +import ctypes import os +import signal import socket import subprocess import sys @@ -65,11 +67,22 @@ def _start_server(self) -> None: str(self.port), ] + self.serve_args + # Helper to ensure child process dies when parent dies + libc = ctypes.CDLL("libc.so.6") + + def preexec_fn(): + # Ensure the child process receives SIGTERM when the parent (this test runner) dies. + # This prevents orphaned processes if the test is killed unexpectedly. + PR_SET_PDEATHSIG = 1 + libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM) + print(f"Launching OmniServer with: {' '.join(cmd)}") self.proc = subprocess.Popen( cmd, env=env, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root + start_new_session=True, + preexec_fn=preexec_fn, ) # Wait for server to be ready @@ -95,11 +108,18 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if self.proc: - self.proc.terminate() + try: + os.killpg(self.proc.pid, signal.SIGTERM) + except ProcessLookupError: + pass + try: self.proc.wait(timeout=30) except subprocess.TimeoutExpired: - self.proc.kill() + try: + os.killpg(self.proc.pid, signal.SIGKILL) + except ProcessLookupError: + pass self.proc.wait() diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 524aed726..183fa0519 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -132,8 +132,9 @@ def test_client(mock_async_diffusion): app.include_router(router) # Set up app state with diffusion engine - app.state.diffusion_engine = mock_async_diffusion - app.state.diffusion_model_name = "Qwen/Qwen-Image" + app.state.engine_client = mock_async_diffusion + app.state.stage_configs = [{"stage_type": "diffusion"}] + app.state.served_model_names = "Qwen/Qwen-Image" return TestClient(app) diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index c6dd04241..ce36b863d 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -83,6 +83,8 @@ def __init__(self, config): self.stage_id = getattr(config, "stage_id", 0) self.engine_args = config.engine_args self.model_stage = getattr(config.engine_args, "model_stage", None) + # set default sampling params + self.default_sampling_params = {"temperature": 1.0} # Allow configuring final_output and final_output_type self.final_output = config.final_output if hasattr(config, "final_output") else False self.final_output_type = getattr(config, "final_output_type", None) @@ -171,18 +173,6 @@ def generate(self, prompts, sampling_params): yield from self._outputs -class _FakeStageLLM: - """Replace OmniStageLLM to avoid constructing real engine.""" - - def __init__(self, **kwargs): - # Allow injecting custom fake outputs, default returns single placeholder output - fake_outputs = kwargs.get("_fake_outputs", [[{"text": "ok"}]]) - self._fake_engine = _FakeEngine(fake_outputs) - - def generate(self, prompts, sampling_params, use_tqdm=False): - yield from self._fake_engine.generate(prompts, sampling_params) - - @pytest.fixture def fake_stage_config(): return { @@ -298,22 +288,22 @@ def _fake_load(result, obj_key, shm_key): def _fake_set(obj): return str(obj).encode() - monkeypatch.setattr("vllm_omni.entrypoints.omni_llm._encode", _fake_encode, raising=False) - monkeypatch.setattr("vllm_omni.entrypoints.omni_llm._load", _fake_load, raising=False) - monkeypatch.setattr("vllm_omni.entrypoints.omni_llm._set", _fake_set, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False) + monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False) def _setup_log_mocks(monkeypatch): """Helper function to set up logging and stats mocks.""" # Mock init_stats_paths to return None files (no stats logging) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.init_stats_paths", + "vllm_omni.entrypoints.omni.init_stats_paths", lambda enable_stats, log_file: (None, None), raising=False, ) # Mock configure_orchestrator_logger to do nothing monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.configure_orchestrator_logger", + "vllm_omni.entrypoints.omni.configure_orchestrator_logger", lambda logger, log_file: None, raising=False, ) @@ -340,7 +330,7 @@ def build_and_log_summary(self, final_stage_id): return "Fake summary" monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.OrchestratorMetrics", + "vllm_omni.entrypoints.omni.OrchestratorMetrics", _FakeOrchestratorMetrics, raising=False, ) @@ -482,7 +472,8 @@ def _fake_loader(model: str): for module_name in [ "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni_llm", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", "vllm_omni.entrypoints.omni_stage", ]: if module_name in sys.modules: @@ -496,7 +487,7 @@ def _fake_loader(model: str): # Mock remove_old_logs to avoid bug (called before stage_list is created) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.remove_old_logs", + "vllm_omni.entrypoints.log_utils.remove_old_logs", lambda log_file, num_stages: None, raising=False, ) @@ -516,28 +507,28 @@ def _fake_loader(model: str): ) # Import the module after mocks are set - import vllm_omni.entrypoints.omni_llm as omni_llm_module + import vllm_omni.entrypoints.omni as omni_module # Patch the imported function and class in the module - monkeypatch.setattr(omni_llm_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_llm_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) - from vllm_omni.entrypoints.omni_llm import OmniLLM + from vllm_omni.entrypoints.omni import Omni - llm = OmniLLM(model="any", init_timeout=1) + omni = Omni(model="any", init_timeout=1) # Verify: auto-loaded stage_configs and stage_list have consistent count - assert isinstance(llm.stage_configs, list) - assert len(llm.stage_configs) == 2 - assert len(llm.stage_list) == 2 + assert isinstance(omni.stage_configs, list) + assert len(omni.stage_configs) == 2 + assert len(omni.stage_list) == 2 # Verify: each Stage is _FakeStage instance - for st in llm.stage_list: + for st in omni.stage_list: assert isinstance(st, _FakeStage) # Verify: queues are attached - for st in llm.stage_list: + for st in omni.stage_list: assert st._in_q is not None assert st._out_q is not None # Verify: all stages are ready - assert len(llm._stages_ready) == 2 + assert len(omni._stages_ready) == 2 def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): @@ -550,7 +541,8 @@ def _fake_loader(model: str): for module_name in [ "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni_llm", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", "vllm_omni.entrypoints.omni_stage", ]: if module_name in sys.modules: @@ -562,7 +554,7 @@ def _fake_loader(model: str): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.remove_old_logs", + "vllm_omni.entrypoints.log_utils.remove_old_logs", lambda log_file, num_stages: None, raising=False, ) @@ -573,7 +565,7 @@ def _fake_loader(model: str): raising=False, ) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", _fake_loader, raising=False, ) @@ -583,16 +575,16 @@ def _fake_loader(model: str): raising=False, ) - import vllm_omni.entrypoints.omni_llm as omni_llm_module + import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_llm_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_llm_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) - from vllm_omni.entrypoints.omni_llm import OmniLLM + from vllm_omni.entrypoints.omni import Omni - llm = OmniLLM(model="any", init_timeout=1) + omni = Omni(model="any", init_timeout=1) with pytest.raises(ValueError): - llm.generate(prompts=["hi"], sampling_params_list=[]) + omni.generate(prompts=["hi"], sampling_params_list=[]) def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): @@ -608,7 +600,8 @@ def _fake_loader(model: str): for module_name in [ "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni_llm", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", "vllm_omni.entrypoints.omni_stage", ]: if module_name in sys.modules: @@ -620,7 +613,7 @@ def _fake_loader(model: str): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.remove_old_logs", + "vllm_omni.entrypoints.log_utils.remove_old_logs", lambda log_file, num_stages: None, raising=False, ) @@ -631,7 +624,7 @@ def _fake_loader(model: str): raising=False, ) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", _fake_loader, raising=False, ) @@ -641,19 +634,19 @@ def _fake_loader(model: str): raising=False, ) - import vllm_omni.entrypoints.omni_llm as omni_llm_module + import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_llm_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_llm_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_llm_module, "uuid", uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) - from vllm_omni.entrypoints.omni_llm import OmniLLM + from vllm_omni.entrypoints.omni import Omni - llm = OmniLLM(model="any", init_timeout=1) + omni = Omni(model="any", init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" @@ -662,7 +655,7 @@ def _fake_loader(model: str): # Note: We put results before calling generate, which simulates worker processes # that have already completed. The polling loop will collect them in stage order. # Stage 0 output (will be collected first) - llm.stage_list[0]._out_q.put_nowait( + omni.stage_list[0]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 0, "text": "s0"}], @@ -674,7 +667,7 @@ def _fake_loader(model: str): # but for testing we pre-populate it. The polling loop processes stages # in order, so stage 0 result will be collected first, then forwarded, # then stage 1 result will be collected. - llm.stage_list[1]._out_q.put_nowait( + omni.stage_list[1]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 1, "text": "s1"}], @@ -685,17 +678,17 @@ def _fake_loader(model: str): # Use dicts instead of object() for serializable sampling params sampling_params_list = [{"temperature": 0.7}, {"temperature": 0.8}] prompts = ["hi"] - outputs = llm.generate(prompts=prompts, sampling_params_list=sampling_params_list) + outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) # Both stages have final_output=True, so should aggregate two OmniRequestOutput assert len(outputs) == 2 # Verify stage outputs are set - assert llm.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}] - assert llm.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}] + assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}] + assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}] # Verify stage 0 input queue received the task - assert not llm.stage_list[0]._in_q.empty() + assert not omni.stage_list[0]._in_q.empty() # Verify stage 1 received forwarded task (process_engine_inputs was called) - assert llm.stage_list[1].process_engine_inputs([], []) is not None + assert omni.stage_list[1].process_engine_inputs([], []) is not None def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): @@ -712,7 +705,8 @@ def _fake_loader(model: str): for module_name in [ "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni_llm", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", "vllm_omni.entrypoints.omni_stage", ]: if module_name in sys.modules: @@ -724,7 +718,7 @@ def _fake_loader(model: str): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.remove_old_logs", + "vllm_omni.entrypoints.log_utils.remove_old_logs", lambda log_file, num_stages: None, raising=False, ) @@ -735,7 +729,83 @@ def _fake_loader(model: str): raising=False, ) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.load_stage_configs_from_model", + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg: _FakeStage(cfg), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + + # Mock uuid.uuid4() to return a predictable value for request ID generation + test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") + monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) + + from vllm_omni.entrypoints.omni import Omni + + omni = Omni(model="any", init_timeout=1) + + # Generate the expected request ID format: "0_" + expected_request_id = f"0_{test_uuid}" + + # Simulate worker behavior: put results into output queues + omni.stage_list[0]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 0}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + omni.stage_list[1]._out_q.put_nowait( + { + "request_id": expected_request_id, + "engine_outputs": [{"stage": 1}], + "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, + } + ) + + # Use dicts instead of object() for serializable sampling params + outputs = omni.generate(prompts=["p"], sampling_params_list=[{"temperature": 0.7}, {"temperature": 0.8}]) + assert outputs == [] + + +def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config): + """Test that generate uses default sampling params when sampling_params_list is None.""" + stage_cfg0 = dict(fake_stage_config) + stage_cfg1 = dict(fake_stage_config) + stage_cfg0["final_output"] = False + stage_cfg1["final_output"] = False + + def _fake_loader(model: str): + return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] + + import sys + + for module_name in [ + "vllm_omni.entrypoints.utils", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", + "vllm_omni.entrypoints.omni_stage", + ]: + if module_name in sys.modules: + del sys.modules[module_name] + + _setup_engine_mocks(monkeypatch) + _setup_multiprocessing_mocks(monkeypatch) + _setup_ipc_mocks(monkeypatch) + _setup_log_mocks(monkeypatch) + + monkeypatch.setattr( + "vllm_omni.entrypoints.log_utils.remove_old_logs", + lambda log_file, num_stages: None, + raising=False, + ) + + monkeypatch.setattr( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", _fake_loader, raising=False, ) @@ -745,42 +815,40 @@ def _fake_loader(model: str): raising=False, ) - import vllm_omni.entrypoints.omni_llm as omni_llm_module + import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_llm_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_llm_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_llm_module, "uuid", uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) - from vllm_omni.entrypoints.omni_llm import OmniLLM + from vllm_omni.entrypoints.omni import Omni - llm = OmniLLM(model="any", init_timeout=1) + omni = Omni(model="any", init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" # Simulate worker behavior: put results into output queues - llm.stage_list[0]._out_q.put_nowait( + omni.stage_list[0]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 0}], "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, } ) - llm.stage_list[1]._out_q.put_nowait( + omni.stage_list[1]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 1}], "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, } ) - - # Use dicts instead of object() for serializable sampling params - outputs = llm.generate(prompts=["p"], sampling_params_list=[{"temperature": 0.7}, {"temperature": 0.8}]) - assert outputs == [] + # Use the default sampling params + omni.generate(prompts=["p"], sampling_params_list=None) def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): @@ -793,7 +861,8 @@ def _fake_loader(model: str): for module_name in [ "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni_llm", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", "vllm_omni.entrypoints.omni_stage", ]: if module_name in sys.modules: @@ -805,7 +874,7 @@ def _fake_loader(model: str): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.remove_old_logs", + "vllm_omni.entrypoints.log_utils.remove_old_logs", lambda log_file, num_stages: None, raising=False, ) @@ -832,17 +901,17 @@ def init_stage_worker(self, *args, **kwargs): raising=False, ) - import vllm_omni.entrypoints.omni_llm as omni_llm_module + import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_llm_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_llm_module, "OmniStage", lambda cfg: _FakeStageNoReady(cfg)) + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStageNoReady(cfg)) - from vllm_omni.entrypoints.omni_llm import OmniLLM + from vllm_omni.entrypoints.omni import Omni # Use very short timeout - llm = OmniLLM(model="any", init_timeout=0.01) + omni = Omni(model="any", init_timeout=0.01) # Verify that no stages are ready - assert len(llm._stages_ready) == 0 + assert len(omni._stages_ready) == 0 def test_generate_handles_error_messages(monkeypatch, fake_stage_config): @@ -855,7 +924,8 @@ def _fake_loader(model: str): for module_name in [ "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni_llm", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", "vllm_omni.entrypoints.omni_stage", ]: if module_name in sys.modules: @@ -867,7 +937,7 @@ def _fake_loader(model: str): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.remove_old_logs", + "vllm_omni.entrypoints.log_utils.remove_old_logs", lambda log_file, num_stages: None, raising=False, ) @@ -877,36 +947,31 @@ def _fake_loader(model: str): _fake_loader, raising=False, ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.load_stage_configs_from_model", - _fake_loader, - raising=False, - ) monkeypatch.setattr( "vllm_omni.entrypoints.omni_stage.OmniStage", lambda cfg: _FakeStage(cfg), raising=False, ) - import vllm_omni.entrypoints.omni_llm as omni_llm_module + import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_llm_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_llm_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) # Mock uuid.uuid4() to return a predictable value for request ID generation test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_llm_module, "uuid", uuid) + monkeypatch.setattr(omni_module, "uuid", uuid) - from vllm_omni.entrypoints.omni_llm import OmniLLM + from vllm_omni.entrypoints.omni import Omni - llm = OmniLLM(model="any", init_timeout=1) + omni = Omni(model="any", init_timeout=1) # Generate the expected request ID format: "0_" expected_request_id = f"0_{test_uuid}" # Put error message in output queue - llm.stage_list[0]._out_q.put_nowait( + omni.stage_list[0]._out_q.put_nowait( { "request_id": expected_request_id, "error": "test error", @@ -914,7 +979,7 @@ def _fake_loader(model: str): ) # Also put a valid result after error to allow the loop to complete # (error handling continues the loop, so we need a valid result to finish) - llm.stage_list[0]._out_q.put_nowait( + omni.stage_list[0]._out_q.put_nowait( { "request_id": expected_request_id, "engine_outputs": [{"stage": 0, "text": "result"}], @@ -925,7 +990,7 @@ def _fake_loader(model: str): # Generate should handle error gracefully (log but continue) # Use dict instead of object() for serializable sampling params sampling_params_list = [{"temperature": 0.7}] - outputs = llm.generate(prompts=["hi"], sampling_params_list=sampling_params_list) + outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list) # Should return final output (error was logged but didn't stop processing) assert isinstance(outputs, list) # Since final_output=True, should have one output @@ -942,7 +1007,8 @@ def _fake_loader(model: str): for module_name in [ "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni_llm", + "vllm_omni.entrypoints.omni", + "vllm_omni.entrypoints.log_utils", "vllm_omni.entrypoints.omni_stage", ]: if module_name in sys.modules: @@ -954,7 +1020,7 @@ def _fake_loader(model: str): _setup_log_mocks(monkeypatch) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.remove_old_logs", + "vllm_omni.entrypoints.log_utils.remove_old_logs", lambda log_file, num_stages: None, raising=False, ) @@ -965,7 +1031,7 @@ def _fake_loader(model: str): raising=False, ) monkeypatch.setattr( - "vllm_omni.entrypoints.omni_llm.load_stage_configs_from_model", + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", _fake_loader, raising=False, ) @@ -975,22 +1041,22 @@ def _fake_loader(model: str): raising=False, ) - import vllm_omni.entrypoints.omni_llm as omni_llm_module + import vllm_omni.entrypoints.omni as omni_module - monkeypatch.setattr(omni_llm_module, "load_stage_configs_from_model", _fake_loader) - monkeypatch.setattr(omni_llm_module, "OmniStage", lambda cfg: _FakeStage(cfg)) + monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader) + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg: _FakeStage(cfg)) - from vllm_omni.entrypoints.omni_llm import OmniLLM + from vllm_omni.entrypoints.omni import Omni - llm = OmniLLM(model="any", init_timeout=1) + omni = Omni(model="any", init_timeout=1) # Call close - llm.close() + omni.close() # Verify shutdown signal (None) was sent to input queue # Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe) try: - shutdown_signal = llm.stage_list[0]._in_q.get_nowait() + shutdown_signal = omni.stage_list[0]._in_q.get_nowait() assert shutdown_signal is None except Empty: # If queue was already empty or only had stage_ready, that's also acceptable @@ -998,4 +1064,4 @@ def _fake_loader(model: str): pass # Verify stop_stage_worker was called (process should be set) - assert llm.stage_list[0]._proc is not None + assert omni.stage_list[0]._proc is not None diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index 8222c3076..e074689c9 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -87,7 +87,17 @@ def draw_hf_text_config(self): # we need to draw the text config from the corresponding model stage. if self.hf_config_name is None: return get_hf_text_config(self.hf_config) - return getattr(self.hf_config, self.hf_config_name).get_text_config() + try: + # Try to get the stage-specific config (e.g., thinker_config, talker_config) + stage_config = getattr(self.hf_config, self.hf_config_name) + return stage_config.get_text_config() + except AttributeError: + # Fallback: if the attribute doesn't exist, use the default get_hf_text_config + logger.warning( + f"Config attribute '{self.hf_config_name}' not found in hf_config, " + "falling back to default get_hf_text_config" + ) + return get_hf_text_config(self.hf_config) def __post_init__( self, @@ -173,9 +183,19 @@ def __post_init__( self.hf_text_config = self.draw_hf_text_config() self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) self.encoder_config = self._get_encoder_config() - self.hf_image_processor_config = get_hf_image_processor_config( - self.model, hf_token=self.hf_token, revision=self.revision - ) + # Try to load image processor config, but allow it to fail for stages that don't need it + try: + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, hf_token=self.hf_token, revision=self.revision + ) + except (OSError, ValueError, IndexError) as e: + # Some stages (e.g., code2wav, talker) don't need image processor + # Log warning but allow initialization to continue + logger.warning( + f"Failed to load image processor config for model '{self.model}': {e}. " + "This is expected for stages that don't require image processing." + ) + self.hf_image_processor_config = None architectures = self.architectures registry = self.registry diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 664ca09cb..bef96a464 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -75,6 +75,21 @@ def __post_init__(self) -> None: * self.cfg_parallel_size ) + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DiffusionParallelConfig": + """ + Create DiffusionParallelConfig from a dictionary. + + Args: + data: Dictionary containing parallel configuration parameters + + Returns: + DiffusionParallelConfig instance with parameters set from dict + """ + if not isinstance(data, dict): + raise TypeError(f"Expected parallel config dict, got {type(data)!r}") + return cls(**data) + @dataclass class TransformerConfig: @@ -221,7 +236,7 @@ def __getattr__(self, item: str) -> Any: @dataclass class OmniDiffusionConfig: # Model and path configuration (for convenience) - model: str + model: str | None = None model_class_name: str | None = None @@ -380,6 +395,15 @@ def __post_init__(self): # TODO: remove hard code initial_master_port = (self.master_port or 30005) + random.randint(0, 100) self.master_port = self.settle_port(initial_master_port, 37) + + # Convert parallel_config dict to DiffusionParallelConfig if needed + # This must be done before accessing parallel_config.world_size + if isinstance(self.parallel_config, dict): + self.parallel_config = DiffusionParallelConfig.from_dict(self.parallel_config) + elif not isinstance(self.parallel_config, DiffusionParallelConfig): + # If it's neither dict nor DiffusionParallelConfig, use default config + self.parallel_config = DiffusionParallelConfig() + if self.num_gpus is None: if self.parallel_config is not None: self.num_gpus = self.parallel_config.world_size diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 30be322a3..9a5e395ae 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -14,6 +14,7 @@ from vllm_omni.diffusion.registry import get_diffusion_post_process_func, get_diffusion_pre_process_func from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.scheduler import Scheduler, scheduler +from vllm_omni.outputs import OmniRequestOutput from vllm_omni.utils.platform_utils import get_diffusion_worker_class logger = init_logger(__name__) @@ -84,12 +85,86 @@ def step(self, requests: list[OmniDiffusionRequest]): raise Exception(f"{output.error}") logger.info("Generation completed successfully.") + if output.output is None: + logger.warning("Output is None, returning empty OmniRequestOutput") + # Return empty output for the first request + if len(requests) > 0: + request = requests[0] + request_id = request.request_id or "" + prompt = request.prompt + if isinstance(prompt, list): + prompt = prompt[0] if prompt else None + return OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + prompt=prompt, + metrics={}, + latents=None, + ) + return None + postprocess_start_time = time.time() - result = self.post_process_func(output.output) if self.post_process_func is not None else output.output + images = self.post_process_func(output.output) if self.post_process_func is not None else output.output postprocess_time = time.time() - postprocess_start_time logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds") - return result + # Convert to OmniRequestOutput format + # Ensure images is a list + if not isinstance(images, list): + images = [images] if images is not None else [] + + # Handle single request or multiple requests + if len(requests) == 1: + # Single request: return single OmniRequestOutput + request = requests[0] + request_id = request.request_id or "" + prompt = request.prompt + if isinstance(prompt, list): + prompt = prompt[0] if prompt else None + + metrics = {} + if output.trajectory_timesteps is not None: + metrics["trajectory_timesteps"] = output.trajectory_timesteps + + return OmniRequestOutput.from_diffusion( + request_id=request_id, + images=images, + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + ) + else: + # Multiple requests: return list of OmniRequestOutput + # Split images based on num_outputs_per_prompt for each request + results = [] + image_idx = 0 + + for request in requests: + request_id = request.request_id or "" + prompt = request.prompt + if isinstance(prompt, list): + prompt = prompt[0] if prompt else None + + # Get images for this request + num_outputs = request.num_outputs_per_prompt + request_images = images[image_idx : image_idx + num_outputs] if image_idx < len(images) else [] + image_idx += num_outputs + + metrics = {} + if output.trajectory_timesteps is not None: + metrics["trajectory_timesteps"] = output.trajectory_timesteps + + results.append( + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=request_images, + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + ) + ) + + return results except Exception as e: logger.error(f"Generation failed: {e}") return None diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_worker.py index 75cb6ca0c..13780dff0 100644 --- a/vllm_omni/diffusion/worker/gpu_worker.py +++ b/vllm_omni/diffusion/worker/gpu_worker.py @@ -124,6 +124,8 @@ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusi Execute a forward pass. """ assert self.pipeline is not None + if not reqs or len(reqs) == 0: + raise ValueError("Cannot execute model with empty request list") # TODO: dealing with first req for now req = reqs[0] @@ -238,7 +240,7 @@ def worker_busy_loop(self) -> None: ) continue - if msg is None: + if msg is None or len(msg) == 0: logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) continue diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 57324374f..6b6a9c727 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -3,6 +3,7 @@ from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeTextConfig from vllm.engine.arg_utils import EngineArgs from vllm.logger import init_logger +from vllm.transformers_utils.config import get_hf_text_config from vllm.v1.engine.async_llm import AsyncEngineArgs from vllm_omni.config import OmniModelConfig @@ -47,7 +48,19 @@ def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: # transformers' get_text_config method is used to get the text config from thinker_config. # to handle the case that each model stage has their own text config, # we need to draw the text config from the corresponding model stage. - return getattr(config_dict["hf_config"], config_dict["hf_config_name"]).get_text_config() + hf_config = config_dict["hf_config"] + hf_config_name = config_dict["hf_config_name"] + try: + # Try to get the stage-specific config (e.g., thinker_config, talker_config) + stage_config = getattr(hf_config, hf_config_name) + return stage_config.get_text_config() + except AttributeError: + # Fallback: if the attribute doesn't exist, use the default get_hf_text_config + logger.warning( + f"Config attribute '{hf_config_name}' not found in hf_config, " + "falling back to default get_hf_text_config" + ) + return get_hf_text_config(hf_config) def _ensure_omni_models_registered(self): if hasattr(self, "_omni_models_registered"): @@ -117,7 +130,19 @@ def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: # transformers' get_text_config method is used to get the text config from thinker_config. # to handle the case that each model stage has their own text config, # we need to draw the text config from the corresponding model stage. - return getattr(config_dict["hf_config"], config_dict["hf_config_name"]).get_text_config() + hf_config = config_dict["hf_config"] + hf_config_name = config_dict["hf_config_name"] + try: + # Try to get the stage-specific config (e.g., thinker_config, talker_config) + stage_config = getattr(hf_config, hf_config_name) + return stage_config.get_text_config() + except AttributeError: + # Fallback: if the attribute doesn't exist, use the default get_hf_text_config + logger.warning( + f"Config attribute '{hf_config_name}' not found in hf_config, " + "falling back to default get_hf_text_config" + ) + return get_hf_text_config(hf_config) def _ensure_omni_models_registered(self): if hasattr(self, "_omni_models_registered"): diff --git a/vllm_omni/entrypoints/__init__.py b/vllm_omni/entrypoints/__init__.py index 1670d8681..8d0ee51a5 100644 --- a/vllm_omni/entrypoints/__init__.py +++ b/vllm_omni/entrypoints/__init__.py @@ -10,8 +10,8 @@ - Omni: Unified entrypoint that auto-selects between LLM and Diffusion """ -from vllm_omni.entrypoints.async_diffusion import AsyncOmniDiffusion from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion from vllm_omni.entrypoints.omni import Omni __all__ = [ diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index e59b02778..51b434714 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -1,41 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import multiprocessing as mp import os -import socket import time -from argparse import Namespace +import uuid from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict from pprint import pformat from typing import Any -import torch - -# External library imports (vLLM) -import vllm.envs as envs +from omegaconf import OmegaConf from vllm.config import VllmConfig -from vllm.engine.protocol import EngineClient from vllm.inputs import PromptType from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.plugins.io_processors import get_io_processor from vllm.sampling_params import SamplingParams from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config -from vllm.tracing import init_tracer -from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.usage.usage_lib import UsageContext -from vllm.utils.func_utils import deprecate_kwargs -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager # Internal imports (our code) from vllm_omni.config import OmniModelConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, initialize_orchestrator_connectors, @@ -46,9 +36,7 @@ get_ray_queue_class, try_close_ray, ) -from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs from vllm_omni.engine.input_processor import OmniInputProcessor -from vllm_omni.engine.output_processor import MultimodalOutputProcessor from vllm_omni.entrypoints.client_request_state import ClientRequestState from vllm_omni.entrypoints.log_utils import ( OrchestratorMetrics, @@ -66,18 +54,32 @@ logger = init_logger(__name__) -class AsyncOmni(EngineClient): - """Async entry point for vLLM-Omni inference. +def _dummy_snapshot_download(model_id): + return model_id + + +def omni_snapshot_download(model_id) -> str: + # TODO: this is just a workaround for quickly use modelscope, we should support + # modelscope in weight loading feature instead of using `snapshot_download` + if os.environ.get("VLLM_USE_MODELSCOPE", False): + from modelscope.hub.snapshot_download import snapshot_download + + return snapshot_download(model_id) + else: + return _dummy_snapshot_download(model_id) + - This class provides an asynchronous interface for running multi-modal - comprehension and generation models. It orchestrates multiple - stages in a pipeline, where each stage runs in a separate process. - Designed for use with async/await patterns and streaming generation. +class AsyncOmni: + """Asynchronous unified entry point supporting multi-stage pipelines for LLM and Diffusion models. + + Similar to the Omni class, but provides an asynchronous interface supporting + asynchronous LLM and Diffusion models. Args: - model: Model name or path to load - cli_args: Namespace object containing command-line arguments. - Expected attributes include: + *args: Variable length argument list. + - args[0] (str): Model name or path to load. + **kwargs: Arbitrary keyword arguments. + - model (str): Model name or path to load (if not in args). - stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging @@ -88,10 +90,10 @@ class AsyncOmni(EngineClient): for IPC. Objects larger than this threshold will use shared memory. - batch_timeout: Timeout in seconds for batching requests within a stage - init_timeout: Timeout in seconds for waiting for all stages to initialize - **kwargs: Additional keyword arguments passed to stage engines + - Additional keyword arguments passed to stage engines. Example: - >>> async_llm = AsyncOmni(model="Qwen/Qwen2.5-Omni-7B", cli_args=args) + >>> async_llm = AsyncOmni(model="Qwen/Qwen2.5-Omni-7B") >>> async for output in async_llm.generate( ... prompt="Hello", ... request_id="req-1", @@ -100,45 +102,525 @@ class AsyncOmni(EngineClient): ... print(output) """ - def __init__( - self, - model: str, - cli_args: Namespace, - **kwargs: Any, - ): - self.worker_backend = getattr(cli_args, "worker_backend", "multi_process") - self.ray_address = getattr(cli_args, "ray_address", "") + def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: + model = args[0] if args else kwargs.get("model", "") + assert model != "", "Null model id detected, please specify a model id." + model = omni_snapshot_download(model) + if args: + args[0] = model + elif kwargs.get("model", "") != "": + kwargs["model"] = model + + # Stage management attributes + self.stage_list: list[OmniStage] = [] + self._stage_in_queues: list[mp.Queue] = [] + self._stage_out_queues: list[mp.Queue] = [] + self._stages_ready: set[int] = set() self._ray_pg = None + self._queue_cls = None + self._ctx = None + + # Pause/resume control attributes + self._pause_cond: asyncio.Condition = asyncio.Condition() + self._paused: bool = False - self.batch_timeout = cli_args.batch_timeout - self._enable_stats: bool = bool(cli_args.log_stats) + # Request state tracking + self.request_states: dict[str, ClientRequestState] = {} + self.output_handler: asyncio.Task | None = None - # Pause / resume state for async RL workflows. - self._pause_cond = asyncio.Condition() - self._paused = False + # Initialize stages - each stage will create appropriate instances based on stage_type + # Stage workers will automatically create AsyncOmniLLM or AsyncOmniDiffusion instances + # based on stage_type in YAML config (handled in omni_stage.py) + logger.info(f"Initializing async stages for model: {model}") + # Use kwargs-based initialization logic to avoid conflicts with old _initialize_stages signature + self._initialize_stages_from_kwargs(model, kwargs) - base_engine_args = AsyncOmniEngineArgs.from_cli_args(cli_args).__dict__.copy() + def _initialize_stages_from_kwargs(self, model: str, kwargs: dict[str, Any]) -> None: + """Initialize stage list management. - if cli_args.stage_configs_path is None: + Each stage will create appropriate instances (AsyncOmniLLM or AsyncOmniDiffusion) + based on stage_type in YAML config. + """ + init_sleep_seconds = kwargs.get("init_sleep_seconds", 20) + shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) + init_timeout = kwargs.get("init_timeout", 300) + worker_backend = kwargs.get("worker_backend", "multi_process") + ray_address = kwargs.get("ray_address", None) + batch_timeout = kwargs.get("batch_timeout", 10) + stage_configs_path = kwargs.get("stage_configs_path", None) + log_stats = kwargs.get("log_stats", False) + + # Load stage configs from YAML + if stage_configs_path is None: self.config_path = resolve_model_config_path(model) - self.stage_configs = load_stage_configs_from_model(model, base_engine_args) + self.stage_configs = load_stage_configs_from_model(model) + if not self.stage_configs: + default_stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": { + "process": True, + "devices": "0", + "max_batch_size": 1, + }, + "engine_args": { + "parallel_config": DiffusionParallelConfig( + pipeline_parallel_size=1, + data_parallel_size=1, + tensor_parallel_size=1, + sequence_parallel_size=1, + ulysses_degree=1, + ring_degree=1, + cfg_parallel_size=1, + ), + "vae_use_slicing": kwargs.get("vae_use_slicing", False), + "vae_use_tiling": kwargs.get("vae_use_tiling", False), + "cache_backend": kwargs.get("cache_backend", "none"), + "cache_config": kwargs.get("cache_config", None), + }, + "final_output": True, + "final_output_type": "image", + } + ] + default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" + self.stage_configs = OmegaConf.create(default_stage_cfg) else: - self.config_path = cli_args.stage_configs_path - self.stage_configs = load_stage_configs_from_yaml(cli_args.stage_configs_path, base_engine_args) - - shm_threshold_bytes = cli_args.shm_threshold_bytes - self.output_handler: asyncio.Task | None = None - self.request_states: dict[str, ClientRequestState] = {} # request_id -> state + self.config_path = stage_configs_path + self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) # Initialize connectors self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( - self.config_path, worker_backend=self.worker_backend, shm_threshold_bytes=shm_threshold_bytes + self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes ) - self.stage_list: list[OmniStage] = [] - self.default_sampling_params_list: list[SamplingParams] = [] + # Initialize stats paths + self._enable_stats: bool = bool(log_stats) + + self.worker_backend = worker_backend + self.ray_address = ray_address + self.batch_timeout = batch_timeout + + # Build OmniStage instances in parallel, preserving original order + def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: + idx, cfg = idx_cfg + return idx, OmniStage(cfg) + + with ThreadPoolExecutor(max_workers=min(len(self.stage_configs), max(1, os.cpu_count() or 1))) as executor: + futures = [executor.submit(_build_stage, (idx, cfg)) for idx, cfg in enumerate(self.stage_configs)] + results: list[tuple[int, OmniStage]] = [] + for fut in as_completed(futures): + results.append(fut.result()) + results.sort(key=lambda x: x[0]) + self.stage_list = [st for _, st in results] + + self.default_sampling_params_list = [st.default_sampling_params for st in self.stage_list] + self.output_modalities = [st.final_output_type for st in self.stage_list] + logger.debug("[AsyncOrchestrator] Loaded %d stages", len(self.stage_list)) + + if self.worker_backend == "ray": + self._queue_cls = get_ray_queue_class() + else: + self._ctx = mp.get_context("spawn") + self._queue_cls = lambda: self._ctx.Queue(maxsize=0) + + self._init_sleep_seconds = max(0, int(init_sleep_seconds)) + self._shm_threshold_bytes = max(0, int(shm_threshold_bytes)) + self._start_stages(model) + # Wait for all stages to report readiness before seeding + self._wait_for_stages_ready(timeout=init_timeout) + + def _start_stages(self, model: str) -> None: + """Start all stage processes.""" + if self.worker_backend == "ray": + # Initialize Ray cluster + self._ray_pg = create_placement_group( + number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" + ) + + for stage_id, stage in enumerate(self.stage_list): + in_q = self._queue_cls() + out_q = self._queue_cls() + self._stage_in_queues.append(in_q) + self._stage_out_queues.append(out_q) + stage.attach_queues(in_q, out_q) + + stage_connectors_config = get_stage_connector_config( + self.omni_transfer_config, + stage_id, + ) + + stage.init_stage_worker( + model, + shm_threshold_bytes=self._shm_threshold_bytes, + ctx=self._ctx if self.worker_backend != "ray" else None, + batch_timeout=self.batch_timeout, + connectors_config=stage_connectors_config, + worker_backend=self.worker_backend, + ray_placement_group=self._ray_pg, + ) + + logger.debug("[AsyncOrchestrator] Stage-%s process started", stage_id) + time.sleep(self._init_sleep_seconds) + + def _wait_for_stages_ready(self, timeout: int = 120) -> None: + """Wait for all stages to report readiness.""" + deadline = time.time() + max(0, int(timeout)) + num_stages = len(self.stage_list) + while len(self._stages_ready) < num_stages and time.time() < deadline: + progressed = False + for stage_id, stage in enumerate(self.stage_list): + if stage_id in self._stages_ready: + continue + result = stage.try_collect() + if result is None: + continue + progressed = True + if result.get("type") == "stage_ready": + self._stages_ready.add(stage_id) + # Store vllm_config received from worker process (may be None for diffusion stages) + vllm_config = result.get("vllm_config") + if vllm_config is not None: + stage.set_vllm_config(vllm_config) + tokenizer = result.get("tokenizer") + if tokenizer is not None: + stage.set_tokenizer(tokenizer) + is_tracing_enabled = result.get("is_tracing_enabled") + if is_tracing_enabled is not None: + stage.set_is_tracing_enabled(is_tracing_enabled) + logger.info("[AsyncOrchestrator] Stage-%s reported ready", stage_id) + else: + # No user data should arrive before seeding; ignore other messages + pass + if not progressed: + time.sleep(0.01) + if len(self._stages_ready) < num_stages: + not_ready = sorted(set(range(num_stages)) - set(self._stages_ready)) + logger.warning( + "[AsyncOrchestrator] Initialization timeout: only %s/%s stages are ready; not ready: %s", + len(self._stages_ready), + num_stages, + not_ready, + ) + # Provide actionable suggestions before shutdown + try: + suggestions = [ + "Verify GPU/device assignment in config (runtime.devices) is correct.", + "Check GPU/host memory availability; reduce model or batch size if needed.", + "Check model weights path and network reachability (if loading remotely).", + "Increase initialization wait time (init_sleep_seconds or call-site timeout).", + ] + logger.error( + "[AsyncOrchestrator] Stage initialization failed, shutting down. Suggestions:\n- %s", + "\n- ".join(suggestions), + ) + except Exception: + # Best-effort logging of suggestions + logger.error( + "[AsyncOrchestrator] Stage initialization failed and an error occurred while logging suggestions", + ) + elif len(self._stages_ready) == num_stages: + logger.info("[AsyncOrchestrator] All stages initialized successfully") + # Initialize input_processor, io_processor, and model_config for API server compatibility + # Find the first LLM stage (with vllm_config) to get vllm_config and tokenizer + for stage in self.stage_list: + if stage.vllm_config is not None and stage.tokenizer is not None: + try: + vllm_config = stage.vllm_config + tokenizer = stage.tokenizer + + # Initialize input_processor + self.input_processor = OmniInputProcessor( + vllm_config=vllm_config, + tokenizer=tokenizer, + ) + + # Initialize model_config + self.model_config = vllm_config.model_config + + # Initialize io_processor + io_processor_plugin = self.model_config.io_processor_plugin + self.io_processor = get_io_processor(vllm_config, io_processor_plugin) + + logger.info( + "[AsyncOrchestrator] Initialized input_processor, " + "io_processor, and model_config from stage-%s", + stage.stage_id, + ) + break + except Exception as e: + logger.warning( + "[AsyncOrchestrator] Failed to initialize processors from stage-%s: %s", + stage.stage_id, + e, + ) + # If no LLM stage found, set processors to None + if not hasattr(self, "input_processor") or self.input_processor is None: + logger.warning( + "[AsyncOrchestrator] No LLM stage found, processors will not be available. " + "This may cause issues with OpenAIServingModels." + ) + self.input_processor = None + self.io_processor = None + self.model_config = None + + async def _run_generation_async( + self, + prompts: PromptType | Sequence[PromptType] | OmniDiffusionRequest | Sequence[OmniDiffusionRequest], + sampling_params_list: Any | Sequence[Any] | None = None, + ) -> AsyncGenerator[OmniRequestOutput, None]: + """Asynchronously run pipeline generation.""" + logger.debug("[AsyncOrchestrator] generate() called") + if sampling_params_list is None: + raise ValueError("sampling_params_list is required for pipelined generation") + + # Normalize sampling_params_list to a list + if not isinstance(sampling_params_list, (list, tuple)): + sampling_params_list = [sampling_params_list] + else: + sampling_params_list = list(sampling_params_list) + + if len(sampling_params_list) != len(self.stage_list): + raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + + # Normalize prompts to a list for per-request iteration + if not isinstance(prompts, (list, tuple)): + request_prompts: list[PromptType] = [prompts] + else: + request_prompts = list(prompts) + + # Create async output queue monitors for each stage + stage_monitors = [] + for stage_id, stage in enumerate(self.stage_list): + monitor = asyncio.create_task(self._monitor_stage_outputs_async(stage_id, stage)) + stage_monitors.append(monitor) + + # Orchestrator keeps stage objects for input derivation + num_stages = len(self.stage_list) + + # Generate globally unique request IDs and map them to original prompts + request_ids: list[str] = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] + request_id_to_prompt: dict[str, PromptType] = {rid: p for rid, p in zip(request_ids, request_prompts)} + + # Track start time for each request for end-to-end timing + _req_start_ts: dict[str, float] = {} + _wall_start_ts: float = time.time() + + # Determine final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) + final_stage_id_to_prompt: dict[str, int] = {} + for rid, prompt in request_id_to_prompt.items(): + if isinstance(prompt, dict): + prompt_modalities = prompt.get("modalities", None) + else: + prompt_modalities = None + final_stage_id_for_e2e = get_final_stage_id_for_e2e( + prompt_modalities, self.output_modalities, self.stage_list + ) + final_stage_id_to_prompt[rid] = final_stage_id_for_e2e + + # Metrics/aggregation helper + metrics = OrchestratorMetrics( + num_stages, + self._enable_stats, + self._stats_file, + self._overall_stats_file, + _wall_start_ts, + ) + + # Seed all requests into stage-0 queue + logger.debug("[AsyncOrchestrator] Seeding %d requests into stage-0", len(request_prompts)) + # Mark first input time for stage-0 + metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() + + for req_id, prompt in request_id_to_prompt.items(): + sp0 = sampling_params_list[0] # type: ignore[index] + task = { + "request_id": req_id, + "engine_inputs": prompt, + "sampling_params": sp0, + } + logger.info(f"task: {task}") + self.stage_list[0].submit(task) + _req_start_ts[req_id] = time.time() + logger.debug("[AsyncOrchestrator] Enqueued request %s to stage-0", req_id) + + # For each stage, forward results in stage order; collect final results at the end + # We pipeline by continuously polling output queues + remaining_by_stage: list[int] = [len(request_prompts)] + [0] * (num_stages - 1) + completed_requests = 0 + total_requests = len(request_prompts) + + logger.debug( + "[AsyncOrchestrator] Entering scheduling loop: total_requests=%d, stages=%d", + total_requests, + num_stages, + ) + + while completed_requests < total_requests: + # Asynchronously wait for output from any stage + await asyncio.sleep(0.001) # Brief sleep to avoid CPU overload + + for stage_id, stage in enumerate(self.stage_list): + result = stage.try_collect() + if result is None: + continue + + req_id = result.get("request_id") + if "error" in result: + logger.error( + "Stage %s error on request %s: %s", + stage_id, + req_id, + result["error"], + ) + continue + + if result.get("type") == "stage_ready": + # If stage initialization is slower than expected, wait briefly and retry + await asyncio.sleep(0.05) + continue + + engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + # Mark last output time for this stage whenever we receive outputs + metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) + + try: + _m = result.get("metrics") + if _m is not None: + metrics.on_stage_metrics(stage_id, req_id, _m) + except Exception as e: + logger.exception( + "[AsyncOrchestrator] Failed to process metrics for stage %s, req %s: %s", + stage_id, + req_id, + e, + ) + + logger.debug( + "[AsyncOrchestrator] Stage-%s completed request %s; forwarding or finalizing", + stage_id, + req_id, + ) + stage.set_engine_outputs(engine_outputs) + + if getattr(stage, "final_output", False): + # Handle diffusion outputs that already contain images + if stage.final_output_type == "image": + # Extract images from engine_outputs if it's an OmniRequestOutput + images = [] + output_to_check = engine_outputs[0] if isinstance(engine_outputs, list) else engine_outputs + if isinstance(output_to_check, OmniRequestOutput) and output_to_check.images: + images = output_to_check.images + elif hasattr(output_to_check, "images") and output_to_check.images: + images = output_to_check.images + final_output = OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + images=images, + ) + else: + final_output = OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, # type: ignore[attr-defined] + request_output=engine_outputs, + ) + yield final_output + logger.debug( + "[AsyncOrchestrator] Request %s finalized at stage-%s", + req_id, + stage_id, + ) + + # End-to-end timing and time-per-token for final output + # (only once per request at the designated final stage) + try: + rid_key = str(req_id) + if stage_id == final_stage_id_to_prompt[req_id] and rid_key not in metrics.e2e_done: + metrics.on_finalize_request( + stage_id, + req_id, + engine_outputs, + _req_start_ts.get(req_id, _wall_start_ts), + ) + except Exception as e: + logger.exception( + "[AsyncOrchestrator] Finalize request handling error for req %s at stage %s: %s", + req_id, + stage_id, + e, + ) + + next_stage_id = stage_id + 1 + if next_stage_id <= final_stage_id_to_prompt[req_id]: + next_stage: OmniStage = self.stage_list[next_stage_id] + try: + next_inputs = next_stage.process_engine_inputs(self.stage_list, [request_id_to_prompt[req_id]]) + except Exception as e: + logger.exception( + "[AsyncOrchestrator] Process engine inputs error for req %s at stage %s: %s", + req_id, + next_stage_id, + e, + ) + continue + sp_next = sampling_params_list[next_stage_id] # type: ignore[index] + + # Check connector for this edge + connector_key = (str(stage_id), str(next_stage_id)) + connector = self.connectors.get(connector_key) + sent_via_connector = False + if connector: + sent_via_connector = try_send_via_connector( + connector=connector, + stage_id=stage_id, + next_stage_id=next_stage_id, + req_id=req_id, + next_inputs=next_inputs, + sampling_params=sp_next, + original_prompt=request_id_to_prompt[req_id], + next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, + metrics=metrics, + ) + + if not sent_via_connector: + raise RuntimeError( + f"[AsyncOrchestrator] Failed to send request {req_id} " + f"to stage-{next_stage_id} via connector. Configure a connector " + "for this edge or inspect connector logs for details." + ) + logger.debug( + "[AsyncOrchestrator] Forwarded request %s to stage-%s", + req_id, + next_stage_id, + ) + remaining_by_stage[next_stage_id] += 1 + else: + completed_requests += 1 + logger.debug( + "[AsyncOrchestrator] Request %s fully completed (%d/%d)", + req_id, + completed_requests, + total_requests, + ) - self._initialize_stages(model, cli_args.init_sleep_seconds, cli_args.shm_threshold_bytes, cli_args.init_timeout) + logger.debug("[AsyncOrchestrator] All requests completed") + + # Summarize and print statistics + try: + summary = metrics.build_and_log_summary(final_stage_id_to_prompt) + logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) + except Exception as e: + logger.exception("[AsyncOrchestrator] Failed to build/log summary: %s", e) + + async def _monitor_stage_outputs_async(self, stage_id: int, stage: OmniStage) -> None: + """Asynchronously monitor stage output queue.""" + while True: + result = stage.try_collect() + if result is not None: + # Put result into async queue for main loop processing + pass + await asyncio.sleep(0.001) # Here is a duplicated init for the first stage to get the tokenizer, # input_processor, and io_processor @@ -436,11 +918,25 @@ async def generate( if isinstance(engine_outputs, list): engine_outputs = engine_outputs[0] - yield OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, - request_output=engine_outputs, - ) + # Handle diffusion outputs that already contain images + if stage.final_output_type == "image": + images = [] + if isinstance(engine_outputs, OmniRequestOutput) and engine_outputs.images: + images = engine_outputs.images + elif hasattr(engine_outputs, "images") and engine_outputs.images: + images = engine_outputs.images + yield OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + images=images, + ) + else: + yield OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + ) # Forward to next stage if there is one next_stage_id = stage_id + 1 @@ -743,191 +1239,3 @@ async def is_paused(self) -> bool: async with self._pause_cond: return self._paused - - -class AsyncOmniStageLLM(AsyncLLM): - """Async single-stage LLM engine for use within a stage worker process. - - This class extends the base vLLM AsyncLLM class with omni-specific - processors for handling multimodal inputs and outputs. It is used - internally by AsyncOmniStage workers and should not be instantiated - directly by users. - - Args: - engine_args: AsyncOmniEngineArgs containing engine configuration - vllm_config: Global vLLM configuration - executor_class: Executor implementation class, e.g. MultiprocExecutor - log_stats: Whether to log statistics - usage_context: Usage context of the LLM (default: ENGINE_CONTEXT) - mm_registry: Multi-modal registry for processing multimodal inputs - use_cached_outputs: Whether to use cached outputs - log_requests: Whether to log requests - start_engine_loop: Whether to start the engine loop automatically - stat_loggers: Customized stat loggers for the engine. - If not provided, default stat loggers will be used. - Note: Stat logger interface may change in V1. - client_addresses: Optional dictionary mapping client names to addresses - client_count: Total number of clients (default: 1) - client_index: Index of this client (default: 0) - """ - - def __init__( - self, - engine_args: AsyncOmniEngineArgs, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - use_cached_outputs: bool = False, - log_requests: bool = True, - start_engine_loop: bool = True, - stat_loggers: list[StatLoggerFactory] | None = None, - client_addresses: dict[str, str] | None = None, - client_count: int = 1, - client_index: int = 0, - ) -> None: - """ - Create an AsyncLLM. - - Args: - vllm_config: global configuration. - executor_class: an Executor impl, e.g. MultiprocExecutor. - log_stats: Whether to log stats. - usage_context: Usage context of the LLM. - mm_registry: Multi-modal registry. - use_cached_outputs: Whether to use cached outputs. - log_requests: Whether to log requests. - start_engine_loop: Whether to start the engine loop. - stat_loggers: customized stat loggers for the engine. - If not provided, default stat loggers will be used. - PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE - IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. - - Returns: - None - """ - # Ensure we can serialize custom transformer configs - maybe_register_config_serialize_by_value() - - self.model_config = vllm_config.model_config - self.vllm_config = vllm_config - self.observability_config = vllm_config.observability_config - self.log_requests = log_requests - - self.log_stats = log_stats or (stat_loggers is not None) - if not log_stats and stat_loggers is not None: - logger.info( - "AsyncLLM created with log_stats=False and non-empty custom logger list; " - "enabling logging without default stat loggers" - ) - - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - # Tokenizer (+ ensure liveness if running in another process). - tokenizer = init_tokenizer_from_config(model_config=vllm_config.model_config) - - # InputProcessor (converts Inputs --> EngineCoreRequests). - self.input_processor = OmniInputProcessor( - vllm_config=vllm_config, - tokenizer=tokenizer, - mm_registry=mm_registry, - ) - - # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). - self.output_processor = MultimodalOutputProcessor( - tokenizer=tokenizer, - log_stats=self.log_stats, - engine_core_output_type=engine_args.engine_output_type, - ) - - if self.observability_config.otlp_traces_endpoint is not None: - tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) - self.output_processor.tracer = tracer - - # Pause / resume state for async RL workflows. - self._pause_cond = asyncio.Condition() - self._paused = False - - # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_async_mp_client( - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=self.log_stats, - client_addresses=client_addresses, - client_count=client_count, - client_index=client_index, - ) - - # Loggers. - self.logger_manager: StatLoggerManager | None = None - if self.log_stats: - self.logger_manager = StatLoggerManager( - vllm_config=vllm_config, - engine_idxs=self.engine_core.engine_ranks_managed, - custom_stat_loggers=stat_loggers, - enable_default_loggers=log_stats, - client_count=client_count, - ) - self.logger_manager.log_engine_initialized() - - self.output_handler: asyncio.Task | None = None - try: - # Start output handler eagerly if we are in the asyncio eventloop. - asyncio.get_running_loop() - self._run_output_handler() - except RuntimeError: - pass - - if envs.VLLM_TORCH_PROFILER_DIR: - logger.info( - "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 - envs.VLLM_TORCH_PROFILER_DIR, - ) - worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - ], - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True - ), - ) - else: - self.profiler = None - - @classmethod - @deprecate_kwargs( - "disable_log_requests", - additional_message=("This argument will have no effect. Use `enable_log_requests` instead."), - ) - def from_vllm_config( - cls, - vllm_config: VllmConfig, - engine_args: AsyncOmniEngineArgs, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: list[StatLoggerFactory] | None = None, - enable_log_requests: bool = False, - disable_log_stats: bool = False, - client_addresses: dict[str, str] | None = None, - client_count: int = 1, - client_index: int = 0, - disable_log_requests: bool = True, # Deprecated, will be removed - ) -> "AsyncLLM": - # Create the LLMEngine. - return cls( - vllm_config=vllm_config, - executor_class=Executor.get_class(vllm_config), - start_engine_loop=start_engine_loop, - stat_loggers=stat_loggers, - log_requests=enable_log_requests, - log_stats=not disable_log_stats, - usage_context=usage_context, - client_addresses=client_addresses, - client_count=client_count, - client_index=client_index, - engine_args=engine_args, - ) diff --git a/vllm_omni/entrypoints/async_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py similarity index 96% rename from vllm_omni/entrypoints/async_diffusion.py rename to vllm_omni/entrypoints/async_omni_diffusion.py index 079432eda..14f3bc0f6 100644 --- a/vllm_omni/entrypoints/async_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -178,7 +178,14 @@ async def generate( logger.error("Generation failed for request %s: %s", request_id, e) raise RuntimeError(f"Diffusion generation failed: {e}") from e - # Process results + # Check if result is already OmniRequestOutput + if isinstance(result, OmniRequestOutput): + # Update request_id if needed + if not result.request_id: + result.request_id = request_id + return result + + # Process results if not OmniRequestOutput images: list[Image.Image] = [] if result is not None: if isinstance(result, list): @@ -188,12 +195,6 @@ async def generate( elif isinstance(result, Image.Image): images.append(result) - logger.debug( - "Generation completed for request %s, produced %d images", - request_id, - len(images), - ) - return OmniRequestOutput.from_diffusion( request_id=request_id, images=images, diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py new file mode 100644 index 000000000..a1433d4f3 --- /dev/null +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +import socket +from typing import TYPE_CHECKING + +import torch +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.tokenizers import init_tokenizer_from_config +from vllm.tracing import init_tracer +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.usage.usage_lib import UsageContext +from vllm.utils.func_utils import deprecate_kwargs +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager + +from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs +from vllm_omni.engine.input_processor import OmniInputProcessor +from vllm_omni.engine.output_processor import MultimodalOutputProcessor + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class AsyncOmniLLM(AsyncLLM): + """Async single-stage LLM engine for use within a stage worker process. + + This class extends the base vLLM AsyncLLM class with omni-specific + processors for handling multimodal inputs and outputs. It is used + internally by AsyncOmniStage workers and should not be instantiated + directly by users. + + Args: + engine_args: AsyncOmniEngineArgs containing engine configuration + vllm_config: Global vLLM configuration + executor_class: Executor implementation class, e.g. MultiprocExecutor + log_stats: Whether to log statistics + usage_context: Usage context of the LLM (default: ENGINE_CONTEXT) + mm_registry: Multi-modal registry for processing multimodal inputs + use_cached_outputs: Whether to use cached outputs + log_requests: Whether to log requests + start_engine_loop: Whether to start the engine loop automatically + stat_loggers: Customized stat loggers for the engine. + If not provided, default stat loggers will be used. + Note: Stat logger interface may change in V1. + client_addresses: Optional dictionary mapping client names to addresses + client_count: Total number of clients (default: 1) + client_index: Index of this client (default: 0) + """ + + def __init__( + self, + engine_args: AsyncOmniEngineArgs, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + log_requests: bool = True, + start_engine_loop: bool = True, + stat_loggers: list[StatLoggerFactory] | None = None, + client_addresses: dict[str, str] | None = None, + client_count: int = 1, + client_index: int = 0, + ) -> None: + """ + Create an AsyncOmniLLM. + + Args: + vllm_config: global configuration. + executor_class: an Executor impl, e.g. MultiprocExecutor. + log_stats: Whether to log stats. + usage_context: Usage context of the LLM. + mm_registry: Multi-modal registry. + use_cached_outputs: Whether to use cached outputs. + log_requests: Whether to log requests. + start_engine_loop: Whether to start the engine loop. + stat_loggers: customized stat loggers for the engine. + If not provided, default stat loggers will be used. + PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE + IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. + + Returns: + None + """ + # Ensure we can serialize custom transformer configs + maybe_register_config_serialize_by_value() + + self.model_config = vllm_config.model_config + self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config + self.log_requests = log_requests + + self.log_stats = log_stats or (stat_loggers is not None) + if not log_stats and stat_loggers is not None: + logger.info( + "AsyncLLM created with log_stats=False and non-empty custom logger list; " + "enabling logging without default stat loggers" + ) + + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + # Tokenizer (+ ensure liveness if running in another process). + tokenizer = init_tokenizer_from_config(model_config=vllm_config.model_config) + + # InputProcessor (converts Inputs --> EngineCoreRequests). + self.input_processor = OmniInputProcessor( + vllm_config=vllm_config, + tokenizer=tokenizer, + mm_registry=mm_registry, + ) + + # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). + self.output_processor = MultimodalOutputProcessor( + tokenizer=tokenizer, + log_stats=self.log_stats, + engine_core_output_type=engine_args.engine_output_type, + ) + + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer + + # Pause / resume state for async RL workflows. + self._pause_cond = asyncio.Condition() + self._paused = False + + # EngineCore (starts the engine in background process). + self.engine_core = EngineCoreClient.make_async_mp_client( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=self.log_stats, + client_addresses=client_addresses, + client_count=client_count, + client_index=client_index, + ) + + # Loggers. + self.logger_manager: StatLoggerManager | None = None + if self.log_stats: + self.logger_manager = StatLoggerManager( + vllm_config=vllm_config, + engine_idxs=self.engine_core.engine_ranks_managed, + custom_stat_loggers=stat_loggers, + enable_default_loggers=log_stats, + client_count=client_count, + ) + self.logger_manager.log_engine_initialized() + + self.output_handler: asyncio.Task | None = None + try: + # Start output handler eagerly if we are in the asyncio eventloop. + asyncio.get_running_loop() + self._run_output_handler() + except RuntimeError: + pass + + if envs.VLLM_TORCH_PROFILER_DIR: + logger.info( + "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 + envs.VLLM_TORCH_PROFILER_DIR, + ) + worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True + ), + ) + else: + self.profiler = None + + @classmethod + @deprecate_kwargs( + "disable_log_requests", + additional_message=("This argument will have no effect. Use `enable_log_requests` instead."), + ) + def from_vllm_config( + cls, + vllm_config: VllmConfig, + engine_args: AsyncOmniEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: list[StatLoggerFactory] | None = None, + enable_log_requests: bool = False, + disable_log_stats: bool = False, + client_addresses: dict[str, str] | None = None, + client_count: int = 1, + client_index: int = 0, + disable_log_requests: bool = True, # Deprecated, will be removed + ) -> "AsyncLLM": + # Create the LLMEngine. + return cls( + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + start_engine_loop=start_engine_loop, + stat_loggers=stat_loggers, + log_requests=enable_log_requests, + log_stats=not disable_log_stats, + usage_context=usage_context, + client_addresses=client_addresses, + client_count=client_count, + client_index=client_index, + engine_args=engine_args, + ) diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 02499525f..a35fa35f8 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -1,10 +1,44 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing as mp import os +import time +import uuid +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict +from pprint import pformat +from typing import Any -from vllm_omni.diffusion.utils.hf_utils import is_diffusion_model -from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion -from vllm_omni.entrypoints.omni_llm import OmniLLM +import msgspec +from omegaconf import OmegaConf +from vllm.inputs import PromptType +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams + +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.distributed.omni_connectors import ( + get_stage_connector_config, + initialize_orchestrator_connectors, +) +from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector +from vllm_omni.distributed.ray_utils.utils import ( + create_placement_group, + get_ray_queue_class, + try_close_ray, +) +from vllm_omni.entrypoints.log_utils import OrchestratorMetrics +from vllm_omni.entrypoints.omni_stage import OmniStage +from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load +from vllm_omni.entrypoints.utils import ( + get_final_stage_id_for_e2e, + load_stage_configs_from_model, + load_stage_configs_from_yaml, + resolve_model_config_path, +) +from vllm_omni.outputs import OmniRequestOutput + +logger = init_logger(__name__) def _dummy_snapshot_download(model_id): @@ -23,9 +57,25 @@ def omni_snapshot_download(model_id) -> str: class Omni: - """Unified entrypoint for both LLM and Diffusion models for better usability.""" + """Unified entrypoint for both LLM and Diffusion models for better usability. + + Args: + *args: Variable length argument list. + - args[0]: Model name or path to load. + **kwargs: Arbitrary keyword arguments. + - model: Model name or path to load (if not in args). + - stage_configs_path: Optional path to YAML file containing stage + configurations. If None, configurations are loaded from the model. + - log_stats: Whether to enable statistics logging + be written to files with stage-specific suffixes. + + Example: + >>> omni = Omni(model="Qwen/Qwen2.5-Omni-7B") + >>> outputs = omni.generate(prompts="Hello, world!", sampling_params_list=[SamplingParams()]) + >>> print(outputs) + """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: model = args[0] if args else kwargs.get("model", "") assert model != "", "Null model id detected, please specify a model id." model = omni_snapshot_download(model) @@ -33,28 +83,490 @@ def __init__(self, *args, **kwargs): args[0] = model elif kwargs.get("model", "") != "": kwargs["model"] = model - if is_diffusion_model(model): - self.instance: OmniLLM | OmniDiffusion = OmniDiffusion(*args, **kwargs) + + # Stage management attributes + self.stage_list: list[OmniStage] = [] + self._stage_in_queues: list[mp.Queue] = [] + self._stage_out_queues: list[mp.Queue] = [] + self._stages_ready: set[int] = set() + self._ray_pg = None + self._queue_cls = None + self._ctx = None + + # Initialize stages - each stage will create appropriate instance based on stage_type + # Stage workers will automatically create OmniLLM or OmniDiffusion instances + # based on stage_type in YAML config (handled in omni_stage.py) + logger.info(f"Initializing stages for model: {model}") + self._initialize_stages(model, kwargs) + + def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: + """Initialize stage list management. + + Each stage will create appropriate instance (OmniLLM or OmniDiffusion) + based on stage_type in YAML config (handled in omni_stage.py). + """ + init_sleep_seconds = kwargs.get("init_sleep_seconds", 20) + shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) + init_timeout = kwargs.get("init_timeout", 300) + worker_backend = kwargs.get("worker_backend", "multi_process") + ray_address = kwargs.get("ray_address", None) + batch_timeout = kwargs.get("batch_timeout", 10) + stage_configs_path = kwargs.get("stage_configs_path", None) + log_stats = kwargs.get("log_stats", False) + + # Load stage configurations from YAML + if stage_configs_path is None: + self.config_path = resolve_model_config_path(model) + self.stage_configs = load_stage_configs_from_model(model) + if not self.stage_configs: + # TODO: hack here, convert dtype to string to avoid non-premitive omegaconf create error. + if "dtype" in kwargs: + kwargs["dtype"] = str(kwargs["dtype"]) + # TODO: hack, calculate devices based on parallel config. + devices = "0" + if "parallel_config" in kwargs: + num_devices = kwargs["parallel_config"].world_size + for i in range(1, num_devices): + devices += f",{i}" + logger.info(f"model: {model}, kwargs: {kwargs}") + default_stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": { + "process": True, + "devices": devices, + "max_batch_size": 1, + }, + "engine_args": OmegaConf.create(kwargs), + "final_output": True, + "final_output_type": "image", + } + ] + default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" + self.stage_configs = OmegaConf.create(default_stage_cfg) + else: + self.config_path = stage_configs_path + self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) + + # Initialize connectors + self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( + self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes + ) + + # Initialize stats paths + self._enable_stats: bool = bool(log_stats) + + self.worker_backend = worker_backend + self.ray_address = ray_address + self.batch_timeout = batch_timeout + + # Build OmniStage instances in parallel, preserve original order + def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: + idx, cfg = idx_cfg + return idx, OmniStage(cfg) + + with ThreadPoolExecutor(max_workers=min(len(self.stage_configs), max(1, os.cpu_count() or 1))) as executor: + futures = [executor.submit(_build_stage, (idx, cfg)) for idx, cfg in enumerate(self.stage_configs)] + results: list[tuple[int, OmniStage]] = [] + for fut in as_completed(futures): + results.append(fut.result()) + results.sort(key=lambda x: x[0]) + self.stage_list = [st for _, st in results] + self.output_modalities = [st.final_output_type for st in self.stage_list] + logger.debug("[Orchestrator] Loaded %d stages", len(self.stage_list)) + + if self.worker_backend == "ray": + self._queue_cls = get_ray_queue_class() + else: + self._ctx = mp.get_context("spawn") + self._queue_cls = lambda: self._ctx.Queue(maxsize=0) + + self._init_sleep_seconds = max(0, int(init_sleep_seconds)) + self._shm_threshold_bytes = max(0, int(shm_threshold_bytes)) + self._start_stages(model) + # Wait for all stages to report readiness before seeding + self._wait_for_stages_ready(timeout=init_timeout) + + def _start_stages(self, model: str) -> None: + """Start all stage processes.""" + if self.worker_backend == "ray": + # Initialize Ray Cluster + self._ray_pg = create_placement_group( + number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" + ) + + for stage_id, stage in enumerate(self.stage_list): + in_q = self._queue_cls() + out_q = self._queue_cls() + self._stage_in_queues.append(in_q) + self._stage_out_queues.append(out_q) + stage.attach_queues(in_q, out_q) + + stage_connectors_config = get_stage_connector_config( + self.omni_transfer_config, + stage_id, + ) + + stage.init_stage_worker( + model, + shm_threshold_bytes=self._shm_threshold_bytes, + ctx=self._ctx if self.worker_backend != "ray" else None, + batch_timeout=self.batch_timeout, + connectors_config=stage_connectors_config, + worker_backend=self.worker_backend, + ray_placement_group=self._ray_pg, + ) + + logger.debug("[Orchestrator] Stage-%s process started", stage_id) + time.sleep(self._init_sleep_seconds) + + def _wait_for_stages_ready(self, timeout: int = 120) -> None: + """Wait for all stages to report readiness.""" + deadline = time.time() + max(0, int(timeout)) + num_stages = len(self.stage_list) + while len(self._stages_ready) < num_stages and time.time() < deadline: + progressed = False + for stage_id, stage in enumerate(self.stage_list): + if stage_id in self._stages_ready: + continue + result = stage.try_collect() + if result is None: + continue + progressed = True + if result.get("type") == "stage_ready": + self._stages_ready.add(stage_id) + logger.info("[Orchestrator] Stage-%s reported ready", stage_id) + else: + # No user data should arrive before seeding; ignore other messages + pass + if not progressed: + time.sleep(0.01) + if len(self._stages_ready) < num_stages: + not_ready = sorted(set(range(num_stages)) - set(self._stages_ready)) + logger.warning( + "[Orchestrator] Initialization timeout: only %s/%s stages are ready; not ready: %s", + len(self._stages_ready), + num_stages, + not_ready, + ) + # Provide actionable suggestions before shutdown + try: + suggestions = [ + "Verify GPU/device assignment in config (runtime.devices) is correct.", + "Check GPU/host memory availability; reduce model or batch size if needed.", + "Check model weights path and network reachability (if loading remotely).", + "Increase initialization wait time (init_sleep_seconds or call-site timeout).", + ] + logger.error( + "[Orchestrator] Stage initialization failed, shutting down. Suggestions:\n- %s", + "\n- ".join(suggestions), + ) + except Exception: + # Best-effort logging of suggestions + logger.error( + "[Orchestrator] Stage initialization failed and an error occurred while logging suggestions", + ) + elif len(self._stages_ready) == num_stages: + logger.info("[Orchestrator] All stages initialized successfully") + + def generate(self, *args: Any, **kwargs: dict[str, Any]) -> list[OmniRequestOutput]: + """Generate outputs for the given prompts. + + Orchestrates the multi-stage pipeline based on YAML configuration. + Each stage will use OmniLLM or OmniDiffusion based on stage_type. + + Args: + *args: Variable length argument list. + - args[0]: Input prompts for generation. + - args[1]: Optional list of per-stage parameters. + **kwargs: Arbitrary keyword arguments. + - prompt: Input prompts for generation (if not in args). + - sampling_params_list: Optional list of per-stage parameters (if not in args). + + Returns: + List of OmniRequestOutput objects, one for each input prompt. + Each output contains the stage_id, final_output_type, and + the request_output from the final stage. + + Raises: + ValueError: If sampling_params_list is None or has incorrect length. + """ + prompts = args[0] if args else kwargs.get("prompts") + sampling_params_list = args[1] if len(args) > 1 else kwargs.get("sampling_params_list") + if prompts is None: + if kwargs.get("prompt") is None: + raise ValueError("prompts is required for generation") + prompts = kwargs.get("prompt") + + if sampling_params_list is None: + omni_params_kwargs = {k: v for k, v in kwargs.items() if k not in ["prompts", "sampling_params_list"]} + + per_stage_params: list[Any] = [] + for stage in self.stage_list: + stage_type = getattr(stage, "stage_type", "llm") + default_dict = msgspec.to_builtins(getattr(stage, "default_sampling_params", {})) + # Merge user-provided kwargs + merged = {**default_dict, **omni_params_kwargs} + if stage_type == "diffusion": + # Diffusion only needs to keep diff params, will be used via OmniDiffusionRequest + per_stage_params.append(merged) + else: + # LLM directly constructs SamplingParams + per_stage_params.append(SamplingParams(**merged)) + + sampling_params_list = per_stage_params + return self._run_generation(prompts, sampling_params_list) + + def _run_generation( + self, + prompts: PromptType | Sequence[PromptType] | OmniDiffusionRequest | Sequence[OmniDiffusionRequest], + sampling_params_list: Any | Sequence[Any] | None = None, + ) -> list[OmniRequestOutput]: + """Run generation through all stages in the pipeline.""" + logger.debug("[Orchestrator] generate() called") + if sampling_params_list is None: + raise ValueError("sampling_params_list is required for pipelined generation") + + # Normalize sampling_params_list to a list + if not isinstance(sampling_params_list, (list, tuple)): + sampling_params_list = [sampling_params_list] + else: + sampling_params_list = list(sampling_params_list) + + if len(sampling_params_list) != len(self.stage_list): + raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + + # Normalize prompts to a list for per-request iteration + if not isinstance(prompts, (list, tuple)): + request_prompts: list[PromptType] = [prompts] else: - self.instance: OmniLLM | OmniDiffusion = OmniLLM(*args, **kwargs) + request_prompts = list(prompts) + + final_outputs: list[OmniRequestOutput] = [] + + # Orchestrator keeps stage objects for input derivation + num_stages = len(self.stage_list) + + # Generate globally unique request IDs and map them to original prompts + request_ids: list[str] = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] + request_id_to_prompt: dict[str, PromptType] = {rid: p for rid, p in zip(request_ids, request_prompts)} + + # Track per-request start time for end-to-end timing + _req_start_ts: dict[str, float] = {} + _wall_start_ts: float = time.time() + + # Determine the final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) + final_stage_id_to_prompt: dict[str, int] = {} + for rid, prompt in request_id_to_prompt.items(): + if isinstance(prompt, dict): + prompt_modalities = prompt.get("modalities", None) + else: + prompt_modalities = None + final_stage_id_for_e2e = get_final_stage_id_for_e2e( + prompt_modalities, self.output_modalities, self.stage_list + ) + final_stage_id_to_prompt[rid] = final_stage_id_for_e2e + + # Metrics/aggregation helper + metrics = OrchestratorMetrics( + num_stages, + self._enable_stats, + _wall_start_ts, + ) + + # Seed stage-0 queue with all requests + logger.debug("[Orchestrator] Seeding %d requests into stage-0", len(request_prompts)) + # Mark first input time for stage-0 + metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() - def __getattr__(self, name): - """Delegate attribute access to the chosen backend instance.""" - return getattr(self.instance, name) + for req_id, prompt in request_id_to_prompt.items(): + sp0 = sampling_params_list[0] # type: ignore[index] + task = { + "request_id": req_id, + "engine_inputs": prompt, + "sampling_params": sp0, + } + self.stage_list[0].submit(task) + _req_start_ts[req_id] = time.time() + logger.debug("[Orchestrator] Enqueued request %s to stage-0", req_id) - def generate(self, *args, **kwargs): - """Convenience wrapper to call `generate` on the backend if available.""" - if hasattr(self.instance, "generate"): - return getattr(self.instance, "generate")(*args, **kwargs) - raise AttributeError(f"'{self.instance.__class__.__name__}' has no attribute 'generate'") + # For each stage, forward results to next stage; collect finals at the end + # We pipeline by continually polling output queues in stage order + remaining_by_stage: list[int] = [len(request_prompts)] + [0] * (num_stages - 1) + completed_requests = 0 + total_requests = len(request_prompts) + + logger.debug( + "[Orchestrator] Entering scheduling loop: total_requests=%d, stages=%d", + total_requests, + num_stages, + ) + while completed_requests < total_requests: + made_progress = False + for stage_id, stage in enumerate(self.stage_list): + result = stage.try_collect() + if result is None: + continue + + made_progress = True + req_id = result.get("request_id") + if "error" in result: + logger.error( + "Stage %s error on request %s: %s", + stage_id, + req_id, + result["error"], + ) + continue + + if result.get("type") == "stage_ready": + # Only happens when stage is initialized slower than expected, + # so we wait for a short time and try again + time.sleep(0.05) + continue + + engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + # Mark last output time for this stage whenever we receive outputs + metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) + try: + _m = asdict(result.get("metrics")) + if _m is not None: + metrics.on_stage_metrics(stage_id, req_id, _m) + except Exception as e: + logger.exception( + "[Orchestrator] Failed to process metrics for stage %s, req %s: %s", + stage_id, + req_id, + e, + ) + logger.debug( + "[Orchestrator] Stage-%s completed request %s; forwarding or finalizing", + stage_id, + req_id, + ) + stage.set_engine_outputs(engine_outputs) + + if getattr(stage, "final_output", False): + final_outputs.append( + OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, # type: ignore[attr-defined] + request_output=engine_outputs, + ) + ) + logger.debug( + "[Orchestrator] Request %s finalized at stage-%s", + req_id, + stage_id, + ) + + # End-to-end timing and time-per-token for final output + # (only once per request at the designated final stage) + try: + rid_key = str(req_id) + if stage_id == final_stage_id_to_prompt[req_id] and rid_key not in metrics.e2e_done: + metrics.on_finalize_request( + stage_id, + req_id, + engine_outputs, + _req_start_ts.get(req_id, _wall_start_ts), + ) + except Exception as e: + logger.exception( + "[Orchestrator] Finalize request handling error for req %s at stage %s: %s", + req_id, + stage_id, + e, + ) + + next_stage_id = stage_id + 1 + if next_stage_id <= final_stage_id_to_prompt[req_id]: + next_stage: OmniStage = self.stage_list[next_stage_id] + try: + next_inputs = next_stage.process_engine_inputs(self.stage_list, [request_id_to_prompt[req_id]]) + except Exception as e: + logger.exception( + "[Orchestrator] Process engine inputs error for req %s at stage %s: %s", + req_id, + next_stage_id, + e, + ) + continue + sp_next = sampling_params_list[next_stage_id] # type: ignore[index] + + # Check if we have a connector for this edge + connector_key = (str(stage_id), str(next_stage_id)) + connector = self.connectors.get(connector_key) + sent_via_connector = False + if connector: + sent_via_connector = try_send_via_connector( + connector=connector, + stage_id=stage_id, + next_stage_id=next_stage_id, + req_id=req_id, + next_inputs=next_inputs, + sampling_params=sp_next, + original_prompt=request_id_to_prompt[req_id], + next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, + metrics=metrics, + ) + + if not sent_via_connector: + raise RuntimeError( + f"[Orchestrator] Failed to send request {req_id} to stage-{next_stage_id} via connector. " + "Configure a connector for this edge or inspect connector logs for details." + ) + logger.debug( + "[Orchestrator] Forwarded request %s to stage-%s", + req_id, + next_stage_id, + ) + remaining_by_stage[next_stage_id] += 1 + else: + completed_requests += 1 + logger.debug( + "[Orchestrator] Request %s fully completed (%d/%d)", + req_id, + completed_requests, + total_requests, + ) + + if not made_progress: + time.sleep(0.005) + logger.debug("[Orchestrator] All requests completed") + + # Summarize and print stats + try: + summary = metrics.build_and_log_summary(final_stage_id_to_prompt) + logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) + except Exception as e: + logger.exception("[Orchestrator] Failed to build/log summary: %s", e) + + return final_outputs def close(self) -> None: - close_method = getattr(self.instance, "close", None) - if callable(close_method): - close_method() + """Close all stage processes and clean up resources.""" + # Close stages if they exist (for LLM models) + if self.stage_list: + for q in self._stage_in_queues: + try: + q.put_nowait(None) + except Exception as e: + logger.warning( + "[Orchestrator] Failed to send shutdown signal to stage input queue: %s", + e, + ) + for stage in self.stage_list: + try: + stage.stop_stage_worker() + except Exception as e: + logger.warning("[Orchestrator] Failed to stop stage worker: %s", e) + + try_close_ray(self._ray_pg) def __del__(self): # pragma: no cover - best effort cleanup try: self.close() except Exception: - pass + logger.debug("[Orchestrator] __del__ close() raised", exc_info=True) diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index d2dcba697..05a48feee 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -1,11 +1,3 @@ -import multiprocessing as mp -import os -import time -import uuid -from collections.abc import Sequence -from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import asdict -from pprint import pformat from typing import Any import cloudpickle @@ -14,74 +6,57 @@ # External library imports (vLLM) from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field from vllm.entrypoints.llm import LLM -from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.plugins.io_processors import get_io_processor -from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext from vllm.utils.counter import Counter from vllm.v1.engine.llm_engine import LLMEngine -from vllm_omni.distributed.omni_connectors import ( - get_stage_connector_config, - initialize_orchestrator_connectors, -) +from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors # Internal imports (our code) -from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector -from vllm_omni.distributed.ray_utils.utils import ( - create_placement_group, - get_ray_queue_class, - try_close_ray, -) from vllm_omni.engine.arg_utils import OmniEngineArgs from vllm_omni.engine.input_processor import OmniInputProcessor from vllm_omni.engine.output_processor import MultimodalOutputProcessor -from vllm_omni.entrypoints.log_utils import ( - OrchestratorMetrics, -) -from vllm_omni.entrypoints.omni_stage import OmniStage -from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load from vllm_omni.entrypoints.utils import ( - get_final_stage_id_for_e2e, load_stage_configs_from_model, load_stage_configs_from_yaml, resolve_model_config_path, ) -from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) -class OmniLLM: +class OmniLLM(LLM): """Main entry point for vLLM-Omni inference. - This class provides a high-level interface for running multi-modal - comprehension and generation models. It orchestrates multiple - stages in a pipeline, where each stage runs in a separate process - with copy-based IPC (Queues). Downstream stages start when upstream - stages finish (window=-1), and each stage processes requests serially - (max_inflight=1) but pipelines across stages. + This class extends the base vLLM LLM class with omni-specific + processors for handling multimodal inputs and outputs. It provides + configuration loading for multi-stage pipelines, while stage management + is handled by the Omni class. Args: model: Model name or path to load stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. log_stats: Whether to enable statistics logging + compilation_config: Optional compilation configuration. Can be an + integer (compilation level), dict, or CompilationConfig instance. + hf_overrides: Optional HuggingFace model configuration overrides + structured_outputs_config: Optional structured outputs configuration. + Can be a dict or StructuredOutputsConfig instance. init_sleep_seconds: Number of seconds to sleep between starting - each stage process during initialization + each stage process during initialization (used by Omni class) shm_threshold_bytes: Threshold in bytes for using shared memory for IPC. Objects larger than this threshold will use shared memory. batch_timeout: Timeout in seconds for batching requests within a stage init_timeout: Timeout in seconds for waiting for all stages to initialize - **kwargs: Additional keyword arguments passed to stage engines + **kwargs: Additional keyword arguments passed to the base LLM class + and engine Example: >>> llm = OmniLLM(model="Qwen/Qwen2.5-Omni-7B") - >>> outputs = llm.generate( - ... prompts="Hello", - ... sampling_params_list=[SamplingParams(), SamplingParams()] - ... ) + >>> # Stage management is handled by Omni class """ def __init__( @@ -89,19 +64,23 @@ def __init__( model: str, stage_configs_path: str | None = None, log_stats: bool = False, + compilation_config: int | dict[str, Any] | CompilationConfig | None = None, + hf_overrides: dict[str, Any] | None = None, + structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None, init_sleep_seconds: int = 20, shm_threshold_bytes: int = 65536, batch_timeout: int = 10, init_timeout: int = 300, **kwargs: Any, ): + """LLM constructor with omni-specific configuration loading.""" + # Store stage management parameters (used by Omni class) self.worker_backend = kwargs.get("worker_backend", "multi_process") self.ray_address = kwargs.get("ray_address", None) - self._ray_pg = None self.batch_timeout = batch_timeout self._enable_stats: bool = bool(log_stats) - # Do NOT call super().__init__ to avoid creating OmniStageLLM instances in parent. + # Load stage configurations if stage_configs_path is None: self.config_path = resolve_model_config_path(model) self.stage_configs = load_stage_configs_from_model(model) @@ -114,448 +93,7 @@ def __init__( self.config_path, worker_backend=self.worker_backend, shm_threshold_bytes=shm_threshold_bytes ) - self.stage_list: list[OmniStage] = [] - self.default_sampling_params_list: list[SamplingParams] = [] - - self._initialize_stages(model, init_sleep_seconds, shm_threshold_bytes, init_timeout) - - def _initialize_stages( - self, - model: str, - init_sleep_seconds: int, - shm_threshold_bytes: int, - init_timeout: int, - ) -> None: - self.stage_list: list[OmniStage] = [] - self.default_sampling_params_list: list[SamplingParams] = [] - - # Build OmniStage instances in parallel, preserve original order - def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: - idx, cfg = idx_cfg - return idx, OmniStage(cfg) - - with ThreadPoolExecutor(max_workers=min(len(self.stage_configs), max(1, os.cpu_count() or 1))) as executor: - futures = [executor.submit(_build_stage, (idx, cfg)) for idx, cfg in enumerate(self.stage_configs)] - results: list[tuple[int, OmniStage]] = [] - for fut in as_completed(futures): - results.append(fut.result()) - results.sort(key=lambda x: x[0]) - self.stage_list = [st for _, st in results] - self.default_sampling_params_list = [st.default_sampling_params for st in self.stage_list] - self.output_modalities = [st.final_output_type for st in self.stage_list] - logger.debug("[Orchestrator] Loaded %d stages", len(self.stage_list)) - - if self.worker_backend == "ray": - self._queue_cls = get_ray_queue_class() - else: - self._ctx = mp.get_context("spawn") - self._queue_cls = lambda: self._ctx.Queue(maxsize=0) - - self._stage_in_queues: list[mp.Queue] = [] - self._stage_out_queues: list[mp.Queue] = [] - self._init_sleep_seconds = max(0, int(init_sleep_seconds)) - self._shm_threshold_bytes = max(0, int(shm_threshold_bytes)) - self._start_stages(model) - # Wait for all stages to report readiness before seeding - self._stages_ready: set[int] = set() - self._wait_for_stages_ready(timeout=init_timeout) - - def _start_stages(self, model: str) -> None: - if self.worker_backend == "ray": - # Initialize Ray Cluster - self._ray_pg = create_placement_group( - number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" - ) - - for stage_id, stage in enumerate(self.stage_list): - in_q = self._queue_cls() - out_q = self._queue_cls() - self._stage_in_queues.append(in_q) - self._stage_out_queues.append(out_q) - stage.attach_queues(in_q, out_q) - - stage_connectors_config = get_stage_connector_config( - self.omni_transfer_config, - stage_id, - ) - - stage.init_stage_worker( - model, - shm_threshold_bytes=self._shm_threshold_bytes, - ctx=self._ctx if self.worker_backend != "ray" else None, - batch_timeout=self.batch_timeout, - connectors_config=stage_connectors_config, - worker_backend=self.worker_backend, - ray_placement_group=self._ray_pg, - ) - - logger.debug("[Orchestrator] Stage-%s process started", stage_id) - time.sleep(self._init_sleep_seconds) - - def close(self) -> None: - """Close all stage processes and clean up resources. - - Sends shutdown signals to all stage input queues and stops - all stage worker processes. This method should be called - when done using the OmniLLM instance. - """ - for q in self._stage_in_queues: - try: - q.put_nowait(None) - except Exception as e: - logger.warning( - "[Orchestrator] Failed to send shutdown signal to stage input queue: %s", - e, - ) - for stage in self.stage_list: - try: - stage.stop_stage_worker() - except Exception as e: - logger.warning("[Orchestrator] Failed to stop stage worker: %s", e) - - try_close_ray(self._ray_pg) - - def __del__(self) -> None: # best-effort - try: - self.close() - except Exception as e: - logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True) - - def generate( - self, - prompts: PromptType | Sequence[PromptType], - sampling_params_list: SamplingParams | Sequence[SamplingParams] | None = None, - ) -> list[OmniRequestOutput]: - """Generate outputs for the given prompts. - - Processes prompts through all stages in the pipeline and returns - the final outputs. Each stage uses its corresponding sampling - parameters from the sampling_params_list. - - Args: - prompts: Single prompt or sequence of prompts to process. - Can be text strings, token IDs, or multimodal prompts. - sampling_params_list: List of SamplingParams, one for each stage. - Must have the same length as the number of stages. - Required for pipelined generation. - - Returns: - List of OmniRequestOutput objects, one for each input prompt. - Each output contains the stage_id, final_output_type, and - the request_output from the final stage. - - Raises: - ValueError: If sampling_params_list is None or has incorrect length. - """ - try: - if sampling_params_list is None: - sampling_params_list = self.default_sampling_params_list - return self._run_generation(prompts, sampling_params_list) - except Exception as e: - logger.exception("[Orchestrator] Failed to run generation: %s", e) - raise e - finally: - self.close() - - def _run_generation( - self, - prompts: PromptType | Sequence[PromptType], - sampling_params_list: SamplingParams | Sequence[SamplingParams] | None = None, - ) -> list[OmniRequestOutput]: - logger.debug("[Orchestrator] generate() called") - if sampling_params_list is None: - raise ValueError("sampling_params_list is required for pipelined generation") - if len(sampling_params_list) != len(self.stage_list): - raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") - - # Normalize prompts to a list for per-request iteration - if not isinstance(prompts, (list, tuple)): - request_prompts: list[PromptType] = [prompts] - else: - request_prompts = list(prompts) - - final_outputs: list[OmniRequestOutput] = [] - - # Orchestrator keeps stage objects for input derivation - num_stages = len(self.stage_list) - - # Generate globally unique request IDs and map them to original prompts - request_ids: list[str] = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] - request_id_to_prompt: dict[str, PromptType] = {rid: p for rid, p in zip(request_ids, request_prompts)} - - # Track per-request start time for end-to-end timing - _req_start_ts: dict[str, float] = {} - _wall_start_ts: float = time.time() - - # Determine the final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) - final_stage_id_to_prompt = {} - for rid, prompt in request_id_to_prompt.items(): - if isinstance(prompt, dict): - prompt_modalities = prompt.get("modalities", None) - else: - prompt_modalities = None - final_stage_id_for_e2e = get_final_stage_id_for_e2e( - prompt_modalities, self.output_modalities, self.stage_list - ) - final_stage_id_to_prompt[rid] = final_stage_id_for_e2e - - # Metrics/aggregation helper - metrics = OrchestratorMetrics( - num_stages, - self._enable_stats, - _wall_start_ts, - ) - - # Seed stage-0 queue with all requests - logger.debug("[Orchestrator] Seeding %d requests into stage-0", len(request_prompts)) - # Mark first input time for stage-0 - metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() - - for req_id, prompt in request_id_to_prompt.items(): - sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] - task = { - "request_id": req_id, - "engine_inputs": prompt, - "sampling_params": sp0, - } - self.stage_list[0].submit(task) - _req_start_ts[req_id] = time.time() - logger.debug("[Orchestrator] Enqueued request %s to stage-0", req_id) - - # For each stage, forward results to next stage; collect finals at the end - # We pipeline by continually polling output queues in stage order - remaining_by_stage: list[int] = [len(request_prompts)] + [0] * (num_stages - 1) - completed_requests = 0 - total_requests = len(request_prompts) - - logger.debug( - "[Orchestrator] Entering scheduling loop: total_requests=%d, stages=%d", - total_requests, - num_stages, - ) - while completed_requests < total_requests: - made_progress = False - for stage_id, stage in enumerate(self.stage_list): - result = stage.try_collect() - if result is None: - continue - - made_progress = True - req_id = result.get("request_id") - if "error" in result: - error_msg = result.get("error", "Unknown error") - error_tb = result.get("error_tb", "") - logger.error( - "Stage %s error on request %s: %s\n%s", - stage_id, - req_id, - error_msg, - error_tb, - ) - continue - - if result.get("type") == "stage_ready": - # Only happens when stage is initialized slower than expected, - # so we wait for a short time and try again - time.sleep(0.05) - continue - - engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") - # Mark last output time for this stage whenever we receive outputs - metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) - try: - _m = asdict(result.get("metrics")) - if _m is not None: - metrics.on_stage_metrics(stage_id, req_id, _m) - except Exception as e: - logger.exception( - "[Orchestrator] Failed to process metrics for stage %s, req %s: %s", - stage_id, - req_id, - e, - ) - logger.debug( - "[Orchestrator] Stage-%s completed request %s; forwarding or finalizing", - stage_id, - req_id, - ) - stage.set_engine_outputs(engine_outputs) - - if getattr(stage, "final_output", False): - final_outputs.append( - OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, # type: ignore[attr-defined] - request_output=engine_outputs, - ) - ) - logger.debug( - "[Orchestrator] Request %s finalized at stage-%s", - req_id, - stage_id, - ) - - # End-to-end timing and time-per-token for final output - # (only once per request at the designated final stage) - try: - rid_key = str(req_id) - if stage_id == final_stage_id_to_prompt[req_id] and rid_key not in metrics.e2e_done: - metrics.on_finalize_request( - stage_id, - req_id, - engine_outputs, - _req_start_ts.get(req_id, _wall_start_ts), - ) - except Exception as e: - logger.exception( - "[Orchestrator] Finalize request handling error for req %s at stage %s: %s", - req_id, - stage_id, - e, - ) - - next_stage_id = stage_id + 1 - if next_stage_id <= final_stage_id_to_prompt[req_id]: - next_stage: OmniStage = self.stage_list[next_stage_id] - try: - next_inputs = next_stage.process_engine_inputs(self.stage_list, [request_id_to_prompt[req_id]]) - except Exception as e: - logger.exception( - "[Orchestrator] Process engine inputs error for req %s at stage %s: %s", - req_id, - next_stage_id, - e, - ) - continue - sp_next: SamplingParams = sampling_params_list[next_stage_id] # type: ignore[index] - - # Check if we have a connector for this edge - connector_key = (str(stage_id), str(next_stage_id)) - connector = self.connectors.get(connector_key) - sent_via_connector = False - if connector: - sent_via_connector = try_send_via_connector( - connector=connector, - stage_id=stage_id, - next_stage_id=next_stage_id, - req_id=req_id, - next_inputs=next_inputs, - sampling_params=sp_next, - original_prompt=request_id_to_prompt[req_id], - next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, - metrics=metrics, - ) - - if not sent_via_connector: - raise RuntimeError( - f"[Orchestrator] Failed to send request {req_id} to stage-{next_stage_id} via connector. " - "Configure a connector for this edge or inspect connector logs for details." - ) - logger.debug( - "[Orchestrator] Forwarded request %s to stage-%s", - req_id, - next_stage_id, - ) - remaining_by_stage[next_stage_id] += 1 - else: - completed_requests += 1 - logger.debug( - "[Orchestrator] Request %s fully completed (%d/%d)", - req_id, - completed_requests, - total_requests, - ) - - if not made_progress: - time.sleep(0.005) - logger.debug("[Orchestrator] All requests completed") - - # Summarize and print stats - try: - summary = metrics.build_and_log_summary(final_stage_id_to_prompt) - logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) - except Exception as e: - logger.exception("[Orchestrator] Failed to build/log summary: %s", e) - - return final_outputs - - def _wait_for_stages_ready(self, timeout: int = 120) -> None: - deadline = time.time() + max(0, int(timeout)) - num_stages = len(self.stage_list) - while len(self._stages_ready) < num_stages and time.time() < deadline: - progressed = False - for stage_id, stage in enumerate(self.stage_list): - if stage_id in self._stages_ready: - continue - result = stage.try_collect() - if result is None: - continue - progressed = True - if result.get("type") == "stage_ready": - self._stages_ready.add(stage_id) - logger.info("[Orchestrator] Stage-%s reported ready", stage_id) - else: - # No user data should arrive before seeding; ignore other messages - pass - if not progressed: - time.sleep(0.01) - if len(self._stages_ready) < num_stages: - not_ready = sorted(set(range(num_stages)) - set(self._stages_ready)) - logger.warning( - "[Orchestrator] Initialization timeout: only %s/%s stages are ready; not ready: %s", - len(self._stages_ready), - num_stages, - not_ready, - ) - # Provide actionable suggestions before shutdown - try: - suggestions = [ - "Verify GPU/device assignment in config (runtime.devices) is correct.", - "Check GPU/host memory availability; reduce model or batch size if needed.", - "Check model weights path and network reachability (if loading remotely).", - "Increase initialization wait time (init_sleep_seconds or call-site timeout).", - ] - logger.error( - "[Orchestrator] Stage initialization failed, shutting down. Suggestions:\n- %s", - "\n- ".join(suggestions), - ) - except Exception: - # Best-effort logging of suggestions - logger.error( - "[Orchestrator] Stage initialization failed and an error occurred while logging suggestions", - ) - elif len(self._stages_ready) == num_stages: - logger.info("[Orchestrator] All stages initialized successfully") - - -class OmniStageLLM(LLM): - """Single-stage LLM engine for use within a stage worker process. - - This class extends the base vLLM LLM class with omni-specific - processors for handling multimodal inputs and outputs. It is used - internally by OmniStage workers and should not be instantiated directly - by users. - - Args: - model: Model name or path to load - compilation_config: Optional compilation configuration. Can be an - integer (compilation level), dict, or CompilationConfig instance. - hf_overrides: Optional HuggingFace model configuration overrides - structured_outputs_config: Optional structured outputs configuration. - Can be a dict or StructuredOutputsConfig instance. - **kwargs: Additional keyword arguments passed to the base LLM class - and engine - """ - - def __init__( - self, - model: str, - compilation_config: int | dict[str, Any] | CompilationConfig | None = None, - hf_overrides: dict[str, Any] | None = None, - structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None, - **kwargs: Any, - ): - """LLM constructor.""" + # Initialize LLM engine if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True @@ -578,8 +116,6 @@ def __init__( raw_config_dict, e, ) - # Consider re-raising a more specific vLLM error or ValueError - # to provide better context to the user. raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e if compilation_config is not None: @@ -637,3 +173,20 @@ def __init__( self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin) self.model_config = self.llm_engine.model_config self.input_processor = self.llm_engine.input_processor + + def close(self) -> None: + """Close resources. + + Note: Stage management is now handled by Omni class. + This method closes the LLM engine but not stages. + """ + # Close the LLM engine if it exists + if hasattr(self, "llm_engine") and self.llm_engine is not None: + if hasattr(self.llm_engine, "shutdown"): + self.llm_engine.shutdown() + + def __del__(self) -> None: # best-effort + try: + self.close() + except Exception as e: + logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 692c122f8..4652a921d 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -1,10 +1,6 @@ """ Stage manager for orchestrating multiple engines in vLLM-Omni. -Enhanced to encapsulate per-stage process lifecycle and worker logic -(device setup, LLM init, batching, shared-memory IPC), while preserving -the original input processing utilities for cross-stage data wiring. - Enhanced to encapsulate per-stage process lifecycle and worker logic (device setup, LLM init, batching, shared-memory IPC), while preserving the original input processing utilities for cross-stage data wiring. @@ -30,18 +26,55 @@ from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.llm_engine import LLMEngine +from vllm_omni.distributed.omni_connectors import build_stage_connectors +from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector from vllm_omni.distributed.ray_utils.utils import kill_ray_actor, start_ray_actor from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs +from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion +from vllm_omni.entrypoints.async_omni_llm import AsyncOmniLLM +from vllm_omni.entrypoints.log_utils import count_tokens_from_outputs +from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion +from vllm_omni.entrypoints.omni_llm import OmniLLM from vllm_omni.entrypoints.stage_utils import ( _to_dict, maybe_dump_to_shm, set_stage_devices, ) from vllm_omni.inputs.data import OmniTokensPrompt +from vllm_omni.utils import detect_device_type logger = init_logger(__name__) +def prepare_sampling_params(sampling_params: Any, stage_type: str) -> Any: + """Prepare sampling parameters for the given stage type. + + Args: + sampling_params: Raw sampling parameters (dict or SamplingParams) + stage_type: Either "llm" or "diffusion" + + Returns: + Processed sampling parameters ready for engine consumption + """ + if stage_type == "diffusion": + # For diffusion stages: extract kwargs, handling different input types + if isinstance(sampling_params, dict): + diffusion_kwargs = dict(sampling_params) + else: + diffusion_kwargs = getattr(sampling_params, "__dict__", {}) or {} + + # Remove 'prompt' and 'request_id' to avoid conflict with explicit arguments + diffusion_kwargs.pop("prompt", None) + diffusion_kwargs.pop("request_id", None) + return diffusion_kwargs + + else: # stage_type == "llm" + # For LLM stages: ensure we have a SamplingParams object + if isinstance(sampling_params, dict): + return SamplingParams(**sampling_params) + return sampling_params + + class OmniStage: """Stage manager for orchestrating a single stage in the omni pipeline. @@ -55,6 +88,7 @@ class OmniStage: """ def __init__(self, stage_config: Any): + logger.info(f"[OmniStage] stage_config: {stage_config}") self.stage_config = stage_config self.engine = None self.async_engine = None @@ -67,9 +101,11 @@ def __init__(self, stage_config: Any): self.model_stage = stage_config.engine_args.model_stage self.requires_multimodal_data = getattr(stage_config.runtime, "requires_multimodal_data", False) self.engine_input_source = getattr(stage_config, "engine_input_source", []) - self.engine_output_type = stage_config.engine_args.engine_output_type + self.engine_output_type = getattr(stage_config.engine_args, "engine_output_type", None) self.engine_outputs = None self.is_comprehension = getattr(stage_config, "is_comprehension", False) + # Support for different stage types: "llm" (default) or "diffusion" + self.stage_type = getattr(stage_config, "stage_type", "llm") if hasattr(stage_config, "custom_process_input_func"): # Import the module specified in the config (already a full module path) module_path, func_name = stage_config.custom_process_input_func.rsplit(".", 1) @@ -81,7 +117,9 @@ def __init__(self, stage_config: Any): self.final_output = getattr(stage_config, "final_output", False) self.final_output_type = getattr(stage_config, "final_output_type", None) default_sampling_params = getattr(stage_config, "default_sampling_params", {}) - self.default_sampling_params = SamplingParams(**_to_dict(default_sampling_params)) + # For LLM stage, this can directly be a SamplingParams-compatible dict; + # For diffusion stage, this only serves as default values for diffusion kwargs. + self.default_sampling_params = _to_dict(default_sampling_params) # Runtime orchestration state (added) self._in_q: mp.Queue | None = None self._out_q: mp.Queue | None = None @@ -204,6 +242,7 @@ def init_stage_worker( "runtime": runtime_cfg, "shm_threshold_bytes": self._shm_threshold_bytes, "connectors_config": connectors_config or {}, + "stage_type": self.stage_type, } try: old_env = os.environ.get("VLLM_LOGGING_PREFIX") @@ -369,19 +408,17 @@ def _stage_worker( batch_timeout: int = 10, ) -> None: """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" + # Use local aliases to avoid conflicts with global imports in worker process + logger.info(f"Starting stage worker with model: {model}") import os as _os import time as _time - from vllm_omni.distributed.omni_connectors import build_stage_connectors - from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector - from vllm_omni.entrypoints.omni_llm import OmniStageLLM - # no inline JSONL/serialization imports; logging handled by utilities - stage_id = stage_payload["stage_id"] engine_args = stage_payload.get("engine_args", {}) runtime_cfg = stage_payload.get("runtime", {}) shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) connectors_config = stage_payload.get("connectors_config", {}) + stage_type = stage_payload.get("stage_type", "llm") # Aggregates for running average _agg_total_tokens = 0 @@ -392,8 +429,6 @@ def _stage_worker( # Device mapping device_type = None try: - from vllm_omni.utils import detect_device_type - device_type = detect_device_type() set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: @@ -408,11 +443,20 @@ def _stage_worker( import torch if torch.cuda.is_available(): - # Get all parallel sizes from engine_args (defaults to 1) - tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) - pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) - data_parallel_size = engine_args.get("data_parallel_size", 1) - prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) + # Get all parallel sizes from engine_args or parallel_config (defaults to 1) + if "parallel_config" in engine_args: + parallel_config = engine_args["parallel_config"] + tensor_parallel_size = parallel_config.get("tensor_parallel_size", 1) + pipeline_parallel_size = parallel_config.get("pipeline_parallel_size", 1) + data_parallel_size = parallel_config.get("data_parallel_size", 1) + prefill_context_parallel_size = 1 # not used for diffusion + sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1) + else: + tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) + data_parallel_size = engine_args.get("data_parallel_size", 1) + prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) + sequence_parallel_size = 1 # not use in omni model # Calculate total number of devices needed for this stage # For a single stage worker: @@ -420,10 +464,15 @@ def _stage_worker( # - PP: splits layers across pipelinestages, but each stage uses TP devices # - DP: replicates model, but each replica uses TP devices # - PCP: context parallelism, typically uses TP devices - # The number of devices per stage is determined by TP * PP * DP * PCP size + # - SP: sequence parallelism, typically uses TP devices + # The number of devices per stage is determined by TP * PP * DP * PCP * SP size # (PP/DP/PCP are higher-level parallelism that don't add devices per stage) num_devices_per_stage = ( - tensor_parallel_size * pipeline_parallel_size * data_parallel_size * prefill_context_parallel_size + tensor_parallel_size + * pipeline_parallel_size + * data_parallel_size + * prefill_context_parallel_size + * sequence_parallel_size ) # Get physical device IDs from CUDA_VISIBLE_DEVICES @@ -450,11 +499,12 @@ def _stage_worker( devices_to_lock = sorted(devices_to_lock) logger.debug( - "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d; will lock %d devices: %s", + "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d; will lock %d devices: %s", tensor_parallel_size, pipeline_parallel_size, data_parallel_size, prefill_context_parallel_size, + sequence_parallel_size, num_devices_to_lock, devices_to_lock, ) @@ -513,12 +563,16 @@ def _stage_worker( lock_files = acquired_lock_fds except Exception as e: - logger.debug("Failed to set up sequential initialization lock: %s", e) - - # Init LLM - logger.debug("Initializing engine with args keys=%s", list(engine_args.keys())) + logger.debug("[Stage-%s] Failed to set up sequential initialization lock: %s", stage_id, e) + # Init engine based on stage_type + logger.debug("[Stage-%s] Initializing %s engine with args keys=%s", stage_id, stage_type, list(engine_args.keys())) try: - stage_engine = OmniStageLLM(model=model, **engine_args) + if stage_type == "diffusion": + engine_args.pop("model_stage") + stage_engine = OmniDiffusion(**engine_args) + else: + # Default to LLM engine + stage_engine = OmniLLM(model=model, **engine_args) finally: # Release all locks by closing file descriptors # Locks are automatically released when file descriptors are closed @@ -554,6 +608,7 @@ def _stage_worker( # Batch processing loop while True: task = in_q.get() + _recv_dequeue_ts = _time.time() if task is None: logger.error("Received shutdown signal") @@ -622,8 +677,12 @@ def _stage_worker( batch_engine_inputs.extend(ein) elif isinstance(ein, dict): batch_engine_inputs.append(ein) + elif isinstance(ein, str): + # For diffusion stage-0, ein might be a string prompt directly + batch_engine_inputs.append(ein) else: - logger.error("Invalid engine input type: %s", type(ein)) + # For other types (e.g., OmniTokensPrompt, TextPrompt), append as-is + batch_engine_inputs.append(ein) sampling_params = batch_tasks[0]["sampling_params"] logger.debug( "Received batch size=%d, request_ids=%s", @@ -634,8 +693,59 @@ def _stage_worker( _batch_seq += 1 gen_outputs: list[Any] = [] _gen_t0 = _time.time() - for ro in stage_engine.generate(batch_engine_inputs, sampling_params, use_tqdm=False): - gen_outputs.append(ro) + if stage_type == "diffusion": + # For diffusion, batch_engine_inputs should be prompts (strings) + # Convert to list of strings if needed + prompts = [] + for ein in batch_engine_inputs: + if isinstance(ein, str): + prompts.append(ein) + elif isinstance(ein, dict) and "prompt" in ein: + prompts.append(ein["prompt"]) + elif hasattr(ein, "prompt"): + prompts.append(ein.prompt) + else: + prompts.append(str(ein)) + # Prepare diffusion kwargs from sampling parameters + diffusion_kwargs = prepare_sampling_params(sampling_params, "diffusion") + # Diffusion generate returns results directly, not an iterator + diffusion_results = stage_engine.generate(prompts, **diffusion_kwargs) + # Convert to list format compatible with LLM outputs + # Ensure each result has a request_id for proper mapping + if isinstance(diffusion_results, list): + gen_outputs = diffusion_results + # Assign request_ids if not present + for idx, result in enumerate(gen_outputs): + if not hasattr(result, "request_id") or result.request_id is None: + if idx < len(batch_request_ids): + if hasattr(result, "request_id"): + result.request_id = batch_request_ids[idx] + else: + # Create a wrapper object if result doesn't support request_id + from types import SimpleNamespace + + wrapped = SimpleNamespace() + wrapped.request_id = batch_request_ids[idx] + wrapped.output = result + gen_outputs[idx] = wrapped + else: + gen_outputs = [diffusion_results] + # Assign request_id to single result + if len(batch_request_ids) > 0: + if hasattr(gen_outputs[0], "request_id"): + gen_outputs[0].request_id = batch_request_ids[0] + else: + from types import SimpleNamespace + + wrapped = SimpleNamespace() + wrapped.request_id = batch_request_ids[0] + wrapped.output = gen_outputs[0] + gen_outputs[0] = wrapped + else: + # LLM engine: use vLLM native SamplingParams + llm_sampling_params = prepare_sampling_params(sampling_params, "llm") + for ro in stage_engine.generate(batch_engine_inputs, llm_sampling_params, use_tqdm=False): + gen_outputs.append(ro) _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 logger.debug(f"Generate done: batch={len(batch_tasks)}, req_ids={batch_request_ids}, gen_ms={_gen_ms:.1f}") @@ -739,20 +849,16 @@ async def _stage_worker_async( batch_timeout: int = 10, ) -> None: """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" + # Use local aliases to avoid conflicts with global imports in worker process import os as _os import time as _time - from vllm_omni.distributed.omni_connectors import build_stage_connectors - from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector - from vllm_omni.entrypoints.async_omni import AsyncOmniStageLLM - - # no inline JSONL/serialization imports; logging handled by utilities - stage_id = stage_payload["stage_id"] engine_args = stage_payload.get("engine_args", {}) runtime_cfg = stage_payload.get("runtime", {}) shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) connectors_config = stage_payload.get("connectors_config", {}) + stage_type = stage_payload.get("stage_type", "llm") in_q = omni_stage._in_q out_q = omni_stage._out_q @@ -901,20 +1007,38 @@ async def _stage_worker_async( except Exception as e: logger.debug("Failed to set up sequential initialization lock: %s", e) - # Init LLM + # Init engine based on stage_type logger.debug( - "Initializing engine with args keys=%s", + "[Stage-%s] Initializing %s engine with args keys=%s", + stage_id, + stage_type, list(engine_args.keys()), ) try: - omni_engine_args = AsyncOmniEngineArgs(**engine_args) - usage_context = UsageContext.OPENAI_API_SERVER - vllm_config = omni_engine_args.create_engine_config(usage_context=usage_context) - stage_engine = AsyncOmniStageLLM.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - engine_args=omni_engine_args, - ) + if stage_type == "diffusion": + # For diffusion, we need to extract diffusion-specific config + od_config = engine_args.get("od_config", {}) + if not od_config: + # Create default config from engine_args + od_config = {"model": model} + # Copy relevant diffusion args + for key in ["model", "device", "dtype", "enable_cpu_offload"]: + if key in engine_args: + od_config[key] = engine_args[key] + logger.debug(f"[Stage-%s] Initializing diffusion engine with config: {od_config}", stage_id) + stage_engine = AsyncOmniDiffusion( + model=model, od_config=od_config, **{k: v for k, v in engine_args.items() if k != "od_config"} + ) + vllm_config = None # Diffusion doesn't use vllm_config + else: + omni_engine_args = AsyncOmniEngineArgs(model=model, **engine_args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = omni_engine_args.create_engine_config(usage_context=usage_context) + stage_engine = AsyncOmniLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + engine_args=omni_engine_args, + ) finally: # Release all locks by closing file descriptors # Locks are automatically released when file descriptors are closed @@ -926,9 +1050,10 @@ async def _stage_worker_async( except (OSError, ValueError): pass omni_stage.set_async_engine(stage_engine) - # Don't keep the dummy data in memory - await stage_engine.reset_mm_cache() - logger.debug("Engine initialized") + # Don't keep the dummy data in memory (only for LLM engines) + if stage_type != "diffusion": + await stage_engine.reset_mm_cache() + logger.debug("[Stage-%s] Engine initialized", stage_id) # Signal readiness to orchestrator and send vllm_config back to main process try: # Send vllm_config back to main process so it can be accessed via @@ -936,16 +1061,16 @@ async def _stage_worker_async( # in the worker process # input_preprocessor = await stage_engine.get_input_preprocessor() - out_q.put( - { - "type": "stage_ready", - "stage_id": stage_id, - "vllm_config": vllm_config, - "tokenizer": getattr(stage_engine, "tokenizer", None), - "is_tracing_enabled": await stage_engine.is_tracing_enabled(), - # "input_preprocessor": input_preprocessor, - } - ) + stage_ready_payload = { + "type": "stage_ready", + "stage_id": stage_id, + "vllm_config": vllm_config, + "tokenizer": getattr(stage_engine, "tokenizer", None), + } + # Only add is_tracing_enabled for LLM engines + if stage_type != "diffusion": + stage_ready_payload["is_tracing_enabled"] = await stage_engine.is_tracing_enabled() + out_q.put(stage_ready_payload) except Exception as e: logger.warning("Failed to send stage ready signal: %s", e) generation_out_q = asyncio.Queue() @@ -985,8 +1110,30 @@ async def generation_single_request(task: dict[str, Any]): _gen_t0 = _time.time() if isinstance(ein, list): ein = ein[0] - async for res in stage_engine.generate(ein, sampling_params, rid): - gen_output = res + + if stage_type == "diffusion": + # For diffusion, ein should be prompts (strings) + # Convert to string if needed + if isinstance(ein, str): + prompt = ein + elif isinstance(ein, dict) and "prompt" in ein: + prompt = ein["prompt"] + elif hasattr(ein, "prompt"): + prompt = ein.prompt + else: + prompt = str(ein) + + # Prepare diffusion kwargs from sampling parameters + diffusion_kwargs = prepare_sampling_params(sampling_params, "diffusion") + # AsyncOmniDiffusion.generate returns a single result, not an async generator + gen_output = await stage_engine.generate(prompt=prompt, request_id=rid, **diffusion_kwargs) + else: + # LLM stages: ensure using SamplingParams + llm_sampling_params = prepare_sampling_params(sampling_params, "llm") + gen_output = None + async for res in stage_engine.generate(ein, llm_sampling_params, rid): + gen_output = res + _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 await generation_out_q.put((rid, gen_output, _gen_ms)) @@ -1103,7 +1250,6 @@ def make_request_stats( ): from vllm_omni.entrypoints.log_utils import ( StageRequestMetrics, - count_tokens_from_outputs, ) num_tokens_out = count_tokens_from_outputs(req_output) diff --git a/vllm_omni/entrypoints/openai/__init__.py b/vllm_omni/entrypoints/openai/__init__.py index 58a96cc0e..e27cb238c 100644 --- a/vllm_omni/entrypoints/openai/__init__.py +++ b/vllm_omni/entrypoints/openai/__init__.py @@ -6,16 +6,12 @@ Provides: - omni_run_server: Main server entry point (auto-detects model type) -- omni_run_diffusion_server: Server for diffusion models - OmniOpenAIServingChat: Unified chat completion handler for both LLM and diffusion models """ from vllm_omni.entrypoints.openai.api_server import ( - build_async_diffusion, build_async_omni, - omni_diffusion_init_app_state, omni_init_app_state, - omni_run_diffusion_server, omni_run_server, ) from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat @@ -23,11 +19,8 @@ __all__ = [ # Server functions "omni_run_server", - "omni_run_diffusion_server", "build_async_omni", - "build_async_diffusion", "omni_init_app_state", - "omni_diffusion_init_app_state", # Serving classes "OmniOpenAIServingChat", ] diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 61f4079ab..04f41c9ea 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1,4 +1,3 @@ -import json import multiprocessing import multiprocessing.forkserver as forkserver import os @@ -8,7 +7,6 @@ from argparse import Namespace from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from dataclasses import fields from http import HTTPStatus from typing import Any @@ -43,9 +41,6 @@ from vllm.tokenizers import MistralTokenizer from vllm.utils.system_utils import decorate_logs -from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig -from vllm_omni.diffusion.utils.hf_utils import is_diffusion_model -from vllm_omni.entrypoints.async_diffusion import AsyncOmniDiffusion from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.image_api_utils import ( encode_image_base64, @@ -67,71 +62,23 @@ async def omni_run_server(args, **uvicorn_kwargs) -> None: """Run a single-worker API server. - Automatically detects if the model is a diffusion model and routes - to the appropriate server implementation. + Unified entry point that automatically handles both LLM and Diffusion models + through AsyncOmni, which manages multi-stage pipelines. """ + # Suppress Pydantic serialization warnings globally for multimodal content + # (e.g., when ChatMessage.content is a list instead of str) + import warnings as warnings_module - # Add process-specific prefix to stdout and stderr. - decorate_logs("APIServer") - - listen_address, sock = setup_server(args) - - # Check if model is a diffusion model - if is_diffusion_model(args.model): - logger.info("Detected diffusion model, starting diffusion API server") - await omni_run_diffusion_server_worker(listen_address, sock, args, **uvicorn_kwargs) - else: - await omni_run_server_worker(listen_address, sock, args, **uvicorn_kwargs) - - -async def omni_run_diffusion_server(args, **uvicorn_kwargs) -> None: - """Run a diffusion model API server.""" + warnings_module.filterwarnings("ignore", message=".*Pydantic.*serialization.*", category=UserWarning) + warnings_module.filterwarnings("ignore", message=".*PydanticSerializationUnexpectedValue.*", category=UserWarning) # Add process-specific prefix to stdout and stderr. - decorate_logs("DiffusionAPIServer") + decorate_logs("APIServer") listen_address, sock = setup_server(args) - await omni_run_diffusion_server_worker(listen_address, sock, args, **uvicorn_kwargs) - -async def omni_run_diffusion_server_worker(listen_address, sock, args, **uvicorn_kwargs) -> None: - """Run a diffusion model API server worker.""" - - # Load logging config for uvicorn if specified - log_config = load_log_config(args.log_config_file) - if log_config is not None: - uvicorn_kwargs["log_config"] = log_config - - async with build_async_diffusion(args) as diffusion_engine: - app = build_app(args) - - await omni_diffusion_init_app_state(diffusion_engine, app.state, args) - - logger.info("Starting vLLM Diffusion API server on %s", listen_address) - - shutdown_task = await serve_http( - app, - sock=sock, - enable_ssl_refresh=getattr(args, "enable_ssl_refresh", False), - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - access_log=not getattr(args, "disable_uvicorn_access_log", False), - timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, - ssl_keyfile=getattr(args, "ssl_keyfile", None), - ssl_certfile=getattr(args, "ssl_certfile", None), - ssl_ca_certs=getattr(args, "ssl_ca_certs", None), - ssl_cert_reqs=getattr(args, "ssl_cert_reqs", 0), - h11_max_incomplete_event_size=getattr(args, "h11_max_incomplete_event_size", None), - h11_max_header_count=getattr(args, "h11_max_header_count", None), - **uvicorn_kwargs, - ) - - # NB: Await server shutdown only after the backend context is exited - try: - await shutdown_task - finally: - sock.close() + # Unified use of omni_run_server_worker, AsyncOmni automatically handles LLM and Diffusion models + await omni_run_server_worker(listen_address, sock, args, **uvicorn_kwargs) async def omni_run_server_worker(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None: @@ -155,11 +102,19 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, vllm_config = await engine_client.get_vllm_config() await omni_init_app_state(engine_client, vllm_config, app.state, args) - logger.info( - "Starting vLLM API server %d on %s", - vllm_config.parallel_config._api_process_rank, - listen_address, - ) + # Check if pure diffusion mode (vllm_config will be None) + is_pure_diffusion = vllm_config is None + if is_pure_diffusion: + logger.info( + "Starting vLLM API server (pure diffusion mode) on %s", + listen_address, + ) + else: + logger.info( + "Starting vLLM API server %d on %s", + vllm_config.parallel_config._api_process_rank, + listen_address, + ) shutdown_task = await serve_http( app, sock=sock, @@ -227,78 +182,11 @@ async def build_async_omni( yield async_omni -@asynccontextmanager -async def build_async_diffusion( - args: Namespace, - **kwargs: Any, -) -> AsyncIterator[AsyncOmniDiffusion]: - """Build an AsyncOmniDiffusion instance from command-line arguments. - - Creates an async context manager that yields an AsyncOmniDiffusion - instance configured from the provided arguments. - - Args: - args: Parsed command-line arguments containing model and configuration - **kwargs: Additional keyword arguments passed to AsyncOmniDiffusion - - Yields: - AsyncOmniDiffusion instance ready for use - """ - diffusion_engine: AsyncOmniDiffusion | None = None - - try: - # Build diffusion kwargs by extracting matching OmniDiffusionConfig fields from args - config_field_names = {f.name for f in fields(OmniDiffusionConfig)} - diffusion_kwargs: dict[str, Any] = {"model": args.model} - - # Diffusion parallelism configuration (e.g. `--usp 2`). - parallel_config_kwargs: dict[str, Any] = {} - for field in fields(DiffusionParallelConfig): - if not hasattr(args, field.name): - continue - value = getattr(args, field.name) - if value is None: - continue - parallel_config_kwargs[field.name] = value - if parallel_config_kwargs: - diffusion_kwargs["parallel_config"] = DiffusionParallelConfig(**parallel_config_kwargs) - - for field_name in config_field_names: - if not hasattr(args, field_name): - continue - value = getattr(args, field_name) - if value is None: - continue - # Special handling for cache_config JSON string - if field_name == "cache_config" and isinstance(value, str): - try: - value = json.loads(value) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse cache_config JSON: {e}") - continue - diffusion_kwargs[field_name] = value - - diffusion_kwargs.update(kwargs) - logger.info(f"diffusion_kwargs: {diffusion_kwargs}") - logger.info( - "Building AsyncOmniDiffusion with model=%s, num_gpus=%s", - args.model, - diffusion_kwargs.get("num_gpus", 1), - ) - diffusion_engine = AsyncOmniDiffusion(**diffusion_kwargs) - - yield diffusion_engine - finally: - if diffusion_engine: - diffusion_engine.shutdown() - - @asynccontextmanager async def build_async_omni_from_stage_config( args: Namespace, *, disable_frontend_multiprocessing: bool = False, - client_config: dict[str, Any] | None = None, ) -> AsyncIterator[EngineClient]: """Create AsyncOmni from stage configuration. @@ -326,7 +214,11 @@ async def build_async_omni_from_stage_config( async_omni: EngineClient | None = None try: - async_omni = AsyncOmni(model=args.model, cli_args=args) + # Convert args Namespace to kwargs dict for AsyncOmni to use + kwargs = vars(args).copy() + # Remove model as it will be passed separately + kwargs.pop("model", None) + async_omni = AsyncOmni(model=args.model, **kwargs) # # Don't keep the dummy data in memory # await async_llm.reset_mm_cache() @@ -339,7 +231,7 @@ async def build_async_omni_from_stage_config( async def omni_init_app_state( engine_client: EngineClient, - vllm_config: VllmConfig, + vllm_config: VllmConfig | None, state: State, args: Namespace, ) -> None: @@ -347,13 +239,25 @@ async def omni_init_app_state( Sets up the application state with model information, request logger, and other server configuration needed for handling API requests. + Automatically detects pure diffusion mode (single diffusion stage) and + handles it appropriately. Args: engine_client: Engine client instance (AsyncOmni) - vllm_config: vLLM configuration object + vllm_config: vLLM configuration object (may be None for pure diffusion) state: FastAPI application state object to initialize args: Parsed command-line arguments """ + # Detect if it's pure Diffusion mode (single stage and is Diffusion) + is_pure_diffusion = False + if hasattr(engine_client, "stage_configs") and engine_client.stage_configs: + stage_configs = engine_client.stage_configs + if len(stage_configs) == 1: + stage_type = stage_configs[0].get("stage_type", "llm") + if stage_type == "diffusion": + is_pure_diffusion = True + logger.info("Detected pure diffusion mode (single diffusion stage)") + if args.served_model_name is not None: served_model_names = args.served_model_name else: @@ -367,36 +271,62 @@ async def omni_init_app_state( base_model_paths = [BaseModelPath(name=name, model_path=args.model) for name in served_model_names] state.engine_client = engine_client state.log_stats = not args.disable_log_stats - state.vllm_config = vllm_config - _model_config = vllm_config.model_config - state.log_stats = not args.disable_log_stats # For omni models - state.stage_configs = engine_client.stage_configs + state.stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None + + # Pure Diffusion mode: use simplified initialization logic + if is_pure_diffusion: + model_name = served_model_names[0] if served_model_names else args.model + state.vllm_config = None + state.diffusion_engine = engine_client + + # Use for_diffusion method to create chat handler + state.openai_serving_chat = OmniOpenAIServingChat.for_diffusion( + diffusion_engine=engine_client, # type: ignore + model_name=model_name, + ) + + state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) + state.server_load_metrics = 0 + logger.info("Pure diffusion API server initialized for model: %s", model_name) + return + + # LLM or multi-stage mode: use standard initialization logic + if vllm_config is None: + # Try to get vllm_config from engine_client + vllm_config = await engine_client.get_vllm_config() + if vllm_config is None: + logger.warning("vllm_config is None, some features may not work correctly") + + state.vllm_config = vllm_config + if vllm_config is not None: + _model_config = vllm_config.model_config resolved_chat_template = load_chat_template(args.chat_template) - if resolved_chat_template is not None: + if resolved_chat_template is not None and vllm_config is not None: # Get the tokenizer to check official template tokenizer = await engine_client.get_tokenizer() - if isinstance(tokenizer, MistralTokenizer): - # The warning is logged in resolve_mistral_chat_template. - resolved_chat_template = resolve_mistral_chat_template(chat_template=resolved_chat_template) - else: - hf_chat_template = resolve_hf_chat_template( - tokenizer=tokenizer, - chat_template=None, - tools=None, - model_config=vllm_config.model_config, - ) - - if hf_chat_template != resolved_chat_template: - logger.warning( - "Using supplied chat template: %s\nIt is different from official chat template '%s'. This discrepancy may lead to performance degradation.", # noqa: E501 - resolved_chat_template, - args.model, + if tokenizer is not None: + if isinstance(tokenizer, MistralTokenizer): + # The warning is logged in resolve_mistral_chat_template. + resolved_chat_template = resolve_mistral_chat_template(chat_template=resolved_chat_template) + else: + hf_chat_template = resolve_hf_chat_template( + tokenizer=tokenizer, + chat_template=None, + tools=None, + model_config=vllm_config.model_config, ) + if hf_chat_template != resolved_chat_template: + logger.warning( + "Using supplied chat template: %s\nIt is different from official chat template '%s'. This discrepancy may lead to performance degradation.", # noqa: E501 + resolved_chat_template, + args.model, + ) + if args.tool_server == "demo": tool_server: ToolServer | None = DemoToolServer() assert isinstance(tool_server, DemoToolServer) @@ -408,7 +338,9 @@ async def omni_init_app_state( tool_server = None # Merge default_mm_loras into the static lora_modules - default_mm_loras = vllm_config.lora_config.default_mm_loras if vllm_config.lora_config is not None else {} + default_mm_loras = {} + if vllm_config is not None and vllm_config.lora_config is not None: + default_mm_loras = vllm_config.lora_config.default_mm_loras lora_modules = args.lora_modules if default_mm_loras: @@ -424,6 +356,57 @@ async def omni_init_app_state( else: lora_modules += default_mm_lora_paths + # Ensure input_processor, io_processor, and model_config exist for OpenAIServingModels compatibility + if ( + not hasattr(engine_client, "input_processor") + or engine_client.input_processor is None + or not hasattr(engine_client, "io_processor") + or engine_client.io_processor is None + or not hasattr(engine_client, "model_config") + or engine_client.model_config is None + ): + if vllm_config is not None: + # Try to initialize processors if vllm_config is available + try: + from vllm.plugins.io_processors import get_io_processor + + from vllm_omni.engine.input_processor import OmniInputProcessor + + tokenizer = await engine_client.get_tokenizer() + if tokenizer is not None: + # Initialize input_processor + if not hasattr(engine_client, "input_processor") or engine_client.input_processor is None: + engine_client.input_processor = OmniInputProcessor( + vllm_config=vllm_config, + tokenizer=tokenizer, + ) + logger.info("Initialized input_processor for AsyncOmni") + + # Initialize model_config + if not hasattr(engine_client, "model_config") or engine_client.model_config is None: + engine_client.model_config = vllm_config.model_config + logger.info("Initialized model_config for AsyncOmni") + + # Initialize io_processor + if not hasattr(engine_client, "io_processor") or engine_client.io_processor is None: + model_config = ( + engine_client.model_config + if hasattr(engine_client, "model_config") + else vllm_config.model_config + ) + io_processor_plugin = model_config.io_processor_plugin + engine_client.io_processor = get_io_processor(vllm_config, io_processor_plugin) + logger.info("Initialized io_processor for AsyncOmni") + else: + logger.warning("Cannot initialize processors: tokenizer is None. OpenAIServingModels may fail.") + except Exception as e: + logger.warning( + "Failed to initialize processors for AsyncOmni: %s. OpenAIServingModels may fail.", + e, + ) + else: + logger.warning("Cannot initialize processors: vllm_config is None. OpenAIServingModels may fail.") + state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, base_model_paths=base_model_paths, @@ -457,50 +440,6 @@ def Omnichat(request: Request) -> OmniOpenAIServingChat | None: return request.app.state.openai_serving_chat -async def omni_diffusion_init_app_state( - diffusion_engine: AsyncOmniDiffusion, - state: State, - args: Namespace, -) -> None: - """Initialize the FastAPI application state for diffusion model API server. - - Sets up the application state with diffusion model information and - chat completion handler for image generation via /v1/chat/completions. - - Args: - diffusion_engine: AsyncOmniDiffusion engine instance - state: FastAPI application state object to initialize - args: Parsed command-line arguments - """ - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - - model_name = served_model_names[0] if served_model_names else args.model - - state.diffusion_engine = diffusion_engine - state.diffusion_model_name = model_name # Store for image endpoints - state.log_stats = not getattr(args, "disable_log_stats", False) - - # Initialize chat handler with diffusion engine (uses /v1/chat/completions endpoint) - # Note: Request-level parameters (num_inference_steps, guidance_scale, seed, height, width, etc.) - # are passed per-request via the API, not as server defaults - state.openai_serving_chat = OmniOpenAIServingChat.for_diffusion( - diffusion_engine=diffusion_engine, - model_name=model_name, - ) - - # Set other handlers to None for diffusion-only mode - state.engine_client = None - state.vllm_config = None - - state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) - state.server_load_metrics = 0 - - logger.info("Diffusion API server initialized for model: %s", model_name) - - @router.post( "/v1/chat/completions", dependencies=[Depends(validate_json_request)], @@ -529,7 +468,30 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) elif isinstance(generator, ChatCompletionResponse): - return JSONResponse(content=generator.model_dump()) + # Completely bypass Pydantic serialization warnings for multimodal content + # by converting to dict first, then serializing with warnings suppressed + import json as json_lib + import warnings as warnings_module + + # Temporarily suppress ALL Pydantic UserWarnings during serialization + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning) + warnings_module.filterwarnings("ignore", message=".*Pydantic.*", category=UserWarning) + try: + # Use serialize_as_any=True to bypass type checking + response_dict = generator.model_dump(mode="json", serialize_as_any=True, warnings="none") + return JSONResponse(content=response_dict) + except Exception: + # Fallback: convert to JSON string and parse back to avoid any serialization issues + try: + response_json = generator.model_dump_json(warnings="none", serialize_as_any=True) + response_dict = json_lib.loads(response_json) + return JSONResponse(content=response_dict) + except Exception: + # Last resort: regular dump with warnings suppressed + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning) + return JSONResponse(content=generator.model_dump(mode="json", warnings="none")) return StreamingResponse(content=generator, media_type="text/event-stream") @@ -551,6 +513,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) """Generate images from text prompts using diffusion models. OpenAI DALL-E compatible endpoint for text-to-image generation. + Only supports multi-stage omni mode with diffusion stages. Args: request: Image generation request with prompt and parameters @@ -562,16 +525,56 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) Raises: HTTPException: For validation errors, missing engine, or generation failures """ - # Get diffusion engine from app state - diffusion_engine: AsyncOmniDiffusion | None = getattr(raw_request.app.state, "diffusion_engine", None) - if diffusion_engine is None: + # Get engine client (AsyncOmni) from app state + engine_client: EngineClient | None = getattr(raw_request.app.state, "engine_client", None) + if engine_client is None or not hasattr(engine_client, "stage_list"): raise HTTPException( status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, - detail="Diffusion engine not initialized. Start server with a diffusion model.", + detail="Multi-stage engine not initialized. Start server with a multi-stage omni model.", ) - # Get server's loaded model - model_name = getattr(raw_request.app.state, "diffusion_model_name", "unknown") + # Check if there's a diffusion stage + stage_configs = getattr(raw_request.app.state, "stage_configs", None) + if not stage_configs: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="Stage configs not found. Start server with a multi-stage omni model.", + ) + + # Check for diffusion stage + has_diffusion_stage = False + for stage in stage_configs: + # Handle both dict and OmegaConf objects + stage_type = None + if isinstance(stage, dict): + stage_type = stage.get("stage_type", "llm") + elif hasattr(stage, "get"): + stage_type = stage.get("stage_type", "llm") + elif hasattr(stage, "stage_type"): + stage_type = stage.stage_type + else: + # Fallback: try to access as dict-like + try: + stage_type = stage["stage_type"] if "stage_type" in stage else "llm" + except (TypeError, KeyError): + stage_type = "llm" + + if stage_type == "diffusion": + has_diffusion_stage = True + break + + if not has_diffusion_stage: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="No diffusion stage found in multi-stage pipeline.", + ) + + # Get server's loaded model name + serving_models = getattr(raw_request.app.state, "openai_serving_models", None) + if serving_models and hasattr(serving_models, "base_model_paths") and serving_models.base_model_paths: + model_name = serving_models.base_model_paths[0].name + else: + model_name = "unknown" # Validate model field (warn if mismatch, don't error) if request.model is not None and request.model != model_name: @@ -607,14 +610,21 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) gen_params["true_cfg_scale"] = request.true_cfg_scale if request.seed is not None: gen_params["seed"] = request.seed + gen_params["request_id"] = f"img_gen_{int(time.time())}" logger.info(f"Generating {request.n} image(s) {size_str}") - # Generate images using AsyncOmniDiffusion - result = await diffusion_engine.generate(**gen_params) + # Generate images using AsyncOmni (multi-stage mode) + result = await engine_client.generate(**gen_params) + + if result is None: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail="No output generated from multi-stage pipeline.", + ) # Extract images from result - images = result.images if hasattr(result, "images") else [] + images = result.images if hasattr(result, "images") and result.images else [] logger.info(f"Successfully generated {len(images)} image(s)") diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 422a38788..9a8407d84 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -74,7 +74,7 @@ from vllm_omni.outputs import OmniRequestOutput if TYPE_CHECKING: - from vllm_omni.entrypoints.async_diffusion import AsyncOmniDiffusion + from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion logger = init_logger(__name__) @@ -477,6 +477,8 @@ def _build_sampling_params_list_from_request( sampling_params_list = [] for idx, default_params in enumerate(default_params_list): + if isinstance(default_params, dict): + default_params = SamplingParams(**default_params) if idx == comprehension_idx: params = self._apply_request_overrides(default_params, request) sampling_params_list.append(params) @@ -985,9 +987,21 @@ def _create_image_choice(self, omni_outputs: OmniRequestOutput, role: str): content = [{"type": "text", "text": "Image generation completed but no images were produced."}] # Create response choice + # Use model_construct to bypass validation for multimodal content + # (ChatMessage.content only accepts str, but we need list for images) + # Then use object.__setattr__ to directly set the field, bypassing Pydantic's type checking + import warnings as warnings_module + + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning, module="pydantic") + message = ChatMessage.model_construct(role=role) + object.__setattr__(message, "content", content) + # Mark content as set in fields_set to ensure proper serialization + if hasattr(message, "__pydantic_fields_set__"): + message.__pydantic_fields_set__.add("content") choice_data = ChatCompletionResponseChoice( index=0, - message=ChatMessage(role=role, content=content), + message=message, logprobs=None, finish_reason="stop", stop_reason=None, @@ -1108,6 +1122,9 @@ async def _create_diffusion_chat_completion( else: od_config = getattr(self._diffusion_engine, "od_config", None) supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) + if od_config is None: + # TODO: entry is asyncOmni. We hack the od config here. + supports_multimodal_inputs = True if supports_multimodal_inputs: gen_kwargs["pil_image"] = pil_images else: @@ -1119,11 +1136,30 @@ async def _create_diffusion_chat_completion( ) # Generate image - result = await self._diffusion_engine.generate(**gen_kwargs) + # Handle both AsyncOmniDiffusion (returns OmniRequestOutput) and AsyncOmni (returns AsyncGenerator) + if hasattr(self._diffusion_engine, "stage_list"): + # AsyncOmni: iterate through async generator to get final output + result = None + async for output in self._diffusion_engine.generate( + prompt=gen_kwargs["prompt"], + request_id=gen_kwargs.get("request_id"), + sampling_params_list=[gen_kwargs], # Pass as single-stage params + ): + result = output + if result is None: + return self._create_error_response("No output generated from AsyncOmni") + else: + # AsyncOmniDiffusion: direct call + result = await self._diffusion_engine.generate(**gen_kwargs) + # Extract images from result + # Handle nested OmniRequestOutput structure where images might be in request_output + images: list[Image.Image] = [] + if result.request_output["images"]: + images = result.request_output["images"] # Convert images to base64 content image_contents: list[dict[str, Any]] = [] - for img in result.images: + for img in images: with BytesIO() as buffer: img.save(buffer, format="PNG") img_bytes = buffer.getvalue() @@ -1145,7 +1181,16 @@ async def _create_diffusion_chat_completion( # Use model_construct to bypass validation for multimodal content # (ChatMessage.content only accepts str, but we need list for images) - message = ChatMessage.model_construct(role="assistant", content=content) + # Then use object.__setattr__ to directly set the field, bypassing Pydantic's type checking + import warnings as warnings_module + + with warnings_module.catch_warnings(): + warnings_module.filterwarnings("ignore", category=UserWarning, module="pydantic") + message = ChatMessage.model_construct(role="assistant") + object.__setattr__(message, "content", content) + # Mark content as set in fields_set to ensure proper serialization + if hasattr(message, "__pydantic_fields_set__"): + message.__pydantic_fields_set__.add("content") choice = ChatCompletionResponseChoice.model_construct( index=0, message=message, @@ -1169,7 +1214,7 @@ async def _create_diffusion_chat_completion( logger.info( "Diffusion chat completed for request %s: %d images", request_id, - len(result.images), + len(images), ) return response diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 82c8a4d34..770f695c6 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -6,7 +6,8 @@ from omegaconf import OmegaConf from vllm.logger import init_logger -from vllm.transformers_utils.config import get_config +from vllm.transformers_utils.config import get_config, get_hf_file_to_dict +from vllm.transformers_utils.repo_utils import file_or_path_exists from vllm_omni.utils import detect_device_type, is_rocm @@ -16,6 +17,23 @@ logger = init_logger(__name__) +def _try_get_class_name_from_diffusers_config(model: str) -> str | None: + """Try to get class name from diffusers model configuration files. + + Args: + model: Model name or path + + Returns: + Model type string if found, None otherwise + """ + model_index = get_hf_file_to_dict("model_index.json", model, revision=None) + if model_index and isinstance(model_index, dict) and "_class_name" in model_index: + logger.debug(f"Found model_type '{model_index['_class_name']}' in model_index.json") + return model_index["_class_name"] + + return None + + def _convert_dataclasses_to_dict(obj: Any) -> Any: """Recursively convert non-serializable objects to OmegaConf-compatible types. @@ -79,10 +97,28 @@ def resolve_model_config_path(model: str) -> str: String path to the stage configuration file Raises: + ValueError: If model_type cannot be determined FileNotFoundError: If no stage config file exists for the model type """ - hf_config = get_config(model, trust_remote_code=True) - model_type = hf_config.model_type + # Try to get config from standard transformers format first + try: + hf_config = get_config(model, trust_remote_code=True) + model_type = hf_config.model_type + except (ValueError, Exception): + # If standard transformers format fails, try diffusers format + if file_or_path_exists(model, "model_index.json", revision=None): + model_type = _try_get_class_name_from_diffusers_config(model) + if model_type is None: + raise ValueError( + f"Could not determine model_type for diffusers model: {model}. " + f"Please ensure the model has 'model_type' in transformer/config.json or model_index.json" + ) + else: + raise ValueError( + f"Could not determine model_type for model: {model}. " + f"Model is not in standard transformers format and does not have model_index.json. " + f"Please ensure the model has proper configuration files with 'model_type' field" + ) device_type = detect_device_type() # Try device-specific config first @@ -98,7 +134,7 @@ def resolve_model_config_path(model: str) -> str: stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" stage_config_path = PROJECT_ROOT / stage_config_file if not os.path.exists(stage_config_path): - raise FileNotFoundError(f"Stage config file {stage_config_path} not found") + return None return str(stage_config_path) @@ -121,6 +157,8 @@ def load_stage_configs_from_model(model: str, base_engine_args: dict | None = No if base_engine_args is None: base_engine_args = {} stage_config_path = resolve_model_config_path(model) + if stage_config_path is None: + return [] stage_configs = load_stage_configs_from_yaml(config_path=stage_config_path, base_engine_args=base_engine_args) return stage_configs diff --git a/vllm_omni/model_executor/stage_configs/npu/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/npu/qwen2_5_omni.yaml index 2812e99a0..49085276f 100644 --- a/vllm_omni/model_executor/stage_configs/npu/qwen2_5_omni.yaml +++ b/vllm_omni/model_executor/stage_configs/npu/qwen2_5_omni.yaml @@ -1,6 +1,7 @@ # stage config for running qwen2.5-omni with architecture of OmniLLM. stage_args: - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage @@ -27,6 +28,7 @@ stage_args: detokenize: True repetition_penalty: 1.1 - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true devices: "1" @@ -53,6 +55,7 @@ stage_args: repetition_penalty: 1.05 stop_token_ids: [8294] - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true devices: "0" # Example: use a different NPU than the previous stage; use "0" if single NPU diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml index b89171b4c..fc84f485e 100644 --- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml @@ -3,6 +3,7 @@ # The following config has been verified on 2x H100-80G GPU. stage_args: - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) @@ -31,6 +32,7 @@ stage_args: repetition_penalty: 1.1 - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true devices: "1" @@ -59,6 +61,7 @@ stage_args: stop_token_ids: [8294] - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml index 5d5de6a95..f2e4d03dd 100644 --- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml @@ -3,6 +3,7 @@ # The following config has been verified on 1x H100-80G GPU. stage_args: - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) @@ -32,6 +33,7 @@ stage_args: output_connectors: to_stage_1: mooncake_connector - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true devices: "1" @@ -63,6 +65,7 @@ stage_args: output_connectors: to_stage_2: mooncake_connector - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: process: true devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml index 73f65ecb5..e077cf529 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml @@ -6,6 +6,7 @@ # The following config has been verified on 2x H100-80G GPUs. stage_args: - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: devices: "0,1" max_batch_size: 1 @@ -36,6 +37,7 @@ stage_args: repetition_penalty: 1.05 - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: devices: "1" max_batch_size: 1 @@ -67,6 +69,7 @@ stage_args: stop_token_ids: [2150] - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: devices: "0" max_batch_size: 1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml index eb4ca3886..4abd3835a 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml @@ -6,6 +6,7 @@ # The following config has been verified on 2x H100-80G GPUs. stage_args: - stage_id: 0 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: devices: "0,1" max_batch_size: 1 @@ -38,6 +39,7 @@ stage_args: to_stage_1: connector_of_mooncake - stage_id: 1 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: devices: "1" max_batch_size: 1 @@ -73,6 +75,7 @@ stage_args: to_stage_2: connector_of_mooncake - stage_id: 2 + stage_type: llm # Use llm stage type to launch OmniLLM runtime: devices: "0" max_batch_size: 1 diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 1fe3112d7..86f41a64d 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -151,3 +151,23 @@ def to_dict(self) -> dict[str, Any]: ) return result + + def __repr__(self) -> str: + """Custom repr to properly show image count instead of image objects.""" + # For images, show count instead of full list + images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]" + + # Build repr string + parts = [ + f"request_id={self.request_id!r}", + f"finished={self.finished}", + f"stage_id={self.stage_id}", + f"final_output_type={self.final_output_type!r}", + f"request_output={self.request_output}", + f"images={images_repr}", + f"prompt={self.prompt!r}", + f"latents={self.latents}", + f"metrics={self.metrics}", + ] + + return f"OmniRequestOutput({', '.join(parts)})" diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py index 4748d1fd3..9e058addb 100644 --- a/vllm_omni/worker/gpu_ar_worker.py +++ b/vllm_omni/worker/gpu_ar_worker.py @@ -2,6 +2,7 @@ import os import torch +from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.utils.mem_constants import GiB_bytes @@ -12,6 +13,8 @@ from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner +logger = init_logger(__name__) + class GPUARWorker(GPUWorker): """GPU worker for autoregressive omni model stages.