@@ -218,6 +218,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
218
218
"role" : "context" ,
219
219
"content" : {
220
220
"passthrough_input" : _input ,
221
+ # We also set all the input variables as top level context variables
221
222
** (_input if isinstance (_input , dict ) else {}),
222
223
},
223
224
},
@@ -838,55 +839,83 @@ async def astream(
838
839
streaming_enabled = True
839
840
840
841
try :
841
- async for chunk in self .rails .stream_async (messages = input_messages ):
842
+ from nemoguardrails .streaming import END_OF_STREAM
843
+
844
+ async for chunk in self .rails .stream_async (
845
+ messages = input_messages , include_generation_metadata = True
846
+ ):
847
+ # Skip END_OF_STREAM markers
848
+ chunk_text = (
849
+ chunk ["text" ]
850
+ if isinstance (chunk , dict ) and "text" in chunk
851
+ else chunk
852
+ )
853
+ if chunk_text is END_OF_STREAM :
854
+ continue
855
+
842
856
# Format the chunk based on the input type for streaming
843
857
formatted_chunk = self ._format_streaming_chunk (input , chunk )
844
858
yield formatted_chunk
845
859
finally :
846
860
if streaming_enabled and hasattr (self .rails .llm , "streaming" ):
847
861
self .rails .llm .streaming = original_streaming
848
862
849
- def _format_streaming_chunk (self , input : Any , chunk : str ) -> Any :
863
+ def _format_streaming_chunk (self , input : Any , chunk ) -> Any :
850
864
"""Format a streaming chunk based on the input type.
851
865
852
866
Args:
853
867
input: The original input
854
- chunk: The current text chunk
868
+ chunk: The current chunk (string or dict with text/generation_info)
855
869
856
870
Returns:
857
871
The formatted streaming chunk (using AIMessageChunk for LangChain compatibility)
858
872
"""
873
+ text_content = chunk
874
+ metadata = {}
875
+
876
+ if isinstance (chunk , dict ) and "text" in chunk :
877
+ text_content = chunk ["text" ]
878
+ generation_info = chunk .get ("generation_info" , {})
879
+
880
+ if generation_info :
881
+ metadata = generation_info .copy ()
859
882
if isinstance (input , ChatPromptValue ):
860
- return AIMessageChunk (content = chunk )
883
+ return AIMessageChunk (content = text_content , ** metadata )
861
884
elif isinstance (input , StringPromptValue ):
862
- return chunk
885
+ return text_content # String outputs don't support metadata
863
886
elif isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
864
- return AIMessageChunk (content = chunk )
887
+ return AIMessageChunk (content = text_content , ** metadata )
865
888
elif isinstance (input , list ) and all (
866
889
isinstance (msg , BaseMessage ) for msg in input
867
890
):
868
- return AIMessageChunk (content = chunk )
891
+ return AIMessageChunk (content = text_content , ** metadata )
869
892
elif isinstance (input , dict ):
870
893
output_key = self .passthrough_bot_output_key
871
894
if self .passthrough_user_input_key in input or "input" in input :
872
895
user_input = input .get (
873
896
self .passthrough_user_input_key , input .get ("input" )
874
897
)
875
898
if isinstance (user_input , str ):
876
- return {output_key : chunk }
899
+ return {output_key : text_content }
877
900
elif isinstance (user_input , list ):
878
901
if all (
879
902
isinstance (msg , dict ) and "role" in msg for msg in user_input
880
903
):
881
- return {output_key : {"role" : "assistant" , "content" : chunk }}
904
+ return {
905
+ output_key : {"role" : "assistant" , "content" : text_content }
906
+ }
882
907
elif all (isinstance (msg , BaseMessage ) for msg in user_input ):
883
- return {output_key : AIMessageChunk (content = chunk )}
884
- return {output_key : chunk }
908
+ return {
909
+ output_key : AIMessageChunk (content = text_content , ** metadata )
910
+ }
911
+ return {output_key : text_content }
885
912
elif isinstance (user_input , BaseMessage ):
886
- return {output_key : AIMessageChunk (content = chunk )}
887
- return {output_key : chunk }
913
+ return {
914
+ output_key : AIMessageChunk (content = text_content , ** metadata )
915
+ }
916
+ return {output_key : text_content }
888
917
elif isinstance (input , str ):
889
- return AIMessageChunk (content = chunk )
918
+ return AIMessageChunk (content = text_content , ** metadata )
890
919
else :
891
920
raise ValueError (f"Unexpected input type: { type (input )} " )
892
921
0 commit comments