Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
"role": "context",
"content": {
"passthrough_input": _input,
# We also set all the input variables as top level context variables
**(_input if isinstance(_input, dict) else {}),
},
},
Expand Down Expand Up @@ -838,55 +839,83 @@ async def astream(
streaming_enabled = True

try:
async for chunk in self.rails.stream_async(messages=input_messages):
from nemoguardrails.streaming import END_OF_STREAM

async for chunk in self.rails.stream_async(
messages=input_messages, include_generation_metadata=True
):
# Skip END_OF_STREAM markers
chunk_text = (
chunk["text"]
if isinstance(chunk, dict) and "text" in chunk
else chunk
)
if chunk_text is END_OF_STREAM:
continue

# Format the chunk based on the input type for streaming
formatted_chunk = self._format_streaming_chunk(input, chunk)
yield formatted_chunk
finally:
if streaming_enabled and hasattr(self.rails.llm, "streaming"):
self.rails.llm.streaming = original_streaming

def _format_streaming_chunk(self, input: Any, chunk: str) -> Any:
def _format_streaming_chunk(self, input: Any, chunk) -> Any:
"""Format a streaming chunk based on the input type.

Args:
input: The original input
chunk: The current text chunk
chunk: The current chunk (string or dict with text/generation_info)

Returns:
The formatted streaming chunk (using AIMessageChunk for LangChain compatibility)
"""
text_content = chunk
metadata = {}

if isinstance(chunk, dict) and "text" in chunk:
text_content = chunk["text"]
generation_info = chunk.get("generation_info", {})

if generation_info:
metadata = generation_info.copy()
if isinstance(input, ChatPromptValue):
return AIMessageChunk(content=chunk)
return AIMessageChunk(content=text_content, **metadata)
elif isinstance(input, StringPromptValue):
return chunk
return text_content # String outputs don't support metadata
elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)):
return AIMessageChunk(content=chunk)
return AIMessageChunk(content=text_content, **metadata)
elif isinstance(input, list) and all(
isinstance(msg, BaseMessage) for msg in input
):
return AIMessageChunk(content=chunk)
return AIMessageChunk(content=text_content, **metadata)
elif isinstance(input, dict):
output_key = self.passthrough_bot_output_key
if self.passthrough_user_input_key in input or "input" in input:
user_input = input.get(
self.passthrough_user_input_key, input.get("input")
)
if isinstance(user_input, str):
return {output_key: chunk}
return {output_key: text_content}
elif isinstance(user_input, list):
if all(
isinstance(msg, dict) and "role" in msg for msg in user_input
):
return {output_key: {"role": "assistant", "content": chunk}}
return {
output_key: {"role": "assistant", "content": text_content}
}
elif all(isinstance(msg, BaseMessage) for msg in user_input):
return {output_key: AIMessageChunk(content=chunk)}
return {output_key: chunk}
return {
output_key: AIMessageChunk(content=text_content, **metadata)
}
return {output_key: text_content}
elif isinstance(user_input, BaseMessage):
return {output_key: AIMessageChunk(content=chunk)}
return {output_key: chunk}
return {
output_key: AIMessageChunk(content=text_content, **metadata)
}
return {output_key: text_content}
elif isinstance(input, str):
return AIMessageChunk(content=chunk)
return AIMessageChunk(content=text_content, **metadata)
else:
raise ValueError(f"Unexpected input type: {type(input)}")

Expand Down
72 changes: 72 additions & 0 deletions tests/runnable_rails/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,75 @@ def test_partial_metadata(self, mock_rails_config):
assert result.content == "Test response"
assert result.additional_kwargs == {"custom_field": "value"}
assert result.response_metadata is None or result.response_metadata == {}

def test_streaming_metadata_preservation(self, mock_rails_config):
"""Test that streaming preserves metadata in chunks."""
from unittest.mock import AsyncMock

