Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/ATen/native/xpu/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <ATen/native/BatchLinearAlgebra.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/ops/linalg_qr_native.h>
#include <ATen/ops/linalg_qr_cpu_dispatch.h>
#if defined(USE_ONEMKL_XPU)
#include <ATen/native/xpu/mkl/BatchLinearAlgebra.h>
#endif // USE_ONEMKL_XPU
Expand Down Expand Up @@ -64,4 +66,21 @@ void lu_factor_kernel_xpu(

REGISTER_XPU_DISPATCH(lu_factor_stub, &lu_factor_kernel_xpu);

TORCH_IMPL_FUNC(linalg_qr_xpu_out)(const Tensor& A,
std::string_view mode,
const Tensor & Q,
const Tensor & R) {
#if defined(USE_ONEMKL_XPU)
xpu::linalg_qr_kernel(A, mode, Q, R);
#else
auto A_cpu = A.to(at::kCPU);
auto Q_cpu = at::empty_like(Q, at::kCPU);
auto R_cpu = at::empty_like(R, at::kCPU);
at::cpu::linalg_qr_out(Q_cpu, R_cpu, A_cpu, mode);
Q.copy_(Q_cpu);
R.copy_(R_cpu);
#endif // USE_ONEMKL_XPU
}
Comment on lines 69 to 92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion is to register geqrf_kerenl_xpu/orgqr_kernel_xpu to geqrf_stub/orgqr_stub, which allows us to reuse op level code in stock Pytorch and reuse these two kernels in future.



Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty comment line should be removed.

Suggested change

Copilot uses AI. Check for mistakes.
} // namespace at::native
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"linalg_lstsq.out",
"linalg_lu.out",
"linalg_matrix_exp",
"linalg_qr.out",
"linalg_solve_triangular",
"linalg_solve_triangular.out",
"_linalg_svd.U",
Expand Down
104 changes: 104 additions & 0 deletions src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,4 +561,108 @@ void lu_factor_mkl(
pivots.copy_(pivots_);
}

void linalg_qr_kernel(
const at::Tensor& A,
std::string_view mode,
const at::Tensor& Q,
const at::Tensor& R) {

//TORCH_CHECK(A.device().is_xpu(), "a must be an XPU tensor");
//TORCH_CHECK(A.dtype() == at::kFloat, "a must be float");

at::Tensor a_contig = A.contiguous();
at::Tensor result_r = at::clone(a_contig);

auto options = at::TensorOptions().dtype(at::kFloat).device(kXPU);
auto dimensions = A.sizes();

result_r = result_r.transpose(-2, -1).contiguous();

int numel = a_contig.numel();
int range = a_contig.dim();
int64_t n = a_contig.sizes().at(range - 2);
int64_t m = a_contig.sizes().at(range - 1);
int64_t mn = int64_t(m * n);
int64_t b = numel / mn;

int out_q_columns = m > n ? n : m;
if (n > m && mode == "complete") {
out_q_columns = n;
}

std::vector v(dimensions.begin(), dimensions.end());
if (mode != "r") {
v[range - 1] = v[range - 2];
v[range - 2] = out_q_columns;
} else {
v = std::vector<long>({0, 0});
}
auto q_dimensions = at::IntArrayRef(v);

at::Tensor result_q = at::empty(q_dimensions, options);



sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();

int64_t bufsize1 =
oneapi::mkl::lapack::geqrf_scratchpad_size<float>(queue, n, m, n);
int64_t bufsize2 =
oneapi::mkl::lapack::orgqr_scratchpad_size<float>(queue, n, m, m, n);

int64_t bufsize = bufsize2 > bufsize1 ? bufsize2 : bufsize1;
int64_t tau_len = m > n ? n : m;
float* sbuffer = sycl::malloc_device<float>(bufsize, queue);
float* tau_buf = sycl::malloc_device<float>(tau_len, queue);
float* r_buf = result_r.data_ptr<float>();

float* q_buf = NULL;
if (mode != "r") {
q_buf = result_q.data_ptr<float>();
}

for (int batch_item = 0; batch_item < b; batch_item++) {
oneapi::mkl::lapack::geqrf(queue, n, m, r_buf, n, tau_buf, sbuffer, bufsize)
.wait();

if (mode != "r") {
// copy relevant part of R matrix to Q matrix
int copy_columns = out_q_columns > m ? m : out_q_columns;
queue.memcpy(q_buf, r_buf, n * copy_columns * sizeof(float)).wait();

oneapi::mkl::lapack::orgqr(
queue,
n,
out_q_columns,
tau_len,
q_buf,
n,
tau_buf,
sbuffer,
bufsize)
.wait();

q_buf += n * out_q_columns;
}

r_buf += mn;

} // batch

sycl::free(sbuffer, queue);
sycl::free(tau_buf, queue);

if ((mode == "reduced" || mode == "r") && n > m) {
result_r =
result_r
.index(
{"...", at::indexing::Slice(0, n), at::indexing::Slice(0, m)})
.contiguous();
}

Q.set_(result_q.transpose(-2, -1));
R.set_(result_r.transpose(-2, -1).triu_());
queue.wait();
}

} // namespace at::native::xpu
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/mkl/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@ TORCH_XPU_API void lu_factor_mkl(
const Tensor& info,
bool pivot);

