Skip to content

Commit 8309931

Browse files
committed
feat(runnable_rails): implement AIMessage metadata parity in RunnableRails
Ensure AIMessage responses from RunnableRails contain the same metadata fields (response_metadata, usage_metadata, additional_kwargs, id) as direct LLM calls, enabling consistent LangChain integration behavior. fix
1 parent cbfc294 commit 8309931

File tree

7 files changed

+472
-17
lines changed

7 files changed

+472
-17
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
2727
from nemoguardrails.context import (
2828
llm_call_info_var,
29+
llm_response_metadata_var,
2930
reasoning_trace_var,
3031
tool_calls_var,
3132
)
@@ -85,6 +86,7 @@ async def llm_call(
8586
response = await _invoke_with_message_list(llm, prompt, all_callbacks, stop)
8687

8788
_store_tool_calls(response)
89+
_store_response_metadata(response)
8890
return _extract_content(response)
8991

9092

@@ -173,6 +175,20 @@ def _store_tool_calls(response) -> None:
173175
tool_calls_var.set(tool_calls)
174176

175177

178+
def _store_response_metadata(response) -> None:
179+
"""Store response metadata excluding content for metadata preservation."""
180+
if hasattr(response, "model_fields"):
181+
metadata = {}
182+
for field_name in response.model_fields:
183+
if (
184+
field_name != "content"
185+
): # Exclude content since it may be modified by rails
186+
metadata[field_name] = getattr(response, field_name)
187+
llm_response_metadata_var.set(metadata)
188+
else:
189+
llm_response_metadata_var.set(None)
190+
191+
176192
def _extract_content(response) -> str:
177193
"""Extract text content from response."""
178194
if hasattr(response, "content"):
@@ -655,3 +671,15 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]:
655671
tool_calls_var.set(None)
656672
return tool_calls
657673
return None
674+
675+
676+
def get_and_clear_response_metadata_contextvar() -> Optional[dict]:
677+
"""Get the current response metadata and clear it from the context.
678+
679+
Returns:
680+
Optional[dict]: The response metadata if it exists, None otherwise.
681+
"""
682+
if metadata := llm_response_metadata_var.get():
683+
llm_response_metadata_var.set(None)
684+
return metadata
685+
return None

nemoguardrails/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,8 @@
4242
tool_calls_var: contextvars.ContextVar[Optional[list]] = contextvars.ContextVar(
4343
"tool_calls", default=None
4444
)
45+
46+
# The response metadata from the current LLM response.
47+
llm_response_metadata_var: contextvars.ContextVar[
48+
Optional[dict]
49+
] = contextvars.ContextVar("llm_response_metadata", default=None)

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,21 @@ def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> An
396396
return passthrough_output
397397

398398
def _format_chat_prompt_output(
399-
self, result: Any, tool_calls: Optional[list] = None
399+
self,
400+
result: Any,
401+
tool_calls: Optional[list] = None,
402+
metadata: Optional[dict] = None,
400403
) -> AIMessage:
401404
"""Format output for ChatPromptValue input."""
402405
content = self._extract_content_from_result(result)
403-
if tool_calls:
406+
407+
if metadata and isinstance(metadata, dict):
408+
metadata_copy = metadata.copy()
409+
metadata_copy.pop("content", None)
410+
if tool_calls:
411+
metadata_copy["tool_calls"] = tool_calls
412+
return AIMessage(content=content, **metadata_copy)
413+
elif tool_calls:
404414
return AIMessage(content=content, tool_calls=tool_calls)
405415
return AIMessage(content=content)
406416

@@ -409,11 +419,21 @@ def _format_string_prompt_output(self, result: Any) -> str:
409419
return self._extract_content_from_result(result)
410420

411421
def _format_message_output(
412-
self, result: Any, tool_calls: Optional[list] = None
422+
self,
423+
result: Any,
424+
tool_calls: Optional[list] = None,
425+
metadata: Optional[dict] = None,
413426
) -> AIMessage:
414427
"""Format output for BaseMessage input types."""
415428
content = self._extract_content_from_result(result)
416-
if tool_calls:
429+
430+
if metadata and isinstance(metadata, dict):
431+
metadata_copy = metadata.copy()
432+
metadata_copy.pop("content", None)
433+
if tool_calls:
434+
metadata_copy["tool_calls"] = tool_calls
435+
return AIMessage(content=content, **metadata_copy)
436+
elif tool_calls:
417437
return AIMessage(content=content, tool_calls=tool_calls)
418438
return AIMessage(content=content)
419439

@@ -437,25 +457,50 @@ def _format_dict_output_for_dict_message_list(
437457
}
438458

439459
def _format_dict_output_for_base_message_list(
440-
self, result: Any, output_key: str, tool_calls: Optional[list] = None
460+
self,
461+
result: Any,
462+
output_key: str,
463+
tool_calls: Optional[list] = None,
464+
metadata: Optional[dict] = None,
441465
) -> Dict[str, Any]:
442466
"""Format dict output when user input was a list of BaseMessage objects."""
443467
content = self._extract_content_from_result(result)
444-
if tool_calls:
468+
469+
if metadata and isinstance(metadata, dict):
470+
metadata_copy = metadata.copy()
471+
metadata_copy.pop("content", None)
472+
if tool_calls:
473+
metadata_copy["tool_calls"] = tool_calls
474+
return {output_key: AIMessage(content=content, **metadata_copy)}
475+
elif tool_calls:
445476
return {output_key: AIMessage(content=content, tool_calls=tool_calls)}
446477
return {output_key: AIMessage(content=content)}
447478

