Skip to content

Commit b82b902

Browse files
[gpu] Recognize parameters as valid inputs for compressed weights
This change enables use of quantized LoRA weights, passed as parameters during execution, to be recognized by the transformaions that produce FullyConnectedCompressed nodes for QGEMM execution.
1 parent 9352ef0 commit b82b902

File tree

5 files changed

+144
-52
lines changed

5 files changed

+144
-52
lines changed

src/common/transformations/include/transformations/utils/utils.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>
289289

290290
TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output);
291291

292+
TRANSFORMATIONS_API bool is_on_constant_or_param_path(const ov::Output<ov::Node>& output);
293+
292294
TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);
293295

294296
TRANSFORMATIONS_API std::tuple<std::shared_ptr<ov::Node>, // result

src/common/transformations/src/transformations/utils/utils.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,44 @@ bool is_on_constant_path(const ov::Output<ov::Node>& output) {
536536
return status;
537537
}
538538

539+
bool is_on_constant_or_param_path(const ov::Output<ov::Node>& output) {
540+
auto status = true;
541+
542+
auto root_node = output.get_node();
543+
if (!root_node || root_node->get_output_size() == 0) {
544+
return false;
545+
}
546+
std::deque<ov::Node*> nodes_to_calculate = {root_node};
547+
548+
std::unordered_set<ov::Node*> visited;
549+
while (status && !nodes_to_calculate.empty()) {
550+
auto current_node = nodes_to_calculate.front();
551+
nodes_to_calculate.pop_front();
552+
if (visited.count(current_node)) {
553+
continue;
554+
}
555+
visited.insert(current_node);
556+
// RandomUniform output changes during runtime, so we should not consider it as a constant
557+
if (current_node->get_type_info() == ov::op::v8::RandomUniform::get_type_info_static()) {
558+
return false;
559+
}
560+
561+
if (current_node->get_input_size() == 0 &&
562+
!(ov::is_type<ov::op::v0::Constant>(current_node) || ov::is_type<ov::op::v0::Parameter>(current_node))) {
563+
status = false;
564+
} else {
565+
// not a leaf - continue to search
566+
for (const auto& input_value : current_node->input_values()) {
567+
const auto& input_node = input_value.get_node();
568+
if (!visited.count(input_node)) {
569+
nodes_to_calculate.push_front(input_node);
570+
}
571+
}
572+
}
573+
}
574+
return status;
575+
}
576+
539577
bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node) {
540578
bool changed = false;
541579

src/plugins/intel_gpu/src/plugin/transformations/compressed_weights_pattern.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ using namespace ov::pass::pattern;
1414
((in_ps.size() == 3 && out_ps.size() == 2) || (in_ps.size() == 4 && out_ps.size() == 3));\
1515
};\
1616
\
17-
auto compressed_weights_m = wrap_type<ov::op::v0::Constant>(compressed_constant);\
17+
auto weights_const_m = wrap_type<ov::op::v0::Constant>(compressed_constant);\
18+
auto weights_param_m = wrap_type<ov::op::v0::Parameter>(compressed_constant);\
19+
auto weights_param_reshape_m = wrap_type<ov::op::v1::Reshape>({weights_param_m, any_input()});\
20+
auto compressed_weights_m = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{weights_const_m, weights_param_m, weights_param_reshape_m});\
1821
auto convert_m = wrap_type<ov::op::v0::Convert>({compressed_weights_m});\
1922
\
2023
auto sub_const_m = wrap_type<ov::op::v0::Constant>();\

src/plugins/intel_gpu/src/plugin/transformations/convert_fc_to_compressed.cpp

Lines changed: 99 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,15 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
5757

5858
auto weight_ptr = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(compressed_weights_m).get_node_shared_ptr());
5959
bool weight_u8 = false;
60-
if (weight_ptr->get_element_type() == ov::element::u8 || weight_ptr->get_element_type() == ov::element::i8)
61-
weight_u8 = true;
60+
if (pattern_map.count(weights_const_m)) {
61+
auto weight_ptr = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(weights_const_m).get_node_shared_ptr());
62+
if (weight_ptr->get_element_type() == ov::element::u8 || weight_ptr->get_element_type() == ov::element::i8)
63+
weight_u8 = true;
64+
} else {
65+
auto weight_ptr = ov::as_type_ptr<ov::op::v0::Parameter>(pattern_map.at(weights_param_m).get_node_shared_ptr());
66+
if (weight_ptr->get_element_type() == ov::element::u8 || weight_ptr->get_element_type() == ov::element::i8)
67+
weight_u8 = true;
68+
}
6269