TORCH_XPU_API void linalg_qr_kernel(
const at::Tensor& A,
std::string_view mode,
const at::Tensor& Q,
const at::Tensor& R);

} // namespace at::native::xpu
53 changes: 53 additions & 0 deletions test/xpu/test_linalg_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,57 @@ def __tunableop_ctx(self):
pass


@parametrize("batch", [1, 3])
@parametrize("m", [0, 1, 12])
@parametrize("n", [0, 1, 17])
@dtypes(torch.float32)
def qr_mode_r(self, device, dtype, batch, m, n):
if batch > 1:
A_cpu = torch.randn(batch, m, n, dtype=dtype, device="cpu")
else:
A_cpu = torch.randn(m, n, dtype=dtype, device="cpu")
A_xpu = A_cpu.to(device)

R_cpu = torch.linalg.qr(A_cpu, mode="r").R
R_xpu = torch.linalg.qr(A_xpu, mode="r").R
self.assertEqual(R_xpu, R_cpu, atol=1e-5, rtol=1e-5)

# Verify that R is upper triangular
lower_triangle = torch.tril(R_xpu, diagonal=-1)
self.assertEqual(lower_triangle.sum(), 0.0, atol=0.0, rtol=0.0)


@parametrize("batch", [1, 3])
@parametrize("m", [0, 1, 12])
@parametrize("n", [0, 1, 17])
@parametrize("mode", ["reduced", "complete"])
@dtypes(torch.float32)
def qr_modes_reduced_complete(self, device, dtype, batch, m, n, mode):
if batch > 1:
A_cpu = torch.randn(batch, m, n, dtype=dtype, device="cpu")
else:
A_cpu = torch.randn(m, n, dtype=dtype, device="cpu")
A_xpu = A_cpu.to(device)

Q_cpu, R_cpu = torch.linalg.qr(A_cpu, mode=mode)
Q_xpu, R_xpu = torch.linalg.qr(A_xpu, mode=mode)

self.assertEqual(Q_xpu, Q_cpu, atol=1e-5, rtol=1e-5)
self.assertEqual(R_xpu, R_cpu, atol=1e-5, rtol=1e-5)

# Verify Q is orthogonal: Q^T @ Q should be identity
QTQ_xpu = torch.matmul(Q_xpu.mT, Q_xpu)
k = min(m, n) if mode == "reduced" else m
identity = torch.eye(k, dtype=dtype, device=device)
if batch > 1:
identity = identity.expand(batch, k, k)
self.assertEqual(QTQ_xpu, identity, atol=1e-5, rtol=1e-5)

# Verify that R is upper triangular
lower_triangle = torch.tril(R_xpu, diagonal=-1)
self.assertEqual(lower_triangle.sum(), 0.0, atol=0.0, rtol=0.0)


with XPUPatchForImport(False):
from test_linalg import TestLinalg

Expand All @@ -493,6 +544,8 @@ def __tunableop_ctx(self):
TestLinalg.test_ck_blas_library = ck_blas_library
TestLinalg.test_addmm_relu_tunableop_rocm = addmm_relu_tunableop_rocm
TestLinalg._tunableop_ctx = __tunableop_ctx
TestLinalg.test_qr_mode_r = qr_mode_r
TestLinalg.test_qr_modes_reduced_complete = qr_modes_reduced_complete

