1+ // Copyright (C) 2018-2025 Intel Corporation
2+ // SPDX-License-Identifier: Apache-2.0
3+ //
4+
5+ #include " low_precision/qdq_stripping.hpp"
6+
7+ #include < memory>
8+
9+ #include " itt.hpp"
10+ #include " low_precision/common/ie_lpt_exception.hpp"
11+ #include " low_precision/lpt_itt.hpp"
12+ #include " low_precision/network_helper.hpp"
13+ #include " openvino/core/except.hpp"
14+ #include " openvino/core/type.hpp"
15+ #include " openvino/op/clamp.hpp"
16+ #include " openvino/op/constant.hpp"
17+ #include " openvino/op/fake_quantize.hpp"
18+ #include " openvino/pass/pattern/op/wrap_type.hpp"
19+ #include " openvino/util/log.hpp"
20+
21+ namespace ov {
22+ namespace pass {
23+ namespace low_precision {
24+
25+ FQStrippingTransformation::FQStrippingTransformation (const std::set<size_t >& levels_to_strip) {
26+ MATCHER_SCOPE (FQStrippingTransformation);
27+ auto is_scalar = [](const Output<Node>& output) -> bool {
28+ return ov::shape_size (output.get_shape ()) == 1 ;
29+ };
30+ auto input_low_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
31+ auto input_high_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
32+ auto output_low_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
33+ auto output_high_m = pattern::wrap_type<ov::op::v0::Constant>(is_scalar);
34+ auto fq_m = pattern::wrap_type<ov::op::v0::FakeQuantize>(
35+ {pattern::any_input (), input_low_m, input_high_m, output_low_m, output_high_m});
36+
37+ ov::graph_rewrite_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
38+ const auto & pattern_map = m.get_pattern_value_map ();
39+ auto node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(pattern_map.at (fq_m).get_node_shared_ptr ());
40+ if (!node) {
41+ return false ;
42+ }
43+
44+ const size_t levels = node->get_levels ();
45+ if (!levels_to_strip.count (levels)) {
46+ std::cout << " [QDQStripping] Levels " << levels << " not in strip set, skipping" << std::endl;
47+ return false ;
48+ }
49+
50+ std::cout << " [QDQStripping] Levels " << levels << " found in strip set, proceeding with transformation"
51+ << std::endl;
52+
53+ auto input = node->get_input_node_shared_ptr (0 );
54+ auto output_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (output_low_m).get_node_shared_ptr ());
55+ auto output_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (output_high_m).get_node_shared_ptr ());
56+
57+ // TODO: need to check that input and output intervals are equal
58+ if (!output_low || !output_high) {
59+ std::cout << " [QDQStripping] Failed to get constant output_low or output_high nodes" << std::endl;
60+ return false ;
61+ }
62+
63+ auto clamp = std::make_shared<ov::op::v0::Clamp>(input->output (0 ),
64+ output_low->cast_vector <double >()[0 ],
65+ output_high->cast_vector <double >()[0 ]);
66+ std::cout << " [ INFO ] clamp low = " << clamp->get_min () << " , high = " << clamp->get_max () << std::endl;
67+
68+ return replace_node_update_name (node, clamp);
69+ };
70+
71+ auto m = std::make_shared<ov::pass::pattern::Matcher>(fq_m, matcher_name);
72+ this ->register_matcher (m, callback);
73+ }
74+
75+ } // namespace low_precision
76+ } // namespace pass
77+ } // namespace ov
0 commit comments