Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

Main entry points for vLLM-Omni inference and serving.

- [vllm_omni.entrypoints.async_diffusion.AsyncOmniDiffusion][]
- [vllm_omni.entrypoints.async_omni_diffusion.AsyncOmniDiffusion][]
- [vllm_omni.entrypoints.async_omni.AsyncOmni][]
- [vllm_omni.entrypoints.async_omni.AsyncOmniStageLLM][]
- [vllm_omni.entrypoints.async_omni_llm.AsyncOmniLLM][]
- [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalContentParser][]
- [vllm_omni.entrypoints.chat_utils.OmniAsyncMultiModalItemTracker][]
- [vllm_omni.entrypoints.chat_utils.parse_chat_messages_futures][]
Expand Down
123 changes: 83 additions & 40 deletions examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import torch
from PIL import Image

from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.data import DiffusionParallelConfig, logger
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type, is_npu

Expand Down Expand Up @@ -317,45 +317,88 @@ def main():
print(f" Parallel configuration: ulysses_degree={args.ulysses_degree}")
print(f"{'=' * 60}\n")

generation_start = time.perf_counter()
# Generate edited image
images = omni.generate(
prompt=args.prompt,
pil_image=input_image,
negative_prompt=args.negative_prompt,
generator=generator,
true_cfg_scale=args.cfg_scale,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
num_outputs_per_prompt=args.num_outputs_per_prompt,
layers=args.layers,
)
generation_end = time.perf_counter()
generation_time = generation_end - generation_start

# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

# Save output image(s)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".png"
stem = output_path.stem or "output_image_edit"

if args.num_outputs_per_prompt <= 1:
img = images[0]
img = img if isinstance(img, list) else [img]
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
else:
for idx, img in enumerate(images):
img = img if isinstance(img, list) else [img]
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
try:
generation_start = time.perf_counter()
# Generate edited image
generate_kwargs = {
"prompt": args.prompt,
"pil_image": input_image,
"negative_prompt": args.negative_prompt,
"generator": generator,
"true_cfg_scale": args.cfg_scale,
"guidance_scale": args.guidance_scale,
"num_inference_steps": args.num_inference_steps,
"num_outputs_per_prompt": args.num_outputs_per_prompt,
"layers": args.layers,
"resolution": args.resolution,
}

outputs = omni.generate(**generate_kwargs)
generation_end = time.perf_counter()
generation_time = generation_end - generation_start

# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

if not outputs:
raise ValueError("No output generated from omni.generate()")
logger.info("Outputs: %s", outputs)

# Extract images from OmniRequestOutput
# Handle both OmniRequestOutput list and direct images list
images = []
if isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
# Check if it's OmniRequestOutput with images attribute
if hasattr(first_output, "images") and first_output.images:
images = first_output.images
elif hasattr(first_output, "request_output") and first_output.request_output:
req_out = first_output.request_output
if isinstance(req_out, list):
req_out = req_out[0]
if hasattr(req_out, "images"):
images = req_out.images or []
# Check if outputs is already a list of images
elif isinstance(first_output, Image.Image):
images = outputs
elif isinstance(outputs, Image.Image):
images = [outputs]

if not images:
raise ValueError("No images found in omni.generate() output")

# Save output image(s)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".png"
stem = output_path.stem or "output_image_edit"

# Handle layered output (each image may be a list of layers)
if args.num_outputs_per_prompt <= 1:
img = images[0]
# Check if this is a layered output (list of images)
if isinstance(img, list):
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
else:
img.save(output_path)
print(f"Saved edited image to {os.path.abspath(output_path)}")
else:
for idx, img in enumerate(images):
# Check if this is a layered output (list of images)
if isinstance(img, list):
for sub_idx, sub_img in enumerate(img):
save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}"
sub_img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
else:
save_path = output_path.parent / f"{stem}_{idx}{suffix}"
img.save(save_path)
print(f"Saved edited image to {os.path.abspath(save_path)}")
finally:
omni.close()


if __name__ == "__main__":
Expand Down
20 changes: 15 additions & 5 deletions examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.data import DiffusionParallelConfig, logger
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type, is_npu

