File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -16,7 +16,9 @@ def inverse_shifted_softplus(
1616
1717def inverse_softplus (x : np .ndarray , beta : float = 1.0 , threshold : float = 20.0 ) -> np .ndarray :
1818 """Numerically stabilized inverse softplus function."""
19- return np .where (beta * x > threshold , x , np .log (beta * np .expm1 (x )) / beta )
19+ with np .errstate (over = "ignore" ):
20+ expm1_x = np .expm1 (x )
21+ return np .where (beta * x > threshold , x , np .log (beta * expm1_x ) / beta )
2022
2123
2224def one_hot (indices : np .ndarray , num_classes : int , dtype : str = "float32" ) -> np .ndarray :
@@ -37,4 +39,6 @@ def shifted_softplus(
3739
3840def softplus (x : np .ndarray , beta : float = 1.0 , threshold : float = 20.0 ) -> np .ndarray :
3941 """Numerically stabilized softplus function."""
40- return np .where (beta * x > threshold , x , np .log1p (np .exp (beta * x )) / beta )
42+ with np .errstate (over = "ignore" ):
43+ exp_beta_x = np .exp (beta * x )
44+ return np .where (beta * x > threshold , x , np .log1p (exp_beta_x ) / beta )
You can’t perform that action at this time.
0 commit comments