-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[LPT] QDQ stripping #32266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[LPT] QDQ stripping #32266
Changes from 20 commits
bbe32a5
1955d55
17d77ab
e4f6d6a
1640fdb
d6130ae
5824c91
b4c320a
e771761
8d46d49
dbba54b
95de481
8faa050
b181306
0f7aa0a
cbc6318
0e403c2
5211740
0bb9870
6a6d9b8
5032aea
68a9e3b
060e17d
4736a76
5b25b47
fd48bf4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| // Copyright (C) 2018-2025 Intel Corporation | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <memory> | ||
| #include <set> | ||
|
|
||
| #include "lpt_visibility.hpp" | ||
| #include "openvino/pass/matcher_pass.hpp" | ||
| #include "quantization_details.hpp" | ||
|
|
||
| namespace ov { | ||
| namespace pass { | ||
| namespace low_precision { | ||
|
|
||
| /** | ||
| * @ingroup ov_transformation_common_api | ||
| * @brief FQStrippingTransformation strips FakeQuantize operations with specified levels | ||
| * by replacing them with Clamp operations. | ||
| */ | ||
| class LP_TRANSFORMATIONS_API FQStrippingTransformation : public ov::pass::MatcherPass { | ||
| public: | ||
| OPENVINO_RTTI("FQStrippingTransformation", "0", MatcherPass); | ||
| FQStrippingTransformation(const std::set<size_t>& levels_to_strip, bool replace_with_clamp); | ||
| }; | ||
|
|
||
| } // namespace low_precision | ||
| } // namespace pass | ||
| } // namespace ov |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| // Copyright (C) 2018-2025 Intel Corporation | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // | ||
|
|
||
| #include "low_precision/qdq_stripping.hpp" | ||
|
|
||
| #include <memory> | ||
|
|
||
| #include "itt.hpp" | ||
| #include "low_precision/common/ie_lpt_exception.hpp" | ||
| #include "low_precision/lpt_itt.hpp" | ||
| #include "low_precision/network_helper.hpp" | ||
| #include "openvino/core/except.hpp" | ||
| #include "openvino/core/type.hpp" | ||
| #include "openvino/op/clamp.hpp" | ||
| #include "openvino/op/constant.hpp" | ||
| #include "openvino/op/equal.hpp" | ||
| #include "openvino/op/fake_quantize.hpp" | ||
| #include "openvino/pass/pattern/op/wrap_type.hpp" | ||
| #include "openvino/util/log.hpp" | ||
| #include "transformations/utils/utils.hpp" | ||
|
|
||
| namespace ov { | ||
| namespace pass { | ||
| namespace low_precision { | ||
|
|
||
| FQStrippingTransformation::FQStrippingTransformation(const std::set<size_t>& levels_to_strip, bool replace_with_clamp) { | ||
| MATCHER_SCOPE(FQStrippingTransformation); | ||
| auto is_scalar = [](const Output<Node>& output) -> bool { | ||
| return ov::shape_size(output.get_shape()) == 1; | ||
| }; | ||
| auto input_low_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar); | ||
| auto input_high_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar); | ||
| auto output_low_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar); | ||
| auto output_high_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar); | ||
| auto fq_m = pattern::wrap_type<ov::op::v0::FakeQuantize>( | ||
| {pattern::any_input(), input_low_m, input_high_m, output_low_m, output_high_m}); | ||
|
|
||
| ov::graph_rewrite_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) { | ||
| const auto& pattern_map = m.get_pattern_value_map(); | ||
| auto node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(pattern_map.at(fq_m).get_node_shared_ptr()); | ||
| if (!node) { | ||
| return false; | ||
| } | ||
|
|
||
| const size_t levels = node->get_levels(); | ||
| if (!levels_to_strip.count(levels)) { | ||
| return false; | ||
| } | ||
|
|
||
| auto input = node->get_input_node_shared_ptr(0); | ||
| auto input_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(input_low_m).get_node_shared_ptr()); | ||
| auto input_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(input_high_m).get_node_shared_ptr()); | ||
| auto output_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(output_low_m).get_node_shared_ptr()); | ||
| auto output_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(output_high_m).get_node_shared_ptr()); | ||
|
|
||
| if (!input_low || !input_high || !output_low || !output_high) { | ||
| return false; | ||
| } | ||
| auto constants_are_equal = [](const std::shared_ptr<ov::op::v0::Constant>& lhs, | ||
| const std::shared_ptr<ov::op::v0::Constant>& rhs) { | ||
| auto equal = | ||
| ov::as_type_ptr<ov::op::v0::Constant>(ov::op::util::make_try_fold<ov::op::v1::Equal>(lhs, rhs)); | ||
| OPENVINO_ASSERT(equal && ov::shape_size(equal->get_shape()) == 1, | ||
| "constants_are_equal expects scalar constant as a comparison result"); | ||
| return equal->get_vector<bool>()[0] == true; | ||
| }; | ||
|
|
||
| if (!constants_are_equal(input_low, output_low) || !constants_are_equal(input_high, output_high)) { | ||
| return false; | ||
| } | ||
|
|
||
| bool res = false; | ||
| if (replace_with_clamp) { | ||
| auto clamp = std::make_shared<ov::op::v0::Clamp>(input->output(0), | ||
| output_low->cast_vector<double>()[0], | ||
| output_high->cast_vector<double>()[0]); | ||
| res = replace_node_update_name(node, clamp); | ||
| } else { | ||
| res = replace_output_update_name(node->output(0), node->input_value(0)); | ||
| } | ||
| OPENVINO_ASSERT(res, "FQ stripping failed"); | ||
| return res; | ||
| }; | ||
|
|
||
| auto m = std::make_shared<ov::pass::pattern::Matcher>(fq_m, matcher_name); | ||
| this->register_matcher(m, callback); | ||
| } | ||
|
|
||
| } // namespace low_precision | ||
| } // namespace pass | ||
| } // namespace ov | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| #include "low_precision/fold_convert.hpp" | ||
| #include "low_precision/fuse_convert.hpp" | ||
| #include "low_precision/group_convolution.hpp" | ||
| #include "low_precision/qdq_stripping.hpp" | ||
| #include "low_precision/low_precision.hpp" | ||
| #include "low_precision/mat_mul.hpp" | ||
| #include "low_precision/multiply_to_group_convolution.hpp" | ||
|
|
@@ -387,8 +388,22 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) { | |
| ov::disable_keep_const_precision(node); | ||
| } | ||
|
|
||
| auto is_model_quantized = ov::pass::low_precision::LowPrecision::isFunctionQuantized(func); | ||
| using namespace ov::pass::low_precision; | ||
| auto is_model_quantized = LowPrecision::isFunctionQuantized(func, std::set<levels>{levels::int8, levels::int8_narrow_range}); | ||
| enableInt8 = config.get_enable_lp_transformations() && is_model_quantized; | ||
| { | ||
| using namespace ov::element; | ||
| // QDQ stripping pipeline | ||
| // 1. Transform DQ part to canonicalized form: Multiply->Add => Subtract->Multiply | ||
| manager.register_pass<AddTransformation>(); | ||
| // 2. Fuse FQ->Convert->DQ to a single FQ | ||
| manager.register_pass<ov::pass::ConvertQuantizeDequantize>(TypeVector{i16, u16}, TypeVector{f32}, true); | ||
| // 3. Strip FQ layers with unsupported levels | ||
| bool replace_with_clamp = ov::util::getenv_bool("REPLACE_QDQ_WITH_CLAMP", true); | ||
| std::cout << "[ QDQ STRIPPING INFO ] replace_with_clamp = " << replace_with_clamp << std::endl; | ||
|
||
| manager.register_pass<FQStrippingTransformation>(std::set<size_t>{levels::int16}, replace_with_clamp); | ||
| manager.register_pass<ov::pass::Validate>(); | ||
aobolensk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| manager.register_pass<ov::pass::MarkDequantization>( | ||
| std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 }, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.