Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

Main entry points for vLLM-Omni inference and serving.

- [vllm_omni.entrypoints.async_diffusion.AsyncOmniDiffusion][]
- [vllm_omni.entrypoints.async_omni_diffusion.AsyncOmniDiffusion][]
- [vllm_omni.entrypoints.async_omni.AsyncOmni][]
- [vllm_omni.entrypoints.async_omni.AsyncOmniStageLLM][]
- [vllm_omni.entrypoints.async_omni_llm.AsyncOmniLLM][]
- [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalContentParser][]
- [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalItemTracker][]
- [vllm_omni.entrypoints.chat_utils.parse_chat_messages_futures][]
Expand Down
114 changes: 74 additions & 40 deletions examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import torch
from PIL import Image

from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.data import DiffusionParallelConfig, logger
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type, is_npu

Expand Down Expand Up @@ -317,45 +317,79 @@ def main():
print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}")
print(f"{'=' * 60}\n")

generation_start = time.perf_counter()
# Generate edited image
images = omni.generate(
prompt=args.prompt,
pil_image=input_image,
negative_prompt=args.negative_prompt,
generator=generator,
true_cfg_scale=args.cfg_scale,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
num_outputs_per_prompt=args.num_outputs_per_prompt,
layers=args.layers,
)
generation_end = time.perf_counter()
generation_time = generation_end - generation_start

# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

# Save output image(s)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".png"
stem = output_path.stem or "output_image_edit"

if args.num_outputs_per_prompt <= 1:
img = images[0]
img = img if isinstance(img, list) else [img]
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
else:
for idx, img in enumerate(images):
img = img if isinstance(img, list) else [img]
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
try:
generation_start = time.perf_counter()
# Generate edited image
generate_kwargs = {
"prompt": args.prompt,
"pil_image": input_image,
"negative_prompt": args.negative_prompt,
"generator": generator,
"true_cfg_scale": args.cfg_scale,
"guidance_scale": args.guidance_scale,
"num_inference_steps": args.num_inference_steps,
"num_outputs_per_prompt": args.num_outputs_per_prompt,
"layers": args.layers,
"resolution": args.resolution,
}

outputs = omni.generate(**generate_kwargs)
generation_end = time.perf_counter()
generation_time = generation_end - generation_start

# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

if not outputs:
raise ValueError("No output generated from omni.generate()")
logger.info("Outputs: %s", outputs)

# Extract images from OmniRequestOutput
# omni.generate() returns list[OmniRequestOutput], extract images from request_output[0]['images']
first_output = outputs[0]
if not hasattr(first_output, "request_output") or not first_output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")

req_out = first_output.request_output[0]
if not isinstance(req_out, dict) or "images" not in req_out:
raise ValueError("Invalid request_output structure or missing 'images' key")

images = req_out["images"]
if not images:
raise ValueError("No images found in request_output")

# Save output image(s)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".png"
stem = output_path.stem or "output_image_edit"

# Handle layered output (each image may be a list of layers)
if args.num_outputs_per_prompt <= 1:
img = images[0]
# Check if this is a layered output (list of images)
if isinstance(img, list):
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
else:
img.save(output_path)
print(f"Saved edited image to {os.path.abspath(output_path)}")
else:
for idx, img in enumerate(images):
# Check if this is a layered output (list of images)
if isinstance(img, list):
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
else:
save_path = output_path.parent / f"{stem}_{idx}{suffix}"
img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
finally:
omni.close()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/offline_inference/qwen3_omni/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def main(args):
# Save audio file with explicit WAV format
sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV")
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
omni_llm.close()


def parse_args():
Expand Down
29 changes: 25 additions & 4 deletions examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.data import DiffusionParallelConfig, logger
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type, is_npu

