@@ -11,25 +11,21 @@ using namespace ::onnxruntime::common;
11
11
namespace onnxruntime {
12
12
13
13
template <typename T>
14
- void MergeRows (const T* weight, std::vector<T>& result, int64_t first_dim, int64_t second_dim, int64_t third_dim) {
15
- // Insert row_count * col_count elements starting at weight.
16
- result.insert (result.end (), weight, weight + first_dim * second_dim * third_dim);
14
+ void MergeRows (const T* weight, std::vector<T>& result, int64_t N, int64_t blocks, int64_t block_size) {
15
+ result.insert (result.end (), weight, weight + N * blocks * block_size);
17
16
}
18
17
19
18
// Merge Q, K, V tensors into a single vector in the order:
20
19
// [N rows of Q, N rows of K, N rows of V].
21
20
template <typename T>
22
21
void MergeMatMulWeights (const T* q_weight, const T* k_weight, const T* v_weight,
23
- std::vector<T>& result, int64_t first_dim, int64_t second_dim, int64_t third_dim) {
24
- // Merge all rows of Q.
25
- MergeRows (q_weight, result, first_dim, second_dim, third_dim);
26
- // Merge all rows of K.
27
- MergeRows (k_weight, result, first_dim, second_dim, third_dim);
28
- // Merge all rows of V.
29
- MergeRows (v_weight, result, first_dim, second_dim, third_dim);
22
+ std::vector<T>& result, int64_t N, int64_t blocks, int64_t block_size) {
23
+ MergeRows (q_weight, result, N, blocks, block_size);
24
+ MergeRows (k_weight, result, N, blocks, block_size);
25
+ MergeRows (v_weight, result, N, blocks, block_size);
30
26
}
31
27
32
- void AttackDoRotaryAttribute (Node& node) {
28
+ void AttachDoRotaryAttribute (Node& node) {
33
29
NodeAttributes& node_attributes = node.GetMutableAttributes ();
34
30
ONNX_NAMESPACE::AttributeProto attr;
35
31
attr.set_name (" do_rotary" );
@@ -38,10 +34,23 @@ void AttackDoRotaryAttribute(Node& node) {
38
34
node_attributes[" do_rotary" ] = attr;
39
35
}
40
36
37
+ NodeAttributes AttachMmnbAttributes (Node* node) {
38
+ NodeAttributes mmnb_node_atributes = node->GetAttributes ();
39
+ ONNX_NAMESPACE::AttributeProto mmnb_N_attr_proto;
40
+
41
+ // N needs to be multiplied by 3.
42
+ mmnb_N_attr_proto.set_name (" N" );
43
+ mmnb_N_attr_proto.set_type (mmnb_node_atributes[" N" ].type ());
44
+ mmnb_N_attr_proto.set_i (3 * mmnb_node_atributes[" N" ].i ());
45
+ mmnb_node_atributes[" N" ] = mmnb_N_attr_proto;
46
+
47
+ return mmnb_node_atributes;
48
+ }
49
+
41
50
std::array<NodeArg*, 3 > MergeQkvWeights (Graph& graph,
42
- int64_t input_dim, // For example, 3072
43
- int64_t qkv_second_dim, // Number of query heads, e.g., 24
44
- int64_t qkv_third_dim, // Size per head, e.g., 64
51
+ int64_t input_dim, // For example, 3072
52
+ int64_t blocks, // Number of query heads, e.g., 24
53
+ int64_t block_size, // Size per head, e.g., 64
45
54
const ONNX_NAMESPACE::TensorProto* q_tensor,
46
55
const ONNX_NAMESPACE::TensorProto* k_tensor,
47
56
const ONNX_NAMESPACE::TensorProto* v_tensor,
@@ -71,19 +80,25 @@ std::array<NodeArg*, 3> MergeQkvWeights(Graph& graph,
71
80
ONNX_NAMESPACE::TensorProto qkv_b_initializer;
72
81
qkv_b_initializer.set_name (graph.GenerateNodeArgName (" qkv_B" ));
73
82
qkv_b_initializer.add_dims (3 * input_dim);
74
- qkv_b_initializer.add_dims (qkv_second_dim );
75
- qkv_b_initializer.add_dims (qkv_third_dim );
83
+ qkv_b_initializer.add_dims (blocks );
84
+ qkv_b_initializer.add_dims (block_size );
76
85
qkv_b_initializer.set_data_type (q_tensor->data_type ());
77
86
78
87
ONNX_NAMESPACE::TensorProto qkv_scale_initializer;
79
88
qkv_scale_initializer.set_name (graph.GenerateNodeArgName (" qkv_scale" ));
80
- qkv_scale_initializer.add_dims (3 * input_dim);
81
- qkv_scale_initializer.add_dims (qkv_second_dim);
89
+ // Preserve the original tensor dimensionality.
90
+ if (q_scale_tensor->dims ().size () > 1 ) {
91
+ qkv_scale_initializer.add_dims (3 * input_dim);
92
+ qkv_scale_initializer.add_dims (blocks);
93
+ } else {
94
+ qkv_scale_initializer.add_dims (3 * input_dim * blocks);
95
+ }
96
+
82
97
qkv_scale_initializer.set_data_type (q_scale_tensor->data_type ());
83
98
84
99
ONNX_NAMESPACE::TensorProto qkv_zp_initializer;
85
100
qkv_zp_initializer.set_name (graph.GenerateNodeArgName (" qkv_zp" ));
86
- qkv_zp_initializer.add_dims (3 * input_dim * qkv_second_dim / 2 );
101
+ qkv_zp_initializer.add_dims (3 * input_dim * blocks / 2 );
87
102
qkv_zp_initializer.set_data_type (q_zero_point_tensor->data_type ());
88
103
89
104
size_t q_elements = q_initializer.size ();
@@ -102,21 +117,21 @@ std::array<NodeArg*, 3> MergeQkvWeights(Graph& graph,
102
117
const MLFloat16* v_scale_data = v_scale_initializer.data <MLFloat16>();
103
118
const uint8_t * v_zero_points_data = v_zp_initializer.data <uint8_t >();
104
119
105
- size_t scale_elements_count = 3 * input_dim * qkv_second_dim ;
120
+ size_t scale_elements_count = 3 * input_dim * blocks ;
106
121
std::vector<MLFloat16> merged_qkv_scale;
107
122
merged_qkv_scale.reserve (scale_elements_count);
108
123
109
- size_t zp_elements_count = 3 * input_dim * qkv_second_dim / 2 ;
124
+ size_t zp_elements_count = 3 * input_dim * blocks / 2 ;
110
125
std::vector<uint8_t > merged_qkv_zp;
111
126
merged_qkv_zp.reserve (zp_elements_count);
112
127
113
128
size_t element_count = q_elements + k_elements + v_elements;
114
129
std::vector<uint8_t > merged_qkv_B;
115
130
merged_qkv_B.reserve (element_count);
116
131
117
- MergeMatMulWeights (q_data, k_data, v_data, merged_qkv_B, input_dim, qkv_second_dim, qkv_third_dim );
118
- MergeMatMulWeights (q_scale_data, k_scale_data, v_scale_data, merged_qkv_scale, input_dim, qkv_second_dim , 1 );
119
- MergeMatMulWeights (q_zero_points_data, k_zero_points_data, v_zero_points_data, merged_qkv_zp, input_dim, qkv_second_dim / 2 , 1 );
132
+ MergeMatMulWeights (q_data, k_data, v_data, merged_qkv_B, input_dim, blocks, block_size );
133
+ MergeMatMulWeights (q_scale_data, k_scale_data, v_scale_data, merged_qkv_scale, input_dim, blocks , 1 );
134
+ MergeMatMulWeights (q_zero_points_data, k_zero_points_data, v_zero_points_data, merged_qkv_zp, input_dim, blocks / 2 , 1 );
120
135
121
136
utils::SetRawDataInTensorProto (qkv_b_initializer, merged_qkv_B.data (), gsl::narrow<size_t >(element_count) * sizeof (uint8_t ));
122
137
utils::SetRawDataInTensorProto (qkv_scale_initializer, merged_qkv_scale.data (), gsl::narrow<size_t >(scale_elements_count) * sizeof (MLFloat16));
@@ -344,12 +359,7 @@ Status GroupQueryAttentionFusion::ApplyImpl(
344
359
345
360
const std::array mmnb_output_defs{&matmul_output};
346
361
347
- NodeAttributes mmnb_node_atributes = q_node->GetAttributes ();
348
- ONNX_NAMESPACE::AttributeProto mmnb_N_attr_proto;
349
- mmnb_N_attr_proto.set_name (" N" );
350
- mmnb_N_attr_proto.set_type (mmnb_node_atributes[" N" ].type ());
351
- mmnb_N_attr_proto.set_i (3 * mmnb_node_atributes[" N" ].i ());
352
- mmnb_node_atributes[" N" ] = mmnb_N_attr_proto;
362
+ NodeAttributes mmnb_node_atributes = AttachMmnbAttributes (q_node);
353
363
354
364
// Add MatMulNBits
355
365
Node& mat_mul_n_bits_new_node = graph.AddNode (graph.GenerateNodeName (" MatMulNBits" ),
@@ -376,32 +386,18 @@ Status GroupQueryAttentionFusion::ApplyImpl(
376
386
cos_cache_arg,
377
387
sin_cache_arg,
378
388
pos_ids_arg};
379
- /*
380
- // Now add the fused GroupQueryAttention node.
381
- Node& gqa_node = graph.AddNode(graph.GenerateNodeName("GroupQueryAttention"),
382
- "GroupQueryAttention",
383
- "Fused GroupQueryAttention subgraphs",
384
- gqa_input_defs,
385
- {},
386
- &node_attributes,
387
- kMSDomain);
388
-
389
- */
390
389
391
390
// TODO: screws up definition for output
392
391
// ORT_RETURN_IF_ERROR(graph.Resolve());
393
392
394
- [[maybe_unused]] auto producer2 = graph.GetConsumerNodes (matmul_output.Name ());
395
- [[maybe_unused]] auto producer3 = graph.GetProducerNode (matmul_output.Name ());
396
-
397
393
graph_utils::FinalizeNodeFusion (graph, {*q_node, *k_node, *v_node, *rotary_node_1, *rotary_node_2}, mat_mul_n_bits_new_node);
398
394
399
395
auto & mat_mut_output_defs = mat_mul_n_bits_new_node.MutableOutputDefs ();
400
396
mat_mut_output_defs.assign (mmnb_output_defs.begin (), mmnb_output_defs.end ());
401
397
402
398
[[maybe_unused]] const onnxruntime::Node* producer = graph.GetProducerNode (matmul_output.Name ());
403
399
404
- AttackDoRotaryAttribute (node);
400
+ AttachDoRotaryAttribute (node);
405
401
406
402
auto & gqaInputArgs = node.MutableInputArgsCount ();
407
403
gqaInputArgs[7 ] = 1 ;
@@ -415,6 +411,8 @@ Status GroupQueryAttentionFusion::ApplyImpl(
415
411
416
412
ORT_RETURN_IF_ERROR (graph.Resolve ());
417
413
414
+ modified = true ;
415
+
418
416
[[maybe_unused]] const onnxruntime::Node* producer44 = graph.GetProducerNode (matmul_output.Name ());
419
417
420
418
// graph_utils::FinalizeNodeFusion(graph, {node}, node);
0 commit comments