Skip to content

Commit

Permalink
feat: svd (#66)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Aug 1, 2024
1 parent 4233306 commit 77537a5
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/quaxed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

# pylint: disable=redefined-builtin

__all__ = ["__version__", "array_api"]
__all__ = ["__version__", "array_api", "scipy"]

import sys
from typing import Any

import plum
from jaxtyping import ArrayLike

from . import _jax, array_api
from . import _jax, array_api, scipy
from ._jax import *
from ._version import version as __version__

Expand Down
6 changes: 2 additions & 4 deletions src/quaxed/scipy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Quaxed :mod:`jax.scipy`."""

__all__ = [
"special",
]
__all__ = ["linalg", "special"]

from . import special
from . import linalg, special
29 changes: 29 additions & 0 deletions src/quaxed/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# ruff:noqa: F822

"""Quaxed :mod:`jax.scipy.linalg`."""

__all__ = [
"svd",
]

import sys
from collections.abc import Callable
from typing import Any

import jax.scipy.linalg
from quax import quaxify


def __dir__() -> list[str]:
return sorted(__all__)


# TODO: better return type annotation
def __getattr__(name: str) -> Callable[..., Any]:
# Quaxify the func
func = quaxify(getattr(jax.scipy.linalg, name))

# Cache the function in this module
setattr(sys.modules[__name__], name, func)

return func
1 change: 1 addition & 0 deletions tests/scipy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test `quaxed.numpy`."""
23 changes: 23 additions & 0 deletions tests/scipy/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Test with JAX inputs."""

import jax
import jax.numpy as jnp

import quaxed


def test_dir():
"""Test the `__dir__` method."""
assert quaxed.scipy.linalg.__dir__() == quaxed.scipy.linalg.__all__


def test_svd():
"""Test `quaxed.scipy.linalg.svd`."""
assert hasattr(jax.scipy.linalg, "svd")
assert hasattr(quaxed.scipy.linalg, "svd")

x = jnp.array([[1, 2], [3, 4]])
got = quaxed.scipy.linalg.svd(x, compute_uv=False)
expected = jax.scipy.linalg.svd(x, compute_uv=False)

assert jnp.array_equal(got, expected)

0 comments on commit 77537a5

Please sign in to comment.