@@ -221,6 +221,7 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
221221                "role" : "context" ,
222222                "content" : {
223223                    "passthrough_input" : _input ,
224+                     # We also set all the input variables as top level context variables 
224225                    ** (_input  if  isinstance (_input , dict ) else  {}),
225226                },
226227            },
@@ -841,55 +842,87 @@ async def astream(
841842            streaming_enabled  =  True 
842843
843844        try :
844-             async  for  chunk  in  self .rails .stream_async (messages = input_messages ):
845+             from  nemoguardrails .streaming  import  END_OF_STREAM 
846+ 
847+             async  for  chunk  in  self .rails .stream_async (
848+                 messages = input_messages , include_generation_metadata = True 
849+             ):
850+                 # Skip END_OF_STREAM markers 
851+                 chunk_text  =  (
852+                     chunk ["text" ]
853+                     if  isinstance (chunk , dict ) and  "text"  in  chunk 
854+                     else  chunk 
855+                 )
856+                 if  chunk_text  is  END_OF_STREAM :
857+                     continue 
858+ 
845859                # Format the chunk based on the input type for streaming 
846860                formatted_chunk  =  self ._format_streaming_chunk (input , chunk )
847861                yield  formatted_chunk 
848862        finally :
849863            if  streaming_enabled  and  hasattr (self .rails .llm , "streaming" ):
850864                self .rails .llm .streaming  =  original_streaming 
851865
852-     def  _format_streaming_chunk (self , input : Any , chunk :  str ) ->  Any :
866+     def  _format_streaming_chunk (self , input : Any , chunk ) ->  Any :
853867        """Format a streaming chunk based on the input type. 
854868
855869        Args: 
856870            input: The original input 
857-             chunk: The current text  chunk 
871+             chunk: The current chunk (string or dict with text/generation_info)  
858872
859873        Returns: 
860874            The formatted streaming chunk (using AIMessageChunk for LangChain compatibility) 
861875        """ 
876+         # Extract text and metadata from chunk if it's a dict with generation metadata 
877+         text_content  =  chunk 
878+         metadata  =  {}
879+ 
880+         if  isinstance (chunk , dict ) and  "text"  in  chunk :
881+             text_content  =  chunk ["text" ]
882+             generation_info  =  chunk .get ("generation_info" , {})
883+ 
884+             # Use generation_info as metadata for streaming chunks 
885+             if  generation_info :
886+                 metadata  =  generation_info .copy ()
862887        if  isinstance (input , ChatPromptValue ):
863-             return  AIMessageChunk (content = chunk )
888+             return  AIMessageChunk (content = text_content ,  ** metadata )
864889        elif  isinstance (input , StringPromptValue ):
865-             return  chunk 
890+             return  text_content    # String outputs don't support metadata 
866891        elif  isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
867-             return  AIMessageChunk (content = chunk )
892+             return  AIMessageChunk (content = text_content ,  ** metadata )
868893        elif  isinstance (input , list ) and  all (
869894            isinstance (msg , BaseMessage ) for  msg  in  input 
870895        ):
871-             return  AIMessageChunk (content = chunk )
896+             return  AIMessageChunk (content = text_content ,  ** metadata )
872897        elif  isinstance (input , dict ):
873898            output_key  =  self .passthrough_bot_output_key 
874899            if  self .passthrough_user_input_key  in  input  or  "input"  in  input :
875900                user_input  =  input .get (
876901                    self .passthrough_user_input_key , input .get ("input" )
877902                )
878903                if  isinstance (user_input , str ):
879-                     return  {output_key : chunk }
904+                     return  {
905+                         output_key : text_content 
906+                     }  # Dict outputs don't support metadata 
880907                elif  isinstance (user_input , list ):
881908                    if  all (
882909                        isinstance (msg , dict ) and  "role"  in  msg  for  msg  in  user_input 
883910                    ):
884-                         return  {output_key : {"role" : "assistant" , "content" : chunk }}
911+                         return  {
912+                             output_key : {"role" : "assistant" , "content" : text_content }
913+                         }
885914                    elif  all (isinstance (msg , BaseMessage ) for  msg  in  user_input ):
886-                         return  {output_key : AIMessageChunk (content = chunk )}
887-                     return  {output_key : chunk }
915+                         return  {
916+                             output_key : AIMessageChunk (content = text_content , ** metadata )
917+                         }
918+                     return  {output_key : text_content }
888919                elif  isinstance (user_input , BaseMessage ):
889-                     return  {output_key : AIMessageChunk (content = chunk )}
890-             return  {output_key : chunk }
920+                     return  {
921+                         output_key : AIMessageChunk (content = text_content , ** metadata )
922+                     }
923+             return  {output_key : text_content }
891924        elif  isinstance (input , str ):
892-             return  AIMessageChunk (content = chunk )
925+             return  AIMessageChunk (content = text_content ,  ** metadata )
893926        else :
894927            raise  ValueError (f"Unexpected input type: { type (input )}  " )
895928
0 commit comments