@@ -352,3 +352,142 @@ def capture_run(r: Run) -> None:
352352 assert "metadata" in captured_run .extra
353353 assert captured_run .extra ["metadata" ]["usage_metadata" ] == usage_metadata
354354 assert captured_run .extra ["metadata" ]["existing_key" ] == "existing_value"
355+
356+
357+ def test_on_llm_end_flattens_single_generation () -> None :
358+ """Test that outputs are flattened when generations is a 1x1 matrix."""
359+ client = unittest .mock .MagicMock (spec = Client )
360+ client .tracing_queue = None
361+ tracer = LangChainTracer (client = client )
362+
363+ run_id = UUID ("9d878ab3-e5ca-4218-aef6-44cbdc90160a" )
364+ tracer .on_llm_start ({"name" : "test_llm" }, ["foo" ], run_id = run_id )
365+
366+ run = tracer .run_map [str (run_id )]
367+ generation = {"text" : "Hello!" , "message" : {"content" : "Hello!" }}
368+ run .outputs = {"generations" : [[generation ]]}
369+
370+ captured_run = None
371+
372+ def capture_run (r : Run ) -> None :
373+ nonlocal captured_run
374+ captured_run = r
375+
376+ with unittest .mock .patch .object (tracer , "_update_run_single" , capture_run ):
377+ tracer ._on_llm_end (run )
378+
379+ assert captured_run is not None
380+ # Should be flattened to just the generation object
381+ assert captured_run .outputs == generation
382+
383+
384+ def test_on_llm_end_does_not_flatten_multiple_generations_in_batch () -> None :
385+ """Test that outputs are not flattened when there are multiple generations."""
386+ client = unittest .mock .MagicMock (spec = Client )
387+ client .tracing_queue = None
388+ tracer = LangChainTracer (client = client )
389+
390+ run_id = UUID ("9d878ab3-e5ca-4218-aef6-44cbdc90160a" )
391+ tracer .on_llm_start ({"name" : "test_llm" }, ["foo" ], run_id = run_id )
392+
393+ run = tracer .run_map [str (run_id )]
394+ generation1 = {"text" : "Hello!" , "message" : {"content" : "Hello!" }}
395+ generation2 = {"text" : "Hi there!" , "message" : {"content" : "Hi there!" }}
396+ run .outputs = {"generations" : [[generation1 , generation2 ]]}
397+
398+ captured_run = None
399+
400+ def capture_run (r : Run ) -> None :
401+ nonlocal captured_run
402+ captured_run = r
403+
404+ with unittest .mock .patch .object (tracer , "_update_run_single" , capture_run ):
405+ tracer ._on_llm_end (run )
406+
407+ assert captured_run is not None
408+ # Should NOT be flattened - keep original structure
409+ assert captured_run .outputs == {"generations" : [[generation1 , generation2 ]]}
410+
411+
412+ def test_on_llm_end_does_not_flatten_multiple_batches () -> None :
413+ """Test that outputs are not flattened when there are multiple batches."""
414+ client = unittest .mock .MagicMock (spec = Client )
415+ client .tracing_queue = None
416+ tracer = LangChainTracer (client = client )
417+
418+ run_id = UUID ("9d878ab3-e5ca-4218-aef6-44cbdc90160a" )
419+ tracer .on_llm_start ({"name" : "test_llm" }, ["foo" , "bar" ], run_id = run_id )
420+
421+ run = tracer .run_map [str (run_id )]
422+ generation1 = {"text" : "Response 1" , "message" : {"content" : "Response 1" }}
423+ generation2 = {"text" : "Response 2" , "message" : {"content" : "Response 2" }}
424+ run .outputs = {"generations" : [[generation1 ], [generation2 ]]}
425+
426+ captured_run = None
427+
428+ def capture_run (r : Run ) -> None :
429+ nonlocal captured_run
430+ captured_run = r
431+
432+ with unittest .mock .patch .object (tracer , "_update_run_single" , capture_run ):
433+ tracer ._on_llm_end (run )
434+
435+ assert captured_run is not None
436+ # Should NOT be flattened - keep original structure
437+ assert captured_run .outputs == {"generations" : [[generation1 ], [generation2 ]]}
438+
439+
440+ def test_on_llm_end_does_not_flatten_multiple_batches_multiple_generations () -> None :
441+ """Test outputs not flattened with multiple batches and multiple generations."""
442+ client = unittest .mock .MagicMock (spec = Client )
443+ client .tracing_queue = None
444+ tracer = LangChainTracer (client = client )
445+
446+ run_id = UUID ("9d878ab3-e5ca-4218-aef6-44cbdc90160a" )
447+ tracer .on_llm_start ({"name" : "test_llm" }, ["foo" , "bar" ], run_id = run_id )
448+
449+ run = tracer .run_map [str (run_id )]
450+ gen1a = {"text" : "1a" , "message" : {"content" : "1a" }}
451+ gen1b = {"text" : "1b" , "message" : {"content" : "1b" }}
452+ gen2a = {"text" : "2a" , "message" : {"content" : "2a" }}
453+ gen2b = {"text" : "2b" , "message" : {"content" : "2b" }}
454+ run .outputs = {"generations" : [[gen1a , gen1b ], [gen2a , gen2b ]]}
455+
456+ captured_run = None
457+
458+ def capture_run (r : Run ) -> None :
459+ nonlocal captured_run
460+ captured_run = r
461+
462+ with unittest .mock .patch .object (tracer , "_update_run_single" , capture_run ):
463+ tracer ._on_llm_end (run )
464+
465+ assert captured_run is not None
466+ # Should NOT be flattened - keep original structure
467+ assert captured_run .outputs == {"generations" : [[gen1a , gen1b ], [gen2a , gen2b ]]}
468+
469+
470+ def test_on_llm_end_handles_empty_generations () -> None :
471+ """Test that empty generations are handled without error."""
472+ client = unittest .mock .MagicMock (spec = Client )
473+ client .tracing_queue = None
474+ tracer = LangChainTracer (client = client )
475+
476+ run_id = UUID ("9d878ab3-e5ca-4218-aef6-44cbdc90160a" )
477+ tracer .on_llm_start ({"name" : "test_llm" }, ["foo" ], run_id = run_id )
478+
479+ run = tracer .run_map [str (run_id )]
480+ run .outputs = {"generations" : []}
481+
482+ captured_run = None
483+
484+ def capture_run (r : Run ) -> None :
485+ nonlocal captured_run
486+ captured_run = r
487+
488+ with unittest .mock .patch .object (tracer , "_update_run_single" , capture_run ):
489+ tracer ._on_llm_end (run )
490+
491+ assert captured_run is not None
492+ # Should keep original structure when empty
493+ assert captured_run .outputs == {"generations" : []}
0 commit comments