Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dependency updates and typing simplifications #335

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ repos:
- id: markdownlint

- repo: https://github.com/adamchainz/blacken-docs
rev: 1.19.0
rev: 1.19.1
hooks:
- id: blacken-docs
additional_dependencies: [black==24.*]
Expand All @@ -78,7 +78,7 @@ repos:
- id: codespell

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.1
hooks:
- id: ruff
args: [--fix, --show-fixes]
Expand Down
42 changes: 22 additions & 20 deletions lmo/contrib/scipy_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from __future__ import annotations

import contextlib
from collections.abc import Callable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
Expand All @@ -23,10 +23,6 @@

import numpy as np
import numpy.typing as npt
from scipy.stats import (
fit as scipy_fit,
rv_discrete,
)
from scipy.stats.distributions import rv_continuous, rv_frozen

import lmo.typing as lmt
Expand All @@ -51,11 +47,21 @@
l_stats_cov_from_cdf,
)

if TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence

from scipy.stats import rv_discrete

__all__ = "install", "l_rv_frozen", "l_rv_generic"


_T = TypeVar("_T")
_T_x = TypeVar("_T_x", bound=float | npt.NDArray[np.float64])
_T_x = TypeVar("_T_x", float, npt.NDArray[np.float64])


class _Fn1(Protocol):
def __call__(self, x: _T_x, /) -> _T_x: ...


_Tuple2: TypeAlias = tuple[_T, _T]
_Tuple4: TypeAlias = tuple[_T, _T, _T, _T]
Expand Down Expand Up @@ -141,23 +147,23 @@ class l_rv_generic(PatchClass):
cdf: lspt.RVFunction[...]
fit: Callable[..., tuple[float, ...]]
mean: Callable[..., float]
ppf: lspt.RVFunction[...]
ppf: _Fn1
std: Callable[..., float]

def _get_xxf(
self,
*args: Any,
loc: float = 0,
scale: float = 1,
) -> _Tuple2[Callable[[float], float]]:
) -> _Tuple2[_Fn1]:
assert scale > 0

_cdf, _ppf = self._cdf, self._ppf

def cdf(x: float, /) -> float:
def cdf(x: _T_x, /) -> _T_x:
return _cdf(np.array([(x - loc) / scale], dtype=float), *args)[0]

def ppf(q: float, /) -> float:
def ppf(q: _T_x, /) -> _T_x:
return _ppf(np.array([q], dtype=float), *args)[0] * scale + loc

