Skip to content

Commit 02ced19

Browse files
committed
chore(core): flatten generations for LangChainTracer
1 parent 9511665 commit 02ced19

File tree

2 files changed

+145
-3
lines changed

2 files changed

+145
-3
lines changed

libs/core/langchain_core/tracers/langchain.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,16 @@ def _on_llm_end(self, run: Run) -> None:
295295
"""Process the LLM Run."""
296296
# Extract usage_metadata from outputs and store in extra.metadata
297297
if run.outputs and "generations" in run.outputs:
298-
usage_metadata = _get_usage_metadata_from_generations(
299-
run.outputs["generations"]
300-
)
298+
generations = run.outputs["generations"]
299+
usage_metadata = _get_usage_metadata_from_generations(generations)
301300
if usage_metadata is not None:
302301
if "metadata" not in run.extra:
303302
run.extra["metadata"] = {}
304303
run.extra["metadata"]["usage_metadata"] = usage_metadata
304+
305+
# Flatten outputs if there's only a single generation
306+
if len(generations) == 1 and len(generations[0]) == 1:
307+
run.outputs = generations[0][0]
305308
self._update_run_single(run)
306309

307310
def _on_llm_error(self, run: Run) -> None:

libs/core/tests/unit_tests/tracers/test_langchain.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,142 @@ def capture_run(r: Run) -> None:
352352
assert "metadata" in captured_run.extra
353353
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
354354
assert captured_run.extra["metadata"]["existing_key"] == "existing_value"
355+
356+
357+
def test_on_llm_end_flattens_single_generation() -> None:
358+
"""Test that outputs are flattened when generations is a 1x1 matrix."""
359+
client = unittest.mock.MagicMock(spec=Client)
360+
client.tracing_queue = None
361+
tracer = LangChainTracer(client=client)
362+
363+
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
364+
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
365+
366+
run = tracer.run_map[str(run_id)]
367+
generation = {"text": "Hello!", "message": {"content": "Hello!"}}
368+
run.outputs = {"generations": [[generation]]}
369+
370+
captured_run = None
371+
372+
def capture_run(r: Run) -> None:
373+
nonlocal captured_run
374+
captured_run = r
375+
376+
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
377+
tracer._on_llm_end(run)
378+
379+
assert captured_run is not None
380+
# Should be flattened to just the generation object
381+
assert captured_run.outputs == generation
382+
383+
384+
def test_on_llm_end_does_not_flatten_multiple_generations_in_batch() -> None:
385+
"""Test that outputs are not flattened when there are multiple generations."""
386+
client = unittest.mock.MagicMock(spec=Client)
387+
client.tracing_queue = None
388+
tracer = LangChainTracer(client=client)
389+
390+
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
391+
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
392+
393+
run = tracer.run_map[str(run_id)]
394+
generation1 = {"text": "Hello!", "message": {"content": "Hello!"}}
395+
generation2 = {"text": "Hi there!", "message": {"content": "Hi there!"}}
396+
run.outputs = {"generations": [[generation1, generation2]]}
397+
398+
captured_run = None
399+
400+
def capture_run(r: Run) -> None:
401+
nonlocal captured_run
402+
captured_run = r
403+
404+
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
405+
tracer._on_llm_end(run)
406+
407+
assert captured_run is not None
408+
# Should NOT be flattened - keep original structure
409+
assert captured_run.outputs == {"generations": [[generation1, generation2]]}
410+
411+
412+
def test_on_llm_end_does_not_flatten_multiple_batches() -> None:
413+
"""Test that outputs are not flattened when there are multiple batches."""
414+
client = unittest.mock.MagicMock(spec=Client)
415+
client.tracing_queue = None
416+
tracer = LangChainTracer(client=client)
417+
418+
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
419+
tracer.on_llm_start({"name": "test_llm"}, ["foo", "bar"], run_id=run_id)
420+
421+
run = tracer.run_map[str(run_id)]
422+
generation1 = {"text": "Response 1", "message": {"content": "Response 1"}}
423+
generation2 = {"text": "Response 2", "message": {"content": "Response 2"}}
424+
run.outputs = {"generations": [[generation1], [generation2]]}
425+
426+
captured_run = None
427+
428+
def capture_run(r: Run) -> None:
429+
nonlocal captured_run
430+
captured_run = r
431+
432+
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
433+
tracer._on_llm_end(run)
434+
435+
assert captured_run is not None
436+
# Should NOT be flattened - keep original structure
437+
assert captured_run.outputs == {"generations": [[generation1], [generation2]]}
438+
439+
440+
def test_on_llm_end_does_not_flatten_multiple_batches_multiple_generations() -> None:
441+
"""Test outputs not flattened with multiple batches and multiple generations."""
442+
client = unittest.mock.MagicMock(spec=Client)
443+
client.tracing_queue = None
444+
tracer = LangChainTracer(client=client)
445+
446+
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
447+
tracer.on_llm_start({"name": "test_llm"}, ["foo", "bar"], run_id=run_id)
448+
449+
run = tracer.run_map[str(run_id)]
450+
gen1a = {"text": "1a", "message": {"content": "1a"}}
451+
gen1b = {"text": "1b", "message": {"content": "1b"}}
452+
gen2a = {"text": "2a", "message": {"content": "2a"}}
453+
gen2b = {"text": "2b", "message": {"content": "2b"}}
454+
run.outputs = {"generations": [[gen1a, gen1b], [gen2a, gen2b]]}
455+
456+
captured_run = None
457+
458+
def capture_run(r: Run) -> None:
459+
nonlocal captured_run
460+
captured_run = r
461+
462+
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
463+
tracer._on_llm_end(run)
464+
465+
assert captured_run is not None
466+
# Should NOT be flattened - keep original structure
467+
assert captured_run.outputs == {"generations": [[gen1a, gen1b], [gen2a, gen2b]]}
468+
469+
470+
def test_on_llm_end_handles_empty_generations() -> None:
471+
"""Test that empty generations are handled without error."""
472+
client = unittest.mock.MagicMock(spec=Client)
473+
client.tracing_queue = None
474+
tracer = LangChainTracer(client=client)
475+
476+
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
477+
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
478+
479+
run = tracer.run_map[str(run_id)]
480+
run.outputs = {"generations": []}
481+
482+
captured_run = None
483+
484+
def capture_run(r: Run) -> None:
485+
nonlocal captured_run
486+
captured_run = r
487+
488+
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
489+
tracer._on_llm_end(run)
490+
491+
assert captured_run is not None
492+
# Should keep original structure when empty
493+
assert captured_run.outputs == {"generations": []}

0 commit comments

Comments
 (0)