Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"

- label: "Audio Generation Model Test"
timeout_in_minutes: 15
depends_on: image-build
commands:
- pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
agents:
queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
plugins:
- docker#v5.2.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
always-pull: true
propagate-environment: true
environment:
- "HF_HOME=/fsx/hf_cache"
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"

- label: "Diffusion Cache Backend Test"
timeout_in_minutes: 15
depends_on: image-build
Expand Down
200 changes: 200 additions & 0 deletions examples/offline_inference/text_to_audio/text_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Example script for text-to-audio generation using Stable Audio Open.

This script demonstrates how to generate audio from text prompts using
the Stable Audio Open model with vLLM-Omni.

Usage:
python text_to_audio.py --prompt "The sound of a dog barking"
python text_to_audio.py --prompt "A piano playing a gentle melody" --audio_length 10.0
python text_to_audio.py --prompt "Thunder and rain sounds" --negative_prompt "Low quality"
"""

import argparse
import time
from pathlib import Path

import numpy as np
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate audio with Stable Audio Open.")
parser.add_argument(
"--model",
default="stabilityai/stable-audio-open-1.0",
help="Stable Audio model name or local path.",
)
parser.add_argument(
"--prompt",
default="The sound of a hammer hitting a wooden surface.",
help="Text prompt for audio generation.",
)
parser.add_argument(
"--negative_prompt",
default="Low quality.",
help="Negative prompt for classifier-free guidance.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for deterministic results.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=7.0,
help="Classifier-free guidance scale.",
)
parser.add_argument(
"--audio_start",
type=float,
default=0.0,
help="Audio start time in seconds.",
)
parser.add_argument(
"--audio_length",
type=float,
default=10.0,
help="Audio length in seconds (max ~47s for stable-audio-open-1.0).",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=100,
help="Number of denoising steps for the diffusion sampler.",
)
parser.add_argument(
"--num_waveforms",
type=int,
default=1,
help="Number of audio waveforms to generate for the given prompt.",
)
parser.add_argument(
"--output",
type=str,
default="stable_audio_output.wav",
help="Path to save the generated audio (WAV format).",
)
parser.add_argument(
"--sample_rate",
type=int,
default=44100,
help="Sample rate for output audio (Stable Audio uses 44100 Hz).",
)
return parser.parse_args()


def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 44100):
"""Save audio data to a WAV file."""
try:
import soundfile as sf

sf.write(output_path, audio_data, sample_rate)
except ImportError:
try:
import scipy.io.wavfile as wav

# Ensure audio is in the correct format for scipy
if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
# Normalize to int16 range
audio_data = np.clip(audio_data, -1.0, 1.0)
audio_data = (audio_data * 32767).astype(np.int16)
wav.write(output_path, sample_rate, audio_data)
except ImportError:
raise ImportError(
"Either 'soundfile' or 'scipy' is required to save audio files. "
"Install with: pip install soundfile or pip install scipy"
)


def main():
args = parse_args()
device = detect_device_type()
generator = torch.Generator(device=device).manual_seed(args.seed)

print(f"\n{'=' * 60}")
print("Stable Audio Open - Text-to-Audio Generation")
print(f"{'=' * 60}")
print(f" Model: {args.model}")
print(f" Prompt: {args.prompt}")
print(f" Negative prompt: {args.negative_prompt}")
print(f" Audio length: {args.audio_length}s")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Guidance scale: {args.guidance_scale}")
print(f" Seed: {args.seed}")
print(f"{'=' * 60}\n")

# Initialize Omni with Stable Audio model
omni = Omni(model=args.model)

# Calculate audio end time
audio_end_in_s = args.audio_start + args.audio_length

# Time profiling for generation
generation_start = time.perf_counter()

# Generate audio
audio = omni.generate(
args.prompt,
negative_prompt=args.negative_prompt,
generator=generator,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
num_outputs_per_prompt=args.num_waveforms,
extra={
"audio_start_in_s": args.audio_start,
"audio_end_in_s": audio_end_in_s,
},
)

generation_end = time.perf_counter()
generation_time = generation_end - generation_start

print(f"Total generation time: {generation_time:.2f} seconds")

# Process and save audio
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".wav"
stem = output_path.stem or "stable_audio_output"

# Handle different output formats
if isinstance(audio, torch.Tensor):
audio = audio.cpu().float().numpy()

# Audio shape is typically [batch, channels, samples] or [channels, samples]
if audio.ndim == 3:
# [batch, channels, samples]
if args.num_waveforms <= 1:
audio_data = audio[0].T # [samples, channels]
save_audio(audio_data, str(output_path), args.sample_rate)
print(f"Saved generated audio to {output_path}")
else:
for idx in range(audio.shape[0]):
audio_data = audio[idx].T # [samples, channels]
save_path = output_path.parent / f"{stem}_{idx}{suffix}"
save_audio(audio_data, str(save_path), args.sample_rate)
print(f"Saved generated audio to {save_path}")
elif audio.ndim == 2:
# [channels, samples]
audio_data = audio.T # [samples, channels]
save_audio(audio_data, str(output_path), args.sample_rate)
print(f"Saved generated audio to {output_path}")
else:
# [samples] - mono audio
save_audio(audio, str(output_path), args.sample_rate)
print(f"Saved generated audio to {output_path}")

print(f"\nGenerated {args.audio_length}s of audio at {args.sample_rate} Hz")


if __name__ == "__main__":
main()
52 changes: 52 additions & 0 deletions tests/e2e/offline_inference/test_stable_audio_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import sys
from pathlib import Path

import numpy as np
import pytest
import torch

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from vllm_omni import Omni

os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"

models = ["stabilityai/stable-audio-open-1.0"]


@pytest.mark.parametrize("model_name", models)
def test_stable_audio_model(model_name: str):
m = Omni(model=model_name)

# Use minimal settings for testing
# Generate a short 2-second audio clip with minimal inference steps
audio_start_in_s = 0.0
audio_end_in_s = 2.0 # Short duration for fast testing
sample_rate = 44100 # Stable Audio uses 44100 Hz

audio = m.generate(
"The sound of a dog barking",
negative_prompt="Low quality.",
num_inference_steps=4, # Minimal steps for speed
guidance_scale=7.0,
generator=torch.Generator("cuda").manual_seed(42),
num_outputs_per_prompt=1,
extra={
"audio_start_in_s": audio_start_in_s,
"audio_end_in_s": audio_end_in_s,
},
)

assert audio is not None
assert isinstance(audio, np.ndarray)
# audio shape: (batch, channels, samples)
# For stable-audio-open-1.0: sample_rate=44100, so 2 seconds = 88200 samples
assert audio.ndim == 3
assert audio.shape[0] == 1 # batch size
assert audio.shape[1] == 2 # stereo channels
expected_samples = int((audio_end_in_s - audio_start_in_s) * sample_rate)
assert audio.shape[2] == expected_samples # 88200 samples for 2 seconds
18 changes: 18 additions & 0 deletions vllm_omni/diffusion/models/stable_audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Stable Audio Open model support for vLLM-Omni."""

from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import (
StableAudioPipeline,
get_stable_audio_post_process_func,
)
from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import (
StableAudioDiTModel,
)

__all__ = [
"StableAudioDiTModel",
"StableAudioPipeline",
"get_stable_audio_post_process_func",
]
Loading
Loading