1414#include " openvino/core/type.hpp"
1515#include " openvino/op/clamp.hpp"
1616#include " openvino/op/constant.hpp"
17+ #include " openvino/op/equal.hpp"
1718#include " openvino/op/fake_quantize.hpp"
1819#include " openvino/pass/pattern/op/wrap_type.hpp"
1920#include " openvino/util/log.hpp"
21+ #include " transformations/utils/utils.hpp"
2022
2123namespace ov {
2224namespace pass {
2325namespace low_precision {
2426
25- FQStrippingTransformation::FQStrippingTransformation (const std::set<size_t >& levels_to_strip) {
27+ FQStrippingTransformation::FQStrippingTransformation (const std::set<size_t >& levels_to_strip, bool replace_with_clamp ) {
2628 MATCHER_SCOPE (FQStrippingTransformation);
2729 auto is_scalar = [](const Output<Node>& output) -> bool {
2830 return ov::shape_size (output.get_shape ()) == 1 ;
@@ -43,29 +45,38 @@ FQStrippingTransformation::FQStrippingTransformation(const std::set<size_t>& lev
4345
4446 const size_t levels = node->get_levels ();
4547 if (!levels_to_strip.count (levels)) {
46- std::cout << " [QDQStripping] Levels " << levels << " not in strip set, skipping" << std::endl;
4748 return false ;
4849 }
4950
50- std::cout << " [QDQStripping] Levels " << levels << " found in strip set, proceeding with transformation"
51- << std::endl;
52-
5351 auto input = node->get_input_node_shared_ptr (0 );
52+ auto input_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (input_low_m).get_node_shared_ptr ());
53+ auto input_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (input_high_m).get_node_shared_ptr ());
5454 auto output_low = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (output_low_m).get_node_shared_ptr ());
5555 auto output_high = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (output_high_m).get_node_shared_ptr ());
5656
5757 // TODO: need to check that input and output intervals are equal
58- if (!output_low || !output_high) {
59- std::cout << " [QDQStripping] Failed to get constant output_low or output_high nodes" << std::endl;
58+ if (!input_low || !input_high || !output_low || !output_high) {
59+ return false ;
60+ }
61+ auto constants_are_equal = [](const std::shared_ptr<ov::op::v0::Constant>& lhs,
62+ const std::shared_ptr<ov::op::v0::Constant>& rhs) {
63+ auto equal = ov::as_type_ptr<ov::op::v0::Constant>(ov::op::util::make_try_fold<ov::op::v1::Equal>(lhs, rhs));
64+ OPENVINO_ASSERT (equal && ov::shape_size (equal->get_shape ()) == 1 ,
65+ " constants_are_equal expects scalar constant as a comparison result" );
66+ return equal->get_vector <bool >()[0 ] == true ;
67+ };
68+ if (!constants_are_equal (input_low, output_low) || !constants_are_equal (input_high, output_high)) {
6069 return false ;
6170 }
6271
63- auto clamp = std::make_shared<ov::op::v0::Clamp>(input->output (0 ),
64- output_low->cast_vector <double >()[0 ],
65- output_high->cast_vector <double >()[0 ]);
66- std::cout << " [ INFO ] clamp low = " << clamp->get_min () << " , high = " << clamp->get_max () << std::endl;
67-
68- return replace_node_update_name (node, clamp);
72+ if (replace_with_clamp) {
73+ auto clamp = std::make_shared<ov::op::v0::Clamp>(input->output (0 ),
74+ output_low->cast_vector <double >()[0 ],
75+ output_high->cast_vector <double >()[0 ]);
76+ return replace_node_update_name (node, clamp);
77+ } else {
78+ return replace_output_update_name (node->output (0 ), node->input_value (0 ));
79+ }
6980 };
7081
7182 auto m = std::make_shared<ov::pass::pattern::Matcher>(fq_m, matcher_name);
0 commit comments