From 980872bfeae5c80b33dd5239f2b8cc2f33392b52 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Tue, 2 Dec 2025 17:09:40 +0800 Subject: [PATCH] Update Attention.cpp --- src/ATen/native/transformers/Attention.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/ATen/native/transformers/Attention.cpp b/src/ATen/native/transformers/Attention.cpp index df0a2c9bc0..3090dfbeec 100644 --- a/src/ATen/native/transformers/Attention.cpp +++ b/src/ATen/native/transformers/Attention.cpp @@ -38,6 +38,16 @@ std::tuple transform_bias_rescale_qkv_xpu( auto T = qkv.is_nested() ? native::NestedTensor_get_max_size( *native::get_nested_tensor_impl(qkv))[0] : qkv.size(1); + if (qkv.is_nested()) { + // Don't mess with non-nested case for now since it's not set up to fiddle + // with mask size. + + // Round T up to next multiple of 8 so as to be able to utilize Tensor + // cores. Otherwise, sometimes with padding, *no* row will have the maximum + // sequence length and so we'll have a non-divisible-by-8 dimension even if + // the model author chose a multiple of 8. + T = T + (8 - (T % 8)) % 8; + } auto _3D = qkv_bias.size(0); auto D = _3D / 3; TORCH_CHECK(D % num_head == 0);