|
21 | 21 |
|
22 | 22 | from langchain_core.env import get_runtime_environment |
23 | 23 | from langchain_core.load import dumpd |
| 24 | +from langchain_core.messages.ai import add_usage |
24 | 25 | from langchain_core.tracers.base import BaseTracer |
25 | 26 | from langchain_core.tracers.schemas import Run |
26 | 27 |
|
@@ -72,26 +73,27 @@ def _get_executor() -> ThreadPoolExecutor: |
72 | 73 | def _get_usage_metadata_from_generations( |
73 | 74 | generations: list[list[dict[str, Any]]], |
74 | 75 | ) -> dict[str, Any] | None: |
75 | | - """Extract usage_metadata from generations. |
| 76 | + """Extract and aggregate usage_metadata from generations. |
76 | 77 |
|
77 | | - Iterates through generations to find and return the first usage_metadata |
78 | | - found in a message. This is typically present in chat model outputs. |
| 78 | + Iterates through generations to find and aggregate all usage_metadata |
| 79 | + found in messages. This is typically present in chat model outputs. |
79 | 80 |
|
80 | 81 | Args: |
81 | 82 | generations: List of generation batches, where each batch is a list |
82 | 83 | of generation dicts that may contain a "message" key with |
83 | 84 | "usage_metadata". |
84 | 85 |
|
85 | 86 | Returns: |
86 | | - The usage_metadata dict if found, otherwise None. |
| 87 | + The aggregated usage_metadata dict if found, otherwise None. |
87 | 88 | """ |
| 89 | + output: dict[str, Any] | None = None |
88 | 90 | for generation_batch in generations: |
89 | 91 | for generation in generation_batch: |
90 | 92 | if isinstance(generation, dict) and "message" in generation: |
91 | 93 | message = generation["message"] |
92 | 94 | if isinstance(message, dict) and "usage_metadata" in message: |
93 | | - return message["usage_metadata"] |
94 | | - return None |
| 95 | + output = add_usage(output, message["usage_metadata"]) |
| 96 | + return output |
95 | 97 |
|
96 | 98 |
|
97 | 99 | class LangChainTracer(BaseTracer): |
|
0 commit comments