return cdf, ppf
Expand Down Expand Up @@ -814,7 +820,7 @@ def l_moment_influence(
quad_opts: lspt.QuadOptions | None = None,
tol: float = 1e-8,
**kwds: Any,
) -> Callable[[_T_x], _T_x]:
) -> _Fn1:
r"""
Returns the influence function (IF) of an L-moment.

Expand Down Expand Up @@ -883,10 +889,7 @@ def l_moment_influence(
lm = self.l_moment(r, *args, trim=trim, quad_opts=quad_opts, **kwds)

args, loc, scale = self._parse_args(*args, **kwds)
cdf = cast(
Callable[[_ArrF8], _ArrF8],
self._get_xxf(*args, loc=loc, scale=scale)[0],
)
cdf = self._get_xxf(*args, loc=loc, scale=scale)[0]

return l_moment_influence_from_cdf(
cdf,
Expand Down Expand Up @@ -980,10 +983,7 @@ def l_ratio_influence(
)

args, loc, scale = self._parse_args(*args, **kwds)
cdf = cast(
Callable[[_ArrF8], _ArrF8],
self._get_xxf(*args, loc=loc, scale=scale)[0],
)
cdf = self._get_xxf(*args, loc=loc, scale=scale)[0]

return l_ratio_influence_from_cdf(
cdf,
Expand Down Expand Up @@ -1207,7 +1207,9 @@ def l_fit(
else:
# almost never works without custom (finite and tight) bounds...
# ... and otherwise it'll runs for +-17 exa-eons
rv = cast(rv_discrete, self)
from scipy.stats import fit as scipy_fit

rv = cast("rv_discrete", self)
bounds0 = [
(-np.inf if a is None else a, np.inf if b is None else b)
for a, b in bounds
Expand Down
62 changes: 39 additions & 23 deletions lmo/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Any,
Final,
NamedTuple,
Protocol,
TypeAlias,
TypeVar,
cast,
Expand All @@ -20,13 +21,7 @@

import numpy as np
import numpy.typing as npt
from scipy.integrate import quad
from scipy.optimize import OptimizeWarning, minimize
from scipy.special import chdtrc
from scipy.stats.distributions import rv_continuous, rv_frozen

import lmo.typing.np as lnpt
import lmo.typing.scipy as lspt
from . import constants
from ._lm import l_ratio
from ._poly import extrema_jacobi
Expand All @@ -41,9 +36,10 @@


if TYPE_CHECKING:
import lmo.typing.np as lnpt
import lmo.typing.scipy as lspt
from .contrib.scipy_stats import l_rv_generic


__all__ = (
"error_sensitivity",
"l_moment_bounds",
Expand All @@ -57,6 +53,12 @@


_T = TypeVar("_T")
_T_x = TypeVar("_T_x", float, npt.NDArray[np.float64])


class _Fn1(Protocol):
def __call__(self, x: _T_x, /) -> _T_x: ...


_Tuple2: TypeAlias = tuple[_T, _T]
_ArrF8: TypeAlias = npt.NDArray[np.float64]
Expand Down Expand Up @@ -190,8 +192,20 @@ def _gof_stat_single(l_obs: _ArrF8, l_exp: _ArrF8, cov: _ArrF8) -> float:
)


def _is_rv(x: object) -> TypeIs[lspt.RVFrozen | lspt.RV]:
from scipy.stats.distributions import (
rv_continuous,
rv_discrete,
rv_frozen,
rv_histogram,
)

# NOTE: this assumes that the (private) `rv_generic` class is a sealed type
return isinstance(x, rv_frozen | rv_continuous | rv_discrete | rv_histogram)


def l_moment_gof(
rv_or_cdf: lspt.RV | Callable[[float], float],
rv_or_cdf: lspt.RV | lspt.RVFrozen | _Fn1,
l_moments: _ArrF8,
n_obs: int,
/,
Expand Down Expand Up @@ -270,28 +284,26 @@ def l_moment_gof(

r = np.arange(1, 1 + n)

if isinstance(rv_or_cdf, rv_continuous.__base__ | rv_frozen):
if _is_rv(rv_or_cdf):
rv = cast("l_rv_generic", rv_or_cdf)
lambda_r = rv.l_moment(r, trim=trim, **kwargs)
lambda_rr = rv.l_moments_cov(n, trim=trim, **kwargs)
else:
from .theoretical import l_moment_cov_from_cdf, l_moment_from_cdf

cdf = cast(Callable[[float], float], rv_or_cdf)
cdf = rv_or_cdf
lambda_r = l_moment_from_cdf(cdf, r, trim, **kwargs)
lambda_rr = l_moment_cov_from_cdf(cdf, n, trim, **kwargs)

from scipy.special import chdtrc

stat = n_obs * _gof_stat(l_r.T, lambda_r, lambda_rr).T[()]
pval = chdtrc(n, stat)
return HypothesisTestResult(stat, pval)


def _is_rv(x: object) -> TypeIs[lspt.RV]:
return isinstance(x, lspt.RV)


def l_stats_gof(
rv_or_cdf: lspt.RV | Callable[[float], float],
rv_or_cdf: lspt.RV | lspt.RVFrozen | _Fn1,
l_stats: _ArrF8,
n_obs: int,
/,
Expand Down Expand Up @@ -319,8 +331,10 @@ def l_stats_gof(
tau_rr = l_stats_cov_from_cdf(cdf, n, trim, **kwargs)
tau_r = l_stats_from_cdf(cdf, n, trim, **kwargs)

from scipy.special import chdtrc

stat = n_obs * _gof_stat(t_r.T, tau_r, tau_rr).T[()]
pval = cast(float | _ArrF8, chdtrc(n, stat))
pval = chdtrc(n, stat)
return HypothesisTestResult(stat, pval)


Expand Down Expand Up @@ -729,19 +743,17 @@ def rejection_point(
if influence_fn(rho_max) != 0 or influence_fn(-rho_max) != 0:
return np.nan

from scipy.integrate import quad

def integrand(x: float) -> float:
return max(abs(influence_fn(-x)), abs(influence_fn(x)))

def obj(r: _ArrF8) -> float:
return quad(integrand, r[0], np.inf)[0]

# TO
res = minimize(
obj,
bounds=[(rho_min, rho_max)],
x0=[rho_min],
method="COBYLA",
)
from scipy.optimize import minimize

res = minimize(obj, bounds=[(rho_min, rho_max)], x0=[rho_min], method="COBYLA")

rho = cast(float, res.x[0])
if rho <= _MIN_RHO or influence_fn(-rho) or influence_fn(rho):
Expand Down Expand Up @@ -801,6 +813,8 @@ def obj(xs: _ArrF8) -> float:

bounds = None if np.isneginf(a) and np.isposinf(b) else [(a, b)]

from scipy.optimize import OptimizeWarning, minimize

res = minimize(obj, bounds=bounds, x0=[min(max(0, a), b)], method="COBYLA")
if not res.success:
warnings.warn(res.message, OptimizeWarning, stacklevel=1)
Expand Down Expand Up @@ -876,6 +890,8 @@ def obj(xs: _ArrF8) -> float:
a, b = domain
bounds = None if np.isneginf(a) and np.isposinf(b) else [(a, b)]

from scipy.optimize import OptimizeWarning, minimize

res = minimize(
obj,
bounds=bounds,
Expand Down
6 changes: 1 addition & 5 deletions lmo/distributions/_genlambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import functools
import math
import sys
from collections.abc import Callable
from typing import TYPE_CHECKING, Final, TypeAlias, TypeVar, cast

import numpy as np
Expand Down Expand Up @@ -260,10 +259,7 @@ def _l_moment(
lmbda_r = cast(
float | _ArrF8,
l_moment_from_ppf(
cast(
Callable[[float], float],
functools.partial(self._ppf, b=b, d=d, f=f),
),
functools.partial(self._ppf, b=b, d=d, f=f),
r,
trim=trim,
quad_opts=quad_opts,
Expand Down
6 changes: 3 additions & 3 deletions lmo/distributions/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def lm_genextreme(r: int, s: float, t: float, /, a: float) -> np.float64 | float
# - conditionals within the function are avoided through multiple functions
if a == 0:

def _ppf(q: float) -> float:
def _ppf(q: float, /) -> float:
if q <= 0:
return -float("inf")
if q >= 1:
Expand All @@ -331,7 +331,7 @@ def _ppf(q: float) -> float:

elif a < 0:

def _ppf(q: float) -> float:
def _ppf(q: float, /) -> float:
if q <= 0:
return 1 / a
if q >= 1:
Expand All @@ -340,7 +340,7 @@ def _ppf(q: float) -> float:

else: # a > 0

def _ppf(q: float) -> float:
def _ppf(q: float, /) -> float:
if q <= 0:
return -float("inf")
if q >= 1:
Expand Down
Loading