Skip to content

Commit 390e903

Browse files
committed
Add .hypothesis/ directory to .gitignore
and ppf and cdf to scipy.stats.uniform
1 parent c0d51e7 commit 390e903

File tree

5 files changed

+55
-0
lines changed

5 files changed

+55
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
.envrc
2222
jax.iml
2323
.bazelrc.user
24+
.hypothesis/
2425

2526
# virtualenv/venv directories
2627
/venv/

docs/jax.scipy.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,8 @@ jax.scipy.stats.uniform
425425

426426
logpdf
427427
pdf
428+
cdf
429+
ppf
428430

429431
jax.scipy.stats.gaussian_kde
430432
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

jax/_src/scipy/stats/uniform.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import scipy.stats as osp_stats
1717

1818
from jax import lax
19+
from jax import numpy as jnp
1920
from jax.numpy import where, inf, logical_or
2021
from jax._src.typing import Array, ArrayLike
2122
from jax._src.numpy.util import _wraps, promote_args_inexact
@@ -32,3 +33,21 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
3233
@_wraps(osp_stats.uniform.pdf, update_doc=False)
3334
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
3435
return lax.exp(logpdf(x, loc, scale))
36+
37+
@_wraps(osp_stats.uniform.cdf, update_doc=False)
38+
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
39+
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
40+
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
41+
conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))]
42+
vals = [zero, one, lax.div(lax.sub(x, loc), scale)]
43+
44+
return jnp.select(conds, vals)
45+
46+
@_wraps(osp_stats.uniform.ppf, update_doc=False)
47+
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
48+
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
49+
return where(
50+
jnp.isnan(q) | (q < 0) | (q > 1),
51+
jnp.nan,
52+
lax.add(loc, lax.mul(scale, q))
53+
)

jax/scipy/stats/uniform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@
1818
from jax._src.scipy.stats.uniform import (
1919
logpdf as logpdf,
2020
pdf as pdf,
21+
cdf as cdf,
22+
ppf as ppf,
2123
)

tests/scipy_stats_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,36 @@ def args_maker():
10431043
tol=1e-4)
10441044
self._CompileAndCheck(lax_fun, args_maker)
10451045

1046+
@genNamedParametersNArgs(3)
1047+
def testUniformCdf(self, shapes, dtypes):
1048+
rng = jtu.rand_default(self.rng())
1049+
scipy_fun = osp_stats.uniform.cdf
1050+
lax_fun = lsp_stats.uniform.cdf
1051+
1052+
def args_maker():
1053+
x, loc, scale = map(rng, shapes, dtypes)
1054+
return [x, loc, np.abs(scale)]
1055+
1056+
with jtu.strict_promotion_if_dtypes_match(dtypes):
1057+
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
1058+
tol=1e-5)
1059+
self._CompileAndCheck(lax_fun, args_maker)
1060+
1061+
@genNamedParametersNArgs(3)
1062+
def testUniformPpf(self, shapes, dtypes):
1063+
rng = jtu.rand_default(self.rng())
1064+
scipy_fun = osp_stats.uniform.ppf
1065+
lax_fun = lsp_stats.uniform.ppf
1066+
1067+
def args_maker():
1068+
q, loc, scale = map(rng, shapes, dtypes)
1069+
return [q, loc, np.abs(scale)]
1070+
1071+
with jtu.strict_promotion_if_dtypes_match(dtypes):
1072+
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
1073+
tol=1e-5)
1074+
self._CompileAndCheck(lax_fun, args_maker)
1075+
10461076
@genNamedParametersNArgs(4)
10471077
def testChi2LogPdf(self, shapes, dtypes):
10481078
rng = jtu.rand_positive(self.rng())
@@ -1058,6 +1088,7 @@ def args_maker():
10581088
tol=5e-4)
10591089
self._CompileAndCheck(lax_fun, args_maker)
10601090

1091+
10611092
@genNamedParametersNArgs(4)
10621093
def testChi2LogCdf(self, shapes, dtypes):
10631094
rng = jtu.rand_positive(self.rng())

0 commit comments

Comments
 (0)