Skip to content

Commit f8004f8

Browse files
committed
feat(runnable-rails): stream metadata in RunnableRails output
Enhance streaming in RunnableRails to include generation metadata in streamed chunks. Skips END_OF_STREAM markers and updates chunk formatting to support metadata for AIMessageChunk outputs. This improves compatibility with consumers expecting metadata in streaming responses.
1 parent 0b114ba commit f8004f8

File tree

3 files changed

+166
-16
lines changed

3 files changed

+166
-16
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
221221
"role": "context",
222222
"content": {
223223
"passthrough_input": _input,
224+
# We also set all the input variables as top level context variables
224225
**(_input if isinstance(_input, dict) else {}),
225226
},
226227
},
@@ -841,55 +842,83 @@ async def astream(
841842
streaming_enabled = True
842843

843844
try:
844-
async for chunk in self.rails.stream_async(messages=input_messages):
845+
from nemoguardrails.streaming import END_OF_STREAM
846+
847+
async for chunk in self.rails.stream_async(
848+
messages=input_messages, include_generation_metadata=True
849+
):
850+
# Skip END_OF_STREAM markers
851+
chunk_text = (
852+
chunk["text"]
853+
if isinstance(chunk, dict) and "text" in chunk
854+
else chunk
855+
)
856+
if chunk_text is END_OF_STREAM:
857+
continue
858+
845859
# Format the chunk based on the input type for streaming
846860
formatted_chunk = self._format_streaming_chunk(input, chunk)
847861
yield formatted_chunk
848862
finally:
849863
if streaming_enabled and hasattr(self.rails.llm, "streaming"):
850864
self.rails.llm.streaming = original_streaming
851865

852-
def _format_streaming_chunk(self, input: Any, chunk: str) -> Any:
866+
def _format_streaming_chunk(self, input: Any, chunk) -> Any:
853867
"""Format a streaming chunk based on the input type.
854868
855869
Args:
856870
input: The original input
857-
chunk: The current text chunk
871+
chunk: The current chunk (string or dict with text/generation_info)
858872
859873
Returns:
860874
The formatted streaming chunk (using AIMessageChunk for LangChain compatibility)
861875
"""
876+
text_content = chunk
877+
metadata = {}
878+
879+
if isinstance(chunk, dict) and "text" in chunk:
880+
text_content = chunk["text"]
881+
generation_info = chunk.get("generation_info", {})
882+
883+
if generation_info:
884+
metadata = generation_info.copy()
862885
if isinstance(input, ChatPromptValue):
863-
return AIMessageChunk(content=chunk)
886+
return AIMessageChunk(content=text_content, **metadata)
864887
elif isinstance(input, StringPromptValue):
865-
return chunk
888+
return text_content # String outputs don't support metadata
866889
elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)):
867-
return AIMessageChunk(content=chunk)
890+
return AIMessageChunk(content=text_content, **metadata)
868891
elif isinstance(input, list) and all(
869892
isinstance(msg, BaseMessage) for msg in input
870893
):
871-
return AIMessageChunk(content=chunk)
894+
return AIMessageChunk(content=text_content, **metadata)
872895
elif isinstance(input, dict):
873896
output_key = self.passthrough_bot_output_key
874897
if self.passthrough_user_input_key in input or "input" in input:
875898
user_input = input.get(
876899
self.passthrough_user_input_key, input.get("input")
877900
)
878901
if isinstance(user_input, str):
879-
return {output_key: chunk}
902+
return {output_key: text_content}
880903
elif isinstance(user_input, list):
881904
if all(
882905
isinstance(msg, dict) and "role" in msg for msg in user_input
883906
):
884-
return {output_key: {"role": "assistant", "content": chunk}}
907+
return {
908+
output_key: {"role": "assistant", "content": text_content}
909+
}
885910
elif all(isinstance(msg, BaseMessage) for msg in user_input):
886-
return {output_key: AIMessageChunk(content=chunk)}
887-
return {output_key: chunk}
911+
return {
912+
output_key: AIMessageChunk(content=text_content, **metadata)
913+
}
914+
return {output_key: text_content}
888915
elif isinstance(user_input, BaseMessage):
889-
return {output_key: AIMessageChunk(content=chunk)}
890-
return {output_key: chunk}
916+
return {
917+
output_key: AIMessageChunk(content=text_content, **metadata)
918+
}
919+
return {output_key: text_content}
891920
elif isinstance(input, str):
892-
return AIMessageChunk(content=chunk)
921+
return AIMessageChunk(content=text_content, **metadata)
893922
else:
894923
raise ValueError(f"Unexpected input type: {type(input)}")
895924

