diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index b7ac5311cb3..81aea21af18 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -166,6 +166,7 @@ std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, + const int trans_code, const std::vector &depends) { const int a_array_nd = a_array.get_ndim(); @@ -264,11 +265,20 @@ std::pair const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); - // Use transpose::T if the LU-factorized array is passed as C-contiguous. - // For F-contiguous we use transpose::N. - oneapi::mkl::transpose trans = is_a_array_c_contig - ? oneapi::mkl::transpose::T - : oneapi::mkl::transpose::N; + oneapi::mkl::transpose trans; + switch (trans_code) { + case 0: + trans = oneapi::mkl::transpose::N; + break; + case 1: + trans = oneapi::mkl::transpose::T; + break; + case 2: + trans = oneapi::mkl::transpose::C; + break; + default: + throw py::value_error("`trans_code` must be 0 (N), 1 (T), or 2 (C)"); + } char *a_array_data = a_array.get_data(); char *b_array_data = b_array.get_data(); diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 8fa4889c99a..30db88c62fe 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -37,6 +37,7 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, + const int trans_code, const std::vector &depends = {}); extern void init_getrs_dispatch_vector(void); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 4d5adfe09e4..9dc22419e57 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -160,7 +160,8 @@ PYBIND11_MODULE(_lapack_impl, m) "the solves of linear equations with an LU-factored " "square coefficient matrix, with multiple right-hand sides", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), - py::arg("b_array"), py::arg("depends") = py::list()); + py::arg("b_array"), py::arg("trans_code"), + py::arg("depends") = py::list()); m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return " diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 55d140c5c88..1a7d452935a 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2632,6 +2632,12 @@ def dpnp_solve(a, b): _manager = dpu.SequentialOrderManager[exec_q] dev_evs = _manager.submitted_events + # TODO: remove after PR #2558 is merged + # Temporarily set trans_code=1 (transpose) because the LU-factorized + # array is C-contiguous. + # For F-contiguous arrays use 0 (non-transpose) + trans_code = 1 + # use DPCTL tensor function to fill the сopy of the input array # from the input array ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( @@ -2688,6 +2694,7 @@ def dpnp_solve(a, b): a_h.get_array(), ipiv_h.get_array(), b_h.get_array(), + trans_code, depends=[b_copy_ev, getrf_ev], ) _manager.add_event_pair(ht_ev, getrs_ev)