diff --git a/src/ATen/native/xpu/BatchLinearAlgebra.cpp b/src/ATen/native/xpu/BatchLinearAlgebra.cpp index 8036419c07..ef95e9ca44 100644 --- a/src/ATen/native/xpu/BatchLinearAlgebra.cpp +++ b/src/ATen/native/xpu/BatchLinearAlgebra.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include #if defined(USE_ONEMKL_XPU) #include #endif // USE_ONEMKL_XPU @@ -64,4 +66,28 @@ void lu_factor_kernel_xpu( REGISTER_XPU_DISPATCH(lu_factor_stub, &lu_factor_kernel_xpu); +TORCH_IMPL_FUNC(linalg_qr_xpu_out)(const Tensor& A, + std::string_view mode, + const Tensor & Q, + const Tensor & R) { +#if defined(USE_ONEMKL_XPU) + if (!A.is_complex()) { + xpu::linalg_qr_kernel(A, mode, Q, R); + } else { + auto A_cpu = A.to(at::kCPU); + auto Q_cpu = at::empty_like(Q, at::kCPU); + auto R_cpu = at::empty_like(R, at::kCPU); + at::cpu::linalg_qr_out(Q_cpu, R_cpu, A_cpu, mode); + Q.copy_(Q_cpu); + R.copy_(R_cpu); + } +#else + auto A_cpu = A.to(at::kCPU); + auto Q_cpu = at::empty_like(Q, at::kCPU); + auto R_cpu = at::empty_like(R, at::kCPU); + at::cpu::linalg_qr_out(Q_cpu, R_cpu, A_cpu, mode); + Q.copy_(Q_cpu); + R.copy_(R_cpu); +#endif // USE_ONEMKL_XPU +} } // namespace at::native diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 3479967d9e..908f2feaed 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -209,7 +209,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "linalg_lstsq.out", "linalg_lu.out", "linalg_matrix_exp", - "linalg_qr.out", "linalg_solve_triangular", "linalg_solve_triangular.out", "_linalg_svd.U", diff --git a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp index 26e80fa4d0..17faee757c 100644 --- a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp +++ b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp @@ -561,4 +561,143 @@ void lu_factor_mkl( pivots.copy_(pivots_); } + +template +void linalg_qr_kernel_impl( + const at::Tensor& A, + std::string_view mode, + const at::Tensor& Q, + const at::Tensor& R) { + + at::Tensor a_contig = A.contiguous(); + at::Tensor result_r = at::clone(a_contig); + + auto options = at::TensorOptions().dtype(A.dtype()).device(kXPU); + auto dimensions = A.sizes(); + + int numel = a_contig.numel(); + int range = a_contig.dim(); + int64_t n = a_contig.sizes().at(range - 2); + int64_t m = a_contig.sizes().at(range - 1); + int64_t mn = int64_t(m * n); + int64_t b = numel == 0 ? 0 : numel / mn; + + // correct R matrix dimensions if needed + if (numel == 0 && mode != "complete") { + std::vector r(dimensions.begin(), dimensions.end()); + if (r[range-1] == 0) + r[range-2]=0; + result_r = at::zeros(r,options); + } + + result_r=result_r.transpose(-2,-1).contiguous(); + + + if (b==0 && mode=="complete" && n>0) { + b = native::batchCount(a_contig); + } + + int out_q_columns = m > n ? n : m; + if (n > m && mode == "complete") { + out_q_columns = n; + } + + // correct Q matrix output dimensions if needed + std::vector v_dim(dimensions.begin(), dimensions.end()); + if (mode != "r") { + v_dim[range - 1] = v_dim[range - 2]; + v_dim[range - 2] = out_q_columns; + } else { + // dim =(0) for "r" mode + v_dim = std::vector({0}); + } + auto q_dimensions = at::IntArrayRef(v_dim); + + at::Tensor result_q = at::empty(q_dimensions, options); + + + + sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); + + // add one to size to avoid special case when any of dimensions is 0. + int64_t bufsize1 = + oneapi::mkl::lapack::geqrf_scratchpad_size(queue, n+1, m+1, n+1); + int64_t bufsize2 = + oneapi::mkl::lapack::orgqr_scratchpad_size(queue, n+1, m+1, m+1, n+1); + + int64_t bufsize = bufsize2 > bufsize1 ? bufsize2 : bufsize1; + int64_t tau_len = m > n ? n : m; + scalar_t* sbuffer = sycl::malloc_device(bufsize, queue); + scalar_t* tau_buf = sycl::malloc_device(tau_len, queue); + scalar_t* r_buf = result_r.data_ptr(); + + scalar_t* q_buf = nullptr; + if (mode != "r") { + q_buf = result_q.data_ptr(); + } + + + for (int batch_item = 0; batch_item < b; batch_item++) { + + if (mn != 0) // make QR if there is something to orthogonalize + oneapi::mkl::lapack::geqrf(queue, n, m, r_buf, n, tau_buf, sbuffer, bufsize) + .wait(); + + if (mode != "r") { + // copy relevant part of R matrix to Q matrix + int copy_columns = out_q_columns > m ? m : out_q_columns; + queue.memcpy(q_buf, r_buf, n * copy_columns * sizeof(scalar_t)).wait(); + + oneapi::mkl::lapack::orgqr( + queue, + n, + out_q_columns, + tau_len, + q_buf, + n, + tau_buf, + sbuffer, + bufsize) + .wait(); + + q_buf += n * out_q_columns; + } + + r_buf += mn; + + } // batch + + sycl::free(sbuffer, queue); + sycl::free(tau_buf, queue); + + if ((mode == "reduced" || mode == "r") && n > m) { + result_r = + result_r + .index( + {"...", at::indexing::Slice(0, n), at::indexing::Slice(0, m)}) + .contiguous(); + } + + // normal case, non-zero dimensions + if (mode!="r") { + result_q.transpose_(-2, -1).contiguous(); + } + Q.set_(result_q); + R.set_(result_r.transpose(-2, -1).triu_()); + +} + + + +void linalg_qr_kernel( + const at::Tensor& A, + std::string_view mode, + const at::Tensor& Q, + const at::Tensor& R) { + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "linalg_qr_xpu", [&] { + linalg_qr_kernel_impl(A, mode, Q, R); + }); +} + } // namespace at::native::xpu + // diff --git a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h index c1cc1da5c6..ef846c4d6b 100644 --- a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h +++ b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h @@ -16,4 +16,10 @@ TORCH_XPU_API void lu_factor_mkl( const Tensor& info, bool pivot); +TORCH_XPU_API void linalg_qr_kernel( + const at::Tensor& A, + std::string_view mode, + const at::Tensor& Q, + const at::Tensor& R); + } // namespace at::native::xpu diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index c7b88ccc9e..edc70c25ce 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -473,6 +473,57 @@ def __tunableop_ctx(self): pass +@parametrize("batch", [1, 3]) +@parametrize("m", [1, 12]) +@parametrize("n", [1, 17]) +@dtypes(torch.float32, torch.float64) +def qr_mode_r(self, device, dtype, batch, m, n): + if batch > 1: + A_cpu = torch.randn(batch, m, n, dtype=dtype, device="cpu") + else: + A_cpu = torch.randn(m, n, dtype=dtype, device="cpu") + A_xpu = A_cpu.to(device) + + R_cpu = torch.linalg.qr(A_cpu, mode="r").R + R_xpu = torch.linalg.qr(A_xpu, mode="r").R + self.assertEqual(R_xpu, R_cpu, atol=1e-5, rtol=1e-5) + + # Verify that R is upper triangular + lower_triangle = torch.tril(R_xpu, diagonal=-1) + self.assertEqual(lower_triangle.sum(), 0.0, atol=0.0, rtol=0.0) + + +@parametrize("batch", [1, 3]) +@parametrize("m", [0, 1, 12]) +@parametrize("n", [0, 1, 17]) +@parametrize("mode", ["reduced", "complete"]) +@dtypes(torch.float32, torch.float64) +def qr_modes_reduced_complete(self, device, dtype, batch, m, n, mode): + if batch > 1: + A_cpu = torch.randn(batch, m, n, dtype=dtype, device="cpu") + else: + A_cpu = torch.randn(m, n, dtype=dtype, device="cpu") + A_xpu = A_cpu.to(device) + + Q_cpu, R_cpu = torch.linalg.qr(A_cpu, mode=mode) + Q_xpu, R_xpu = torch.linalg.qr(A_xpu, mode=mode) + + self.assertEqual(Q_xpu, Q_cpu, atol=1e-5, rtol=1e-5) + self.assertEqual(R_xpu, R_cpu, atol=1e-5, rtol=1e-5) + + # Verify Q is orthogonal: Q^T @ Q should be identity + QTQ_xpu = torch.matmul(Q_xpu.mT, Q_xpu) + k = min(m, n) if mode == "reduced" else m + identity = torch.eye(k, dtype=dtype, device=device) + if batch > 1: + identity = identity.expand(batch, k, k) + self.assertEqual(QTQ_xpu, identity, atol=1e-5, rtol=1e-5) + + # Verify that R is upper triangular + lower_triangle = torch.tril(R_xpu, diagonal=-1) + self.assertEqual(lower_triangle.sum(), 0.0, atol=0.0, rtol=0.0) + + with XPUPatchForImport(False): from test_linalg import TestLinalg @@ -493,6 +544,8 @@ def __tunableop_ctx(self): TestLinalg.test_ck_blas_library = ck_blas_library TestLinalg.test_addmm_relu_tunableop_rocm = addmm_relu_tunableop_rocm TestLinalg._tunableop_ctx = __tunableop_ctx +TestLinalg.test_qr_mode_r = qr_mode_r +TestLinalg.test_qr_modes_reduced_complete = qr_modes_reduced_complete TestLinalg._default_dtype_check_enabled = True instantiate_device_type_tests(TestLinalg, globals(), only_for=("xpu"), allow_xpu=True) diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index a3281791de..7221ebdbdc 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -9443,6 +9443,17 @@ - func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor python_module: linalg +- func: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) + python_module: linalg + variants: function + structured_delegate: linalg_qr.out + +- func: linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + python_module: linalg + structured: True + dispatch: + XPU: linalg_qr_xpu_out + - func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) python_module: linalg structured_delegate: linalg_inv_ex.inverse