tests/runnable_rails/test_metadata.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,74 @@ def test_partial_metadata(self, mock_rails_config):
373373
assert result.content == "Test response"
374374
assert result.additional_kwargs == {"custom_field": "value"}
375375
assert result.response_metadata is None or result.response_metadata == {}
376+
377+
def test_streaming_metadata_preservation(self, mock_rails_config):
378+
"""Test that streaming preserves metadata in chunks."""
379+
from unittest.mock import AsyncMock
380+
381+
mock_rails = AsyncMock()
382+
mock_generation_response = Mock()
383+
mock_generation_response.response = "Streaming response"
384+
mock_generation_response.output_data = {}
385+
mock_generation_response.tool_calls = None
386+
mock_generation_response.llm_metadata = {
387+
"additional_kwargs": {"finish_reason": "stop"},
388+
"response_metadata": {"model_name": "test-model"},
389+
}
390+
391+
async def mock_stream(*args, **kwargs):
392+
chunks = [
393+
{
394+
"text": "Hello ",
395+
"generation_info": {"model": "test-model", "finish_reason": "stop"},
396+
},
397+
{
398+
"text": "world!",
399+
"generation_info": {"model": "test-model", "finish_reason": "stop"},
400+
},
401+
]
402+
for chunk in chunks:
403+
yield chunk
404+
405+
mock_rails.stream_async = mock_stream
406+
407+
runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
408+
runnable_rails.rails = mock_rails
409+
410+
chunks = list(runnable_rails.stream("Test input"))
411+
412+
assert len(chunks) == 2
413+
for chunk in chunks:
414+
assert hasattr(chunk, "content")
415+
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")
416+
assert hasattr(chunk, "response_metadata") or hasattr(
417+
chunk, "finish_reason"
418+
)
419+
420+
async def test_async_streaming_metadata_preservation(self, mock_rails_config):
421+
"""Test that async streaming preserves metadata in chunks."""
422+
from unittest.mock import AsyncMock
423+
424+
mock_rails = AsyncMock()
425+
426+
async def mock_stream(*args, **kwargs):
427+
chunks = [
428+
{"text": "Async ", "generation_info": {"model": "test-model"}},
429+
{"text": "stream!", "generation_info": {"model": "test-model"}},
430+
]
431+
for chunk in chunks:
432+
yield chunk
433+
434+
mock_rails.stream_async = mock_stream
435+
436+
runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
437+
runnable_rails.rails = mock_rails
438+
439+
chunks = []
440+
async for chunk in runnable_rails.astream("Test input"):
441+
chunks.append(chunk)
442+
443+
assert len(chunks) == 2
444+
for chunk in chunks:
445+
assert hasattr(chunk, "content")
446+
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")

tests/runnable_rails/test_streaming.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,10 @@ async def _astream(self, messages, stop=None, run_manager=None, **kwargs):
6060

6161
def test_runnable_rails_basic_streaming():
6262
"""Test basic synchronous streaming functionality."""
63-
# Configure a streaming LLM with a response
6463
llm = StreamingFakeLLM(responses=["Hello there! How can I help you today?"])
6564
config = RailsConfig.from_content(config={"models": []})
6665
rails = RunnableRails(config, llm=llm)
6766

68-
# Collect chunks from the stream
6967
chunks = []
7068
for chunk in rails.stream("Hi there"):
7169
chunks.append(chunk)
@@ -451,3 +449,55 @@ def test_streaming_with_different_input_types():
451449
), f"Failed for {input_type}: {full_content}"
452450

453451
assert llm.streaming == False
452+
453+
454+
def test_streaming_metadata_preservation():
455+
"""Test that streaming chunks preserve metadata structure."""
456+
llm = FakeLLM(responses=["Test response"])
457+
config = RailsConfig.from_content(config={"models": []})
458+
model_with_rails = RunnableRails(config, llm=llm)
459+
460+
chunks = []
461+
for chunk in model_with_rails.stream("Test input"):
462+
chunks.append(chunk)
463+
464+
assert len(chunks) > 0
465+
466+
for chunk in chunks:
467+
assert hasattr(chunk, "content")
468+
assert hasattr(chunk, "additional_kwargs")
469+
assert hasattr(chunk, "response_metadata")
470+
assert isinstance(chunk.additional_kwargs, dict)
471+
assert isinstance(chunk.response_metadata, dict)
472+
473+
474+
async def test_async_streaming_metadata_preservation():
475+
"""Test that async streaming chunks preserve metadata structure."""
476+
llm = FakeLLM(responses=["Test async response"])
477+
config = RailsConfig.from_content(config={"models": []})
478+
model_with_rails = RunnableRails(config, llm=llm)
479+
480+
chunks = []
481+
async for chunk in model_with_rails.astream("Test input"):
482+
chunks.append(chunk)
483+
484+
assert len(chunks) > 0
485+
486+
for chunk in chunks:
487+
assert hasattr(chunk, "content")
488+
assert hasattr(chunk, "additional_kwargs")
489+
assert hasattr(chunk, "response_metadata")
490+
assert isinstance(chunk.additional_kwargs, dict)
491+
assert isinstance(chunk.response_metadata, dict)
492+
493+
494+
def test_streaming_chunk_types():
495+
"""Test that streaming returns proper AIMessageChunk types."""
496+
llm = FakeLLM(responses=["Hello world"])
497+
config = RailsConfig.from_content(config={"models": []})
498+
model_with_rails = RunnableRails(config, llm=llm)
499+
500+
chunks = list(model_with_rails.stream("Hi"))
501+
502+
for chunk in chunks:
503+
assert chunk.__class__.__name__ == "AIMessageChunk"

0 commit comments

Comments
 (0)