Skip to content

Commit ab0be62

Browse files
authored
Fix: warnings in softplus and inverse_softplus #270 (#275)
1 parent 6449a92 commit ab0be62

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

bayesflow/utils/numpy_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def inverse_shifted_softplus(
1616

1717
def inverse_softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray:
1818
"""Numerically stabilized inverse softplus function."""
19-
return np.where(beta * x > threshold, x, np.log(beta * np.expm1(x)) / beta)
19+
with np.errstate(over="ignore"):
20+
expm1_x = np.expm1(x)
21+
return np.where(beta * x > threshold, x, np.log(beta * expm1_x) / beta)
2022

2123

2224
def one_hot(indices: np.ndarray, num_classes: int, dtype: str = "float32") -> np.ndarray:
@@ -37,4 +39,6 @@ def shifted_softplus(
3739

3840
def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray:
3941
"""Numerically stabilized softplus function."""
40-
return np.where(beta * x > threshold, x, np.log1p(np.exp(beta * x)) / beta)
42+
with np.errstate(over="ignore"):
43+
exp_beta_x = np.exp(beta * x)
44+
return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)

0 commit comments

Comments
 (0)