Skip to content

Commit a0c0e24

Browse files
committed
feat(runnable-rails): stream metadata in RunnableRails output
Enhance streaming in RunnableRails to include generation metadata in streamed chunks. Skips END_OF_STREAM markers and updates chunk formatting to support metadata for AIMessageChunk outputs. This improves compatibility with consumers expecting metadata in streaming responses. fix fix
1 parent 52a42c5 commit a0c0e24

File tree

4 files changed

+395
-16
lines changed

4 files changed

+395
-16
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
218218
"role": "context",
219219
"content": {
220220
"passthrough_input": _input,
221+
# We also set all the input variables as top level context variables
221222
**(_input if isinstance(_input, dict) else {}),
222223
},
223224
},
@@ -838,55 +839,83 @@ async def astream(
838839
streaming_enabled = True
839840

840841
try:
841-
async for chunk in self.rails.stream_async(messages=input_messages):
842+
from nemoguardrails.streaming import END_OF_STREAM
843+
844+
async for chunk in self.rails.stream_async(
845+
messages=input_messages, include_generation_metadata=True
846+
):
847+
# Skip END_OF_STREAM markers
848+
chunk_text = (
849+
chunk["text"]
850+
if isinstance(chunk, dict) and "text" in chunk
851+
else chunk
852+
)
853+
if chunk_text is END_OF_STREAM:
854+
continue
855+
842856
# Format the chunk based on the input type for streaming
843857
formatted_chunk = self._format_streaming_chunk(input, chunk)
844858
yield formatted_chunk
845859
finally:
846860
if streaming_enabled and hasattr(self.rails.llm, "streaming"):
847861
self.rails.llm.streaming = original_streaming
848862

849-
def _format_streaming_chunk(self, input: Any, chunk: str) -> Any:
863+
def _format_streaming_chunk(self, input: Any, chunk) -> Any:
850864
"""Format a streaming chunk based on the input type.
851865
852866
Args:
853867
input: The original input
854-
chunk: The current text chunk
868+
chunk: The current chunk (string or dict with text/generation_info)
855869
856870
Returns:
857871
The formatted streaming chunk (using AIMessageChunk for LangChain compatibility)
858872
"""
873+
text_content = chunk
874+
metadata = {}
875+
876+
if isinstance(chunk, dict) and "text" in chunk:
877+
text_content = chunk["text"]
878+
generation_info = chunk.get("generation_info", {})
879+
880+
if generation_info:
881+
metadata = generation_info.copy()
859882
if isinstance(input, ChatPromptValue):
860-
return AIMessageChunk(content=chunk)
883+
return AIMessageChunk(content=text_content, **metadata)
861884
elif isinstance(input, StringPromptValue):
862-
return chunk
885+
return text_content # String outputs don't support metadata
863886
elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)):
864-
return AIMessageChunk(content=chunk)
887+
return AIMessageChunk(content=text_content, **metadata)
865888
elif isinstance(input, list) and all(
866889
isinstance(msg, BaseMessage) for msg in input
867890
):
868-
return AIMessageChunk(content=chunk)
891+
return AIMessageChunk(content=text_content, **metadata)
869892
elif isinstance(input, dict):
870893
output_key = self.passthrough_bot_output_key
871894
if self.passthrough_user_input_key in input or "input" in input:
872895
user_input = input.get(
873896
self.passthrough_user_input_key, input.get("input")
874897
)
875898
if isinstance(user_input, str):
876-
return {output_key: chunk}
899+
return {output_key: text_content}
877900
elif isinstance(user_input, list):
878901
if all(
879902
isinstance(msg, dict) and "role" in msg for msg in user_input
880903
):
881-
return {output_key: {"role": "assistant", "content": chunk}}
904+
return {
905+
output_key: {"role": "assistant", "content": text_content}
906+
}
882907
elif all(isinstance(msg, BaseMessage) for msg in user_input):
883-
return {output_key: AIMessageChunk(content=chunk)}
884-
return {output_key: chunk}
908+
return {
909+
output_key: AIMessageChunk(content=text_content, **metadata)
910+
}
911+
return {output_key: text_content}
885912
elif isinstance(user_input, BaseMessage):
886-
return {output_key: AIMessageChunk(content=chunk)}
887-
return {output_key: chunk}
913+
return {
914+
output_key: AIMessageChunk(content=text_content, **metadata)
915+
}
916+
return {output_key: text_content}
888917
elif isinstance(input, str):
889-
return AIMessageChunk(content=chunk)
918+
return AIMessageChunk(content=text_content, **metadata)
890919
else:
891920
raise ValueError(f"Unexpected input type: {type(input)}")
892921