6370
auto reshape_const = [has_transpose, grouped, is_weight_3d](std::shared_ptr<ov::Node> node) {
6471
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
@@ -73,7 +80,7 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
7380
return constant;
7481
else
7582
new_shape = (has_transpose || !grouped) ? ov::Shape{current_shape[0] * current_shape[1], current_shape[2]}
76-
: ov::Shape{current_shape[0], current_shape[1] * current_shape[2]};
83+
: ov::Shape{current_shape[0], current_shape[1] * current_shape[2]};
7784
} else {
7885
OPENVINO_ASSERT(current_shape.size() == 4 && is_weight_3d);
7986
new_shape = (has_transpose || !grouped) ? ov::Shape{current_shape[0], current_shape[1] * current_shape[2], current_shape[3]}
@@ -102,7 +109,6 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
102109
return result;
103110
};
104111

105-
106112
const ov::Output<Node>& fc_input_a = fc->input(0).get_source_output();
107113
const auto& scale = reshape_const(pattern_map.at(mul_const_m).get_node_shared_ptr());
108114
std::shared_ptr<ov::Node> optional_zero_point = nullptr;
@@ -112,61 +118,104 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
112118
optional_zero_point = convert_const_to_u8(reshape_const(pattern_map.at(sub_const_m).get_node_shared_ptr()));
113119
}
114120

115-
std::shared_ptr<ov::Node> fc_input_b = reshape_const(pattern_map.at(compressed_weights_m).get_node_shared_ptr());
116-
std::shared_ptr<ov::Node> fc_input_scale = scale;
117-
std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
118-
std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at(bias_m).get_node_shared_ptr();
119-
std::vector<std::shared_ptr<ov::Node>> result_nodes = {};
120-
121-
if (has_transpose) {
122-
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
123-
std::shared_ptr<ov::Node> transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
124-
if (ov::shape_size(transpose_const->get_shape()) != fc_input_b->get_output_partial_shape(0).size()) {
125-
std::vector<int32_t> new_order(fc_input_b->get_output_partial_shape(0).size());
126-
std::iota(new_order.begin(), new_order.end(), 0);
127-
std::swap(new_order[new_order.size() - 1], new_order[new_order.size() - 2]);
128-
transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_order.size()}, new_order);
121+
if (pattern_map.count(weights_const_m)) {
122+
std::shared_ptr<ov::Node> fc_input_b = reshape_const(pattern_map.at(weights_const_m).get_node_shared_ptr());
123+
std::shared_ptr<ov::Node> fc_input_scale = scale;
124+
std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
125+
std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at(bias_m).get_node_shared_ptr();
126+
std::vector<std::shared_ptr<ov::Node>> result_nodes = {};
127+
128+
if (has_transpose) {
129+
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
130+
std::shared_ptr<ov::Node> transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
131+
if (ov::shape_size(transpose_const->get_shape()) != fc_input_b->get_output_partial_shape(0).size()) {
132+
std::vector<int32_t> new_order(fc_input_b->get_output_partial_shape(0).size());
133+
std::iota(new_order.begin(), new_order.end(), 0);
134+
std::swap(new_order[new_order.size() - 1], new_order[new_order.size() - 2]);
135+
transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_order.size()}, new_order);
136+
}
137+
138+
fc_input_b = transpose->clone_with_new_inputs({fc_input_b->output(0), transpose_const});
139+
result_nodes.push_back(fc_input_b);
140+
141+
if (ov::shape_size(scale->output(0).get_shape()) > 1) {
142+
fc_input_scale = transpose->clone_with_new_inputs({scale->output(0), transpose_const});
143+
result_nodes.push_back(fc_input_scale);
144+
}
145+
146+
if (with_zero_point && ov::shape_size(optional_zero_point->output(0).get_shape()) > 1) {
147+
fc_input_zp = transpose->clone_with_new_inputs({optional_zero_point->output(0), transpose_const});
148+
result_nodes.push_back(fc_input_zp);
149+
}
129150
}
130151

131-
fc_input_b = transpose->clone_with_new_inputs({ fc_input_b->output(0), transpose_const });
132-
result_nodes.push_back(fc_input_b);
152+
if (pattern_map.count(mul2_m)) {
153+
auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(mul2_const_m).get_node_shared_ptr());
154+
fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
155+
}
133156

