@@ -396,11 +396,21 @@ def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> An
396
396
return passthrough_output
397
397
398
398
def _format_chat_prompt_output (
399
- self , result : Any , tool_calls : Optional [list ] = None
399
+ self ,
400
+ result : Any ,
401
+ tool_calls : Optional [list ] = None ,
402
+ metadata : Optional [dict ] = None ,
400
403
) -> AIMessage :
401
404
"""Format output for ChatPromptValue input."""
402
405
content = self ._extract_content_from_result (result )
403
- if tool_calls :
406
+
407
+ if metadata and isinstance (metadata , dict ):
408
+ metadata_copy = metadata .copy ()
409
+ metadata_copy .pop ("content" , None )
410
+ if tool_calls :
411
+ metadata_copy ["tool_calls" ] = tool_calls
412
+ return AIMessage (content = content , ** metadata_copy )
413
+ elif tool_calls :
404
414
return AIMessage (content = content , tool_calls = tool_calls )
405
415
return AIMessage (content = content )
406
416
@@ -409,11 +419,21 @@ def _format_string_prompt_output(self, result: Any) -> str:
409
419
return self ._extract_content_from_result (result )
410
420
411
421
def _format_message_output (
412
- self , result : Any , tool_calls : Optional [list ] = None
422
+ self ,
423
+ result : Any ,
424
+ tool_calls : Optional [list ] = None ,
425
+ metadata : Optional [dict ] = None ,
413
426
) -> AIMessage :
414
427
"""Format output for BaseMessage input types."""
415
428
content = self ._extract_content_from_result (result )
416
- if tool_calls :
429
+
430
+ if metadata and isinstance (metadata , dict ):
431
+ metadata_copy = metadata .copy ()
432
+ metadata_copy .pop ("content" , None )
433
+ if tool_calls :
434
+ metadata_copy ["tool_calls" ] = tool_calls
435
+ return AIMessage (content = content , ** metadata_copy )
436
+ elif tool_calls :
417
437
return AIMessage (content = content , tool_calls = tool_calls )
418
438
return AIMessage (content = content )
419
439
@@ -437,25 +457,50 @@ def _format_dict_output_for_dict_message_list(
437
457
}
438
458
439
459
def _format_dict_output_for_base_message_list (
440
- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
460
+ self ,
461
+ result : Any ,
462
+ output_key : str ,
463
+ tool_calls : Optional [list ] = None ,
464
+ metadata : Optional [dict ] = None ,
441
465
) -> Dict [str , Any ]:
442
466
"""Format dict output when user input was a list of BaseMessage objects."""
443
467
content = self ._extract_content_from_result (result )
444
- if tool_calls :
468
+
469
+ if metadata and isinstance (metadata , dict ):
470
+ metadata_copy = metadata .copy ()
471
+ metadata_copy .pop ("content" , None )
472
+ if tool_calls :
473
+ metadata_copy ["tool_calls" ] = tool_calls
474
+ return {output_key : AIMessage (content = content , ** metadata_copy )}
475
+ elif tool_calls :
445
476
return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
446
477
return {output_key : AIMessage (content = content )}
447
478
448
479
def _format_dict_output_for_base_message (
449
- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
480
+ self ,
481
+ result : Any ,
482
+ output_key : str ,
483
+ tool_calls : Optional [list ] = None ,
484
+ metadata : Optional [dict ] = None ,
450
485
) -> Dict [str , Any ]:
451
486
"""Format dict output when user input was a BaseMessage."""
452
487
content = self ._extract_content_from_result (result )
453
- if tool_calls :
488
+
489
+ if metadata :
490
+ metadata_copy = metadata .copy ()
491
+ if tool_calls :
492
+ metadata_copy ["tool_calls" ] = tool_calls
493
+ return {output_key : AIMessage (content = content , ** metadata_copy )}
494
+ elif tool_calls :
454
495
return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
455
496
return {output_key : AIMessage (content = content )}
456
497
457
498
def _format_dict_output (
458
- self , input_dict : dict , result : Any , tool_calls : Optional [list ] = None
499
+ self ,
500
+ input_dict : dict ,
501
+ result : Any ,
502
+ tool_calls : Optional [list ] = None ,
503
+ metadata : Optional [dict ] = None ,
459
504
) -> Dict [str , Any ]:
460
505
"""Format output for dictionary input."""
461
506
output_key = self .passthrough_bot_output_key
@@ -474,13 +519,13 @@ def _format_dict_output(
474
519
)
475
520
elif all (isinstance (msg , BaseMessage ) for msg in user_input ):
476
521
return self ._format_dict_output_for_base_message_list (
477
- result , output_key , tool_calls
522
+ result , output_key , tool_calls , metadata
478
523
)
479
524
else :
480
525
return {output_key : result }
481
526
elif isinstance (user_input , BaseMessage ):
482
527
return self ._format_dict_output_for_base_message (
483
- result , output_key , tool_calls
528
+ result , output_key , tool_calls , metadata
484
529
)
485
530
486
531
# Generic fallback for dictionaries
@@ -493,6 +538,7 @@ def _format_output(
493
538
result : Any ,
494
539
context : Dict [str , Any ],
495
540
tool_calls : Optional [list ] = None ,
541
+ metadata : Optional [dict ] = None ,
496
542
) -> Any :
497
543
"""Format the output based on the input type and rails result.
498
544
@@ -515,17 +561,17 @@ def _format_output(
515
561
return self ._format_passthrough_output (result , context )
516
562
517
563
if isinstance (input , ChatPromptValue ):
518
- return self ._format_chat_prompt_output (result , tool_calls )
564
+ return self ._format_chat_prompt_output (result , tool_calls , metadata )
519
565
elif isinstance (input , StringPromptValue ):
520
566
return self ._format_string_prompt_output (result )
521
567
elif isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
522
- return self ._format_message_output (result , tool_calls )
568
+ return self ._format_message_output (result , tool_calls , metadata )
523
569
elif isinstance (input , list ) and all (
524
570
isinstance (msg , BaseMessage ) for msg in input
525
571
):
526
- return self ._format_message_output (result , tool_calls )
572
+ return self ._format_message_output (result , tool_calls , metadata )
527
573
elif isinstance (input , dict ):
528
- return self ._format_dict_output (input , result , tool_calls )
574
+ return self ._format_dict_output (input , result , tool_calls , metadata )
529
575
elif isinstance (input , str ):
530
576
return self ._format_string_prompt_output (result )
531
577
else :
@@ -672,7 +718,9 @@ def _full_rails_invoke(
672
718
result = result [0 ]
673
719
674
720
# Format and return the output based in input type
675
- return self ._format_output (input , result , context , res .tool_calls )
721
+ return self ._format_output (
722
+ input , result , context , res .tool_calls , res .llm_metadata
723
+ )
676
724
677
725
async def ainvoke (
678
726
self ,
@@ -734,7 +782,9 @@ async def _full_rails_ainvoke(
734
782
result = res .response
735
783
736
784
# Format and return the output based on input type
737
- return self ._format_output (input , result , context , res .tool_calls )
785
+ return self ._format_output (
786
+ input , result , context , res .tool_calls , res .llm_metadata
787
+ )
738
788
739
789
def stream (
740
790
self ,
0 commit comments