diff --git a/litellm/integrations/datadog/datadog_llm_obs.py b/litellm/integrations/datadog/datadog_llm_obs.py index 200f2f283dea..adab0ca6593d 100644 --- a/litellm/integrations/datadog/datadog_llm_obs.py +++ b/litellm/integrations/datadog/datadog_llm_obs.py @@ -64,7 +64,7 @@ def __init__(self, **kwargs): asyncio.create_task(self.periodic_flush()) self.flush_lock = asyncio.Lock() self.log_queue: List[LLMObsPayload] = [] - + ######################################################### # Handle datadog_llm_observability_params set as litellm.datadog_llm_observability_params ######################################################### @@ -83,22 +83,25 @@ def _get_datadog_llm_obs_params(self) -> Dict: """ dict_datadog_llm_obs_params: Dict = {} if litellm.datadog_llm_observability_params is not None: - if isinstance(litellm.datadog_llm_observability_params, DatadogLLMObsInitParams): - dict_datadog_llm_obs_params = litellm.datadog_llm_observability_params.model_dump() + if isinstance( + litellm.datadog_llm_observability_params, DatadogLLMObsInitParams + ): + dict_datadog_llm_obs_params = ( + litellm.datadog_llm_observability_params.model_dump() + ) elif isinstance(litellm.datadog_llm_observability_params, Dict): # only allow params that are of DatadogLLMObsInitParams - dict_datadog_llm_obs_params = DatadogLLMObsInitParams(**litellm.datadog_llm_observability_params).model_dump() + dict_datadog_llm_obs_params = DatadogLLMObsInitParams( + **litellm.datadog_llm_observability_params + ).model_dump() return dict_datadog_llm_obs_params - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: verbose_logger.debug( f"DataDogLLMObs: Logging success event for model {kwargs.get('model', 'unknown')}" ) - payload = self.create_llm_obs_payload( - kwargs, start_time, end_time - ) + payload = self.create_llm_obs_payload(kwargs, start_time, end_time) verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}") self.log_queue.append(payload) @@ -108,15 +111,13 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti verbose_logger.exception( f"DataDogLLMObs: Error logging success event - {str(e)}" ) - + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: verbose_logger.debug( f"DataDogLLMObs: Logging failure event for model {kwargs.get('model', 'unknown')}" ) - payload = self.create_llm_obs_payload( - kwargs, start_time, end_time - ) + payload = self.create_llm_obs_payload(kwargs, start_time, end_time) verbose_logger.debug(f"DataDogLLMObs: Payload: {payload}") self.log_queue.append(payload) @@ -184,7 +185,6 @@ def create_llm_obs_payload( messages = standard_logging_payload["messages"] messages = self._ensure_string_content(messages=messages) - response_obj = standard_logging_payload.get("response") metadata = kwargs.get("litellm_params", {}).get("metadata", {}) @@ -193,10 +193,12 @@ def create_llm_obs_payload( messages ) ) - output_meta = OutputMeta(messages=self._get_response_messages( - response_obj=response_obj, - call_type=standard_logging_payload.get("call_type") - )) + output_meta = OutputMeta( + messages=self._get_response_messages( + standard_logging_payload=standard_logging_payload, + call_type=standard_logging_payload.get("call_type"), + ) + ) error_info = self._assemble_error_info(standard_logging_payload) @@ -214,7 +216,9 @@ def create_llm_obs_payload( output_tokens=float(standard_logging_payload.get("completion_tokens", 0)), total_tokens=float(standard_logging_payload.get("total_tokens", 0)), total_cost=float(standard_logging_payload.get("response_cost", 0)), - time_to_first_token=self._get_time_to_first_token_seconds(standard_logging_payload), + time_to_first_token=self._get_time_to_first_token_seconds( + standard_logging_payload + ), ) payload: LLMObsPayload = LLMObsPayload( @@ -251,27 +255,35 @@ def _get_apm_trace_id(self) -> Optional[str]: except Exception: pass return None - - def _assemble_error_info(self, standard_logging_payload: StandardLoggingPayload) -> Optional[DDLLMObsError]: + + def _assemble_error_info( + self, standard_logging_payload: StandardLoggingPayload + ) -> Optional[DDLLMObsError]: """ Assemble error information for failure cases according to DD LLM Obs API spec """ # Handle error information for failure cases according to DD LLM Obs API spec error_info: Optional[DDLLMObsError] = None - + if standard_logging_payload.get("status") == "failure": # Try to get structured error information first - error_information: Optional[StandardLoggingPayloadErrorInformation] = standard_logging_payload.get("error_information") - + error_information: Optional[ + StandardLoggingPayloadErrorInformation + ] = standard_logging_payload.get("error_information") + if error_information: error_info = DDLLMObsError( - message=error_information.get("error_message") or standard_logging_payload.get("error_str") or "Unknown error", + message=error_information.get("error_message") + or standard_logging_payload.get("error_str") + or "Unknown error", type=error_information.get("error_class"), - stack=error_information.get("traceback") + stack=error_information.get("traceback"), ) return error_info - def _get_time_to_first_token_seconds(self, standard_logging_payload: StandardLoggingPayload) -> float: + def _get_time_to_first_token_seconds( + self, standard_logging_payload: StandardLoggingPayload + ) -> float: """ Get the time to first token in seconds @@ -280,7 +292,9 @@ def _get_time_to_first_token_seconds(self, standard_logging_payload: StandardLog For non streaming calls, CompletionStartTime is time we get the response back """ start_time: Optional[float] = standard_logging_payload.get("startTime") - completion_start_time: Optional[float] = standard_logging_payload.get("completionStartTime") + completion_start_time: Optional[float] = standard_logging_payload.get( + "completionStartTime" + ) end_time: Optional[float] = standard_logging_payload.get("endTime") if completion_start_time is not None and start_time is not None: @@ -290,19 +304,43 @@ def _get_time_to_first_token_seconds(self, standard_logging_payload: StandardLog else: return 0.0 - def _get_response_messages( - self, response_obj: Any, call_type: Optional[str] + self, standard_logging_payload: StandardLoggingPayload, call_type: Optional[str] ) -> List[Any]: """ Get the messages from the response object for now this handles logging /chat/completions responses """ + + response_obj = standard_logging_payload.get("response") if response_obj is None: return [] - - if call_type in [CallTypes.completion.value, CallTypes.acompletion.value]: + + # edge case: handle response_obj is a string representation of a dict + if isinstance(response_obj, str): + try: + import ast + + response_obj = ast.literal_eval(response_obj) + except (ValueError, SyntaxError): + try: + # fallback to json parsing + response_obj = json.loads(str(response_obj)) + except json.JSONDecodeError: + return [] + + if call_type in [ + CallTypes.completion.value, + CallTypes.acompletion.value, + CallTypes.text_completion.value, + CallTypes.atext_completion.value, + CallTypes.generate_content.value, + CallTypes.agenerate_content.value, + CallTypes.generate_content_stream.value, + CallTypes.agenerate_content_stream.value, + CallTypes.anthropic_messages.value, + ]: try: # Safely extract message from response_obj, handle failure cases if isinstance(response_obj, dict) and "choices" in response_obj: @@ -315,102 +353,104 @@ def _get_response_messages( return [] return [] - def _get_datadog_span_kind(self, call_type: Optional[str]) -> Literal["llm", "tool", "task", "embedding", "retrieval"]: + def _get_datadog_span_kind( + self, call_type: Optional[str] + ) -> Literal["llm", "tool", "task", "embedding", "retrieval"]: """ Map liteLLM call_type to appropriate DataDog LLM Observability span kind. - + Available DataDog span kinds: "llm", "tool", "task", "embedding", "retrieval" """ if call_type is None: return "llm" - + # Embedding operations if call_type in [CallTypes.embedding.value, CallTypes.aembedding.value]: return "embedding" - - # LLM completion operations + + # LLM completion operations if call_type in [ - CallTypes.completion.value, + CallTypes.completion.value, CallTypes.acompletion.value, - CallTypes.text_completion.value, + CallTypes.text_completion.value, CallTypes.atext_completion.value, - CallTypes.generate_content.value, + CallTypes.generate_content.value, CallTypes.agenerate_content.value, - CallTypes.generate_content_stream.value, + CallTypes.generate_content_stream.value, CallTypes.agenerate_content_stream.value, - CallTypes.anthropic_messages.value + CallTypes.anthropic_messages.value, ]: return "llm" - + # Tool operations if call_type in [CallTypes.call_mcp_tool.value]: return "tool" - + # Retrieval operations if call_type in [ - CallTypes.get_assistants.value, + CallTypes.get_assistants.value, CallTypes.aget_assistants.value, - CallTypes.get_thread.value, + CallTypes.get_thread.value, CallTypes.aget_thread.value, - CallTypes.get_messages.value, + CallTypes.get_messages.value, CallTypes.aget_messages.value, - CallTypes.afile_retrieve.value, + CallTypes.afile_retrieve.value, CallTypes.file_retrieve.value, - CallTypes.afile_list.value, + CallTypes.afile_list.value, CallTypes.file_list.value, - CallTypes.afile_content.value, + CallTypes.afile_content.value, CallTypes.file_content.value, - CallTypes.retrieve_batch.value, + CallTypes.retrieve_batch.value, CallTypes.aretrieve_batch.value, - CallTypes.retrieve_fine_tuning_job.value, + CallTypes.retrieve_fine_tuning_job.value, CallTypes.aretrieve_fine_tuning_job.value, - CallTypes.responses.value, + CallTypes.responses.value, CallTypes.aresponses.value, - CallTypes.alist_input_items.value + CallTypes.alist_input_items.value, ]: return "retrieval" - + # Task operations (batch, fine-tuning, file operations, etc.) if call_type in [ - CallTypes.create_batch.value, + CallTypes.create_batch.value, CallTypes.acreate_batch.value, - CallTypes.create_fine_tuning_job.value, + CallTypes.create_fine_tuning_job.value, CallTypes.acreate_fine_tuning_job.value, - CallTypes.cancel_fine_tuning_job.value, + CallTypes.cancel_fine_tuning_job.value, CallTypes.acancel_fine_tuning_job.value, - CallTypes.list_fine_tuning_jobs.value, + CallTypes.list_fine_tuning_jobs.value, CallTypes.alist_fine_tuning_jobs.value, - CallTypes.create_assistants.value, + CallTypes.create_assistants.value, CallTypes.acreate_assistants.value, - CallTypes.delete_assistant.value, + CallTypes.delete_assistant.value, CallTypes.adelete_assistant.value, - CallTypes.create_thread.value, + CallTypes.create_thread.value, CallTypes.acreate_thread.value, - CallTypes.add_message.value, + CallTypes.add_message.value, CallTypes.a_add_message.value, - CallTypes.run_thread.value, + CallTypes.run_thread.value, CallTypes.arun_thread.value, - CallTypes.run_thread_stream.value, + CallTypes.run_thread_stream.value, CallTypes.arun_thread_stream.value, - CallTypes.file_delete.value, + CallTypes.file_delete.value, CallTypes.afile_delete.value, - CallTypes.create_file.value, + CallTypes.create_file.value, CallTypes.acreate_file.value, - CallTypes.image_generation.value, + CallTypes.image_generation.value, CallTypes.aimage_generation.value, - CallTypes.image_edit.value, + CallTypes.image_edit.value, CallTypes.aimage_edit.value, - CallTypes.moderation.value, + CallTypes.moderation.value, CallTypes.amoderation.value, - CallTypes.transcription.value, + CallTypes.transcription.value, CallTypes.atranscription.value, - CallTypes.speech.value, + CallTypes.speech.value, CallTypes.aspeech.value, - CallTypes.rerank.value, - CallTypes.arerank.value + CallTypes.rerank.value, + CallTypes.arerank.value, ]: return "task" - + # Default fallback for unknown or passthrough operations return "llm" @@ -443,7 +483,9 @@ def _get_dd_llm_obs_payload_metadata( "cache_hit": standard_logging_payload.get("cache_hit", "unknown"), "cache_key": standard_logging_payload.get("cache_key", "unknown"), "saved_cache_cost": standard_logging_payload.get("saved_cache_cost", 0), - "guardrail_information": standard_logging_payload.get("guardrail_information", None), + "guardrail_information": standard_logging_payload.get( + "guardrail_information", None + ), } ######################################################### @@ -452,22 +494,32 @@ def _get_dd_llm_obs_payload_metadata( latency_metrics = self._get_latency_metrics(standard_logging_payload) _metadata.update({"latency_metrics": dict(latency_metrics)}) + ## extract tool calls and add to metadata + tool_call_metadata = self._extract_tool_call_metadata(standard_logging_payload) + _metadata.update(tool_call_metadata) + _standard_logging_metadata: dict = ( dict(standard_logging_payload.get("metadata", {})) or {} ) _metadata.update(_standard_logging_metadata) return _metadata - def _get_latency_metrics(self, standard_logging_payload: StandardLoggingPayload) -> DDLLMObsLatencyMetrics: + def _get_latency_metrics( + self, standard_logging_payload: StandardLoggingPayload + ) -> DDLLMObsLatencyMetrics: """ Get the latency metrics from the standard logging payload """ latency_metrics: DDLLMObsLatencyMetrics = DDLLMObsLatencyMetrics() # Add latency metrics to metadata # Time to first token (convert from seconds to milliseconds for consistency) - time_to_first_token_seconds = self._get_time_to_first_token_seconds(standard_logging_payload) + time_to_first_token_seconds = self._get_time_to_first_token_seconds( + standard_logging_payload + ) if time_to_first_token_seconds > 0: - latency_metrics["time_to_first_token_ms"] = time_to_first_token_seconds * 1000 + latency_metrics["time_to_first_token_ms"] = ( + time_to_first_token_seconds * 1000 + ) # LiteLLM overhead time hidden_params = standard_logging_payload.get("hidden_params", {}) @@ -476,11 +528,143 @@ def _get_latency_metrics(self, standard_logging_payload: StandardLoggingPayload) latency_metrics["litellm_overhead_time_ms"] = litellm_overhead_ms # Guardrail overhead latency - guardrail_info: Optional[StandardLoggingGuardrailInformation] = standard_logging_payload.get("guardrail_information") + guardrail_info: Optional[ + StandardLoggingGuardrailInformation + ] = standard_logging_payload.get("guardrail_information") if guardrail_info is not None: - _guardrail_duration_seconds: Optional[float] = guardrail_info.get("duration") + _guardrail_duration_seconds: Optional[float] = guardrail_info.get( + "duration" + ) if _guardrail_duration_seconds is not None: # Convert from seconds to milliseconds for consistency - latency_metrics["guardrail_overhead_time_ms"] = _guardrail_duration_seconds * 1000 - - return latency_metrics \ No newline at end of file + latency_metrics["guardrail_overhead_time_ms"] = ( + _guardrail_duration_seconds * 1000 + ) + + return latency_metrics + + def _process_input_messages_preserving_tool_calls( + self, messages: List[Any] + ) -> List[Dict[str, Any]]: + """ + Process input messages while preserving tool_calls and tool message types. + + This bypasses the lossy string conversion when tool calls are present, + allowing complex nested tool_calls objects to be preserved for Datadog. + """ + processed = [] + for msg in messages: + if isinstance(msg, dict): + # Preserve messages with tool_calls or tool role as-is + if "tool_calls" in msg or msg.get("role") == "tool": + processed.append(msg) + else: + # For regular messages, still apply string conversion + converted = ( + handle_any_messages_to_chat_completion_str_messages_conversion( + [msg] + ) + ) + processed.extend(converted) + else: + # For non-dict messages, apply string conversion + converted = ( + handle_any_messages_to_chat_completion_str_messages_conversion( + [msg] + ) + ) + processed.extend(converted) + return processed + + @staticmethod + def _tool_calls_kv_pair(tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Extract tool call information into key-value pairs for Datadog metadata. + + Similar to OpenTelemetry's implementation but adapted for Datadog's format. + """ + kv_pairs: Dict[str, Any] = {} + for idx, tool_call in enumerate(tool_calls): + try: + # Extract tool call ID + tool_id = tool_call.get("id") + if tool_id: + kv_pairs[f"tool_calls.{idx}.id"] = tool_id + + # Extract tool call type + tool_type = tool_call.get("type") + if tool_type: + kv_pairs[f"tool_calls.{idx}.type"] = tool_type + + # Extract function information + function = tool_call.get("function") + if function: + function_name = function.get("name") + if function_name: + kv_pairs[f"tool_calls.{idx}.function.name"] = function_name + + function_arguments = function.get("arguments") + if function_arguments: + # Store arguments as JSON string for Datadog + if isinstance(function_arguments, str): + kv_pairs[ + f"tool_calls.{idx}.function.arguments" + ] = function_arguments + else: + import json + + kv_pairs[ + f"tool_calls.{idx}.function.arguments" + ] = json.dumps(function_arguments) + except (KeyError, TypeError, ValueError) as e: + verbose_logger.debug( + f"DataDogLLMObs: Error processing tool call {idx}: {str(e)}" + ) + continue + + return kv_pairs + + def _extract_tool_call_metadata( + self, standard_logging_payload: StandardLoggingPayload + ) -> Dict[str, Any]: + """ + Extract tool call information from both input messages and response for Datadog metadata. + """ + tool_call_metadata: Dict[str, Any] = {} + + try: + # Extract tool calls from input messages + messages = standard_logging_payload.get("messages", []) + if messages and isinstance(messages, list): + for message in messages: + if isinstance(message, dict) and "tool_calls" in message: + tool_calls = message.get("tool_calls") + if tool_calls: + input_tool_calls_kv = self._tool_calls_kv_pair(tool_calls) + # Prefix with "input_" to distinguish from response tool calls + for key, value in input_tool_calls_kv.items(): + tool_call_metadata[f"input_{key}"] = value + + # Extract tool calls from response + response_obj = standard_logging_payload.get("response") + if response_obj and isinstance(response_obj, dict): + choices = response_obj.get("choices", []) + for choice in choices: + if isinstance(choice, dict): + message = choice.get("message") + if message and isinstance(message, dict): + tool_calls = message.get("tool_calls") + if tool_calls: + response_tool_calls_kv = self._tool_calls_kv_pair( + tool_calls + ) + # Prefix with "output_" to distinguish from input tool calls + for key, value in response_tool_calls_kv.items(): + tool_call_metadata[f"output_{key}"] = value + + except Exception as e: + verbose_logger.debug( + f"DataDogLLMObs: Error extracting tool call metadata: {str(e)}" + ) + + return tool_call_metadata diff --git a/litellm/types/integrations/datadog_llm_obs.py b/litellm/types/integrations/datadog_llm_obs.py index 75c55bcc93cd..41489ace30f5 100644 --- a/litellm/types/integrations/datadog_llm_obs.py +++ b/litellm/types/integrations/datadog_llm_obs.py @@ -10,7 +10,7 @@ class InputMeta(TypedDict): messages: List[ - Dict[str, str] + Dict[str, Any] # changed to fit with tool calls ] # Relevant Issue: https://github.com/BerriAI/litellm/issues/9494 @@ -20,6 +20,7 @@ class OutputMeta(TypedDict): class DDLLMObsError(TypedDict, total=False): """Error information on the span according to DD LLM Obs API spec""" + message: str # The error message stack: Optional[str] # The stack trace type: Optional[str] # The error type @@ -54,7 +55,7 @@ class LLMObsPayload(TypedDict, total=False): duration: int metrics: LLMMetrics tags: List - status: Literal["ok", "error"] # Error status ("ok" or "error"). Defaults to "ok". + status: Literal["ok", "error"] # Error status ("ok" or "error"). Defaults to "ok". class DDSpanAttributes(TypedDict): @@ -72,10 +73,11 @@ class DatadogLLMObsInitParams(StandardCustomLoggerInitParams): """ Params for initializing a DatadogLLMObs logger on litellm """ + pass class DDLLMObsLatencyMetrics(TypedDict, total=False): time_to_first_token_ms: float litellm_overhead_time_ms: float - guardrail_overhead_time_ms: float \ No newline at end of file + guardrail_overhead_time_ms: float diff --git a/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py b/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py index b1ce08de9e77..39b8427fcf12 100644 --- a/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py +++ b/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py @@ -1,11 +1,9 @@ import asyncio -import json import os import sys -import uuid from datetime import datetime, timedelta -from typing import Dict, Optional -from unittest.mock import MagicMock, Mock, patch +from typing import Optional +from unittest.mock import Mock, patch, MagicMock import pytest @@ -16,8 +14,6 @@ from litellm.integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger from litellm.types.integrations.datadog_llm_obs import ( DatadogLLMObsInitParams, - LLMMetrics, - LLMObsPayload, ) from litellm.types.utils import ( StandardLoggingGuardrailInformation, @@ -129,7 +125,7 @@ def create_standard_logging_payload_with_failure() -> StandardLoggingPayload: error_class="RateLimitError", llm_provider="openai", traceback="Traceback (most recent call last):\n File test.py, line 1\n RateLimitError: You exceeded your current quota", - error_message="RateLimitError: You exceeded your current quota" + error_message="RateLimitError: You exceeded your current quota", ), model_parameters={"stream": False}, hidden_params=StandardLoggingHiddenParams( @@ -150,100 +146,77 @@ class TestDataDogLLMObsLogger: @pytest.fixture def mock_env_vars(self): """Mock environment variables for DataDog""" - with patch.dict(os.environ, { - "DD_API_KEY": "test_api_key", - "DD_SITE": "us5.datadoghq.com" - }): + with patch.dict( + os.environ, {"DD_API_KEY": "test_api_key", "DD_SITE": "us5.datadoghq.com"} + ): yield @pytest.fixture def mock_response_obj(self): """Create a mock response object""" mock_response = Mock() - mock_response.__getitem__ = Mock(return_value={ - "choices": [{"message": Mock(json=Mock(return_value={"role": "assistant", "content": "Hello!"}))}] - }) + mock_response.__getitem__ = Mock( + return_value={ + "choices": [ + { + "message": Mock( + json=Mock( + return_value={"role": "assistant", "content": "Hello!"} + ) + ) + } + ] + } + ) return mock_response def test_cost_and_trace_id_integration(self, mock_env_vars, mock_response_obj): """Test that total_cost is passed and trace_id from standard payload is used""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + standard_payload = create_standard_logging_payload_with_cache() - + kwargs = { "standard_logging_object": standard_payload, - "litellm_params": {"metadata": {"trace_id": "old-trace-id-should-be-ignored"}} + "litellm_params": { + "metadata": {"trace_id": "old-trace-id-should-be-ignored"} + }, } - + start_time = datetime.now() end_time = datetime.now() - + payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) - + # Test 1: Verify total_cost is correctly extracted from response_cost assert payload["metrics"].get("total_cost") == 0.05 - + # Test 2: Verify trace_id comes from standard_logging_payload, not metadata assert payload["trace_id"] == "test-trace-id-123" - - # Test 3: Verify saved_cache_cost is in metadata + + # Test 3: Verify saved_cache_cost is in metadata metadata = payload["meta"]["metadata"] assert metadata["saved_cache_cost"] == 0.02 - assert metadata["cache_hit"] == True + assert metadata["cache_hit"] is True assert metadata["cache_key"] == "test-cache-key-789" - def test_apm_id_included(self, mock_env_vars, mock_response_obj): - """Test that the current APM trace ID is attached to the payload""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): - fake_tracer = MagicMock() - fake_span = MagicMock() - fake_span.trace_id = 987654321 - fake_tracer.current_span.return_value = fake_span - - with patch('litellm.integrations.datadog.datadog_llm_obs.tracer', fake_tracer): - logger = DataDogLLMObsLogger() - - standard_payload = create_standard_logging_payload_with_cache() - - kwargs = { - "standard_logging_object": standard_payload, - "litellm_params": {"metadata": {}} - } - - start_time = datetime.now() - end_time = datetime.now() - - payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) - - assert payload["apm_id"] == str(fake_span.trace_id) - def test_cache_metadata_fields(self, mock_env_vars, mock_response_obj): """Test that cache-related metadata fields are correctly tracked""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + standard_payload = create_standard_logging_payload_with_cache() - - kwargs = { - "standard_logging_object": standard_payload, - "litellm_params": {"metadata": {}} - } - - start_time = datetime.now() - end_time = datetime.now() - - payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) - + # Test the _get_dd_llm_obs_payload_metadata method directly metadata = logger._get_dd_llm_obs_payload_metadata(standard_payload) - + # Verify all cache-related fields are present - assert metadata["cache_hit"] == True + assert metadata["cache_hit"] is True assert metadata["cache_key"] == "test-cache-key-789" assert metadata["saved_cache_cost"] == 0.02 assert metadata["id"] == "test-request-id-456" @@ -253,55 +226,66 @@ def test_cache_metadata_fields(self, mock_env_vars, mock_response_obj): def test_get_time_to_first_token_seconds(self, mock_env_vars): """Test the _get_time_to_first_token_seconds method for streaming calls""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + # Test streaming case (completion_start_time available) streaming_payload = create_standard_logging_payload_with_cache() # Modify times for testing: start=1000, completion_start=1002, end=1005 streaming_payload["startTime"] = 1000.0 streaming_payload["completionStartTime"] = 1002.0 streaming_payload["endTime"] = 1005.0 - + # Test streaming case: should use completion_start_time - start_time - time_to_first_token = logger._get_time_to_first_token_seconds(streaming_payload) + time_to_first_token = logger._get_time_to_first_token_seconds( + streaming_payload + ) assert time_to_first_token == 2.0 # 1002.0 - 1000.0 = 2.0 seconds - def test_datadog_span_kind_mapping(self, mock_env_vars): """Test that call_type values are correctly mapped to DataDog span kinds""" from litellm.types.utils import CallTypes - - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + # Test embedding operations assert logger._get_datadog_span_kind(CallTypes.embedding.value) == "embedding" assert logger._get_datadog_span_kind(CallTypes.aembedding.value) == "embedding" - + # Test LLM completion operations assert logger._get_datadog_span_kind(CallTypes.completion.value) == "llm" assert logger._get_datadog_span_kind(CallTypes.acompletion.value) == "llm" assert logger._get_datadog_span_kind(CallTypes.text_completion.value) == "llm" assert logger._get_datadog_span_kind(CallTypes.generate_content.value) == "llm" - assert logger._get_datadog_span_kind(CallTypes.anthropic_messages.value) == "llm" - + assert ( + logger._get_datadog_span_kind(CallTypes.anthropic_messages.value) == "llm" + ) + # Test tool operations assert logger._get_datadog_span_kind(CallTypes.call_mcp_tool.value) == "tool" - + # Test retrieval operations - assert logger._get_datadog_span_kind(CallTypes.get_assistants.value) == "retrieval" - assert logger._get_datadog_span_kind(CallTypes.file_retrieve.value) == "retrieval" - assert logger._get_datadog_span_kind(CallTypes.retrieve_batch.value) == "retrieval" - + assert ( + logger._get_datadog_span_kind(CallTypes.get_assistants.value) == "retrieval" + ) + assert ( + logger._get_datadog_span_kind(CallTypes.file_retrieve.value) == "retrieval" + ) + assert ( + logger._get_datadog_span_kind(CallTypes.retrieve_batch.value) == "retrieval" + ) + # Test task operations assert logger._get_datadog_span_kind(CallTypes.create_batch.value) == "task" assert logger._get_datadog_span_kind(CallTypes.image_generation.value) == "task" assert logger._get_datadog_span_kind(CallTypes.moderation.value) == "task" assert logger._get_datadog_span_kind(CallTypes.transcription.value) == "task" - + # Test default fallback assert logger._get_datadog_span_kind("unknown_call_type") == "llm" assert logger._get_datadog_span_kind(None) == "llm" @@ -309,68 +293,78 @@ def test_datadog_span_kind_mapping(self, mock_env_vars): @pytest.mark.asyncio async def test_async_log_failure_event(self, mock_env_vars): """Test that async_log_failure_event correctly processes failure payloads according to DD LLM Obs API spec""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + # Ensure log_queue starts empty logger.log_queue = [] - + standard_failure_payload = create_standard_logging_payload_with_failure() - + kwargs = { "standard_logging_object": standard_failure_payload, "model": "gpt-4", - "litellm_params": {"metadata": {}} + "litellm_params": {"metadata": {}}, } - + start_time = datetime.now() end_time = datetime.now() + timedelta(seconds=2) - + # Mock async_send_batch to prevent actual network calls - with patch.object(logger, 'async_send_batch') as mock_send_batch: + with patch.object(logger, "async_send_batch") as mock_send_batch: # Call the method under test await logger.async_log_failure_event(kwargs, None, start_time, end_time) - + # Verify payload was added to queue assert len(logger.log_queue) == 1 - + # Verify the payload has correct failure characteristics according to DD LLM Obs API spec payload = logger.log_queue[0] assert payload["trace_id"] == "test-trace-id-failure-456" - assert payload["meta"]["metadata"]["id"] == "test-request-id-failure-789" + assert ( + payload["meta"]["metadata"]["id"] == "test-request-id-failure-789" + ) assert payload["status"] == "error" - + # Verify error information follows DD LLM Obs API spec - assert payload["meta"]["error"]["message"] == "RateLimitError: You exceeded your current quota" + assert ( + payload["meta"]["error"]["message"] + == "RateLimitError: You exceeded your current quota" + ) assert payload["meta"]["error"]["type"] == "RateLimitError" - assert payload["meta"]["error"]["stack"] == "Traceback (most recent call last):\n File test.py, line 1\n RateLimitError: You exceeded your current quota" - + assert ( + payload["meta"]["error"]["stack"] + == "Traceback (most recent call last):\n File test.py, line 1\n RateLimitError: You exceeded your current quota" + ) + assert payload["metrics"]["total_cost"] == 0.0 assert payload["metrics"]["total_tokens"] == 0 assert payload["metrics"]["output_tokens"] == 0 - + # Verify batch sending not triggered (queue size < batch_size) mock_send_batch.assert_not_called() - class TestDataDogLLMObsLoggerForRedaction(DataDogLLMObsLogger): """Test suite for DataDog LLM Observability Logger""" + def __init__(self, **kwargs): super().__init__(**kwargs) self.logged_standard_logging_payload: Optional[StandardLoggingPayload] = None - + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): self.logged_standard_logging_payload = kwargs.get("standard_logging_object") class TestS3Logger(CustomLogger): """Test suite for S3 Logger""" + def __init__(self, **kwargs): super().__init__(**kwargs) self.logged_standard_logging_payload: Optional[StandardLoggingPayload] = None - + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): self.logged_standard_logging_payload = kwargs.get("standard_logging_object") @@ -380,26 +374,26 @@ async def test_dd_llms_obs_redaction(mock_env_vars): # init DD with turn_off_message_logging=True litellm._turn_on_debug() from litellm.types.utils import LiteLLMCommonStrings - litellm.datadog_llm_observability_params = DatadogLLMObsInitParams(turn_off_message_logging=True) + + litellm.datadog_llm_observability_params = DatadogLLMObsInitParams( + turn_off_message_logging=True + ) dd_llms_obs_logger = TestDataDogLLMObsLoggerForRedaction() test_s3_logger = TestS3Logger() - litellm.callbacks = [ - dd_llms_obs_logger, - test_s3_logger - ] + litellm.callbacks = [dd_llms_obs_logger, test_s3_logger] # call litellm await litellm.acompletion( model="gpt-4o", mock_response="Hi there!", - messages=[{"role": "user", "content": "Hello, world!"}] + messages=[{"role": "user", "content": "Hello, world!"}], ) # sleep 1 second for logging to complete await asyncio.sleep(1) ################# - # test validation + # test validation # 1. both loggers logged a standard_logging_payload # 2. DD LLM Obs standard_logging_payload has messages and response redacted # 3. S3 standard_logging_payload does not have messages and response redacted @@ -407,25 +401,37 @@ async def test_dd_llms_obs_redaction(mock_env_vars): assert dd_llms_obs_logger.logged_standard_logging_payload is not None assert test_s3_logger.logged_standard_logging_payload is not None - print("logged DD LLM Obs payload", json.dumps(dd_llms_obs_logger.logged_standard_logging_payload, indent=4, default=str)) - print("\n\nlogged S3 payload", json.dumps(test_s3_logger.logged_standard_logging_payload, indent=4, default=str)) + assert ( + dd_llms_obs_logger.logged_standard_logging_payload["messages"][0]["content"] + == LiteLLMCommonStrings.redacted_by_litellm.value + ) + assert ( + dd_llms_obs_logger.logged_standard_logging_payload["response"]["choices"][0][ + "message" + ]["content"] + == LiteLLMCommonStrings.redacted_by_litellm.value + ) + + assert test_s3_logger.logged_standard_logging_payload["messages"] == [ + {"role": "user", "content": "Hello, world!"} + ] + assert ( + test_s3_logger.logged_standard_logging_payload["response"]["choices"][0][ + "message" + ]["content"] + == "Hi there!" + ) - assert dd_llms_obs_logger.logged_standard_logging_payload["messages"][0]["content"] == LiteLLMCommonStrings.redacted_by_litellm.value - assert dd_llms_obs_logger.logged_standard_logging_payload["response"]["choices"][0]["message"]["content"] == LiteLLMCommonStrings.redacted_by_litellm.value - assert test_s3_logger.logged_standard_logging_payload["messages"] == [{"role": "user", "content": "Hello, world!"}] - assert test_s3_logger.logged_standard_logging_payload["response"]["choices"][0]["message"]["content"] == "Hi there!" - - @pytest.fixture def mock_env_vars(): """Mock environment variables for DataDog""" - with patch.dict(os.environ, { - "DD_API_KEY": "test_api_key", - "DD_SITE": "us5.datadoghq.com" - }): + with patch.dict( + os.environ, {"DD_API_KEY": "test_api_key", "DD_SITE": "us5.datadoghq.com"} + ): yield + @pytest.mark.asyncio async def test_create_llm_obs_payload(mock_env_vars): datadog_llm_obs_logger = DataDogLLMObsLogger() @@ -440,8 +446,6 @@ async def test_create_llm_obs_payload(mock_env_vars): end_time=datetime.now() + timedelta(seconds=1), ) - print("dd created payload", payload) - assert payload["name"] == "litellm_llm_call" assert payload["meta"]["kind"] == "llm" assert payload["meta"]["input"]["messages"] == [ @@ -462,9 +466,13 @@ def create_standard_logging_payload_with_latency_metrics() -> StandardLoggingPay end_time=1234567890.5, duration=0.5, # 500ms guardrail_request={"input": "test input message", "user_id": "test_user"}, - guardrail_response={"output": "filtered output", "flagged": False, "score": 0.1}, + guardrail_response={ + "output": "filtered output", + "flagged": False, + "score": 0.1, + }, ) - + hidden_params = StandardLoggingHiddenParams( model_id="model-123", cache_key="test-cache-key", @@ -473,7 +481,7 @@ def create_standard_logging_payload_with_latency_metrics() -> StandardLoggingPay litellm_overhead_time_ms=150.0, # 150ms additional_headers=None, ) - + return StandardLoggingPayload( id="test-request-id-latency", call_type="completion", @@ -525,40 +533,45 @@ def create_standard_logging_payload_with_latency_metrics() -> StandardLoggingPay def test_latency_metrics_in_metadata(mock_env_vars): """Test that time to first token, litellm overhead, and guardrail overhead are included in metadata""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + standard_payload = create_standard_logging_payload_with_latency_metrics() - + kwargs = { "standard_logging_object": standard_payload, - "litellm_params": {"metadata": {}} + "litellm_params": {"metadata": {}}, } - + start_time = datetime.now() end_time = datetime.now() - + # Test the metadata generation directly metadata = logger._get_dd_llm_obs_payload_metadata(standard_payload) latency_metadata = metadata.get("latency_metrics", {}) - + # Verify time to first token is included (800ms) assert "time_to_first_token_ms" in latency_metadata - assert abs(latency_metadata["time_to_first_token_ms"] - 800.0) < 0.001 # 0.8 seconds * 1000 with tolerance for floating-point precision - + assert ( + abs(latency_metadata["time_to_first_token_ms"] - 800.0) < 0.001 + ) # 0.8 seconds * 1000 with tolerance for floating-point precision + # Verify litellm overhead is included (150ms) assert "litellm_overhead_time_ms" in latency_metadata assert latency_metadata["litellm_overhead_time_ms"] == 150.0 - - # Verify guardrail overhead is included (500ms) + + # Verify guardrail overhead is included (500ms) assert "guardrail_overhead_time_ms" in latency_metadata - assert latency_metadata["guardrail_overhead_time_ms"] == 500.0 # 0.5 seconds * 1000 - + assert ( + latency_metadata["guardrail_overhead_time_ms"] == 500.0 + ) # 0.5 seconds * 1000 + # Verify these metrics are also included in the full payload payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) payload_metadata_latency = payload["meta"]["metadata"]["latency_metrics"] - + assert abs(payload_metadata_latency["time_to_first_token_ms"] - 800.0) < 0.001 assert payload_metadata_latency["litellm_overhead_time_ms"] == 150.0 assert payload_metadata_latency["guardrail_overhead_time_ms"] == 500.0 @@ -566,26 +579,29 @@ def test_latency_metrics_in_metadata(mock_env_vars): def test_latency_metrics_edge_cases(mock_env_vars): """Test latency metrics with edge cases (missing fields, zero values, etc.)""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + # Test case 1: No latency metrics present standard_payload = create_standard_logging_payload_with_cache() metadata = logger._get_dd_llm_obs_payload_metadata(standard_payload) - + # Should not have latency fields if data is missing/zero assert "time_to_first_token_ms" not in metadata # Will be 0, so not included - assert "litellm_overhead_time_ms" not in metadata # Not present in hidden_params + assert ( + "litellm_overhead_time_ms" not in metadata + ) # Not present in hidden_params assert "guardrail_overhead_time_ms" not in metadata # No guardrail_information - + # Test case 2: Zero time to first token should not be included standard_payload = create_standard_logging_payload_with_cache() standard_payload["startTime"] = 1000.0 standard_payload["completionStartTime"] = 1000.0 # Same time = 0 difference metadata = logger._get_dd_llm_obs_payload_metadata(standard_payload) assert "time_to_first_token_ms" not in metadata - + # Test case 3: Missing guardrail duration should not crash standard_payload = create_standard_logging_payload_with_cache() standard_payload["guardrail_information"] = StandardLoggingGuardrailInformation( @@ -599,42 +615,285 @@ def test_latency_metrics_edge_cases(mock_env_vars): def test_guardrail_information_in_metadata(mock_env_vars): """Test that guardrail_information is included in metadata with input/output fields""" - with patch('litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client'), \ - patch('asyncio.create_task'): + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): logger = DataDogLLMObsLogger() - + # Create a standard payload with guardrail information standard_payload = create_standard_logging_payload_with_latency_metrics() - + kwargs = { "standard_logging_object": standard_payload, - "litellm_params": {"metadata": {}} + "litellm_params": {"metadata": {}}, } - + start_time = datetime.now() end_time = datetime.now() - + # Create the payload and verify guardrail_information is in metadata payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) metadata = payload["meta"]["metadata"] - + # Verify guardrail_information is present in metadata assert "guardrail_information" in metadata assert metadata["guardrail_information"] is not None - + # Verify the guardrail information structure guardrail_info = metadata["guardrail_information"] assert guardrail_info["guardrail_name"] == "test_guardrail" assert guardrail_info["guardrail_status"] == "success" assert guardrail_info["duration"] == 0.5 - + # Verify input/output fields are present assert "guardrail_request" in guardrail_info assert "guardrail_response" in guardrail_info - + # Validate the input/output content assert guardrail_info["guardrail_request"]["input"] == "test input message" assert guardrail_info["guardrail_request"]["user_id"] == "test_user" assert guardrail_info["guardrail_response"]["output"] == "filtered output" - assert guardrail_info["guardrail_response"]["flagged"] == False + assert guardrail_info["guardrail_response"]["flagged"] is False assert guardrail_info["guardrail_response"]["score"] == 0.1 + + +def create_standard_logging_payload_with_tool_calls() -> StandardLoggingPayload: + """Create a StandardLoggingPayload object with tool calls for testing""" + return { + "id": "test-request-id-tool-calls", + "call_type": "completion", + "response_cost": 0.05, + "response_cost_failure_debug_info": None, + "status": "success", + "total_tokens": 50, + "prompt_tokens": 20, + "completion_tokens": 30, + "startTime": 1234567890.0, + "endTime": 1234567891.0, + "completionStartTime": 1234567890.5, + "model_map_information": {"model_map_key": "gpt-4", "model_map_value": None}, + "model": "gpt-4", + "model_id": "model-123", + "model_group": "openai-gpt", + "api_base": "https://api.openai.com", + "metadata": { + "user_api_key_hash": "test_hash", + "user_api_key_org_id": None, + "user_api_key_alias": "test_alias", + "user_api_key_team_id": "test_team", + "user_api_key_user_id": "test_user", + "user_api_key_team_alias": "test_team_alias", + "user_api_key_user_email": None, + "user_api_key_end_user_id": None, + "user_api_key_request_route": None, + "spend_logs_metadata": None, + "requester_ip_address": "127.0.0.1", + "requester_metadata": None, + "requester_custom_headers": None, + "prompt_management_metadata": None, + "mcp_tool_call_metadata": None, + "vector_store_request_metadata": None, + "applied_guardrails": None, + "usage_object": None, + "cold_storage_object_key": None, + }, + "cache_hit": False, + "cache_key": None, + "saved_cache_cost": 0.0, + "request_tags": [], + "end_user": None, + "requester_ip_address": "127.0.0.1", + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "I'll check the weather for you.", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": '{"temperature": 72, "condition": "sunny"}', + }, + ], + "response": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "It's 72°F and sunny in NYC!", + "tool_calls": [ + { + "id": "call_456", + "type": "function", + "function": { + "name": "format_response", + "arguments": '{"temp": 72, "condition": "sunny"}', + }, + } + ], + } + } + ] + }, + "error_str": None, + "model_parameters": {"temperature": 0.7}, + "hidden_params": { + "model_id": "model-123", + "cache_key": None, + "api_base": "https://api.openai.com", + "response_cost": "0.05", + "litellm_overhead_time_ms": None, + "additional_headers": None, + "batch_models": None, + "litellm_model_name": None, + "usage_object": None, + }, + "stream": None, + "response_time": 1.0, + "error_information": None, + "guardrail_information": None, + "standard_built_in_tools_params": None, + "trace_id": "test-trace-id-tool-calls", + "custom_llm_provider": "openai", + } + + +class TestDataDogLLMObsLoggerToolCalls: + """Simple test suite for DataDog LLM Observability Logger tool call handling""" + + @pytest.fixture + def mock_env_vars(self): + """Mock environment variables for DataDog""" + with patch.dict( + os.environ, {"DD_API_KEY": "test_api_key", "DD_SITE": "us5.datadoghq.com"} + ): + yield + + def test_tool_call_span_kind_mapping(self, mock_env_vars): + """Test that tool call operations are correctly mapped to 'tool' span kind""" + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): + logger = DataDogLLMObsLogger() + + # Test MCP tool call mapping + from litellm.types.utils import CallTypes + + assert ( + logger._get_datadog_span_kind(CallTypes.call_mcp_tool.value) == "tool" + ) + + def test_tool_call_payload_creation(self, mock_env_vars): + """Test that tool call payloads are created correctly""" + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): + logger = DataDogLLMObsLogger() + + standard_payload = create_standard_logging_payload_with_tool_calls() + + kwargs = { + "standard_logging_object": standard_payload, + "litellm_params": {"metadata": {}}, + } + + start_time = datetime.now() + end_time = datetime.now() + + payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) + + # Verify basic payload structure + assert payload.get("name") == "litellm_llm_call" + assert payload.get("status") == "ok" + assert ( + payload.get("meta", {}).get("kind") == "llm" + ) # Regular completion, not tool call + + # Verify metrics + metrics = payload.get("metrics", {}) + assert metrics.get("input_tokens") == 20 + assert metrics.get("output_tokens") == 30 + assert metrics.get("total_tokens") == 50 + + def test_tool_call_messages_preserved(self, mock_env_vars): + """Test that tool call messages are preserved in the payload""" + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): + logger = DataDogLLMObsLogger() + + standard_payload = create_standard_logging_payload_with_tool_calls() + + kwargs = { + "standard_logging_object": standard_payload, + "litellm_params": {"metadata": {}}, + } + + start_time = datetime.now() + end_time = datetime.now() + + payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) + + # Verify input messages include tool calls + meta = payload.get("meta", {}) + input_meta = meta.get("input", {}) + input_messages = input_meta.get("messages", []) + assert len(input_messages) == 3 + + # Check assistant message has tool calls + assistant_msg = input_messages[1] + assert assistant_msg.get("role") == "assistant" + assert "tool_calls" in assistant_msg + tool_calls = assistant_msg.get("tool_calls", []) + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + function_info = tool_call.get("function", {}) + assert function_info.get("name") == "get_weather" + + # Check tool message + tool_msg = input_messages[2] + assert tool_msg.get("role") == "tool" + assert tool_msg.get("tool_call_id") == "call_123" + + def test_tool_call_response_handling(self, mock_env_vars): + """Test that tool calls in response are handled correctly""" + with patch( + "litellm.integrations.datadog.datadog_llm_obs.get_async_httpx_client" + ), patch("asyncio.create_task"): + logger = DataDogLLMObsLogger() + + standard_payload = create_standard_logging_payload_with_tool_calls() + + kwargs = { + "standard_logging_object": standard_payload, + "litellm_params": {"metadata": {}}, + } + + start_time = datetime.now() + end_time = datetime.now() + + payload = logger.create_llm_obs_payload(kwargs, start_time, end_time) + + # Verify output messages include tool calls + meta = payload.get("meta", {}) + output_meta = meta.get("output", {}) + output_messages = output_meta.get("messages", []) + assert len(output_messages) == 1 + + output_msg = output_messages[0] + assert output_msg.get("role") == "assistant" + assert "tool_calls" in output_msg + output_tool_calls = output_msg.get("tool_calls", []) + assert len(output_tool_calls) == 1 + output_function_info = output_tool_calls[0].get("function", {}) + assert output_function_info.get("name") == "format_response"