From 6797cf30c3dabcca6ee848aa8f7c84dadb807d67 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 14 Aug 2025 03:47:26 -0700 Subject: [PATCH 1/3] Use gesv instead of getrf/getrs in dpnp_solve --- dpnp/linalg/dpnp_utils_linalg.py | 49 ++++++-------------------------- 1 file changed, 9 insertions(+), 40 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 55d140c5c88..8a542eddc6f 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2619,15 +2619,9 @@ def dpnp_solve(a, b): a_usm_arr = dpnp.get_usm_ndarray(a) b_usm_arr = dpnp.get_usm_ndarray(b) - # Due to MKLD-17226 (bug with incorrect checking ldb parameter - # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error - # `invalid argument` when nrhs > n) we can not use _gesv directly. - # This w/a uses _getrf and _getrs instead - # to handle cases where nrhs > n for a.shape = (n x n) - # and b.shape = (n x nrhs). - - # oneMKL LAPACK getrf overwrites `a`. - a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=res_usm_type) + # oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as + # input + a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=res_usm_type) _manager = dpu.SequentialOrderManager[exec_q] dev_evs = _manager.submitted_events @@ -2658,39 +2652,14 @@ def dpnp_solve(a, b): ) _manager.add_event_pair(ht_ev, b_copy_ev) - n = a.shape[0] - - ipiv_h = dpnp.empty_like( - a, - shape=(n,), - dtype=dpnp.int64, - ) - dev_info_h = [0] - - # Call the LAPACK extension function _getrf - # to perform LU decomposition of the input matrix - ht_ev, getrf_ev = li._getrf( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - dev_info_h, - depends=[a_copy_ev], + # Call the LAPACK extension function _gesv to solve the system of linear + # equations with the coefficient square matrix and + # the dependent variables array. + ht_lapack_ev, gesv_ev = li._gesv( + exec_q, a_h.get_array(), b_h.get_array(), [a_copy_ev, b_copy_ev] ) - _manager.add_event_pair(ht_ev, getrf_ev) - _check_lapack_dev_info(dev_info_h) - - # Call the LAPACK extension function _getrs - # to solve the system of linear equations with an LU-factored - # coefficient square matrix, with multiple right-hand sides. - ht_ev, getrs_ev = li._getrs( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - b_h.get_array(), - depends=[b_copy_ev, getrf_ev], - ) - _manager.add_event_pair(ht_ev, getrs_ev) + _manager.add_event_pair(ht_lapack_ev, gesv_ev) return b_h From 3d737dc20a4153573db9d8ebda6334a59a8513d1 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 14 Aug 2025 04:08:00 -0700 Subject: [PATCH 2/3] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb31d3cd325..168cfb5b19d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Replaced `ci` section in `.pre-commit-config.yaml` with a new GitHub workflow with scheduled run to autoupdate the `pre-commit` configuration [#2542](https://github.com/IntelPython/dpnp/pull/2542) * FFT module is updated to perform in-place FFT in intermediate steps of ND FFT [#2543](https://github.com/IntelPython/dpnp/pull/2543) * Reused dpctl tensor include to enable experimental SYCL namespace for complex types [#2546](https://github.com/IntelPython/dpnp/pull/2546) +* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558) ### Deprecated From 452dd74174535ce531133b16a4a3d0bdab9f4d60 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 22 Aug 2025 05:24:14 -0700 Subject: [PATCH 3/3] Use assert_dtype_allclose to TestSolve --- dpnp/tests/test_linalg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index f4dc1f74add..61800756593 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2565,7 +2565,7 @@ def test_solve(self, dtype): expected = numpy.linalg.solve(a_np, a_np) result = dpnp.linalg.solve(a_dp, a_dp) - assert_allclose(result, expected) + assert_dtype_allclose(result, expected) @testing.with_requires("numpy>=2.0") @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) @@ -2638,12 +2638,12 @@ def test_solve_strides(self): # positive strides expected = numpy.linalg.solve(a_np[::2, ::2], b_np[::2]) result = dpnp.linalg.solve(a_dp[::2, ::2], b_dp[::2]) - assert_allclose(result, expected, rtol=1e-6) + assert_dtype_allclose(result, expected) # negative strides expected = numpy.linalg.solve(a_np[::-2, ::-2], b_np[::-2]) result = dpnp.linalg.solve(a_dp[::-2, ::-2], b_dp[::-2]) - assert_allclose(result, expected) + assert_dtype_allclose(result, expected) @pytest.mark.parametrize( "matrix, vector",