Skip to content

[GPU] XeTLA based LoRA kernel #30624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b7f8f25
XeTLA based LoRA kernel
dsawczuk-int May 19, 2025
e25836c
Merge branch 'master' into lora_xetla
dsawczuk-int May 19, 2025
a883461
Turn off HW LoRA check temporarily
dsawczuk-int May 19, 2025
15b8fbd
Add fp16 LoRA fusion test, fix fusings check, use JIT for all input s…
dsawczuk-int May 20, 2025
0bb0ba3
Fix: Don't check dynamic shapes in dynamic context
dsawczuk-int May 20, 2025
47cf72a
Use different kernel files for each stage to avoid issues witch caching
dsawczuk-int May 21, 2025
91226b6
Enable fused kernel
dsawczuk-int May 21, 2025
bf08073
Add more precise silu implementation
dsawczuk-int May 21, 2025
91fa3dd
Adjust test threshold
dsawczuk-int May 21, 2025
6e723c6
Fix checking if state A should be transposed
dsawczuk-int May 22, 2025
ee8b08c
Add tiling for second token
dsawczuk-int May 23, 2025
8398ade
Add first token tilings
dsawczuk-int May 23, 2025
10e8d4b
Merge branch 'master' into lora_xetla
dsawczuk-int May 23, 2025
be2c643
Fix
dsawczuk-int May 23, 2025
e504f57
Fix warning
dsawczuk-int May 23, 2025
7886d61
Refactor
dsawczuk-int May 23, 2025
e78736f
Merge branch 'master' into lora_xetla
dsawczuk-int May 23, 2025
905dc61
Add copyright header
dsawczuk-int May 26, 2025
26634bd
Merge branch 'master' into lora_xetla
dsawczuk-int May 26, 2025
4d2b829
Merge branch 'master' into lora_xetla
dsawczuk-int May 27, 2025
138cc04
Change postop index data type to remove warning
dsawczuk-int May 28, 2025
e3d8169
Merge branch 'master' into lora_xetla
dsawczuk-int May 28, 2025
722b83f
Merge branch 'master' into lora_xetla
dsawczuk-int May 29, 2025
8bf5751
Merge branch 'master' into lora_xetla
dsawczuk-int Jun 2, 2025
30b9ceb
Reduce number of kernels
dsawczuk-int Jun 18, 2025
6ec792c
Remove unused variables
dsawczuk-int Jun 18, 2025
f856c8b
Remove empty lora check from dynamic context
dsawczuk-int Jun 18, 2025
3ddfd12
Merge branch 'master' into lora_xetla
dsawczuk-int Jun 18, 2025
da70bc0
Refactor
dsawczuk-int Jun 19, 2025
bbec023
Fix warning
dsawczuk-int Jun 19, 2025
a2bdf2f
Implement adding fused ops argument in CM kernel generator
dsawczuk-int Jun 19, 2025
168d64f
Create helper for handling xetla post ops
dsawczuk-int Jun 19, 2025
b274a03
Rename XeTLA post op macros
dsawczuk-int Jun 19, 2025
88e615d
Remove unnecessary file
dsawczuk-int Jun 19, 2025
7e3cd69
Add postop only kernel for empty lora
dsawczuk-int Jun 24, 2025
601c8a8
Divide alpha scale by lora rank
dsawczuk-int Jun 25, 2025
9424b5f
Fix typo
dsawczuk-int Jun 25, 2025
e135e29
Add FP16 test
dsawczuk-int Jun 25, 2025
003b989
Merge branch 'master' into lora_xetla
dsawczuk-int Jun 25, 2025
c741791
Temporarily disable fp16 test
dsawczuk-int Jun 26, 2025
5a02a1b
Merge branch 'master' into lora_xetla
dsawczuk-int Jun 26, 2025
c6bb3f4
Merge branch 'master' into lora_xetla
dsawczuk-int Jun 26, 2025
05b6c0e
Merge branch 'master' into lora_xetla
dsawczuk-int Jun 26, 2025
a5d17c4
Merge remote-tracking branch 'upstream/master' into lora_xetla
dsawczuk-int Jun 30, 2025
6974895
Merge remote-tracking branch 'upstream/master' into lora_xetla
dsawczuk-int Jul 1, 2025
33fd39c
Merge branch 'master' into lora_xetla
dsawczuk-int Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*******************************************************************************
* Copyright (c) 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

