Skip to content

Commit 56fdb42

Browse files
committed
Copy nn.{softmax,log_softmax} to scipy.special
1 parent 300d06a commit 56fdb42

4 files changed

Lines changed: 92 additions & 12 deletions

File tree

docs/jax.scipy.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ jax.scipy.special
174174
i0e
175175
i1
176176
i1e
177+
kl_div
177178
log_ndtr
179+
log_softmax
178180
logit
179181
logsumexp
180182
lpmn
@@ -184,13 +186,13 @@ jax.scipy.special
184186
ndtri
185187
poch
186188
polygamma
189+
rel_entr
190+
softmax
187191
spence
188192
sph_harm
189193
xlog1py
190194
xlogy
191195
zeta
192-
kl_div
193-
rel_entr
194196

195197

196198
jax.scipy.stats

jax/_src/scipy/special.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from jax._src.ops import special as ops_special
3636
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
3737
from jax._src.typing import Array, ArrayLike
38+
from jax._src.nn.functions import softmax as nn_softmax
39+
from jax._src.nn.functions import log_softmax as nn_log_softmax
3840

3941

4042
def gammaln(x: ArrayLike) -> Array:
@@ -2582,3 +2584,72 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
25822584
lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
25832585
lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
25842586
)
2587+
2588+
2589+
def softmax(x: ArrayLike,
2590+
/,
2591+
*,
2592+
axis: int | tuple[int, ...] | None = None,
2593+
) -> Array:
2594+
r"""Softmax function.
2595+
2596+
JAX implementation of :func:`scipy.special.softmax`.
2597+
2598+
Computes the function which rescales elements to the range :math:`[0, 1]`
2599+
such that the elements along :code:`axis` sum to :math:`1`.
2600+
2601+
.. math ::
2602+
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
2603+
2604+
Args:
2605+
x : input array
2606+
axis: the axis or axes along which the softmax should be computed. The
2607+
softmax output summed across these dimensions should sum to :math:`1`.
2608+
2609+
Returns:
2610+
An array of the same shape as ``x``.
2611+
2612+
Note:
2613+
If any input values are ``+inf``, the result will be all ``NaN``: this
2614+
reflects the fact that ``inf / inf`` is not well-defined in the context of
2615+
floating-point math.
2616+
2617+
See also:
2618+
:func:`log_softmax`
2619+
"""
2620+
return nn_softmax(x, axis=axis)
2621+
2622+
2623+
def log_softmax(x: ArrayLike,
2624+
/,
2625+
*,
2626+
axis: int | tuple[int, ...] | None = None,
2627+
) -> Array:
2628+
r"""Log-Softmax function.
2629+
2630+
JAX implementation of :func:`scipy.special.log_softmax`
2631+
2632+
Computes the logarithm of the :code:`softmax` function, which rescales
2633+
elements to the range :math:`[-\infty, 0)`.
2634+
2635+
.. math ::
2636+
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
2637+
\right)
2638+
2639+
Args:
2640+
x : input array
2641+
axis: the axis or axes along which the :code:`log_softmax` should be
2642+
computed.
2643+
2644+
Returns:
2645+
An array of the same shape as ``x``
2646+
2647+
Note:
2648+
If any input values are ``+inf``, the result will be all ``NaN``: this
2649+
reflects the fact that ``inf / inf`` is not well-defined in the context of
2650+
floating-point math.
2651+
2652+
See also:
2653+
:func:`softmax`
2654+
"""
2655+
return nn_log_softmax(x, axis=axis)

jax/scipy/special.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from jax._src.scipy.special import (
1919
bernoulli as bernoulli,
20+
bessel_jn as bessel_jn,
21+
beta as beta,
2022
betainc as betainc,
2123
betaln as betaln,
22-
beta as beta,
23-
bessel_jn as bessel_jn,
2424
digamma as digamma,
2525
entr as entr,
2626
erf as erf,
@@ -31,31 +31,33 @@
3131
expit as expit,
3232
expn as expn,
3333
factorial as factorial,
34+
gamma as gamma,
3435
gammainc as gammainc,
3536
gammaincc as gammaincc,
3637
gammaln as gammaln,
3738
gammasgn as gammasgn,
38-
gamma as gamma,
39+
hyp1f1 as hyp1f1,
3940
i0 as i0,
4041
i0e as i0e,
4142
i1 as i1,
4243
i1e as i1e,
44+
kl_div as kl_div,
45+
log_ndtr as log_ndtr,
46+
log_softmax as log_softmax,
4347
logit as logit,
4448
logsumexp as logsumexp,
4549
lpmn as lpmn,
4650
lpmn_values as lpmn_values,
4751
multigammaln as multigammaln,
48-
log_ndtr as log_ndtr,
4952
ndtr as ndtr,
5053
ndtri as ndtri,
54+
poch as poch,
5155
polygamma as polygamma,
56+
rel_entr as rel_entr,
57+
softmax as softmax,
5258
spence as spence,
5359
sph_harm as sph_harm,
54-
xlogy as xlogy,
5560
xlog1py as xlog1py,
61+
xlogy as xlogy,
5662
zeta as zeta,
57-
kl_div as kl_div,
58-
rel_entr as rel_entr,
59-
poch as poch,
60-
hyp1f1 as hyp1f1,
6163
)

tests/lax_scipy_special_functions_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,12 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
148148
"rel_entr", 2, float_dtypes, jtu.rand_positive, True,
149149
),
150150
op_record("poch", 2, float_dtypes, jtu.rand_positive, True),
151-
op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True)
151+
op_record(
152+
"hyp1f1", 3, float_dtypes,
153+
functools.partial(jtu.rand_uniform, low=0.5, high=30), True
154+
),
155+
op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True),
156+
op_record("softmax", 1, float_dtypes, jtu.rand_default, True),
152157
]
153158

154159

0 commit comments

Comments
 (0)