Expand All @@ -20,7 +20,7 @@ def parse_args() -> argparse.Namespace:
help="Diffusion model name or local path. Supported models: Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo",
)
parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.")
parser.add_argument("--seed", type=int, default=142, help="Random seed for deterministic results.")
parser.add_argument(
"--cfg_scale",
type=float,
Expand All @@ -32,7 +32,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--output",
type=str,
default="qwen_image_output.png",
default="qwen_image_output1.png",
help="Path to save the generated image (PNG).",
)
parser.add_argument(
Expand Down Expand Up @@ -127,7 +127,7 @@ def main():
print(f"{'=' * 60}\n")

generation_start = time.perf_counter()
images = omni.generate(
outputs = omni.generate(
args.prompt,
height=args.height,
width=args.width,
Expand All @@ -142,11 +142,19 @@ def main():
# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

# Extract images from OmniRequestOutput
# omni.generate() returns list[OmniRequestOutput], extract images from the first output
if not outputs or len(outputs) == 0:
raise ValueError("No output generated from omni.generate()")
logger.info(f"Outputs: {outputs}")

images = outputs[0].request_output[0].images

output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".png"
stem = output_path.stem or "qwen_image_output"
if args.num_images_per_prompt <= 1:
if len(images) <= 1:
images[0].save(output_path)
print(f"Saved generated image to {output_path}")
else:
Expand All @@ -155,6 +163,8 @@ def main():
img.save(save_path)
print(f"Saved generated image to {save_path}")

omni.close()


if __name__ == "__main__":
main()
26 changes: 20 additions & 6 deletions examples/online_serving/text_to_image/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def generate_image(
seed: int | None,
negative_prompt: str,
server_url: str,
num_outputs_per_prompt: int = 1,
) -> Image.Image | None:
"""Generate an image using the chat completions API."""
messages = [{"role": "user", "content": prompt}]
Expand All @@ -39,6 +40,8 @@ def generate_image(
extra_body["seed"] = seed
if negative_prompt:
extra_body["negative_prompt"] = negative_prompt
# 与 run_curl_text_to_image.sh 保持一致,始终发送 num_outputs_per_prompt
extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt

# Build request payload
payload = {"messages": messages, "extra_body": extra_body}
Expand Down Expand Up @@ -109,7 +112,8 @@ def create_demo(server_url: str):
label="Inference Steps",
minimum=10,
maximum=100,
value=50,
# 默认步数与 run_curl_text_to_image.sh 对齐为 100
value=100,
step=5,
)
cfg_scale = gr.Slider(
Expand Down Expand Up @@ -138,16 +142,26 @@ def create_demo(server_url: str):
# Examples
gr.Examples(
examples=[
["A beautiful landscape painting with misty mountains", "", 1024, 1024, 50, 4.0, 42],
["A cute cat sitting on a windowsill with sunlight", "", 1024, 1024, 50, 4.0, 123],
["Cyberpunk style futuristic city with neon lights", "blurry, low quality", 1024, 768, 50, 4.0, 456],
["Chinese ink painting of bamboo forest with a house", "", 768, 1024, 50, 4.0, 789],
["A beautiful landscape painting with misty mountains", "", 1024, 1024, 100, 4.0, 42],
["A cute cat sitting on a windowsill with sunlight", "", 1024, 1024, 100, 4.0, 123],
["Cyberpunk style futuristic city with neon lights", "blurry, low quality", 1024, 768, 100, 4.0, 456],
["Chinese ink painting of bamboo forest with a house", "", 768, 1024, 100, 4.0, 789],
],
inputs=[prompt, negative_prompt, height, width, steps, cfg_scale, seed],
)

generate_btn.click(
fn=lambda p, h, w, st, c, se, n: generate_image(p, h, w, st, c, se if se >= 0 else None, n, server_url),
fn=lambda p, h, w, st, c, se, n: generate_image(
p,
h,
w,
st,
c,
se if se >= 0 else None,
n,
server_url,
1,
),
inputs=[prompt, height, width, steps, cfg_scale, seed, negative_prompt],
outputs=[output_image],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def generate_image(
extra_body["seed"] = seed
if negative_prompt:
extra_body["negative_prompt"] = negative_prompt
if num_outputs_per_prompt > 1:
extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt
extra_body["num_outputs_per_prompt"] = num_outputs_per_prompt

# Build request payload
payload = {"messages": messages}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Qwen-Image text-to-image curl example

SERVER="${SERVER:-http://localhost:8091}"
PROMPT="${PROMPT:-a cup of coffee on the table}"
OUTPUT="${OUTPUT:-qwen_image_output.png}"
PROMPT="${PROMPT:-a good boy in the ocean}"
CURRENT_TIME=$(date +%Y%m%d%H%M%S)
OUTPUT="${OUTPUT:-qwen_image_output_${CURRENT_TIME}.png}"

echo "Generating image..."
echo "Prompt: $PROMPT"
Expand All @@ -18,12 +19,12 @@ curl -s "$SERVER/v1/chat/completions" \
\"extra_body\": {
\"height\": 1024,
\"width\": 1024,
\"num_inference_steps\": 50,
\"num_inference_steps\": 100,
\"true_cfg_scale\": 4.0,
\"seed\": 42,
\"num_outputs_per_prompt\": 1
}
}" | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2 | base64 -d > "$OUTPUT"
}" | jq -r '.choices[0].message.content[0].image_url.url' | sed 's/^data:image[^,]*,\s*//' | base64 -d > "$OUTPUT"

if [ -f "$OUTPUT" ]; then
echo "Image saved to: $OUTPUT"
Expand Down
28 changes: 24 additions & 4 deletions vllm_omni/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ def draw_hf_text_config(self):
# we need to draw the text config from the corresponding model stage.
if self.hf_config_name is None:
return get_hf_text_config(self.hf_config)
return getattr(self.hf_config, self.hf_config_name).get_text_config()
try:
# Try to get the stage-specific config (e.g., thinker_config, talker_config)
stage_config = getattr(self.hf_config, self.hf_config_name)
return stage_config.get_text_config()
except AttributeError:
# Fallback: if the attribute doesn't exist, use the default get_hf_text_config
logger.warning(
f"Config attribute '{self.hf_config_name}' not found in hf_config, "
"falling back to default get_hf_text_config"
)
return get_hf_text_config(self.hf_config)

def __post_init__(
self,
Expand Down Expand Up @@ -173,9 +183,19 @@ def __post_init__(
self.hf_text_config = self.draw_hf_text_config()
self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision
)
# Try to load image processor config, but allow it to fail for stages that don't need it
try:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision
)
except (OSError, ValueError, IndexError) as e:
# Some stages (e.g., code2wav, talker) don't need image processor
# Log warning but allow initialization to continue
logger.warning(
f"Failed to load image processor config for model '{self.model}': {e}. "
"This is expected for stages that don't require image processing."
)
self.hf_image_processor_config = None

architectures = self.architectures
registry = self.registry
Expand Down
Loading