From a6ddea0f8e02ed5ea7d6d47064bb89bee2283830 Mon Sep 17 00:00:00 2001
From: "evan.zhang5" <evan.zhang5@nio.com>
Date: Thu, 24 Oct 2024 18:08:23 +0800
Subject: [PATCH] fix:Added condition array for filtering 0 values

---
 machine_learning/loss_functions.py | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/machine_learning/loss_functions.py b/machine_learning/loss_functions.py
index 0bd9aa8b5401..8bc16e1c1c7a 100644
--- a/machine_learning/loss_functions.py
+++ b/machine_learning/loss_functions.py
@@ -645,6 +645,11 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
     - y_true: True class probabilities
     - y_pred: Predicted class probabilities
 
+    >>> true_labels = np.array([0, 0.4, 0.6])
+    >>> predicted_probs = np.array([0.3, 0.3, 0.4])
+    >>> float(kullback_leibler_divergence(true_labels, predicted_probs))
+    0.35835189384561095
+
     >>> true_labels = np.array([0.2, 0.3, 0.5])
     >>> predicted_probs = np.array([0.3, 0.3, 0.4])
     >>> float(kullback_leibler_divergence(true_labels, predicted_probs))
@@ -659,6 +664,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
     if len(y_true) != len(y_pred):
         raise ValueError("Input arrays must have the same length.")
 
+    filter_array = y_true != 0
+    y_true = y_true[filter_array]
+    y_pred = y_pred[filter_array]
     kl_loss = y_true * np.log(y_true / y_pred)
     return np.sum(kl_loss)