@@ -221,6 +221,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
221
221
"role" : "context" ,
222
222
"content" : {
223
223
"passthrough_input" : _input ,
224
+ # We also set all the input variables as top level context variables
224
225
** (_input if isinstance (_input , dict ) else {}),
225
226
},
226
227
},
@@ -841,55 +842,83 @@ async def astream(
841
842
streaming_enabled = True
842
843
843
844
try :
844
- async for chunk in self .rails .stream_async (messages = input_messages ):
845
+ from nemoguardrails .streaming import END_OF_STREAM
846
+
847
+ async for chunk in self .rails .stream_async (
848
+ messages = input_messages , include_generation_metadata = True
849
+ ):
850
+ # Skip END_OF_STREAM markers
851
+ chunk_text = (
852
+ chunk ["text" ]
853
+ if isinstance (chunk , dict ) and "text" in chunk
854
+ else chunk
855
+ )
856
+ if chunk_text is END_OF_STREAM :
857
+ continue
858
+
845
859
# Format the chunk based on the input type for streaming
846
860
formatted_chunk = self ._format_streaming_chunk (input , chunk )
847
861
yield formatted_chunk
848
862
finally :
849
863
if streaming_enabled and hasattr (self .rails .llm , "streaming" ):
850
864
self .rails .llm .streaming = original_streaming
851
865
852
- def _format_streaming_chunk (self , input : Any , chunk : str ) -> Any :
866
+ def _format_streaming_chunk (self , input : Any , chunk ) -> Any :
853
867
"""Format a streaming chunk based on the input type.
854
868
855
869
Args:
856
870
input: The original input
857
- chunk: The current text chunk
871
+ chunk: The current chunk (string or dict with text/generation_info)
858
872
859
873
Returns:
860
874
The formatted streaming chunk (using AIMessageChunk for LangChain compatibility)
861
875
"""
876
+ text_content = chunk
877
+ metadata = {}
878
+
879
+ if isinstance (chunk , dict ) and "text" in chunk :
880
+ text_content = chunk ["text" ]
881
+ generation_info = chunk .get ("generation_info" , {})
882
+
883
+ if generation_info :
884
+ metadata = generation_info .copy ()
862
885
if isinstance (input , ChatPromptValue ):
863
- return AIMessageChunk (content = chunk )
886
+ return AIMessageChunk (content = text_content , ** metadata )
864
887
elif isinstance (input , StringPromptValue ):
865
- return chunk
888
+ return text_content # String outputs don't support metadata
866
889
elif isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
867
- return AIMessageChunk (content = chunk )
890
+ return AIMessageChunk (content = text_content , ** metadata )
868
891
elif isinstance (input , list ) and all (
869
892
isinstance (msg , BaseMessage ) for msg in input
870
893
):
871
- return AIMessageChunk (content = chunk )
894
+ return AIMessageChunk (content = text_content , ** metadata )
872
895
elif isinstance (input , dict ):
873
896
output_key = self .passthrough_bot_output_key
874
897
if self .passthrough_user_input_key in input or "input" in input :
875
898
user_input = input .get (
876
899
self .passthrough_user_input_key , input .get ("input" )
877
900
)
878
901
if isinstance (user_input , str ):
879
- return {output_key : chunk }
902
+ return {output_key : text_content }
880
903
elif isinstance (user_input , list ):
881
904
if all (
882
905
isinstance (msg , dict ) and "role" in msg for msg in user_input
883
906
):
884
- return {output_key : {"role" : "assistant" , "content" : chunk }}
907
+ return {
908
+ output_key : {"role" : "assistant" , "content" : text_content }
909
+ }
885
910
elif all (isinstance (msg , BaseMessage ) for msg in user_input ):
886
- return {output_key : AIMessageChunk (content = chunk )}
887
- return {output_key : chunk }
911
+ return {
912
+ output_key : AIMessageChunk (content = text_content , ** metadata )
913
+ }
914
+ return {output_key : text_content }
888
915
elif isinstance (user_input , BaseMessage ):
889
- return {output_key : AIMessageChunk (content = chunk )}
890
- return {output_key : chunk }
916
+ return {
917
+ output_key : AIMessageChunk (content = text_content , ** metadata )
918
+ }
919
+ return {output_key : text_content }
891
920
elif isinstance (input , str ):
892
- return AIMessageChunk (content = chunk )
921
+ return AIMessageChunk (content = text_content , ** metadata )
893
922
else :
894
923
raise ValueError (f"Unexpected input type: { type (input )} " )
895
924
0 commit comments