Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU][POC] clDNN gemv optimization for LLM second token #28976

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c6a8d8d
[GPU][MTL]support gemv
riverlijunjie Feb 10, 2025
c4d9e84
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 10, 2025
1b03afd
Fix accuracy issue
riverlijunjie Feb 13, 2025
a646496
Fix unit test failures
riverlijunjie Feb 15, 2025
0e0786e
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 16, 2025
8d02fe5
Kernel refine
riverlijunjie Feb 19, 2025
668b480
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 19, 2025
8c3909d
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Feb 19, 2025
d1d7c53
Merge branch 'master' into river/cldnn_gemv_opt
peterchen-intel Feb 21, 2025
17d1e29
Fix zp issue
riverlijunjie Feb 21, 2025
5bbe1f9
Fixed INT4_CW issue
riverlijunjie Feb 23, 2025
d9db6bf
Fix USS issue caused by kernel cache
riverlijunjie Feb 25, 2025
870ae43
Fix activation issue
riverlijunjie Mar 2, 2025
509ccb0
Update kernel data type
riverlijunjie Mar 6, 2025
f5b40ec
Merge branch 'master' into river/cldnn_gemv_opt
riverlijunjie Mar 6, 2025
80357ed
Fix gws mismatch lws if output size less than 16
riverlijunjie Mar 7, 2025
2e8998e
update: move judgement logic to fully_connected_inst
zhaixuejun1993 Mar 18, 2025
00103d8
Merge branch 'master' into river/cldnn_gemv_opt
zhaixuejun1993 Mar 18, 2025
30198d0
Fix zp out of memory access issue
riverlijunjie Mar 19, 2025
4f86428
Resolve review comments
riverlijunjie Mar 23, 2025
25e55c2
Update missed file
riverlijunjie Mar 24, 2025
1f178c0
Remove unnecessary kernel switch logic
riverlijunjie Mar 26, 2025
3bbd10b
Revert for dynamic shape
riverlijunjie Mar 27, 2025
d9a70bb
Solve dynamic shape kernel selector issue and remove debug code
riverlijunjie Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/plugins/intel_gpu/src/graph/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ std::vector<layout> fully_connected_inst::calc_output_layouts(fully_connected_no
}

kernel_impl_params fully_connected_inst::get_fake_aligned_params(kernel_impl_params const& orig_impl_param) {
if (can_apply_single_batch_optimization(orig_impl_param)) {
return std::move(orig_impl_param);
}

// fc_tiled_opt kernel is optimized for row shape aligned by 8.
// Thus, use fake aligned shape at kernel execution for better performance.
const auto& orig_input_layout = orig_impl_param.get_input_layout();
Expand Down Expand Up @@ -326,6 +330,32 @@ std::string fully_connected_inst::to_string(fully_connected_node const& node) {
return primitive_description.str();
}

bool fully_connected_inst::can_apply_single_batch_optimization(const kernel_impl_params& impl_param) {
if ((impl_param.output_layouts.size() == 0) || impl_param.output_layouts[0].is_dynamic())
return false;

// Only support i4/u4 weight so far
if (impl_param.weights_layout) {
auto weights_layout_dt = impl_param.weights_layout.value().data_type;
if (weights_layout_dt != data_types::i4 && weights_layout_dt != data_types::u4) {
return false;
}
}

// Don't support swiglu fused
if (impl_param.fused_desc.size() > 0) {
for (const auto& f : impl_param.fused_desc) {
if (f.is_type<swiglu>())
return false;
}
}

// Single batch
auto shape = impl_param.output_layouts[0].get_partial_shape().to_shape();
auto shape_size = ov::shape_size(shape);
return one_of(shape_size, shape) && (shape_size % 16 == 0);
}

fully_connected_inst::typed_primitive_inst(network& network, fully_connected_node const& node)
: parent(network, node) { }
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {
if (with_zp) {
params.has_decompression_zp = true;
params.decompression_zero_point = convert_data_tensor(updated_impl_param.input_layouts[3]);
if (updated_impl_param.input_layouts[3].get_linear_size() == 1 &&
primitive->decompression_zero_point_scalar.has_value()) {
params.scalar_zp = true;
params.zp_value = primitive->decompression_zero_point_scalar.value();
}
} else if (primitive->decompression_zero_point_scalar.has_value()) {
params.has_decompression_zp = true;
params.scalar_zp = true;
Expand All @@ -203,7 +208,9 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {
params.quantization = kernel_selector::QuantizationType::NONE;
}

params.dynamic_quantization_group_size = impl_param.get_program().get_config().get_dynamic_quantization_group_size();
params._single_batch_optimized = fully_connected_inst::can_apply_single_batch_optimization(updated_impl_param);
params.dynamic_quantization_group_size =
impl_param.get_program().get_config().get_dynamic_quantization_group_size();

return params;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class typed_primitive_inst<fully_connected> : public typed_primitive_inst_base<f
static layout calc_output_layout(fully_connected_node const& node, kernel_impl_params const& impl_param);
static kernel_impl_params get_fake_aligned_params(kernel_impl_params const& orig_impl_param);
static std::string to_string(fully_connected_node const& node);
static bool can_apply_single_batch_optimization(const kernel_impl_params& impl_param);

typed_primitive_inst(network& network, fully_connected_node const& node);

Expand Down
14 changes: 14 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2868,9 +2868,23 @@ std::shared_ptr<primitive_impl> ImplementationsFactory::get_primitive_impl_for_p
});
}

auto need_single_batch_optimization = [&inst, &updated_params](const std::shared_ptr<primitive_impl> impl) -> bool {
auto is_cldnn_fc_impl = inst.get_node().get_preferred_impl_type() == impl_types::ocl;
auto kernel_name = impl->get_kernel_name();
// Avoid ref_kernel test issue.
auto is_ref_impl = kernel_name.find("fully_connected_gpu_bfyx_ref") != std::string::npos;
auto is_gemv_impl = kernel_name.find("gemv") != std::string::npos;
return is_cldnn_fc_impl && fully_connected_inst::can_apply_single_batch_optimization(updated_params) &&
!is_ref_impl && !is_gemv_impl;
};

std::shared_ptr<primitive_impl> dynamic_impl = nullptr;
// 2. Try to find existing dynamic impl which supports given shapes
for (auto& impl : m_dynamic_impls_cache) {
if (inst.get_node().is_type<fully_connected>() && need_single_batch_optimization(impl)) {
// Switch to single batch optimization.
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this logic? Optimized static impl should be selected at 1) step before this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need it to switch gemv kernel for second token, let's double check the details.

Copy link
Contributor Author

@riverlijunjie riverlijunjie Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirm it works well if run LLM model without this logic, but in case of dynamic shape it will choose fc_fb_tiled kernel rather than gemv kernel for single batch input. @sshlyapn Is there better solution to solve this problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to set priority value in GetKernelsPriority() lower than for bf_tiled kernel, something like FORCE_PRIORITY_3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems it doesn't work, as we see gemv impl is only for input with single batch, and for dynamic shape case input batch is not decided before choose fc impl, so it will first select fc_bf_tiled impl. Once input shape is set, there is no chance to re-choose new fc impl, we have to add above logic to make it can re-choose fc impl.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sshlyapn great help to solve the dynamic shape issue!

if (impl->m_manager->support_shapes(params)) {
dynamic_impl = impl;
break;
Expand Down
Loading
Loading