using namespace gpu::xetla;
using namespace gpu::xetla::group;
using namespace gpu::xetla::kernel;
using namespace gpu::xetla::subgroup;

template <typename dtype_a, typename dtype_b, typename dtype_c,
typename dtype_acc, uint32_t wg_m, uint32_t wg_n, uint32_t sg_m,
uint32_t sg_n, uint32_t sg_k, mem_layout layout_a, mem_layout layout_b,
mem_layout layout_c, mem_space mem_space_a, mem_space mem_space_b,
mem_space mem_space_c, uint32_t local_kslicing,
uint32_t global_kslicing, mma_engine engine,
uint32_t periodic_sync_interval, uint32_t prefetch_distance,
gpu_arch arch_tag, uint32_t snake_swizzle = 0, bool unaligned = false>
struct gemm_universal {
using tile_shape = group::tile_shape_t<wg_n, wg_m, sg_n, sg_m>;

using mem_desc_a
= mem_desc_t<dtype_a, layout_a, mem_space_a, unaligned ? 1 : 8>;
using mem_desc_b
= mem_desc_t<dtype_b, layout_b, mem_space_b, unaligned ? 1 : 8>;
using mem_desc_c
= mem_desc_t<dtype_c, layout_c, mem_space_c, unaligned ? 1 : 8>;

using compute_attr = typename std::conditional<engine == mma_engine::fpu,
compute_attr_t<dtype_acc, dtype_acc, dtype_acc>,
compute_attr_t<dtype_a, dtype_b, dtype_acc>>::type;

using perf_tuning_knob = perf_tuning_knob_t<sg_k, prefetch_distance,
periodic_sync_interval>;

using compute_policy_0 =
typename std::conditional<engine == mma_engine::fpu,
compute_policy_default_fpu<compute_attr, perf_tuning_knob,
arch_tag>,
compute_policy_default_xmx<compute_attr, perf_tuning_knob,
arch_tag>>::type;
using compute_policy = typename std::conditional<unaligned,
compute_policy_unaligned_xmx<compute_attr, perf_tuning_knob,
arch_tag>,
compute_policy_0>::type;
using pre_processing = pre_processing_default_t<tile_shape, arch_tag>;
using gemm = gemm_t<compute_policy, tile_shape, mem_desc_a, mem_desc_b,
pre_processing>;

#ifdef LORA_GEMM_A
using scale_op_t = subgroup::scale_v_div_op_t<dtype_b, arch_tag>;
using tile_op_t = subgroup::chained_tile_op_t<scale_op_t>;
#else
XETLA_POST_OP_DEFINITIONS
using tile_op_t = subgroup::chained_tile_op_t<XETLA_POST_OP_LIST>;
#endif

using epilogue = epilogue_t<
epilogue_policy_tile_op<tile_op_t, arch_tag,
unaligned ? msg_type::unaligned_2d : msg_type::block_2d>,
tile_shape, mem_desc_c>;

using epilogue_args_t = typename epilogue::arguments_t;

using group_swizzle_t = kernel::group_swizzle_default<arch_tag>;

using gemm_op_t = kernel::gemm_universal_t<
kernel::dispatch_policy_kslicing<group_swizzle_t, global_kslicing,
local_kslicing>,
gemm, epilogue>;

static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();

inline static void run(sycl::nd_item<3> &item, dtype_a *a, dtype_b *b,
typename epilogue::mem_desc_c_t::base_t c, dtype_acc *acc,
uint32_t *cnt, uint32_t mat_m, uint32_t mat_n, uint32_t mat_k,
uint32_t lda, uint32_t ldb, uint32_t ldc
#ifdef LORA_GEMM_A
,
dtype_b *scale_input
#else
XETLA_POST_OP_ARGS
#endif
) {
gemm_op_t gemm_op;

#ifdef LORA_GEMM_A
typename scale_op_t::scale_shape_t scale_input_shape(mat_n, 1, mat_n);
epilogue_args_t epilogue_args;
epilogue_args.init(
{{scale_input, scale_input_shape, static_cast<float>(mat_n)}});
#else
XETLA_POST_OP_SHAPE_DEFINITIONS
epilogue_args_t epilogue_args;
epilogue_args.init({XETLA_POST_OP_EPILOGUE_INIT_ARGS});
#endif

typename gemm_op_t::arguments_t arg(mat_m, mat_k, mat_n, a, lda, b, ldb,
c.base, ldc, acc, cnt, epilogue_args);
gemm_op(item, arg);
}
};
Loading
Loading