TestLinalg._default_dtype_check_enabled = True
instantiate_device_type_tests(TestLinalg, globals(), only_for=("xpu"), allow_xpu=True)
Expand Down
11 changes: 11 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9443,6 +9443,17 @@
- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
python_module: linalg

- func: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
python_module: linalg
variants: function
structured_delegate: linalg_qr.out

- func: linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
python_module: linalg
structured: True
dispatch:
XPU: linalg_qr_xpu_out

Comment on lines +9446 to +9456
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

@mwiktor-intel mwiktor-intel Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using existing PT stub for implementing QR has both pros and cons. At the side of pros, we take the most of the existing infrastructure with simplified flow on sycl/xpu side. The implementation proposed here has, however also several strengths. Namely:

  1. Although QR is a composition of geqrf and orgqr, the initial request was for QR, not a wrappers for mentioned mkl funciotns intended to fit QR meta-operator.
  2. MKL functions are matrix oriented and can operate only on matrices. Thus, every subsequent call would require slicing a tensor and sending its parts to the kernel. It would require forcing contiquous function separately for each slice, instead of once, at the beginning of the proposed kernel. Since the proposed kernel uses direct pointers it can go through batch simply by manipulating pointers, no slicing and memory reorganization needed.
  3. The biggest possible efficiency trap lays in matrix layout differences between PT and LAPACK. The former stores data column wise the latter row wise. In order to properly call mkl functions, the data must be transposed.
    In fused version, as proposed in this PR, the transposition is called twice, at the beginning and exit. Splitting this kernel into two separate calls would require fitting what user expect calling geqrf and orgqr, which means, the data must have been transposed twice per call. Saying transposed means real reordering data in memory, since unlike PT functions, MKL does not support strides.

Thus, although provitidn these two kernels separately is a natural development path, I would prioritize it for later, after QR and two pending requests will be ready.

Copy link
Contributor

@CuiYifeng CuiYifeng Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mwiktor-intel For layout differences, could you indicate in MKL document the reason why LAPACK is row-major? https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-dpcpp/2025-2/geqrf-usm-version.html Please pay attention to leading dimension lda of input matrix. Suppose matrix A is col-major storage and shape is (m, n), we can set lda=m to indicate correct layout.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LDA cannot be replacement of stride. LDA can be used to skip some elements in matrix, but with LDA we cannot force difference data orientation. Specifically: say we have 3x2 matrix. Lapack expects this 6 element array in memory to represent the data as follows:
[ 0 2 4
1 3 5]
LDA here is 2, Setting LDA 3 does not change anything here The purpose of LDA is when, say, we have for any reason some bigger structure in memory, say,
0 3 6
1 4 7
2 5 8 ]
and we want to take only 2 upper rows to process. LDA here should be 3, giving information to the function, that apart from logical number of rows equal to 2, the procedure should skip 3 values to get to the data for 2nd column.

Some BLAS function have row_major of col_major order, but it is not the case within LAPACK.
Note, in the documentation, LDA cannot be set to less than number of rows in the processed matrix.

Copy link
Contributor

@CuiYifeng CuiYifeng Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the elaboration. Let me rephase your example 1: given 3x2 matrix stored as [0, 2, 4, 1, 3, 5] in memory, the logic order is [[0, 1], [2, 3], [4, 5]] if this matrix is col-major, right? You mentioned "LDA here is 2", but why LDA here can be smaller than the number of rows?
In your example 2: given memory buffer [0, 3, 6, 1, 4, 7, 2, 5, 8], the logic order of a 3x3 col-major matrix is [[0, 1, 2], [3, 4, 5], [6, 7, 8]]. We can use shape 2x3 and LDA=3 to get col-major matrix [[0, 1, 2], [3, 4, 5]] from this buffer. In this example, I agree with your understanding.
LDA is the distance in memory between the start of consecutive columns (for column-major storage) or consecutive rows (for row-major storage). In other words, LDA means stride[0] for row-major or stride[1] for col-major in buffer (not stride in matrix, as the bigger memory you mentioned).
Let's return to the API of geqrf. Given a buffer, shape (m, n) and LDA tell API how to get data from buffer. Furthermore, since you have noticed that The leading dimension of a, at least max(1, m), why not at least max(1, n)? In my understanding, this implies that API geqrf requires col-major input matrix.

- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
python_module: linalg
structured_delegate: linalg_inv_ex.inverse
Expand Down
Loading