Skip to content

Commit 0ba1789

Browse files
committed
Finalize first QDQ Stripping implementation
1 parent 3b2fdc7 commit 0ba1789

File tree

4 files changed

+37
-15
lines changed

4 files changed

+37
-15
lines changed

src/common/low_precision_transformations/include/low_precision/qdq_stripping.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace low_precision {
2323
class LP_TRANSFORMATIONS_API FQStrippingTransformation : public ov::pass::MatcherPass {
2424
public:
2525
OPENVINO_RTTI("FQStrippingTransformation", "0", MatcherPass);
26-
FQStrippingTransformation(const std::set<size_t>& levels_to_strip);
26+
FQStrippingTransformation(const std::set<size_t>& levels_to_strip, bool replace_with_clamp);
2727
};
2828

2929
} // namespace low_precision

src/common/low_precision_transformations/src/qdq_stripping.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
#include "openvino/core/type.hpp"
1515
#include "openvino/op/clamp.hpp"
1616
#include "openvino/op/constant.hpp"
17+
#include "openvino/op/equal.hpp"
1718
#include "openvino/op/fake_quantize.hpp"
1819
#include "openvino/pass/pattern/op/wrap_type.hpp"
1920
#include "openvino/util/log.hpp"
21+
#include "transformations/utils/utils.hpp"
2022

2123
namespace ov {
2224
namespace pass {
2325
namespace low_precision {
2426

25-
FQStrippingTransformation::FQStrippingTransformation(const std::set<size_t>& levels_to_strip) {
27+
FQStrippingTransformation::FQStrippingTransformation(const std::set<size_t>& levels_to_strip, bool replace_with_clamp) {
2628
MATCHER_SCOPE(FQStrippingTransformation);
2729
auto is_scalar = [](const Output<Node>& output) -> bool {
2830
return ov::shape_size(output.get_shape()) == 1;
@@ -43,29 +45,38 @@ FQStrippingTransformation::FQStrippingTransformation(const std::set<size_t>& lev
4345

4446
const size_t levels = node->get_levels();
4547
if (!levels_to_strip.count(levels)) {
46-
std::cout << "[QDQStripping] Levels " << levels << " not in strip set, skipping" << std::endl;
4748
return false;
4849
}
4950

50-
std::cout << "[QDQStripping] Levels " << levels << " found in strip set, proceeding with transformation"
51-
<< std::endl;
52-
5351
auto input = node->get_input_node_shared_ptr(0);
52+
auto input_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(input_low_m).get_node_shared_ptr());
53+
auto input_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(input_high_m).get_node_shared_ptr());
5454
auto output_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(output_low_m).get_node_shared_ptr());
5555
auto output_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(output_high_m).get_node_shared_ptr());
5656

5757
// TODO: need to check that input and output intervals are equal
58-
if (!output_low || !output_high) {
59-
std::cout << "[QDQStripping] Failed to get constant output_low or output_high nodes" << std::endl;
58+
if (!input_low || !input_high || !output_low || !output_high) {
59+
return false;
60+
}
61+
auto constants_are_equal = [](const std::shared_ptr<ov::op::v0::Constant>& lhs,
62+
const std::shared_ptr<ov::op::v0::Constant>& rhs) {
63+
auto equal = ov::as_type_ptr<ov::op::v0::Constant>(ov::op::util::make_try_fold<ov::op::v1::Equal>(lhs, rhs));
64+
OPENVINO_ASSERT(equal && ov::shape_size(equal->get_shape()) == 1,
65+
"constants_are_equal expects scalar constant as a comparison result");
66+
return equal->get_vector<bool>()[0] == true;
67+
};
68+
if (!constants_are_equal(input_low, output_low) || !constants_are_equal(input_high, output_high)) {
6069
return false;
6170
}
6271

63-
auto clamp = std::make_shared<ov::op::v0::Clamp>(input->output(0),
64-
output_low->cast_vector<double>()[0],
65-
output_high->cast_vector<double>()[0]);
66-
std::cout << "[ INFO ] clamp low = " << clamp->get_min() << ", high = " << clamp->get_max() << std::endl;
67-
68-
return replace_node_update_name(node, clamp);
72+
if (replace_with_clamp) {
73+
auto clamp = std::make_shared<ov::op::v0::Clamp>(input->output(0),
74+
output_low->cast_vector<double>()[0],
75+
output_high->cast_vector<double>()[0]);
76+
return replace_node_update_name(node, clamp);
77+
} else {
78+
return replace_output_update_name(node->output(0), node->input_value(0));
79+
}
6980
};
7081

7182
auto m = std::make_shared<ov::pass::pattern::Matcher>(fq_m, matcher_name);

src/common/transformations/src/transformations/common_optimizations/convert_quantize_dequantize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ ov::pass::ConvertQuantizeDequantize::ConvertQuantizeDequantize(
166166

167167
copy_runtime_info({fq, convert1.get_node_shared_ptr(), convert2.get_node_shared_ptr()}, new_fq);
168168
replace_node(mul, new_fq);
169-
std::cout << "[ INFO ] ConvertQuantizeDequantize is finished for node " << new_fq->get_friendly_name() << std::endl;
170169

171170
return true;
172171
};

src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "low_precision/fold_convert.hpp"
2525
#include "low_precision/fuse_convert.hpp"
2626
#include "low_precision/group_convolution.hpp"
27+
#include "low_precision/qdq_stripping.hpp"
2728
#include "low_precision/low_precision.hpp"
2829
#include "low_precision/mat_mul.hpp"
2930
#include "low_precision/multiply_to_group_convolution.hpp"
@@ -390,6 +391,17 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
390391

391392
auto is_model_quantized = ov::pass::low_precision::LowPrecision::isFunctionQuantized(func);
392393
enableInt8 = config.get_enable_lp_transformations() && is_model_quantized;
394+
{
395+
using namespace ov::pass::low_precision;
396+
// QDQ stripping pipeline
397+
// 1. Transform DQ part to canonicalized form: Multiply->Add => Subtract->Multiply
398+
manager.register_pass<AddTransformation>();
399+
// 2. Fuse FQ->Convert->DQ to a single FQ
400+
manager.register_pass<ov::pass::ConvertQuantizeDequantize>(ov::element::TypeVector{ov::element::i16, ov::element::u16});
401+
// 3. Strip FQ layers with unsupported levels
402+
bool replace_with_clamp = false;
403+
manager.register_pass<FQStrippingTransformation>(std::set<size_t>{levels::int16}, replace_with_clamp);
404+
}
393405

394406
manager.register_pass<ov::pass::MarkDequantization>(
395407
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 },

0 commit comments

Comments
 (0)