From d256fe331f1fc20229663627c6a018e9550a1424 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:41:18 +0200 Subject: [PATCH] feat(runnable-rails): stream metadata in RunnableRails output Enhance streaming in RunnableRails to include generation metadata in streamed chunks. Skips END_OF_STREAM markers and updates chunk formatting to support metadata for AIMessageChunk outputs. This improves compatibility with consumers expecting metadata in streaming responses. fix fix --- .../integrations/langchain/runnable_rails.py | 57 +++-- tests/runnable_rails/test_metadata.py | 72 ++++++ tests/runnable_rails/test_streaming.py | 55 ++++- tool_calling_example.py | 227 ++++++++++++++++++ 4 files changed, 395 insertions(+), 16 deletions(-) create mode 100644 tool_calling_example.py diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index a39e8aaa1..764930e2a 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -218,6 +218,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]: "role": "context", "content": { "passthrough_input": _input, + # We also set all the input variables as top level context variables **(_input if isinstance(_input, dict) else {}), }, }, @@ -838,7 +839,20 @@ async def astream( streaming_enabled = True try: - async for chunk in self.rails.stream_async(messages=input_messages): + from nemoguardrails.streaming import END_OF_STREAM + + async for chunk in self.rails.stream_async( + messages=input_messages, include_generation_metadata=True + ): + # Skip END_OF_STREAM markers + chunk_text = ( + chunk["text"] + if isinstance(chunk, dict) and "text" in chunk + else chunk + ) + if chunk_text is END_OF_STREAM: + continue + # Format the chunk based on the input type for streaming formatted_chunk = self._format_streaming_chunk(input, chunk) yield formatted_chunk @@ -846,26 +860,35 @@ async def astream( if streaming_enabled and hasattr(self.rails.llm, "streaming"): self.rails.llm.streaming = original_streaming - def _format_streaming_chunk(self, input: Any, chunk: str) -> Any: + def _format_streaming_chunk(self, input: Any, chunk) -> Any: """Format a streaming chunk based on the input type. Args: input: The original input - chunk: The current text chunk + chunk: The current chunk (string or dict with text/generation_info) Returns: The formatted streaming chunk (using AIMessageChunk for LangChain compatibility) """ + text_content = chunk + metadata = {} + + if isinstance(chunk, dict) and "text" in chunk: + text_content = chunk["text"] + generation_info = chunk.get("generation_info", {}) + + if generation_info: + metadata = generation_info.copy() if isinstance(input, ChatPromptValue): - return AIMessageChunk(content=chunk) + return AIMessageChunk(content=text_content, **metadata) elif isinstance(input, StringPromptValue): - return chunk + return text_content # String outputs don't support metadata elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)): - return AIMessageChunk(content=chunk) + return AIMessageChunk(content=text_content, **metadata) elif isinstance(input, list) and all( isinstance(msg, BaseMessage) for msg in input ): - return AIMessageChunk(content=chunk) + return AIMessageChunk(content=text_content, **metadata) elif isinstance(input, dict): output_key = self.passthrough_bot_output_key if self.passthrough_user_input_key in input or "input" in input: @@ -873,20 +896,26 @@ def _format_streaming_chunk(self, input: Any, chunk: str) -> Any: self.passthrough_user_input_key, input.get("input") ) if isinstance(user_input, str): - return {output_key: chunk} + return {output_key: text_content} elif isinstance(user_input, list): if all( isinstance(msg, dict) and "role" in msg for msg in user_input ): - return {output_key: {"role": "assistant", "content": chunk}} + return { + output_key: {"role": "assistant", "content": text_content} + } elif all(isinstance(msg, BaseMessage) for msg in user_input): - return {output_key: AIMessageChunk(content=chunk)} - return {output_key: chunk} + return { + output_key: AIMessageChunk(content=text_content, **metadata) + } + return {output_key: text_content} elif isinstance(user_input, BaseMessage): - return {output_key: AIMessageChunk(content=chunk)} - return {output_key: chunk} + return { + output_key: AIMessageChunk(content=text_content, **metadata) + } + return {output_key: text_content} elif isinstance(input, str): - return AIMessageChunk(content=chunk) + return AIMessageChunk(content=text_content, **metadata) else: raise ValueError(f"Unexpected input type: {type(input)}") diff --git a/tests/runnable_rails/test_metadata.py b/tests/runnable_rails/test_metadata.py index 310435158..cc9d5e35b 100644 --- a/tests/runnable_rails/test_metadata.py +++ b/tests/runnable_rails/test_metadata.py @@ -373,3 +373,75 @@ def test_partial_metadata(self, mock_rails_config): assert result.content == "Test response" assert result.additional_kwargs == {"custom_field": "value"} assert result.response_metadata is None or result.response_metadata == {} + + def test_streaming_metadata_preservation(self, mock_rails_config): + """Test that streaming preserves metadata in chunks.""" + from unittest.mock import AsyncMock + + mock_rails = AsyncMock() + mock_generation_response = Mock() + mock_generation_response.response = "Streaming response" + mock_generation_response.output_data = {} + mock_generation_response.tool_calls = None + mock_generation_response.llm_metadata = { + "additional_kwargs": {"finish_reason": "stop"}, + "response_metadata": {"model_name": "test-model"}, + } + + async def mock_stream(*args, **kwargs): + chunks = [ + { + "text": "Hello ", + "generation_info": {"model": "test-model", "finish_reason": "stop"}, + }, + { + "text": "world!", + "generation_info": {"model": "test-model", "finish_reason": "stop"}, + }, + ] + for chunk in chunks: + yield chunk + + mock_rails.stream_async = mock_stream + + runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True) + runnable_rails.rails = mock_rails + + chunks = list(runnable_rails.stream("Test input")) + + assert len(chunks) == 2 + for chunk in chunks: + assert hasattr(chunk, "content") + assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model") + assert hasattr(chunk, "response_metadata") or hasattr( + chunk, "finish_reason" + ) + + @pytest.mark.asyncio + async def test_async_streaming_metadata_preservation(self, mock_rails_config): + """Test that async streaming preserves metadata in chunks.""" + from unittest.mock import AsyncMock + + mock_rails = AsyncMock() + + async def mock_stream(*args, **kwargs): + chunks = [ + {"text": "Async ", "generation_info": {"model": "test-model"}}, + {"text": "stream!", "generation_info": {"model": "test-model"}}, + ] + for chunk in chunks: + yield chunk + + mock_rails.stream_async = mock_stream + + runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True) + runnable_rails.rails = mock_rails + + chunks = [] + async for chunk in runnable_rails.astream("Test input"): + chunks.append(chunk) + + assert len(chunks) == 2 + for chunk in chunks: + assert hasattr(chunk, "content") + assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model") diff --git a/tests/runnable_rails/test_streaming.py b/tests/runnable_rails/test_streaming.py index 2c3615816..27ebe8f58 100644 --- a/tests/runnable_rails/test_streaming.py +++ b/tests/runnable_rails/test_streaming.py @@ -60,12 +60,10 @@ async def _astream(self, messages, stop=None, run_manager=None, **kwargs): def test_runnable_rails_basic_streaming(): """Test basic synchronous streaming functionality.""" - # Configure a streaming LLM with a response llm = StreamingFakeLLM(responses=["Hello there! How can I help you today?"]) config = RailsConfig.from_content(config={"models": []}) rails = RunnableRails(config, llm=llm) - # Collect chunks from the stream chunks = [] for chunk in rails.stream("Hi there"): chunks.append(chunk) @@ -451,3 +449,56 @@ def test_streaming_with_different_input_types(): ), f"Failed for {input_type}: {full_content}" assert llm.streaming == False + + +def test_streaming_metadata_preservation(): + """Test that streaming chunks preserve metadata structure.""" + llm = FakeLLM(responses=["Test response"]) + config = RailsConfig.from_content(config={"models": []}) + model_with_rails = RunnableRails(config, llm=llm) + + chunks = [] + for chunk in model_with_rails.stream("Test input"): + chunks.append(chunk) + + assert len(chunks) > 0 + + for chunk in chunks: + assert hasattr(chunk, "content") + assert hasattr(chunk, "additional_kwargs") + assert hasattr(chunk, "response_metadata") + assert isinstance(chunk.additional_kwargs, dict) + assert isinstance(chunk.response_metadata, dict) + + +@pytest.mark.asyncio +async def test_async_streaming_metadata_preservation(): + """Test that async streaming chunks preserve metadata structure.""" + llm = FakeLLM(responses=["Test async response"]) + config = RailsConfig.from_content(config={"models": []}) + model_with_rails = RunnableRails(config, llm=llm) + + chunks = [] + async for chunk in model_with_rails.astream("Test input"): + chunks.append(chunk) + + assert len(chunks) > 0 + + for chunk in chunks: + assert hasattr(chunk, "content") + assert hasattr(chunk, "additional_kwargs") + assert hasattr(chunk, "response_metadata") + assert isinstance(chunk.additional_kwargs, dict) + assert isinstance(chunk.response_metadata, dict) + + +def test_streaming_chunk_types(): + """Test that streaming returns proper AIMessageChunk types.""" + llm = FakeLLM(responses=["Hello world"]) + config = RailsConfig.from_content(config={"models": []}) + model_with_rails = RunnableRails(config, llm=llm) + + chunks = list(model_with_rails.stream("Hi")) + + for chunk in chunks: + assert chunk.__class__.__name__ == "AIMessageChunk" diff --git a/tool_calling_example.py b/tool_calling_example.py new file mode 100644 index 000000000..7a05ee854 --- /dev/null +++ b/tool_calling_example.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: Tool calling with RunnableRails using hello_world config + +This demonstrates how RunnableRails now preserves tool calls and can be used +with real LLMs and tools. +""" + +import asyncio +import json + +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.tools import tool +from langchain_openai import ChatOpenAI + +from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails +from nemoguardrails.rails.llm.options import GenerationOptions + + +# Define some simple tools - using the exact format from LangChain docs +@tool +def add(a: int, b: int) -> int: + """Adds a and b.""" + return a + b + + +@tool +def multiply(a: int, b: int) -> int: + """Multiplies a and b.""" + return a * b + + +@tool +def get_weather(location: str) -> str: + """Get the current weather for a location.""" + # Mock weather data for demo + weather_data = { + "new york": "Sunny, 72°F", + "london": "Rainy, 15°C", + "tokyo": "Cloudy, 22°C", + "san francisco": "Foggy, 60°F", + } + return weather_data.get( + location.lower(), f"Weather data not available for {location}" + ) + + +async def demo_direct_llmrails_with_tools(): + """Demo using LLMRails directly with tool calls.""" + print("🔧 Demo 1: LLMRails with Tool Calls") + + # Load the hello_world config + config = RailsConfig.from_path("./examples/configs/nemoguards") + + # Create LLM with tools bound + llm = ChatOpenAI(model="gpt-4o", temperature=0) + llm_with_tools = llm.bind_tools([add, multiply, get_weather]) + + # Create LLMRails + rails = LLMRails(config=config, llm=llm_with_tools) + + print("\n📍 Query: 'What is 3 + 5?'") + + # Generate response with tool calls + result = await rails.generate_async( + messages=[{"role": "user", "content": "What is 3 + 5?"}], + options=GenerationOptions(), # This ensures we get GenerationResponse with tool_calls + ) + + print(f"Response: {result.response}") + if result.tool_calls: + print(f"Tool calls generated: {result.tool_calls}") + print( + f"Tool call format: {type(result.tool_calls[0])} - {result.tool_calls[0]}" + ) + + # Execute the tool call - try both formats + for tool_call in result.tool_calls: + print(f"Processing tool call: {tool_call}") + + # Check if it's the standard LangChain format + if "function" in tool_call: + tool_name = tool_call["function"]["name"] + args = json.loads(tool_call["function"]["arguments"]) + print(f"✅ Using standard LangChain format: {tool_name}, {args}") + else: + # Check if it's the simplified format + tool_name = tool_call.get("name") + args = tool_call.get("args", {}) + print(f"⚠️ Using simplified format: {tool_name}, {args}") + + if tool_name == "add": + result = add.invoke(args) + print(f"Tool execution result: {result}") + elif tool_name == "multiply": + result = multiply.invoke(args) + print(f"Tool execution result: {result}") + elif tool_name == "get_weather": + result = get_weather.invoke(args) + print(f"Tool execution result: {result}") + else: + print("No tool calls generated") + + +async def demo_runnable_rails_chain(): + """Demo using RunnableRails in a simple chain.""" + print("\n🔗 Demo 2: RunnableRails in a Chain") + + # Load config + config = RailsConfig.from_path("./examples/bots/hello_world/") + + # Create LLM with tools + llm = ChatOpenAI(model="gpt-4", temperature=0) + tools = [add, multiply, get_weather] + llm_with_tools = llm.bind_tools(tools) + + # Create RunnableRails + runnable_rails = RunnableRails( + config=config, + llm=llm_with_tools, + # tools=tools, # Register tools with Rails + passthrough=True, + ) + + print("\n🧮 Query: 'What is 6 * 7?'") + + # Simple invocation + result = runnable_rails.invoke("What is 6 * 7?") + print(f"Result: {result}") + + # Check if tool_calls are preserved in the response + if isinstance(result, dict) and "tool_calls" in result: + print(f"Tool calls preserved: {result['tool_calls']}") + + +async def demo_multi_turn_conversation(): + """Demo a multi-turn conversation with tool calls.""" + print("\n💬 Demo 3: Multi-turn Conversation with Tools") + + config = RailsConfig.from_path("./examples/bots/hello_world/") + llm = ChatOpenAI(model="gpt-4", temperature=0) + llm_with_tools = llm.bind_tools([get_weather, calculate]) + + rails = LLMRails(config=config, llm=llm_with_tools) + + # Conversation with tool calls + messages = [ + {"role": "user", "content": "What's the weather in London and what's 15 + 27?"} + ] + + result = await rails.generate_async(messages=messages) + print(f"Assistant: {result['content']}") + + if "tool_calls" in result: + print(f"Tool calls to execute: {len(result['tool_calls'])}") + + # Add assistant message with tool calls + messages.append(result) + + # Execute tools and add tool messages + for tool_call in result["tool_calls"]: + tool_name = tool_call["name"] + tool_args = tool_call["args"] # Already a dict + + if tool_name == "get_weather": + tool_result = get_weather.invoke(tool_args) + elif tool_name == "calculate": + tool_result = calculate.invoke(tool_args) + else: + tool_result = "Unknown tool" + + messages.append( + { + "role": "tool", + "content": tool_result, + "tool_call_id": tool_call["id"], + } + ) + print(f"Tool {tool_name} result: {tool_result}") + + # Get final response + final_result = await rails.generate_async(messages=messages) + print(f"Final response: {final_result['content']}") + + +async def main(): + """Run all demos.""" + print("🚀 Tool Calling with NeMo Guardrails Examples") + print("=" * 50) + + try: + await demo_direct_llmrails_with_tools() + await demo_runnable_rails_chain() + await demo_multi_turn_conversation() + + except Exception as e: + print(f"\n⚠️ Error: {e}") + print("Make sure you have:") + print("1. Set OPENAI_API_KEY environment variable") + print("2. Installed langchain-openai: pip install langchain-openai") + + # Show what the flow would look like + print("\n💡 Expected Flow:") + print("1. User: 'What's the weather in San Francisco?'") + print("2. LLM generates tool_call for get_weather") + print("3. Tool executes → 'Foggy, 60°F'") + print("4. Final response: 'The weather in San Francisco is foggy and 60°F.'") + + +if __name__ == "__main__": + asyncio.run(main())