tests/runnable_rails/test_metadata.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,75 @@ def test_partial_metadata(self, mock_rails_config):
373373
assert result.content == "Test response"
374374
assert result.additional_kwargs == {"custom_field": "value"}
375375
assert result.response_metadata is None or result.response_metadata == {}
376+
377+
def test_streaming_metadata_preservation(self, mock_rails_config):
378+
"""Test that streaming preserves metadata in chunks."""
379+
from unittest.mock import AsyncMock
380+
381+
mock_rails = AsyncMock()
382+
mock_generation_response = Mock()
383+
mock_generation_response.response = "Streaming response"
384+
mock_generation_response.output_data = {}
385+
mock_generation_response.tool_calls = None
386+
mock_generation_response.llm_metadata = {
387+
"additional_kwargs": {"finish_reason": "stop"},
388+
"response_metadata": {"model_name": "test-model"},
389+
}
390+
391+
async def mock_stream(*args, **kwargs):
392+
chunks = [
393+
{
394+
"text": "Hello ",
395+
"generation_info": {"model": "test-model", "finish_reason": "stop"},
396+
},
397+
{
398+
"text": "world!",
399+
"generation_info": {"model": "test-model", "finish_reason": "stop"},
400+
},
401+
]
402+
for chunk in chunks:
403+
yield chunk
404+
405+
mock_rails.stream_async = mock_stream
406+
407+
runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
408+
runnable_rails.rails = mock_rails
409+
410+
chunks = list(runnable_rails.stream("Test input"))
411+
412+
assert len(chunks) == 2
413+
for chunk in chunks:
414+
assert hasattr(chunk, "content")
415+
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")
416+
assert hasattr(chunk, "response_metadata") or hasattr(
417+
chunk, "finish_reason"
418+
)
419+
420+
@pytest.mark.asyncio
421+
async def test_async_streaming_metadata_preservation(self, mock_rails_config):
422+
"""Test that async streaming preserves metadata in chunks."""
423+
from unittest.mock import AsyncMock
424+
425+
mock_rails = AsyncMock()
426+
427+
async def mock_stream(*args, **kwargs):
428+
chunks = [
429+
{"text": "Async ", "generation_info": {"model": "test-model"}},
430+
{"text": "stream!", "generation_info": {"model": "test-model"}},
431+
]
432+
for chunk in chunks:
433+
yield chunk
434+
435+
mock_rails.stream_async = mock_stream
436+
437+
runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
438+
runnable_rails.rails = mock_rails
439+
440+
chunks = []
441+
async for chunk in runnable_rails.astream("Test input"):
442+
chunks.append(chunk)
443+
444+
assert len(chunks) == 2
445+
for chunk in chunks:
446+
assert hasattr(chunk, "content")
447+
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")

tests/runnable_rails/test_streaming.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,10 @@ async def _astream(self, messages, stop=None, run_manager=None, **kwargs):
6060

6161
def test_runnable_rails_basic_streaming():
6262
"""Test basic synchronous streaming functionality."""
63-
# Configure a streaming LLM with a response
6463
llm = StreamingFakeLLM(responses=["Hello there! How can I help you today?"])
6564
config = RailsConfig.from_content(config={"models": []})
6665
rails = RunnableRails(config, llm=llm)
6766

68-
# Collect chunks from the stream
6967
chunks = []
7068
for chunk in rails.stream("Hi there"):
7169
chunks.append(chunk)
@@ -451,3 +449,56 @@ def test_streaming_with_different_input_types():
451449
), f"Failed for {input_type}: {full_content}"
452450

453451
assert llm.streaming == False
452+
453+
454+
def test_streaming_metadata_preservation():
455+
"""Test that streaming chunks preserve metadata structure."""
456+
llm = FakeLLM(responses=["Test response"])
457+
config = RailsConfig.from_content(config={"models": []})
458+
model_with_rails = RunnableRails(config, llm=llm)
459+
460+
chunks = []
461+
for chunk in model_with_rails.stream("Test input"):
462+
chunks.append(chunk)
463+
464+
assert len(chunks) > 0
465+
466+
for chunk in chunks:
467+
assert hasattr(chunk, "content")
468+
assert hasattr(chunk, "additional_kwargs")
469+
assert hasattr(chunk, "response_metadata")
470+
assert isinstance(chunk.additional_kwargs, dict)
471+
assert isinstance(chunk.response_metadata, dict)
472+
473+
474+
@pytest.mark.asyncio
475+
async def test_async_streaming_metadata_preservation():
476+
"""Test that async streaming chunks preserve metadata structure."""
477+
llm = FakeLLM(responses=["Test async response"])
478+
config = RailsConfig.from_content(config={"models": []})
479+
model_with_rails = RunnableRails(config, llm=llm)
480+
481+
chunks = []
482+
async for chunk in model_with_rails.astream("Test input"):
483+
chunks.append(chunk)
484+
485+
assert len(chunks) > 0
486+
487+
for chunk in chunks:
488+
assert hasattr(chunk, "content")
489+
assert hasattr(chunk, "additional_kwargs")
490+
assert hasattr(chunk, "response_metadata")
491+
assert isinstance(chunk.additional_kwargs, dict)
492+
assert isinstance(chunk.response_metadata, dict)
493+
494+
495+
def test_streaming_chunk_types():
496+
"""Test that streaming returns proper AIMessageChunk types."""
497+
llm = FakeLLM(responses=["Hello world"])
498+
config = RailsConfig.from_content(config={"models": []})
499+
model_with_rails = RunnableRails(config, llm=llm)
500+
501+
chunks = list(model_with_rails.stream("Hi"))
502+
503+
for chunk in chunks:
504+
assert chunk.__class__.__name__ == "AIMessageChunk"

0 commit comments

Comments
 (0)