-
Notifications
You must be signed in to change notification settings - Fork 39
WIP: add compatibility shims for {eig,eigvals} #379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
We need a wrapper because numpy currently returns `float|complex`. Implementation-wise, follow `linalg.solve` and copy-paste relevant numpy code with minimal required modifications.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds compatibility shims for the eig and eigvals linear algebra functions to support the 2025.12 array API standard. The shims ensure dtype stability for NumPy by always returning complex dtypes, as required by the Array API specification, while also handling xfails for libraries that don't yet support these functions.
Changes:
- Added
EigResultnamedtuple class to common/_linalg.py - Implemented
eig()andeigvals()shim functions in numpy/linalg.py with dtype conversion logic - Added xfail entries for dask test suite (dask doesn't have these functions yet)
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| dask-xfails.txt | Added xfail entries for test_eig and test_eigvals under 2025.12 support section |
| array_api_compat/common/_linalg.py | Added EigResult namedtuple to match the pattern of other result types |
| array_api_compat/numpy/linalg.py | Implemented eig() and eigvals() shims with dtype stability fixes and imported EigResult |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def eig(x: Array, /) -> tuple[Array, Array]: | ||
| try: | ||
| from numpy.linalg._linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| except ImportError: | ||
| from numpy.linalg.linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| from numpy.linalg import _umath_linalg | ||
|
|
||
| x, wrap = _makearray(x) | ||
| _assert_stacked_square(x) | ||
| _assert_finite(x) | ||
| t, result_t = _commonType(x) | ||
|
|
||
| signature = 'D->DD' if isComplexType(t) else 'd->DD' | ||
| with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, | ||
| invalid='call', over='ignore', divide='ignore', | ||
| under='ignore'): | ||
| w, vt = _umath_linalg.eig(x, signature=signature) | ||
|
|
||
| result_t = _complexType(result_t) | ||
| vt = vt.astype(result_t, copy=False) | ||
| return EigResult(w.astype(result_t, copy=False), wrap(vt)) |
Copilot
AI
Jan 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type annotation for the eig function indicates it returns a tuple[Array, Array], but the implementation returns an EigResult (a NamedTuple). This inconsistency should be corrected to return EigResult instead of tuple[Array, Array] to match the actual behavior and to be consistent with similar functions like eigh which properly returns EighResult.
| ) | ||
| from numpy.linalg import _umath_linalg | ||
|
|
||
| x, wrap = _makearray(x) |
Copilot
AI
Jan 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable 'wrap' returned from _makearray is not used in the eigvals function, but is assigned. This suggests either it should be used (similar to how it's used in the eig function to wrap the result) or should be replaced with an underscore to indicate it's intentionally unused.
| def eig(x: Array, /) -> tuple[Array, Array]: | ||
| try: | ||
| from numpy.linalg._linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| except ImportError: | ||
| from numpy.linalg.linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| from numpy.linalg import _umath_linalg | ||
|
|
||
| x, wrap = _makearray(x) | ||
| _assert_stacked_square(x) | ||
| _assert_finite(x) | ||
| t, result_t = _commonType(x) | ||
|
|
||
| signature = 'D->DD' if isComplexType(t) else 'd->DD' | ||
| with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, | ||
| invalid='call', over='ignore', divide='ignore', | ||
| under='ignore'): | ||
| w, vt = _umath_linalg.eig(x, signature=signature) | ||
|
|
||
| result_t = _complexType(result_t) | ||
| vt = vt.astype(result_t, copy=False) | ||
| return EigResult(w.astype(result_t, copy=False), wrap(vt)) | ||
|
|
||
|
|
||
| def eigvals(x: Array, /) -> Array: | ||
| try: | ||
| from numpy.linalg._linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| except ImportError: | ||
| from numpy.linalg.linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| from numpy.linalg import _umath_linalg | ||
|
|
||
| x, wrap = _makearray(x) | ||
| _assert_stacked_square(x) | ||
| _assert_finite(x) | ||
| t, result_t = _commonType(x) | ||
|
|
||
| signature = 'D->D' if isComplexType(t) else 'd->D' | ||
| with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, | ||
| invalid='call', over='ignore', divide='ignore', | ||
| under='ignore'): | ||
| w = _umath_linalg.eigvals(x, signature=signature) | ||
|
|
||
| result_t = _complexType(result_t) | ||
| return w.astype(result_t, copy=False) |
Copilot
AI
Jan 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication between the eig and eigvals functions. Both functions import the same set of helper functions from numpy.linalg, perform the same validation steps (_assert_stacked_square, _assert_finite), and use similar dtype handling logic. Consider extracting the common logic into a shared helper function to improve maintainability and reduce duplication.
| # 2025.12 support | ||
| array_api_tests/test_linalg.py::test_eig | ||
| array_api_tests/test_linalg.py::test_eigvals | ||
|
|
Copilot
AI
Jan 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the PR description, CuPy needs xfail entries for eig and eigvals tests until the minimum CuPy version is >=13. However, these xfail entries are missing from cupy-xfails.txt. The dask-xfails.txt file includes the necessary entries (lines 133-134), but cupy-xfails.txt should also have corresponding entries under a "# 2025.12 support" section similar to the pattern used in dask-xfails.txt.
WIP until 2025.12 is released.
daskdoes not haveeig, thus xfail the testsnumpyneeds a shim for return dtype stabilitycupyneeds an xfail until the min cupy version is <=13data-apis/array-api-tests#404 is the matching test PR