134-
if (ov::shape_size(scale->output(0).get_shape()) > 1) {
135-
fc_input_scale = transpose->clone_with_new_inputs({ scale->output(0), transpose_const });
136-
result_nodes.push_back(fc_input_scale);
157+
std::shared_ptr<ov::Node> new_fc = nullptr;
158+
if (with_zero_point) {
159+
new_fc =
160+
std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc_input_zp, fc->get_output_type());
161+
} else {
162+
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc->get_output_type());
137163
}
138164

139-
if (with_zero_point && ov::shape_size(optional_zero_point->output(0).get_shape()) > 1) {
140-
fc_input_zp = transpose->clone_with_new_inputs({ optional_zero_point->output(0), transpose_const });
141-
result_nodes.push_back(fc_input_zp);
165+
result_nodes.push_back(new_fc);
166+
new_fc->set_friendly_name(fc->get_friendly_name());
167+
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
168+
ov::replace_node(fc, new_fc);
169+
} else {
170+
std::shared_ptr<ov::Node> fc_input_b = pattern_map.count(weights_param_reshape_m) ? pattern_map.at(weights_param_reshape_m).get_node_shared_ptr()
171+
: pattern_map.at(weights_param_m).get_node_shared_ptr();
172+
std::shared_ptr<ov::Node> fc_input_scale = scale;
173+
std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
174+
std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at(bias_m).get_node_shared_ptr();
175+
std::vector<std::shared_ptr<ov::Node>> result_nodes = {};
176+
177+
if (has_transpose) {
178+
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
179+
std::shared_ptr<ov::Node> transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
180+
if (ov::shape_size(transpose_const->get_shape()) != fc_input_b->get_output_partial_shape(0).size()) {
181+
std::vector<int32_t> new_order(fc_input_b->get_output_partial_shape(0).size());
182+
std::iota(new_order.begin(), new_order.end(), 0);
183+
std::swap(new_order[new_order.size() - 1], new_order[new_order.size() - 2]);
184+
transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_order.size()}, new_order);
185+
}
186+
187+
fc_input_b = transpose->clone_with_new_inputs({fc_input_b->output(0), transpose_const});
188+
result_nodes.push_back(fc_input_b);
189+
190+
if (ov::shape_size(scale->output(0).get_shape()) > 1) {
191+
fc_input_scale = transpose->clone_with_new_inputs({scale->output(0), transpose_const});
192+
result_nodes.push_back(fc_input_scale);
193+
}
194+
195+
if (with_zero_point && ov::shape_size(optional_zero_point->output(0).get_shape()) > 1) {
196+
fc_input_zp = transpose->clone_with_new_inputs({optional_zero_point->output(0), transpose_const});
197+
result_nodes.push_back(fc_input_zp);
198+
}
142199
}
143-
}
144200

145-
if (pattern_map.count(mul2_m)) {
146-
auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(mul2_const_m).get_node_shared_ptr());
147-
fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
148-
}
201+
if (pattern_map.count(mul2_m)) {
202+
auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(mul2_const_m).get_node_shared_ptr());
203+
fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
204+
}
149205

150-
std::shared_ptr<ov::Node> new_fc = nullptr;
151-
if (with_zero_point) {
152-
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
153-
fc_input_b,
154-
fc_input_bias,
155-
fc_input_scale,
156-
fc_input_zp,
157-
fc->get_output_type());
158-
} else {
159-
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
160-
fc_input_b,
161-
fc_input_bias,
162-
fc_input_scale,
163-
fc->get_output_type());
164-
}
206+
std::shared_ptr<ov::Node> new_fc = nullptr;
207+
if (with_zero_point) {
208+
new_fc =
209+
std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc_input_zp, fc->get_output_type());
210+
} else {
211+
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc->get_output_type());
212+
}
165213

166-
result_nodes.push_back(new_fc);
167-
new_fc->set_friendly_name(fc->get_friendly_name());
168-
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
169-
ov::replace_node(fc, new_fc);
214+
result_nodes.push_back(new_fc);
215+
new_fc->set_friendly_name(fc->get_friendly_name());
216+
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
217+
ov::replace_node(fc, new_fc);
218+
}
170219

171220
return true;
172221
};

src/plugins/intel_gpu/src/plugin/transformations/convert_matmul_to_fc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ConvertMatMulToFullyConnected::ConvertMatMulToFullyConnected(bool supports_immad
2828
};
2929
auto weights_path = [&static_rank_gt_1](const ov::Output<ov::Node>& output) {
3030
const auto& pshape = output.get_partial_shape();
31-
return ov::op::util::is_on_constant_path(output) &&
31+
return ov::op::util::is_on_constant_or_param_path(output) &&
3232
static_rank_gt_1(output) &&
3333
pshape.is_static();
3434
};

0 commit comments

Comments
 (0)