Skip to content

using syrk for performing special cases of matrix multiplication #2509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478)
* Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500)
* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509)

### Changed

Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/extensions/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down
22 changes: 16 additions & 6 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "dotu.hpp"
#include "gemm.hpp"
#include "gemv.hpp"
#include "syrk.hpp"

namespace blas_ns = dpnp::extensions::blas;
namespace py = pybind11;
Expand All @@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void)
blas_ns::init_gemm_batch_dispatch_table();
blas_ns::init_gemm_dispatch_table();
blas_ns::init_gemv_dispatch_vector();
blas_ns::init_syrk_dispatch_vector();
}

static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
Expand All @@ -73,7 +75,7 @@ PYBIND11_MODULE(_blas_impl, m)
};

m.def("_dot", dot_pyapi,
"Call `dot` from OneMKL BLAS library to compute "
"Call `dot` from oneMKL BLAS library to compute "
"the dot product of two real-valued vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
py::arg("result"), py::arg("depends") = py::list());
Expand All @@ -91,7 +93,7 @@ PYBIND11_MODULE(_blas_impl, m)
};

m.def("_dotc", dotc_pyapi,
"Call `dotc` from OneMKL BLAS library to compute "
"Call `dotc` from oneMKL BLAS library to compute "
"the dot product of two complex vectors, "
"conjugating the first vector.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -110,37 +112,45 @@ PYBIND11_MODULE(_blas_impl, m)
};

m.def("_dotu", dotu_pyapi,
"Call `dotu` from OneMKL BLAS library to compute "
"Call `dotu` from oneMKL BLAS library to compute "
"the dot product of two complex vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
py::arg("result"), py::arg("depends") = py::list());
}

{
m.def("_gemm", &blas_ns::gemm,
"Call `gemm` from OneMKL BLAS library to compute "
"Call `gemm` from oneMKL BLAS library to compute "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ns::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to compute "
"Call `gemm_batch` from oneMKL BLAS library to compute "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemv", &blas_ns::gemv,
"Call `gemv` from OneMKL BLAS library to compute "
"Call `gemv` from oneMKL BLAS library to compute "
"the matrix-vector product with a general matrix.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
py::arg("vectorY"), py::arg("transpose"),
py::arg("depends") = py::list());
}

{
m.def("_syrk", &blas_ns::syrk,
"Call `syrk` from oneMKL BLAS library to compute "
"the matrix-vector product with a general matrix.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"),
py::arg("depends") = py::list());
}

{
m.def(
"_using_onemath",
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/blas/dot_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ std::pair<sycl::event, sycl::event>
dot_impl_fn_ptr_t dot_fn = dot_dispatch_vector[type_id];
if (dot_fn == nullptr) {
throw py::value_error(
"Types of input vectors and result array are mismatched.");
"No dot implementation is available for the specified data type "
"of the input and output arrays.");
}

char *x_typeless_ptr = vectorX.get_data();
Expand Down
11 changes: 7 additions & 4 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
// stride between successive rows (for row major layout).
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda.
Tab(0), // Scaling factor for matrix C.
Expand Down Expand Up @@ -168,7 +167,8 @@ std::tuple<sycl::event, sycl::event, bool>
const int resultC_nd = resultC.get_ndim();

if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) {
throw py::value_error("Input matrices must be two-dimensional.");
throw py::value_error(
"Input and output matrices must be two-dimensional.");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
Expand Down Expand Up @@ -286,6 +286,8 @@ std::tuple<sycl::event, sycl::event, bool>
}
}
else {
// both A and B are f_contig so using column-major gemm and
// no transpose is needed
transA = oneapi::mkl::transpose::N;
transB = oneapi::mkl::transpose::N;
lda = m;
Expand Down Expand Up @@ -313,7 +315,8 @@ std::tuple<sycl::event, sycl::event, bool>
gemm_dispatch_table[matrixAB_type_id][resultC_type_id];
if (gemm_fn == nullptr) {
throw py::value_error(
"Types of input matrices and result matrix are mismatched.");
"No gemm implementation is available for the specified data type "
"of the input and output arrays.");
}

const char *a_typeless_ptr = matrixA.get_data();
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ std::tuple<sycl::event, sycl::event, bool>
gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
if (gemm_batch_fn == nullptr) {
throw py::value_error(
"Types of input matrices and result matrix are mismatched.");
"No gemm_batch implementation is available for the specified data "
"type of the input and output arrays.");
}

const char *a_typeless_ptr = matrixA.get_data();
Expand Down
49 changes: 24 additions & 25 deletions dpnp/backend/extensions/blas/gemv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
T(1), // Scaling factor for the matrix-vector product.
a, // Pointer to the input matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
// stride between successive rows (for row major layout).
x, // Pointer to the input vector x.
incx, // The stride of vector x.
T(0), // Scaling factor for vector y.
Expand Down Expand Up @@ -190,6 +189,26 @@ std::pair<sycl::event, sycl::event>
const py::ssize_t *a_shape = matrixA.get_shape_raw();
const py::ssize_t *x_shape = vectorX.get_shape_raw();
const py::ssize_t *y_shape = vectorY.get_shape_raw();
if (transpose) {
if (a_shape[0] != x_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in X.");
}
if (a_shape[1] != y_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in Y.");
}
}
else {
if (a_shape[1] != x_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in X.");
}
if (a_shape[0] != y_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in Y.");
}
}

oneapi::mkl::transpose transA;
std::size_t src_nelems;
Expand Down Expand Up @@ -243,27 +262,6 @@ std::pair<sycl::event, sycl::event>
}
#endif // USE_ONEMATH_CUBLAS

if (transpose) {
if (a_shape[0] != x_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in X.");
}
if (a_shape[1] != y_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in Y.");
}
}
else {
if (a_shape[1] != x_shape[0]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of elements in X.");
}
if (a_shape[0] != y_shape[0]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of elements in Y.");
}
}

const std::int64_t lda = is_row_major ? n : m;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY,
Expand All @@ -284,10 +282,11 @@ std::pair<sycl::event, sycl::event>
gemv_impl_fn_ptr_t gemv_fn = gemv_dispatch_vector[type_id];
if (gemv_fn == nullptr) {
throw py::value_error(
"Types of input arrays and result array are mismatched.");
"No gemv implementation is available for the specified data type "
"of the input and output arrays.");
}

char *a_typeless_ptr = matrixA.get_data();
const char *a_typeless_ptr = matrixA.get_data();
char *x_typeless_ptr = vectorX.get_data();
char *y_typeless_ptr = vectorY.get_data();

Expand Down
1 change: 0 additions & 1 deletion dpnp/backend/extensions/blas/gemv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,4 @@ extern std::pair<sycl::event, sycl::event>
const std::vector<sycl::event> &depends);

extern void init_gemv_dispatch_vector(void);
extern void init_gemv_batch_dispatch_vector(void);
} // namespace dpnp::extensions::blas
Loading
Loading