Skip to content

Commit

Permalink
optional caching of L-moment weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Jul 3, 2023
1 parent b676e8e commit d7cb375
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 11 deletions.
46 changes: 40 additions & 6 deletions lmo/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
'l_kurtosis',
)

from typing import Any, TypeVar, cast
from typing import Any, Final, TypeVar, cast

import numpy as np
import numpy.typing as npt
Expand All @@ -29,9 +29,16 @@
T = TypeVar('T', bound=np.floating[Any])


# Low-level methods
# Low-level weight methods

def l0_weights(
_L_WEIGHTS_CACHE: Final[
dict[
tuple[int, int, int], # (n, t_1, t_2)
npt.NDArray[np.floating[Any]]
]
] = {}

def _l0_weights(
r: int,
n: int,
/,
Expand Down Expand Up @@ -125,6 +132,8 @@ def l_weights(
/,
trim: tuple[int, int] = (0, 0),
dtype: np.dtype[T] | type[T] = np.float_,
*,
cache: bool = False
) -> npt.NDArray[T]:
"""
Projection matrix of the first $r$ (T)L-moments for $n$ samples.
Expand Down Expand Up @@ -174,8 +183,22 @@ def l_weights(
L-moments](https://doi.org/10.1016/j.jspi.2006.12.002)
"""
cache_key = n, *trim
if (
cache_key in _L_WEIGHTS_CACHE
and (P_r := _L_WEIGHTS_CACHE[cache_key]).shape[0] <= r
):
if P_r.dtype is not np.dtype(dtype):
P_r = P_r.view(dtype)
if P_r.shape[0] < r:
P_r = P_r[:r]

assert P_r.shape == (r, n)
return cast(npt.NDArray[T], P_r)


if sum(trim) == 0:
return l0_weights(r, n, dtype)
return _l0_weights(r, n, dtype)

P_r = np.empty((r, n), dtype)

Expand All @@ -187,7 +210,7 @@ def l_weights(

np.matmul(
hosking_jacobi(r, trim),
l0_weights(r + sum(trim), n),
_l0_weights(r + sum(trim), n),
out=P_r
)

Expand All @@ -197,6 +220,11 @@ def l_weights(
P_r[:, :t1] = P_r[:, n - t2:] = 0
P_r[1:, t1:n - t2] -= P_r[1:, t1:n - t2].mean(1, keepdims=True)

if cache:
# memoize, and mark as readonly to avoid corruping the cache
P_r.setflags(write=False)
_L_WEIGHTS_CACHE[cache_key] = P_r

return P_r


Expand All @@ -213,6 +241,7 @@ def l_moment(
fweights: IntVector | None = None,
aweights: npt.ArrayLike | None = None,
sort: SortKind | None = 'stable',
cache: bool = False,
) -> T | npt.NDArray[T]:
"""
Estimates the generalized trimmed L-moment $\\lambda^{(t_1, t_2)}_r$ from
Expand Down Expand Up @@ -277,6 +306,11 @@ def l_moment(
sort ('quick' | 'stable' | 'heap'):
Sorting algorithm, see [`numpy.sort`][numpy.sort].
cache:
Set to `True` to speed up future L-moment calculations that have
the same number of observations in `a`, equal `trim`, and equal or
smaller `r`.
Returns:
l:
The L-moment(s) of the input This is a scalar iff a is 1-d and
Expand Down Expand Up @@ -322,7 +356,7 @@ def l_moment(
n = x_k.shape[-1]

# projection matrix
P_r = l_weights(r_max, n, trim, dtype=dtype)
P_r = l_weights(r_max, n, trim, dtype=dtype, cache=cache)

l_r = np.inner(P_r, x_k)

Expand Down
11 changes: 8 additions & 3 deletions lmo/_lm_co.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def l_comoment(
dtype: np.dtype[T] | type[T] = np.float_,
*,
sort: SortKind | None = 'stable',
cache: bool = False,
) -> npt.NDArray[T]:
"""
Multivariate extension of [`lmo.l_moment`][lmo.l_moment]. Estimates the
Expand Down Expand Up @@ -102,6 +103,11 @@ def l_comoment(
sort ('quick' | 'stable' | 'heap'):
Sorting algorithm, see [`numpy.sort`][numpy.sort].
cache:
Set to `True` to speed up future L-moment calculations that have
the same number of observations in `a`, equal `trim`, and equal or
smaller `r`.
Returns:
L: Array of shape `(*r.shape, m, m)` with r-th L-comoments.
Expand Down Expand Up @@ -145,7 +151,7 @@ def _clean_array(arr: npt.ArrayLike) -> npt.NDArray[T]:
return np.empty(np.shape(_r) + (0, 0), dtype=dtype)

# projection matrix of shape (r, n)
P_r = l_weights(r_max, n, trim, dtype=dtype)
P_r = l_weights(r_max, n, trim, dtype=dtype, cache=cache)

# L-comoment matrices for r = 0, ..., r_max
L_ij = np.empty((r_max + 1, m, m), dtype=dtype)
Expand All @@ -154,10 +160,9 @@ def _clean_array(arr: npt.ArrayLike) -> npt.NDArray[T]:
# matrix is the identity matrix
L_ij[0] = np.eye(m, dtype=dtype)

kwargs = {'axis': -1, 'dtype': dtype, 'sort': sort}
for j in range(m):
# concomitants of x[i] w.r.t. x[j] for all i
x_k_ij = ordered(x, x[j], **kwargs)
x_k_ij = ordered(x, x[j], axis=-1, dtype=dtype, sort=sort)

L_ij[1:, :, j] = np.inner(P_r, x_k_ij)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from lmo._lm import l0_weights, l_weights
from lmo._lm import l_weights

st_n = st.integers(32, 1024)
st_r = st.integers(1, 8)
Expand All @@ -26,7 +26,7 @@

@given(n=st_n, r=st_r, trim0=st.just((0, 0)))
def test_l_weights_alias(n, r, trim0):
w_l = l0_weights(r, n)
w_l = l_weights(r, n)
w_tl = l_weights(r, n, trim0)

assert np.array_equal(w_l, w_tl)
Expand Down

0 comments on commit d7cb375

Please sign in to comment.