Skip to content

Commit 3b2fdc7

Browse files
committed
QDQStripping initial implementation
1 parent 57ade3f commit 3b2fdc7

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <memory>
8+
#include <set>
9+
10+
#include "lpt_visibility.hpp"
11+
#include "openvino/pass/matcher_pass.hpp"
12+
#include "quantization_details.hpp"
13+
14+
namespace ov {
15+
namespace pass {
16+
namespace low_precision {
17+
18+
/**
19+
* @ingroup ov_transformation_common_api
20+
* @brief FQStrippingTransformation strips FakeQuantize operations with specified levels
21+
* by replacing them with Clamp operations.
22+
*/
23+
class LP_TRANSFORMATIONS_API FQStrippingTransformation : public ov::pass::MatcherPass {
24+
public:
25+
OPENVINO_RTTI("FQStrippingTransformation", "0", MatcherPass);
26+
FQStrippingTransformation(const std::set<size_t>& levels_to_strip);
27+
};
28+
29+
} // namespace low_precision
30+
} // namespace pass
31+
} // namespace ov
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "low_precision/qdq_stripping.hpp"
6+
7+
#include <memory>
8+
9+
#include "itt.hpp"
10+
#include "low_precision/common/ie_lpt_exception.hpp"
11+
#include "low_precision/lpt_itt.hpp"
12+
#include "low_precision/network_helper.hpp"
13+
#include "openvino/core/except.hpp"
14+
#include "openvino/core/type.hpp"
15+
#include "openvino/op/clamp.hpp"
16+
#include "openvino/op/constant.hpp"
17+
#include "openvino/op/fake_quantize.hpp"
18+
#include "openvino/pass/pattern/op/wrap_type.hpp"
19+
#include "openvino/util/log.hpp"
20+
21+
namespace ov {
22+
namespace pass {
23+
namespace low_precision {
24+
25+
FQStrippingTransformation::FQStrippingTransformation(const std::set<size_t>& levels_to_strip) {
26+
MATCHER_SCOPE(FQStrippingTransformation);
27+
auto is_scalar = [](const Output<Node>& output) -> bool {
28+
return ov::shape_size(output.get_shape()) == 1;
29+
};
30+
auto input_low_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
31+
auto input_high_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
32+
auto output_low_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
33+
auto output_high_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
34+
auto fq_m = pattern::wrap_type<ov::op::v0::FakeQuantize>(
35+
{pattern::any_input(), input_low_m, input_high_m, output_low_m, output_high_m});
36+
37+
ov::graph_rewrite_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
38+
const auto& pattern_map = m.get_pattern_value_map();
39+
auto node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(pattern_map.at(fq_m).get_node_shared_ptr());
40+
if (!node) {
41+
return false;
42+
}
43+
44+
const size_t levels = node->get_levels();
45+
if (!levels_to_strip.count(levels)) {
46+
std::cout << "[QDQStripping] Levels " << levels << " not in strip set, skipping" << std::endl;
47+
return false;
48+
}
49+
50+
std::cout << "[QDQStripping] Levels " << levels << " found in strip set, proceeding with transformation"
51+
<< std::endl;
52+
53+
auto input = node->get_input_node_shared_ptr(0);
54+
auto output_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(output_low_m).get_node_shared_ptr());
55+
auto output_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(output_high_m).get_node_shared_ptr());
56+
57+
// TODO: need to check that input and output intervals are equal
58+
if (!output_low || !output_high) {
59+
std::cout << "[QDQStripping] Failed to get constant output_low or output_high nodes" << std::endl;
60+
return false;
61+
}
62+
63+
auto clamp = std::make_shared<ov::op::v0::Clamp>(input->output(0),
64+
output_low->cast_vector<double>()[0],
65+
output_high->cast_vector<double>()[0]);
66+
std::cout << "[ INFO ] clamp low = " << clamp->get_min() << ", high = " << clamp->get_max() << std::endl;
67+
68+
return replace_node_update_name(node, clamp);
69+
};
70+
71+
auto m = std::make_shared<ov::pass::pattern::Matcher>(fq_m, matcher_name);
72+
this->register_matcher(m, callback);
73+
}
74+
75+
} // namespace low_precision
76+
} // namespace pass
77+
} // namespace ov

0 commit comments

Comments
 (0)