@@ -221,6 +221,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
221
221
"role" : "context" ,
222
222
"content" : {
223
223
"passthrough_input" : _input ,
224
+ # We also set all the input variables as top level context variables
224
225
** (_input if isinstance (_input , dict ) else {}),
225
226
},
226
227
},
@@ -841,55 +842,87 @@ async def astream(
841
842
streaming_enabled = True
842
843
843
844
try :
844
- async for chunk in self .rails .stream_async (messages = input_messages ):
845
+ from nemoguardrails .streaming import END_OF_STREAM
846
+
847
+ async for chunk in self .rails .stream_async (
848
+ messages = input_messages , include_generation_metadata = True
849
+ ):
850
+ # Skip END_OF_STREAM markers
851
+ chunk_text = (
852
+ chunk ["text" ]
853
+ if isinstance (chunk , dict ) and "text" in chunk
854
+ else chunk
855
+ )
856
+ if chunk_text is END_OF_STREAM :
857
+ continue
858
+
845
859
# Format the chunk based on the input type for streaming
846
860
formatted_chunk = self ._format_streaming_chunk (input , chunk )
847
861
yield formatted_chunk
848
862
finally :
849
863
if streaming_enabled and hasattr (self .rails .llm , "streaming" ):
850
864
self .rails .llm .streaming = original_streaming
851
865
852
- def _format_streaming_chunk (self , input : Any , chunk : str ) -> Any :
866
+ def _format_streaming_chunk (self , input : Any , chunk ) -> Any :
853
867
"""Format a streaming chunk based on the input type.
854
868
855
869
Args:
856
870
input: The original input
857
- chunk: The current text chunk
871
+ chunk: The current chunk (string or dict with text/generation_info)
858
872
859
873
Returns:
860
874
The formatted streaming chunk (using AIMessageChunk for LangChain compatibility)
861
875
"""
876
+ # Extract text and metadata from chunk if it's a dict with generation metadata
877
+ text_content = chunk
878
+ metadata = {}
879
+
880
+ if isinstance (chunk , dict ) and "text" in chunk :
881
+ text_content = chunk ["text" ]
882
+ generation_info = chunk .get ("generation_info" , {})
883
+
884
+ # Use generation_info as metadata for streaming chunks
885
+ if generation_info :
886
+ metadata = generation_info .copy ()
862
887
if isinstance (input , ChatPromptValue ):
863
- return AIMessageChunk (content = chunk )
888
+ return AIMessageChunk (content = text_content , ** metadata )
864
889
elif isinstance (input , StringPromptValue ):
865
- return chunk
890
+ return text_content # String outputs don't support metadata
866
891
elif isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
867
- return AIMessageChunk (content = chunk )
892
+ return AIMessageChunk (content = text_content , ** metadata )
868
893
elif isinstance (input , list ) and all (
869
894
isinstance (msg , BaseMessage ) for msg in input
870
895
):
871
- return AIMessageChunk (content = chunk )
896
+ return AIMessageChunk (content = text_content , ** metadata )
872
897
elif isinstance (input , dict ):
873
898
output_key = self .passthrough_bot_output_key
874
899
if self .passthrough_user_input_key in input or "input" in input :
875
900
user_input = input .get (
876
901
self .passthrough_user_input_key , input .get ("input" )
877
902
)
878
903
if isinstance (user_input , str ):
879
- return {output_key : chunk }
904
+ return {
905
+ output_key : text_content
906
+ } # Dict outputs don't support metadata
880
907
elif isinstance (user_input , list ):
881
908
if all (
882
909
isinstance (msg , dict ) and "role" in msg for msg in user_input
883
910
):
884
- return {output_key : {"role" : "assistant" , "content" : chunk }}
911
+ return {
912
+ output_key : {"role" : "assistant" , "content" : text_content }
913
+ }
885
914
elif all (isinstance (msg , BaseMessage ) for msg in user_input ):
886
- return {output_key : AIMessageChunk (content = chunk )}
887
- return {output_key : chunk }
915
+ return {
916
+ output_key : AIMessageChunk (content = text_content , ** metadata )
917
+ }
918
+ return {output_key : text_content }
888
919
elif isinstance (user_input , BaseMessage ):
889
- return {output_key : AIMessageChunk (content = chunk )}
890
- return {output_key : chunk }
920
+ return {
921
+ output_key : AIMessageChunk (content = text_content , ** metadata )
922
+ }
923
+ return {output_key : text_content }
891
924
elif isinstance (input , str ):
892
- return AIMessageChunk (content = chunk )
925
+ return AIMessageChunk (content = text_content , ** metadata )
893
926
else :
894
927
raise ValueError (f"Unexpected input type: { type (input )} " )
895
928
0 commit comments