-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[CPU] Fix SDPA node attention mask precision handling for bf16/f16 inference #33132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this test be in the classes folder and instantiated in x64 and ARM (the changes seems to be common for both the platforms)?
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| #include <gtest/gtest.h> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we reuse existing tests to cover the boolean attn_mask ? The only change here is the attn_mask type.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also add copy right in the head of the file.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
previously I try to extend current existing SDPA tests to cover our changed logic, but I found these SDPA tests are stateless and will be converted to Snippets(Subgraph) during inference which will not use our changed logic, so I created new ones to cover our changes. |
||
|
|
||
| #include <cstddef> | ||
| #include <cstdint> | ||
| #include <memory> | ||
| #include <vector> | ||
|
|
||
| #include "common_test_utils/include/common_test_utils/data_utils.hpp" | ||
| #include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" | ||
| #include "internal_properties.hpp" | ||
| #include "openvino/core/dimension.hpp" | ||
| #include "openvino/core/model.hpp" | ||
| #include "openvino/core/partial_shape.hpp" | ||
| #include "openvino/core/type/element_type.hpp" | ||
| #include "openvino/op/assign.hpp" | ||
| #include "openvino/op/concat.hpp" | ||
| #include "openvino/op/constant.hpp" | ||
| #include "openvino/op/gather.hpp" | ||
| #include "openvino/op/parameter.hpp" | ||
| #include "openvino/op/read_value.hpp" | ||
| #include "openvino/op/scaled_dot_product_attention.hpp" | ||
| #include "openvino/op/util/variable.hpp" | ||
| #include "openvino/opsets/opset13.hpp" | ||
| #include "openvino/pass/manager.hpp" | ||
| #include "shared_test_classes/base/ov_subgraph.hpp" | ||
| #include "utils/cpu_test_utils.hpp" | ||
|
|
||
| using namespace ov::test; | ||
| using namespace CPUTestUtils; | ||
|
|
||
| namespace ov { | ||
| namespace test { | ||
|
|
||
| namespace { | ||
|
|
||
| class StatefulSdpaBoolMaskTest : public ov::test::SubgraphBaseTest, public CPUTestsBase { | ||
| protected: | ||
| void SetUp() override { | ||
| targetDevice = ov::test::utils::DEVICE_CPU; | ||
| configuration[ov::hint::inference_precision.name()] = ov::element::bf16; | ||
| configuration[ov::hint::kv_cache_precision.name()] = ov::element::bf16; | ||
| rel_threshold = 0.02f; | ||
| abs_threshold = 0.02f; | ||
| selectedType = makeSelectedTypeStr(getPrimitiveType(), ov::element::bf16); | ||
|
|
||
| const InputShape q_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; | ||
| const InputShape k_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; | ||
| const InputShape v_shape{{-1, 8, -1, 64}, {{1, 8, 10, 64}}}; | ||
| const InputShape mask_shape{{1, 1, -1, -1}, {{1, 1, 10, 10}}}; | ||
| const InputShape past_shape{{-1, 8, -1, 64}, {{1, 8, 0, 64}}}; | ||
| const InputShape beam_shape{{-1}, {{1}}}; | ||
|
|
||
| init_input_shapes({q_shape, k_shape, v_shape, mask_shape, past_shape, beam_shape}); | ||
|
|
||
| auto q = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[0]); | ||
| auto k = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[1]); | ||
| auto v = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[2]); | ||
| auto mask = std::make_shared<ov::op::v0::Parameter>(ov::element::boolean, inputDynamicShapes[3]); | ||
| auto past_init = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, inputDynamicShapes[4]); | ||
| auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, inputDynamicShapes[5]); | ||
|
|
||
| q->set_friendly_name("q"); | ||
| k->set_friendly_name("k"); | ||
| v->set_friendly_name("v"); | ||
| mask->set_friendly_name("attention_mask"); | ||
| past_init->set_friendly_name("past_init"); | ||
| beam_idx->set_friendly_name("beam_idx"); | ||
|
|
||
| auto variable_k = std::make_shared<ov::op::util::Variable>( | ||
| ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastk"}); | ||
| auto variable_v = std::make_shared<ov::op::util::Variable>( | ||
| ov::op::util::VariableInfo{inputDynamicShapes[4], ov::element::bf16, "pastv"}); | ||
|
|
||
| auto past_k = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_k); | ||
| auto past_v = std::make_shared<ov::op::v6::ReadValue>(past_init, variable_v); | ||
| past_k->set_friendly_name("pastk_read"); | ||
| past_v->set_friendly_name("pastv_read"); | ||
|
|
||
| auto axis = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); | ||
| auto gather_k = std::make_shared<ov::op::v8::Gather>(past_k, beam_idx, axis); | ||
| auto gather_v = std::make_shared<ov::op::v8::Gather>(past_v, beam_idx, axis); | ||
| gather_k->set_batch_dims(0); | ||
| gather_v->set_batch_dims(0); | ||
|
|
||
| auto concat_k = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_k, k}, 2); | ||
| auto concat_v = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{gather_v, v}, 2); | ||
|
|
||
| auto sdpa = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concat_k, concat_v, mask, false); | ||
| sdpa->set_friendly_name("stateful_sdpa"); | ||
|
|
||
| auto assign_k = std::make_shared<ov::op::v6::Assign>(concat_k, variable_k); | ||
| auto assign_v = std::make_shared<ov::op::v6::Assign>(concat_v, variable_v); | ||
| assign_k->set_friendly_name("pastk_write"); | ||
| assign_v->set_friendly_name("pastv_write"); | ||
|
|
||
| ov::ResultVector results{std::make_shared<ov::op::v0::Result>(sdpa)}; | ||
| ov::SinkVector sinks{assign_k, assign_v}; | ||
| function = std::make_shared<ov::Model>(results, | ||
| sinks, | ||
| ov::ParameterVector{q, k, v, mask, past_init, beam_idx}, | ||
| "StatefulSdpaBoolMask"); | ||
| } | ||
|
|
||
| void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override { | ||
| inputs.clear(); | ||
|
|
||
| const auto& parameters = function->get_parameters(); | ||
| for (size_t idx = 0; idx < parameters.size(); ++idx) { | ||
| const auto& param = parameters[idx]; | ||
| const auto& shape = targetInputStaticShapes[idx]; | ||
| if (param->get_element_type() == ov::element::bf16) { | ||
| ov::Tensor tensor{ov::element::bf16, shape}; | ||
| utils::fill_data_random(static_cast<ov::bfloat16*>(tensor.data()), tensor.get_size(), 2, -1, 10); | ||
| inputs.insert({param, tensor}); | ||
| } else if (param->get_element_type() == ov::element::boolean) { | ||
| ov::Tensor tensor{ov::element::boolean, shape}; | ||
| auto* data = tensor.data<bool>(); | ||
| for (size_t i = 0; i < tensor.get_size(); ++i) { | ||
| data[i] = (i % 3) != 0; | ||
| } | ||
| inputs.insert({param, tensor}); | ||
| } else if (param->get_element_type() == ov::element::i32) { | ||
| ov::Tensor tensor{ov::element::i32, shape}; | ||
| auto* data = tensor.data<int32_t>(); | ||
| int32_t denom = 1; | ||
| if (!shape.empty() && shape[0] != 0) { | ||
| denom = static_cast<int32_t>(shape[0]); | ||
| } | ||
| for (size_t i = 0; i < tensor.get_size(); ++i) { | ||
| data[i] = static_cast<int32_t>(i % denom); | ||
| } | ||
| inputs.insert({param, tensor}); | ||
| } else { | ||
| FAIL() << "Unexpected parameter precision " << param->get_element_type(); | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| TEST_F(StatefulSdpaBoolMaskTest, CompareWithRefs) { | ||
| SKIP_IF_CURRENT_TEST_IS_DISABLED(); | ||
| if (!ov::with_cpu_x86_bfloat16()) { | ||
| GTEST_SKIP(); | ||
| } | ||
| run(); | ||
| CheckPluginRelatedResults(compiledModel, "ScaledAttn"); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| } // namespace test | ||
| } // namespace ov | ||
Uh oh!
There was an error while loading. Please reload this page.