mock_rails = AsyncMock()
mock_generation_response = Mock()
mock_generation_response.response = "Streaming response"
mock_generation_response.output_data = {}
mock_generation_response.tool_calls = None
mock_generation_response.llm_metadata = {
"additional_kwargs": {"finish_reason": "stop"},
"response_metadata": {"model_name": "test-model"},
}

async def mock_stream(*args, **kwargs):
chunks = [
{
"text": "Hello ",
"generation_info": {"model": "test-model", "finish_reason": "stop"},
},
{
"text": "world!",
"generation_info": {"model": "test-model", "finish_reason": "stop"},
},
]
for chunk in chunks:
yield chunk

mock_rails.stream_async = mock_stream

runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
runnable_rails.rails = mock_rails

chunks = list(runnable_rails.stream("Test input"))

assert len(chunks) == 2
for chunk in chunks:
assert hasattr(chunk, "content")
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")
assert hasattr(chunk, "response_metadata") or hasattr(
chunk, "finish_reason"
)

@pytest.mark.asyncio
async def test_async_streaming_metadata_preservation(self, mock_rails_config):
"""Test that async streaming preserves metadata in chunks."""
from unittest.mock import AsyncMock

mock_rails = AsyncMock()

async def mock_stream(*args, **kwargs):
chunks = [
{"text": "Async ", "generation_info": {"model": "test-model"}},
{"text": "stream!", "generation_info": {"model": "test-model"}},
]
for chunk in chunks:
yield chunk

mock_rails.stream_async = mock_stream

runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
runnable_rails.rails = mock_rails

chunks = []
async for chunk in runnable_rails.astream("Test input"):
chunks.append(chunk)

assert len(chunks) == 2
for chunk in chunks:
assert hasattr(chunk, "content")
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")
55 changes: 53 additions & 2 deletions tests/runnable_rails/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@ async def _astream(self, messages, stop=None, run_manager=None, **kwargs):

def test_runnable_rails_basic_streaming():
"""Test basic synchronous streaming functionality."""
# Configure a streaming LLM with a response
llm = StreamingFakeLLM(responses=["Hello there! How can I help you today?"])
config = RailsConfig.from_content(config={"models": []})
rails = RunnableRails(config, llm=llm)

# Collect chunks from the stream
chunks = []
for chunk in rails.stream("Hi there"):
chunks.append(chunk)
Expand Down Expand Up @@ -451,3 +449,56 @@ def test_streaming_with_different_input_types():
), f"Failed for {input_type}: {full_content}"

assert llm.streaming == False


def test_streaming_metadata_preservation():
"""Test that streaming chunks preserve metadata structure."""
llm = FakeLLM(responses=["Test response"])
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

chunks = []
for chunk in model_with_rails.stream("Test input"):
chunks.append(chunk)

assert len(chunks) > 0

for chunk in chunks:
assert hasattr(chunk, "content")
assert hasattr(chunk, "additional_kwargs")
assert hasattr(chunk, "response_metadata")
assert isinstance(chunk.additional_kwargs, dict)
assert isinstance(chunk.response_metadata, dict)


@pytest.mark.asyncio
async def test_async_streaming_metadata_preservation():
"""Test that async streaming chunks preserve metadata structure."""
llm = FakeLLM(responses=["Test async response"])
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

chunks = []
async for chunk in model_with_rails.astream("Test input"):
chunks.append(chunk)

assert len(chunks) > 0

for chunk in chunks:
assert hasattr(chunk, "content")
assert hasattr(chunk, "additional_kwargs")
assert hasattr(chunk, "response_metadata")
assert isinstance(chunk.additional_kwargs, dict)
assert isinstance(chunk.response_metadata, dict)


def test_streaming_chunk_types():
"""Test that streaming returns proper AIMessageChunk types."""
llm = FakeLLM(responses=["Hello world"])
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

chunks = list(model_with_rails.stream("Hi"))

for chunk in chunks:
assert chunk.__class__.__name__ == "AIMessageChunk"
Loading