Expand All @@ -20,7 +20,7 @@ def parse_args() -> argparse.Namespace:
help="Diffusion model name or local path. Supported models: Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo",
)
parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.")
parser.add_argument("--seed", type=int, default=142, help="Random seed for deterministic results.")
parser.add_argument(
"--cfg_scale",
type=float,
Expand Down Expand Up @@ -127,7 +127,7 @@ def main():
print(f"{'=' * 60}\n")

generation_start = time.perf_counter()
images = omni.generate(
outputs = omni.generate(
args.prompt,
height=args.height,
width=args.width,
Expand All @@ -142,11 +142,30 @@ def main():
# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

# Extract images from OmniRequestOutput
# omni.generate() returns list[OmniRequestOutput], extract images from the first output
if not outputs or len(outputs) == 0:
raise ValueError("No output generated from omni.generate()")
logger.info(f"Outputs: {outputs}")

# Extract images from request_output[0]['images']
first_output = outputs[0]
if not hasattr(first_output, "request_output") or not first_output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")

req_out = first_output.request_output[0]
if not isinstance(req_out, dict) or "images" not in req_out:
raise ValueError("Invalid request_output structure or missing 'images' key")

images = req_out["images"]
if not images:
raise ValueError("No images found in request_output")

output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".png"
stem = output_path.stem or "qwen_image_output"
if args.num_images_per_prompt <= 1:
if len(images) <= 1:
images[0].save(output_path)
print(f"Saved generated image to {output_path}")
else:
Expand All @@ -155,6 +174,8 @@ def main():
img.save(save_path)
print(f"Saved generated image to {save_path}")

omni.close()


if __name__ == "__main__":
main()
26 changes: 20 additions & 6 deletions examples/online_serving/text_to_image/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def generate_image(
seed: int | None,
negative_prompt: str,
server_url: str,
num_outputs_per_prompt: int = 1,
) -> Image.Image | None:
"""Generate an image using the chat completions API."""
messages = [{"role": "user", "content": prompt}]
Expand All @@ -39,6 +40,8 @@ def generate_image(
extra_body["seed"] = seed
if negative_prompt:
extra_body["negative_prompt"] = negative_prompt
# Keep consistent with run_curl_text_to_image.sh, always send num_outputs_per_prompt
extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt

# Build request payload
payload = {"messages": messages, "extra_body": extra_body}
Expand Down Expand Up @@ -109,7 +112,8 @@ def create_demo(server_url: str):
label="Inference Steps",
minimum=10,
maximum=100,
value=50,
# Default steps aligned with run_curl_text_to_image.sh to 100
value=100,
step=5,
)
cfg_scale = gr.Slider(
Expand Down Expand Up @@ -138,16 +142,26 @@ def create_demo(server_url: str):
# Examples
gr.Examples(
examples=[
["A beautiful landscape painting with misty mountains", "", 1024, 1024, 50, 4.0, 42],
["A cute cat sitting on a windowsill with sunlight", "", 1024, 1024, 50, 4.0, 123],
["Cyberpunk style futuristic city with neon lights", "blurry, low quality", 1024, 768, 50, 4.0, 456],
["Chinese ink painting of bamboo forest with a house", "", 768, 1024, 50, 4.0, 789],
["A beautiful landscape painting with misty mountains", "", 1024, 1024, 100, 4.0, 42],
["A cute cat sitting on a windowsill with sunlight", "", 1024, 1024, 100, 4.0, 123],
["Cyberpunk style futuristic city with neon lights", "blurry, low quality", 1024, 768, 100, 4.0, 456],
["Chinese ink painting of bamboo forest with a house", "", 768, 1024, 100, 4.0, 789],
],
inputs=[prompt, negative_prompt, height, width, steps, cfg_scale, seed],
)

generate_btn.click(
fn=lambda p, h, w, st, c, se, n: generate_image(p, h, w, st, c, se if se >= 0 else None, n, server_url),
fn=lambda p, h, w, st, c, se, n: generate_image(
p,
h,
w,
st,
c,
se if se >= 0 else None,
n,
server_url,
1,
),
inputs=[prompt, height, width, steps, cfg_scale, seed, negative_prompt],
outputs=[output_image],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def generate_image(
extra_body["seed"] = seed
if negative_prompt:
extra_body["negative_prompt"] = negative_prompt
if num_outputs_per_prompt > 1:
extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt
extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt

# Build request payload
payload = {"messages": messages}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Qwen-Image text-to-image curl example

SERVER="${SERVER:-http://localhost:8091}"
PROMPT="${PROMPT:-a cup of coffee on the table}"
OUTPUT="${OUTPUT:-qwen_image_output.png}"
PROMPT="${PROMPT:-a good boy in the ocean}"
CURRENT_TIME=$(date +%Y%m%d%H%M%S)
OUTPUT="${OUTPUT:-qwen_image_output_${CURRENT_TIME}.png}"

echo "Generating image..."
echo "Prompt: $PROMPT"
Expand All @@ -18,12 +19,12 @@ curl -s "$SERVER/v1/chat/completions" \
\"extra_body\": {
\"height\": 1024,
\"width\": 1024,
\"num_inference_steps\": 50,
\"num_inference_steps\": 100,
\"true_cfg_scale\": 4.0,
\"seed\": 42,
\"num_outputs_per_prompt\": 1
}
}" | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2 | base64 -d > "$OUTPUT"
}" | jq -r '.choices[0].message.content[0].image_url.url' | sed 's/^data:image[^,]*,\s*//' | base64 -d > "$OUTPUT"

if [ -f "$OUTPUT" ]; then
echo "Image saved to: $OUTPUT"
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/offline_inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_default_sampling_params_list(self) -> list[SamplingParams]:
Returns:
List of SamplingParams with default decoding for each stage
"""
return [st.default_sampling_params for st in self.omni.instance.stage_list]
return [st.default_sampling_params for st in self.omni.stage_list]

def get_omni_inputs(
self,
Expand Down Expand Up @@ -337,8 +337,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def close(self):
"""Close and cleanup the Omni instance."""
if hasattr(self.omni.instance, "close"):
self.omni.instance.close()
if hasattr(self.omni, "close"):
self.omni.close()


@pytest.fixture(scope="session")
Expand Down
15 changes: 14 additions & 1 deletion tests/e2e/offline_inference/test_cache_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_cache_dit(model_name: str):
width = 256
num_inference_steps = 4 # Minimal steps for fast test

images = m.generate(
outputs = m.generate(
"a photo of a cat sitting on a laptop keyboard",
height=height,
width=width,
Expand All @@ -60,10 +60,23 @@ def test_cache_dit(model_name: str):
generator=torch.Generator("cuda").manual_seed(42),
num_outputs_per_prompt=1, # Single output for speed
)
# Extract images from request_output[0]['images']
first_output = outputs[0]
assert first_output.final_output_type == "image"
if not hasattr(first_output, "request_output") or not first_output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")

req_out = first_output.request_output[0]
if not isinstance(req_out, dict) or "images" not in req_out:
raise ValueError("Invalid request_output structure or missing 'images' key")

images = req_out["images"]

# Verify generation succeeded
assert images is not None
assert len(images) == 1
# Check image size
assert images[0].width == width
assert images[0].height == height
# manually close the Omni instance
m.close()
6 changes: 4 additions & 2 deletions tests/e2e/offline_inference/test_sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in
dtype=dtype,
)
try:
baseline_images = baseline.generate(
outputs = baseline.generate(
PROMPT,
height=height,
width=width,
Expand All @@ -87,6 +87,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in
generator=torch.Generator(get_device_name()).manual_seed(seed),
num_outputs_per_prompt=1,
)
baseline_images = outputs[0].request_output[0]["images"]
finally:
baseline.close()

Expand All @@ -103,7 +104,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in
dtype=dtype,
)
try:
sp_images = sp.generate(
outputs = sp.generate(
PROMPT,
height=height,
width=width,
Expand All @@ -112,6 +113,7 @@ def test_sequence_parallel(model_name: str, ulysses_degree: int, ring_degree: in
generator=torch.Generator(get_device_name()).manual_seed(seed),
num_outputs_per_prompt=1,
)
sp_images = outputs[0].request_output[0]["images"]
finally:
sp.close()

Expand Down
Loading