Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
bool is_xf16 = any_of(precision_of<T>::value, ov::element::bf16, ov::element::f16);
// packed k, v
ov::element::Type attn_mask_precision = ov::element::Type(precision_of<T>::value);
if (attention_mask) {
attn_mask_precision = attention_mask.get_precision();
}

parallel_for2d(B, Hk, [&](size_t b, size_t h) {
T* k_ptr = &present_key.at<T>({b, h, 0, 0});
T* v_ptr = &present_value.at<T>({b, h, 0, 0});
Expand Down Expand Up @@ -451,11 +456,12 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
}

uint8_t* attn_mask_ptr = nullptr;
auto attn_mask_stride = 0;
size_t attn_mask_stride = 0;
if (attention_mask) {
attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<T>({b, h, 0, 0}, true));
const size_t mask_head = attention_mask.size(1) > 1 ? h : 0;
attn_mask_ptr = static_cast<uint8_t*>(attention_mask.ptr_v(b, mask_head, 0, 0));
if (attention_mask.size(2) > 1) {
attn_mask_stride = attention_mask.stride(2) * sizeof(T);
attn_mask_stride = attention_mask.stride_bytes(2);
}
}
uint8_t* cmask_ptr = nullptr;
Expand All @@ -474,18 +480,20 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
if (sink_input) {
sink = &sink_input.at<float>({b, h, m, 0}, true);
}
uint8_t* attn_mask_row =
attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;

attn_softmax(reinterpret_cast<void*>(score),
reinterpret_cast<T*>(score),
d_scale,
reinterpret_cast<void*>(alibi_ptr + m * alibi_stride),
attn_mask_ptr + m * attn_mask_stride,
attn_mask_row,
cmask_ptr + m * cmask_stride,
select_nfltmax_at_0,
ncausal,
kv_len,
precision_of<T>::value,
precision_of<T>::value,
attn_mask_precision,
precision_of<T>::value,
sink);
}
Expand Down Expand Up @@ -638,6 +646,10 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
auto k_stride_s = present_key.stride(3);

auto m_blocks = (q_len + m_block_size - 1) / m_block_size;
ov::element::Type attn_mask_precision = precision;
if (attention_mask) {
attn_mask_precision = attention_mask.get_precision();
}

parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) {
auto m_start = m_blk * m_block_size;
Expand All @@ -657,11 +669,12 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
}
}
uint8_t* attn_mask_ptr = nullptr;
auto attn_mask_stride = 0;
size_t attn_mask_stride = 0;
if (attention_mask) {
attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<T>({b, h, 0, 0}, true));
const size_t mask_head = attention_mask.size(1) > 1 ? h : 0;
attn_mask_ptr = static_cast<uint8_t*>(attention_mask.ptr_v(b, mask_head, 0, 0));
if (attention_mask.size(2) > 1) {
attn_mask_stride = attention_mask.stride(2) * sizeof(T);
attn_mask_stride = attention_mask.stride_bytes(2);
}
}
uint8_t* cmask_ptr = nullptr;
Expand Down Expand Up @@ -696,17 +709,19 @@ struct MHAKernel<ScaledDotProductAttention::KT_ACL, T> {
for (size_t m = m_start; m < m_end; m++) {
// apply attention mask & sofmax
auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len;
uint8_t* attn_mask_row =
attn_mask_ptr && attn_mask_stride ? attn_mask_ptr + m * attn_mask_stride : attn_mask_ptr;
attn_softmax(reinterpret_cast<void*>(qk + (m - m_start) * kv_len),
qk + (m - m_start) * kv_len,
d_scale,
reinterpret_cast<void*>(alibi_ptr + m * alibi_stride),
attn_mask_ptr + m * attn_mask_stride,
attn_mask_row,
cmask_ptr + m * cmask_stride,
select_nfltmax_at_0,
ncausal,
kv_len,
precision,
precision,
attn_mask_precision,
precision,
nullptr);
}
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this test be in the classes folder and instantiated in x64 and ARM (the changes seems to be common for both the platforms)?
Also the test constructed for bf16 graph, which is odd, since typically IR contains fp32 and the inference precision is then forced by the runtime. Also it's claimed that changes are targeting f16 too, but the test covers bf16 only.
Can we do the following?

  1. Move the test body to the shared section (classes)
  2. Use fp32 graph, but force the specific inference precision via the properties
  3. Instance the test in x64 folder for bf16 and fp16 depending on the HW capabilities.
  4. Instance the test in ARM folder with fp16.

Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include <gtest/gtest.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse existing tests to cover the boolean attn_mask ? The only change here is the attn_mask type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add copy right in the head of the file.

Copy link
Contributor Author

@liubo-intel liubo-intel Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse existing tests to cover the boolean attn_mask ? The only change here is the attn_mask type.

previously I try to extend current existing SDPA tests to cover our changed logic, but I found these SDPA tests are stateless and will be converted to Snippets(Subgraph) during inference which will not use our changed logic, so I created new ones to cover our changes.


#include <cstddef>
#include <cstdint>
#include <memory>
#include <vector>

#include "common_test_utils/include/common_test_utils/data_utils.hpp"
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
#include "internal_properties.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/core/model.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/op/assign.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/read_value.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/util/variable.hpp"
#include "openvino/opsets/opset13.hpp"
#include "openvino/pass/manager.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "utils/cpu_test_utils.hpp"

using namespace ov::test;
using namespace CPUTestUtils;

namespace ov {
namespace test {

namespace {

class StatefulSdpaBoolMaskTest : public ov::test::SubgraphBaseTest, public CPUTestsBase {
protected:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
configuration[ov::hint::inference_precision.name()] = ov::element::bf16;
configuration[ov::hint::kv_cache_precision.name()] = ov::element::bf16;
rel_threshold = 0.02f;
abs_threshold = 0.02f;
selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::bf16);

const InputShape q_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
const InputShape k_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
const InputShape v_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}};
const InputShape mask_shape{{1, 1, -1, -1}, {{1, 1, 10, 10}}};
const InputShape past_shape{{-1, 8, -1, 64}, {{1, 8, 0, 64}}};
const InputShape beam_shape{{-1}, {{1}}};

init_input_shapes({q_shape, k_shape, v_shape, mask_shape, past_shape, beam_shape});

auto q = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[0]);
auto k = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[1]);
auto v = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[2]);
auto mask = std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[3]);
auto past_init = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[4]);
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, inputDynamicShapes[5]);

