Skip to content

Commit 01f2bef

Browse files
committed
aggregate usage metadata
1 parent 4fa21f7 commit 01f2bef

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

libs/core/langchain_core/tracers/langchain.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from langchain_core.env import get_runtime_environment
2323
from langchain_core.load import dumpd
24+
from langchain_core.messages.ai import add_usage
2425
from langchain_core.tracers.base import BaseTracer
2526
from langchain_core.tracers.schemas import Run
2627

@@ -72,26 +73,27 @@ def _get_executor() -> ThreadPoolExecutor:
7273
def _get_usage_metadata_from_generations(
7374
generations: list[list[dict[str, Any]]],
7475
) -> dict[str, Any] | None:
75-
"""Extract usage_metadata from generations.
76+
"""Extract and aggregate usage_metadata from generations.
7677
77-
Iterates through generations to find and return the first usage_metadata
78-
found in a message. This is typically present in chat model outputs.
78+
Iterates through generations to find and aggregate all usage_metadata
79+
found in messages. This is typically present in chat model outputs.
7980
8081
Args:
8182
generations: List of generation batches, where each batch is a list
8283
of generation dicts that may contain a "message" key with
8384
"usage_metadata".
8485
8586
Returns:
86-
The usage_metadata dict if found, otherwise None.
87+
The aggregated usage_metadata dict if found, otherwise None.
8788
"""
89+
output: dict[str, Any] | None = None
8890
for generation_batch in generations:
8991
for generation in generation_batch:
9092
if isinstance(generation, dict) and "message" in generation:
9193
message = generation["message"]
9294
if isinstance(message, dict) and "usage_metadata" in message:
93-
return message["usage_metadata"]
94-
return None
95+
output = add_usage(output, message["usage_metadata"])
96+
return output
9597

9698

9799
class LangChainTracer(BaseTracer):

libs/core/tests/unit_tests/tracers/test_langchain.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def mock_create_run(**kwargs: Any) -> Any:
179179
# Returns None for empty generations
180180
([], None),
181181
([[]], None),
182-
# Returns first usage_metadata from multiple generations
182+
# Aggregates usage_metadata across multiple generations
183183
(
184184
[
185185
[
@@ -207,7 +207,7 @@ def mock_create_run(**kwargs: Any) -> Any:
207207
},
208208
]
209209
],
210-
{"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
210+
{"input_tokens": 55, "output_tokens": 110, "total_tokens": 165},
211211
),
212212
# Finds usage_metadata across multiple batches
213213
(
@@ -236,7 +236,7 @@ def mock_create_run(**kwargs: Any) -> Any:
236236
"returns_none_when_no_message",
237237
"returns_none_for_empty_list",
238238
"returns_none_for_empty_batch",
239-
"returns_first_from_multiple_generations",
239+
"aggregates_across_multiple_generations",
240240
"finds_across_multiple_batches",
241241
],
242242
)

0 commit comments

Comments
 (0)