Skip to content

[GPU][POC] clDNN gemv optimization for LLM second token #28976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c6a8d8d
[GPU][MTL]support gemv
riverlijunjie Feb 10, 2025
c4d9e84
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 10, 2025
1b03afd
Fix accuracy issue
riverlijunjie Feb 13, 2025
a646496
Fix unit test failures
riverlijunjie Feb 15, 2025
0e0786e
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 16, 2025
8d02fe5
Kernel refine
riverlijunjie Feb 19, 2025
668b480
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 19, 2025
8c3909d
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 19, 2025
d1d7c53
Merge branch 'master' into river/cldnn_gemv_opt
peterchen-intel Feb 21, 2025
17d1e29
Fix zp issue
riverlijunjie Feb 21, 2025
5bbe1f9
Fixed INT4_CW issue
riverlijunjie Feb 23, 2025
d9db6bf
Fix USS issue caused by kernel cache
riverlijunjie Feb 25, 2025
870ae43
Fix activation issue
riverlijunjie Mar 2, 2025
509ccb0
Update kernel data type
riverlijunjie Mar 6, 2025
f5b40ec
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Mar 6, 2025
80357ed
Fix gws mismatch lws if output size less than 16
riverlijunjie Mar 7, 2025
2e8998e
update: move judgement logic to fully_connected_inst
zhaixuejun1993 Mar 18, 2025
00103d8
Merge branch 'master' into river/cldnn_gemv_opt
zhaixuejun1993 Mar 18, 2025
30198d0
Fix zp out of memory access issue
riverlijunjie Mar 19, 2025
4f86428
Resolve review comments
riverlijunjie Mar 23, 2025
25e55c2
Update missed file
riverlijunjie Mar 24, 2025
1f178c0
Remove unnecessary kernel switch logic
riverlijunjie Mar 26, 2025
3bbd10b
Revert for dynamic shape
riverlijunjie Mar 27, 2025
d9a70bb
Solve dynamic shape kernel selector issue and remove debug code
riverlijunjie Mar 28, 2025
3b2973e
Update KernelsPriority
riverlijunjie Apr 3, 2025
9e6cf5a
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Apr 3, 2025
7383356
Replace single_batch_optimized field with priority configuration
riverlijunjie Apr 6, 2025
3060393
Merge branch 'master' into river/cldnn_gemv_opt
peterchen-intel Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/plugins/intel_gpu/src/graph/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ std::vector<layout> fully_connected_inst::calc_output_layouts(fully_connected_no
}

kernel_impl_params fully_connected_inst::get_fake_aligned_params(kernel_impl_params const& orig_impl_param) {
if (can_apply_single_batch_optimization(orig_impl_param)) {
return std::move(orig_impl_param);
}

// fc_tiled_opt kernel is optimized for row shape aligned by 8.
// Thus, use fake aligned shape at kernel execution for better performance.
const auto& orig_input_layout = orig_impl_param.get_input_layout();
Expand Down Expand Up @@ -326,6 +330,32 @@ std::string fully_connected_inst::to_string(fully_connected_node const& node) {
return primitive_description.str();
}

bool fully_connected_inst::can_apply_single_batch_optimization(const kernel_impl_params& impl_param) {
if (impl_param.output_layouts.empty() || impl_param.output_layouts[0].is_dynamic())
return false;

// Only support i4/u4 weight so far
if (impl_param.weights_layout) {
auto weights_layout_dt = impl_param.weights_layout.value().data_type;
if (weights_layout_dt != data_types::i4 && weights_layout_dt != data_types::u4) {
return false;
}
}

// Don't support swiglu fused
if (impl_param.fused_desc.size() > 0) {
for (const auto& f : impl_param.fused_desc) {
if (f.is_type<swiglu>())
return false;
}
}

// Single batch
auto shape = impl_param.output_layouts[0].get_partial_shape().to_shape();
auto shape_size = ov::shape_size(shape);
return one_of(shape_size, shape) && (shape_size % 16 == 0);
}

fully_connected_inst::typed_primitive_inst(network& network, fully_connected_node const& node)
: parent(network, node) { }
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {
if (with_zp) {
params.has_decompression_zp = true;
params.decompression_zero_point = convert_data_tensor(updated_impl_param.input_layouts[3]);
if (updated_impl_param.input_layouts[3].get_linear_size() == 1 &&
primitive->decompression_zero_point_scalar.has_value()) {
params.scalar_zp = true;
params.zp_value = primitive->decompression_zero_point_scalar.value();
}
} else if (primitive->decompression_zero_point_scalar.has_value()) {
params.has_decompression_zp = true;
params.scalar_zp = true;
Expand All @@ -203,7 +208,8 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {
params.quantization = kernel_selector::QuantizationType::NONE;
}

params.dynamic_quantization_group_size = impl_param.get_program().get_config().get_dynamic_quantization_group_size();
params.dynamic_quantization_group_size =
impl_param.get_program().get_config().get_dynamic_quantization_group_size();

return params;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class typed_primitive_inst<fully_connected> : public typed_primitive_inst_base<f
static layout calc_output_layout(fully_connected_node const& node, kernel_impl_params const& impl_param);
static kernel_impl_params get_fake_aligned_params(kernel_impl_params const& orig_impl_param);
static std::string to_string(fully_connected_node const& node);
static bool can_apply_single_batch_optimization(const kernel_impl_params& impl_param);

typed_primitive_inst(network& network, fully_connected_node const& node);

Expand Down
Loading
Loading