Skip to content

Commit 70354b6

Browse files
committed
Variable name changes
1 parent f76b3ad commit 70354b6

File tree

1 file changed

+43
-45
lines changed

1 file changed

+43
-45
lines changed

onnxruntime/core/optimizer/group_query_attention_fusion.cc

+43-45
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,21 @@ using namespace ::onnxruntime::common;
1111
namespace onnxruntime {
1212

1313
template <typename T>
14-
void MergeRows(const T* weight, std::vector<T>& result, int64_t first_dim, int64_t second_dim, int64_t third_dim) {
15-
// Insert row_count * col_count elements starting at weight.
16-
result.insert(result.end(), weight, weight + first_dim * second_dim * third_dim);
14+
void MergeRows(const T* weight, std::vector<T>& result, int64_t N, int64_t blocks, int64_t block_size) {
15+
result.insert(result.end(), weight, weight + N * blocks * block_size);
1716
}
1817

1918
// Merge Q, K, V tensors into a single vector in the order:
2019
// [N rows of Q, N rows of K, N rows of V].
2120
template <typename T>
2221
void MergeMatMulWeights(const T* q_weight, const T* k_weight, const T* v_weight,
23-
std::vector<T>& result, int64_t first_dim, int64_t second_dim, int64_t third_dim) {
24-
// Merge all rows of Q.
25-
MergeRows(q_weight, result, first_dim, second_dim, third_dim);
26-
// Merge all rows of K.
27-
MergeRows(k_weight, result, first_dim, second_dim, third_dim);
28-
// Merge all rows of V.
29-
MergeRows(v_weight, result, first_dim, second_dim, third_dim);
22+
std::vector<T>& result, int64_t N, int64_t blocks, int64_t block_size) {
23+
MergeRows(q_weight, result, N, blocks, block_size);
24+
MergeRows(k_weight, result, N, blocks, block_size);
25+
MergeRows(v_weight, result, N, blocks, block_size);
3026
}
3127

32-
void AttackDoRotaryAttribute(Node& node) {
28+
void AttachDoRotaryAttribute(Node& node) {
3329
NodeAttributes& node_attributes = node.GetMutableAttributes();
3430
ONNX_NAMESPACE::AttributeProto attr;
3531
attr.set_name("do_rotary");
@@ -38,10 +34,23 @@ void AttackDoRotaryAttribute(Node& node) {
3834
node_attributes["do_rotary"] = attr;
3935
}
4036

37+
NodeAttributes AttachMmnbAttributes(Node* node) {
38+
NodeAttributes mmnb_node_atributes = node->GetAttributes();
39+
ONNX_NAMESPACE::AttributeProto mmnb_N_attr_proto;
40+
41+
// N needs to be multiplied by 3.
42+
mmnb_N_attr_proto.set_name("N");
43+
mmnb_N_attr_proto.set_type(mmnb_node_atributes["N"].type());
44+
mmnb_N_attr_proto.set_i(3 * mmnb_node_atributes["N"].i());
45+
mmnb_node_atributes["N"] = mmnb_N_attr_proto;
46+
47+
return mmnb_node_atributes;
48+
}
49+
4150
std::array<NodeArg*, 3> MergeQkvWeights(Graph& graph,
42-
int64_t input_dim, // For example, 3072
43-
int64_t qkv_second_dim, // Number of query heads, e.g., 24
44-
int64_t qkv_third_dim, // Size per head, e.g., 64
51+
int64_t input_dim, // For example, 3072
52+
int64_t blocks, // Number of query heads, e.g., 24
53+
int64_t block_size, // Size per head, e.g., 64
4554
const ONNX_NAMESPACE::TensorProto* q_tensor,
4655
const ONNX_NAMESPACE::TensorProto* k_tensor,
4756
const ONNX_NAMESPACE::TensorProto* v_tensor,
@@ -71,19 +80,25 @@ std::array<NodeArg*, 3> MergeQkvWeights(Graph& graph,
7180
ONNX_NAMESPACE::TensorProto qkv_b_initializer;
7281
qkv_b_initializer.set_name(graph.GenerateNodeArgName("qkv_B"));
7382
qkv_b_initializer.add_dims(3 * input_dim);
74-
qkv_b_initializer.add_dims(qkv_second_dim);
75-
qkv_b_initializer.add_dims(qkv_third_dim);
83+
qkv_b_initializer.add_dims(blocks);
84+
qkv_b_initializer.add_dims(block_size);
7685
qkv_b_initializer.set_data_type(q_tensor->data_type());
7786

7887
ONNX_NAMESPACE::TensorProto qkv_scale_initializer;
7988
qkv_scale_initializer.set_name(graph.GenerateNodeArgName("qkv_scale"));
80-
qkv_scale_initializer.add_dims(3 * input_dim);
81-
qkv_scale_initializer.add_dims(qkv_second_dim);
89+
// Preserve the original tensor dimensionality.
90+
if (q_scale_tensor->dims().size() > 1) {
91+
qkv_scale_initializer.add_dims(3 * input_dim);
92+
qkv_scale_initializer.add_dims(blocks);
93+
} else {
94+
qkv_scale_initializer.add_dims(3 * input_dim * blocks);
95+
}
96+
8297
qkv_scale_initializer.set_data_type(q_scale_tensor->data_type());
8398

8499
ONNX_NAMESPACE::TensorProto qkv_zp_initializer;
85100
qkv_zp_initializer.set_name(graph.GenerateNodeArgName("qkv_zp"));
86-
qkv_zp_initializer.add_dims(3 * input_dim * qkv_second_dim / 2);
101+
qkv_zp_initializer.add_dims(3 * input_dim * blocks / 2);
87102
qkv_zp_initializer.set_data_type(q_zero_point_tensor->data_type());
88103

89104
size_t q_elements = q_initializer.size();
@@ -102,21 +117,21 @@ std::array<NodeArg*, 3> MergeQkvWeights(Graph& graph,
102117
const MLFloat16* v_scale_data = v_scale_initializer.data<MLFloat16>();
103118
const uint8_t* v_zero_points_data = v_zp_initializer.data<uint8_t>();
104119

105-
size_t scale_elements_count = 3 * input_dim * qkv_second_dim;
120+
size_t scale_elements_count = 3 * input_dim * blocks;
106121
std::vector<MLFloat16> merged_qkv_scale;
107122
merged_qkv_scale.reserve(scale_elements_count);
108123

109-
size_t zp_elements_count = 3 * input_dim * qkv_second_dim / 2;
124+
size_t zp_elements_count = 3 * input_dim * blocks / 2;
110125
std::vector<uint8_t> merged_qkv_zp;
111126
merged_qkv_zp.reserve(zp_elements_count);
112127

113128
size_t element_count = q_elements + k_elements + v_elements;
114129
std::vector<uint8_t> merged_qkv_B;
115130
merged_qkv_B.reserve(element_count);
116131

117-
MergeMatMulWeights(q_data, k_data, v_data, merged_qkv_B, input_dim, qkv_second_dim, qkv_third_dim);
118-
MergeMatMulWeights(q_scale_data, k_scale_data, v_scale_data, merged_qkv_scale, input_dim, qkv_second_dim, 1);
119-
MergeMatMulWeights(q_zero_points_data, k_zero_points_data, v_zero_points_data, merged_qkv_zp, input_dim, qkv_second_dim / 2, 1);
132+
MergeMatMulWeights(q_data, k_data, v_data, merged_qkv_B, input_dim, blocks, block_size);
133+
MergeMatMulWeights(q_scale_data, k_scale_data, v_scale_data, merged_qkv_scale, input_dim, blocks, 1);
134+
MergeMatMulWeights(q_zero_points_data, k_zero_points_data, v_zero_points_data, merged_qkv_zp, input_dim, blocks / 2, 1);
120135

121136
utils::SetRawDataInTensorProto(qkv_b_initializer, merged_qkv_B.data(), gsl::narrow<size_t>(element_count) * sizeof(uint8_t));
122137
utils::SetRawDataInTensorProto(qkv_scale_initializer, merged_qkv_scale.data(), gsl::narrow<size_t>(scale_elements_count) * sizeof(MLFloat16));
@@ -344,12 +359,7 @@ Status GroupQueryAttentionFusion::ApplyImpl(
344359

345360
const std::array mmnb_output_defs{&matmul_output};
346361

347-
NodeAttributes mmnb_node_atributes = q_node->GetAttributes();
348-
ONNX_NAMESPACE::AttributeProto mmnb_N_attr_proto;
349-
mmnb_N_attr_proto.set_name("N");
350-
mmnb_N_attr_proto.set_type(mmnb_node_atributes["N"].type());
351-
mmnb_N_attr_proto.set_i(3 * mmnb_node_atributes["N"].i());
352-
mmnb_node_atributes["N"] = mmnb_N_attr_proto;
362+
NodeAttributes mmnb_node_atributes = AttachMmnbAttributes(q_node);
353363

354364
// Add MatMulNBits
355365
Node& mat_mul_n_bits_new_node = graph.AddNode(graph.GenerateNodeName("MatMulNBits"),
@@ -376,32 +386,18 @@ Status GroupQueryAttentionFusion::ApplyImpl(
376386
cos_cache_arg,
377387
sin_cache_arg,
378388
pos_ids_arg};
379-
/*
380-
// Now add the fused GroupQueryAttention node.
381-
Node& gqa_node = graph.AddNode(graph.GenerateNodeName("GroupQueryAttention"),
382-
"GroupQueryAttention",
383-
"Fused GroupQueryAttention subgraphs",
384-
gqa_input_defs,
385-
{},
386-
&node_attributes,
387-
kMSDomain);
388-
389-
*/
390389

391390
// TODO: screws up definition for output
392391
// ORT_RETURN_IF_ERROR(graph.Resolve());
393392

394-
[[maybe_unused]] auto producer2 = graph.GetConsumerNodes(matmul_output.Name());
395-
[[maybe_unused]] auto producer3 = graph.GetProducerNode(matmul_output.Name());
396-
397393
graph_utils::FinalizeNodeFusion(graph, {*q_node, *k_node, *v_node, *rotary_node_1, *rotary_node_2}, mat_mul_n_bits_new_node);
398394

399395
auto& mat_mut_output_defs = mat_mul_n_bits_new_node.MutableOutputDefs();
400396
mat_mut_output_defs.assign(mmnb_output_defs.begin(), mmnb_output_defs.end());
401397

402398
[[maybe_unused]] const onnxruntime::Node* producer = graph.GetProducerNode(matmul_output.Name());
403399

404-
AttackDoRotaryAttribute(node);
400+
AttachDoRotaryAttribute(node);
405401

406402
auto& gqaInputArgs = node.MutableInputArgsCount();
407403
gqaInputArgs[7] = 1;
@@ -415,6 +411,8 @@ Status GroupQueryAttentionFusion::ApplyImpl(
415411

416412
ORT_RETURN_IF_ERROR(graph.Resolve());
417413

414+
modified = true;
415+
418416
[[maybe_unused]] const onnxruntime::Node* producer44 = graph.GetProducerNode(matmul_output.Name());
419417

420418
// graph_utils::FinalizeNodeFusion(graph, {node}, node);

0 commit comments

Comments
 (0)