From 949326f794fdf633e0c9b5f49d3bea874255f4e8 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 15 Apr 2025 12:57:21 -0700 Subject: [PATCH 1/9] impl-syrk --- CHANGELOG.md | 1 + dpnp/backend/extensions/blas/CMakeLists.txt | 1 + dpnp/backend/extensions/blas/blas_py.cpp | 22 +- dpnp/backend/extensions/blas/gemm.cpp | 8 +- dpnp/backend/extensions/blas/gemv.cpp | 46 ++- dpnp/backend/extensions/blas/gemv.hpp | 1 - dpnp/backend/extensions/blas/syrk.cpp | 297 ++++++++++++++++++ dpnp/backend/extensions/blas/syrk.hpp | 42 +++ dpnp/backend/extensions/blas/types_matrix.hpp | 25 ++ dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 55 +++- dpnp/tests/test_product.py | 45 ++- dpnp/tests/test_sycl_queue.py | 19 +- dpnp/tests/test_usm_type.py | 17 +- 13 files changed, 518 insertions(+), 61 deletions(-) create mode 100644 dpnp/backend/extensions/blas/syrk.cpp create mode 100644 dpnp/backend/extensions/blas/syrk.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c811c73ae49..f489677bd431 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478) * Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500) +* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509) ### Changed diff --git a/dpnp/backend/extensions/blas/CMakeLists.txt b/dpnp/backend/extensions/blas/CMakeLists.txt index 24b8457ffebc..82022844c979 100644 --- a/dpnp/backend/extensions/blas/CMakeLists.txt +++ b/dpnp/backend/extensions/blas/CMakeLists.txt @@ -30,6 +30,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp ) pybind11_add_module(${python_module_name} MODULE ${_module_src}) diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index 3393315ffe19..850c1e784c7b 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -36,6 +36,7 @@ #include "dotu.hpp" #include "gemm.hpp" #include "gemv.hpp" +#include "syrk.hpp" namespace blas_ns = dpnp::extensions::blas; namespace py = pybind11; @@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void) blas_ns::init_gemm_batch_dispatch_table(); blas_ns::init_gemm_dispatch_table(); blas_ns::init_gemv_dispatch_vector(); + blas_ns::init_syrk_dispatch_vector(); } static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types]; @@ -73,7 +75,7 @@ PYBIND11_MODULE(_blas_impl, m) }; m.def("_dot", dot_pyapi, - "Call `dot` from OneMKL BLAS library to compute " + "Call `dot` from oneMKL BLAS library to compute " "the dot product of two real-valued vectors.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), py::arg("result"), py::arg("depends") = py::list()); @@ -91,7 +93,7 @@ PYBIND11_MODULE(_blas_impl, m) }; m.def("_dotc", dotc_pyapi, - "Call `dotc` from OneMKL BLAS library to compute " + "Call `dotc` from oneMKL BLAS library to compute " "the dot product of two complex vectors, " "conjugating the first vector.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), @@ -110,7 +112,7 @@ PYBIND11_MODULE(_blas_impl, m) }; m.def("_dotu", dotu_pyapi, - "Call `dotu` from OneMKL BLAS library to compute " + "Call `dotu` from oneMKL BLAS library to compute " "the dot product of two complex vectors.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), py::arg("result"), py::arg("depends") = py::list()); @@ -118,7 +120,7 @@ PYBIND11_MODULE(_blas_impl, m) { m.def("_gemm", &blas_ns::gemm, - "Call `gemm` from OneMKL BLAS library to compute " + "Call `gemm` from oneMKL BLAS library to compute " "the matrix-matrix product with 2-D matrices.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), py::arg("resultC"), py::arg("depends") = py::list()); @@ -126,7 +128,7 @@ PYBIND11_MODULE(_blas_impl, m) { m.def("_gemm_batch", &blas_ns::gemm_batch, - "Call `gemm_batch` from OneMKL BLAS library to compute " + "Call `gemm_batch` from oneMKL BLAS library to compute " "the matrix-matrix product for a batch of 2-D matrices.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), py::arg("resultC"), py::arg("depends") = py::list()); @@ -134,13 +136,21 @@ PYBIND11_MODULE(_blas_impl, m) { m.def("_gemv", &blas_ns::gemv, - "Call `gemv` from OneMKL BLAS library to compute " + "Call `gemv` from oneMKL BLAS library to compute " "the matrix-vector product with a general matrix.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"), py::arg("vectorY"), py::arg("transpose"), py::arg("depends") = py::list()); } + { + m.def("_syrk", &blas_ns::syrk, + "Call `syrk` from oneMKL BLAS library to compute " + "the matrix-vector product with a general matrix.", + py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"), + py::arg("depends") = py::list()); + } + { m.def( "_using_onemath", diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index af18ab3002fb..6ef4c3c2c7b4 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -129,8 +129,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q, Tab(1), // Scaling factor for the product of matrices A and B. a, // Pointer to matrix A. lda, // Leading dimension of matrix A, which is the - // stride between successive rows (for row major - // layout). + // stride between successive rows (for row major layout). b, // Pointer to matrix B. ldb, // Leading dimension of matrix B, similar to lda. Tab(0), // Scaling factor for matrix C. @@ -168,7 +167,8 @@ std::tuple const int resultC_nd = resultC.get_ndim(); if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) { - throw py::value_error("Input matrices must be two-dimensional."); + throw py::value_error( + "Input and output matrices must be two-dimensional."); } auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); @@ -286,6 +286,8 @@ std::tuple } } else { + // both A and B are f_contig so using column-major gemm and + // no transpose is needed transA = oneapi::mkl::transpose::N; transB = oneapi::mkl::transpose::N; lda = m; diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 91057893aa5f..28993c56275d 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -118,8 +118,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q, T(1), // Scaling factor for the matrix-vector product. a, // Pointer to the input matrix A. lda, // Leading dimension of matrix A, which is the - // stride between successive rows (for row major - // layout). + // stride between successive rows (for row major layout). x, // Pointer to the input vector x. incx, // The stride of vector x. T(0), // Scaling factor for vector y. @@ -190,6 +189,26 @@ std::pair const py::ssize_t *a_shape = matrixA.get_shape_raw(); const py::ssize_t *x_shape = vectorX.get_shape_raw(); const py::ssize_t *y_shape = vectorY.get_shape_raw(); + if (transpose) { + if (a_shape[0] != x_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of elements in X."); + } + if (a_shape[1] != y_shape[0]) { + throw py::value_error("The number of columns in A must be equal to " + "the number of elements in Y."); + } + } + else { + if (a_shape[1] != x_shape[0]) { + throw py::value_error("The number of columns in A must be equal to " + "the number of elements in X."); + } + if (a_shape[0] != y_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of elements in Y."); + } + } oneapi::mkl::transpose transA; std::size_t src_nelems; @@ -243,27 +262,6 @@ std::pair } #endif // USE_ONEMATH_CUBLAS - if (transpose) { - if (a_shape[0] != x_shape[0]) { - throw py::value_error("The number of rows in A must be equal to " - "the number of elements in X."); - } - if (a_shape[1] != y_shape[0]) { - throw py::value_error("The number of columns in A must be equal to " - "the number of elements in Y."); - } - } - else { - if (a_shape[1] != x_shape[0]) { - throw py::value_error("The number of columns in A must be equal to " - "the number of elements in X."); - } - if (a_shape[0] != y_shape[0]) { - throw py::value_error("The number of rows in A must be equal to " - "the number of elements in Y."); - } - } - const std::int64_t lda = is_row_major ? n : m; dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY); dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY, @@ -287,7 +285,7 @@ std::pair "Types of input arrays and result array are mismatched."); } - char *a_typeless_ptr = matrixA.get_data(); + const char *a_typeless_ptr = matrixA.get_data(); char *x_typeless_ptr = vectorX.get_data(); char *y_typeless_ptr = vectorY.get_data(); diff --git a/dpnp/backend/extensions/blas/gemv.hpp b/dpnp/backend/extensions/blas/gemv.hpp index 88e9f9c5c6f0..094cdafdc483 100644 --- a/dpnp/backend/extensions/blas/gemv.hpp +++ b/dpnp/backend/extensions/blas/gemv.hpp @@ -41,5 +41,4 @@ extern std::pair const std::vector &depends); extern void init_gemv_dispatch_vector(void); -extern void init_gemv_batch_dispatch_vector(void); } // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/syrk.cpp b/dpnp/backend/extensions/blas/syrk.cpp new file mode 100644 index 000000000000..6974379880d5 --- /dev/null +++ b/dpnp/backend/extensions/blas/syrk.cpp @@ -0,0 +1,297 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_utils.hpp" + +#include "syrk.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp::extensions::blas +{ +namespace mkl_blas = oneapi::mkl::blas; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*syrk_impl_fn_ptr_t)(sycl::queue &, + oneapi::mkl::transpose, + const std::int64_t, + const std::int64_t, + const char *, + const std::int64_t, + char *, + const std::int64_t, +#if !defined(USE_ONEMATH_CUBLAS) + const bool, +#endif // !USE_ONEMATH_CUBLAS + const std::vector &); + +static syrk_impl_fn_ptr_t syrk_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event syrk_impl(sycl::queue &exec_q, + oneapi::mkl::transpose transA, + const std::int64_t n, + const std::int64_t k, + const char *matrixA, + const std::int64_t lda, + char *resultC, + const std::int64_t ldc, +#if !defined(USE_ONEMATH_CUBLAS) + const bool is_row_major, +#endif // !USE_ONEMATH_CUBLAS + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(matrixA); + T *res = reinterpret_cast(resultC); + + std::stringstream error_msg; + bool is_exception_caught = false; + + sycl::event syrk_event; + try { + auto syrk_func = + [&](sycl::queue &q, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose transA, const std::int64_t n, + const std::int64_t k, T alpha, const T *a, + const std::int64_t lda, T beta, T *c, const std::int64_t ldc, + const std::vector &deps) -> sycl::event { +#if defined(USE_ONEMATH_CUBLAS) + return mkl_blas::column_major::syrk(q, upper_lower, transA, n, k, + alpha, a, lda, beta, c, ldc, + deps); +#else + if (is_row_major) { + return mkl_blas::row_major::syrk(q, upper_lower, transA, n, k, + alpha, a, lda, beta, c, ldc, + deps); + } + else { + return mkl_blas::column_major::syrk(q, upper_lower, transA, n, + k, alpha, a, lda, beta, c, + ldc, deps); + } +#endif // USE_ONEMATH_CUBLAS + }; + + // we pass beta = 0, so passing upper or lower does not matter + oneapi::mkl::uplo uplo = oneapi::mkl::uplo::upper; + syrk_event = syrk_func( + exec_q, + uplo, // Specifies whether C’s data is stored in its upper + // or lower triangle + transA, // Defines the transpose operation for matrix A: + // 'N' indicates no transpose, 'T' for transpose, + // or 'C' for a conjugate transpose. + n, // Number of rows in op(A). + // Number of rows and columns in C. + k, // Number of columns in op(A). + T(1), // Scaling factor for the rank-k update. + a, // Pointer to the input matrix A. + lda, // Leading dimension of matrix A, which is the + // stride between successive rows (for row major layout). + T(0), // Scaling factor for matrix C. + res, // Pointer to output matrix c, where the result is stored. + ldc, // Leading dimension of matrix C. + depends); + } catch (oneapi::mkl::exception const &e) { + error_msg + << "Unexpected MKL exception caught during syrk() call:\nreason: " + << e.what(); + is_exception_caught = true; + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during syrk() call:\n" + << e.what(); + is_exception_caught = true; + } + + if (is_exception_caught) // an unexpected error occurs + { + throw std::runtime_error(error_msg.str()); + } + + // kernel to copy upper triangle to lower triangle + sycl::event copy_event = exec_q.submit([&](sycl::handler &h) { + h.depends_on(syrk_event); + + h.parallel_for( + sycl::range<2>{static_cast(n), static_cast(n)}, + [=](sycl::id<2> idx) { + std::int64_t i = idx[0]; + std::int64_t j = idx[1]; + if (j > i) { + res[j * ldc + i] = res[i * ldc + j]; + } + }); + }); + + return copy_event; +} + +std::pair + syrk(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &matrixA, + const dpctl::tensor::usm_ndarray &resultC, + const std::vector &depends) +{ + const int matrixA_nd = matrixA.get_ndim(); + const int resultC_nd = resultC.get_ndim(); + + if ((matrixA_nd != 2) || (resultC_nd != 2)) { + throw py::value_error("The given arrays have incorrect dimensions."); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(matrixA, resultC)) { + throw py::value_error("Input and output matrices are overlapping " + "segments of memory"); + } + + if (!dpctl::utils::queues_are_compatible( + exec_q, {matrixA.get_queue(), resultC.get_queue()})) + { + throw py::value_error( + "USM allocations are not compatible with the execution queue."); + } + + const py::ssize_t *a_shape = matrixA.get_shape_raw(); + const py::ssize_t *c_shape = resultC.get_shape_raw(); + if (c_shape[0] != c_shape[1]) { + throw py::value_error("The output matrix should be square."); + } + if (a_shape[0] != c_shape[0]) { + throw py::value_error("The number of rows in A must be equal to " + "the number of rows in result array."); + } + + const bool is_matrixA_f_contig = matrixA.is_f_contiguous(); + const bool is_matrixA_c_contig = matrixA.is_c_contiguous(); + if (!is_matrixA_f_contig and !is_matrixA_c_contig) { + throw py::value_error( + "Input matrix is not c-contiguous nor f-contiguous."); + } + + oneapi::mkl::transpose transA; + std::size_t src_nelems; + +// cuBLAS supports only column-major storage +#if defined(USE_ONEMATH_CUBLAS) + const bool is_row_major = false; + std::int64_t n; + std::int64_t k; + + if (is_matrixA_f_contig) { + transA = oneapi::mkl::transpose::N; + n = a_shape[0]; + k = a_shape[1]; + src_nelems = n * n; + } + else { + transA = oneapi::mkl::transpose::T; + k = a_shape[0]; + n = a_shape[1]; + src_nelems = k * k; + } +#else + bool is_row_major = true; + if (is_matrixA_f_contig) { + is_row_major = false; + } + + transA = oneapi::mkl::transpose::N; + const std::int64_t n = a_shape[0]; + const std::int64_t k = a_shape[1]; + src_nelems = n * n; +#endif // USE_ONEMATH_CUBLAS + + const std::int64_t lda = is_row_major ? k : n; + const std::int64_t ldc = n; + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC); + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC, + src_nelems); + + const int matrixA_typenum = matrixA.get_typenum(); + const int resultC_typenum = resultC.get_typenum(); + if (matrixA_typenum != resultC_typenum) { + throw py::value_error("Given arrays must be of the same type."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int type_id = array_types.typenum_to_lookup_id(matrixA_typenum); + syrk_impl_fn_ptr_t syrk_fn = syrk_dispatch_vector[type_id]; + if (syrk_fn == nullptr) { + throw py::value_error( + "Types of input arrays and result array are mismatched."); + } + + const char *a_typeless_ptr = matrixA.get_data(); + char *r_typeless_ptr = resultC.get_data(); + +#if defined(USE_ONEMATH_CUBLAS) + sycl::event syrk_ev = syrk_fn(exec_q, transA, n, k, a_typeless_ptr, lda, + r_typeless_ptr, ldc, depends); +#else + sycl::event syrk_ev = syrk_fn(exec_q, transA, n, k, a_typeless_ptr, lda, + r_typeless_ptr, ldc, is_row_major, depends); +#endif // USE_ONEMATH_CUBLAS + + sycl::event args_ev = + dpctl::utils::keep_args_alive(exec_q, {matrixA, resultC}, {syrk_ev}); + + return std::make_pair(args_ev, syrk_ev); +} + +template +struct SyrkContigFactory +{ + fnT get() + { + if constexpr (types::SyrkTypePairSupportFactory::is_defined) { + return syrk_impl; + } + else { + return nullptr; + } + } +}; + +void init_syrk_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(syrk_dispatch_vector); +} +} // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/syrk.hpp b/dpnp/backend/extensions/blas/syrk.hpp new file mode 100644 index 000000000000..7fd38a9abdb7 --- /dev/null +++ b/dpnp/backend/extensions/blas/syrk.hpp @@ -0,0 +1,42 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp::extensions::blas +{ +extern std::pair + syrk(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &matrixA, + const dpctl::tensor::usm_ndarray &resultC, + const std::vector &depends); + +extern void init_syrk_dispatch_vector(void); +} // namespace dpnp::extensions::blas diff --git a/dpnp/backend/extensions/blas/types_matrix.hpp b/dpnp/backend/extensions/blas/types_matrix.hpp index 7590364737bf..3d70255be313 100644 --- a/dpnp/backend/extensions/blas/types_matrix.hpp +++ b/dpnp/backend/extensions/blas/types_matrix.hpp @@ -186,4 +186,29 @@ struct GemvTypePairSupportFactory // fall-through dpctl_td_ns::NotDefinedEntry>::is_defined; }; + +/** + * @brief A factory to define pairs of supported types for which + * MKL BLAS library provides support in oneapi::mkl::blas::syrk + * function. + * + * @tparam T Type of input and output arrays. + */ +template +struct SyrkTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; } // namespace dpnp::extensions::blas::types diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index c80332ea8ebd..b123b9801e86 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -50,7 +50,30 @@ ] -def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"): +def _call_syrk(x1, x2): + """ + Check to see if `syrk` can be called instead of `gemm`. + + It is assumed that x1 and x2 are usm_ndarray objects. These arrays have + already been validated to be 2-dimensional and contiguous. Therefore, this + function only verifies the following: Both arrays reference the same + memory. The number of rows in x1 equals the number of columns in x2. If one + array is C-contiguous, the other must be F-contiguous. + + """ + call_syrk = False + if ( + x1._pointer == x2._pointer + and x1.shape[0] == x2.shape[1] + and x1.flags.c_contiguous != x2.flags.c_contiguous + and x1.flags.f_contiguous != x2.flags.f_contiguous + ): + call_syrk = True + + return call_syrk + + +def _compute_res_dtype(*arrays, dtype=None, out=None, casting="no"): """ Determines the output array data type. If `dtype` and `out` are ``None``, the output array data type of the @@ -70,8 +93,6 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"): If not ``None``, data type of the output array. casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional Controls what kind of data casting may occur. - sycl_queue : {SyclQueue} - A SYCL queue to use for determining default floating point datat type. Returns ------- @@ -334,7 +355,7 @@ def _gemm_matmul(exec_q, x1, x2, res): def _gemm_special_case(x1, x2, res_dtype, call_flag): """ `gemm` and `gemm_batch` support these special cases of data types - while `gemv` does not. + while `gemv` or `syrk` do not. """ @@ -765,9 +786,7 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False): _validate_out_array(out, exec_q) # Determine the appropriate data types - res_dtype = _compute_res_dtype( - a, b, out=out, casting=casting, sycl_queue=exec_q - ) + res_dtype = _compute_res_dtype(a, b, out=out, casting=casting) result = _create_result_array( a, b, out, (), res_dtype, res_usm_type, exec_q @@ -918,7 +937,7 @@ def dpnp_multiplication( # Determine the appropriate data types res_dtype = _compute_res_dtype( - x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q + x1, x2, dtype=dtype, out=out, casting=casting ) call_flag = None @@ -1062,7 +1081,6 @@ def dpnp_multiplication( x_usm = dpnp.get_usm_ndarray(x2) _manager = dpu.SequentialOrderManager[exec_q] - ht_ev, gemv_ev = bi._gemv( exec_q, a_usm, @@ -1073,7 +1091,20 @@ def dpnp_multiplication( ) _manager.add_event_pair(ht_ev, gemv_ev) elif call_flag == "gemm": - result = _gemm_matmul(exec_q, x1, x2, result) + x1_usm = dpnp.get_usm_ndarray(x1) + x2_usm = dpnp.get_usm_ndarray(x2) + call_syrk = _call_syrk(x1_usm, x2_usm) + if call_syrk: + _manager = dpu.SequentialOrderManager[exec_q] + ht_ev, gemv_ev = bi._syrk( + exec_q, + x1_usm, + dpnp.get_usm_ndarray(result), + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, gemv_ev) + else: + result = _gemm_matmul(exec_q, x1_usm, x2_usm, result) else: assert call_flag == "gemm_batch" result = _gemm_batch_matmul(exec_q, x1, x2, result) @@ -1217,7 +1248,7 @@ def dpnp_vecdot( if axis is not None: raise TypeError("cannot specify both `axis` and `axes`.") - axes_x1, axes_x2, axes_res = _validate_axes(x1, x2, axes, "vecdot") + axes_x1, axes_x2, _ = _validate_axes(x1, x2, axes, "vecdot") # Move the axes that are going to be used in dot product, # to the end of "x1" and "x2" @@ -1241,7 +1272,7 @@ def dpnp_vecdot( # Determine the appropriate data types res_dtype = _compute_res_dtype( - x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q + x1, x2, dtype=dtype, out=out, casting=casting ) _, x1_is_1D, _ = _define_dim_flags(x1, axis=-1) diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index b2a4514a57e8..ebfbba9bbc01 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -12,6 +12,7 @@ assert_dtype_allclose, generate_random_numpy_array, get_all_dtypes, + get_float_complex_dtypes, numpy_version, ) from .third_party.cupy import testing @@ -1059,17 +1060,19 @@ def test_strided_vec_mat(self, dtype, func, incx, incy, transpose): @pytest.mark.parametrize("dtype", _selected_dtypes) def test_out_order1(self, order1, order2, out_order, dtype): # test gemm with out keyword - a = generate_random_numpy_array((5, 4), dtype, low=-5, high=5) - b = generate_random_numpy_array((4, 7), dtype, low=-5, high=5) - a = numpy.array(a, order=order1) - b = numpy.array(b, order=order2) + a = generate_random_numpy_array( + (5, 4), dtype, order=order1, low=-5, high=5 + ) + b = generate_random_numpy_array( + (4, 7), dtype, order=order2, low=-5, high=5 + ) ia, ib = dpnp.array(a), dpnp.array(b) - iout = dpnp.empty((5, 7), dtype=dtype, order=out_order) + out = numpy.empty((5, 7), dtype=dtype, order=out_order) + iout = dpnp.array(out) result = dpnp.matmul(ia, ib, out=iout) assert result is iout - out = numpy.empty((5, 7), dtype=dtype, order=out_order) expected = numpy.matmul(a, b, out=out) assert result.flags.c_contiguous == expected.flags.c_contiguous assert result.flags.f_contiguous == expected.flags.f_contiguous @@ -1181,6 +1184,36 @@ def test_special_case(self, dt_out, shape1, shape2): result = dpnp.matmul(ia, ib, out=iout) assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dt", get_float_complex_dtypes()) + def test_syrk(self, dt): + a = generate_random_numpy_array((6, 9), dtype=dt) + ia = dpnp.array(a) + + result = dpnp.matmul(ia, ia.mT) + expected = numpy.matmul(a, a.T) + assert_dtype_allclose(result, expected) + + iout = dpnp.empty(result.shape, dtype=dt) + result = dpnp.matmul(ia, ia.mT, out=iout) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "order, out_order", + [("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")], + ) + def test_syrk_out_order(self, order, out_order): + # test syrk with out keyword + a = generate_random_numpy_array((5, 4), order=order, low=-5, high=5) + out = numpy.empty((5, 5), dtype=a.dtype, order=out_order) + ia, iout = dpnp.array(a), dpnp.array(out) + + expected = numpy.matmul(a, a.T, out=out) + result = dpnp.matmul(ia, ia.mT, out=iout) + assert result is iout + assert result.flags.c_contiguous == expected.flags.c_contiguous + assert result.flags.f_contiguous == expected.flags.f_contiguous + assert_dtype_allclose(result, expected) + def test_bool(self): a = generate_random_numpy_array((3, 4), dtype=dpnp.bool) b = generate_random_numpy_array((4, 5), dtype=dpnp.bool) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index c501bcb169e6..61a9560a3e42 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -415,9 +415,6 @@ def test_1in_1out(func, data, device): pytest.param("ldexp", [5, 5, 5, 5, 5], [0, 1, 2, 3, 4]), pytest.param("logaddexp", [-1, 2, 5, 9], [4, -3, 2, -8]), pytest.param("logaddexp2", [-1, 2, 5, 9], [4, -3, 2, -8]), - pytest.param( - "matmul", [[1.0, 0.0], [0.0, 1.0]], [[4.0, 1.0], [1.0, 2.0]] - ), pytest.param("maximum", [2.0, 3.0, 4.0], [1.0, 5.0, 2.0]), pytest.param("minimum", [2.0, 3.0, 4.0], [1.0, 5.0, 2.0]), pytest.param( @@ -633,6 +630,7 @@ def test_bitwise_op_2in(op, device): @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) +@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) @pytest.mark.parametrize( "shape1, shape2", [ @@ -658,9 +656,11 @@ def test_bitwise_op_2in(op, device): "((6, 7, 4, 3), (6, 7, 3, 5))", ], ) -def test_matmul(device, shape1, shape2): - a = dpnp.arange(numpy.prod(shape1), device=device).reshape(shape1) - b = dpnp.arange(numpy.prod(shape2), device=device).reshape(shape2) +def test_matmul(device, dtype, shape1, shape2): + # int32 checks dpctl implementation and float32 checks oneMKL + a = dpnp.arange(numpy.prod(shape1), dtype=dtype, device=device) + b = dpnp.arange(numpy.prod(shape2), dtype=dtype, device=device) + a, b = a.reshape(shape1), b.reshape(shape2) result = dpnp.matmul(a, b) result_queue = result.sycl_queue @@ -668,6 +668,13 @@ def test_matmul(device, shape1, shape2): assert_sycl_queue_equal(result_queue, b.sycl_queue) +@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) +def test_matmul_syrk(device): + a = dpnp.arange(20, dtype=dpnp.float32, device=device).reshape(4, 5) + result = dpnp.matmul(a, a.mT) + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) + + @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) @pytest.mark.parametrize( "shape1, shape2", diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index df88071e39e5..3c593096f85d 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -405,6 +405,7 @@ def test_bitwise_op_2in(op, usm_type_x, usm_type_y): @pytest.mark.parametrize("usm_type_x", list_of_usm_types) @pytest.mark.parametrize("usm_type_y", list_of_usm_types) +@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) @pytest.mark.parametrize( "shape1, shape2", [ @@ -430,9 +431,11 @@ def test_bitwise_op_2in(op, usm_type_x, usm_type_y): "((6, 7, 4, 3), (6, 7, 3, 5))", ], ) -def test_matmul(usm_type_x, usm_type_y, shape1, shape2): - x = dpnp.arange(numpy.prod(shape1), usm_type=usm_type_x).reshape(shape1) - y = dpnp.arange(numpy.prod(shape2), usm_type=usm_type_y).reshape(shape2) +def test_matmul(usm_type_x, usm_type_y, dtype, shape1, shape2): + # int32 checks dpctl implementation and float32 checks oneMKL + x = dpnp.arange(numpy.prod(shape1), dtype=dtype, usm_type=usm_type_x) + y = dpnp.arange(numpy.prod(shape2), dtype=dtype, usm_type=usm_type_y) + x, y = x.reshape(shape1), y.reshape(shape2) z = dpnp.matmul(x, y) assert x.usm_type == usm_type_x @@ -440,6 +443,14 @@ def test_matmul(usm_type_x, usm_type_y, shape1, shape2): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +@pytest.mark.parametrize("usm_type", list_of_usm_types) +def test_matmul_syrk(usm_type): + x = dpnp.arange(20, dtype=dpnp.float32, usm_type=usm_type).reshape(4, 5) + y = dpnp.matmul(x, x.mT) + + assert y.usm_type == usm_type + + @pytest.mark.parametrize("usm_type_x", list_of_usm_types) @pytest.mark.parametrize("usm_type_y", list_of_usm_types) @pytest.mark.parametrize( From 8a39e552e907300cbab8b69ce82bcf50c770b928 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 1 Jul 2025 09:58:17 -0700 Subject: [PATCH 2/9] address comments --- dpnp/backend/extensions/blas/syrk.cpp | 8 +- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 2 - dpnp/tests/test_product.py | 1 + dpnp/tests/test_sycl_queue.py | 84 ++++++++++---------- dpnp/tests/test_usm_type.py | 86 ++++++++++----------- 5 files changed, 90 insertions(+), 91 deletions(-) diff --git a/dpnp/backend/extensions/blas/syrk.cpp b/dpnp/backend/extensions/blas/syrk.cpp index 6974379880d5..986b2a28dcba 100644 --- a/dpnp/backend/extensions/blas/syrk.cpp +++ b/dpnp/backend/extensions/blas/syrk.cpp @@ -44,7 +44,7 @@ namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; typedef sycl::event (*syrk_impl_fn_ptr_t)(sycl::queue &, - oneapi::mkl::transpose, + const oneapi::mkl::transpose, const std::int64_t, const std::int64_t, const char *, @@ -60,7 +60,7 @@ static syrk_impl_fn_ptr_t syrk_dispatch_vector[dpctl_td_ns::num_types]; template static sycl::event syrk_impl(sycl::queue &exec_q, - oneapi::mkl::transpose transA, + const oneapi::mkl::transpose transA, const std::int64_t n, const std::int64_t k, const char *matrixA, @@ -107,7 +107,7 @@ static sycl::event syrk_impl(sycl::queue &exec_q, }; // we pass beta = 0, so passing upper or lower does not matter - oneapi::mkl::uplo uplo = oneapi::mkl::uplo::upper; + static constexpr auto uplo = oneapi::mkl::uplo::upper; syrk_event = syrk_func( exec_q, uplo, // Specifies whether C’s data is stored in its upper @@ -198,7 +198,7 @@ std::pair const bool is_matrixA_f_contig = matrixA.is_f_contiguous(); const bool is_matrixA_c_contig = matrixA.is_c_contiguous(); - if (!is_matrixA_f_contig and !is_matrixA_c_contig) { + if (!is_matrixA_f_contig && !is_matrixA_c_contig) { throw py::value_error( "Input matrix is not c-contiguous nor f-contiguous."); } diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index b123b9801e86..70d5bea0f16c 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -947,8 +947,6 @@ def dpnp_multiplication( x1_is_2D, x1_is_1D, x1_base_is_1D = _define_dim_flags(x1, axis=-1) x2_is_2D, x2_is_1D, x2_base_is_1D = _define_dim_flags(x2, axis=-2) - # TODO: investigate usage of syrk function from BLAS in - # case of a.T @ a and a @ a.T to gain performance. if numpy.prod(result_shape) == 0: res_shape = result_shape elif x1_shape[-1] == 1: diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index ebfbba9bbc01..17e4636d7164 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -1195,6 +1195,7 @@ def test_syrk(self, dt): iout = dpnp.empty(result.shape, dtype=dt) result = dpnp.matmul(ia, ia.mT, out=iout) + assert result is iout assert_dtype_allclose(result, expected) @pytest.mark.parametrize( diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 61a9560a3e42..0316c8a7510f 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -629,50 +629,50 @@ def test_bitwise_op_2in(op, device): assert_sycl_queue_equal(zy.sycl_queue, y.sycl_queue) -@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) -@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) -@pytest.mark.parametrize( - "shape1, shape2", - [ - ((2, 4), (4,)), - ((4,), (4, 3)), - ((2, 4), (4, 3)), - ((2, 0), (0, 3)), - ((2, 4), (4, 0)), - ((4, 2, 3), (4, 3, 5)), - ((4, 2, 3), (4, 3, 1)), - ((4, 1, 3), (4, 3, 5)), - ((6, 7, 4, 3), (6, 7, 3, 5)), - ], - ids=[ - "((2, 4), (4,))", - "((4,), (4, 3))", - "((2, 4), (4, 3))", - "((2, 0), (0, 3))", - "((2, 4), (4, 0))", - "((4, 2, 3), (4, 3, 5))", - "((4, 2, 3), (4, 3, 1))", - "((4, 1, 3), (4, 3, 5))", - "((6, 7, 4, 3), (6, 7, 3, 5))", - ], -) -def test_matmul(device, dtype, shape1, shape2): - # int32 checks dpctl implementation and float32 checks oneMKL - a = dpnp.arange(numpy.prod(shape1), dtype=dtype, device=device) - b = dpnp.arange(numpy.prod(shape2), dtype=dtype, device=device) - a, b = a.reshape(shape1), b.reshape(shape2) - result = dpnp.matmul(a, b) - - result_queue = result.sycl_queue - assert_sycl_queue_equal(result_queue, a.sycl_queue) - assert_sycl_queue_equal(result_queue, b.sycl_queue) +class TestMatmul: + @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) + @pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) + @pytest.mark.parametrize( + "shape1, shape2", + [ + ((2, 4), (4,)), + ((4,), (4, 3)), + ((2, 4), (4, 3)), + ((2, 0), (0, 3)), + ((2, 4), (4, 0)), + ((4, 2, 3), (4, 3, 5)), + ((4, 2, 3), (4, 3, 1)), + ((4, 1, 3), (4, 3, 5)), + ((6, 7, 4, 3), (6, 7, 3, 5)), + ], + ids=[ + "((2, 4), (4,))", + "((4,), (4, 3))", + "((2, 4), (4, 3))", + "((2, 0), (0, 3))", + "((2, 4), (4, 0))", + "((4, 2, 3), (4, 3, 5))", + "((4, 2, 3), (4, 3, 1))", + "((4, 1, 3), (4, 3, 5))", + "((6, 7, 4, 3), (6, 7, 3, 5))", + ], + ) + def test_matmul(self, device, dtype, shape1, shape2): + # int32 checks dpctl implementation and float32 checks oneMKL + a = dpnp.arange(numpy.prod(shape1), dtype=dtype, device=device) + b = dpnp.arange(numpy.prod(shape2), dtype=dtype, device=device) + a, b = a.reshape(shape1), b.reshape(shape2) + result = dpnp.matmul(a, b) + result_queue = result.sycl_queue + assert_sycl_queue_equal(result_queue, a.sycl_queue) + assert_sycl_queue_equal(result_queue, b.sycl_queue) -@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) -def test_matmul_syrk(device): - a = dpnp.arange(20, dtype=dpnp.float32, device=device).reshape(4, 5) - result = dpnp.matmul(a, a.mT) - assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) + @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) + def test_matmul_syrk(self, device): + a = dpnp.arange(20, dtype=dpnp.float32, device=device).reshape(4, 5) + result = dpnp.matmul(a, a.mT) + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index 3c593096f85d..aed316eca533 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -403,52 +403,52 @@ def test_bitwise_op_2in(op, usm_type_x, usm_type_y): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) -@pytest.mark.parametrize("usm_type_x", list_of_usm_types) -@pytest.mark.parametrize("usm_type_y", list_of_usm_types) -@pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) -@pytest.mark.parametrize( - "shape1, shape2", - [ - ((2, 4), (4,)), - ((4,), (4, 3)), - ((2, 4), (4, 3)), - ((2, 0), (0, 3)), - ((2, 4), (4, 0)), - ((4, 2, 3), (4, 3, 5)), - ((4, 2, 3), (4, 3, 1)), - ((4, 1, 3), (4, 3, 5)), - ((6, 7, 4, 3), (6, 7, 3, 5)), - ], - ids=[ - "((2, 4), (4,))", - "((4,), (4, 3))", - "((2, 4), (4, 3))", - "((2, 0), (0, 3))", - "((2, 4), (4, 0))", - "((4, 2, 3), (4, 3, 5))", - "((4, 2, 3), (4, 3, 1))", - "((4, 1, 3), (4, 3, 5))", - "((6, 7, 4, 3), (6, 7, 3, 5))", - ], -) -def test_matmul(usm_type_x, usm_type_y, dtype, shape1, shape2): - # int32 checks dpctl implementation and float32 checks oneMKL - x = dpnp.arange(numpy.prod(shape1), dtype=dtype, usm_type=usm_type_x) - y = dpnp.arange(numpy.prod(shape2), dtype=dtype, usm_type=usm_type_y) - x, y = x.reshape(shape1), y.reshape(shape2) - z = dpnp.matmul(x, y) - - assert x.usm_type == usm_type_x - assert y.usm_type == usm_type_y - assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +class TestMatmul: + @pytest.mark.parametrize("usm_type_x", list_of_usm_types) + @pytest.mark.parametrize("usm_type_y", list_of_usm_types) + @pytest.mark.parametrize("dtype", [dpnp.int32, dpnp.float32]) + @pytest.mark.parametrize( + "shape1, shape2", + [ + ((2, 4), (4,)), + ((4,), (4, 3)), + ((2, 4), (4, 3)), + ((2, 0), (0, 3)), + ((2, 4), (4, 0)), + ((4, 2, 3), (4, 3, 5)), + ((4, 2, 3), (4, 3, 1)), + ((4, 1, 3), (4, 3, 5)), + ((6, 7, 4, 3), (6, 7, 3, 5)), + ], + ids=[ + "((2, 4), (4,))", + "((4,), (4, 3))", + "((2, 4), (4, 3))", + "((2, 0), (0, 3))", + "((2, 4), (4, 0))", + "((4, 2, 3), (4, 3, 5))", + "((4, 2, 3), (4, 3, 1))", + "((4, 1, 3), (4, 3, 5))", + "((6, 7, 4, 3), (6, 7, 3, 5))", + ], + ) + def test_basic(self, usm_type_x, usm_type_y, dtype, shape1, shape2): + # int32 checks dpctl implementation and float32 checks oneMKL + x = dpnp.arange(numpy.prod(shape1), dtype=dtype, usm_type=usm_type_x) + y = dpnp.arange(numpy.prod(shape2), dtype=dtype, usm_type=usm_type_y) + x, y = x.reshape(shape1), y.reshape(shape2) + z = dpnp.matmul(x, y) + assert x.usm_type == usm_type_x + assert y.usm_type == usm_type_y + assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) -@pytest.mark.parametrize("usm_type", list_of_usm_types) -def test_matmul_syrk(usm_type): - x = dpnp.arange(20, dtype=dpnp.float32, usm_type=usm_type).reshape(4, 5) - y = dpnp.matmul(x, x.mT) + @pytest.mark.parametrize("usm_type", list_of_usm_types) + def test_syrk(self, usm_type): + x = dpnp.arange(20, dtype=dpnp.float32, usm_type=usm_type).reshape(4, 5) + y = dpnp.matmul(x, x.mT) - assert y.usm_type == usm_type + assert y.usm_type == usm_type @pytest.mark.parametrize("usm_type_x", list_of_usm_types) From 6057823677029b48f9e85688a01188c7301075d0 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 1 Jul 2025 16:49:45 -0700 Subject: [PATCH 3/9] fix an issue for F-contiguous arrays --- dpnp/backend/extensions/blas/syrk.cpp | 13 +++++++++++-- dpnp/tests/test_product.py | 13 ++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/dpnp/backend/extensions/blas/syrk.cpp b/dpnp/backend/extensions/blas/syrk.cpp index 986b2a28dcba..a26baaacfc59 100644 --- a/dpnp/backend/extensions/blas/syrk.cpp +++ b/dpnp/backend/extensions/blas/syrk.cpp @@ -152,11 +152,20 @@ static sycl::event syrk_impl(sycl::queue &exec_q, std::int64_t i = idx[0]; std::int64_t j = idx[1]; if (j > i) { - res[j * ldc + i] = res[i * ldc + j]; + // result form row_major::syrk is row major and result form + // column_major::syrk is column major, so copying upper + // triangle to lower triangle is different for each case + if (is_row_major) { + // row-major: res[i][j] = res[i * ldc + j] + res[j * ldc + i] = res[i * ldc + j]; + } + else { + // column-major: res[i][j] = res[i + j * ldc] + res[i * ldc + j] = res[j * ldc + i]; + } } }); }); - return copy_event; } diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 17e4636d7164..227a7bb5fd93 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -1198,12 +1198,15 @@ def test_syrk(self, dt): assert result is iout assert_dtype_allclose(result, expected) + result = ia.mT @ ia + expected = a.T @ a + assert_dtype_allclose(result, expected) + @pytest.mark.parametrize( "order, out_order", [("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")], ) def test_syrk_out_order(self, order, out_order): - # test syrk with out keyword a = generate_random_numpy_array((5, 4), order=order, low=-5, high=5) out = numpy.empty((5, 5), dtype=a.dtype, order=out_order) ia, iout = dpnp.array(a), dpnp.array(out) @@ -1215,6 +1218,14 @@ def test_syrk_out_order(self, order, out_order): assert result.flags.f_contiguous == expected.flags.f_contiguous assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("order", ["F", "C"]) + def test_syrk_order(self, order): + a = generate_random_numpy_array((4, 6), order=order, low=-5, high=5) + ia = dpnp.array(a) + expected = numpy.matmul(a, a.T) + result = dpnp.matmul(ia, ia.mT) + assert_dtype_allclose(result, expected) + def test_bool(self): a = generate_random_numpy_array((3, 4), dtype=dpnp.bool) b = generate_random_numpy_array((4, 5), dtype=dpnp.bool) From dc0081ad53940e3ceb8c2f9ad7408cd597e4f49e Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 2 Jul 2025 07:19:16 -0700 Subject: [PATCH 4/9] update error message when backend function is not available for specific dtype --- dpnp/backend/extensions/blas/dot_common.hpp | 3 ++- dpnp/backend/extensions/blas/gemm.cpp | 3 ++- dpnp/backend/extensions/blas/gemm_batch.cpp | 3 ++- dpnp/backend/extensions/blas/gemv.cpp | 3 ++- dpnp/backend/extensions/blas/syrk.cpp | 3 ++- dpnp/backend/extensions/lapack/evd_batch_common.hpp | 3 ++- dpnp/backend/extensions/lapack/evd_common.hpp | 3 ++- 7 files changed, 14 insertions(+), 7 deletions(-) diff --git a/dpnp/backend/extensions/blas/dot_common.hpp b/dpnp/backend/extensions/blas/dot_common.hpp index fb9a1f078c53..169421a2464c 100644 --- a/dpnp/backend/extensions/blas/dot_common.hpp +++ b/dpnp/backend/extensions/blas/dot_common.hpp @@ -128,7 +128,8 @@ std::pair dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id]; if (dot_fn == nullptr) { throw py::value_error( - "Types of input vectors and result array are mismatched."); + "No dot implementation is available for the specified data type " + "of the input and output arrays."); } char *x_typeless_ptr = vectorX.get_data(); diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index 6ef4c3c2c7b4..b1cfb10859d2 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -315,7 +315,8 @@ std::tuple gemm_dispatch_table[matrixAB_type_id][resultC_type_id]; if (gemm_fn == nullptr) { throw py::value_error( - "Types of input matrices and result matrix are mismatched."); + "No gemm implementation is available for the specified data type " + "of the input and output arrays."); } const char *a_typeless_ptr = matrixA.get_data(); diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 1e210aede9fa..926dd8720272 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -389,7 +389,8 @@ std::tuple gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id]; if (gemm_batch_fn == nullptr) { throw py::value_error( - "Types of input matrices and result matrix are mismatched."); + "No gemm_batch implementation is available for the specified data " + "type of the input and output arrays."); } const char *a_typeless_ptr = matrixA.get_data(); diff --git a/dpnp/backend/extensions/blas/gemv.cpp b/dpnp/backend/extensions/blas/gemv.cpp index 28993c56275d..e06af6577920 100644 --- a/dpnp/backend/extensions/blas/gemv.cpp +++ b/dpnp/backend/extensions/blas/gemv.cpp @@ -282,7 +282,8 @@ std::pair gemv_impl_fn_ptr_t gemv_fn = gemv_dispatch_vector[type_id]; if (gemv_fn == nullptr) { throw py::value_error( - "Types of input arrays and result array are mismatched."); + "No gemv implementation is available for the specified data type " + "of the input and output arrays."); } const char *a_typeless_ptr = matrixA.get_data(); diff --git a/dpnp/backend/extensions/blas/syrk.cpp b/dpnp/backend/extensions/blas/syrk.cpp index a26baaacfc59..90e5e89c7325 100644 --- a/dpnp/backend/extensions/blas/syrk.cpp +++ b/dpnp/backend/extensions/blas/syrk.cpp @@ -262,7 +262,8 @@ std::pair syrk_impl_fn_ptr_t syrk_fn = syrk_dispatch_vector[type_id]; if (syrk_fn == nullptr) { throw py::value_error( - "Types of input arrays and result array are mismatched."); + "No syrk implementation is available for the specified data type " + "of the input and output arrays."); } const char *a_typeless_ptr = matrixA.get_data(); diff --git a/dpnp/backend/extensions/lapack/evd_batch_common.hpp b/dpnp/backend/extensions/lapack/evd_batch_common.hpp index 9610d6aa568a..3545db01458c 100644 --- a/dpnp/backend/extensions/lapack/evd_batch_common.hpp +++ b/dpnp/backend/extensions/lapack/evd_batch_common.hpp @@ -97,7 +97,8 @@ std::pair evd_batch_dispatch_table[eig_vecs_type_id][eig_vals_type_id]; if (evd_batch_fn == nullptr) { throw py::value_error( - "Types of input vectors and result array are mismatched."); + "No evd_batch implementation is available for the specified data " + "type of the input and output arrays."); } char *eig_vecs_data = eig_vecs.get_data(); diff --git a/dpnp/backend/extensions/lapack/evd_common.hpp b/dpnp/backend/extensions/lapack/evd_common.hpp index 5503d8f82052..3964943c5305 100644 --- a/dpnp/backend/extensions/lapack/evd_common.hpp +++ b/dpnp/backend/extensions/lapack/evd_common.hpp @@ -91,7 +91,8 @@ std::pair evd_dispatch_table[eig_vecs_type_id][eig_vals_type_id]; if (evd_fn == nullptr) { throw py::value_error( - "Types of input vectors and result array are mismatched."); + "No evd implementation is available for the specified data type " + "of the input and output arrays."); } char *eig_vecs_data = eig_vecs.get_data(); From 04d72f80a2ae0b8e6b1357897c71d5cba9bff28b Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 2 Jul 2025 07:29:37 -0700 Subject: [PATCH 5/9] use syrk for int dtypes when possible --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 78 +++++++++++---------- dpnp/tests/test_product.py | 30 +++++++- 2 files changed, 69 insertions(+), 39 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 70d5bea0f16c..3201b64b7265 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -50,29 +50,6 @@ ] -def _call_syrk(x1, x2): - """ - Check to see if `syrk` can be called instead of `gemm`. - - It is assumed that x1 and x2 are usm_ndarray objects. These arrays have - already been validated to be 2-dimensional and contiguous. Therefore, this - function only verifies the following: Both arrays reference the same - memory. The number of rows in x1 equals the number of columns in x2. If one - array is C-contiguous, the other must be F-contiguous. - - """ - call_syrk = False - if ( - x1._pointer == x2._pointer - and x1.shape[0] == x2.shape[1] - and x1.flags.c_contiguous != x2.flags.c_contiguous - and x1.flags.f_contiguous != x2.flags.f_contiguous - ): - call_syrk = True - - return call_syrk - - def _compute_res_dtype(*arrays, dtype=None, out=None, casting="no"): """ Determines the output array data type. @@ -541,6 +518,29 @@ def _get_signature(func): return signature, distinct_core +def _is_syrk_compatible(x1, x2): + """ + Check to see if `syrk` can be called instead of `gemm`. + Input arrays have already been validated to be 2-dimensional. + + """ + # Must share data (same base buffer) + if dpnp.get_usm_ndarray(x1)._pointer != dpnp.get_usm_ndarray(x2)._pointer: + return False + + # Result must be square + if x1.shape[0] != x2.shape[1]: + return False + + # Strides must match transpose pattern + x1_strides = x1.strides + x2_strides = x2.strides + if x1_strides[0] != x2_strides[1] or x1_strides[1] != x2_strides[0]: + return False + + return True + + def _shape_error(shape1, shape2, func, err_msg): """Validate the shapes of input and output arrays.""" @@ -983,6 +983,11 @@ def dpnp_multiplication( x1 = dpnp.reshape(x1, x1_shape[-2:]) x2 = dpnp.reshape(x2, x2_shape[-2:]) res_shape = (x1_shape[-2], x2_shape[-1]) + if _is_syrk_compatible(x1, x2): + call_flag = "syrk" + res_dtype_orig = res_dtype + if dpnp.issubdtype(res_dtype, dpnp.integer): + res_dtype = dpnp.default_float_type(x1.device) elif x1_base_is_1D: # TODO: implement gemv_batch to use it here with transpose call_flag = "gemm_batch" @@ -1088,21 +1093,17 @@ def dpnp_multiplication( depends=_manager.submitted_events, ) _manager.add_event_pair(ht_ev, gemv_ev) + elif call_flag == "syrk": + _manager = dpu.SequentialOrderManager[exec_q] + ht_ev, gemv_ev = bi._syrk( + exec_q, + dpnp.get_usm_ndarray(x1), + dpnp.get_usm_ndarray(result), + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, gemv_ev) elif call_flag == "gemm": - x1_usm = dpnp.get_usm_ndarray(x1) - x2_usm = dpnp.get_usm_ndarray(x2) - call_syrk = _call_syrk(x1_usm, x2_usm) - if call_syrk: - _manager = dpu.SequentialOrderManager[exec_q] - ht_ev, gemv_ev = bi._syrk( - exec_q, - x1_usm, - dpnp.get_usm_ndarray(result), - depends=_manager.submitted_events, - ) - _manager.add_event_pair(ht_ev, gemv_ev) - else: - result = _gemm_matmul(exec_q, x1_usm, x2_usm, result) + result = _gemm_matmul(exec_q, x1, x2, result) else: assert call_flag == "gemm_batch" result = _gemm_batch_matmul(exec_q, x1, x2, result) @@ -1130,6 +1131,9 @@ def dpnp_multiplication( elif res_shape != result_shape: result = dpnp.reshape(result, result_shape) + if call_flag == "syrk" and res_dtype_orig != res_dtype: + result = result.astype(res_dtype_orig) + if out is None: if axes is not None: # Move the data back to the appropriate axes of the result array diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 227a7bb5fd93..383af34bf3b7 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -12,7 +12,6 @@ assert_dtype_allclose, generate_random_numpy_array, get_all_dtypes, - get_float_complex_dtypes, numpy_version, ) from .third_party.cupy import testing @@ -1184,7 +1183,7 @@ def test_special_case(self, dt_out, shape1, shape2): result = dpnp.matmul(ia, ib, out=iout) assert_dtype_allclose(result, expected) - @pytest.mark.parametrize("dt", get_float_complex_dtypes()) + @pytest.mark.parametrize("dt", get_all_dtypes()) def test_syrk(self, dt): a = generate_random_numpy_array((6, 9), dtype=dt) ia = dpnp.array(a) @@ -1202,6 +1201,21 @@ def test_syrk(self, dt): expected = a.T @ a assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dt", [dpnp.int32, dpnp.float32]) + def test_syrk_strided(self, dt): + a = generate_random_numpy_array((20, 30), dtype=dt) + ia = dpnp.array(a) + a = a[::2, ::2] + ia = ia[::2, ::2] + + result = dpnp.matmul(ia, ia.mT) + expected = numpy.matmul(a, a.T) + assert_dtype_allclose(result, expected) + + result = ia.mT @ ia + expected = a.T @ a + assert_dtype_allclose(result, expected) + @pytest.mark.parametrize( "order, out_order", [("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")], @@ -1226,6 +1240,18 @@ def test_syrk_order(self, order): result = dpnp.matmul(ia, ia.mT) assert_dtype_allclose(result, expected) + # added for coverage + def test_not_syrk(self): + a = generate_random_numpy_array((20, 20), low=-5, high=5) + ia = dpnp.array(a) + + # Result must be square + b = a.mT[:, ::2] + ib = ia.mT[:, ::2] + expected = numpy.matmul(a, b) + result = dpnp.matmul(ia, ib) + assert_dtype_allclose(result, expected) + def test_bool(self): a = generate_random_numpy_array((3, 4), dtype=dpnp.bool) b = generate_random_numpy_array((4, 5), dtype=dpnp.bool) From 3d4c17110a59ea72d314b4199543e83d4b33f6dc Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 3 Jul 2025 12:26:08 -0700 Subject: [PATCH 6/9] use .T for numpy array tranpose --- dpnp/tests/test_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 383af34bf3b7..97fc60afd828 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -1246,7 +1246,7 @@ def test_not_syrk(self): ia = dpnp.array(a) # Result must be square - b = a.mT[:, ::2] + b = a.T[:, ::2] ib = ia.mT[:, ::2] expected = numpy.matmul(a, b) result = dpnp.matmul(ia, ib) From f5dc9ab2e91573eb5a00306fe51d8ffe8f52037b Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 3 Jul 2025 13:53:12 -0700 Subject: [PATCH 7/9] remove unnecessary setup_method --- dpnp/tests/test_product.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 97fc60afd828..70ec90ea5728 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -22,9 +22,6 @@ class TestCross: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("axis", [None, 0]) @pytest.mark.parametrize("axisc", [-1, 0]) @pytest.mark.parametrize("axisb", [-1, 0]) @@ -179,9 +176,6 @@ def test_linalg_error(self): class TestDot: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_ones(self, dtype): n = 10**5 @@ -430,9 +424,6 @@ def test_out_error(self, shape1, shape2, out_shape): class TestInner: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_scalar(self, dtype): a = 2 @@ -596,9 +587,6 @@ def test_order(self, order): class TestMatmul: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", _selected_dtypes) @pytest.mark.parametrize( "order1, order2", [("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")] @@ -1185,7 +1173,7 @@ def test_special_case(self, dt_out, shape1, shape2): @pytest.mark.parametrize("dt", get_all_dtypes()) def test_syrk(self, dt): - a = generate_random_numpy_array((6, 9), dtype=dt) + a = generate_random_numpy_array((6, 9), dtype=dt, low=-5, high=5) ia = dpnp.array(a) result = dpnp.matmul(ia, ia.mT) @@ -1507,9 +1495,6 @@ def test_invalid_axes(self, xp): @testing.with_requires("numpy>=2.2") class TestMatvec: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shape1, shape2", @@ -1569,9 +1554,6 @@ def test_error(self, xp): class TestMultiDot: - def setup_method(self): - numpy.random.seed(70) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shapes", @@ -1726,9 +1708,6 @@ def test_error(self): class TestTensordot: - def setup_method(self): - numpy.random.seed(87) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_scalar(self, dtype): a = 2 @@ -1860,9 +1839,6 @@ def test_error(self): class TestVdot: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_scalar(self, dtype): a = numpy.array([3.5], dtype=dtype) @@ -1953,9 +1929,6 @@ def test_error(self): @testing.with_requires("numpy>=2.0") class TestVecdot: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shape1, shape2", @@ -2233,9 +2206,6 @@ def test_error(self, xp): @testing.with_requires("numpy>=2.2") class TestVecmat: - def setup_method(self): - numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize( "shape1, shape2", From 03d08c5ef6ecb132e32af4f57a03128edf6ac2e8 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 3 Jul 2025 13:58:55 -0700 Subject: [PATCH 8/9] avoid unnecessary copy of x2 in syrk --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 3201b64b7265..69fbd8f40632 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -1068,12 +1068,13 @@ def dpnp_multiplication( dtype=res_dtype, order=res_order, ) - x2 = _copy_array( - x2, - copy_flag=not x2_contig_flag, - dtype=res_dtype, - order=res_order, - ) + if call_flag != "syrk": + x2 = _copy_array( + x2, + copy_flag=not x2_contig_flag, + dtype=res_dtype, + order=res_order, + ) if call_flag == "gemv": if transpose: From 5898507458d089546bcff2672f908faacc4d95e5 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 3 Jul 2025 14:07:09 -0700 Subject: [PATCH 9/9] use syrk for boolean dtypes when possible --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 5 ++++- dpnp/tests/test_product.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 69fbd8f40632..8e21e1ca4dac 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -986,7 +986,10 @@ def dpnp_multiplication( if _is_syrk_compatible(x1, x2): call_flag = "syrk" res_dtype_orig = res_dtype - if dpnp.issubdtype(res_dtype, dpnp.integer): + # for exact dtypes, use syrk implementation unlike general approach + # where dpctl implementation is used for exact dtypes for better + # performance + if not dpnp.issubdtype(res_dtype, dpnp.inexact): res_dtype = dpnp.default_float_type(x1.device) elif x1_base_is_1D: # TODO: implement gemv_batch to use it here with transpose diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index 70ec90ea5728..4f50b222caca 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -1171,7 +1171,7 @@ def test_special_case(self, dt_out, shape1, shape2): result = dpnp.matmul(ia, ib, out=iout) assert_dtype_allclose(result, expected) - @pytest.mark.parametrize("dt", get_all_dtypes()) + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True)) def test_syrk(self, dt): a = generate_random_numpy_array((6, 9), dtype=dt, low=-5, high=5) ia = dpnp.array(a)