diff --git a/docs/my-website/docs/proxy/logging_spec.md b/docs/my-website/docs/proxy/logging_spec.md index 6364b8c4444f..e0281c9c9bfb 100644 --- a/docs/my-website/docs/proxy/logging_spec.md +++ b/docs/my-website/docs/proxy/logging_spec.md @@ -91,7 +91,7 @@ Inherits from `StandardLoggingUserAPIKeyMetadata` and adds: | `applied_guardrails` | `Optional[List[str]]` | List of applied guardrail names | | `usage_object` | `Optional[dict]` | Raw usage object from the LLM provider | | `cold_storage_object_key` | `Optional[str]` | S3/GCS object key for cold storage retrieval | -| `guardrail_information` | `Optional[StandardLoggingGuardrailInformation]` | Guardrail information | +| `guardrail_information` | `Optional[list[StandardLoggingGuardrailInformation]]` | Guardrail information | ## StandardLoggingVectorStoreRequest @@ -170,7 +170,7 @@ A literal type with two possible values: | `guardrail_mode` | `Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]]` | Guardrail mode | | `guardrail_request` | `Optional[dict]` | Guardrail request | | `guardrail_response` | `Optional[Union[dict, str, List[dict]]]` | Guardrail response | -| `guardrail_status` | `Literal["success", "failure", "blocked"]` | Guardrail execution status: `success` = no violations detected, `blocked` = content blocked/modified due to policy violations, `failure` = technical error or API failure | +| `guardrail_status` | `Literal["success", "guardrail_intervened", "guardrail_failed_to_respond"]` | Guardrail execution status: `success` = no violations detected, `blocked` = content blocked/modified due to policy violations, `failure` = technical error or API failure | | `start_time` | `Optional[float]` | Start time of the guardrail | | `end_time` | `Optional[float]` | End time of the guardrail | | `duration` | `Optional[float]` | Duration of the guardrail in seconds | diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index c3e1a31c3efa..b50d05ed2ec6 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -59,7 +59,6 @@ def __init__( self.mask_response_content: bool = mask_response_content if supported_event_hooks: - ## validate event_hook is in supported_event_hooks self._validate_event_hook(event_hook, supported_event_hooks) super().__init__(**kwargs) @@ -80,7 +79,6 @@ def _validate_event_hook( ], supported_event_hooks: List[GuardrailEventHooks], ) -> None: - def _validate_event_hook_list_is_in_supported_event_hooks( event_hook: Union[List[GuardrailEventHooks], List[str]], supported_event_hooks: List[GuardrailEventHooks], @@ -130,15 +128,12 @@ def _guardrail_is_in_requested_guardrails( self, requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]], ) -> bool: - for _guardrail in requested_guardrails: if isinstance(_guardrail, dict): if self.guardrail_name in _guardrail: - return True elif isinstance(_guardrail, str): if self.guardrail_name == _guardrail: - return True return False @@ -146,7 +141,6 @@ def _guardrail_is_in_requested_guardrails( async def async_pre_call_deployment_hook( self, kwargs: Dict[str, Any], call_type: Optional[CallTypes] ) -> Optional[dict]: - from litellm.proxy._types import UserAPIKeyAuth # should run guardrail @@ -385,14 +379,24 @@ def add_standard_logging_guardrail_information_to_request_data( duration=duration, masked_entity_count=masked_entity_count, ) + + def _append_guardrail_info(container: dict) -> None: + key = "standard_logging_guardrail_information" + existing = container.get(key) + if existing is None: + container[key] = [slg] + elif isinstance(existing, list): + existing.append(slg) + else: + # should not happen + container[key] = [existing, slg] + if "metadata" in request_data: if request_data["metadata"] is None: request_data["metadata"] = {} - request_data["metadata"]["standard_logging_guardrail_information"] = slg + _append_guardrail_info(request_data["metadata"]) elif "litellm_metadata" in request_data: - request_data["litellm_metadata"][ - "standard_logging_guardrail_information" - ] = slg + _append_guardrail_info(request_data["litellm_metadata"]) else: verbose_logger.warning( "unable to log guardrail information. No metadata found in request_data" @@ -497,37 +501,46 @@ def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None """ for key, value in vars(litellm_params).items(): setattr(self, key, value) - - def get_guardrails_messages_for_call_type(self, call_type: CallTypes, data: Optional[dict] = None) -> Optional[List[AllMessageValues]]: + + def get_guardrails_messages_for_call_type( + self, call_type: CallTypes, data: Optional[dict] = None + ) -> Optional[List[AllMessageValues]]: """ Returns the messages for the given call type and data """ if call_type is None or data is None: return None - + ######################################################### - # /chat/completions - # /messages + # /chat/completions + # /messages # Both endpoints store the messages in the "messages" key ######################################################### - if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value or call_type == CallTypes.anthropic_messages.value: + if ( + call_type == CallTypes.completion.value + or call_type == CallTypes.acompletion.value + or call_type == CallTypes.anthropic_messages.value + ): return data.get("messages") - + ######################################################### - # /responses + # /responses # User/System messages are stored in the "input" key, use litellm transformation to get the messages ######################################################### - if call_type == CallTypes.responses.value or call_type == CallTypes.aresponses.value: + if ( + call_type == CallTypes.responses.value + or call_type == CallTypes.aresponses.value + ): from typing import cast from litellm.responses.litellm_completion_transformation.transformation import ( LiteLLMCompletionResponsesConfig, ) - + input_data = data.get("input") if input_data is None: return None - + messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages( input=input_data, responses_api_request=data, diff --git a/litellm/integrations/datadog/datadog_llm_obs.py b/litellm/integrations/datadog/datadog_llm_obs.py index fc3cf4b9ff27..b44762d0af88 100644 --- a/litellm/integrations/datadog/datadog_llm_obs.py +++ b/litellm/integrations/datadog/datadog_llm_obs.py @@ -498,7 +498,9 @@ def _get_dd_llm_obs_payload_metadata( "guardrail_information": standard_logging_payload.get( "guardrail_information", None ), - "is_streamed_request": self._get_stream_value_from_payload(standard_logging_payload), + "is_streamed_request": self._get_stream_value_from_payload( + standard_logging_payload + ), } ######################################################### @@ -548,21 +550,24 @@ def _get_latency_metrics( # Guardrail overhead latency guardrail_info: Optional[ - StandardLoggingGuardrailInformation + list[StandardLoggingGuardrailInformation] ] = standard_logging_payload.get("guardrail_information") if guardrail_info is not None: - _guardrail_duration_seconds: Optional[float] = guardrail_info.get( - "duration" - ) - if _guardrail_duration_seconds is not None: + total_duration = 0.0 + for info in guardrail_info: + _guardrail_duration_seconds: Optional[float] = info.get("duration") + if _guardrail_duration_seconds is not None: + total_duration += float(_guardrail_duration_seconds) + + if total_duration > 0: # Convert from seconds to milliseconds for consistency - latency_metrics["guardrail_overhead_time_ms"] = ( - _guardrail_duration_seconds * 1000 - ) + latency_metrics["guardrail_overhead_time_ms"] = total_duration * 1000 return latency_metrics - def _get_stream_value_from_payload(self, standard_logging_payload: StandardLoggingPayload) -> bool: + def _get_stream_value_from_payload( + self, standard_logging_payload: StandardLoggingPayload + ) -> bool: """ Extract the stream value from standard logging payload. diff --git a/litellm/integrations/langfuse/langfuse.py b/litellm/integrations/langfuse/langfuse.py index 7f807bb8b0c9..a067d285245e 100644 --- a/litellm/integrations/langfuse/langfuse.py +++ b/litellm/integrations/langfuse/langfuse.py @@ -688,11 +688,17 @@ def _log_langfuse_v2( # noqa: PLR0915 "completion_tokens": _usage_obj.completion_tokens, "total_cost": cost if self._supports_costs() else None, } - usage_details = LangfuseUsageDetails(input=_usage_obj.prompt_tokens, - output=_usage_obj.completion_tokens, - total=_usage_obj.total_tokens, - cache_creation_input_tokens=_usage_obj.get('cache_creation_input_tokens', 0), - cache_read_input_tokens=_usage_obj.get('cache_read_input_tokens', 0)) + usage_details = LangfuseUsageDetails( + input=_usage_obj.prompt_tokens, + output=_usage_obj.completion_tokens, + total=_usage_obj.total_tokens, + cache_creation_input_tokens=_usage_obj.get( + "cache_creation_input_tokens", 0 + ), + cache_read_input_tokens=_usage_obj.get( + "cache_read_input_tokens", 0 + ), + ) generation_name = clean_metadata.pop("generation_name", None) if generation_name is None: @@ -790,7 +796,7 @@ def _get_responses_api_content_for_langfuse( """ Get the responses API content for Langfuse logging """ - if hasattr(response_obj, 'output') and response_obj.output: + if hasattr(response_obj, "output") and response_obj.output: # ResponsesAPIResponse.output is a list of strings return response_obj.output else: @@ -880,29 +886,44 @@ def _log_guardrail_information_as_span( guardrail_information = standard_logging_object.get( "guardrail_information", None ) - if guardrail_information is None: + if not guardrail_information: verbose_logger.debug( - "Not logging guardrail information as span because guardrail_information is None" + "Not logging guardrail information as span because guardrail_information is empty" ) return - span = trace.span( - name="guardrail", - input=guardrail_information.get("guardrail_request", None), - output=guardrail_information.get("guardrail_response", None), - metadata={ - "guardrail_name": guardrail_information.get("guardrail_name", None), - "guardrail_mode": guardrail_information.get("guardrail_mode", None), - "guardrail_masked_entity_count": guardrail_information.get( - "masked_entity_count", None - ), - }, - start_time=guardrail_information.get("start_time", None), # type: ignore - end_time=guardrail_information.get("end_time", None), # type: ignore - ) + if not isinstance(guardrail_information, list): + verbose_logger.debug( + "Not logging guardrail information as span because guardrail_information is not a list: %s", + type(guardrail_information), + ) + return + + for guardrail_entry in guardrail_information: + if not isinstance(guardrail_entry, dict): + verbose_logger.debug( + "Skipping guardrail entry with unexpected type: %s", + type(guardrail_entry), + ) + continue + + span = trace.span( + name="guardrail", + input=guardrail_entry.get("guardrail_request", None), + output=guardrail_entry.get("guardrail_response", None), + metadata={ + "guardrail_name": guardrail_entry.get("guardrail_name", None), + "guardrail_mode": guardrail_entry.get("guardrail_mode", None), + "guardrail_masked_entity_count": guardrail_entry.get( + "masked_entity_count", None + ), + }, + start_time=guardrail_entry.get("start_time", None), # type: ignore + end_time=guardrail_entry.get("end_time", None), # type: ignore + ) - verbose_logger.debug(f"Logged guardrail information as span: {span}") - span.end() + verbose_logger.debug(f"Logged guardrail information as span: {span}") + span.end() def _add_prompt_to_generation_params( diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 9315384ad964..9a17244a06d2 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -141,7 +141,6 @@ def __init__( meter_provider: Optional[Any] = None, **kwargs, ): - if config is None: config = OpenTelemetryConfig.from_env() @@ -203,13 +202,14 @@ def _init_tracing(self, tracer_provider): # Check if a TracerProvider is already set globally (e.g., by Langfuse SDK) try: from opentelemetry.trace import ProxyTracerProvider + existing_provider = trace.get_tracer_provider() # If an actual provider exists (not the default proxy), use it if not isinstance(existing_provider, ProxyTracerProvider): verbose_logger.debug( "OpenTelemetry: Using existing TracerProvider: %s", - type(existing_provider).__name__ + type(existing_provider).__name__, ) tracer_provider = existing_provider # Don't call set_tracer_provider to preserve existing context @@ -223,7 +223,7 @@ def _init_tracing(self, tracer_provider): # Fallback: create a new provider if something goes wrong verbose_logger.debug( "OpenTelemetry: Exception checking existing provider, creating new one: %s", - str(e) + str(e), ) tracer_provider = TracerProvider(resource=_get_litellm_resource()) tracer_provider.add_span_processor(self._get_span_processor()) @@ -232,7 +232,7 @@ def _init_tracing(self, tracer_provider): # Tracer provider explicitly provided (e.g., for testing) verbose_logger.debug( "OpenTelemetry: Using provided TracerProvider: %s", - type(tracer_provider).__name__ + type(tracer_provider).__name__, ) trace.set_tracer_provider(tracer_provider) @@ -514,9 +514,9 @@ def get_tracer_to_use_for_request(self, kwargs: dict) -> Tracer: def _get_dynamic_otel_headers_from_kwargs(self, kwargs) -> Optional[dict]: """Extract dynamic headers from kwargs if available.""" - standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( - kwargs.get("standard_callback_dynamic_params") - ) + standard_callback_dynamic_params: Optional[ + StandardCallbackDynamicParams + ] = kwargs.get("standard_callback_dynamic_params") if not standard_callback_dynamic_params: return None @@ -775,52 +775,63 @@ def _create_guardrail_span( if standard_logging_payload is None: return - guardrail_information = standard_logging_payload.get("guardrail_information") - if guardrail_information is None: + guardrail_information_data = standard_logging_payload.get( + "guardrail_information" + ) + if not guardrail_information_data: return - start_time_float = guardrail_information.get("start_time") - end_time_float = guardrail_information.get("end_time") - start_time_datetime = datetime.now() - if start_time_float is not None: - start_time_datetime = datetime.fromtimestamp(start_time_float) - end_time_datetime = datetime.now() - if end_time_float is not None: - end_time_datetime = datetime.fromtimestamp(end_time_float) + guardrail_information_list = [ + information + for information in guardrail_information_data + if isinstance(information, dict) + ] - otel_tracer: Tracer = self.get_tracer_to_use_for_request(kwargs) - guardrail_span = otel_tracer.start_span( - name="guardrail", - start_time=self._to_ns(start_time_datetime), - context=context, - ) + if not guardrail_information_list: + return - self.safe_set_attribute( - span=guardrail_span, - key="guardrail_name", - value=guardrail_information.get("guardrail_name"), - ) + otel_tracer: Tracer = self.get_tracer_to_use_for_request(kwargs) + for guardrail_information in guardrail_information_list: + start_time_float = guardrail_information.get("start_time") + end_time_float = guardrail_information.get("end_time") + start_time_datetime = datetime.now() + if start_time_float is not None: + start_time_datetime = datetime.fromtimestamp(start_time_float) + end_time_datetime = datetime.now() + if end_time_float is not None: + end_time_datetime = datetime.fromtimestamp(end_time_float) + + guardrail_span = otel_tracer.start_span( + name="guardrail", + start_time=self._to_ns(start_time_datetime), + context=context, + ) - self.safe_set_attribute( - span=guardrail_span, - key="guardrail_mode", - value=guardrail_information.get("guardrail_mode"), - ) + self.safe_set_attribute( + span=guardrail_span, + key="guardrail_name", + value=guardrail_information.get("guardrail_name"), + ) - # Set masked_entity_count directly without conversion - masked_entity_count = guardrail_information.get("masked_entity_count") - if masked_entity_count is not None: - guardrail_span.set_attribute( - "masked_entity_count", safe_dumps(masked_entity_count) + self.safe_set_attribute( + span=guardrail_span, + key="guardrail_mode", + value=guardrail_information.get("guardrail_mode"), ) - self.safe_set_attribute( - span=guardrail_span, - key="guardrail_response", - value=guardrail_information.get("guardrail_response"), - ) + masked_entity_count = guardrail_information.get("masked_entity_count") + if masked_entity_count is not None: + guardrail_span.set_attribute( + "masked_entity_count", safe_dumps(masked_entity_count) + ) - guardrail_span.end(end_time=self._to_ns(end_time_datetime)) + self.safe_set_attribute( + span=guardrail_span, + key="guardrail_response", + value=guardrail_information.get("guardrail_response"), + ) + + guardrail_span.end(end_time=self._to_ns(end_time_datetime)) def _handle_failure(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode @@ -841,10 +852,10 @@ def _handle_failure(self, kwargs, response_obj, start_time, end_time): ) span.set_status(Status(StatusCode.ERROR)) self.set_attributes(span, kwargs, response_obj) - + # Record exception information using OTEL standard method self._record_exception_on_span(span=span, kwargs=kwargs) - + span.end(end_time=self._to_ns(end_time)) # Create span for guardrail information @@ -856,7 +867,7 @@ def _handle_failure(self, kwargs, response_obj, start_time, end_time): def _record_exception_on_span(self, span: Span, kwargs: dict): """ Record exception information on the span using OTEL standard methods. - + This extracts error information from StandardLoggingPayload and: 1. Uses span.record_exception() for the actual exception object (OTEL standard) 2. Sets structured error attributes from StandardLoggingPayloadErrorInformation @@ -866,22 +877,22 @@ def _record_exception_on_span(self, span: Span, kwargs: dict): # Get the exception object if available exception = kwargs.get("exception") - + # Record the exception using OTEL's standard method if exception is not None: span.record_exception(exception) - + # Get StandardLoggingPayload for structured error information standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object" ) - + if standard_logging_payload is None: return - + # Extract error_information from StandardLoggingPayload error_information = standard_logging_payload.get("error_information") - + if error_information is None: # Fallback to error_str if error_information is not available error_str = standard_logging_payload.get("error_str") @@ -892,7 +903,7 @@ def _record_exception_on_span(self, span: Span, kwargs: dict): value=error_str, ) return - + # Set structured error attributes from StandardLoggingPayloadErrorInformation if error_information.get("error_code"): self.safe_set_attribute( @@ -900,35 +911,35 @@ def _record_exception_on_span(self, span: Span, kwargs: dict): key=ErrorAttributes.ERROR_CODE, value=error_information["error_code"], ) - + if error_information.get("error_class"): self.safe_set_attribute( span=span, key=ErrorAttributes.ERROR_TYPE, value=error_information["error_class"], ) - + if error_information.get("error_message"): self.safe_set_attribute( span=span, key=ErrorAttributes.ERROR_MESSAGE, value=error_information["error_message"], ) - + if error_information.get("llm_provider"): self.safe_set_attribute( span=span, key=ErrorAttributes.ERROR_LLM_PROVIDER, value=error_information["llm_provider"], ) - + if error_information.get("traceback"): self.safe_set_attribute( span=span, key=ErrorAttributes.ERROR_STACK_TRACE, value=error_information["traceback"], ) - + except Exception as e: verbose_logger.exception( "OpenTelemetry: Error recording exception on span: %s", str(e) @@ -1363,12 +1374,16 @@ def _get_span_context(self, kwargs): # Priority 1: Explicit parent span from metadata if parent_otel_span is not None: - verbose_logger.debug("OpenTelemetry: Using explicit parent span from metadata") + verbose_logger.debug( + "OpenTelemetry: Using explicit parent span from metadata" + ) return trace.set_span_in_context(parent_otel_span), parent_otel_span # Priority 2: HTTP traceparent header if traceparent is not None: - verbose_logger.debug("OpenTelemetry: Using traceparent header for context propagation") + verbose_logger.debug( + "OpenTelemetry: Using traceparent header for context propagation" + ) carrier = {"traceparent": traceparent} return TraceContextTextMapPropagator().extract(carrier=carrier), None @@ -1381,16 +1396,20 @@ def _get_span_context(self, kwargs): verbose_logger.debug( "OpenTelemetry: Using active span from global context: %s (trace_id=%s, span_id=%s, is_recording=%s)", current_span, - format(span_context.trace_id, '032x'), - format(span_context.span_id, '016x'), - current_span.is_recording() + format(span_context.trace_id, "032x"), + format(span_context.span_id, "016x"), + current_span.is_recording(), ) return context.get_current(), current_span except Exception as e: - verbose_logger.debug("OpenTelemetry: Error getting current span: %s", str(e)) + verbose_logger.debug( + "OpenTelemetry: Error getting current span: %s", str(e) + ) # Priority 4: No parent context - verbose_logger.debug("OpenTelemetry: No parent context found, creating root span") + verbose_logger.debug( + "OpenTelemetry: No parent context found, creating root span" + ) return None, None def _get_span_processor(self, dynamic_headers: Optional[dict] = None): diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 974dd18ccc91..fac57e038f04 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -308,9 +308,9 @@ def __init__( self.litellm_trace_id: str = litellm_trace_id or str(uuid.uuid4()) self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[Any] = ( - [] - ) # for generating complete stream response + self.sync_streaming_chunks: List[ + Any + ] = [] # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -686,9 +686,9 @@ def get_custom_logger_for_prompt_management( if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( non_default_params ): - self.model_call_details["prompt_integration"] = ( - anthropic_cache_control_logger.__class__.__name__ - ) + self.model_call_details[ + "prompt_integration" + ] = anthropic_cache_control_logger.__class__.__name__ return anthropic_cache_control_logger ######################################################### @@ -700,9 +700,9 @@ def get_custom_logger_for_prompt_management( internal_usage_cache=None, llm_router=None, ) - self.model_call_details["prompt_integration"] = ( - vector_store_custom_logger.__class__.__name__ - ) + self.model_call_details[ + "prompt_integration" + ] = vector_store_custom_logger.__class__.__name__ # Add to global callbacks so post-call hooks are invoked if ( vector_store_custom_logger @@ -762,9 +762,9 @@ def _pre_call(self, input, api_key, model=None, additional_args={}): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"]["api_base"] = ( - self._get_masked_api_base(additional_args.get("api_base", "")) - ) + self.model_call_details["litellm_params"][ + "api_base" + ] = self._get_masked_api_base(additional_args.get("api_base", "")) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 # Log the exact input to the LLM API @@ -793,10 +793,10 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - _metadata["raw_request"] = ( - "redacted by litellm. \ + _metadata[ + "raw_request" + ] = "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" - ) else: curl_command = self._get_request_curl_command( api_base=additional_args.get("api_base", ""), @@ -807,32 +807,32 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ignore_sensitive_headers=True, - ), - error=None, - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ignore_sensitive_headers=True, + ), + error=None, ) except Exception as e: - self.model_call_details["raw_request_typed_dict"] = ( - RawRequestTypedDict( - error=str(e), - ) + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + error=str(e), ) - _metadata["raw_request"] = ( - "Unable to Log \ + _metadata[ + "raw_request" + ] = "Unable to Log \ raw request: {}".format( - str(e) - ) + str(e) ) if getattr(self, "logger_fn", None) and callable(self.logger_fn): try: @@ -1133,13 +1133,13 @@ async def async_post_mcp_tool_call_hook( for callback in callbacks: try: if isinstance(callback, CustomLogger): - response: Optional[MCPPostCallResponseObject] = ( - await callback.async_post_mcp_tool_call_hook( - kwargs=kwargs, - response_obj=post_mcp_tool_call_response_obj, - start_time=start_time, - end_time=end_time, - ) + response: Optional[ + MCPPostCallResponseObject + ] = await callback.async_post_mcp_tool_call_hook( + kwargs=kwargs, + response_obj=post_mcp_tool_call_response_obj, + start_time=start_time, + end_time=end_time, ) ###################################################################### # if any of the callbacks modify the response, use the modified response @@ -1302,13 +1302,12 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None try: - response_cost = litellm.response_cost_calculator( **response_cost_calculator_kwargs ) @@ -1331,9 +1330,9 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details["response_cost_failure_debug_information"] = ( - debug_info - ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info return None @@ -1477,9 +1476,9 @@ def _success_handler_helper_fn( end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details["completion_start_time"] = ( - self.completion_start_time - ) + self.model_call_details[ + "completion_start_time" + ] = self.completion_start_time self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit @@ -1532,39 +1531,39 @@ def _success_handler_helper_fn( "response_cost" ] else: - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=logging_result) - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=logging_result) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=logging_result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=logging_result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif isinstance(result, dict) or isinstance(result, list): ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) elif standard_logging_object is not None: - self.model_call_details["standard_logging_object"] = ( - standard_logging_object - ) + self.model_call_details[ + "standard_logging_object" + ] = standard_logging_object else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None @@ -1720,23 +1719,23 @@ def success_handler( # noqa: PLR0915 verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details["complete_streaming_response"] = ( - complete_streaming_response - ) - self.model_call_details["response_cost"] = ( - self._response_cost_calculator(result=complete_streaming_response) - ) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=complete_streaming_response) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_success_callbacks, @@ -2064,10 +2063,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -2106,10 +2105,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = ( - self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} ) result = self.model_call_details["complete_response"] @@ -2247,9 +2246,9 @@ async def async_success_handler( # noqa: PLR0915 if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details["async_complete_streaming_response"] = ( - complete_streaming_response - ) + self.model_call_details[ + "async_complete_streaming_response" + ] = complete_streaming_response try: if self.model_call_details.get("cache_hit", False) is True: @@ -2260,10 +2259,10 @@ async def async_success_handler( # noqa: PLR0915 model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details["response_cost"] = ( - self._response_cost_calculator( - result=complete_streaming_response - ) + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator( + result=complete_streaming_response ) verbose_logger.debug( @@ -2276,16 +2275,16 @@ async def async_success_handler( # noqa: PLR0915 self.model_call_details["response_cost"] = None ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, ) callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_success_callbacks, @@ -2498,18 +2497,18 @@ def _failure_handler_helper_fn( ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, ) return start_time, end_time @@ -3408,9 +3407,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 endpoint=arize_config.endpoint, ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"space_id={arize_config.space_key},api_key={arize_config.api_key}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"space_id={arize_config.space_key},api_key={arize_config.api_key}" for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -3434,9 +3433,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - arize_phoenix_config.otlp_auth_headers - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = arize_phoenix_config.otlp_auth_headers for callback in _in_memory_loggers: if ( @@ -3568,9 +3567,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( - f"api_key={os.getenv('LANGTRACE_API_KEY')}" - ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -4270,10 +4269,10 @@ def get_hidden_params( for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params["additional_headers"] = ( - StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] - ) + clean_hidden_params[ + "additional_headers" + ] = StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -4497,7 +4496,7 @@ def _get_request_tags( def _get_status_fields( status: StandardLoggingPayloadStatus, - guardrail_information: Optional[dict], + guardrail_information: Optional[list[dict]], error_str: Optional[str], ) -> "StandardLoggingPayloadStatusFields": """ @@ -4528,9 +4527,13 @@ def _get_status_fields( # Map - guardrail_information.guardrail_status to guardrail_status ######################################################### guardrail_status: GuardrailStatus = "not_run" - if guardrail_information and isinstance(guardrail_information, dict): - raw_status = guardrail_information.get("guardrail_status", "not_run") - guardrail_status = GUARDRAIL_STATUS_MAP.get(raw_status, "not_run") + if guardrail_information and isinstance(guardrail_information, list): + for information in guardrail_information: + if isinstance(information, dict): + raw_status = information.get("guardrail_status", "not_run") + if raw_status != "not_run": + guardrail_status = GUARDRAIL_STATUS_MAP.get(raw_status, "not_run") + break return StandardLoggingPayloadStatusFields( llm_api_status=llm_api_status, guardrail_status=guardrail_status @@ -4832,9 +4835,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[k] = ( - "scrubbed_by_litellm_for_sensitive_keys" - ) + cleaned_user_api_key_metadata[ + k + ] = "scrubbed_by_litellm_for_sensitive_keys" else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 20a78d839a31..974e25615e40 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -764,9 +764,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[dict] = ( - {} - ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[ + dict + ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -1192,12 +1192,12 @@ class NewCustomerRequest(BudgetNewRequest): blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget spend: Optional[float] = None - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model @model_validator(mode="before") @classmethod @@ -1219,12 +1219,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[AllowedModelRegion] = ( - None # require all user requests to use models in this specific region - ) - default_model: Optional[str] = ( - None # if no equivalent model in allowed region - default all requests to this model - ) + allowed_model_region: Optional[ + AllowedModelRegion + ] = None # require all user requests to use models in this specific region + default_model: Optional[ + str + ] = None # if no equivalent model in allowed region - default all requests to this model class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1308,15 +1308,15 @@ class NewTeamRequest(TeamBase): ] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm model_tpm_limit: Optional[Dict[str, int]] = None - team_member_budget: Optional[float] = ( - None # allow user to set a budget for all team members - ) - team_member_rpm_limit: Optional[int] = ( - None # allow user to set RPM limit for all team members - ) - team_member_tpm_limit: Optional[int] = ( - None # allow user to set TPM limit for all team members - ) + team_member_budget: Optional[ + float + ] = None # allow user to set a budget for all team members + team_member_rpm_limit: Optional[ + int + ] = None # allow user to set RPM limit for all team members + team_member_tpm_limit: Optional[ + int + ] = None # allow user to set TPM limit for all team members team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None @@ -1400,9 +1400,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( - "success_and_failure" - ) + callback_type: Optional[ + Literal["success", "failure", "success_and_failure"] + ] = "success_and_failure" callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1687,9 +1687,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[List[FieldDetail]] = ( - None # For nested dictionary or Pydantic fields - ) + nested_fields: Optional[ + List[FieldDetail] + ] = None # For nested dictionary or Pydantic fields class UserHeaderMapping(LiteLLMPydanticObjectBase): @@ -2069,9 +2069,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[Any] = ( - None # You might want to replace 'Any' with a more specific type if available - ) + user: Optional[ + Any + ] = None # You might want to replace 'Any' with a more specific type if available litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -2520,7 +2520,7 @@ class SpendLogsMetadata(TypedDict): applied_guardrails: Optional[List[str]] mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] vector_store_request_metadata: Optional[List[StandardLoggingVectorStoreRequest]] - guardrail_information: Optional[StandardLoggingGuardrailInformation] + guardrail_information: Optional[list[StandardLoggingGuardrailInformation]] status: StandardLoggingPayloadStatus proxy_server_request: Optional[str] batch_models: Optional[List[str]] @@ -3004,9 +3004,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[float] = ( - None # Users max budget within the organization - ) + max_budget_in_organization: Optional[ + float + ] = None # Users max budget within the organization class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -3219,9 +3219,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[str, ProviderBudgetResponseObject] = ( - {} - ) # Dictionary mapping provider names to their budget configurations + providers: Dict[ + str, ProviderBudgetResponseObject + ] = {} # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -3355,9 +3355,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[str] = ( - None # can be either user / team, inferred from the role mapping - ) + object_id_jwt_field: Optional[ + str + ] = None # can be either user / team, inferred from the role mapping scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 2e43191e2dbf..02adcac0eca2 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -51,7 +51,7 @@ def _get_spend_logs_metadata( vector_store_request_metadata: Optional[ List[StandardLoggingVectorStoreRequest] ] = None, - guardrail_information: Optional[StandardLoggingGuardrailInformation] = None, + guardrail_information: Optional[list[StandardLoggingGuardrailInformation]] = None, usage_object: Optional[dict] = None, model_map_information: Optional[StandardLoggingModelInformation] = None, cold_storage_object_key: Optional[str] = None, @@ -95,9 +95,9 @@ def _get_spend_logs_metadata( clean_metadata["applied_guardrails"] = applied_guardrails clean_metadata["batch_models"] = batch_models clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata - clean_metadata["vector_store_request_metadata"] = ( - _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata) - ) + clean_metadata[ + "vector_store_request_metadata" + ] = _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata) clean_metadata["guardrail_information"] = guardrail_information clean_metadata["usage_object"] = usage_object clean_metadata["model_map_information"] = model_map_information diff --git a/litellm/types/utils.py b/litellm/types/utils.py index f91e7d68a8d5..e62dfbb20f0d 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -9,7 +9,6 @@ Literal, Mapping, Optional, - Tuple, Union, ) @@ -1293,7 +1292,7 @@ class ModelResponse(ModelResponseBase): choices: List[Union[Choices, StreamingChoices]] """The list of completion choices the model generated for the input prompt.""" - def __init__( + def __init__( # noqa: PLR0915 self, id=None, choices=None, @@ -2201,7 +2200,7 @@ class StandardLoggingPayload(TypedDict): error_information: Optional[StandardLoggingPayloadErrorInformation] model_parameters: dict hidden_params: StandardLoggingHiddenParams - guardrail_information: Optional[StandardLoggingGuardrailInformation] + guardrail_information: Optional[list[StandardLoggingGuardrailInformation]] standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] diff --git a/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py b/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py index d7db5ef00a05..f6b8af503fb8 100644 --- a/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py +++ b/tests/test_litellm/integrations/datadog/test_datadog_llm_observability.py @@ -528,7 +528,7 @@ def create_standard_logging_payload_with_latency_metrics() -> StandardLoggingPay error_information=None, model_parameters={"stream": True}, hidden_params=hidden_params, - guardrail_information=guardrail_info, + guardrail_information=[ guardrail_info ], trace_id="test-trace-id-latency", custom_llm_provider="openai", ) @@ -607,11 +607,11 @@ def test_latency_metrics_edge_cases(mock_env_vars): # Test case 3: Missing guardrail duration should not crash standard_payload = create_standard_logging_payload_with_cache() - standard_payload["guardrail_information"] = StandardLoggingGuardrailInformation( + standard_payload["guardrail_information"] = [StandardLoggingGuardrailInformation( guardrail_name="test", guardrail_status="success", # duration is missing - ) + )] metadata = logger._get_dd_llm_obs_payload_metadata(standard_payload) assert "guardrail_overhead_time_ms" not in metadata @@ -644,20 +644,20 @@ def test_guardrail_information_in_metadata(mock_env_vars): # Verify the guardrail information structure guardrail_info = metadata["guardrail_information"] - assert guardrail_info["guardrail_name"] == "test_guardrail" - assert guardrail_info["guardrail_status"] == "success" - assert guardrail_info["duration"] == 0.5 + assert guardrail_info[0]["guardrail_name"] == "test_guardrail" + assert guardrail_info[0]["guardrail_status"] == "success" + assert guardrail_info[0]["duration"] == 0.5 # Verify input/output fields are present - assert "guardrail_request" in guardrail_info - assert "guardrail_response" in guardrail_info + assert "guardrail_request" in guardrail_info[0] + assert "guardrail_response" in guardrail_info[0] # Validate the input/output content - assert guardrail_info["guardrail_request"]["input"] == "test input message" - assert guardrail_info["guardrail_request"]["user_id"] == "test_user" - assert guardrail_info["guardrail_response"]["output"] == "filtered output" - assert guardrail_info["guardrail_response"]["flagged"] is False - assert guardrail_info["guardrail_response"]["score"] == 0.1 + assert guardrail_info[0]["guardrail_request"]["input"] == "test input message" + assert guardrail_info[0]["guardrail_request"]["user_id"] == "test_user" + assert guardrail_info[0]["guardrail_response"]["output"] == "filtered output" + assert guardrail_info[0]["guardrail_response"]["flagged"] is False + assert guardrail_info[0]["guardrail_response"]["score"] == 0.1 def create_standard_logging_payload_with_tool_calls() -> StandardLoggingPayload: diff --git a/tests/test_litellm/integrations/test_custom_guardrail.py b/tests/test_litellm/integrations/test_custom_guardrail.py index 1106c0b36781..601c18077a5f 100644 --- a/tests/test_litellm/integrations/test_custom_guardrail.py +++ b/tests/test_litellm/integrations/test_custom_guardrail.py @@ -237,3 +237,74 @@ async def apply_guardrail(self, text, language=None, entities=None): assert hasattr( child_with_override, "apply_guardrail" ), "All instances should have apply_guardrail via inheritance" + + +class TestGuardrailLoggingAggregation: + def _make_guardrail(self): + from litellm.types.guardrails import GuardrailEventHooks + + return CustomGuardrail( + guardrail_name="test_guardrail", + event_hook=GuardrailEventHooks.pre_call, + ) + + def _invoke_add_log(self, request_data: dict) -> None: + guardrail = self._make_guardrail() + guardrail.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response={"result": "ok"}, + request_data=request_data, + guardrail_status="success", + start_time=1.0, + end_time=2.0, + duration=1.0, + masked_entity_count={"EMAIL": 1}, + guardrail_provider="presidio", + ) + + def test_appends_to_existing_metadata_list(self): + request_data = { + "metadata": { + "standard_logging_guardrail_information": [ + {"guardrail_name": "existing_guardrail"} + ] + } + } + + self._invoke_add_log(request_data) + + info = request_data["metadata"]["standard_logging_guardrail_information"] + assert isinstance(info, list) + assert len(info) == 2 + assert info[0]["guardrail_name"] == "existing_guardrail" + assert info[1]["guardrail_name"] == "test_guardrail" + + def test_converts_existing_metadata_dict_to_list(self): + request_data = { + "metadata": { + "standard_logging_guardrail_information": {"guardrail_name": "legacy"} + } + } + + self._invoke_add_log(request_data) + + info = request_data["metadata"]["standard_logging_guardrail_information"] + assert isinstance(info, list) + assert len(info) == 2 + assert info[0]["guardrail_name"] == "legacy" + assert info[1]["guardrail_name"] == "test_guardrail" + + def test_appends_to_litellm_metadata(self): + request_data = { + "litellm_metadata": { + "standard_logging_guardrail_information": [ + {"guardrail_name": "litellm_existing"} + ] + } + } + + self._invoke_add_log(request_data) + + info = request_data["litellm_metadata"]["standard_logging_guardrail_information"] + assert isinstance(info, list) + assert len(info) == 2 + assert info[1]["guardrail_name"] == "test_guardrail" diff --git a/tests/test_litellm/integrations/test_opentelemetry.py b/tests/test_litellm/integrations/test_opentelemetry.py index 68ca7bb1ecb2..2774d1687011 100644 --- a/tests/test_litellm/integrations/test_opentelemetry.py +++ b/tests/test_litellm/integrations/test_opentelemetry.py @@ -40,7 +40,7 @@ def test_create_guardrail_span_with_valid_info(self, mock_datetime): } # Create a kwargs dict with standard_logging_object containing guardrail information - kwargs = {"standard_logging_object": {"guardrail_information": guardrail_info}} + kwargs = {"standard_logging_object": {"guardrail_information": [ guardrail_info ]}} # Call the method otel._create_guardrail_span(kwargs=kwargs, context=None) @@ -156,7 +156,7 @@ def test_create_guardrail_span_with_valid_info(self, mock_datetime): } # Create a kwargs dict with standard_logging_object containing guardrail information - kwargs = {"standard_logging_object": {"guardrail_information": guardrail_info}} + kwargs = {"standard_logging_object": {"guardrail_information": [ guardrail_info ]}} # Call the method otel._create_guardrail_span(kwargs=kwargs, context=None) diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx index 4145283db41e..8aa37f007de4 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx @@ -36,30 +36,150 @@ interface GuardrailInformation { } interface GuardrailViewerProps { - data: GuardrailInformation; + data: GuardrailInformation | GuardrailInformation[]; } -const GuardrailViewer = ({ data }: GuardrailViewerProps) => { - const [sectionExpanded, setSectionExpanded] = useState(true); +interface GuardrailDetailsProps { + entry: GuardrailInformation; + index: number; + total: number; +} - // Default to presidio for backwards compatibility - const guardrailProvider = data.guardrail_provider ?? "presidio"; +const formatTime = (timestamp: number) => { + const date = new Date(timestamp * 1000); + return date.toLocaleString(); +}; - if (!data) return null; +const GuardrailDetails = ({ entry, index, total }: GuardrailDetailsProps) => { + const guardrailProvider = entry.guardrail_provider ?? "presidio"; + const statusLabel = entry.guardrail_status ?? "unknown"; + const isSuccess = statusLabel.toLowerCase() === "success"; + const maskedEntityCount = entry.masked_entity_count || {}; + const totalMaskedEntities = Object.values(maskedEntityCount).reduce( + (sum, count) => sum + (typeof count === "number" ? count : 0), + 0, + ); - const isSuccess = typeof data.guardrail_status === "string" && data.guardrail_status.toLowerCase() === "success"; + const guardrailResponse = entry.guardrail_response; + const presidioEntities = Array.isArray(guardrailResponse) ? guardrailResponse : []; + const bedrockResponse = + guardrailProvider === "bedrock" && + guardrailResponse !== null && + typeof guardrailResponse === "object" && + !Array.isArray(guardrailResponse) + ? (guardrailResponse as BedrockGuardrailResponse) + : undefined; - const tooltipTitle = isSuccess ? null : "Guardrail failed to run."; + return ( +
+ {total > 1 && ( +
+

+ Guardrail #{index + 1} + {entry.guardrail_name} +

+ + {guardrailProvider} + +
+ )} + +
+
+
+ Guardrail Name: + {entry.guardrail_name} +
+
+ Mode: + {entry.guardrail_mode} +
+
+ Status: + + + {statusLabel} + + +
+
- // Calculate total masked entities - const totalMaskedEntities = data.masked_entity_count - ? Object.values(data.masked_entity_count).reduce((sum, count) => sum + count, 0) - : 0; +
+
+ Start Time: + {formatTime(entry.start_time)} +
+
+ End Time: + {formatTime(entry.end_time)} +
+
+ Duration: + {entry.duration.toFixed(4)}s +
+
+
- const formatTime = (timestamp: number): string => { - const date = new Date(timestamp * 1000); - return date.toLocaleString(); - }; + {totalMaskedEntities > 0 && ( +
+
Masked Entity Summary
+
+ {Object.entries(maskedEntityCount).map(([entityType, count]) => ( + + {entityType}: {count} + + ))} +
+
+ )} + + {guardrailProvider === "presidio" && presidioEntities.length > 0 && ( +
+ +
+ )} + + {guardrailProvider === "bedrock" && bedrockResponse && ( +
+ +
+ )} +
+ ); +}; + +const GuardrailViewer = ({ data }: GuardrailViewerProps) => { + const guardrailEntries = Array.isArray(data) + ? data.filter((entry): entry is GuardrailInformation => Boolean(entry)) + : data + ? [data] + : []; + + if (guardrailEntries.length === 0) { + return null; + } + + const [sectionExpanded, setSectionExpanded] = useState(true); + + const primaryName = guardrailEntries.length === 1 ? guardrailEntries[0].guardrail_name : `${guardrailEntries.length} guardrails`; + const statuses = Array.from(new Set(guardrailEntries.map((entry) => entry.guardrail_status))); + const allSucceeded = statuses.every((status) => (status ?? "").toLowerCase() === "success"); + const aggregatedStatus = allSucceeded ? "success" : "failure"; + const totalMaskedEntities = guardrailEntries.reduce((sum, entry) => { + return ( + sum + + Object.values(entry.masked_entity_count || {}).reduce((acc, count) => acc + (typeof count === "number" ? count : 0), 0) + ); + }, 0); + + const tooltipTitle = allSucceeded ? null : "Guardrail failed to run."; return (
@@ -67,9 +187,9 @@ const GuardrailViewer = ({ data }: GuardrailViewerProps) => { className="flex justify-between items-center p-4 border-b cursor-pointer hover:bg-gray-50" onClick={() => setSectionExpanded(!sectionExpanded)} > -
+
{

Guardrail Information

- {/* Header status chip with tooltip */} - {data.guardrail_status} + {aggregatedStatus} + {primaryName} + {totalMaskedEntities > 0 && ( - + {totalMaskedEntities} masked {totalMaskedEntities === 1 ? "entity" : "entities"} )} @@ -99,76 +220,15 @@ const GuardrailViewer = ({ data }: GuardrailViewerProps) => {
{sectionExpanded && ( -
-
-
-
-
- Guardrail Name: - {data.guardrail_name} -
-
- Mode: - {data.guardrail_mode} -
-
- Status: - - - {data.guardrail_status} - - -
-
- -
-
- Start Time: - {formatTime(data.start_time)} -
-
- End Time: - {formatTime(data.end_time)} -
-
- Duration: - {data.duration.toFixed(4)}s -
-
-
- - {/* Masked Entity Summary */} - {data.masked_entity_count && Object.keys(data.masked_entity_count).length > 0 && ( -
-

Masked Entity Summary

-
- {Object.entries(data.masked_entity_count).map(([entityType, count]) => ( - - {entityType}: {count} - - ))} -
-
- )} -
- - {/* Provider-specific Detected Entities */} - {guardrailProvider === "presidio" && (data.guardrail_response as GuardrailEntity[])?.length > 0 && ( - - )} - - {guardrailProvider === "bedrock" && data.guardrail_response && ( -
- -
- )} +
+ {guardrailEntries.map((entry, index) => ( + + ))}
)}
diff --git a/ui/litellm-dashboard/src/components/view_logs/index.tsx b/ui/litellm-dashboard/src/components/view_logs/index.tsx index bfc57a84f7d3..fb833bdfca30 100644 --- a/ui/litellm-dashboard/src/components/view_logs/index.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/index.tsx @@ -790,20 +790,34 @@ export function RequestViewer({ row }: { row: Row }) { metadata.vector_store_request_metadata.length > 0; // Extract guardrail information from metadata if available - const hasGuardrailData = row.original.metadata && row.original.metadata.guardrail_information; + const guardrailInfo = row.original.metadata?.guardrail_information; + const guardrailEntries = Array.isArray(guardrailInfo) + ? guardrailInfo + : guardrailInfo + ? [guardrailInfo] + : []; + const hasGuardrailData = guardrailEntries.length > 0; // Calculate total masked entities if guardrail data exists - const getTotalMaskedEntities = (): number => { - if (!hasGuardrailData || !row.original.metadata?.guardrail_information.masked_entity_count) { - return 0; + const totalMaskedEntities = guardrailEntries.reduce((sum, entry) => { + const maskedCounts = entry?.masked_entity_count; + if (!maskedCounts) { + return sum; } - return Object.values(row.original.metadata.guardrail_information.masked_entity_count).reduce( - (sum: number, count: any) => sum + (typeof count === "number" ? count : 0), - 0, + return ( + sum + + Object.values(maskedCounts).reduce( + (acc, count) => (typeof count === "number" ? acc + count : acc), + 0, + ) ); - }; + }, 0); - const totalMaskedEntities = getTotalMaskedEntities(); + const primaryGuardrailLabel = guardrailEntries.length === 1 + ? guardrailEntries[0]?.guardrail_name ?? "-" + : guardrailEntries.length > 1 + ? `${guardrailEntries.length} guardrails` + : "-"; return (
@@ -850,7 +864,7 @@ export function RequestViewer({ row }: { row: Row }) {
Guardrail:
- {row.original.metadata!.guardrail_information.guardrail_name} + {primaryGuardrailLabel} {totalMaskedEntities > 0 && ( {totalMaskedEntities} masked @@ -934,7 +948,7 @@ export function RequestViewer({ row }: { row: Row }) {
{/* Guardrail Data - Show only if present */} - {hasGuardrailData && } + {hasGuardrailData && } {/* Vector Store Request Data - Show only if present */} {hasVectorStoreData && }