Skip to content

Commit 1e75852

Browse files
committed
ConvertQuantizeDequantize: ignore consumers_count check
1 parent 5624454 commit 1e75852

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

src/common/transformations/include/transformations/common_optimizations/convert_quantize_dequantize.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@ class ov::pass::ConvertQuantizeDequantize : public ov::pass::MatcherPass {
3737
ov::element::u8,
3838
ov::element::i16,
3939
ov::element::u16},
40-
const ov::element::TypeVector& supported_original_precisions = {ov::element::f32});
40+
const ov::element::TypeVector& supported_original_precisions = {ov::element::f32},
41+
const bool ignore_consumers_count_check = false);
4142
};

src/common/transformations/src/transformations/common_optimizations/convert_quantize_dequantize.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272

7373
ov::pass::ConvertQuantizeDequantize::ConvertQuantizeDequantize(
7474
const ov::element::TypeVector& supported_low_precisions,
75-
const ov::element::TypeVector& supported_original_precisions) {
75+
const ov::element::TypeVector& supported_original_precisions,
76+
const bool ignore_consumers_count_check) {
7677
MATCHER_SCOPE(ConvertQuantizeDequantize);
7778

7879
using namespace ov::pass::pattern;
@@ -85,13 +86,18 @@ ov::pass::ConvertQuantizeDequantize::ConvertQuantizeDequantize(
8586
auto output_high_pattern = wrap_type<v0::Constant>();
8687
auto fq_pattern = wrap_type<v0::FakeQuantize>(
8788
{data_pattern, input_low_pattern, input_high_pattern, output_low_pattern, output_high_pattern});
88-
auto convert1_pattern =
89-
wrap_type<v0::Convert>({fq_pattern}, type_matches_any(supported_low_precisions) && consumers_count(1));
90-
auto convert2_pattern =
91-
wrap_type<v0::Convert>({convert1_pattern},
92-
type_matches_any(supported_original_precisions) && consumers_count(1));
89+
op::Predicate convert1_predicate = ignore_consumers_count_check
90+
? type_matches_any(supported_low_precisions)
91+
: type_matches_any(supported_low_precisions) && consumers_count(1);
92+
auto convert1_pattern = wrap_type<v0::Convert>({fq_pattern}, convert1_predicate);
93+
op::Predicate convert2_predicate = ignore_consumers_count_check
94+
? type_matches_any(supported_original_precisions)
95+
: type_matches_any(supported_original_precisions) && consumers_count(1);
96+
auto convert2_pattern = wrap_type<v0::Convert>({convert1_pattern}, convert2_predicate);
97+
9398
auto zero_point_pattern = any_input();
94-
auto sub_pattern = optional<v1::Subtract>({convert2_pattern, zero_point_pattern}, consumers_count(1));
99+
op::Predicate sub_predicate = ignore_consumers_count_check ? op::Predicate() : consumers_count(1);
100+
auto sub_pattern = optional<v1::Subtract>({convert2_pattern, zero_point_pattern}, sub_predicate);
95101
auto scale_pattern = any_input();
96102
auto mul_pattern = wrap_type<v1::Multiply>({sub_pattern, scale_pattern});
97103

src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
407407
manager.register_pass<AddTransformation>();
408408
SERIALIZE_GRAPHS("add_transformation");
409409
// 2. Fuse FQ->Convert->DQ to a single FQ
410-
manager.register_pass<ov::pass::ConvertQuantizeDequantize>(TypeVector{i16, u16, i32},
411-
TypeVector{f16, f32});
410+
manager.register_pass<ov::pass::ConvertQuantizeDequantize>(TypeVector{i16, u16, i32}, TypeVector{f16, f32}, true);
412411
SERIALIZE_GRAPHS("convert_qdq");
413412
// 3. Strip FQ layers with unsupported levels
414413
bool replace_with_clamp = true;

0 commit comments

Comments
 (0)