q->set_friendly_name("q");
k->set_friendly_name("k");
v->set_friendly_name("v");
mask->set_friendly_name("attention_mask");
past_init->set_friendly_name("past_init");
beam_idx->set_friendly_name("beam_idx");

auto variable_k = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastk"});
auto variable_v = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastv"});

auto past_k = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_k);
auto past_v = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_v);
past_k->set_friendly_name("pastk_read");
past_v->set_friendly_name("pastv_read");

auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {0});
auto gather_k = std::make_shared<ov::op::v8::Gather>(past_k, beam_idx, axis);
auto gather_v = std::make_shared<ov::op::v8::Gather>(past_v, beam_idx, axis);
gather_k->set_batch_dims(0);
gather_v->set_batch_dims(0);

auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_k, k}, 2);
auto concat_v = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_v, v}, 2);

auto sdpa = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concat_k, concat_v, mask, false);
sdpa->set_friendly_name("stateful_sdpa");

auto assign_k = std::make_shared<ov::op::v6::Assign>(concat_k, variable_k);
auto assign_v = std::make_shared<ov::op::v6::Assign>(concat_v, variable_v);
assign_k->set_friendly_name("pastk_write");
assign_v->set_friendly_name("pastv_write");

ov::ResultVector results{std::make_shared<ov::op::v0::Result>(sdpa)};
ov::SinkVector sinks{assign_k, assign_v};
function = std::make_shared<ov::Model>(results,
sinks,
ov::ParameterVector{q, k, v, mask, past_init, beam_idx},
"StatefulSdpaBoolMask");
}

void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
inputs.clear();

const auto& parameters = function->get_parameters();
for (size_t idx = 0; idx < parameters.size(); ++idx) {
const auto& param = parameters[idx];
const auto& shape = targetInputStaticShapes[idx];
if (param->get_element_type() == ov::element::bf16) {
ov::Tensor tensor{ov::element::bf16, shape};
utils::fill_data_random(static_cast<ov::bfloat16*>(tensor.data()), tensor.get_size(), 2, -1, 10);
inputs.insert({param, tensor});
} else if (param->get_element_type() == ov::element::boolean) {
ov::Tensor tensor{ov::element::boolean, shape};
auto* data = tensor.data<bool>();
for (size_t i = 0; i < tensor.get_size(); ++i) {
data[i] = (i % 3) != 0;
}
inputs.insert({param, tensor});
} else if (param->get_element_type() == ov::element::i32) {
ov::Tensor tensor{ov::element::i32, shape};
auto* data = tensor.data<int32_t>();
int32_t denom = 1;
if (!shape.empty() && shape[0] != 0) {
denom = static_cast<int32_t>(shape[0]);
}
for (size_t i = 0; i < tensor.get_size(); ++i) {
data[i] = static_cast<int32_t>(i % denom);
}
inputs.insert({param, tensor});
} else {
FAIL() << "Unexpected parameter precision " << param->get_element_type();
}
}
}
};

TEST_F(StatefulSdpaBoolMaskTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
if (!ov::with_cpu_x86_bfloat16()) {
GTEST_SKIP();
}
run();
CheckPluginRelatedResults(compiledModel, "ScaledAttn");
}

} // namespace

} // namespace test
} // namespace ov
Loading