diff --git a/examples/online_serving/text_to_video/README.md b/examples/online_serving/text_to_video/README.md new file mode 100644 index 000000000..f22c445de --- /dev/null +++ b/examples/online_serving/text_to_video/README.md @@ -0,0 +1,145 @@ +# Text-To-Video + +This example demonstrates how to deploy Wan2.2 text-to-video model for online generation using vLLM-Omni. + +## Start Server + +### Basic Start + +```bash +vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8093 --boundary-ratio 0.875 --flow-shift 5.0 +``` + +Notes: +- `flow-shift`: 5.0 for 720p, 12.0 for 480p (Wan2.2 recommendation). +- `boundary-ratio`: 0.875 for Wan2.2 low/high DiT split. + +### Start with Parameters + +Or use the startup script: + +```bash +bash run_server.sh +``` + +## API Calls + +### Method 1: Using curl + +```bash +# Basic text-to-video generation +bash run_curl_text_to_video.sh + +# Or execute directly +curl -s http://localhost:8093/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + {"role": "user", "content": "A cinematic shot of a flying kite over the ocean."} + ], + "extra_body": { + "height": 720, + "width": 1280, + "num_frames": 81, + "fps": 24, + "num_inference_steps": 40, + "guidance_scale": 4.0, + "guidance_scale_2": 4.0, + "seed": 42 + } + }' | jq -r '.choices[0].message.content[0].video_url.url' | cut -d',' -f2 | base64 -d > output.mp4 +``` + +### Method 2: Using Python Client + +```bash +python openai_chat_client.py --prompt "A cinematic shot of a flying kite over the ocean." --output output.mp4 +``` + +## Request Format + +### Simple Text Generation + +```json +{ + "messages": [ + {"role": "user", "content": "A cinematic shot of a flying kite over the ocean."} + ] +} +``` + +### Generation with Parameters + +Use `extra_body` to pass generation parameters: + +```json +{ + "messages": [ + {"role": "user", "content": "A cinematic shot of a flying kite over the ocean."} + ], + "extra_body": { + "height": 720, + "width": 1280, + "num_frames": 81, + "fps": 24, + "num_inference_steps": 40, + "guidance_scale": 4.0, + "guidance_scale_2": 4.0, + "seed": 42, + "negative_prompt": "" + } +} +``` + +## Generation Parameters (extra_body) + +| Parameter | Type | Default | Description | +| ------------------------ | ----- | ------- | ------------------------------------- | +| `height` | int | None | Video height in pixels | +| `width` | int | None | Video width in pixels | +| `num_frames` | int | None | Number of frames | +| `fps` | int | 24 | Frames per second for exported MP4 | +| `num_inference_steps` | int | 40 | Number of denoising steps | +| `guidance_scale` | float | 4.0 | CFG guidance scale (low noise) | +| `guidance_scale_2` | float | 4.0 | CFG guidance scale (high noise) | +| `seed` | int | None | Random seed (reproducible) | +| `negative_prompt` | str | None | Negative prompt | +| `num_outputs_per_prompt` | int | 1 | Number of videos to generate | + +## Response Format + +```json +{ + "id": "chatcmpl-xxx", + "created": 1234567890, + "model": "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": [{ + "type": "video_url", + "video_url": { + "url": "data:video/mp4;base64,..." + } + }] + }, + "finish_reason": "stop" + }], + "usage": {...} +} +``` + +## Extract Video + +```bash +cat response.json | jq -r '.choices[0].message.content[0].video_url.url' | cut -d',' -f2 | base64 -d > output.mp4 +``` + +## File Description + +| File | Description | +| --------------------------- | ---------------------- | +| `run_server.sh` | Server startup script | +| `run_curl_text_to_video.sh` | curl example | +| `openai_chat_client.py` | Python client | diff --git a/examples/online_serving/text_to_video/openai_chat_client.py b/examples/online_serving/text_to_video/openai_chat_client.py new file mode 100644 index 000000000..0f4b5248d --- /dev/null +++ b/examples/online_serving/text_to_video/openai_chat_client.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +""" +Wan2.2 OpenAI-compatible chat client for text-to-video generation. + +Usage: + python openai_chat_client.py --prompt "A cinematic shot..." --output output.mp4 + python openai_chat_client.py --num-frames 81 --fps 24 --height 720 --width 1280 +""" + +import argparse +import base64 +from pathlib import Path + +import requests + + +def generate_video( + prompt: str, + server_url: str = "http://localhost:8093", + height: int | None = None, + width: int | None = None, + num_frames: int | None = None, + fps: int | None = None, + steps: int | None = None, + guidance_scale: float | None = None, + guidance_scale_2: float | None = None, + seed: int | None = None, + negative_prompt: str | None = None, +) -> bytes | None: + """Generate a video using the chat completions API.""" + messages = [{"role": "user", "content": prompt}] + + extra_body = {} + if height is not None: + extra_body["height"] = height + if width is not None: + extra_body["width"] = width + if num_frames is not None: + extra_body["num_frames"] = num_frames + if fps is not None: + extra_body["fps"] = fps + if steps is not None: + extra_body["num_inference_steps"] = steps + if guidance_scale is not None: + extra_body["guidance_scale"] = guidance_scale + if guidance_scale_2 is not None: + extra_body["guidance_scale_2"] = guidance_scale_2 + if seed is not None: + extra_body["seed"] = seed + if negative_prompt: + extra_body["negative_prompt"] = negative_prompt + + payload = {"messages": messages} + if extra_body: + payload["extra_body"] = extra_body + + try: + response = requests.post( + f"{server_url}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=600, + ) + response.raise_for_status() + data = response.json() + + content = data["choices"][0]["message"]["content"] + if isinstance(content, list): + for item in content: + video_url = item.get("video_url", {}).get("url", "") + if video_url.startswith("data:video"): + _, b64_data = video_url.split(",", 1) + return base64.b64decode(b64_data) + + print(f"Unexpected response format: {content}") + return None + + except Exception as e: + print(f"Error: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser(description="Wan2.2 chat client") + parser.add_argument("--prompt", "-p", default="A cinematic shot of a flying kite over the ocean.") + parser.add_argument("--output", "-o", default="wan22_output.mp4", help="Output file") + parser.add_argument("--server", "-s", default="http://localhost:8093", help="Server URL") + parser.add_argument("--height", type=int, default=720, help="Video height") + parser.add_argument("--width", type=int, default=1280, help="Video width") + parser.add_argument("--num-frames", type=int, default=81, help="Number of frames") + parser.add_argument("--fps", type=int, default=24, help="Frames per second") + parser.add_argument("--steps", type=int, default=40, help="Inference steps") + parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG scale (low noise)") + parser.add_argument("--cfg-scale-high", type=float, default=None, help="CFG scale (high noise)") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--negative", default="", help="Negative prompt") + + args = parser.parse_args() + + print(f"Generating video for: {args.prompt}") + + video_bytes = generate_video( + prompt=args.prompt, + server_url=args.server, + height=args.height, + width=args.width, + num_frames=args.num_frames, + fps=args.fps, + steps=args.steps, + guidance_scale=args.cfg_scale, + guidance_scale_2=args.cfg_scale_high, + seed=args.seed, + negative_prompt=args.negative, + ) + + if video_bytes: + output_path = Path(args.output) + output_path.write_bytes(video_bytes) + print(f"Video saved to: {output_path}") + print(f"Size: {len(video_bytes) / 1024 / 1024:.2f} MB") + else: + print("Failed to generate video") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/text_to_video/run_curl_text_to_video.sh b/examples/online_serving/text_to_video/run_curl_text_to_video.sh new file mode 100644 index 000000000..9ce8ee640 --- /dev/null +++ b/examples/online_serving/text_to_video/run_curl_text_to_video.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Wan2.2 text-to-video curl example + +SERVER="${SERVER:-http://localhost:8093}" +PROMPT="${PROMPT:-A cinematic shot of a flying kite over the ocean.}" +OUTPUT="${OUTPUT:-wan22_output.mp4}" + +HEIGHT="${HEIGHT:-720}" +WIDTH="${WIDTH:-1280}" +NUM_FRAMES="${NUM_FRAMES:-81}" +FPS="${FPS:-24}" +STEPS="${STEPS:-40}" +GUIDANCE_SCALE="${GUIDANCE_SCALE:-4.0}" +GUIDANCE_SCALE_2="${GUIDANCE_SCALE_2:-4.0}" +SEED="${SEED:-42}" + +echo "Generating video..." +echo "Prompt: $PROMPT" +echo "Output: $OUTPUT" + +curl -s "$SERVER/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{ + \"messages\": [ + {\"role\": \"user\", \"content\": \"$PROMPT\"} + ], + \"extra_body\": { + \"height\": $HEIGHT, + \"width\": $WIDTH, + \"num_frames\": $NUM_FRAMES, + \"fps\": $FPS, + \"num_inference_steps\": $STEPS, + \"guidance_scale\": $GUIDANCE_SCALE, + \"guidance_scale_2\": $GUIDANCE_SCALE_2, + \"seed\": $SEED, + \"num_outputs_per_prompt\": 1 + } + }" | jq -r '.choices[0].message.content[0].video_url.url' | cut -d',' -f2 | base64 -d > "$OUTPUT" + +if [ -f "$OUTPUT" ]; then + echo "Video saved to: $OUTPUT" + echo "Size: $(du -h "$OUTPUT" | cut -f1)" +else + echo "Failed to generate video" + exit 1 +fi diff --git a/examples/online_serving/text_to_video/run_server.sh b/examples/online_serving/text_to_video/run_server.sh new file mode 100644 index 000000000..a51a91fbe --- /dev/null +++ b/examples/online_serving/text_to_video/run_server.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Wan2.2 text-to-video server + +MODEL="${MODEL:-Wan-AI/Wan2.2-T2V-A14B-Diffusers}" +PORT="${PORT:-8093}" +BOUNDARY_RATIO="${BOUNDARY_RATIO:-0.875}" +FLOW_SHIFT="${FLOW_SHIFT:-5.0}" + +echo "Starting server for: $MODEL" +echo "Port: $PORT" +echo "boundary_ratio: $BOUNDARY_RATIO" +echo "flow_shift: $FLOW_SHIFT" + +vllm serve "$MODEL" --omni --port "$PORT" --boundary-ratio "$BOUNDARY_RATIO" --flow-shift "$FLOW_SHIFT" diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 14f3bc0f6..9d0259dae 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -15,6 +15,8 @@ from dataclasses import fields from typing import Any +import numpy as np +import torch from PIL import Image from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict @@ -32,7 +34,7 @@ class AsyncOmniDiffusion: This class provides an asynchronous interface for running diffusion models, enabling concurrent request handling. It wraps the DiffusionEngine and - provides async methods for image generation. + provides async methods for image/video generation. Args: model: Model name or path to load @@ -127,10 +129,10 @@ async def generate( seed: int | None = None, **kwargs: Any, ) -> OmniRequestOutput: - """Generate images asynchronously from a text prompt. + """Generate images or videos asynchronously from a text prompt. Args: - prompt: Text prompt describing the desired image + prompt: Text prompt describing the desired output request_id: Optional unique identifier for tracking the request num_inference_steps: Number of denoising steps (default: 50) guidance_scale: Classifier-free guidance scale (default: 7.5) @@ -142,7 +144,7 @@ async def generate( **kwargs: Additional generation parameters Returns: - OmniRequestOutput containing generated images + OmniRequestOutput containing generated images or videos Raises: RuntimeError: If generation fails @@ -186,18 +188,12 @@ async def generate( return result # Process results if not OmniRequestOutput - images: list[Image.Image] = [] - if result is not None: - if isinstance(result, list): - for item in result: - if isinstance(item, Image.Image): - images.append(item) - elif isinstance(result, Image.Image): - images.append(result) + images, videos = self._extract_diffusion_outputs(result) return OmniRequestOutput.from_diffusion( request_id=request_id, images=images, + videos=videos, prompt=prompt, metrics={ "num_inference_steps": num_inference_steps, @@ -211,7 +207,7 @@ async def generate_stream( request_id: str | None = None, **kwargs: Any, ) -> AsyncGenerator[OmniRequestOutput, None]: - """Generate images with streaming progress updates. + """Generate diffusion outputs with streaming progress updates. Currently, diffusion models don't support true streaming, so this yields a single result after generation completes. Future implementations @@ -228,6 +224,62 @@ async def generate_stream( result = await self.generate(prompt=prompt, request_id=request_id, **kwargs) yield result + def _extract_diffusion_outputs(self, result: Any) -> tuple[list[Image.Image], list[Any]]: + """Split diffusion outputs into image or video outputs.""" + images: list[Image.Image] = [] + videos: list[Any] = [] + + if result is None: + return images, videos + + if isinstance(result, Image.Image): + return [result], videos + + if isinstance(result, list): + if result and all(isinstance(item, Image.Image) for item in result): + return list(result), videos + for item in result: + if isinstance(item, Image.Image): + images.append(item) + else: + videos.extend(self._normalize_video_items(item)) + return images, videos + + if isinstance(result, (np.ndarray, torch.Tensor)): + videos.extend(self._normalize_video_items(result)) + + return images, videos + + def _normalize_video_items(self, item: Any) -> list[Any]: + """Normalize possible video outputs into a list of video arrays/tensors.""" + if item is None: + return [] + + if isinstance(item, list): + videos: list[Any] = [] + for sub_item in item: + videos.extend(self._normalize_video_items(sub_item)) + return videos + + if isinstance(item, torch.Tensor): + item = item.detach().cpu() + + if isinstance(item, np.ndarray): + if item.ndim == 5: + return [item[i] for i in range(item.shape[0])] + if item.ndim in (4, 3): + return [item] + return [item] + + if isinstance(item, torch.Tensor): + if item.ndim == 5: + return [item[i] for i in range(item.shape[0])] + if item.ndim in (4, 3): + return [item] + return [item] + + return [] + def close(self) -> None: """Close the engine and release resources. diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 9a8407d84..c4954279a 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -1017,14 +1017,14 @@ async def _create_diffusion_chat_completion( request: ChatCompletionRequest, raw_request: Request | None = None, ) -> ChatCompletionResponse | ErrorResponse: - """Generate images via chat completion interface for diffusion models. + """Generate images or videos via chat completion interface for diffusion models. Args: request: Chat completion request raw_request: Raw FastAPI request object Returns: - ChatCompletionResponse with generated images or ErrorResponse + ChatCompletionResponse with generated images/videos or ErrorResponse """ try: request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" @@ -1074,6 +1074,7 @@ async def _create_diffusion_chat_completion( # Text-to-video parameters (ref: text_to_video.py) num_frames = extra_body.get("num_frames") guidance_scale_2 = extra_body.get("guidance_scale_2") # For video high-noise CFG + fps = extra_body.get("fps") logger.info( "Diffusion chat request %s: prompt=%r, ref_images=%d, params=%s", @@ -1114,6 +1115,8 @@ async def _create_diffusion_chat_completion( gen_kwargs["num_frames"] = num_frames if guidance_scale_2 is not None: gen_kwargs["guidance_scale_2"] = guidance_scale_2 + if fps is not None: + gen_kwargs["fps"] = fps # Add reference image if provided if pil_images: @@ -1157,6 +1160,59 @@ async def _create_diffusion_chat_completion( if result.request_output["images"]: images = result.request_output["images"] + # Convert videos to base64 content if present + videos = getattr(result, "videos", []) + if videos: + fps_value = int(fps) if fps is not None else 24 + video_contents: list[dict[str, Any]] = [] + for video in videos: + frames = self._normalize_video_frames(video) + if not frames: + logger.warning("Video generation completed but no frames were returned.") + continue + video_url = self._encode_video_frames_to_data_url(frames, fps_value) + video_contents.append( + { + "type": "video_url", + "video_url": {"url": video_url}, + } + ) + + content: str | list[dict[str, Any]] + if not video_contents: + content = "Video generation completed but no videos were produced." + else: + content = video_contents + + message = ChatMessage.model_construct(role="assistant", content=content) + choice = ChatCompletionResponseChoice.model_construct( + index=0, + message=message, + finish_reason="stop", + logprobs=None, + stop_reason=None, + ) + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=self._diffusion_model_name, + choices=[choice], + usage=UsageInfo( + prompt_tokens=len(prompt.split()), + completion_tokens=1, + total_tokens=len(prompt.split()) + 1, + ), + ) + + logger.info( + "Diffusion chat completed for request %s: %d videos", + request_id, + len(video_contents), + ) + + return response + # Convert images to base64 content image_contents: list[dict[str, Any]] = [] for img in images: @@ -1222,7 +1278,7 @@ async def _create_diffusion_chat_completion( except Exception as e: logger.exception("Diffusion chat completion failed: %s", e) return self._create_error_response( - f"Image generation failed: {str(e)}", + f"Diffusion generation failed: {str(e)}", status_code=500, ) @@ -1281,6 +1337,65 @@ def _extract_diffusion_prompt_and_images( prompt = " ".join(prompt_parts).strip() return prompt, images + def _normalize_video_frames(self, video: Any) -> list[Any]: + """Normalize video output into a list of frames for export_to_video.""" + import numpy as np + import torch + + video_array = None + if isinstance(video, torch.Tensor): + video_tensor = video.detach().cpu() + if video_tensor.dim() == 5: + # [B, C, F, H, W] or [B, F, H, W, C] + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + # [C, F, H, W] -> [F, H, W, C] + video_tensor = video_tensor.permute(1, 2, 3, 0) + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + video_array = video_tensor.float().numpy() + elif isinstance(video, np.ndarray): + video_array = video + if video_array.ndim == 5: + video_array = video_array[0] + if video_array.ndim == 4 and video_array.shape[-1] not in (1, 3, 4) and video_array.shape[0] in (1, 3, 4): + video_array = np.transpose(video_array, (1, 2, 3, 0)) + if video_array.dtype.kind == "f": + if video_array.min() < 0: + video_array = np.clip(video_array, -1, 1) + video_array = (video_array + 1) / 2 + elif video_array.max() > 1.0: + video_array = np.clip(video_array, 0, 255) / 255.0 + else: + video_array = np.clip(video_array, 0, 1) + + if video_array is None: + return [] + + if video_array.ndim == 3: + video_array = np.expand_dims(video_array, 0) + if video_array.ndim != 4: + return [] + + return list(video_array) + + def _encode_video_frames_to_data_url(self, frames: list[Any], fps: int) -> str: + """Encode video frames into a base64 MP4 data URL.""" + from diffusers.utils import export_to_video + from pathlib import Path + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) / "output.mp4" + export_to_video(frames, str(output_path), fps=fps) + video_bytes = output_path.read_bytes() + + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + return f"data:video/mp4;base64,{video_base64}" + def _create_error_response( self, message: str, diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 86f41a64d..e537ffbdb 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -36,6 +36,7 @@ class OmniRequestOutput: final_output_type: Type of output ("text", "image", "audio", "latents") request_output: The underlying RequestOutput from the stage (pipeline mode) images: List of generated PIL images (diffusion mode) + videos: List of generated videos (diffusion mode) prompt: The prompt used for generation (diffusion mode) latents: Optional tensor of latent representations (diffusion mode) metrics: Optional dictionary of generation metrics @@ -51,6 +52,7 @@ class OmniRequestOutput: # Diffusion model fields images: list[Image.Image] = field(default_factory=list) + videos: list[Any] = field(default_factory=list) prompt: str | None = None latents: torch.Tensor | None = None metrics: dict[str, Any] = field(default_factory=dict) @@ -84,7 +86,8 @@ def from_pipeline( def from_diffusion( cls, request_id: str, - images: list[Image.Image], + images: list[Image.Image] | None = None, + videos: list[Any] | None = None, prompt: str | None = None, metrics: dict[str, Any] | None = None, latents: torch.Tensor | None = None, @@ -94,6 +97,7 @@ def from_diffusion( Args: request_id: Request identifier images: Generated images + videos: Generated videos prompt: The prompt used metrics: Generation metrics latents: Optional latent tensors @@ -101,10 +105,14 @@ def from_diffusion( Returns: OmniRequestOutput configured for diffusion mode """ + image_list = images or [] + video_list = videos or [] + final_output_type = "video" if video_list else "image" return cls( request_id=request_id, - final_output_type="image", - images=images, + final_output_type=final_output_type, + images=image_list, + videos=video_list, prompt=prompt, latents=latents, metrics=metrics or {}, @@ -116,10 +124,19 @@ def num_images(self) -> int: """Return the number of generated images.""" return len(self.images) + @property + def num_videos(self) -> int: + """Return the number of generated videos.""" + return len(self.videos) + @property def is_diffusion_output(self) -> bool: """Check if this is a diffusion model output.""" - return len(self.images) > 0 or self.final_output_type == "image" + return ( + len(self.images) > 0 + or len(self.videos) > 0 + or self.final_output_type in {"image", "video"} + ) @property def is_pipeline_output(self) -> bool: @@ -138,6 +155,7 @@ def to_dict(self) -> dict[str, Any]: result.update( { "num_images": self.num_images, + "num_videos": self.num_videos, "prompt": self.prompt, "metrics": self.metrics, }