diff --git a/docs/api/README.md b/docs/api/README.md index 0332dceff..cbb027913 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -32,6 +32,7 @@ Input data structures for multi-modal inputs. Engine classes for offline and online inference. +- [vllm_omni.diffusion.diffusion_engine.BackgroundResources][] - [vllm_omni.diffusion.diffusion_engine.DiffusionEngine][] - [vllm_omni.engine.AdditionalInformationEntry][] - [vllm_omni.engine.AdditionalInformationPayload][] @@ -57,6 +58,7 @@ Core scheduling and caching components. Model execution components. +- [vllm_omni.model_executor.custom_process_mixin.CustomProcessMixin][] - [vllm_omni.model_executor.models.output_templates.OmniOutput][] - [vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni.Qwen2_5OmniForConditionalGeneration][] - [vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_talker.Qwen2_5OmniTalkerForConditionalGeneration][] diff --git a/docs/user_guide/examples/online_serving/qwen2_5_omni.md b/docs/user_guide/examples/online_serving/qwen2_5_omni.md index 867d44e16..98a3d4362 100644 --- a/docs/user_guide/examples/online_serving/qwen2_5_omni.md +++ b/docs/user_guide/examples/online_serving/qwen2_5_omni.md @@ -74,84 +74,19 @@ bash run_curl_multimodal_generation.sh mixed_modalities ``` ## Modality control - -You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. - -### Supported modalities - -| Modalities | Output | -|------------|--------| -| `["text"]` | Text only | -| `["audio"]` | Text + Audio | -| `["text", "audio"]` | Text + Audio | -| Not specified | Text + Audio (default) | - -### Using curl - -#### Text only - -```bash -curl http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Qwen/Qwen2.5-Omni-7B", - "messages": [{"role": "user", "content": "Describe vLLM in brief."}], - "modalities": ["text"] - }' -``` - -#### Text + Audio - -```bash -curl http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Qwen/Qwen2.5-Omni-7B", - "messages": [{"role": "user", "content": "Describe vLLM in brief."}], - "modalities": ["audio"] - }' -``` - -### Using Python client - +If you want to control output modalities, e.g. only output text, you can run the command below: ```bash python openai_chat_completion_client_for_multimodal_generation.py \ --query-type mixed_modalities \ --modalities text ``` -### Using OpenAI Python SDK - -#### Text only - -```python -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") - -response = client.chat.completions.create( - model="Qwen/Qwen2.5-Omni-7B", - messages=[{"role": "user", "content": "Describe vLLM in brief."}], - modalities=["text"] -) -print(response.choices[0].message.content) -``` - -#### Text + Audio - -```python -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") - -response = client.chat.completions.create( - model="Qwen/Qwen2.5-Omni-7B", - messages=[{"role": "user", "content": "Describe vLLM in brief."}], - modalities=["audio"] -) -# Response contains two choices: one with text, one with audio -print(response.choices[0].message.content) # Text response -print(response.choices[1].message.audio) # Audio response +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --stream ``` ## Run Local Web UI Demo diff --git a/docs/user_guide/examples/online_serving/qwen3_omni.md b/docs/user_guide/examples/online_serving/qwen3_omni.md index 5b6c3a7b2..2e6bb3795 100644 --- a/docs/user_guide/examples/online_serving/qwen3_omni.md +++ b/docs/user_guide/examples/online_serving/qwen3_omni.md @@ -82,84 +82,19 @@ sudo apt install ffmpeg ``` ## Modality control - -You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. - -### Supported modalities - -| Modalities | Output | -|------------|--------| -| `["text"]` | Text only | -| `["audio"]` | Text + Audio | -| `["text", "audio"]` | Text + Audio | -| Not specified | Text + Audio (default) | - -### Using curl - -#### Text only - -```bash -curl http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", - "messages": [{"role": "user", "content": "Describe vLLM in brief."}], - "modalities": ["text"] - }' -``` - -#### Text + Audio - -```bash -curl http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", - "messages": [{"role": "user", "content": "Describe vLLM in brief."}], - "modalities": ["audio"] - }' -``` - -### Using Python client - +If you want to control output modalities, e.g. only output text, you can run the command below: ```bash python openai_chat_completion_client_for_multimodal_generation.py \ --query-type use_image \ --modalities text ``` -### Using OpenAI Python SDK - -#### Text only - -```python -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") - -response = client.chat.completions.create( - model="Qwen/Qwen3-Omni-30B-A3B-Instruct", - messages=[{"role": "user", "content": "Describe vLLM in brief."}], - modalities=["text"] -) -print(response.choices[0].message.content) -``` - -#### Text + Audio - -```python -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") - -response = client.chat.completions.create( - model="Qwen/Qwen3-Omni-30B-A3B-Instruct", - messages=[{"role": "user", "content": "Describe vLLM in brief."}], - modalities=["audio"] -) -# Response contains two choices: one with text, one with audio -print(response.choices[0].message.content) # Text response -print(response.choices[1].message.audio) # Audio response +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --stream ``` ## Run Local Web UI Demo diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index bfd5324c5..4c116bb8b 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -377,12 +377,12 @@ def main(args): for i, prompt in enumerate(prompts): prompt["modalities"] = output_modalities - omni_outputs = omni_llm.generate(prompts, sampling_params_list) + omni_generator = omni_llm.generate(prompts, sampling_params_list) # Determine output directory: prefer --output-dir; fallback to --output-wav output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav os.makedirs(output_dir, exist_ok=True) - for stage_outputs in omni_outputs: + for stage_outputs in omni_generator: if stage_outputs.final_output_type == "text": for output in stage_outputs.request_output: request_id = output.request_id diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 753f7cc36..8eae6d941 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -233,12 +233,12 @@ def main(args): for i, prompt in enumerate(prompts): prompt["modalities"] = output_modalities - omni_outputs = omni_llm.generate(prompts, sampling_params_list) + omni_generator = omni_llm.generate(prompts, sampling_params_list) # Determine output directory: prefer --output-dir; fallback to --output-wav output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav os.makedirs(output_dir, exist_ok=True) - for stage_outputs in omni_outputs: + for stage_outputs in omni_generator: if stage_outputs.final_output_type == "text": for output in stage_outputs.request_output: request_id = output.request_id diff --git a/examples/online_serving/qwen2_5_omni/README.md b/examples/online_serving/qwen2_5_omni/README.md index ebaf6a2ca..fd010346c 100644 --- a/examples/online_serving/qwen2_5_omni/README.md +++ b/examples/online_serving/qwen2_5_omni/README.md @@ -78,6 +78,14 @@ python openai_chat_completion_client_for_multimodal_generation.py \ --modalities text ``` +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type mixed_modalities \ + --stream +``` + ## Run Local Web UI Demo This Web UI demo allows users to interact with the model through a web browser. diff --git a/examples/online_serving/qwen2_5_omni/gradio_demo.py b/examples/online_serving/qwen2_5_omni/gradio_demo.py index 82775ac95..a6d3a67e8 100644 --- a/examples/online_serving/qwen2_5_omni/gradio_demo.py +++ b/examples/online_serving/qwen2_5_omni/gradio_demo.py @@ -23,16 +23,18 @@ "top_p": 1.0, "top_k": -1, "max_tokens": 2048, + "seed": SEED, "detokenize": True, "repetition_penalty": 1.1, }, "talker": { - "temperature": 0.0, - "top_p": 1.0, - "top_k": -1, + "temperature": 0.9, + "top_p": 0.8, + "top_k": 40, "max_tokens": 2048, + "seed": SEED, "detokenize": True, - "repetition_penalty": 1.1, + "repetition_penalty": 1.05, "stop_token_ids": [8294], }, "code2wav": { @@ -40,6 +42,7 @@ "top_p": 1.0, "top_k": -1, "max_tokens": 2048, + "seed": SEED, "detokenize": True, "repetition_penalty": 1.1, }, @@ -241,10 +244,12 @@ def run_inference_api( video_file: str | None = None, use_audio_in_video: bool = False, output_modalities: str | None = None, + stream: bool = False, ): """Run inference using OpenAI API client with multimodal support.""" if not user_prompt.strip() and not audio_file and not image_file and not video_file: - return "Please provide at least a text prompt or multimodal input.", None + yield "Please provide at least a text prompt or multimodal input.", None + return try: # Build message content list @@ -324,7 +329,7 @@ def run_inference_api( extra_body["mm_processor_kwargs"] = mm_processor_kwargs # Parse output modalities - if output_modalities is not None: + if output_modalities and output_modalities.strip(): output_modalities_list = [m.strip() for m in output_modalities.split(",")] else: output_modalities_list = None @@ -335,29 +340,71 @@ def run_inference_api( model=model, modalities=output_modalities_list, extra_body=extra_body, + stream=stream, ) - # Extract outputs - text_outputs: list[str] = [] - audio_output = None - - for choice in chat_completion.choices: - if choice.message.content: - text_outputs.append(choice.message.content) - if choice.message.audio: - # Decode base64 audio - audio_data = base64.b64decode(choice.message.audio.data) - # Load audio from bytes - audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) - # Convert to mono if needed - if audio_np.ndim > 1: - audio_np = audio_np[:, 0] - audio_output = (int(sample_rate), audio_np.astype(np.float32)) - - text_response = "\n\n".join(text_outputs) if text_outputs else "No text output." - return text_response, audio_output + if not stream: + # Non-streaming mode: extract outputs and yield once + text_outputs: list[str] = [] + audio_output = None + + for choice in chat_completion.choices: + if choice.message.content: + text_outputs.append(choice.message.content) + if choice.message.audio: + # Decode base64 audio + audio_data = base64.b64decode(choice.message.audio.data) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + + text_response = "\n\n".join(text_outputs) if text_outputs else "No text output." + yield text_response, audio_output + else: + # Streaming mode: yield incremental updates + text_content = "" + audio_output = None + + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + # Handle audio modality + if getattr(chunk, "modality", None) == "audio" and content: + try: + # Decode base64 audio + audio_data = base64.b64decode(content) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + # Yield current text and audio + yield text_content if text_content else "", audio_output + except Exception: # pylint: disable=broad-except + # If audio processing fails, just yield text + yield text_content if text_content else "", None + + # Handle text modality + elif getattr(chunk, "modality", None) == "text": + if content: + text_content += content + # Yield updated text content (keep existing audio if any) + yield text_content, audio_output + + # Final yield with accumulated text and last audio (if any) + yield text_content if text_content else "No text output.", audio_output + except Exception as exc: # pylint: disable=broad-except - return f"Inference failed: {exc}", None + error_msg = f"Inference failed: {exc}" + yield error_msg, None def build_interface( @@ -374,8 +421,10 @@ def run_inference( video_file: str | None, use_audio_in_video: bool, output_modalities: str | None = None, + stream: bool = False, ): - return run_inference_api( + # Always yield from the API function to maintain consistent generator behavior + yield from run_inference_api( client, model, sampling_params_dict, @@ -385,6 +434,7 @@ def run_inference( video_file, use_audio_in_video, output_modalities, + stream, ) css = """ @@ -455,8 +505,15 @@ def run_inference( with gr.Row(): output_modalities = gr.Textbox( label="Output Modalities", + value=None, placeholder="For example: text, image, video. Use comma to separate multiple modalities.", lines=1, + scale=2, + ) + stream_checkbox = gr.Checkbox( + label="Stream output", + value=False, + info="Enable streaming to see output as it's generated.", scale=1, ) @@ -474,7 +531,15 @@ def run_inference( generate_btn.click( fn=run_inference, - inputs=[input_box, audio_input, image_input, video_input, use_audio_in_video_checkbox, output_modalities], + inputs=[ + input_box, + audio_input, + image_input, + video_input, + use_audio_in_video_checkbox, + output_modalities, + stream_checkbox, + ], outputs=[text_output, audio_output], ) demo.queue() diff --git a/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py index e5ef42c8c..a25e97ebf 100644 --- a/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py @@ -354,19 +354,43 @@ def run_multimodal_generation(args) -> None: model=model_name, modalities=output_modalities, extra_body=extra_body, + stream=args.stream, ) count = 0 - for choice in chat_completion.choices: - if choice.message.audio: - audio_data = base64.b64decode(choice.message.audio.data) - audio_file_path = f"audio_{count}.wav" - with open(audio_file_path, "wb") as f: - f.write(audio_data) - print(f"Audio saved to {audio_file_path}") - count += 1 - elif choice.message.content: - print("Chat completion output from text:", choice.message.content) + if not args.stream: + for choice in chat_completion.choices: + if choice.message.audio: + audio_data = base64.b64decode(choice.message.audio.data) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"Audio saved to {audio_file_path}") + count += 1 + elif choice.message.content: + print("Chat completion output from text:", choice.message.content) + else: + printed_content = False + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + if getattr(chunk, "modality", None) == "audio" and content: + audio_data = base64.b64decode(content) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"\nAudio saved to {audio_file_path}") + count += 1 + + elif getattr(chunk, "modality", None) == "text": + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + print(content, end="", flush=True) def parse_args(): @@ -413,6 +437,11 @@ def parse_args(): default=None, help="Output modalities to use for the prompts.", ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream the response.", + ) return parser.parse_args() diff --git a/examples/online_serving/qwen3_omni/README.md b/examples/online_serving/qwen3_omni/README.md index 41e954264..b9b825a1e 100644 --- a/examples/online_serving/qwen3_omni/README.md +++ b/examples/online_serving/qwen3_omni/README.md @@ -86,6 +86,14 @@ python openai_chat_completion_client_for_multimodal_generation.py \ --modalities text ``` +## Streaming Output +If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --stream +``` + ## Run Local Web UI Demo This Web UI demo allows users to interact with the model through a web browser. diff --git a/examples/online_serving/qwen3_omni/gradio_demo.py b/examples/online_serving/qwen3_omni/gradio_demo.py index eee0c75f6..76c4e311a 100644 --- a/examples/online_serving/qwen3_omni/gradio_demo.py +++ b/examples/online_serving/qwen3_omni/gradio_demo.py @@ -244,10 +244,11 @@ def run_inference_api( video_file: str | None = None, use_audio_in_video: bool = False, output_modalities: str | None = None, + stream: bool = False, ): """Run inference using OpenAI API client with multimodal support.""" if not user_prompt.strip() and not audio_file and not image_file and not video_file: - return "Please provide at least a text prompt or multimodal input.", None + yield "Please provide at least a text prompt or multimodal input.", None try: # Build message content list @@ -327,7 +328,7 @@ def run_inference_api( extra_body["mm_processor_kwargs"] = mm_processor_kwargs # Parse output modalities - if output_modalities is not None: + if output_modalities and output_modalities.strip(): output_modalities_list = [m.strip() for m in output_modalities.split(",")] else: output_modalities_list = None @@ -338,29 +339,71 @@ def run_inference_api( model=model, modalities=output_modalities_list, extra_body=extra_body, + stream=stream, ) - # Extract outputs - text_outputs: list[str] = [] - audio_output = None - - for choice in chat_completion.choices: - if choice.message.content: - text_outputs.append(choice.message.content) - if choice.message.audio: - # Decode base64 audio - audio_data = base64.b64decode(choice.message.audio.data) - # Load audio from bytes - audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) - # Convert to mono if needed - if audio_np.ndim > 1: - audio_np = audio_np[:, 0] - audio_output = (int(sample_rate), audio_np.astype(np.float32)) - - text_response = "\n\n".join(text_outputs) if text_outputs else "No text output." - return text_response, audio_output + if not stream: + # Non-streaming mode: extract outputs and yield once + text_outputs: list[str] = [] + audio_output = None + + for choice in chat_completion.choices: + if choice.message.content: + text_outputs.append(choice.message.content) + if choice.message.audio: + # Decode base64 audio + audio_data = base64.b64decode(choice.message.audio.data) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + + text_response = "\n\n".join(text_outputs) if text_outputs else "No text output." + yield text_response, audio_output + else: + # Streaming mode: yield incremental updates + text_content = "" + audio_output = None + + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + # Handle audio modality + if getattr(chunk, "modality", None) == "audio" and content: + try: + # Decode base64 audio + audio_data = base64.b64decode(content) + # Load audio from bytes + audio_np, sample_rate = sf.read(io.BytesIO(audio_data)) + # Convert to mono if needed + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + audio_output = (int(sample_rate), audio_np.astype(np.float32)) + # Yield current text and audio + yield text_content if text_content else "", audio_output + except Exception: # pylint: disable=broad-except + # If audio processing fails, just yield text + yield text_content if text_content else "", None + + # Handle text modality + elif getattr(chunk, "modality", None) == "text": + if content: + text_content += content + # Yield updated text content (keep existing audio if any) + yield text_content, audio_output + + # Final yield with accumulated text and last audio (if any) + yield text_content if text_content else "No text output.", audio_output + except Exception as exc: # pylint: disable=broad-except - return f"Inference failed: {exc}", None + error_msg = f"Inference failed: {exc}" + yield error_msg, None def build_interface( @@ -377,8 +420,10 @@ def run_inference( video_file: str | None, use_audio_in_video: bool, output_modalities: str | None = None, + stream: bool = False, ): - return run_inference_api( + # Always yield from the API function to maintain consistent generator behavior + yield from run_inference_api( client, model, sampling_params_dict, @@ -388,6 +433,7 @@ def run_inference( video_file, use_audio_in_video, output_modalities, + stream, ) css = """ @@ -458,8 +504,15 @@ def run_inference( with gr.Row(): output_modalities = gr.Textbox( label="Output Modalities", + value=None, placeholder="For example: text, image, video. Use comma to separate multiple modalities.", lines=1, + scale=2, + ) + stream_checkbox = gr.Checkbox( + label="Stream output", + value=False, + info="Enable streaming to see output as it's generated.", scale=1, ) @@ -477,7 +530,15 @@ def run_inference( generate_btn.click( fn=run_inference, - inputs=[input_box, audio_input, image_input, video_input, use_audio_in_video_checkbox, output_modalities], + inputs=[ + input_box, + audio_input, + image_input, + video_input, + use_audio_in_video_checkbox, + output_modalities, + stream_checkbox, + ], outputs=[text_output, audio_output], ) demo.queue() diff --git a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py index 7ebf8c7f4..6438fa0fd 100644 --- a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py @@ -342,19 +342,43 @@ def run_multimodal_generation(args) -> None: model=model_name, modalities=output_modalities, extra_body=extra_body, + stream=args.stream, ) count = 0 - for choice in chat_completion.choices: - if choice.message.audio: - audio_data = base64.b64decode(choice.message.audio.data) - audio_file_path = f"audio_{count}.wav" - with open(audio_file_path, "wb") as f: - f.write(audio_data) - print(f"Audio saved to {audio_file_path}") - count += 1 - elif choice.message.content: - print("Chat completion output from text:", choice.message.content) + if not args.stream: + for choice in chat_completion.choices: + if choice.message.audio: + audio_data = base64.b64decode(choice.message.audio.data) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"Audio saved to {audio_file_path}") + count += 1 + elif choice.message.content: + print("Chat completion output from text:", choice.message.content) + else: + printed_content = False + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + if getattr(chunk, "modality", None) == "audio" and content: + audio_data = base64.b64decode(content) + audio_file_path = f"audio_{count}.wav" + with open(audio_file_path, "wb") as f: + f.write(audio_data) + print(f"\nAudio saved to {audio_file_path}") + count += 1 + + elif getattr(chunk, "modality", None) == "text": + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + print(content, end="", flush=True) def parse_args(): @@ -408,6 +432,11 @@ def parse_args(): default=None, help="Output modalities to use for the prompts.", ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream the response.", + ) return parser.parse_args() diff --git a/mkdocs.yml b/mkdocs.yml index 4d4fbbb60..dea2f2566 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -113,7 +113,8 @@ plugins: - https://docs.aiohttp.org/en/stable/objects.inv - https://pillow.readthedocs.io/en/stable/objects.inv - https://numpy.org/doc/stable/objects.inv - - https://pytorch.org/docs/stable/objects.inv + # Temporarily disabled due to decompression errors + # - https://pytorch.org/docs/stable/objects.inv - https://psutil.readthedocs.io/en/stable/objects.inv markdown_extensions: diff --git a/tests/e2e/offline_inference/conftest.py b/tests/e2e/offline_inference/conftest.py index a24c63bff..26611540e 100644 --- a/tests/e2e/offline_inference/conftest.py +++ b/tests/e2e/offline_inference/conftest.py @@ -4,6 +4,7 @@ Pytest configuration and fixtures for vllm-omni tests. """ +from collections.abc import Generator from typing import Any import pytest @@ -189,7 +190,7 @@ def generate( self, prompts: list[dict[str, Any]], sampling_params_list: list[SamplingParams] | None = None, - ) -> list[Any]: + ) -> Generator[Any, None, None]: """ Generate outputs for the given prompts. @@ -205,7 +206,7 @@ def generate( if sampling_params_list is None: sampling_params_list = self.get_default_sampling_params_list() - return self.omni.generate(prompts, sampling_params_list) + yield from self.omni.generate(prompts, sampling_params_list) def generate_multimodal( self, @@ -217,7 +218,7 @@ def generate_multimodal( videos: PromptVideoInput = None, mm_processor_kwargs: dict[str, Any] | None = None, modalities: list[str] | None = None, - ) -> list[Any]: + ) -> Generator[Any, None, None]: """ Convenience method to generate with multimodal inputs. @@ -242,7 +243,7 @@ def generate_multimodal( mm_processor_kwargs=mm_processor_kwargs, modalities=modalities, ) - return self.generate(omni_inputs, sampling_params_list) + yield from self.generate(omni_inputs, sampling_params_list) def generate_audio( self, @@ -251,7 +252,7 @@ def generate_audio( system_prompt: str | None = None, audios: PromptAudioInput = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> list[Any]: + ) -> Generator[Any, None, None]: """ Convenience method to generate with multimodal inputs. Args: @@ -269,7 +270,7 @@ def generate_audio( audios=audios, mm_processor_kwargs=mm_processor_kwargs, ) - return self.generate(omni_inputs, sampling_params_list) + yield from self.generate(omni_inputs, sampling_params_list) def generate_video( self, @@ -278,7 +279,7 @@ def generate_video( system_prompt: str | None = None, videos: PromptVideoInput = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> list[Any]: + ) -> Generator[Any, None, None]: """ Convenience method to generate with multimodal inputs. Args: @@ -296,7 +297,7 @@ def generate_video( videos=videos, mm_processor_kwargs=mm_processor_kwargs, ) - return self.generate(omni_inputs, sampling_params_list) + yield from self.generate(omni_inputs, sampling_params_list) def generate_image( self, @@ -305,7 +306,7 @@ def generate_image( system_prompt: str | None = None, images: PromptImageInput = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> list[Any]: + ) -> Generator[Any, None, None]: """ Convenience method to generate with multimodal inputs. Args: @@ -323,7 +324,7 @@ def generate_image( images=images, mm_processor_kwargs=mm_processor_kwargs, ) - return self.generate(omni_inputs, sampling_params_list) + yield from self.generate(omni_inputs, sampling_params_list) def __enter__(self): """Context manager entry.""" diff --git a/tests/e2e/offline_inference/test_qwen2_5_omni.py b/tests/e2e/offline_inference/test_qwen2_5_omni.py index 63eea1ba2..eba50059b 100644 --- a/tests/e2e/offline_inference/test_qwen2_5_omni.py +++ b/tests/e2e/offline_inference/test_qwen2_5_omni.py @@ -58,9 +58,6 @@ def test_mixed_modalities_to_audio(omni_runner: type[OmniRunner], test_config: t videos=video, ) - # Verify we got outputs from multiple stages - assert len(outputs) > 0 - # Find and verify text output (thinker stage) text_output = None for stage_output in outputs: @@ -113,15 +110,10 @@ def test_mixed_modalities_to_text_only(omni_runner: type[OmniRunner], test_confi modalities=modalities, ) - # Verify we got outputs from multiple stages - assert len(outputs) > 0 - - for stage_output in outputs: - assert stage_output.final_output_type != "audio" - # Find and verify text output (thinker stage) text_output = None for stage_output in outputs: + assert stage_output.final_output_type != "audio" if stage_output.final_output_type == "text": text_output = stage_output break diff --git a/tests/e2e/offline_inference/test_qwen3_omni.py b/tests/e2e/offline_inference/test_qwen3_omni.py index 945b6eaef..f27af3de4 100644 --- a/tests/e2e/offline_inference/test_qwen3_omni.py +++ b/tests/e2e/offline_inference/test_qwen3_omni.py @@ -37,9 +37,6 @@ def test_video_to_audio(omni_runner: type[OmniRunner], test_config) -> None: videos=video, ) - # Verify we got outputs from multiple stages - assert len(outputs) > 0 - # Find and verify text output (thinker stage) text_output = None for stage_output in outputs: diff --git a/tests/e2e/offline_inference/utils.py b/tests/e2e/offline_inference/utils.py index c491c10b9..3113599a3 100644 --- a/tests/e2e/offline_inference/utils.py +++ b/tests/e2e/offline_inference/utils.py @@ -110,10 +110,10 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: ): exc_info = cloudpickle.load(f) - if (original_exception := exc_info.get("pickled_exception")) is not None: + original_exception = exc_info.get("pickled_exception") + if original_exception is not None and isinstance(original_exception, Exception): # Re-raise the actual exception object if it was # successfully pickled. - assert isinstance(original_exception, Exception) raise original_exception if (original_tb := exc_info.get("traceback")) is not None: diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index 63ff0e050..a9e0a370c 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -4,8 +4,10 @@ E2E Online tests for Qwen3-Omni model with video input and audio output. """ +import base64 import concurrent.futures import os +import signal import socket import subprocess import sys @@ -70,6 +72,7 @@ def _start_server(self) -> None: cmd, env=env, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root + start_new_session=True, # Create a new process group to enable killing all child processes ) # Wait for server to be ready @@ -95,12 +98,37 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if self.proc: - self.proc.terminate() try: - self.proc.wait(timeout=30) - except subprocess.TimeoutExpired: - self.proc.kill() - self.proc.wait() + # Get the process group ID to kill all child processes + pgid = os.getpgid(self.proc.pid) + # Ignore SIGTERM signal itself to avoid killing the test process + old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + try: + # Terminate the entire process group (kills all child processes) + os.killpg(pgid, signal.SIGTERM) + # Wait for the process to terminate + try: + self.proc.wait(timeout=30) + except subprocess.TimeoutExpired: + # If graceful termination fails, force kill + os.killpg(pgid, signal.SIGKILL) + self.proc.wait() + finally: + # Restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + except (ProcessLookupError, OSError): + # Process group may not exist if process already terminated + # Try to clean up the process directly + try: + self.proc.terminate() + try: + self.proc.wait(timeout=10) + except subprocess.TimeoutExpired: + self.proc.kill() + self.proc.wait() + except (ProcessLookupError, OSError): + # Process already terminated, nothing to do + pass @pytest.fixture @@ -126,8 +154,6 @@ def client(omni_server): @pytest.fixture(scope="session") def base64_encoded_video() -> str: """Base64 encoded video for testing.""" - import base64 - video = VideoAsset(name="baby_reading", num_frames=4) with open(video.video_path, "rb") as f: content = f.read() @@ -220,3 +246,40 @@ def test_video_to_audio_concurrent( if hasattr(audio_message, "audio") and audio_message.audio: assert audio_message.audio.data is not None assert len(audio_message.audio.data) > 0 + + # Test streaming completion + chat_completion = client.chat.completions.create( + model=omni_server.model, + messages=messages, + stream=True, + ) + + # Collect text and audio data from stream + text_content = "" + audio_data = None + + for chunk in chat_completion: + for choice in chunk.choices: + if hasattr(choice, "delta"): + content = getattr(choice.delta, "content", None) + else: + content = None + + modality = getattr(chunk, "modality", None) + + if modality == "audio" and content: + # Audio chunk - decode base64 content + if audio_data is None: + audio_data = base64.b64decode(content) + else: + audio_data += base64.b64decode(content) + elif modality == "text" and content: + # Text chunk - accumulate text content + text_content += content if content else "" + + # Verify text output + assert text_content is not None and len(text_content) >= 2 + + # Verify audio output + assert audio_data is not None + assert len(audio_data) > 0 diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index 1b36ed509..045621cc1 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -586,7 +586,7 @@ def _fake_loader(model: str): llm = OmniLLM(model="any", init_timeout=1) with pytest.raises(ValueError): - llm.generate(prompts=["hi"], sampling_params_list=[]) + list(llm.generate(prompts=["hi"], sampling_params_list=[])) def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): @@ -679,7 +679,7 @@ def _fake_loader(model: str): # Use dicts instead of object() for serializable sampling params sampling_params_list = [{"temperature": 0.7}, {"temperature": 0.8}] prompts = ["hi"] - outputs = llm.generate(prompts=prompts, sampling_params_list=sampling_params_list) + outputs = list(llm.generate(prompts=prompts, sampling_params_list=sampling_params_list)) # Both stages have final_output=True, so should aggregate two OmniRequestOutput assert len(outputs) == 2 @@ -773,7 +773,7 @@ def _fake_loader(model: str): ) # Use dicts instead of object() for serializable sampling params - outputs = llm.generate(prompts=["p"], sampling_params_list=[{"temperature": 0.7}, {"temperature": 0.8}]) + outputs = list(llm.generate(prompts=["p"], sampling_params_list=[{"temperature": 0.7}, {"temperature": 0.8}])) assert outputs == [] @@ -829,7 +829,7 @@ def _fake_loader(model: str): llm = OmniLLM(model="any", init_timeout=1) with pytest.raises(ValueError): - llm.generate(prompts=["p"], sampling_params_list=None) + list(llm.generate(prompts=["p"], sampling_params_list=None)) def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): @@ -974,7 +974,7 @@ def _fake_loader(model: str): # Generate should handle error gracefully (log but continue) # Use dict instead of object() for serializable sampling params sampling_params_list = [{"temperature": 0.7}] - outputs = llm.generate(prompts=["hi"], sampling_params_list=sampling_params_list) + outputs = list(llm.generate(prompts=["hi"], sampling_params_list=sampling_params_list)) # Should return final output (error was logged but didn't stop processing) assert isinstance(outputs, list) # Since final_output=True, should have one output diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index f57a90465..640395a66 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -382,80 +382,84 @@ async def generate( logger.debug("[Orchestrator] Entering scheduling loop: stages=%d", num_stages) for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): - result = await req_state.queue.get() - assert stage_id == req_state.stage_id + finished = False + while not finished: + result = await req_state.queue.get() + assert stage_id == req_state.stage_id + + req_id = result.get("request_id") + if "error" in result: + logger.error( + "Stage %s error on request %s: %s", + stage_id, + req_id, + result["error"], + ) + raise RuntimeError(result) # Request Finished due to error - req_id = result.get("request_id") - if "error" in result: - logger.error( - "Stage %s error on request %s: %s", + engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + # Mark last output time for this stage whenever we receive outputs + metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) + try: + _m = result.get("metrics") + if _m is not None: + metrics.on_stage_metrics(stage_id, req_id, _m) + except Exception as e: + logger.exception( + "[Orchestrator] Failed to process metrics for stage %s, \ + req %s: %s", + stage_id, + req_id, + e, + ) + logger.debug( + "[Orchestrator] Stage-%s completed request %s; \ + forwarding or finalizing", stage_id, req_id, - result["error"], ) - raise RuntimeError(result) # Request Finished due to error + stage.set_engine_outputs(engine_outputs) - engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") - # Mark last output time for this stage whenever we receive outputs - metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) - try: - _m = result.get("metrics") - if _m is not None: - metrics.on_stage_metrics(stage_id, req_id, _m) - except Exception as e: - logger.exception( - "[Orchestrator] Failed to process metrics for stage %s, \ - req %s: %s", - stage_id, - req_id, - e, - ) - logger.debug( - "[Orchestrator] Stage-%s completed request %s; \ - forwarding or finalizing", - stage_id, - req_id, - ) - stage.set_engine_outputs(engine_outputs) + if isinstance(engine_outputs, list): + engine_outputs = engine_outputs[0] - if getattr(stage, "final_output", False): - logger.debug( - "[Orchestrator] Request %s finalized at stage-%s", - req_id, - stage_id, - ) + finished = engine_outputs.finished - # End-to-end timing and time-per-token for final output - # (only once per request at the designated final stage) - try: - rid_key = str(req_id) - if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done: - metrics.on_finalize_request( - stage_id, - req_id, - engine_outputs, - _req_start_ts.get(req_id, _wall_start_ts), - ) - except Exception as e: - logger.exception( - "[Orchestrator] Finalize request handling error for \ - req %s at stage %s: %s", + if getattr(stage, "final_output", False): + logger.debug( + "[Orchestrator] Request %s finalized at stage-%s", req_id, stage_id, - e, ) - if isinstance(engine_outputs, list): - engine_outputs = engine_outputs[0] - yield OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, - request_output=engine_outputs, - ) + # End-to-end timing and time-per-token for final output + # (only once per request at the designated final stage) + try: + rid_key = str(req_id) + if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done: + metrics.on_finalize_request( + stage_id, + req_id, + [engine_outputs], + _req_start_ts.get(req_id, _wall_start_ts), + ) + except Exception as e: + logger.exception( + "[Orchestrator] Finalize request handling error for \ + req %s at stage %s: %s", + req_id, + stage_id, + e, + ) + yield OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + ) # Forward to next stage if there is one next_stage_id = stage_id + 1 - if next_stage_id <= final_stage_id_for_e2e: + if next_stage_id <= final_stage_id_for_e2e and finished: next_stage: OmniStage = self.stage_list[next_stage_id] next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) sp_next: SamplingParams = sampling_params_list[next_stage_id] diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index bfca19573..d9f102900 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -2,7 +2,7 @@ import os import time import uuid -from collections.abc import Sequence +from collections.abc import Generator, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -232,7 +232,7 @@ def generate( self, prompts: PromptType | Sequence[PromptType], sampling_params_list: SamplingParams | Sequence[SamplingParams] | None = None, - ) -> list[OmniRequestOutput]: + ) -> Generator[OmniRequestOutput, None, None]: """Generate outputs for the given prompts. Processes prompts through all stages in the pipeline and returns @@ -255,7 +255,7 @@ def generate( ValueError: If sampling_params_list is None or has incorrect length. """ try: - return self._run_generation(prompts, sampling_params_list) + yield from self._run_generation(prompts, sampling_params_list) except Exception as e: logger.exception("[Orchestrator] Failed to run generation: %s", e) raise e @@ -266,7 +266,7 @@ def _run_generation( self, prompts: PromptType | Sequence[PromptType], sampling_params_list: SamplingParams | Sequence[SamplingParams] | None = None, - ) -> list[OmniRequestOutput]: + ) -> Generator[OmniRequestOutput, None, None]: logger.debug("[Orchestrator] generate() called") if sampling_params_list is None: raise ValueError("sampling_params_list is required for pipelined generation") @@ -279,8 +279,6 @@ def _run_generation( else: request_prompts = list(prompts) - final_outputs: list[OmniRequestOutput] = [] - # Orchestrator keeps stage objects for input derivation num_stages = len(self.stage_list) @@ -389,13 +387,6 @@ def _run_generation( stage.set_engine_outputs(engine_outputs) if getattr(stage, "final_output", False): - final_outputs.append( - OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, # type: ignore[attr-defined] - request_output=engine_outputs, - ) - ) logger.debug( "[Orchestrator] Request %s finalized at stage-%s", req_id, @@ -421,6 +412,12 @@ def _run_generation( e, ) + yield OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, # type: ignore[attr-defined] + request_output=engine_outputs, + ) + next_stage_id = stage_id + 1 if next_stage_id <= final_stage_id_to_prompt[req_id]: next_stage: OmniStage = self.stage_list[next_stage_id] @@ -484,8 +481,6 @@ def _run_generation( except Exception as e: logger.exception("[Orchestrator] Failed to build/log summary: %s", e) - return final_outputs - def _wait_for_stages_ready(self, timeout: int = 120) -> None: deadline = time.time() + max(0, int(timeout)) num_stages = len(self.stage_list) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 503c8395c..b56039e35 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -1175,11 +1175,12 @@ async def generation_single_request(task: dict[str, Any]): _gen_t0 = _time.time() if isinstance(ein, list): ein = ein[0] + async for res in stage_engine.generate(ein, sampling_params, rid): gen_output = res - _gen_t1 = _time.time() - _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 - await generation_out_q.put((rid, gen_output, _gen_ms)) + _gen_t1 = _time.time() + _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 + await generation_out_q.put((rid, gen_output, _gen_ms)) except Exception as e: _logging.getLogger(__name__).exception("[Stage-%s] Failed on request %s: %s", stage_id, rid, e) out_q.put( @@ -1236,7 +1237,6 @@ async def generation_single_request(task: dict[str, Any]): for rid, _gen_ms in zip(batch_request_ids, _gen_ms_list): log_stage_batch_stats(_stats_file, stage_id, 1, float(_gen_ms), [rid]) - logger.info("[Stage-%s] Sending outputs to main process", stage_id) for rid, output, _gen_ms in zip(batch_request_ids, batch_request_outputs, _gen_ms_list): try: r_outputs = [output] diff --git a/vllm_omni/entrypoints/openai/protocol/__init__.py b/vllm_omni/entrypoints/openai/protocol/__init__.py index b17b648eb..da65e1817 100644 --- a/vllm_omni/entrypoints/openai/protocol/__init__.py +++ b/vllm_omni/entrypoints/openai/protocol/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm_omni.entrypoints.openai.protocol.chat_completion import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.images import ( ImageData, ImageGenerationRequest, @@ -13,4 +14,5 @@ "ImageGenerationRequest", "ImageGenerationResponse", "ResponseFormat", + "OmniChatCompletionStreamResponse", ] diff --git a/vllm_omni/entrypoints/openai/protocol/chat_completion.py b/vllm_omni/entrypoints/openai/protocol/chat_completion.py new file mode 100644 index 000000000..9f6076249 --- /dev/null +++ b/vllm_omni/entrypoints/openai/protocol/chat_completion.py @@ -0,0 +1,5 @@ +from vllm.entrypoints.openai.protocol import ChatCompletionStreamResponse + + +class OmniChatCompletionStreamResponse(ChatCompletionStreamResponse): + modality: str | None = "text" diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 422a38788..136ec0c78 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from datetime import datetime, timedelta, timezone from io import BytesIO -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Final, Optional import jinja2 from fastapi import Request @@ -29,13 +29,18 @@ make_tool_call_id, resolve_chat_template_content_format, ) -from vllm.entrypoints.harmony_utils import parse_chat_output +from vllm.entrypoints.harmony_utils import get_streamable_parser_for_assistant, parse_chat_output from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition, @@ -56,6 +61,8 @@ ) from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls +from vllm.entrypoints.utils import should_include_usage from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -68,9 +75,11 @@ truncate_tool_call_ids, validate_request_params, ) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.collection_utils import as_list from vllm_omni.entrypoints.chat_utils import parse_chat_messages_futures +from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.outputs import OmniRequestOutput if TYPE_CHECKING: @@ -224,6 +233,9 @@ async def create_chat_completion( raw_request.state.request_metadata = request_metadata output_modalities = getattr(request, "modalities", self.engine_client.output_modalities) + request.modalities = ( + output_modalities if output_modalities is not None else self.engine_client.output_modalities + ) # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] @@ -264,7 +276,15 @@ async def create_chat_completion( # Streaming response if request.stream: - raise RuntimeError("Not support streaming output now.") + return self.chat_completion_stream_generator( + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + ) try: return await self.chat_completion_full_generator( @@ -514,6 +534,732 @@ def _log_inputs( lora_request, ) + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ): + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration_dict = {} + assert hasattr(request, "modalities") and request.modalities is not None, ( + "Streaming request must specify output modalities" + ) + for modality in request.modalities: + first_iteration_dict[modality] = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + num_cached_tokens = None + if self.use_harmony: + harmony_parsers = [get_streamable_parser_for_assistant() for _ in range(num_choices)] + harmony_tools_streamed = [False] * num_choices + tools_streamed = [False] * num_choices + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = not tool_choice_function_name and self._should_stream_with_auto_tool_parsing(request) + + all_previous_token_ids: list[list[int]] | None + function_name_returned = [False] * num_choices + if self.tool_call_id_type == "kimi_k2": + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 + + # Always track previous_texts for comprehensive output logging + previous_texts = [""] * num_choices + + # Only one of these will be used, thus previous_texts and + # all_previous_token_ids will not be used twice in the same iteration. + if tool_choice_auto or self.reasoning_parser: + # These are only required in "auto" tool choice case + all_previous_token_ids = [[]] * num_choices + # For reasoning parser and tool call all enabled + added_content_delta_arr = [False] * num_choices + reasoning_end_arr = [False] * num_choices + else: + all_previous_token_ids = None + + try: + if self.reasoning_parser: + reasoning_parser = self.reasoning_parser( + tokenizer, + chat_template_kwargs=request.chat_template_kwargs, # type: ignore + ) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: list[ToolParser | None] = [self.tool_parser(tokenizer)] * num_choices + else: + tool_parsers = [None] * num_choices + except Exception as e: + logger.exception("Error in tool parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + stream_options = request.stream_options + include_usage, include_continuous_usage = should_include_usage(stream_options, self.enable_force_include_usage) + + try: + async for omni_res in result_generator: + final_output_type = omni_res.final_output_type + res = omni_res.request_output + if final_output_type not in first_iteration_dict: + logger.warning(f"final output type: {final_output_type} is not needed by the request") + continue + + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration_dict[final_output_type] and final_output_type == "text": + num_cached_tokens = res.num_cached_tokens + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None, + ) + + # return prompt_token_ids at the first chunk ever + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + prompt_token_ids=(res.prompt_token_ids if request.return_token_ids else None), + modality=final_output_type, + ) + + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo: + last_msg_content: str | list[dict[str, str]] = "" + if conversation and "content" in conversation[-1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + logprobs=None, + finish_reason=None, + ) + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + modality=final_output_type, + ) + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration_dict[final_output_type] = False + + if final_output_type == "text": + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + + if self.use_harmony: + harmony_parser = harmony_parsers[i] + prev_recipient = harmony_parser.current_recipient + delta_text = "" + for token_id in output.token_ids: + harmony_parser.process(token_id) + delta_text += harmony_parser.last_content_delta or "" + cur_channel = harmony_parser.current_channel + cur_recipient = harmony_parser.current_recipient + else: + # output.text is cumulative, extract only the delta portion + previous_text = previous_texts[i] if previous_texts else "" + if output.text is not None: + delta_text = output.text[len(previous_text) :] + else: + delta_text = "" + + if not delta_text and not output.token_ids and not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + delta_message: DeltaMessage | None + + # just update previous_texts and previous_token_ids + if tool_choice_auto or self.reasoning_parser: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + # avoid the None + list error. + if previous_token_ids: + current_token_ids = previous_token_ids + as_list(output.token_ids) + else: + current_token_ids = as_list(output.token_ids) + + if self.use_harmony: + if cur_channel == "final": + delta_message = DeltaMessage(content=delta_text) + elif cur_channel == "analysis": + if request.include_reasoning: + delta_message = DeltaMessage(reasoning=delta_text) + else: + delta_message = None + elif ( + cur_channel == "commentary" and cur_recipient and cur_recipient.startswith("functions.") + ): + # Count completed tool calls to determine index + base_index = 0 + for msg in harmony_parser.messages: + if ( + msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith("functions.") + ): + base_index += 1 + + if prev_recipient != cur_recipient: + tool_name = cur_recipient.split("functions.", 1)[1] + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ] + ) + elif delta_text: + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall(arguments=delta_text), + ) + ] + ) + else: + delta_message = None + + if delta_message is not None: + harmony_tools_streamed[i] = True + else: + delta_message = None + # handle streaming deltas for tools with named tool_choice + elif tool_choice_function_name: + if ( + self.reasoning_parser + and not reasoning_end_arr[i] + and not reasoning_parser.is_reasoning_end(previous_token_ids) + ): + assert reasoning_parser is not None + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) + # When encountering think end id in delta_token_ids + # or think end id in prompt_token_ids + # i.e {"enable_thinking": False}, + # set reasoning status to end. + # Only keep 'content', remove 'reasoning'. + if reasoning_parser.is_reasoning_end(as_list(output.token_ids)) or ( + res.prompt_token_ids and reasoning_parser.is_reasoning_end(res.prompt_token_ids) + ): + reasoning_end_arr[i] = True + if delta_message and delta_message.content: + # This need to be added to next `delta_text` + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + else: + # Just to add remaining `content` + if self.reasoning_parser: + delta_text = previous_text + delta_text + current_text = "" + + if function_name_returned[i]: + delta_tool_call = DeltaToolCall( + function=DeltaFunctionCall(arguments=delta_text), + index=i, + ) + else: + delta_tool_call = DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text, + ), + index=i, + ) + function_name_returned[i] = True + + delta_message = DeltaMessage( + tool_calls=[ + delta_tool_call, + ] + ) + tools_streamed[i] = True + + elif request.tool_choice == "required": + assert previous_texts is not None + previous_text = previous_texts[i] + current_text = previous_text + delta_text + fn_name_returned = function_name_returned[i] + output_token_ids = as_list(output.token_ids) + + if ( + self.reasoning_parser is not None + and not reasoning_end_arr[i] + and res.prompt_token_ids + and reasoning_parser.is_reasoning_end(res.prompt_token_ids) + ): + reasoning_end_arr[i] = True + + if self.reasoning_parser and not reasoning_end_arr[i]: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) + if reasoning_parser.is_reasoning_end(output_token_ids): + reasoning_end_arr[i] = True + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + # reasoning ended + current_text = "" + + else: + # either finished reasoning or no reasoning at all + content = current_text + + delta_message, function_name_returned[i] = self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=content, + delta_text=delta_text, + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt, + ) + if ( + delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0].id is not None + ): + history_tool_call_cnt += 1 + tools_streamed[i] = True + + # handle streaming deltas for tools with "auto" tool choice + # and reasoning parser + elif tool_choice_auto and self.reasoning_parser: + assert tool_parser is not None + assert reasoning_parser is not None + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + output_token_ids = as_list(output.token_ids) + if not reasoning_end_arr[i]: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) + # When encountering think end id in prompt_token_ids + # i.e {"enable_thinking": False}, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning'. + if res.prompt_token_ids and reasoning_parser.is_reasoning_end(res.prompt_token_ids): + reasoning_end_arr[i] = True + current_token_ids = output_token_ids + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning'. + if reasoning_parser.is_reasoning_end(output_token_ids): + reasoning_end_arr[i] = True + current_token_ids = reasoning_parser.extract_content_ids(output_token_ids) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + # handle tool calls only after reasoning is done, + else: + delta_token_ids = output_token_ids + # First time to tool call, + # add the remaining text and token ids + # to delta from previous + if not added_content_delta_arr[i]: + added_content_delta_arr[i] = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + + # when only reasoning + elif self.reasoning_parser: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # update the previous values for the next iteration + if (tool_choice_auto or self.reasoning_parser) and not self.use_harmony: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + else: + # Update for comprehensive logging even in simple case + assert previous_texts is not None + previous_texts[i] += delta_text + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + if output.finish_reason is None: + continue + else: + delta_message = DeltaMessage() + + # Log streaming delta if output logging is enabled + if self.enable_log_outputs and self.request_logger: + delta_content = "" + if delta_message.content: + delta_content = delta_message.content + elif delta_message.tool_calls: + delta_content = "".join( + tc.function.arguments + for tc in delta_message.tool_calls + if tc.function and tc.function.arguments + ) + + if delta_content: + self.request_logger.log_outputs( + request_id=request_id, + outputs=delta_content, + output_token_ids=as_list(output.token_ids), + finish_reason=output.finish_reason, + is_streaming=True, + delta=True, + ) + + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None, + token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + ) + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using structured outputs + auto_tools_called = False + if tool_parser: + auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens(delta_message, output) and tool_parser: + latest_delta_len = 0 + if ( + isinstance( + delta_message.tool_calls[0].function, + DeltaFunctionCall, + ) + ) and isinstance(delta_message.tool_calls[0].function.arguments, str): + latest_delta_len = len(delta_message.tool_calls[0].function.arguments) + + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get("arguments", {}), + ensure_ascii=False, + ) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[index] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace(actual_call, "", 1) + # set that as a delta message + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=index, + function=DeltaFunctionCall(arguments=remaining_call).model_dump( + exclude_none=True + ), + ) + ] + ) + + # Send the finish response for each request.n only once + # In OpenAI's API, when a tool is called, the + # finish_reason is: + # "tool_calls" for "auto" or "required" tool calls, + # and "stop" for named tool calls. + if ( + auto_tools_called + or (tools_streamed[i] and not tool_choice_function_name) + or (self.use_harmony and harmony_tools_streamed[i]) + ): + finish_reason_ = "tool_calls" + else: + finish_reason_ = output.finish_reason if output.finish_reason else "stop" + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=finish_reason_, + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + ) + + finish_reason_sent[i] = True + + choice_data = maybe_filter_parallel_tool_calls(choice_data, request) + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + modality=final_output_type, + ) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + elif final_output_type == "audio": + choices_data = self._create_audio_choice(omni_res, role, request, stream=True) + chunk = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=choices_data, + model=model_name, + modality=final_output_type, + ) + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + else: + logger.warning(f"Unsupported streaming final output type: {final_output_type}") + continue + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if include_usage: + completion_tokens = sum(previous_num_tokens) + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo(cached_tokens=num_cached_tokens) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage, + ) + final_usage_data = final_usage_chunk.model_dump_json(exclude_unset=True, exclude_none=True) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens, + ) + + # Log complete streaming response if output logging is enabled + if self.enable_log_outputs and self.request_logger: + # Log the complete response for each choice + for i in range(num_choices): + full_text = ( + previous_texts[i] + if previous_texts and i < len(previous_texts) + else f"" + ) + self.request_logger.log_outputs( + request_id=request_id, + outputs=full_text, + output_token_ids=None, # Consider also logging all token IDs + finish_reason="streaming_complete", + is_streaming=True, + delta=False, + ) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + async def chat_completion_full_generator( self, request: ChatCompletionRequest, @@ -549,6 +1295,9 @@ async def chat_completion_full_generator( for omni_outputs in final_outputs: choices_data = [] + if omni_outputs.request_output is not None and not getattr(omni_outputs.request_output, "finished", False): + continue + if omni_outputs.final_output_type == "text": ( choices_data, @@ -558,9 +1307,9 @@ async def chat_completion_full_generator( kv_transfer_params, ) = self._create_text_choice(request, omni_outputs, tokenizer, conversation, role) elif omni_outputs.final_output_type == "audio": - choices_data = self._create_audio_choice(omni_outputs, role) + choices_data = self._create_audio_choice(omni_outputs, role, request, stream=False) elif omni_outputs.final_output_type == "image": - choices_data = self._create_image_choice(omni_outputs, role) + choices_data = self._create_image_choice(omni_outputs, role, request, stream=False) else: logger.warning(f"Unsupported final output type: {omni_outputs.final_output_type}") continue @@ -853,7 +1602,9 @@ def _create_text_choice( return choices, usage, prompt_logprobs, prompt_token_ids, kv_transfer_params - def _create_audio_choice(self, omni_outputs: OmniRequestOutput, role: str): + def _create_audio_choice( + self, omni_outputs: OmniRequestOutput, role: str, request: ChatCompletionRequest, stream: bool = False + ): choices: list[ChatCompletionResponseChoice] = [] final_res = omni_outputs.request_output audio_tensor = final_res.multimodal_output["audio"].float().detach().cpu().numpy() @@ -894,17 +1645,29 @@ def _create_audio_choice(self, omni_outputs: OmniRequestOutput, role: str): ) for output in final_res.outputs: - choice_data = ChatCompletionResponseChoice( - index=output.index, - message=ChatMessage(role=role, audio=audio_obj), - logprobs=None, - finish_reason="stop", - stop_reason=None, - ) + if stream: + choice_data = ChatCompletionResponseStreamChoice( + index=output.index, + delta=DeltaMessage(role=role, content=audio_base64), + logprobs=None, + finish_reason="stop", + stop_reason=output.stop_reason, + token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + ) + else: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role=role, audio=audio_obj), + logprobs=None, + finish_reason="stop", + stop_reason=None, + ) choices.append(choice_data) return choices - def _create_image_choice(self, omni_outputs: OmniRequestOutput, role: str): + def _create_image_choice( + self, omni_outputs: OmniRequestOutput, role: str, request: ChatCompletionRequest, stream: bool = False + ): """Create chat completion response choices for image output. Converts image tensor or PIL Image output from diffusion models diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 2b1759f71..429c5ed54 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -421,10 +421,6 @@ def forward( ) ) - # # Remove EOS token if present - # if code[-1] == TALKER_CODEC_EOS_TOKEN_ID: - # code = code[:-1] - # Generate audio from codec codes audio_tensors = [] for code in codes: