Skip to content

Commit 2ddc4f3

Browse files
committed
feat(runnable-rails): stream metadata in RunnableRails outputN
Enhance streaming in RunnableRails to include generation metadata in streamed chunks. Skips END_OF_STREAM markers and updates chunk formatting to support metadata for AIMessageChunk outputs. This improves compatibility with consumers expecting metadata in streaming responses.
1 parent 8e1033c commit 2ddc4f3

File tree

3 files changed

+170
-14
lines changed

3 files changed

+170
-14
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
221221
"role": "context",
222222
"content": {
223223
"passthrough_input": _input,
224+
# We also set all the input variables as top level context variables
224225
**(_input if isinstance(_input, dict) else {}),
225226
},
226227
},
@@ -841,55 +842,87 @@ async def astream(
841842
streaming_enabled = True
842843

843844
try:
844-
async for chunk in self.rails.stream_async(messages=input_messages):
845+
from nemoguardrails.streaming import END_OF_STREAM
846+
847+
async for chunk in self.rails.stream_async(
848+
messages=input_messages, include_generation_metadata=True
849+
):
850+
# Skip END_OF_STREAM markers
851+
chunk_text = (
852+
chunk["text"]
853+
if isinstance(chunk, dict) and "text" in chunk
854+
else chunk
855+
)
856+
if chunk_text is END_OF_STREAM:
857+
continue
858+
845859
# Format the chunk based on the input type for streaming
846860
formatted_chunk = self._format_streaming_chunk(input, chunk)
847861
yield formatted_chunk
848862
finally:
849863
if streaming_enabled and hasattr(self.rails.llm, "streaming"):
850864
self.rails.llm.streaming = original_streaming
851865

852-
def _format_streaming_chunk(self, input: Any, chunk: str) -> Any:
866+
def _format_streaming_chunk(self, input: Any, chunk) -> Any:
853867
"""Format a streaming chunk based on the input type.
854868
855869
Args:
856870
input: The original input
857-
chunk: The current text chunk
871+
chunk: The current chunk (string or dict with text/generation_info)
858872
859873
Returns:
860874
The formatted streaming chunk (using AIMessageChunk for LangChain compatibility)
861875
"""
876+
# Extract text and metadata from chunk if it's a dict with generation metadata
877+
text_content = chunk
878+
metadata = {}
879+
880+
if isinstance(chunk, dict) and "text" in chunk:
881+
text_content = chunk["text"]
882+
generation_info = chunk.get("generation_info", {})
883+
884+
# Use generation_info as metadata for streaming chunks
885+
if generation_info:
886+
metadata = generation_info.copy()
862887
if isinstance(input, ChatPromptValue):
863-
return AIMessageChunk(content=chunk)
888+
return AIMessageChunk(content=text_content, **metadata)
864889
elif isinstance(input, StringPromptValue):
865-
return chunk
890+
return text_content # String outputs don't support metadata
866891
elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)):
867-
return AIMessageChunk(content=chunk)
892+
return AIMessageChunk(content=text_content, **metadata)
868893
elif isinstance(input, list) and all(
869894
isinstance(msg, BaseMessage) for msg in input
870895
):
871-
return AIMessageChunk(content=chunk)
896+
return AIMessageChunk(content=text_content, **metadata)
872897
elif isinstance(input, dict):
873898
output_key = self.passthrough_bot_output_key
874899
if self.passthrough_user_input_key in input or "input" in input:
875900
user_input = input.get(
876901
self.passthrough_user_input_key, input.get("input")
877902
)
878903
if isinstance(user_input, str):
879-
return {output_key: chunk}
904+
return {
905+
output_key: text_content
906+
} # Dict outputs don't support metadata
880907
elif isinstance(user_input, list):
881908
if all(
882909
isinstance(msg, dict) and "role" in msg for msg in user_input
883910
):
884-
return {output_key: {"role": "assistant", "content": chunk}}
911+
return {
912+
output_key: {"role": "assistant", "content": text_content}
913+
}
885914
elif all(isinstance(msg, BaseMessage) for msg in user_input):
886-
return {output_key: AIMessageChunk(content=chunk)}
887-
return {output_key: chunk}
915+
return {
916+
output_key: AIMessageChunk(content=text_content, **metadata)
917+
}
918+
return {output_key: text_content}
888919
elif isinstance(user_input, BaseMessage):
889-
return {output_key: AIMessageChunk(content=chunk)}
890-
return {output_key: chunk}
920+
return {
921+
output_key: AIMessageChunk(content=text_content, **metadata)
922+
}
923+
return {output_key: text_content}
891924
elif isinstance(input, str):
892-
return AIMessageChunk(content=chunk)
925+
return AIMessageChunk(content=text_content, **metadata)
893926
else:
894927
raise ValueError(f"Unexpected input type: {type(input)}")
895928