448479
def _format_dict_output_for_base_message(
449-
self, result: Any, output_key: str, tool_calls: Optional[list] = None
480+
self,
481+
result: Any,
482+
output_key: str,
483+
tool_calls: Optional[list] = None,
484+
metadata: Optional[dict] = None,
450485
) -> Dict[str, Any]:
451486
"""Format dict output when user input was a BaseMessage."""
452487
content = self._extract_content_from_result(result)
453-
if tool_calls:
488+
489+
if metadata:
490+
metadata_copy = metadata.copy()
491+
if tool_calls:
492+
metadata_copy["tool_calls"] = tool_calls
493+
return {output_key: AIMessage(content=content, **metadata_copy)}
494+
elif tool_calls:
454495
return {output_key: AIMessage(content=content, tool_calls=tool_calls)}
455496
return {output_key: AIMessage(content=content)}
456497

457498
def _format_dict_output(
458-
self, input_dict: dict, result: Any, tool_calls: Optional[list] = None
499+
self,
500+
input_dict: dict,
501+
result: Any,
502+
tool_calls: Optional[list] = None,
503+
metadata: Optional[dict] = None,
459504
) -> Dict[str, Any]:
460505
"""Format output for dictionary input."""
461506
output_key = self.passthrough_bot_output_key
@@ -474,13 +519,13 @@ def _format_dict_output(
474519
)
475520
elif all(isinstance(msg, BaseMessage) for msg in user_input):
476521
return self._format_dict_output_for_base_message_list(
477-
result, output_key, tool_calls
522+
result, output_key, tool_calls, metadata
478523
)
479524
else:
480525
return {output_key: result}
481526
elif isinstance(user_input, BaseMessage):
482527
return self._format_dict_output_for_base_message(
483-
result, output_key, tool_calls
528+
result, output_key, tool_calls, metadata
484529
)
485530

486531
# Generic fallback for dictionaries
@@ -493,6 +538,7 @@ def _format_output(
493538
result: Any,
494539
context: Dict[str, Any],
495540
tool_calls: Optional[list] = None,
541+
metadata: Optional[dict] = None,
496542
) -> Any:
497543
"""Format the output based on the input type and rails result.
498544
@@ -515,17 +561,17 @@ def _format_output(
515561
return self._format_passthrough_output(result, context)
516562

517563
if isinstance(input, ChatPromptValue):
518-
return self._format_chat_prompt_output(result, tool_calls)
564+
return self._format_chat_prompt_output(result, tool_calls, metadata)
519565
elif isinstance(input, StringPromptValue):
520566
return self._format_string_prompt_output(result)
521567
elif isinstance(input, (HumanMessage, AIMessage, BaseMessage)):
522-
return self._format_message_output(result, tool_calls)
568+
return self._format_message_output(result, tool_calls, metadata)
523569
elif isinstance(input, list) and all(
524570
isinstance(msg, BaseMessage) for msg in input
525571
):
526-
return self._format_message_output(result, tool_calls)
572+
return self._format_message_output(result, tool_calls, metadata)
527573
elif isinstance(input, dict):
528-
return self._format_dict_output(input, result, tool_calls)
574+
return self._format_dict_output(input, result, tool_calls, metadata)
529575
elif isinstance(input, str):
530576
return self._format_string_prompt_output(result)
531577
else:
@@ -672,7 +718,9 @@ def _full_rails_invoke(
672718
result = result[0]
673719

674720
# Format and return the output based in input type
675-
return self._format_output(input, result, context, res.tool_calls)
721+
return self._format_output(
722+
input, result, context, res.tool_calls, res.llm_metadata
723+
)
676724

677725
async def ainvoke(
678726
self,
@@ -734,7 +782,9 @@ async def _full_rails_ainvoke(
734782
result = res.response
735783

736784
# Format and return the output based on input type
737-
return self._format_output(input, result, context, res.tool_calls)
785+
return self._format_output(
786+
input, result, context, res.tool_calls, res.llm_metadata
787+
)
738788

739789
def stream(
740790
self,

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from nemoguardrails.actions.llm.generation import LLMGenerationActions
3434
from nemoguardrails.actions.llm.utils import (
3535
get_and_clear_reasoning_trace_contextvar,
36+
get_and_clear_response_metadata_contextvar,
3637
get_and_clear_tool_calls_contextvar,
3738
get_colang_history,
3839
)
@@ -1086,6 +1087,7 @@ async def generate_async(
10861087
options.log.internal_events = True
10871088

10881089
tool_calls = get_and_clear_tool_calls_contextvar()
1090+
llm_metadata = get_and_clear_response_metadata_contextvar()
10891091

10901092
# If we have generation options, we prepare a GenerationResponse instance.
10911093
if options:
@@ -1106,6 +1108,9 @@ async def generate_async(
11061108
if tool_calls:
11071109
res.tool_calls = tool_calls
11081110

1111+
if llm_metadata:
1112+
res.llm_metadata = llm_metadata
1113+
11091114
if self.config.colang_version == "1.0":
11101115
# If output variables are specified, we extract their values
11111116
if options.output_vars:

nemoguardrails/rails/llm/options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ class GenerationResponse(BaseModel):
412412
default=None,
413413
description="Tool calls extracted from the LLM response, if any.",
414414
)
415+
llm_metadata: Optional[dict] = Field(
416+
default=None,
417+
description="Metadata from the LLM response (additional_kwargs, response_metadata, usage_metadata, etc.)",
418+
)
415419

416420

417421
if __name__ == "__main__":

0 commit comments

Comments
 (0)