diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 3ba45afc1..cd11759cd 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -54,6 +54,23 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Audio Generation Model Test" + timeout_in_minutes: 15 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Diffusion Cache Backend Test" timeout_in_minutes: 15 depends_on: image-build diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py new file mode 100644 index 000000000..0b172a077 --- /dev/null +++ b/examples/offline_inference/text_to_audio/text_to_audio.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example script for text-to-audio generation using Stable Audio Open. + +This script demonstrates how to generate audio from text prompts using +the Stable Audio Open model with vLLM-Omni. + +Usage: + python text_to_audio.py --prompt "The sound of a dog barking" + python text_to_audio.py --prompt "A piano playing a gentle melody" --audio_length 10.0 + python text_to_audio.py --prompt "Thunder and rain sounds" --negative_prompt "Low quality" +""" + +import argparse +import time +from pathlib import Path + +import numpy as np +import torch + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.utils.platform_utils import detect_device_type + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate audio with Stable Audio Open.") + parser.add_argument( + "--model", + default="stabilityai/stable-audio-open-1.0", + help="Stable Audio model name or local path.", + ) + parser.add_argument( + "--prompt", + default="The sound of a hammer hitting a wooden surface.", + help="Text prompt for audio generation.", + ) + parser.add_argument( + "--negative_prompt", + default="Low quality.", + help="Negative prompt for classifier-free guidance.", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for deterministic results.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7.0, + help="Classifier-free guidance scale.", + ) + parser.add_argument( + "--audio_start", + type=float, + default=0.0, + help="Audio start time in seconds.", + ) + parser.add_argument( + "--audio_length", + type=float, + default=10.0, + help="Audio length in seconds (max ~47s for stable-audio-open-1.0).", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="Number of denoising steps for the diffusion sampler.", + ) + parser.add_argument( + "--num_waveforms", + type=int, + default=1, + help="Number of audio waveforms to generate for the given prompt.", + ) + parser.add_argument( + "--output", + type=str, + default="stable_audio_output.wav", + help="Path to save the generated audio (WAV format).", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=44100, + help="Sample rate for output audio (Stable Audio uses 44100 Hz).", + ) + return parser.parse_args() + + +def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 44100): + """Save audio data to a WAV file.""" + try: + import soundfile as sf + + sf.write(output_path, audio_data, sample_rate) + except ImportError: + try: + import scipy.io.wavfile as wav + + # Ensure audio is in the correct format for scipy + if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: + # Normalize to int16 range + audio_data = np.clip(audio_data, -1.0, 1.0) + audio_data = (audio_data * 32767).astype(np.int16) + wav.write(output_path, sample_rate, audio_data) + except ImportError: + raise ImportError( + "Either 'soundfile' or 'scipy' is required to save audio files. " + "Install with: pip install soundfile or pip install scipy" + ) + + +def main(): + args = parse_args() + device = detect_device_type() + generator = torch.Generator(device=device).manual_seed(args.seed) + + print(f"\n{'=' * 60}") + print("Stable Audio Open - Text-to-Audio Generation") + print(f"{'=' * 60}") + print(f" Model: {args.model}") + print(f" Prompt: {args.prompt}") + print(f" Negative prompt: {args.negative_prompt}") + print(f" Audio length: {args.audio_length}s") + print(f" Inference steps: {args.num_inference_steps}") + print(f" Guidance scale: {args.guidance_scale}") + print(f" Seed: {args.seed}") + print(f"{'=' * 60}\n") + + # Initialize Omni with Stable Audio model + omni = Omni(model=args.model) + + # Calculate audio end time + audio_end_in_s = args.audio_start + args.audio_length + + # Time profiling for generation + generation_start = time.perf_counter() + + # Generate audio + audio = omni.generate( + args.prompt, + negative_prompt=args.negative_prompt, + generator=generator, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + num_outputs_per_prompt=args.num_waveforms, + extra={ + "audio_start_in_s": args.audio_start, + "audio_end_in_s": audio_end_in_s, + }, + ) + + generation_end = time.perf_counter() + generation_time = generation_end - generation_start + + print(f"Total generation time: {generation_time:.2f} seconds") + + # Process and save audio + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".wav" + stem = output_path.stem or "stable_audio_output" + + # Handle different output formats + if isinstance(audio, torch.Tensor): + audio = audio.cpu().float().numpy() + + # Audio shape is typically [batch, channels, samples] or [channels, samples] + if audio.ndim == 3: + # [batch, channels, samples] + if args.num_waveforms <= 1: + audio_data = audio[0].T # [samples, channels] + save_audio(audio_data, str(output_path), args.sample_rate) + print(f"Saved generated audio to {output_path}") + else: + for idx in range(audio.shape[0]): + audio_data = audio[idx].T # [samples, channels] + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + save_audio(audio_data, str(save_path), args.sample_rate) + print(f"Saved generated audio to {save_path}") + elif audio.ndim == 2: + # [channels, samples] + audio_data = audio.T # [samples, channels] + save_audio(audio_data, str(output_path), args.sample_rate) + print(f"Saved generated audio to {output_path}") + else: + # [samples] - mono audio + save_audio(audio, str(output_path), args.sample_rate) + print(f"Saved generated audio to {output_path}") + + print(f"\nGenerated {args.audio_length}s of audio at {args.sample_rate} Hz") + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py new file mode 100644 index 000000000..ad4fce62a --- /dev/null +++ b/tests/e2e/offline_inference/test_stable_audio_model.py @@ -0,0 +1,52 @@ +import os +import sys +from pathlib import Path + +import numpy as np +import pytest +import torch + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +models = ["stabilityai/stable-audio-open-1.0"] + + +@pytest.mark.parametrize("model_name", models) +def test_stable_audio_model(model_name: str): + m = Omni(model=model_name) + + # Use minimal settings for testing + # Generate a short 2-second audio clip with minimal inference steps + audio_start_in_s = 0.0 + audio_end_in_s = 2.0 # Short duration for fast testing + sample_rate = 44100 # Stable Audio uses 44100 Hz + + audio = m.generate( + "The sound of a dog barking", + negative_prompt="Low quality.", + num_inference_steps=4, # Minimal steps for speed + guidance_scale=7.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + extra={ + "audio_start_in_s": audio_start_in_s, + "audio_end_in_s": audio_end_in_s, + }, + ) + + assert audio is not None + assert isinstance(audio, np.ndarray) + # audio shape: (batch, channels, samples) + # For stable-audio-open-1.0: sample_rate=44100, so 2 seconds = 88200 samples + assert audio.ndim == 3 + assert audio.shape[0] == 1 # batch size + assert audio.shape[1] == 2 # stereo channels + expected_samples = int((audio_end_in_s - audio_start_in_s) * sample_rate) + assert audio.shape[2] == expected_samples # 88200 samples for 2 seconds diff --git a/vllm_omni/diffusion/models/stable_audio/__init__.py b/vllm_omni/diffusion/models/stable_audio/__init__.py new file mode 100644 index 000000000..baa986a0f --- /dev/null +++ b/vllm_omni/diffusion/models/stable_audio/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Stable Audio Open model support for vLLM-Omni.""" + +from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import ( + StableAudioPipeline, + get_stable_audio_post_process_func, +) +from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import ( + StableAudioDiTModel, +) + +__all__ = [ + "StableAudioDiTModel", + "StableAudioPipeline", + "get_stable_audio_post_process_func", +] diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py new file mode 100644 index 000000000..424ca9d89 --- /dev/null +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -0,0 +1,569 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Stable Audio Open Pipeline for vLLM-Omni. + +This module provides text-to-audio generation using the Stable Audio Open model +from Stability AI, integrated with the vLLM-Omni diffusion framework. +""" + +from __future__ import annotations + +import os +from collections.abc import Iterable + +import torch +from diffusers import AutoencoderOobleck +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel +from diffusers.schedulers import CosineDPMSolverMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import T5EncoderModel, T5Tokenizer +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = init_logger(__name__) + + +def get_stable_audio_post_process_func( + od_config: OmniDiffusionConfig, +): + """ + Create post-processing function for Stable Audio output. + + Converts raw audio tensor to numpy array for saving. + """ + + def post_process_func( + audio: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return audio + if output_type == "pt": + return audio + # Convert to numpy + audio_np = audio.cpu().float().numpy() + return audio_np + + return post_process_func + + +class StableAudioPipeline(nn.Module): + """ + Pipeline for text-to-audio generation using Stable Audio Open. + + This pipeline generates audio from text prompts using the Stable Audio Open model + from Stability AI, integrated with vLLM-Omni's diffusion framework. + + Args: + od_config: OmniDiffusion configuration object + prefix: Weight prefix for loading (default: "") + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + + self.device = get_local_device() + dtype = getattr(od_config, "dtype", torch.float16) + + model = od_config.model + local_files_only = os.path.exists(model) + + # Set up weights sources for transformer + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + ] + + # Load tokenizer + self.tokenizer = T5Tokenizer.from_pretrained( + model, + subfolder="tokenizer", + local_files_only=local_files_only, + ) + + # Load text encoder + self.text_encoder = T5EncoderModel.from_pretrained( + model, + subfolder="text_encoder", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + + # Load VAE (AutoencoderOobleck for audio) + self.vae = AutoencoderOobleck.from_pretrained( + model, + subfolder="vae", + torch_dtype=torch.float32, + local_files_only=local_files_only, + ).to(self.device) + + # Load projection model (using diffusers implementation) + self.projection_model = StableAudioProjectionModel.from_pretrained( + model, + subfolder="projection_model", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + + # Initialize our custom transformer (weights loaded via load_weights) + self.transformer = StableAudioDiTModel(od_config=od_config) + + # Load scheduler + self.scheduler = CosineDPMSolverMultistepScheduler.from_pretrained( + model, + subfolder="scheduler", + local_files_only=local_files_only, + ) + + # Compute rotary embedding dimension + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Cache backend (set by worker if needed) + self._cache_backend = None + + # Properties + self._guidance_scale = None + self._num_timesteps = None + self._current_timestep = None + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + def check_inputs( + self, + prompt: str | list[str] | None, + audio_start_in_s: float, + audio_end_in_s: float, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + """Validate input parameters.""" + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}` must be higher than `audio_start_in_s={audio_start_in_s}`" + ) + + min_val = self.projection_model.config.min_value + max_val = self.projection_model.config.max_value + + if audio_start_in_s < min_val or audio_start_in_s > max_val: + raise ValueError(f"`audio_start_in_s` must be between {min_val} and {max_val}, got {audio_start_in_s}") + + if audio_end_in_s < min_val or audio_end_in_s > max_val: + raise ValueError(f"`audio_end_in_s` must be between {min_val} and {max_val}, got {audio_end_in_s}") + + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`. Please provide only one.") + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device, + do_classifier_free_guidance: bool, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + negative_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Encode text prompt to embeddings.""" + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # Tokenize + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # Encode + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + )[0] + + # Handle negative prompt for CFG + if do_classifier_free_guidance and negative_prompt is not None: + if isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt` has batch size {len(negative_prompt)}, but `prompt` " + f"has batch size {batch_size}. Please make sure they match." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + )[0] + + if negative_attention_mask is not None: + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), + negative_prompt_embeds, + 0.0, + ) + + # Concatenate for CFG + if do_classifier_free_guidance and negative_prompt_embeds is not None: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + # Project embeddings + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s: float, + audio_end_in_s: float, + device: torch.device, + do_classifier_free_guidance: bool, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Encode audio duration to conditioning tensors.""" + audio_start_in_s = [audio_start_in_s] if isinstance(audio_start_in_s, (int, float)) else audio_start_in_s + audio_end_in_s = [audio_end_in_s] if isinstance(audio_end_in_s, (int, float)) else audio_end_in_s + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + audio_start_in_s = torch.tensor([float(x) for x in audio_start_in_s]).to(device) + audio_end_in_s = torch.tensor([float(x) for x in audio_end_in_s]).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + def prepare_latents( + self, + batch_size: int, + num_channels_vae: int, + sample_size: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """Prepare initial latent noise.""" + shape = (batch_size, num_channels_vae, sample_size) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # Scale by scheduler's noise sigma + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + audio_end_in_s: float | None = None, + audio_start_in_s: float = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + num_waveforms_per_prompt: int = 1, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str = "np", + ) -> DiffusionOutput: + """ + Generate audio from text prompt. + + Args: + req: OmniDiffusionRequest containing generation parameters + prompt: Text prompt for audio generation + negative_prompt: Negative prompt for CFG + audio_end_in_s: Audio end time in seconds (max ~47s for stable-audio-open-1.0) + audio_start_in_s: Audio start time in seconds + num_inference_steps: Number of denoising steps + guidance_scale: CFG scale + num_waveforms_per_prompt: Number of audio outputs per prompt + generator: Random generator for reproducibility + latents: Pre-generated latents + prompt_embeds: Pre-computed prompt embeddings + negative_prompt_embeds: Pre-computed negative prompt embeddings + output_type: Output format ("np", "pt", or "latent") + + Returns: + DiffusionOutput containing generated audio + """ + # Extract from request + prompt = req.prompt if req.prompt is not None else prompt + negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt + num_inference_steps = req.num_inference_steps or num_inference_steps + if req.guidance_scale_provided: + guidance_scale = req.guidance_scale + + if generator is None: + generator = req.generator + if generator is None and req.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.seed) + + # Get audio duration from request extra params or defaults + audio_start_in_s = req.extra.get("audio_start_in_s", audio_start_in_s) + audio_end_in_s = req.extra.get("audio_end_in_s", audio_end_in_s) + + # Calculate audio length + downsample_ratio = self.vae.hop_length + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"Requested audio length ({audio_end_in_s - audio_start_in_s}s) exceeds " + f"maximum ({max_audio_length_in_s}s)" + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # Validate inputs + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # Determine batch size + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.device + do_classifier_free_guidance = guidance_scale > 1.0 + self._guidance_scale = guidance_scale + + # Encode prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create combined embeddings + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], + dim=1, + ) + audio_duration_embeds = torch.cat( + [seconds_start_hidden_states, seconds_end_hidden_states], + dim=2, + ) + + # Handle CFG without negative prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like(text_audio_duration_embeds) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], + dim=0, + ) + audio_duration_embeds = torch.cat( + [audio_duration_embeds, audio_duration_embeds], + dim=0, + ) + + # Duplicate for multiple waveforms per prompt + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # Prepare latents + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + ) + + # Prepare rotary embeddings and move to device + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + # Move rotary embeddings to device (returns tuple of cos, sin) + rotary_embedding = ( + rotary_embedding[0].to(device=device, dtype=latents.dtype), + rotary_embedding[1].to(device=device, dtype=latents.dtype), + ) + + # Denoising loop + for t in timesteps: + self._current_timestep = t + + # Expand latents for CFG + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Predict noise + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] + + # Perform CFG + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Scheduler step + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + self._current_timestep = None + + # Decode + if output_type == "latent": + audio = latents + else: + # Convert latents to VAE dtype (VAE may use float32) + latents_for_vae = latents.to(dtype=self.vae.dtype) + audio = self.vae.decode(latents_for_vae).sample + + # Trim to requested length + audio = audio[:, :, waveform_start:waveform_end] + + return DiffusionOutput(output=audio) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py new file mode 100644 index 000000000..a4b9ac35e --- /dev/null +++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py @@ -0,0 +1,603 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Stable Audio DiT Model for vLLM-Omni. +""" + +import math +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn as nn +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig + +logger = init_logger(__name__) + + +def apply_rotary_emb_stable_audio( + hidden_states: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors for Stable Audio. + + Args: + hidden_states: Input tensor of shape [B, S, H, D] where D is head_dim + freqs_cis: Tuple of (cos, sin) frequency tensors of shape [S, rotary_dim] + where rotary_dim = head_dim // 2 + + Returns: + Tensor with rotary embeddings applied to first rotary_dim dimensions only. + The remaining dimensions are left unchanged (pass-through). + """ + cos, sin = freqs_cis # [S, rotary_dim] + rotary_dim = cos.shape[-1] + + # Rotate only the first rotary_dim entries; leave the rest unchanged + x_rot = hidden_states[..., :rotary_dim] + x_pass = hidden_states[..., rotary_dim:] + + cos = cos[None, :, None, :] # [1, S, 1, rotary_dim] + sin = sin[None, :, None, :] # [1, S, 1, rotary_dim] + + # [B, S, H, rotary_dim] -> [B, S, H, 2, rotary_dim//2] -> two halves + x_real, x_imag = x_rot.reshape(*x_rot.shape[:-1], 2, rotary_dim // 2).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + + x_rot = (x_rot.float() * cos + x_rotated.float() * sin).to(hidden_states.dtype) + return torch.cat([x_rot, x_pass], dim=-1) + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels. + + Matches diffusers StableAudioGaussianFourierProjection with: + - flip_sin_to_cos=True (output is [cos, sin] not [sin, cos]) + - log=False (no log transformation of input) + """ + + def __init__(self, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x shape: [batch] or [batch, 1] + # Output: [batch, embedding_size * 2] + x_proj = 2 * math.pi * x[:, None] @ self.weight[None, :] + # flip_sin_to_cos=True means cos comes first + return torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + + +class StableAudioSelfAttention(nn.Module): + """ + Optimized self-attention for Stable Audio using vLLM layers. + + Self-attention uses full attention (all heads for Q, K, V). + GQA is only used for cross-attention. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim + + # All projections use inner_dim for output + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=False) + self.to_k = ReplicatedLinear(dim, self.inner_dim, bias=False) + self.to_v = ReplicatedLinear(dim, self.inner_dim, bias=False) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=False), + nn.Dropout(dropout), + ] + ) + + # Full attention (no GQA for self-attention) + self.attn = Attention( + num_heads=num_attention_heads, + head_size=attention_head_dim, + softmax_scale=1.0 / (attention_head_dim**0.5), + causal=False, + num_kv_heads=num_attention_heads, # Same as query heads + ) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + + # Projections - all output inner_dim + query, _ = self.to_q(hidden_states) + key, _ = self.to_k(hidden_states) + value, _ = self.to_v(hidden_states) + + # Reshape for multi-head attention (all use full heads) + query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # Apply rotary embeddings + if rotary_emb is not None: + query = apply_rotary_emb_stable_audio(query, rotary_emb) + key = apply_rotary_emb_stable_audio(key, rotary_emb) + + # Compute attention + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.view(batch_size, seq_len, self.inner_dim) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class StableAudioCrossAttention(nn.Module): + """ + Optimized cross-attention for Stable Audio using vLLM layers. + + For cross-attention: + - Q projection: outputs inner_dim (full heads) + - K/V projections: outputs kv_dim (reduced heads for GQA) + + GQA is handled by manually expanding K/V heads to match Q heads + since the SDPA backend doesn't handle this automatically. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_attention_heads + self.num_kv_heads = num_key_value_attention_heads + self.head_dim = attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim + self.kv_dim = num_key_value_attention_heads * attention_head_dim + + # Number of times to repeat KV heads + self.num_kv_groups = num_attention_heads // num_key_value_attention_heads + + # Q outputs inner_dim, K/V output kv_dim (GQA) + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=False) + self.to_k = ReplicatedLinear(cross_attention_dim, self.kv_dim, bias=False) + self.to_v = ReplicatedLinear(cross_attention_dim, self.kv_dim, bias=False) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=False), + nn.Dropout(dropout), + ] + ) + + # Use full heads for attention (KV will be expanded) + self.attn = Attention( + num_heads=num_attention_heads, + head_size=attention_head_dim, + softmax_scale=1.0 / (attention_head_dim**0.5), + causal=False, + num_kv_heads=num_attention_heads, # After expansion + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + encoder_seq_len = encoder_hidden_states.shape[1] + + # Projections + query, _ = self.to_q(hidden_states) + key, _ = self.to_k(encoder_hidden_states) + value, _ = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, encoder_seq_len, self.num_kv_heads, self.head_dim) + value = value.view(batch_size, encoder_seq_len, self.num_kv_heads, self.head_dim) + + # Expand K/V heads to match Q heads for GQA + # [B, S, kv_heads, D] -> [B, S, kv_heads, 1, D] -> [B, S, kv_heads, groups, D] -> [B, S, num_heads, D] + key = key.unsqueeze(3).expand(-1, -1, -1, self.num_kv_groups, -1) + key = key.reshape(batch_size, encoder_seq_len, self.num_heads, self.head_dim) + value = value.unsqueeze(3).expand(-1, -1, -1, self.num_kv_groups, -1) + value = value.reshape(batch_size, encoder_seq_len, self.num_heads, self.head_dim) + + # Compute attention + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.view(batch_size, seq_len, self.inner_dim) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class SwiGLU(nn.Module): + """SwiGLU activation - matches diffusers structure.""" + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class StableAudioFeedForward(nn.Module): + """ + Feed-forward network with SwiGLU activation for Stable Audio. + Matches diffusers FeedForward structure with activation_fn="swiglu". + """ + + def __init__(self, dim: int, inner_dim: int, bias: bool = True): + super().__init__() + # Structure matches diffusers FeedForward: + # net.0 = SwiGLU (proj.weight, proj.bias) + # net.1 = Dropout + # net.2 = Linear (weight, bias) + self.net = nn.Sequential( + SwiGLU(dim, inner_dim, bias=bias), + nn.Dropout(0.0), + nn.Linear(inner_dim, dim, bias=bias), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.net(hidden_states) + + +class StableAudioDiTBlock(nn.Module): + """ + Stable Audio DiT block with self-attention, cross-attention, and FFN. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + ff_mult: int = 4, + ): + super().__init__() + + # Self-attention with layer norm + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True) + self.attn1 = StableAudioSelfAttention( + dim=dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + ) + + # Cross-attention with layer norm + self.norm2 = nn.LayerNorm(dim, elementwise_affine=True) + self.attn2 = StableAudioCrossAttention( + dim=dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + + # Feed-forward with SwiGLU activation + # inner_dim = dim * ff_mult (e.g., 1536 * 4 = 6144) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=True) + self.ff = StableAudioFeedForward(dim, inner_dim=dim * ff_mult) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + rotary_embedding: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Self-attention with skip connection + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn1(hidden_states, rotary_emb=rotary_embedding, attention_mask=attention_mask) + hidden_states = residual + hidden_states + + # Cross-attention with skip connection + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2( + hidden_states, + encoder_hidden_states, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + hidden_states = residual + hidden_states + + # Feed-forward with skip connection + residual = hidden_states + hidden_states = self.norm3(hidden_states) + hidden_states = self.ff(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class StableAudioDiTModel(nn.Module): + """ + Optimized Stable Audio DiT model using vLLM layers. + + This is an optimized version of the diffusers StableAudioDiTModel that uses + vLLM's efficient linear layers and attention implementations. + + Architecture: + - Input: [B, in_channels, L] (e.g., [B, 64, L]) + - preprocess_conv: residual conv layer (keeps 64 channels) + - proj_in: projects 64 -> 1536 (inner_dim) + - Global+time embeddings prepended to sequence + - Transformer blocks work on 1536-dim + - proj_out: projects 1536 -> 64 (out_channels) + - postprocess_conv: residual conv layer (keeps 64 channels) + - Output: [B, out_channels, L] + """ + + def __init__( + self, + od_config: Optional[OmniDiffusionConfig] = None, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + + # inner_dim is the transformer hidden dimension + self.inner_dim = num_attention_heads * attention_head_dim + + # Store config for compatibility + self.config = type( + "Config", + (), + { + "sample_size": sample_size, + "in_channels": in_channels, + "out_channels": out_channels, + "num_layers": num_layers, + "attention_head_dim": attention_head_dim, + "num_attention_heads": num_attention_heads, + "num_key_value_attention_heads": num_key_value_attention_heads, + "cross_attention_dim": cross_attention_dim, + "time_proj_dim": time_proj_dim, + "global_states_input_dim": global_states_input_dim, + "cross_attention_input_dim": cross_attention_input_dim, + }, + )() + + # Time projection (Gaussian Fourier features) + # time_proj_dim is the OUTPUT dimension (after sin/cos concatenation) + # So embedding_size = time_proj_dim // 2 + self.time_proj = StableAudioGaussianFourierProjection(embedding_size=time_proj_dim // 2) + + # Timestep projection: time_proj_dim -> inner_dim + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + # Global states projection (for audio duration conditioning) + # Output is inner_dim, added to time embedding + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + # Cross-attention input projection + # Always use Sequential(Linear, SiLU, Linear) to match diffusers structure + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + # Pre-processing conv (residual connection) + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + + # Input projection: in_channels -> inner_dim (64 -> 1536) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + # Transformer blocks - work on inner_dim (1536) + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + # Output projection: inner_dim -> out_channels (1536 -> 64) + self.proj_out = nn.Linear(self.inner_dim, out_channels, bias=False) + + # Post-processing conv (residual connection) + self.postprocess_conv = nn.Conv1d(out_channels, out_channels, 1, bias=False) + + @property + def dtype(self) -> torch.dtype: + """Return the dtype of the model parameters.""" + return next(self.parameters()).dtype + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + global_hidden_states: Optional[torch.Tensor] = None, + rotary_embedding: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass of the Stable Audio DiT model. + + Args: + hidden_states: Input latent tensor [B, C, L] (C=in_channels=64) + timestep: Timestep tensor [B] or [1] + encoder_hidden_states: Text/condition embeddings [B, S, D] + global_hidden_states: Global conditioning (duration) [B, 1, D] + rotary_embedding: Precomputed rotary embeddings (cos, sin) + return_dict: Whether to return a dataclass or tuple + attention_mask: Attention mask for self-attention + encoder_attention_mask: Attention mask for cross-attention + + Returns: + Denoised latent tensor + """ + # Project cross-attention inputs + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + + # Global embedding projection [B, 1, D] -> [B, 1, inner_dim] + global_hidden_states = self.global_proj(global_hidden_states) + + # Time embedding: timestep -> time_proj -> timestep_proj + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) + + # Combine global and time embeddings [B, 1, inner_dim] + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + # Pre-process with residual: [B, C, L] + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + + # Transpose: [B, C, L] -> [B, L, C] + hidden_states = hidden_states.transpose(1, 2) + + # Project to inner_dim: [B, L, C] -> [B, L, inner_dim] + hidden_states = self.proj_in(hidden_states) + + # Prepend global states to hidden states: [B, 1+L, inner_dim] + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=1) + + # Update attention mask if provided + if attention_mask is not None: + prepend_mask = torch.ones( + (hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=torch.bool, + ) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + # Transformer blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + + # Project back to out_channels: [B, 1+L, inner_dim] -> [B, 1+L, out_channels] + hidden_states = self.proj_out(hidden_states) + + # Transpose and remove prepended global token: [B, L, C] -> [B, C, L] + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + + # Post-process with residual: [B, C, L] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if return_dict: + return Transformer2DModelOutput(sample=hidden_states) + return (hidden_states,) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from a pretrained model. + + Maps diffusers weight names to our module structure. + + Returns: + Set of parameter names that were successfully loaded. + """ + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + # Weight name mapping from diffusers to our implementation + name_mapping = { + # Timestep projection - diffusers uses index-based naming + "timestep_proj.linear_1.weight": "timestep_proj.0.weight", + "timestep_proj.linear_1.bias": "timestep_proj.0.bias", + "timestep_proj.linear_2.weight": "timestep_proj.2.weight", + "timestep_proj.linear_2.bias": "timestep_proj.2.bias", + # Global projection - diffusers uses index-based naming + "global_proj.linear_1.weight": "global_proj.0.weight", + "global_proj.linear_2.weight": "global_proj.2.weight", + } + + for name, loaded_weight in weights: + # Apply name mapping if needed + mapped_name = name_mapping.get(name, name) + + if mapped_name in params_dict: + param = params_dict[mapped_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(mapped_name) + else: + logger.debug(f"Skipping weight {name} - not found in model") + + return loaded_params diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 71140536d..c0ca5ca90 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -44,6 +44,10 @@ "pipeline_wan2_2", "Wan22Pipeline", ), + "StableAudioPipeline": ( + "stable_audio", + "pipeline_stable_audio", + "StableAudioPipeline", "LongCatImagePipeline": ( "longcat_image", "pipeline_longcat_image", @@ -94,6 +98,7 @@ def initialize_model( "ZImagePipeline": "get_post_process_func", "OvisImagePipeline": "get_ovis_image_post_process_func", "WanPipeline": "get_wan22_post_process_func", + "StableAudioPipeline": "get_stable_audio_post_process_func", "LongCatImagePipeline": "get_longcat_image_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", } diff --git a/vllm_omni/diffusion/request.py b/vllm_omni/diffusion/request.py index bbceea350..cd31aba67 100644 --- a/vllm_omni/diffusion/request.py +++ b/vllm_omni/diffusion/request.py @@ -109,6 +109,7 @@ class OmniDiffusionRequest: # Scheduler parameters num_inference_steps: int = 50 guidance_scale: float = 1.0 + guidance_scale_provided: bool = False guidance_scale_2: float | None = None guidance_rescale: float = 0.0 eta: float = 0.0 diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 9f149f37b..125fad978 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -26,6 +26,9 @@ def prepare_requests(prompt: str | list[str], **kwargs): if key in field_names: init_kwargs[key] = value + if "guidance_scale" in kwargs: + init_kwargs["guidance_scale_provided"] = True + return OmniDiffusionRequest(**init_kwargs)