tests/runnable_rails/test_metadata.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,74 @@ def test_partial_metadata(self, mock_rails_config):
312312
assert result.additional_kwargs == {"custom_field": "value"}
313313
# Missing fields should be handled gracefully (None or default values)
314314
assert result.response_metadata is None or result.response_metadata == {}
315+
316+
def test_streaming_metadata_preservation(self, mock_rails_config):
317+
"""Test that streaming preserves metadata in chunks."""
318+
from unittest.mock import AsyncMock
319+
320+
mock_rails = AsyncMock()
321+
mock_generation_response = Mock()
322+
mock_generation_response.response = "Streaming response"
323+
mock_generation_response.output_data = {}
324+
mock_generation_response.tool_calls = None
325+
mock_generation_response.llm_metadata = {
326+
"additional_kwargs": {"finish_reason": "stop"},
327+
"response_metadata": {"model_name": "test-model"},
328+
}
329+
330+
async def mock_stream(*args, **kwargs):
331+
chunks = [
332+
{
333+
"text": "Hello ",
334+
"generation_info": {"model": "test-model", "finish_reason": "stop"},
335+
},
336+
{
337+
"text": "world!",
338+
"generation_info": {"model": "test-model", "finish_reason": "stop"},
339+
},
340+
]
341+
for chunk in chunks:
342+
yield chunk
343+
344+
mock_rails.stream_async = mock_stream
345+
346+
runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
347+
runnable_rails.rails = mock_rails
348+
349+
chunks = list(runnable_rails.stream("Test input"))
350+
351+
assert len(chunks) == 2
352+
for chunk in chunks:
353+
assert hasattr(chunk, "content")
354+
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")
355+
assert hasattr(chunk, "response_metadata") or hasattr(
356+
chunk, "finish_reason"
357+
)
358+
359+
async def test_async_streaming_metadata_preservation(self, mock_rails_config):
360+
"""Test that async streaming preserves metadata in chunks."""
361+
from unittest.mock import AsyncMock
362+
363+
mock_rails = AsyncMock()
364+
365+
async def mock_stream(*args, **kwargs):
366+
chunks = [
367+
{"text": "Async ", "generation_info": {"model": "test-model"}},
368+
{"text": "stream!", "generation_info": {"model": "test-model"}},
369+
]
370+
for chunk in chunks:
371+
yield chunk
372+
373+
mock_rails.stream_async = mock_stream
374+
375+
runnable_rails = RunnableRails(config=mock_rails_config, passthrough=True)
376+
runnable_rails.rails = mock_rails
377+
378+
chunks = []
379+
async for chunk in runnable_rails.astream("Test input"):
380+
chunks.append(chunk)
381+
382+
assert len(chunks) == 2
383+
for chunk in chunks:
384+
assert hasattr(chunk, "content")
385+
assert hasattr(chunk, "additional_kwargs") or hasattr(chunk, "model")

tests/runnable_rails/test_streaming.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,55 @@ def test_streaming_with_different_input_types():
480480
), f"Failed for {input_type}: {full_content}"
481481

482482
assert llm.streaming == False
483+
484+
485+
def test_streaming_metadata_preservation():
486+
"""Test that streaming chunks preserve metadata structure."""
487+
llm = FakeLLM(responses=["Test response"])
488+
config = RailsConfig.from_content(config={"models": []})
489+
model_with_rails = RunnableRails(config, llm=llm)
490+
491+
chunks = []
492+
for chunk in model_with_rails.stream("Test input"):
493+
chunks.append(chunk)
494+
495+
assert len(chunks) > 0
496+
497+
for chunk in chunks:
498+
assert hasattr(chunk, "content")
499+
assert hasattr(chunk, "additional_kwargs")
500+
assert hasattr(chunk, "response_metadata")
501+
assert isinstance(chunk.additional_kwargs, dict)
502+
assert isinstance(chunk.response_metadata, dict)
503+
504+
505+
async def test_async_streaming_metadata_preservation():
506+
"""Test that async streaming chunks preserve metadata structure."""
507+
llm = FakeLLM(responses=["Test async response"])
508+
config = RailsConfig.from_content(config={"models": []})
509+
model_with_rails = RunnableRails(config, llm=llm)
510+
511+
chunks = []
512+
async for chunk in model_with_rails.astream("Test input"):
513+
chunks.append(chunk)
514+
515+
assert len(chunks) > 0
516+
517+
for chunk in chunks:
518+
assert hasattr(chunk, "content")
519+
assert hasattr(chunk, "additional_kwargs")
520+
assert hasattr(chunk, "response_metadata")
521+
assert isinstance(chunk.additional_kwargs, dict)
522+
assert isinstance(chunk.response_metadata, dict)
523+
524+
525+
def test_streaming_chunk_types():
526+
"""Test that streaming returns proper AIMessageChunk types."""
527+
llm = FakeLLM(responses=["Hello world"])
528+
config = RailsConfig.from_content(config={"models": []})
529+
model_with_rails = RunnableRails(config, llm=llm)
530+
531+
chunks = list(model_with_rails.stream("Hi"))
532+
533+
for chunk in chunks:
534+
assert chunk.__class__.__name__ == "AIMessageChunk"

0 commit comments

Comments
 (0)