Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions tests/e2e/online_serving/test_qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
E2E Online tests for Qwen3-Omni model with video input and audio output.
"""

import concurrent.futures
import os
import socket
import subprocess
Expand Down Expand Up @@ -167,40 +168,55 @@ def dummy_messages_from_video_data(


@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_video_to_audio(
def test_video_to_audio_concurrent(
client: openai.OpenAI,
omni_server,
base64_encoded_video: str,
) -> None:
"""Test processing video, generating audio output via OpenAI API."""
"""Test processing video with multiple concurrent completions, generating audio output via OpenAI API."""
# Create data URL for the base64 encoded video
video_data_url = f"data:video/mp4;base64,{base64_encoded_video}"

messages = dummy_messages_from_video_data(video_data_url)

# Test single completion
chat_completion = client.chat.completions.create(
model=omni_server.model,
messages=messages,
)

assert len(chat_completion.choices) == 2 # 1 for text output, 1 for audio output

# Verify text output
text_choice = chat_completion.choices[0]
assert text_choice.finish_reason == "length"

# Verify we got a response
text_message = text_choice.message
assert text_message.content is not None and len(text_message.content) >= 10
assert text_message.role == "assistant"

# Verify audio output
audio_choice = chat_completion.choices[1]
assert audio_choice.finish_reason == "stop"
audio_message = audio_choice.message

# Check if audio was generated
if hasattr(audio_message, "audio") and audio_message.audio:
assert audio_message.audio.data is not None
assert len(audio_message.audio.data) > 0
# Test multiple concurrent completions
num_concurrent_requests = 5

with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor:
# Submit multiple completion requests concurrently
futures = [
executor.submit(
client.chat.completions.create,
model=omni_server.model,
messages=messages,
)
for _ in range(num_concurrent_requests)
]

# Wait for all requests to complete and collect results
chat_completions = [future.result() for future in concurrent.futures.as_completed(futures)]

# Verify all completions succeeded
assert len(chat_completions) == num_concurrent_requests

for chat_completion in chat_completions:
assert len(chat_completion.choices) == 2 # 1 for text output, 1 for audio output

# Verify text output
text_choice = chat_completion.choices[0]
assert text_choice.finish_reason == "length"

# Verify we got a response
text_message = text_choice.message
assert text_message.content is not None and len(text_message.content) >= 10
assert text_message.role == "assistant"

# Verify audio output
audio_choice = chat_completion.choices[1]
assert audio_choice.finish_reason == "stop"
audio_message = audio_choice.message

# Check if audio was generated
if hasattr(audio_message, "audio") and audio_message.audio:
assert audio_message.audio.data is not None
assert len(audio_message.audio.data) > 0
150 changes: 88 additions & 62 deletions vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import multiprocessing as mp
import os
import queue
import sys
import traceback
from typing import Any
Expand Down Expand Up @@ -1135,19 +1136,14 @@ def filter(self, record: _logging.LogRecord) -> bool:
)
except Exception as e:
_logging.getLogger(__name__).warning("[Stage-%s] Failed to send stage ready signal: %s", stage_id, e)

generation_out_q = asyncio.Queue()
# Batch processing loop
while True:
task = in_q.get()
_recv_dequeue_ts = _time.time()
if task is None:
_logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id)
break

_rx_bytes_by_rid: dict[Any, int] = {}
_rx_decode_ms_by_rid: dict[Any, float] = {}
_in_flight_ms_by_rid: dict[Any, float] = {}
_rx_bytes_by_rid: dict[Any, int] = {}
_rx_decode_ms_by_rid: dict[Any, float] = {}
_in_flight_ms_by_rid: dict[Any, float] = {}

async def generation_single_request(task: dict[str, Any]):
_recv_dequeue_ts = _time.time()
rid = task["request_id"]
try:
sent_ts = float(task.get("sent_ts", None)) if isinstance(task, dict) else None
Expand All @@ -1157,62 +1153,101 @@ def filter(self, record: _logging.LogRecord) -> bool:
_in_flight_ms_by_rid[rid] = 0.0
except Exception:
_in_flight_ms_by_rid[rid] = 0.0
ein, _rx_metrics = try_recv_via_connector(
task=task,
connectors=connectors,
stage_id=stage_id,
)
if ein is None or _rx_metrics is None:
raise RuntimeError(
f"[Stage-{stage_id}] Missing connector payload for request {rid}. "
"Ensure connectors are configured for all incoming edges."
)
_rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0))
_rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0))

sampling_params = task["sampling_params"]
_logging.getLogger(__name__).debug("[Stage-%s] Received batch size=1, request_ids=%s", stage_id, rid)
print("--------------------------------", flush=True)
print(f"[Stage-{stage_id}] Received batch size=1, request_ids={rid}", flush=True)
print("--------------------------------", flush=True)
try:
_batch_seq += 1
ein, _rx_metrics = try_recv_via_connector(
task=task,
connectors=connectors,
stage_id=stage_id,
)
if ein is None or _rx_metrics is None:
raise RuntimeError(
f"[Stage-{stage_id}] Missing connector payload for request {rid}. "
"Ensure connectors are configured for all incoming edges."
)
_rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0))
_rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0))

sampling_params = task["sampling_params"]
_logging.getLogger(__name__).debug("[Stage-%s] Received batch size=1, request_ids=%s", stage_id, rid)
print("--------------------------------", flush=True)
print(f"[Stage-{stage_id}] Received batch size=1, request_ids={rid}", flush=True)
print("--------------------------------", flush=True)
_gen_t0 = _time.time()
if isinstance(ein, list):
ein = ein[0]

async for res in stage_engine.generate(ein, sampling_params, rid):
gen_output = res
_gen_t1 = _time.time()
_gen_ms = (_gen_t1 - _gen_t0) * 1000.0
await generation_out_q.put((rid, gen_output, _gen_ms))
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that _generation_tasks_by_rid[rid] is not cleaned up when an exception occurs.

_logging.getLogger(__name__).exception("[Stage-%s] Failed on request %s: %s", stage_id, rid, e)
out_q.put(
{
"request_id": rid,
"stage_id": stage_id,
"error": str(e),
}
)

r_outputs = [gen_output]
_num_tokens = count_tokens_from_outputs(r_outputs)
_agg_total_tokens += _num_tokens
_agg_total_gen_time_ms += _gen_ms

if _stats_file:
_avg_tokens_per_s = (
(_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms) if _agg_total_gen_time_ms > 0 else 0.0
)
log_stage_running_avg(
_stats_file,
stage_id,
int(_agg_total_tokens),
float(_agg_total_gen_time_ms),
float(_avg_tokens_per_s),
)
_batch_gen_t0 = _time.time()
while True:
try:
task = in_q.get_nowait()
if task is None:
_logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id)
break
asyncio.create_task(generation_single_request(task))
except queue.Empty:
await asyncio.sleep(0.001)
batch_request_outputs: list[Any] = []
batch_request_ids: list[Any] = []
_gen_ms_list = []
while True:
try:
rids, gen_output, _gen_ms = generation_out_q.get_nowait()
_num_tokens = count_tokens_from_outputs([gen_output])
batch_request_outputs.append(gen_output)
_gen_ms_list.append(_gen_ms)
batch_request_ids.append(rids)
_agg_total_tokens += _num_tokens
except asyncio.QueueEmpty:
await asyncio.sleep(0.001)
break

if not batch_request_outputs:
continue
_batch_seq += 1
if _stats_file:
_batch_gen_t1 = _time.time()
_agg_total_gen_time_ms += (_batch_gen_t1 - _batch_gen_t0) * 1000
_batch_gen_t0 = _batch_gen_t1
_avg_tokens_per_s = (
(_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms) if _agg_total_gen_time_ms > 0 else 0.0
)
log_stage_running_avg(
_stats_file,
stage_id,
int(_agg_total_tokens),
float(_agg_total_gen_time_ms),
float(_avg_tokens_per_s),
)
logger.info("[Stage-%s] Running avg: %s tokens/s", stage_id, _avg_tokens_per_s)
for rid, _gen_ms in zip(batch_request_ids, _gen_ms_list):
log_stage_batch_stats(_stats_file, stage_id, 1, float(_gen_ms), [rid])

logger.info("[Stage-%s] Sending outputs to main process", stage_id)
for rid, output, _gen_ms in zip(batch_request_ids, batch_request_outputs, _gen_ms_list):
try:
r_outputs = [output]
use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes)
_metrics = {
"num_tokens_out": int(count_tokens_from_outputs(r_outputs)),
"stage_gen_time_ms": _gen_ms,
"batch_id": int(_batch_seq),
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.get(rid, 0.0)),
"rx_transfer_bytes": int(_rx_bytes_by_rid.get(rid, 0)),
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.get(rid, 0.0)),
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.pop(rid, 0.0)),
"rx_transfer_bytes": int(_rx_bytes_by_rid.pop(rid, 0)),
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.pop(rid, 0.0)),
}
if _stats_file:
compute_and_log_stage_request_stats(
Expand Down Expand Up @@ -1266,23 +1301,14 @@ def filter(self, record: _logging.LogRecord) -> bool:
"metrics": {
"num_tokens_out": int(count_tokens_from_outputs(r_outputs)),
"stage_gen_time_ms": _gen_ms,
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.get(rid, 0.0)),
"rx_transfer_bytes": int(_rx_bytes_by_rid.get(rid, 0)),
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.get(rid, 0.0)),
"rx_decode_time_ms": float(_rx_decode_ms_by_rid.pop(rid, 0.0)),
"rx_transfer_bytes": int(_rx_bytes_by_rid.pop(rid, 0)),
"rx_in_flight_time_ms": float(_in_flight_ms_by_rid.pop(rid, 0.0)),
},
}
)
_logging.getLogger(__name__).debug("[Stage-%s] Enqueued result for request %s to downstream", stage_id, rid)

except Exception as e:
_logging.getLogger(__name__).exception("[Stage-%s] Failed on request %s: %s", stage_id, rid, e)
out_q.put(
{
"request_id": rid,
"stage_id": stage_id,
"error": str(e),
}
)
print("--------------------------------", flush=True)
print(f"[Stage-{stage_id}] Stage worker exiting", flush=True)
print("--------------------------------", flush=True)