Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sparse/numba_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
full,
full_like,
imag,
interp,
isinf,
isnan,
matmul,
Expand Down Expand Up @@ -252,6 +253,7 @@
"int32",
"int64",
"int8",
"interp",
"isfinite",
"isinf",
"isnan",
Expand Down
81 changes: 81 additions & 0 deletions sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3250,3 +3250,84 @@ def diff(x, axis=-1, n=1, prepend=None, append=None):
for _ in range(n):
result = result[(slice(None),) * axis + (slice(1, None),)] - result[(slice(None),) * axis + (slice(None, -1),)]
return result


def interp(x, xp, fp, left=None, right=None, period=None):
"""
An implementation of ``numpy.interp`` for sparse arrays.

Thanks to the function dispatch of numpy, this enables interpolation on sparse arrays
using the numpy universal function. This function effectively wraps ``np.interp`` by
calling it on the array data and the fill value. See the numpy documentation for
details on the parameters.

Parameters
----------
x : SparseArray
The x-coordinates at which to evaluate the interpolated values.
xp : 1-D sequence or SparseArray
The x-coordinates of the data points.
fp : 1-D sequence or SparseArray
The y-coordinates of the data points, same length as ``xp``.
left : float or complex, optional
Value to return for ``x < xp[0]``, default is ``fp[0]``.
right : float or complex, optional
Value to return for ``x > xp[-1]``, default is ``fp[-1]``.
period : None or float, optional
A period for the x-coordinates.

Returns
-------
out : SparseArray
The interpolated values, same shape as x.

See Also
--------
https://numpy.org/doc/stable/reference/generated/numpy.interp.html

Examples
--------
When interpolating a sparse array, its data and the fill value are interpolated. The
returned array is pruned. Therefore, the fill value and the number of nonzero
elements might change.

>>> import numpy as np
>>> xp = [1, 2, 3]
>>> fp = [3, 2, 0]
>>> y = np.interp(sparse.COO.from_numpy(np.array([0, 1, 1.5, 2.72, 3.14])), xp, fp)
>>> y.to_dense()
array([3. , 3. , 2.5 , 0.56, 0. ])
>>> y.fill_value
3.0
>>> y.nnz
3
"""
from ._compressed import GCXS
from ._coo import COO
from ._dok import DOK

# Densify sparse interpolants
if isinstance(xp, SparseArray):
xp = xp.todense()
if isinstance(fp, SparseArray):
fp = fp.todense()

# Define output type
out_kwargs = {}
out_type = COO
if isinstance(x, GCXS):
out_type = GCXS
out_kwargs["compressed_axes"] = x.compressed_axes
elif isinstance(x, DOK):
out_type = DOK

def interp_func(xx):
return np.interp(xx, xp, fp, left=left, right=right, period=period)

# Perform interpolation
arr = as_coo(x)
data = interp_func(arr.data)
fill_value = interp_func(arr.fill_value)
return COO(data=data, coords=arr.coords, shape=arr.shape, fill_value=fill_value, prune=True).asformat(
out_type, **out_kwargs
)
51 changes: 51 additions & 0 deletions sparse/numba_backend/tests/test_array_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,54 @@ def test_asarray(self, input, dtype, format):
expected = input.todense() if hasattr(input, "todense") else np.asarray(input)

np.testing.assert_equal(actual, expected)


class TestInterp:
x = np.array([[0, 1, 1.5, 2.72, 3.14], [0, 0, 0, -1, 3.14]])
xp = [1, 2, 3]
fp = [3, 2, 0]

@pytest.mark.parametrize(
"x",
[
x,
sparse.COO.from_numpy(x),
sparse.GCXS.from_numpy(x),
sparse.DOK.from_numpy(x),
],
)
@pytest.mark.parametrize(
"xp",
[
xp,
np.array(xp),
sparse.COO.from_numpy(np.array(xp)),
],
)
@pytest.mark.parametrize(
"fp",
[
fp,
np.array(fp),
sparse.COO.from_numpy(np.array(fp)),
],
)
@pytest.mark.parametrize("module", [sparse, np])
def test_interp(self, x, xp, fp, module):
y = module.interp(x, xp, fp)
if isinstance(x, sparse.SparseArray):
assert isinstance(y, type(x))
else:
if module is sparse or any(isinstance(obj, sparse.SparseArray) for obj in (x, xp, fp)):
assert isinstance(y, sparse.COO)
else:
assert isinstance(y, np.ndarray)

np.testing.assert_array_almost_equal(
y.todense() if hasattr(y, "todense") else y,
[[3.0, 3.0, 2.5, 0.56, 0.0], [3.0, 3.0, 3.0, 3.0, 0.0]],
decimal=13,
)
if isinstance(y, sparse.SparseArray):
assert y.nnz == 4
assert y.fill_value == 3.0
Loading