diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index a0f287294..63ff0e050 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -4,6 +4,7 @@ E2E Online tests for Qwen3-Omni model with video input and audio output. """ +import concurrent.futures import os import socket import subprocess @@ -167,40 +168,55 @@ def dummy_messages_from_video_data( @pytest.mark.parametrize("omni_server", test_params, indirect=True) -def test_video_to_audio( +def test_video_to_audio_concurrent( client: openai.OpenAI, omni_server, base64_encoded_video: str, ) -> None: - """Test processing video, generating audio output via OpenAI API.""" + """Test processing video with multiple concurrent completions, generating audio output via OpenAI API.""" # Create data URL for the base64 encoded video video_data_url = f"data:video/mp4;base64,{base64_encoded_video}" messages = dummy_messages_from_video_data(video_data_url) - # Test single completion - chat_completion = client.chat.completions.create( - model=omni_server.model, - messages=messages, - ) - - assert len(chat_completion.choices) == 2 # 1 for text output, 1 for audio output - - # Verify text output - text_choice = chat_completion.choices[0] - assert text_choice.finish_reason == "length" - - # Verify we got a response - text_message = text_choice.message - assert text_message.content is not None and len(text_message.content) >= 10 - assert text_message.role == "assistant" - - # Verify audio output - audio_choice = chat_completion.choices[1] - assert audio_choice.finish_reason == "stop" - audio_message = audio_choice.message - - # Check if audio was generated - if hasattr(audio_message, "audio") and audio_message.audio: - assert audio_message.audio.data is not None - assert len(audio_message.audio.data) > 0 + # Test multiple concurrent completions + num_concurrent_requests = 5 + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor: + # Submit multiple completion requests concurrently + futures = [ + executor.submit( + client.chat.completions.create, + model=omni_server.model, + messages=messages, + ) + for _ in range(num_concurrent_requests) + ] + + # Wait for all requests to complete and collect results + chat_completions = [future.result() for future in concurrent.futures.as_completed(futures)] + + # Verify all completions succeeded + assert len(chat_completions) == num_concurrent_requests + + for chat_completion in chat_completions: + assert len(chat_completion.choices) == 2 # 1 for text output, 1 for audio output + + # Verify text output + text_choice = chat_completion.choices[0] + assert text_choice.finish_reason == "length" + + # Verify we got a response + text_message = text_choice.message + assert text_message.content is not None and len(text_message.content) >= 10 + assert text_message.role == "assistant" + + # Verify audio output + audio_choice = chat_completion.choices[1] + assert audio_choice.finish_reason == "stop" + audio_message = audio_choice.message + + # Check if audio was generated + if hasattr(audio_message, "audio") and audio_message.audio: + assert audio_message.audio.data is not None + assert len(audio_message.audio.data) > 0 diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 4e44b6524..503c8395c 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -16,6 +16,7 @@ import logging import multiprocessing as mp import os +import queue import sys import traceback from typing import Any @@ -1135,19 +1136,14 @@ def filter(self, record: _logging.LogRecord) -> bool: ) except Exception as e: _logging.getLogger(__name__).warning("[Stage-%s] Failed to send stage ready signal: %s", stage_id, e) - + generation_out_q = asyncio.Queue() # Batch processing loop - while True: - task = in_q.get() - _recv_dequeue_ts = _time.time() - if task is None: - _logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id) - break - - _rx_bytes_by_rid: dict[Any, int] = {} - _rx_decode_ms_by_rid: dict[Any, float] = {} - _in_flight_ms_by_rid: dict[Any, float] = {} + _rx_bytes_by_rid: dict[Any, int] = {} + _rx_decode_ms_by_rid: dict[Any, float] = {} + _in_flight_ms_by_rid: dict[Any, float] = {} + async def generation_single_request(task: dict[str, Any]): + _recv_dequeue_ts = _time.time() rid = task["request_id"] try: sent_ts = float(task.get("sent_ts", None)) if isinstance(task, dict) else None @@ -1157,62 +1153,101 @@ def filter(self, record: _logging.LogRecord) -> bool: _in_flight_ms_by_rid[rid] = 0.0 except Exception: _in_flight_ms_by_rid[rid] = 0.0 - ein, _rx_metrics = try_recv_via_connector( - task=task, - connectors=connectors, - stage_id=stage_id, - ) - if ein is None or _rx_metrics is None: - raise RuntimeError( - f"[Stage-{stage_id}] Missing connector payload for request {rid}. " - "Ensure connectors are configured for all incoming edges." - ) - _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) - _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) - - sampling_params = task["sampling_params"] - _logging.getLogger(__name__).debug("[Stage-%s] Received batch size=1, request_ids=%s", stage_id, rid) - print("--------------------------------", flush=True) - print(f"[Stage-{stage_id}] Received batch size=1, request_ids={rid}", flush=True) - print("--------------------------------", flush=True) try: - _batch_seq += 1 + ein, _rx_metrics = try_recv_via_connector( + task=task, + connectors=connectors, + stage_id=stage_id, + ) + if ein is None or _rx_metrics is None: + raise RuntimeError( + f"[Stage-{stage_id}] Missing connector payload for request {rid}. " + "Ensure connectors are configured for all incoming edges." + ) + _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) + _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) + + sampling_params = task["sampling_params"] + _logging.getLogger(__name__).debug("[Stage-%s] Received batch size=1, request_ids=%s", stage_id, rid) + print("--------------------------------", flush=True) + print(f"[Stage-{stage_id}] Received batch size=1, request_ids={rid}", flush=True) + print("--------------------------------", flush=True) _gen_t0 = _time.time() if isinstance(ein, list): ein = ein[0] - async for res in stage_engine.generate(ein, sampling_params, rid): gen_output = res _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 + await generation_out_q.put((rid, gen_output, _gen_ms)) + except Exception as e: + _logging.getLogger(__name__).exception("[Stage-%s] Failed on request %s: %s", stage_id, rid, e) + out_q.put( + { + "request_id": rid, + "stage_id": stage_id, + "error": str(e), + } + ) - r_outputs = [gen_output] - _num_tokens = count_tokens_from_outputs(r_outputs) - _agg_total_tokens += _num_tokens - _agg_total_gen_time_ms += _gen_ms - - if _stats_file: - _avg_tokens_per_s = ( - (_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms) if _agg_total_gen_time_ms > 0 else 0.0 - ) - log_stage_running_avg( - _stats_file, - stage_id, - int(_agg_total_tokens), - float(_agg_total_gen_time_ms), - float(_avg_tokens_per_s), - ) + _batch_gen_t0 = _time.time() + while True: + try: + task = in_q.get_nowait() + if task is None: + _logging.getLogger(__name__).debug("[Stage-%s] Received shutdown signal", stage_id) + break + asyncio.create_task(generation_single_request(task)) + except queue.Empty: + await asyncio.sleep(0.001) + batch_request_outputs: list[Any] = [] + batch_request_ids: list[Any] = [] + _gen_ms_list = [] + while True: + try: + rids, gen_output, _gen_ms = generation_out_q.get_nowait() + _num_tokens = count_tokens_from_outputs([gen_output]) + batch_request_outputs.append(gen_output) + _gen_ms_list.append(_gen_ms) + batch_request_ids.append(rids) + _agg_total_tokens += _num_tokens + except asyncio.QueueEmpty: + await asyncio.sleep(0.001) + break + + if not batch_request_outputs: + continue + _batch_seq += 1 + if _stats_file: + _batch_gen_t1 = _time.time() + _agg_total_gen_time_ms += (_batch_gen_t1 - _batch_gen_t0) * 1000 + _batch_gen_t0 = _batch_gen_t1 + _avg_tokens_per_s = ( + (_agg_total_tokens * 1000.0 / _agg_total_gen_time_ms) if _agg_total_gen_time_ms > 0 else 0.0 + ) + log_stage_running_avg( + _stats_file, + stage_id, + int(_agg_total_tokens), + float(_agg_total_gen_time_ms), + float(_avg_tokens_per_s), + ) + logger.info("[Stage-%s] Running avg: %s tokens/s", stage_id, _avg_tokens_per_s) + for rid, _gen_ms in zip(batch_request_ids, _gen_ms_list): log_stage_batch_stats(_stats_file, stage_id, 1, float(_gen_ms), [rid]) + logger.info("[Stage-%s] Sending outputs to main process", stage_id) + for rid, output, _gen_ms in zip(batch_request_ids, batch_request_outputs, _gen_ms_list): try: + r_outputs = [output] use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) _metrics = { "num_tokens_out": int(count_tokens_from_outputs(r_outputs)), "stage_gen_time_ms": _gen_ms, "batch_id": int(_batch_seq), - "rx_decode_time_ms": float(_rx_decode_ms_by_rid.get(rid, 0.0)), - "rx_transfer_bytes": int(_rx_bytes_by_rid.get(rid, 0)), - "rx_in_flight_time_ms": float(_in_flight_ms_by_rid.get(rid, 0.0)), + "rx_decode_time_ms": float(_rx_decode_ms_by_rid.pop(rid, 0.0)), + "rx_transfer_bytes": int(_rx_bytes_by_rid.pop(rid, 0)), + "rx_in_flight_time_ms": float(_in_flight_ms_by_rid.pop(rid, 0.0)), } if _stats_file: compute_and_log_stage_request_stats( @@ -1266,23 +1301,14 @@ def filter(self, record: _logging.LogRecord) -> bool: "metrics": { "num_tokens_out": int(count_tokens_from_outputs(r_outputs)), "stage_gen_time_ms": _gen_ms, - "rx_decode_time_ms": float(_rx_decode_ms_by_rid.get(rid, 0.0)), - "rx_transfer_bytes": int(_rx_bytes_by_rid.get(rid, 0)), - "rx_in_flight_time_ms": float(_in_flight_ms_by_rid.get(rid, 0.0)), + "rx_decode_time_ms": float(_rx_decode_ms_by_rid.pop(rid, 0.0)), + "rx_transfer_bytes": int(_rx_bytes_by_rid.pop(rid, 0)), + "rx_in_flight_time_ms": float(_in_flight_ms_by_rid.pop(rid, 0.0)), }, } ) _logging.getLogger(__name__).debug("[Stage-%s] Enqueued result for request %s to downstream", stage_id, rid) - except Exception as e: - _logging.getLogger(__name__).exception("[Stage-%s] Failed on request %s: %s", stage_id, rid, e) - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "error": str(e), - } - ) print("--------------------------------", flush=True) print(f"[Stage-{stage_id}] Stage worker exiting", flush=True) print("--------------------------------", flush=True)