@@ -57,8 +57,15 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
5757
5858 auto weight_ptr = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (compressed_weights_m).get_node_shared_ptr ());
5959 bool weight_u8 = false ;
60- if (weight_ptr->get_element_type () == ov::element::u8 || weight_ptr->get_element_type () == ov::element::i8 )
61- weight_u8 = true ;
60+ if (pattern_map.count (weights_const_m)) {
61+ auto weight_ptr = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (weights_const_m).get_node_shared_ptr ());
62+ if (weight_ptr->get_element_type () == ov::element::u8 || weight_ptr->get_element_type () == ov::element::i8 )
63+ weight_u8 = true ;
64+ } else {
65+ auto weight_ptr = ov::as_type_ptr<ov::op::v0::Parameter>(pattern_map.at (weights_param_m).get_node_shared_ptr ());
66+ if (weight_ptr->get_element_type () == ov::element::u8 || weight_ptr->get_element_type () == ov::element::i8 )
67+ weight_u8 = true ;
68+ }
6269
6370 auto reshape_const = [has_transpose, grouped, is_weight_3d](std::shared_ptr<ov::Node> node) {
6471 auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
@@ -73,7 +80,7 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
7380 return constant;
7481 else
7582 new_shape = (has_transpose || !grouped) ? ov::Shape{current_shape[0 ] * current_shape[1 ], current_shape[2 ]}
76- : ov::Shape{current_shape[0 ], current_shape[1 ] * current_shape[2 ]};
83+ : ov::Shape{current_shape[0 ], current_shape[1 ] * current_shape[2 ]};
7784 } else {
7885 OPENVINO_ASSERT (current_shape.size () == 4 && is_weight_3d);
7986 new_shape = (has_transpose || !grouped) ? ov::Shape{current_shape[0 ], current_shape[1 ] * current_shape[2 ], current_shape[3 ]}
@@ -102,7 +109,6 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
102109 return result;
103110 };
104111
105-
106112 const ov::Output<Node>& fc_input_a = fc->input (0 ).get_source_output ();
107113 const auto & scale = reshape_const (pattern_map.at (mul_const_m).get_node_shared_ptr ());
108114 std::shared_ptr<ov::Node> optional_zero_point = nullptr ;
@@ -112,61 +118,104 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
112118 optional_zero_point = convert_const_to_u8 (reshape_const (pattern_map.at (sub_const_m).get_node_shared_ptr ()));
113119 }
114120
115- std::shared_ptr<ov::Node> fc_input_b = reshape_const (pattern_map.at (compressed_weights_m).get_node_shared_ptr ());
116- std::shared_ptr<ov::Node> fc_input_scale = scale;
117- std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
118- std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at (bias_m).get_node_shared_ptr ();
119- std::vector<std::shared_ptr<ov::Node>> result_nodes = {};
120-
121- if (has_transpose) {
122- const auto & transpose = pattern_map.at (transpose_m).get_node_shared_ptr ();
123- std::shared_ptr<ov::Node> transpose_const = pattern_map.at (transpose_const_m).get_node_shared_ptr ();
124- if (ov::shape_size (transpose_const->get_shape ()) != fc_input_b->get_output_partial_shape (0 ).size ()) {
125- std::vector<int32_t > new_order (fc_input_b->get_output_partial_shape (0 ).size ());
126- std::iota (new_order.begin (), new_order.end (), 0 );
127- std::swap (new_order[new_order.size () - 1 ], new_order[new_order.size () - 2 ]);
128- transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32 , ov::Shape{new_order.size ()}, new_order);
121+ if (pattern_map.count (weights_const_m)) {
122+ std::shared_ptr<ov::Node> fc_input_b = reshape_const (pattern_map.at (weights_const_m).get_node_shared_ptr ());
123+ std::shared_ptr<ov::Node> fc_input_scale = scale;
124+ std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
125+ std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at (bias_m).get_node_shared_ptr ();
126+ std::vector<std::shared_ptr<ov::Node>> result_nodes = {};
127+
128+ if (has_transpose) {
129+ const auto & transpose = pattern_map.at (transpose_m).get_node_shared_ptr ();
130+ std::shared_ptr<ov::Node> transpose_const = pattern_map.at (transpose_const_m).get_node_shared_ptr ();
131+ if (ov::shape_size (transpose_const->get_shape ()) != fc_input_b->get_output_partial_shape (0 ).size ()) {
132+ std::vector<int32_t > new_order (fc_input_b->get_output_partial_shape (0 ).size ());
133+ std::iota (new_order.begin (), new_order.end (), 0 );
134+ std::swap (new_order[new_order.size () - 1 ], new_order[new_order.size () - 2 ]);
135+ transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32 , ov::Shape{new_order.size ()}, new_order);
136+ }
137+
138+ fc_input_b = transpose->clone_with_new_inputs ({fc_input_b->output (0 ), transpose_const});
139+ result_nodes.push_back (fc_input_b);
140+
141+ if (ov::shape_size (scale->output (0 ).get_shape ()) > 1 ) {
142+ fc_input_scale = transpose->clone_with_new_inputs ({scale->output (0 ), transpose_const});
143+ result_nodes.push_back (fc_input_scale);
144+ }
145+
146+ if (with_zero_point && ov::shape_size (optional_zero_point->output (0 ).get_shape ()) > 1 ) {
147+ fc_input_zp = transpose->clone_with_new_inputs ({optional_zero_point->output (0 ), transpose_const});
148+ result_nodes.push_back (fc_input_zp);
149+ }
129150 }
130151
131- fc_input_b = transpose->clone_with_new_inputs ({ fc_input_b->output (0 ), transpose_const });
132- result_nodes.push_back (fc_input_b);
152+ if (pattern_map.count (mul2_m)) {
153+ auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (mul2_const_m).get_node_shared_ptr ());
154+ fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
155+ }
133156
134- if (ov::shape_size (scale->output (0 ).get_shape ()) > 1 ) {
135- fc_input_scale = transpose->clone_with_new_inputs ({ scale->output (0 ), transpose_const });
136- result_nodes.push_back (fc_input_scale);
157+ std::shared_ptr<ov::Node> new_fc = nullptr ;
158+ if (with_zero_point) {
159+ new_fc =
160+ std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc_input_zp, fc->get_output_type ());
161+ } else {
162+ new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc->get_output_type ());
137163 }
138164
139- if (with_zero_point && ov::shape_size (optional_zero_point->output (0 ).get_shape ()) > 1 ) {
140- fc_input_zp = transpose->clone_with_new_inputs ({ optional_zero_point->output (0 ), transpose_const });
141- result_nodes.push_back (fc_input_zp);
165+ result_nodes.push_back (new_fc);
166+ new_fc->set_friendly_name (fc->get_friendly_name ());
167+ ov::copy_runtime_info (m.get_matched_nodes (), result_nodes);
168+ ov::replace_node (fc, new_fc);
169+ } else {
170+ std::shared_ptr<ov::Node> fc_input_b = pattern_map.count (weights_param_reshape_m) ? pattern_map.at (weights_param_reshape_m).get_node_shared_ptr ()
171+ : pattern_map.at (weights_param_m).get_node_shared_ptr ();
172+ std::shared_ptr<ov::Node> fc_input_scale = scale;
173+ std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
174+ std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at (bias_m).get_node_shared_ptr ();
175+ std::vector<std::shared_ptr<ov::Node>> result_nodes = {};
176+
177+ if (has_transpose) {
178+ const auto & transpose = pattern_map.at (transpose_m).get_node_shared_ptr ();
179+ std::shared_ptr<ov::Node> transpose_const = pattern_map.at (transpose_const_m).get_node_shared_ptr ();
180+ if (ov::shape_size (transpose_const->get_shape ()) != fc_input_b->get_output_partial_shape (0 ).size ()) {
181+ std::vector<int32_t > new_order (fc_input_b->get_output_partial_shape (0 ).size ());
182+ std::iota (new_order.begin (), new_order.end (), 0 );
183+ std::swap (new_order[new_order.size () - 1 ], new_order[new_order.size () - 2 ]);
184+ transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32 , ov::Shape{new_order.size ()}, new_order);
185+ }
186+
187+ fc_input_b = transpose->clone_with_new_inputs ({fc_input_b->output (0 ), transpose_const});
188+ result_nodes.push_back (fc_input_b);
189+
190+ if (ov::shape_size (scale->output (0 ).get_shape ()) > 1 ) {
191+ fc_input_scale = transpose->clone_with_new_inputs ({scale->output (0 ), transpose_const});
192+ result_nodes.push_back (fc_input_scale);
193+ }
194+
195+ if (with_zero_point && ov::shape_size (optional_zero_point->output (0 ).get_shape ()) > 1 ) {
196+ fc_input_zp = transpose->clone_with_new_inputs ({optional_zero_point->output (0 ), transpose_const});
197+ result_nodes.push_back (fc_input_zp);
198+ }
142199 }
143- }
144200
145- if (pattern_map.count (mul2_m)) {
146- auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (mul2_const_m).get_node_shared_ptr ());
147- fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
148- }
201+ if (pattern_map.count (mul2_m)) {
202+ auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at (mul2_const_m).get_node_shared_ptr ());
203+ fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
204+ }
149205
150- std::shared_ptr<ov::Node> new_fc = nullptr ;
151- if (with_zero_point) {
152- new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
153- fc_input_b,
154- fc_input_bias,
155- fc_input_scale,
156- fc_input_zp,
157- fc->get_output_type ());
158- } else {
159- new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
160- fc_input_b,
161- fc_input_bias,
162- fc_input_scale,
163- fc->get_output_type ());
164- }
206+ std::shared_ptr<ov::Node> new_fc = nullptr ;
207+ if (with_zero_point) {
208+ new_fc =
209+ std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc_input_zp, fc->get_output_type ());
210+ } else {
211+ new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc->get_output_type ());
212+ }
165213
166- result_nodes.push_back (new_fc);
167- new_fc->set_friendly_name (fc->get_friendly_name ());
168- ov::copy_runtime_info (m.get_matched_nodes (), result_nodes);
169- ov::replace_node (fc, new_fc);
214+ result_nodes.push_back (new_fc);
215+ new_fc->set_friendly_name (fc->get_friendly_name ());
216+ ov::copy_runtime_info (m.get_matched_nodes (), result_nodes);
217+ ov::replace_node (fc, new_fc);
218+ }
170219
171220 return true ;
172221 };
0 commit comments