diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora.h b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora.h new file mode 100644 index 00000000000000..cbfb8f4b43a90b --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora.h @@ -0,0 +1,114 @@ +/******************************************************************************* +* Copyright (c) 2022-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +using namespace gpu::xetla; +using namespace gpu::xetla::group; +using namespace gpu::xetla::kernel; +using namespace gpu::xetla::subgroup; + +template +struct gemm_universal { + using tile_shape = group::tile_shape_t; + + using mem_desc_a + = mem_desc_t; + using mem_desc_b + = mem_desc_t; + using mem_desc_c + = mem_desc_t; + + using compute_attr = typename std::conditional, + compute_attr_t>::type; + + using perf_tuning_knob = perf_tuning_knob_t; + + using compute_policy_0 = + typename std::conditional, + compute_policy_default_xmx>::type; + using compute_policy = typename std::conditional, + compute_policy_0>::type; + using pre_processing = pre_processing_default_t; + using gemm = gemm_t; + +#ifdef LORA_GEMM_A + using scale_op_t = subgroup::scale_v_div_op_t; + using tile_op_t = subgroup::chained_tile_op_t; +#else + XETLA_POST_OP_DEFINITIONS + using tile_op_t = subgroup::chained_tile_op_t; +#endif + + using epilogue = epilogue_t< + epilogue_policy_tile_op, + tile_shape, mem_desc_c>; + + using epilogue_args_t = typename epilogue::arguments_t; + + using group_swizzle_t = kernel::group_swizzle_default; + + using gemm_op_t = kernel::gemm_universal_t< + kernel::dispatch_policy_kslicing, + gemm, epilogue>; + + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + + inline static void run(sycl::nd_item<3> &item, dtype_a *a, dtype_b *b, + typename epilogue::mem_desc_c_t::base_t c, dtype_acc *acc, + uint32_t *cnt, uint32_t mat_m, uint32_t mat_n, uint32_t mat_k, + uint32_t lda, uint32_t ldb, uint32_t ldc +#ifdef LORA_GEMM_A + , + dtype_b *scale_input +#else + XETLA_POST_OP_ARGS +#endif + ) { + gemm_op_t gemm_op; + +#ifdef LORA_GEMM_A + typename scale_op_t::scale_shape_t scale_input_shape(mat_n, 1, mat_n); + epilogue_args_t epilogue_args; + epilogue_args.init( + {{scale_input, scale_input_shape, static_cast(mat_n)}}); +#else + XETLA_POST_OP_SHAPE_DEFINITIONS + epilogue_args_t epilogue_args; + epilogue_args.init({XETLA_POST_OP_EPILOGUE_INIT_ARGS}); +#endif + + typename gemm_op_t::arguments_t arg(mat_m, mat_k, mat_n, a, lda, b, ldb, + c.base, ldc, acc, cnt, epilogue_args); + gemm_op(item, arg); + } +}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_a.h b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_a.h new file mode 100644 index 00000000000000..214ae797cab4fd --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_a.h @@ -0,0 +1,580 @@ +/******************************************************************************* +* Copyright (c) 2022-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +using namespace gpu::xetla; +using namespace gpu::xetla::group; +using namespace gpu::xetla::kernel; +using namespace gpu::xetla::subgroup; + +template +class epilogue_lora_gemm_a_t { +public: + using epilogue_policy = epilogue_policy_tile_op; + using tile_op_t = typename epilogue_policy::tile_op_t; + using tile_shape = tile_shape_; + using mem_desc_c_t = mem_desc_c_t_; + static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr uint32_t barrier_count = 0; + static constexpr uint32_t slm_size = mem_desc_c_t::is_local + ? tile_shape::wg_tile_size_x * tile_shape::wg_tile_size_y + : 0; + + /// @brief Epilogue arguments. + struct arguments_t { + /// @brief Is tile_op arguments, could be a single + /// tile_op argument or chained_tile_op_args. + typename tile_op_t::arguments_t tile_op_args; + + /// @brief Constructs a new arguments t object. + inline arguments_t() = default; + + /// @brief Constructs a new arguments t object. + /// @param tile_op_args_ Is tile_op arguments, could be a single + /// tile_op argument or chained_tile_op_args. + inline arguments_t(typename tile_op_t::arguments_t tile_op_args_) + : tile_op_args(tile_op_args_) {} + // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor) + // Please check if you need to add self-define destructor + // inline ~arguments_t(){} + inline arguments_t(const arguments_t &args) + : tile_op_args(args.tile_op_args) {} + + inline arguments_t &operator=(const arguments_t &args) { + this->tile_op_args = args.tile_op_args; + return *this; + } + + /// @brief Explicit initialization function. + /// @param tile_op_args_ Is tile_op arguments, could be a single + /// tile_op argument or chained_tile_op_args. + inline void init(typename tile_op_t::arguments_t tile_op_args_) { + tile_op_args = tile_op_args_; + } + }; + +private: + using work_group_t = typename tile_shape::work_group_t; + static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; + static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; + static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + using dtype_c = typename mem_desc_c_t::dtype; + static constexpr mem_layout mem_layout_c = mem_desc_c_t::layout; + static constexpr mem_space mem_space_c = mem_desc_c_t::space; + + /// @brief Updates tile base descriptor based on the tid. + inline static void update_sg_tile_tdesc( + work_group_t &g, mem_desc_c_t &mem_desc_c) { + int32_t sg_idx = g.get_id() % wg_size_x; + int32_t sg_idy = g.get_id() / wg_size_x; + int32_t tile_offset_n = sg_idx * sg_tile_n; + int32_t tile_offset_m = sg_idy * sg_tile_m; + mem_desc_c.update_coord(tile_offset_n, tile_offset_m); + } + +public: + static constexpr msg_type msg_type_c + = (mem_space_c == mem_space::global ? msg_type_c_ + : msg_type::scatter); + + /// @brief Default epilogue. + /// 1) Call tile_op/chained_tile_op 2) Convert dtype_acc to dtype_c + /// 3) Overwrite/reduce_sum to memory. + /// @tparam matAcc_t Is the type of the input tile. + /// @param g Is the workgroup of the current tile. + /// @param matAcc Is the input tile. + /// @param mem_desc_c Is the memory description of matC, including base, shape and coordinate. + /// @param args Is the additional arguments for epilogue. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + template + inline void operator()(work_group_t &g, matAcc_t &matAcc, + mem_desc_c_t mem_desc_c, arguments_t args = {}, + uint32_t slm_base = 0, uint32_t nbarrier_base = 0) { + using mat_tile_desc = typename matAcc_t::tile_desc; + using matC_t = subgroup::tile_t; + using matC_payload_t = subgroup::mem_payload_t; + update_sg_tile_tdesc(g, mem_desc_c); + tile_op_t tile_op; + tile_op(matAcc, mem_desc_c.coord, args.tile_op_args, slm_base, + nbarrier_base); + if constexpr (store_result) { + matC_t matC; + matC_payload_t matC_payload(mem_desc_c); + subgroup::elemwise_cvt(matC, matAcc); + subgroup::tile_store(matC, matC_payload); + } + } +}; + +/// @addtogroup xetla_gemm_universal +/// @{ + +/// @brief Is the gemm_universal functor, specialized in kslicing dispatch policy and Xe architecture. +/// +/// @tparam num_global_kslicing_ Is the k dim split ratio between groups. +/// @tparam num_local_kslicing_ Is the k dim split ratio within a group. +/// @tparam gemm_t_ Is the gemm functor to compose a GEMM_UNIVERSAL. +/// @tparam epilogue_t_ Is the epilogue functor to compose a GEMM_UNIVERSAL. +template +class gemm_lora_a_t { + using gemm_t = gemm_t_; + using epilogue_t = epilogue_t_; + using gemm_args_t = typename gemm_t::arguments_t; + using epilogue_args_t = typename epilogue_t::arguments_t; + using tile_shape = typename gemm_t::tile_shape; + using group_swizzle_t = group_swizzle_; + + static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y; + static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x; + static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; + static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; + static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y; + static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x; + + static constexpr uint32_t k_stride = gemm_t::k_stride; + using work_group_t = typename gemm_t::work_group_t; + static constexpr uint32_t work_group_size = work_group_t::size; + + static constexpr gpu_arch arch_tag = group_swizzle_t::arch_tag; + static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same"); + static_assert( + arch_tag == epilogue_t::arch_tag, "arch_tag should be the same"); + static_assert(std::is_same::value, + "tile_shape should be the same"); + + using mem_desc_a_t = typename gemm_t::mem_desc_a_t; + using mem_desc_b_t = typename gemm_t::mem_desc_b_t; + using mem_desc_c_t = typename epilogue_t::mem_desc_c_t; + using matA_base_t = typename mem_desc_a_t::base_t; + using matB_base_t = typename mem_desc_b_t::base_t; + using matC_base_t = typename mem_desc_c_t::base_t; + using dtype_a = typename mem_desc_a_t::dtype; + using dtype_b = typename mem_desc_b_t::dtype; + using dtype_c = typename mem_desc_c_t::dtype; + using matAcc_t = typename gemm_t::matAcc_t; + using dtype_acc = typename matAcc_t::dtype; + using mem_desc_acc_t + = mem_desc_t; + using mem_desc_cnt_t + = mem_desc_t; + using acc_base_t = typename mem_desc_acc_t::base_t; + using cnt_base_t = typename mem_desc_cnt_t::base_t; + + static constexpr uint32_t num_global_kslicing = 1; + static constexpr uint32_t num_local_kslicing = num_local_kslicing_; + static_assert((num_global_kslicing > 0) && (num_local_kslicing > 0), + "min slicing ratio is 1"); + + static_assert((num_local_kslicing & (num_local_kslicing - 1)) == 0, + "num_local_kslicing should be power of 2!"); + + using kslicing_t = group::cooperative_reduce_t; + using mat_slice_t = typename kslicing_t::mat_slice_t; + static constexpr uint32_t ks_coop_num_x = kslicing_t::coop_num_x; + static constexpr uint32_t ks_coop_num_y = kslicing_t::coop_num_y; + + static constexpr uint32_t gemm_nbarr_count = gemm_t::barrier_count; + static constexpr uint32_t gemm_slm_size = gemm_t::slm_size; + + static constexpr uint32_t epilogue_nbarr_count = epilogue_t::barrier_count; + static constexpr uint32_t epilogue_slm_size = epilogue_t::slm_size; + + static constexpr uint32_t kslicing_nbarr_count = kslicing_t::barrier_count; + static constexpr uint32_t kslicing_slm_size = kslicing_t::slm_size; + + static constexpr uint32_t counter_size = 8; + + static constexpr uint32_t alignment = 8 / sizeof(dtype_acc); + + using tile_shape_cnt = group::tile_shape_t; + + using global_group_reduce_t = group::global_reduce_t; + +public: + /// @brief GEMM_UNIVERSAL arguments. + /// This is the interface for users to pass the application-related runtime variables. + struct arguments_t { + /// @brief Is the size of the m dimension of the matrix multiplication (m x k x n). + uint32_t matrix_m; + /// @brief Is the size of the k dimension of the matrix multiplication (m x k x n). + uint32_t matrix_k; + /// @brief Is the size of the n dimension of the matrix multiplication (m x k x n). + uint32_t matrix_n; + /// @brief Is the base address of matrix A. + matA_base_t matA_base; + /// @brief Is the leading dimension (pitch) size of the matrix A in memory. + uint32_t matA_ld; + /// @brief Is the base address of matrix B. + matB_base_t matB_base; + /// @brief Is the leading dimension (pitch) size of the matrix B in memory. + uint32_t matB_ld; + /// @brief Is the base address of matrix C. + matC_base_t matC_base; + /// @brief Is the leading dimension (pitch) size of the matrix C in memory. + uint32_t matC_ld; + /// @brief Is the base address of accumulation buffer. + acc_base_t acc_base; + /// @brief Is the base address of counter buffer. + cnt_base_t cnt_base; + /// @brief Is the epilogue arguments. + epilogue_args_t epilogue_args; + + /// @brief Constructs arguments with default method. + inline arguments_t() = default; + + /// @brief Set for device copyable + static constexpr bool host_callable = true; + + // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor) + // Please check if you need to add self-define destructor + // ~arguments_t(){} + + /// @brief Constructs arguments with initialization list. + /// @param matrix_m_ Is the size of the m dimension of the matrix multiplication (m x k x n). + /// @param matrix_k_ Is the size of the k dimension of the matrix multiplication (m x k x n). + /// @param matrix_n_ Is the size of the n dimension of the matrix multiplication (m x k x n). + /// @param matA_base_ Is the base address of matrix A. + /// @param matA_ld_ Is the leading dimension (pitch) size of the matrix A in memory. + /// @param matB_base_ Is the base address of matrix B. + /// @param matB_ld_ Is the leading dimension (pitch) size of the matrix B in memory. + /// @param matC_base_ Is the base address of matrix C. + /// @param matC_ld_ Is the leading dimension (pitch) size of the matrix C in memory. + /// @param epilogue_args_ Is the epilogue arguments. + inline arguments_t(uint32_t matrix_m_, uint32_t matrix_k_, + uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_, + matB_base_t matB_base_, uint32_t matB_ld_, + matC_base_t matC_base_, uint32_t matC_ld_, + acc_base_t acc_base_ = {}, cnt_base_t cnt_base_ = {}, + epilogue_args_t epilogue_args_ = {}) + : matrix_m(matrix_m_) + , matrix_k(matrix_k_) + , matrix_n(matrix_n_) + , matA_base(matA_base_) + , matA_ld(matA_ld_) + , matB_base(matB_base_) + , matB_ld(matB_ld_) + , matC_base(matC_base_) + , matC_ld(matC_ld_) + , acc_base(acc_base_) + , cnt_base(cnt_base_) + , epilogue_args(epilogue_args_) {} + inline arguments_t(const arguments_t &args) + : matrix_m(args.matrix_m) + , matrix_k(args.matrix_k) + , matrix_n(args.matrix_n) + , matA_base(args.matA_base) + , matA_ld(args.matA_ld) + , matB_base(args.matB_base) + , matB_ld(args.matB_ld) + , matC_base(args.matC_base) + , matC_ld(args.matC_ld) + , acc_base(args.acc_base) + , cnt_base(args.cnt_base) + , epilogue_args(args.epilogue_args) {} + // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor) + // Please check if you need to add self-define destructor + // inline ~arguments_t(){} + inline arguments_t &operator=(const arguments_t &args) { + this->matrix_m = args.matrix_m; + this->matrix_k = args.matrix_k; + this->matrix_n = args.matrix_n; + this->matA_base = args.matA_base; + this->matA_ld = args.matA_ld; + this->matB_base = args.matB_base; + this->matB_ld = args.matB_ld; + this->matC_base = args.matC_base; + this->matC_ld = args.matC_ld; + this->acc_base = args.acc_base; + this->cnt_base = args.cnt_base; + this->epilogue_args = args.epilogue_args; + return *this; + } + }; + + /// @brief Gets named_barrier id consumption count. + /// Users query and get a named_barrier id consumption count in compile time. + /// @return The count of named barriers required. + inline static constexpr uint32_t get_barrier_count() { + constexpr uint32_t count = gemm_nbarr_count * num_local_kslicing + + kslicing_nbarr_count + + epilogue_nbarr_count * num_local_kslicing; + static_assert( + count <= 32, "The named_barrier count should be less than 32!"); + return count; + } + + /// @brief Gets local memory size consumption. + /// Users query and get a local memory consumption size in compile time. + /// @return The size of local memory required. + inline static constexpr uint32_t get_slm_size() { + constexpr uint32_t size = gemm_slm_size * num_local_kslicing + + kslicing_slm_size + epilogue_slm_size * num_local_kslicing; + static_assert(size <= (128 * 1024), + "The local memory size should be less than 128KB!"); + return size; + } + + /// @brief Main execution function for GEMM_UNIVERSAL. + /// The processing order is 1) set group-level base and boundary, split group to workgroups -> + /// 2) num_local_kslicing x gemms -> 3) local kslicing -> 4) num_local_kslicing x epilogues. + /// @param item Is the sycl::nd_item, returns execution related information, such as workgroup id, subgroup id... + /// @param args Is the GEMM_UNIVERSAL arguments for application-related runtime variables. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + inline void operator()(sycl::nd_item<3> &item, + const arguments_t &args, matAcc_t &matAcc, uint32_t slm_base = 0, + uint32_t nbarrier_base = 0) const { + // set up workgroup level coordinates and boundaries +#if LORA_TEMP_IN_REG == 1 + work_group_t g((item.get_local_linear_id() / item.get_local_range(2)) + * wg_size_x); +#else + work_group_t g(item.get_local_linear_id() % work_group_size); +#endif + uint32_t wg_id = item.get_local_linear_id() / work_group_size; + group_swizzle_t group_swizzle; + int start_m = group_swizzle.template get_tile_idx<1>(item) * wg_tile_m; + int start_n + = 0; //group_swizzle.template get_tile_idx<2>(item) * wg_tile_n; + int start_k = 0; + uint32_t wg_tile_k = args.matrix_k; + uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n + ? args.matrix_n + : (start_n + wg_tile_n); + uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m + ? args.matrix_m + : (start_m + wg_tile_m); + uint32_t boundary_k = wg_tile_k; + if constexpr (num_global_kslicing > 1) { + wg_tile_k = (wg_tile_k + num_global_kslicing - 1) + / num_global_kslicing; + start_k = start_k + + group_swizzle.template get_tile_idx<0>(item) * wg_tile_k; + boundary_k = (start_k + wg_tile_k) > boundary_k + ? boundary_k + : (start_k + wg_tile_k); + } + if constexpr (num_local_kslicing > 1) { + wg_tile_k + = (wg_tile_k + num_local_kslicing - 1) / num_local_kslicing; + start_k = start_k + wg_id * wg_tile_k; + boundary_k = (start_k + wg_tile_k) > boundary_k + ? boundary_k + : (start_k + wg_tile_k); + } + + // set up arguments + uint32_t gemm_slm_base = slm_base; + uint32_t gemm_nbarr_base = nbarrier_base; + if constexpr (num_local_kslicing > 1) { + gemm_slm_base = slm_base + wg_id * gemm_slm_size; + gemm_nbarr_base = nbarrier_base + wg_id * gemm_nbarr_count; + } + uint32_t kslicing_slm_base + = slm_base + num_local_kslicing * gemm_slm_size; + uint32_t kslicing_nbarr_base + = nbarrier_base + num_local_kslicing * gemm_nbarr_count; + uint32_t epilogue_slm_base = kslicing_slm_base + kslicing_slm_size; + uint32_t epilogue_nbarr_base + = kslicing_nbarr_base + kslicing_nbarr_count; + + mem_desc_a_t mem_desc_a; + mem_desc_b_t mem_desc_b; + mem_desc_c_t mem_desc_c; + //setup for matA + if constexpr (mem_desc_a_t::is_local) { + mem_desc_a.init(args.matA_base, + {wg_tile_k, real_wg_tile_m, wg_tile_k}, {0, 0}); + } else { + mem_desc_a.init(args.matA_base, + {boundary_k, boundary_m, args.matA_ld}, {start_k, start_m}); + } + //setup for matB + if constexpr (mem_desc_b_t::is_local) { + mem_desc_b.init(args.matB_base, + {real_wg_tile_n, wg_tile_k, real_wg_tile_n}, {0, 0}); + } else { + mem_desc_b.init(args.matB_base, + {boundary_n, boundary_k, args.matB_ld}, {start_n, start_k}); + } + + uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; + gemm_args_t gemm_args(mem_desc_a, mem_desc_b, inner_loop_count); + + matAcc.init(0); + gemm_t gemm; + gemm(g, matAcc, gemm_args, gemm_slm_base, gemm_nbarr_base); + if constexpr (num_local_kslicing == 1) { + mem_desc_c.init(args.matC_base, + {boundary_n, boundary_m, args.matC_ld}, {start_n, start_m}); + + epilogue_t epilogue; + epilogue(g, matAcc, mem_desc_c, args.epilogue_args, + epilogue_slm_base, epilogue_nbarr_base); + return; + } + kslicing_t kslicing(wg_id); + mat_slice_t mat_slice; + kslicing(g, mat_slice, matAcc, kslicing_slm_base, kslicing_nbarr_base); + if (kslicing.is_valid_post_process_wg()) { + //setup for matC + //set up cooperative offset for matC store + int32_t coop_offset_x + = kslicing.coop_id_x * mat_slice_t::tile_size_x; + int32_t coop_offset_y + = kslicing.coop_id_y * mat_slice_t::tile_size_y; + int32_t acc_start_x = start_n + coop_offset_x; + int32_t acc_start_y = start_m + coop_offset_y; + int32_t cnt_start_x = group_swizzle.template get_tile_idx<2>(item) + * tile_shape_cnt::wg_tile_size_x + + kslicing.coop_id_x; + int32_t cnt_start_y = group_swizzle.template get_tile_idx<1>(item) + * tile_shape_cnt::wg_tile_size_y + + kslicing.coop_id_y; + uint32_t group_range_x = item.get_group_range(2); + uint32_t group_range_y = item.get_group_range(1); + uint32_t cnt_size_x + = group_range_x * tile_shape_cnt::wg_tile_size_x; + uint32_t cnt_size_y + = group_range_y * tile_shape_cnt::wg_tile_size_y; + + uint32_t acc_aligned_n + = (args.matrix_n + alignment - 1) / alignment * alignment; + + uint32_t acc_boundary_n = (start_n + wg_tile_n) > acc_aligned_n + ? acc_aligned_n + : start_n + wg_tile_n; + + mem_desc_acc_t mem_desc_acc(args.acc_base, + {acc_boundary_n, boundary_m, acc_aligned_n}, + {acc_start_x, acc_start_y}); + mem_desc_cnt_t mem_desc_cnt(args.cnt_base, + {cnt_size_x, cnt_size_y, cnt_size_x}, + {cnt_start_x, cnt_start_y}); + + global_group_reduce_t global_group_reduce; + global_group_reduce(g, mat_slice, mem_desc_acc, mem_desc_cnt); + + if (global_group_reduce.is_last_group()) { + if constexpr (mem_desc_c_t::is_local) { + mem_desc_c.init(args.matC_base, + {real_wg_tile_n, real_wg_tile_m, real_wg_tile_n}, + {coop_offset_x, coop_offset_y}); + } else { + mem_desc_c.init(args.matC_base, + {boundary_n, boundary_m, args.matC_ld}, + {start_n + coop_offset_x, start_m + coop_offset_y}); + } + epilogue_t epilogue; + epilogue(g, mat_slice, mem_desc_c, args.epilogue_args, + epilogue_slm_base, epilogue_nbarr_base); + } + } + } +}; + +template +struct gemm_lora_a { + using tile_shape = group::tile_shape_t; + + static_assert(!((local_kslicing > 1) && (out_in_reg)), + "when using local kslicing, store_result should be true!"); + + using mem_desc_a + = mem_desc_t; + using mem_desc_b + = mem_desc_t; + using mem_desc_c + = mem_desc_t; + + using compute_attr = typename std::conditional, + compute_attr_t>::type; + + using perf_tuning_knob = perf_tuning_knob_t; + + using compute_policy_0 = + typename std::conditional, + compute_policy_default_xmx>::type; + using compute_policy = typename std::conditional, + compute_policy_0>::type; + using pre_processing = pre_processing_default_t; + using gemm = gemm_t; + + using scale_op_t = subgroup::scale_v_div_op_t; + + using tile_op_t = subgroup::chained_tile_op_t; + using epilogue = epilogue_lora_gemm_a_t; + + using epilogue_args_t = typename epilogue::arguments_t; + + using group_swizzle_t = kernel::group_swizzle_default; + + using gemm_op_t + = gemm_lora_a_t; + + using matAcc_t = typename gemm::matAcc_t; + + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + + inline static void run(sycl::nd_item<3> &item, uint32_t mat_m, + uint32_t mat_k, uint32_t mat_n, dtype_a *a, dtype_b *b, + typename epilogue::mem_desc_c_t::base_t c, dtype_b *scale_input, + matAcc_t &matAcc) { + gemm_op_t gemm_op; + + uint32_t lda = layout_a == mem_layout::col_major ? mat_m : mat_k; + uint32_t ldb = layout_b == mem_layout::col_major ? mat_k : mat_n; + uint32_t ldc = layout_c == mem_layout::col_major ? mat_m : mat_n; + + typename scale_op_t::scale_shape_t scale_input_shape(mat_n, 1, mat_n); + epilogue_args_t epilogue_args; + epilogue_args.init( + {{scale_input, scale_input_shape, static_cast(mat_n)}}); + + typename gemm_op_t::arguments_t arg(mat_m, mat_k, mat_n, a, lda, b, ldb, + c.base, ldc, nullptr, nullptr, epilogue_args); + gemm_op(item, arg, matAcc); + } +}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_b.h b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_b.h new file mode 100644 index 00000000000000..811e53b2d0edb1 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_b.h @@ -0,0 +1,708 @@ +/******************************************************************************* +* Copyright (c) 2022-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +using namespace gpu::xetla; +using namespace gpu::xetla::group; +using namespace gpu::xetla::kernel; +using namespace gpu::xetla::subgroup; + +template +class gemm_lora_b_aligned_t { +public: + using mem_desc_a_t = mem_desc_a_t_; + using mem_desc_b_t = mem_desc_b_t_; + using tile_shape = tile_shape_; + using pre_processing_t = pre_processing_t_; + using compute_policy = compute_policy_default_xmx; + static constexpr uint32_t k_stride = compute_policy::k_stride; + static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; + static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; + static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + using work_group_t = typename tile_shape::work_group_t; + + constexpr static gpu_arch arch_tag = compute_policy::arch_tag; + + static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; + static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; + static constexpr bool is_col_major_a + = mem_layout_a == mem_layout::col_major; + static constexpr bool is_col_major_b + = mem_layout_b == mem_layout::col_major; + +private: + /******** set data type **********/ + using dtype_a = typename mem_desc_a_t::dtype; + using dtype_b = typename mem_desc_b_t::dtype; + using dtype_mma_acc = typename compute_policy::dtype_mma_acc; + using dtype_mma_a = typename compute_policy::dtype_mma_a; + using dtype_mma_b = typename compute_policy::dtype_mma_b; + + using check_dtype + = group::gemm::default_xmx::check_dtype_default< + dtype_a, dtype_b, dtype_mma_a, dtype_mma_b>; + + /******** set memory attribute **********/ + static constexpr mem_space mem_space_a = mem_desc_a_t::space; + static constexpr mem_space mem_space_b = mem_desc_b_t::space; + + static constexpr bool is_local_a = mem_space_a == mem_space::local; + static constexpr bool is_local_b = mem_space_b == mem_space::local; + static constexpr tdesc_update_dir update_dir_a = is_col_major_a + ? tdesc_update_dir::y_dir + : tdesc_update_dir::x_dir; + static constexpr tdesc_update_dir update_dir_b = is_col_major_b + ? tdesc_update_dir::x_dir + : tdesc_update_dir::y_dir; + + using check_memory + = group::gemm::default_xmx::check_memory_default< + mem_layout_a, mem_layout_b, mem_space_a, mem_space_b>; + + static constexpr uint32_t stages = compute_policy::stages; + static constexpr uint32_t sync_freq = compute_policy::sync_freq; + + /******** set tile layout && worker scope **********/ + static constexpr uint32_t tile_size_x_a = k_stride; + static constexpr uint32_t tile_size_y_a = sg_tile_m; + static constexpr uint32_t tile_size_x_b = sg_tile_n; + static constexpr uint32_t tile_size_y_b = k_stride; + static constexpr uint32_t tile_size_x_c = sg_tile_n; + static constexpr uint32_t tile_size_y_c = sg_tile_m; + static constexpr uint32_t block_size_x_a = compute_policy::block_size_x_a; + static constexpr uint32_t block_size_y_a + = (compute_policy::block_size_y_a > tile_size_y_a) + ? tile_size_y_a + : compute_policy::block_size_y_a; + static constexpr uint32_t block_size_x_b = compute_policy::block_size_x_b; + static constexpr uint32_t block_size_y_b = compute_policy::block_size_y_b; + + using check_tile_size = group::gemm< + gpu_arch::Xe>::default_xmx::check_tile_size_default; + + /******** set tile **********/ + static constexpr reg_layout reg_layout_a = reg_layout::tiled; + using matA_tile_desc_t = subgroup::tile_desc_t; + + static_assert( + matA_in_t::tile_desc::tile_size_x % matA_tile_desc_t::tile_size_x + == 0, + "matA_in_t tile size x should be the same as matAcc_tile_desc_t"); + static_assert( + matA_tile_desc_t::tile_size_y == matA_in_t::tile_desc::tile_size_y, + "matA_in_t tile size y should be the same as matAcc_tile_desc_t"); + static_assert(matA_tile_desc_t::block_size_x + == matA_in_t::tile_desc::block_size_x, + "matA_in_t block size x should be the same as matAcc_tile_desc_t"); + static_assert(matA_tile_desc_t::block_size_y + == matA_in_t::tile_desc::block_size_y, + "matA_in_t block size y should be the same as matAcc_tile_desc_t"); + + using matA_t = subgroup::tile_t; + using matA_payload_t = subgroup::mem_payload_t; + using matA_acc_t = subgroup::tile_t; + using matA_prefetch_payload_t = subgroup::prefetch_payload_t, + wg_size_x, arch_tag>; + static constexpr reg_layout reg_layout_b + = sizeof(dtype_b) < sizeof(uint32_t) ? reg_layout::vnni_tiled + : reg_layout::tiled; + using matB_tile_desc_t = subgroup::tile_desc_t; + using matB_t = subgroup::tile_t; + using matB_payload_t = subgroup::mem_payload_t; + using matB_acc_t = subgroup::tile_t; + using matB_prefetch_payload_t = subgroup::prefetch_payload_t, + wg_size_y, arch_tag>; + +public: + using matAcc_tile_desc_t = subgroup::tile_desc_t; + using matAcc_t = subgroup::tile_t; + +private: + using tile_mma = subgroup::tile_mma_t; + static constexpr bool enable_periodic_sync = (sync_freq != 0); + static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; + static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + +public: + static constexpr uint32_t barrier_count + = enable_periodic_sync ? barrier_count_x + barrier_count_y : 0; + + static constexpr uint32_t slm_size = 0; + + static constexpr msg_type msg_type_a = matA_payload_t::message_type; + static constexpr msg_type msg_type_b = matB_payload_t::message_type; + + using pre_processing_arg_t = typename pre_processing_t::arguments_t; + + /// @brief Arguments for gemm. + /// User should prepare matA_base_desc, matB_base_desc, inner_loop_count... + struct arguments_t { + /// @brief Is the memory description of matA, including base, shape and coordinate. + mem_desc_a_t matA_base_desc; + /// @brief Is the memory description of matB, including base, shape and coordinate. + mem_desc_b_t matB_base_desc; + /// @brief Is the total inner loop count required to compute the entire K-dim. + uint32_t inner_loop_count; + /// @brief Is the arguments for pre-processing functor. + pre_processing_arg_t pre_processing_args; + + /// @brief Default construct. + inline arguments_t() = default; + + /// @brief Constructs a new arguments t object. + /// @param matA_desc Is the memory description of matA, including base, shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_count, pre_processing_arg_t args = {}) + : matA_base_desc(matA_desc) + , matB_base_desc(matB_desc) + , inner_loop_count(loop_count) + , pre_processing_args(args) {} + // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor) + // Please check if you need to add self-define destructor + // inline ~arguments_t(){} + inline arguments_t(const arguments_t &args) + : matA_base_desc(args.matA_base_desc) + , matB_base_desc(args.matB_base_desc) + , inner_loop_count(args.inner_loop_count) + , pre_processing_args(args.pre_processing_args) {} + inline arguments_t &operator=(const arguments_t &args) { + this->matA_base_desc = args.matA_base_desc; + this->matB_base_desc = args.matB_base_desc; + this->inner_loop_count = args.inner_loop_count; + this->pre_processing_args = args.pre_processing_args; + return *this; + } + + /// @brief Explicit initialization function. + /// @param matA_desc Is the memory description of matA, including base, shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_count, pre_processing_arg_t args = {}) { + matA_base_desc = matA_desc; + matB_base_desc = matB_desc; + inner_loop_count = loop_count; + pre_processing_args = args; + } + }; + + /// @brief Gets the subgroup-level tile offset x. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset x. + inline static int get_matC_offset_x(work_group_t &g) { + int32_t sg_idx = g.get_id() % wg_size_x; + return sg_idx * sg_tile_n; + } + + /// @brief Gets the subgroup-level tile offset y. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset y. + inline static int get_matC_offset_y(work_group_t &g) { + int32_t sg_idy = g.get_id() / wg_size_x; + return sg_idy * sg_tile_m; + } + + inline static void release(uint8_t nbarrier_id = 0) { + static constexpr bool need_local_fence + = (mem_space_a == mem_space::local) + || (mem_space_b == mem_space::local); + if constexpr (need_local_fence) { + xetla_fence(); + } + xetla_fence(); + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; + if constexpr (wg_size > 1) { + xetla_nbarrier_t nbarrier; + nbarrier.init_nbarrier( + nbarrier_id, nbarrier_role::producer_consumer); + nbarrier.arrive_wait(); + } + } + + /// @brief Main execution function for gemm. + /// The basic process is load data -> matrix multiply. + /// @param g Is the workgroup of the current tile. + /// @param matAcc Is the reference of the accumulation buffer. + /// @param args Is the gemm::arguments_t. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + inline void operator()(work_group_t &g, matAcc_t &matAcc, + arguments_t args, matA_in_t &matA_in, uint32_t slm_base = 0, + uint32_t nbarrier_base = 0) { + int32_t sg_idx = g.get_id() % wg_size_x; + int32_t sg_idy = g.get_id() / wg_size_x; + + update_sg_tile_tdesc(args, sg_idx, sg_idy); + pre_processing_t pre_processing; + matA_t matA; + matB_t matB; + // >>>>>>>>>>>>>>>>>> pre_processing init + pre_processing.init(g, args.pre_processing_args); + matB_payload_t matB_payload(args.matB_base_desc); + matA_prefetch_payload_t matA_prefetch_payload( + args.matA_base_desc, sg_idx); + matB_prefetch_payload_t matB_prefetch_payload( + args.matB_base_desc, sg_idy); + xetla_nbarrier_t nbarrier_a; + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + xetla_nbarrier_t nbarrier_b; + nbarrier_b.init_nbarrier(sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + +#pragma unroll + for (int i = 0; i < stages; i++) { + subgroup::tile_prefetch( + matB_prefetch_payload); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + + for (int i = 0; i < args.inner_loop_count; i++) { + if constexpr (enable_periodic_sync) { + if ((i % sync_freq) == 0) { + if constexpr (wg_size_x > 1) { nbarrier_a.arrive(); } + if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); } + } + } + subgroup::tile_load( + matB, matB_payload); + + static_assert(sizeof(dtype_a) == 2); + static_assert(matA_t::tile_desc::tile_size_y + % matA_in_t::tile_desc::block_size_y + == 0); + static_assert(matA_t::tile_desc::block_size_y + == matA_in_t::tile_desc::block_size_y); + static_assert(matA_t::tile_desc::block_size_x + == matA_in_t::tile_desc::block_size_x); + static constexpr uint32_t block_size + = matA_in_t::tile_desc::block_size_x + * matA_in_t::tile_desc::block_size_y; + static_assert(matA_t::tile_desc::tile_size_x + % matA_in_t::tile_desc::block_size_x + == 0); + static constexpr uint32_t copy_n_blocks + = matA_t::tile_desc::tile_size_x + / matA_in_t::tile_desc::block_size_x; + static_assert(matA_in_t::tile_desc::tile_size_x + % matA_in_t::tile_desc::block_size_x + == 0); + if constexpr (matA_t::tile_desc::tile_size_x + < matA_in_t::tile_desc::tile_size_x) { +#pragma unroll + for (int j = 0; j < matA_t::tile_desc::tile_size_y + / matA_in_t::tile_desc::block_size_y; + j++) { + + matA.reg.select(j * block_size * copy_n_blocks) + = matA_in.reg.select< + block_size * copy_n_blocks, 1>(j + * block_size + * matA_in_t::tile_desc::tile_size_x + / matA_in_t::tile_desc::block_size_x + + i * block_size * copy_n_blocks); + } + } else { + matA.reg = matA_in.reg; + } + + if constexpr (stages != 0) { + subgroup::tile_prefetch( + matB_prefetch_payload); + } + cm_fence(CM_SW_BARRIER); + matB_payload.template update_tdesc( + matB_t::tile_size_y); + if constexpr (stages != 0) { + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + cm_fence(CM_SW_BARRIER); + matA_acc_t matA_acc; + matB_acc_t matB_acc; + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::vnni_transform(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + cm_fence(CM_SW_BARRIER); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + cm_fence(CM_SW_BARRIER); + if constexpr (enable_periodic_sync) { + if ((i % sync_freq) == 0) { + if constexpr (wg_size_x > 1) { nbarrier_a.wait(); } + if constexpr (wg_size_y > 1) { nbarrier_b.wait(); } + } + } + } + cm_fence(CM_SW_BARRIER); + } + +private: + /// @brief Updates tile base descriptor based on the tid. + inline static void update_sg_tile_tdesc( + arguments_t &args, int32_t sg_idx, int32_t sg_idy) { + int32_t tile_offset_n = sg_idx * sg_tile_n; + int32_t tile_offset_m = sg_idy * sg_tile_m; + + args.matA_base_desc.update_coord_y(tile_offset_m); + args.matB_base_desc.update_coord_x(tile_offset_n); + } +}; + +template +class gemm_kernel_lora_b_t { + using gemm_t = gemm_t_; + using epilogue_t = epilogue_t_; + using gemm_args_t = typename gemm_t::arguments_t; + using epilogue_args_t = typename epilogue_t::arguments_t; + using tile_shape = typename gemm_t::tile_shape; + using group_swizzle_t = group_swizzle_; + + static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y; + static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x; + static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; + static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; + static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y; + static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x; + + static constexpr uint32_t k_stride = gemm_t::k_stride; + using work_group_t = typename gemm_t::work_group_t; + + static constexpr gpu_arch arch_tag = group_swizzle_t::arch_tag; + static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same"); + static_assert( + arch_tag == epilogue_t::arch_tag, "arch_tag should be the same"); + static_assert(std::is_same::value, + "tile_shape should be the same"); + + using mem_desc_a_t = typename gemm_t::mem_desc_a_t; + using mem_desc_b_t = typename gemm_t::mem_desc_b_t; + using mem_desc_c_t = typename epilogue_t::mem_desc_c_t; + using matA_base_t = typename mem_desc_a_t::base_t; + using matB_base_t = typename mem_desc_b_t::base_t; + using matC_base_t = typename mem_desc_c_t::base_t; + using dtype_a = typename mem_desc_a_t::dtype; + using dtype_b = typename mem_desc_b_t::dtype; + using dtype_c = typename mem_desc_c_t::dtype; + using matAcc_t = typename gemm_t::matAcc_t; + +public: + /// @brief GEMM_UNIVERSAL arguments. + /// This is the interface for users to pass the application-related runtime variables. + struct arguments_t { + /// @brief Is the size of the m dimension of the matrix multiplication (m x k x n). + uint32_t matrix_m; + /// @brief Is the size of the k dimension of the matrix multiplication (m x k x n). + uint32_t matrix_k; + /// @brief Is the size of the n dimension of the matrix multiplication (m x k x n). + uint32_t matrix_n; + /// @brief Is the leading dimension (pitch) size of the matrix A in memory. + uint32_t matA_ld; + /// @brief Is the leading dimension (pitch) size of the matrix B in memory. + uint32_t matB_ld; + /// @brief Is the leading dimension (pitch) size of the matrix C in memory. + uint32_t matC_ld; + /// @brief Is the base address of matrix A. + matA_base_t matA_base; + /// @brief Is the base address of matrix B. + matB_base_t matB_base; + /// @brief Is the base address of matrix C. + matC_base_t matC_base; + /// @brief Is the epilogue arguments. + epilogue_args_t epilogue_args; + + /// @brief Constructs arguments with default method. + inline arguments_t() = default; + + /// @brief Set for device copyable + static constexpr bool host_callable = true; + + /// @brief Constructs arguments with initialization list. + /// @param matrix_m_ Is the size of the m dimension of the matrix multiplication (m x k x n). + /// @param matrix_k_ Is the size of the k dimension of the matrix multiplication (m x k x n). + /// @param matrix_n_ Is the size of the n dimension of the matrix multiplication (m x k x n). + /// @param matA_base_ Is the base address of matrix A. + /// @param matA_ld_ Is the leading dimension (pitch) size of the matrix A in memory. + /// @param matB_base_ Is the base address of matrix B. + /// @param matB_ld_ Is the leading dimension (pitch) size of the matrix B in memory. + /// @param matC_base_ Is the base address of matrix C. + /// @param matC_ld_ Is the leading dimension (pitch) size of the matrix C in memory. + /// @param epilogue_args_ Is the epilogue arguments. + inline arguments_t(uint32_t matrix_m_, uint32_t matrix_k_, + uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_, + matB_base_t matB_base_, uint32_t matB_ld_, + matC_base_t matC_base_, uint32_t matC_ld_, + epilogue_args_t epilogue_args_ = {}) + : matrix_m(matrix_m_) + , matrix_k(matrix_k_) + , matrix_n(matrix_n_) + , matA_base(matA_base_) + , matA_ld(matA_ld_) + , matB_base(matB_base_) + , matB_ld(matB_ld_) + , matC_base(matC_base_) + , matC_ld(matC_ld_) + , epilogue_args(epilogue_args_) {} + // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor) + // Please check if you need to add self-define destructor + // inline ~arguments_t(){} + inline arguments_t(const arguments_t &args) + : matrix_m(args.matrix_m) + , matrix_k(args.matrix_k) + , matrix_n(args.matrix_n) + , matA_base(args.matA_base) + , matA_ld(args.matA_ld) + , matB_base(args.matB_base) + , matB_ld(args.matB_ld) + , matC_base(args.matC_base) + , matC_ld(args.matC_ld) + , epilogue_args(args.epilogue_args) {} + inline arguments_t &operator=(const arguments_t &args) { + this->matrix_m = args.matrix_m; + this->matrix_k = args.matrix_k; + this->matrix_n = args.matrix_n; + this->matA_base = args.matA_base; + this->matA_ld = args.matA_ld; + this->matB_base = args.matB_base; + this->matB_ld = args.matB_ld; + this->matC_base = args.matC_base; + this->matC_ld = args.matC_ld; + this->epilogue_args = args.epilogue_args; + return *this; + } + }; + + /// @brief Gets named_barrier id consumption count. + /// Users query and get a named_barrier id consumption count in compile time. + /// @return The count of named barriers required. + inline static constexpr uint32_t get_barrier_count() { + constexpr uint32_t count + = gemm_t::barrier_count + epilogue_t::barrier_count; + static_assert( + count <= 32, "The named_barrier count should be less than 32!"); + return count; + } + + /// @brief Gets local memory size consumption. + /// Users query and get a local memory consumption size in compile time. + /// @return The size of local memory required. + inline static constexpr uint32_t get_slm_size() { + constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size; + static_assert(size <= (128 * 1024), + "The local memory size should be less than 128KB!"); + return size; + }; + + /// @brief Main execution function for GEMM_UNIVERSAL. + /// The processing order is 1) set group-level base and boundary -> 2) gemm -> 3) epilogue. + /// @param item Is the sycl::nd_item, returns execution related information, such as workgroup id, subgroup id... + /// @param args Is the GEMM_UNIVERSAL arguments for application-related runtime variables. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + inline void operator()(sycl::nd_item<3> &item, + const arguments_t &args, matA_in_t &matA, uint32_t iter = 0, + uint32_t slm_base = 0, uint32_t nbarrier_base = 0) const { + // set up workgroup level coordinates and boundaries + group_swizzle_t group_swizzle; + int start_m = group_swizzle.template get_tile_idx<1>(item) * wg_tile_m; + int start_n + = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n_total + + wg_tile_n * iter; + int start_k = 0; + uint32_t wg_tile_k = args.matrix_k; + uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n + ? args.matrix_n + : (start_n + wg_tile_n); + uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m + ? args.matrix_m + : (start_m + wg_tile_m); + uint32_t boundary_k = wg_tile_k; + + uint32_t gemm_slm_base = slm_base; + uint32_t gemm_nbarr_base = nbarrier_base; + uint32_t epilogue_slm_base = gemm_slm_base + gemm_t::slm_size; + uint32_t epilogue_nbarr_base = gemm_nbarr_base + gemm_t::barrier_count; + + // set up arguments + work_group_t g; + g.init(item.get_local_linear_id()); + mem_desc_a_t mem_desc_a; + mem_desc_b_t mem_desc_b; + mem_desc_c_t mem_desc_c; + //setup for matA + if constexpr (mem_desc_a_t::is_local) { + mem_desc_a.init(args.matA_base, + {wg_tile_k, real_wg_tile_m, wg_tile_k}, {0, 0}); + } else { + mem_desc_a.init(args.matA_base, + {boundary_k, boundary_m, args.matA_ld}, {start_k, start_m}); + } + //setup for matB + if constexpr (mem_desc_b_t::is_local) { + mem_desc_b.init(args.matB_base, + {real_wg_tile_n, wg_tile_k, real_wg_tile_n}, {0, 0}); + } else { + mem_desc_b.init(args.matB_base, + {boundary_n, boundary_k, args.matB_ld}, {start_n, start_k}); + } + //setup for matC + if constexpr (mem_desc_c_t::is_local) { + mem_desc_c.init(args.matC_base, + {real_wg_tile_n, real_wg_tile_m, real_wg_tile_n}, {0, 0}); + } else { + mem_desc_c.init(args.matC_base, + {boundary_n, boundary_m, args.matC_ld}, {start_n, start_m}); + } + uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; + gemm_args_t gemm_args(mem_desc_a, mem_desc_b, inner_loop_count); + gemm_t gemm; + epilogue_t epilogue; + + matAcc_t matAcc(0); + gemm(g, matAcc, gemm_args, matA, gemm_slm_base, gemm_nbarr_base); + epilogue(g, matAcc, mem_desc_c, args.epilogue_args, epilogue_slm_base, + epilogue_nbarr_base); + } +}; + +template +struct gemm_lora_b { + using tile_shape = group::tile_shape_t; + + using mem_desc_a + = mem_desc_t; + using mem_desc_b + = mem_desc_t; + using mem_desc_c + = mem_desc_t; + + using compute_attr = typename std::conditional, + compute_attr_t>::type; + + using perf_tuning_knob = perf_tuning_knob_t; + + using compute_policy_0 = + typename std::conditional, + compute_policy_default_xmx>::type; + using compute_policy = typename std::conditional, + compute_policy_0>::type; + using pre_processing = pre_processing_default_t; +#if LORA_TEMP_IN_REG == 1 + using gemm = gemm_lora_b_aligned_t; +#else + using gemm = gemm_t; +#endif + + XETLA_POST_OP_DEFINITIONS + + using tile_op_t = subgroup::chained_tile_op_t; + using epilogue = epilogue_t< + epilogue_policy_tile_op, + tile_shape, mem_desc_c>; + + using epilogue_args_t = typename epilogue::arguments_t; + + using group_swizzle_t = kernel::group_swizzle_default; + +#if LORA_TEMP_IN_REG == 1 + using gemm_op_t = gemm_kernel_lora_b_t; +#else + using gemm_op_t = kernel::gemm_universal_t< + kernel::dispatch_policy_kslicing, + gemm, epilogue>; +#endif + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + + inline static void run(sycl::nd_item<3> &item, uint32_t mat_m, + uint32_t mat_k, uint32_t mat_n, dtype_a *a, dtype_b *b, dtype_c *c, + uint32_t iter +#if LORA_TEMP_IN_REG == 1 + , + matA_in_t &matA +#endif + XETLA_POST_OP_ARGS) { + + gemm_op_t gemm_op; + + uint32_t lda = layout_a == mem_layout::col_major ? mat_m : mat_k; + uint32_t ldb = layout_b == mem_layout::col_major ? mat_k : mat_n; + uint32_t ldc = layout_c == mem_layout::col_major ? mat_m : mat_n; + + XETLA_POST_OP_SHAPE_DEFINITIONS + epilogue_args_t epilogue_args; + epilogue_args.init({XETLA_POST_OP_EPILOGUE_INIT_ARGS}); + + typename gemm_op_t::arguments_t arg(mat_m, mat_k, mat_n, a, lda, b, ldb, + c, ldc, +#if LORA_TEMP_IN_REG == 0 + nullptr, nullptr, +#endif + epilogue_args); + + gemm_op(item, arg +#if LORA_TEMP_IN_REG == 1 + + , + matA, iter +#endif + ); + } +}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_fused.h b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_fused.h new file mode 100644 index 00000000000000..24f5dbaaee5205 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_lora_fused.h @@ -0,0 +1,99 @@ +/******************************************************************************* +* Copyright (c) 2022-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "xetla_lora_a.h" +#include "xetla_lora_b.h" + +using namespace gpu::xetla; +using namespace gpu::xetla::group; +using namespace gpu::xetla::kernel; +using namespace gpu::xetla::subgroup; + +template +struct lora_gemm_fused { + static constexpr uint32_t local_range_m = (wg_m + sg_m - 1) / sg_m; + static constexpr uint32_t local_range_nA = (wg_n_A + sg_n_A - 1) / sg_n_A; + static constexpr uint32_t local_range_nB = (wg_n_B + sg_n_B - 1) / sg_n_B; + static constexpr uint32_t local_range_n + = local_range_nA > local_range_nB ? local_range_nA : local_range_nB; + static constexpr uint32_t num_threads + = local_range_m * local_range_n * local_kslicing; + + using gemm_lora_a_t = gemm_lora_a; + + using matTemp_t = subgroup::tile_t; + + static constexpr uint32_t gemm_b_n_iters + = (wg_n_B_total + wg_n_B - 1) / wg_n_B; + + using gemm_lora_b_t = gemm_lora_b; + + static constexpr uint32_t barrier_count = gemm_lora_a_t::barrier_count + 1; + static_assert( + gemm_lora_b_t::barrier_count == 0, "barrier_count should be 0"); + static constexpr uint32_t slm_size = gemm_lora_a_t::slm_size; + static_assert(gemm_lora_b_t::slm_size == 0, "slm_size should be 0"); + + inline static void run(sycl::nd_item<3> &item, uint32_t m, uint32_t k, + uint32_t n, uint32_t lora_rank, dtype_a *lora_input, + dtype_b *state_a, dtype_b *state_alpha, dtype_b *state_b, + dtype_c *out, dtype_a *lora_temp XETLA_POST_OP_ARGS) { + + typename gemm_lora_a_t::matAcc_t matAcc; + gemm_lora_a_t::run(item, m, k, lora_rank, lora_input, state_a, + lora_temp, state_alpha, matAcc); + +#if LORA_TEMP_IN_REG == 1 + matTemp_t matTemp; + subgroup::elemwise_cvt(matTemp, matAcc); +#endif + + if (LORA_TEMP_IN_REG != 1) { + xetla_nbarrier_t nbarrier; + nbarrier.init_nbarrier( + barrier_count - 1, nbarrier_role::producer_consumer); + xetla_fence(); + nbarrier.arrive_wait(); + } +#pragma unroll + for (uint32_t i = 0; i < gemm_b_n_iters; i++) { + gemm_lora_b_t::run(item, m, lora_rank, n, lora_temp, state_b, out, i +#if LORA_TEMP_IN_REG == 1 + , + matTemp +#endif + XETLA_POST_OP_ARGS_PASS); + } + } +}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_postop.h b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_postop.h new file mode 100644 index 00000000000000..0989a466ad41c0 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/xetla_postop.h @@ -0,0 +1,94 @@ +/******************************************************************************* +* Copyright (c) 2022-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +using namespace gpu::xetla; +using namespace gpu::xetla::group; +using namespace gpu::xetla::kernel; +using namespace gpu::xetla::subgroup; + +template +struct postop { + static constexpr gpu_arch arch_tag = gpu_arch::Xe; + + using tile_shape = group::tile_shape_t; + + using tile_desc_t + = subgroup::tile_desc_t; + + using tile_in_t = subgroup::tile_t; + using tile_acc_t = subgroup::tile_t; + using tile_out_t = subgroup::tile_t; + + using mem_desc_in_t + = mem_desc_t; + using mem_desc_out_t + = mem_desc_t; + + using payload_input_t = subgroup::mem_payload_t; + + XETLA_POST_OP_DEFINITIONS + + using tile_op_t = subgroup::chained_tile_op_t; + using epilogue = epilogue_t, + tile_shape, mem_desc_out_t>; + using epilogue_args_t = typename epilogue::arguments_t; + + inline static void run(sycl::nd_item<3> &item, uint32_t mat_m, + uint32_t mat_n, dtype_in *in, dtype_out *out XETLA_POST_OP_ARGS) { + uint32_t ldc = mat_n; + int32_t start_m = item.get_group(1) * wg_m; + int32_t start_n = item.get_group(2) * wg_n; + + uint32_t boundary_m + = (start_m + wg_m) > mat_m ? mat_m : (start_m + wg_m); + uint32_t boundary_n + = (start_n + wg_n) > mat_n ? mat_n : (start_n + wg_n); + + typename tile_shape::work_group_t g; + g.init(item.get_local_linear_id()); + + mem_desc_in_t mem_desc_in; + mem_desc_out_t mem_desc_out; + + int sg_start_n = start_n + (g.get_id() % tile_shape::wg_size_x) * sg_n; + int sg_start_m = start_m + (g.get_id() / tile_shape::wg_size_x) * sg_m; + + mem_desc_in.init( + in, {boundary_n, boundary_m, ldc}, {sg_start_n, sg_start_m}); + mem_desc_out.init( + out, {boundary_n, boundary_m, ldc}, {start_n, start_m}); + + tile_in_t mat_in; + payload_input_t input_payload(mem_desc_in); + + subgroup::tile_load( + mat_in, input_payload); + + tile_acc_t mat_acc; + subgroup::elemwise_cvt(mat_acc, mat_in); + + XETLA_POST_OP_SHAPE_DEFINITIONS + epilogue_args_t epilogue_args; + epilogue_args.init({XETLA_POST_OP_EPILOGUE_INIT_ARGS}); + + epilogue epilogue; + epilogue(g, mat_acc, mem_desc_out, epilogue_args, 0, 0); + } +}; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.cpp index cd8902ef04f54e..6e172f8f2b1310 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.cpp @@ -83,4 +83,20 @@ Arguments KernelGenerator::get_arguments_desc(const RuntimeParams& params) const return args; } +void KernelGenerator::add_fused_ops_arguments(Arguments& args, const RuntimeParams& params) { + if (params.has_fused_primitives()) { + size_t num_fused_deps = 0; + for (const auto& fd : params.fused_desc) { + for (const auto& in_d : fd.inputs) { + if (in_d.m_type == cldnn::FusedInputType::EXTERNAL) { + num_fused_deps++; + } + } + } + for (size_t i = 0; i < num_fused_deps; i++) { + args.push_back({ArgumentDescriptor::Types::INPUT_OF_FUSED_PRIMITIVE, static_cast(i)}); + } + } +} + } // namespace ov::intel_gpu::cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.hpp index 89fdf6f77b44a8..ed91374ac16837 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/utils/kernel_generator.hpp @@ -43,6 +43,8 @@ class KernelGenerator : public KernelGeneratorBase { [[nodiscard]] static std::string build_code(std::string_view template_name, const JitConstants& jit_constants, const std::string& entry_point); + static void add_fused_ops_arguments(Arguments& args, const RuntimeParams& params); + private: std::string m_kernel_name; std::string m_stage_suffix; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_helpers.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_helpers.hpp new file mode 100644 index 00000000000000..701e07d30f19c8 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_helpers.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "openvino/core/except.hpp" +#include "openvino/core/type.hpp" + +namespace ov::intel_gpu::cm { + +enum class MemLayout { row_major, col_major }; + +inline std::string get_xetla_mem_layout(MemLayout layout) { + switch (layout) { + case MemLayout::row_major: + return "mem_layout::row_major"; + case MemLayout::col_major: + return "mem_layout::col_major"; + default: + OPENVINO_THROW("Unsupported XeTLA memory layout!"); + } +} + +inline std::string ov_to_xetla_dtype(ov::element::Type type) { + switch (type) { + case ov::element::Type_t::f16: + return "fp16"; + case ov::element::Type_t::bf16: + return "bf16"; + case ov::element::Type_t::f32: + return "float"; + default: + OPENVINO_THROW("Unsupported XeTLA data type!"); + } +} +} // namespace ov::intel_gpu::cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_postops.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_postops.cpp new file mode 100644 index 00000000000000..cc14e09b480079 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_postops.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "xetla_postops.hpp" + +#include "graph/common_utils/jitter.hpp" +#include "xetla_helpers.hpp" + +namespace ov::intel_gpu::cm { + +std::vector> XeTLAPostOPs::get_definitions() { + std::string post_op_kernel_args = ""; + std::string post_op_args = ""; + std::string post_op_args_pass = ""; + std::string post_op_definitions = ""; + std::string post_op_list = ""; + std::string post_op_shape_definitions = ""; + std::string post_op_epilogue_init_args = ""; + + bool first_epilogue_arg = true; + bool first_post_op_list = true; + + for (const auto& post_op : postops) { + auto kernel_arg_definition = post_op->get_kernel_arg_definition(); + if (!kernel_arg_definition.empty()) { + post_op_kernel_args += ", " + kernel_arg_definition; + } + + auto arg_definition = post_op->get_arg_definition(); + if (!arg_definition.empty()) { + post_op_args += ", " + arg_definition; + } + + auto arg_definition_pass = post_op->get_arg_name(); + if (!arg_definition_pass.empty()) { + post_op_args_pass += ", " + arg_definition_pass; + } + + post_op_definitions += post_op->get_definition(); + + if (first_post_op_list) { + post_op_list += post_op->get_definition_name(); + first_post_op_list = false; + } else { + post_op_list += ", " + post_op->get_definition_name(); + } + + post_op_shape_definitions += post_op->get_shape_definition(); + + if (first_epilogue_arg) { + post_op_epilogue_init_args += post_op->get_epilogue_init(); + first_epilogue_arg = false; + } else { + post_op_epilogue_init_args += ", " + post_op->get_epilogue_init(); + } + } + std::vector> definitions; + definitions.push_back({"XETLA_POST_OP_KERNEL_ARGS", post_op_kernel_args}); + definitions.push_back({"XETLA_POST_OP_ARGS", post_op_args}); + definitions.push_back({"XETLA_POST_OP_ARGS_PASS", post_op_args_pass}); + definitions.push_back({"XETLA_POST_OP_DEFINITIONS", post_op_definitions}); + definitions.push_back({"XETLA_POST_OP_LIST", post_op_list}); + definitions.push_back({"XETLA_POST_OP_SHAPE_DEFINITIONS", post_op_shape_definitions}); + definitions.push_back({"XETLA_POST_OP_EPILOGUE_INIT_ARGS", post_op_epilogue_init_args}); + + return definitions; +} + +size_t XeTLAPostOPs::add_post_ops(const RuntimeParams& params, size_t post_op_arg_index) { + for (const auto& postop : params.fused_desc) { + const bool is_eltwise = cldnn::fused_ops_are_one_of({postop}); + const bool is_activation = cldnn::fused_ops_are_one_of({postop}); + if (is_eltwise) { + auto eltwise = std::static_pointer_cast(postop.desc); + auto eltwise_layout = params.input_layouts[post_op_arg_index++]; + auto eltwise_dtype = ov_to_xetla_dtype(eltwise_layout.data_type); + + bool broadcast = false; + bool is_M_dynamic = eltwise_layout.get_partial_shape()[0].is_dynamic() || eltwise_layout.get_partial_shape()[1].is_dynamic(); + if (!is_M_dynamic) { + const auto eltwise_M = extract_channel(ChannelName::BATCH, eltwise_layout) * extract_channel(ChannelName::FEATURE, eltwise_layout); + broadcast = eltwise_M == 1; + } + assert(eltwise->broadcast_spec.m_axis == 0); + + if (broadcast) { + if (eltwise->mode == cldnn::eltwise_mode::sum) { + postops.push_back(std::make_unique(post_op_index++, eltwise_dtype)); + } else if (eltwise->mode == cldnn::eltwise_mode::prod) { + postops.push_back(std::make_unique(post_op_index++, eltwise_dtype)); + } + } else { + const auto eltwise_op = get_xetla_eltwise_op(eltwise->mode); + assert(eltwise_op != Eltwise::EltwiseOp::none); + postops.push_back(std::make_unique(post_op_index++, eltwise_dtype, eltwise_op)); + } + } else if (is_activation) { + const auto activation = std::static_pointer_cast(postop.desc); + const auto activation_dtype = ov_to_xetla_dtype(ov::element::Type_t::f32); + const auto activation_op = get_xetla_activation_op(activation->activation_function); + + assert(activation_op != Activation::ActivationOp::none); + postops.push_back(std::make_unique(post_op_index++, activation_dtype, activation_op)); + } + } + return post_op_arg_index; +} + +} // namespace ov::intel_gpu::cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_postops.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_postops.hpp new file mode 100644 index 00000000000000..4a7be5528d9572 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/utils/xetla_postops.hpp @@ -0,0 +1,204 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include +#include +#include +#include + +#include "intel_gpu/graph/kernel_impl_params.hpp" +#include "intel_gpu/primitives/activation.hpp" +#include "intel_gpu/primitives/eltwise.hpp" +#include "openvino/core/except.hpp" + +namespace ov::intel_gpu::cm { + +class XeTLAPostOP { +protected: + const size_t index; + const std::string dtype; + +public: + XeTLAPostOP(size_t index, std::string dtype) : index{index}, dtype{dtype} {}; + virtual ~XeTLAPostOP() = default; + virtual std::string get_arg_name() const = 0; + virtual std::string get_arg_definition() const { + return dtype + " *" + get_arg_name(); + } + virtual std::string get_kernel_arg_definition() const { + return get_arg_definition() + " [[type(\"svmptr_t\")]]"; + } + virtual std::string get_definition_name() const = 0; + virtual std::string get_definition() const = 0; + virtual std::string get_shape_name() const { + return get_arg_name() + "_shape"; + } + virtual std::string get_shape_definition() const = 0; + virtual std::string get_epilogue_init() const { + return "{" + get_arg_name() + ", " + get_shape_name() + "}"; + } +}; + +class ScaleChannels : public XeTLAPostOP { +public: + ScaleChannels(size_t index, std::string dtype) : XeTLAPostOP(index, dtype) {} + virtual std::string get_arg_name() const { + return "scale_input" + std::to_string(index); + } + virtual std::string get_definition_name() const { + return "scale_op_t" + std::to_string(index); + } + virtual std::string get_definition() const { + return "using " + get_definition_name() + " = subgroup::scale_v_op_t<" + dtype + ", arch_tag>;"; + } + virtual std::string get_shape_definition() const { + return "typename " + get_definition_name() + "::scale_shape_t " + get_shape_name() + "(mat_n, 1, mat_n);"; + } +}; + +class ShiftChannels : public XeTLAPostOP { +public: + ShiftChannels(size_t index, std::string dtype) : XeTLAPostOP(index, dtype) {} + virtual std::string get_arg_name() const { + return "shift_input" + std::to_string(index); + } + virtual std::string get_definition_name() const { + return "shift_op_t" + std::to_string(index); + } + virtual std::string get_definition() const { + return "using " + get_definition_name() + " = subgroup::bias_add_op_t<" + dtype + ", arch_tag>;"; + } + virtual std::string get_shape_definition() const { + return "typename " + get_definition_name() + "::shape_t " + get_shape_name() + "(mat_n, 1, mat_n);"; + } +}; + +class Eltwise : public XeTLAPostOP { +public: + enum class EltwiseOp { none, sum, prod }; + +private: + EltwiseOp op; + +public: + Eltwise(size_t index, std::string dtype, EltwiseOp op) : XeTLAPostOP(index, dtype), op{op} {} + virtual std::string get_arg_name() const { + return "eltwise_input" + std::to_string(index); + } + virtual std::string get_definition_name() const { + return "eltwise_op_t" + std::to_string(index); + } + virtual std::string get_definition() const { + return "using " + get_definition_name() + " = subgroup::elemwise_reduce_op_t;"; + } + virtual std::string get_shape_definition() const { + return "typename " + get_definition_name() + "::shape_t " + get_shape_name() + "(mat_n, mat_m, ldc);"; + } + std::string get_op_name() const { + switch (op) { + case EltwiseOp::sum: + return "sum"; + case EltwiseOp::prod: + return "prod"; + default: + OPENVINO_THROW("Unknown XeTLA EltwiseOp"); + } + } +}; + +class Activation : public XeTLAPostOP { +public: + enum class ActivationOp { none, ReLU, Tanh, Sigmoid, SiLU, GeLU }; + +private: + ActivationOp op; + +public: + Activation(size_t index, std::string dtype, ActivationOp op) : XeTLAPostOP(index, dtype), op{op} {} + virtual std::string get_arg_name() const { + return ""; + } + virtual std::string get_arg_definition() const { + return ""; + } + virtual std::string get_kernel_arg_definition() const { + return ""; + } + virtual std::string get_definition_name() const { + return "activation_op_t" + std::to_string(index); + } + virtual std::string get_definition() const { + return "using " + get_definition_name() + " = subgroup::" + get_op_name() + ";"; + } + virtual std::string get_shape_definition() const { + return ""; + } + virtual std::string get_shape_name() const { + return ""; + } + virtual std::string get_epilogue_init() const { + return "{}"; + } + std::string get_op_name() const { + switch (op) { + case ActivationOp::ReLU: + return "relu_op_t"; + case ActivationOp::Tanh: + return "tanh_op_t"; + case ActivationOp::Sigmoid: + return "sigmoid_op_t"; + case ActivationOp::SiLU: + return "silu_precise_op_t"; + case ActivationOp::GeLU: + return "gelu_fwd_op_t"; + default: + OPENVINO_THROW("Unknown XeTLA ActivationOp"); + } + } +}; + +class XeTLAPostOPs { + size_t post_op_index = 0; + std::vector> postops; + +public: + template + void add_post_op(Args&&... args) { + postops.push_back(std::make_unique(post_op_index++, std::forward(args)...)); + } + + size_t add_post_ops(const RuntimeParams& params, size_t post_op_arg_index); + std::vector> get_definitions(); +}; + +std::vector> generate_post_ops(const std::vector>& post_ops); + +inline Activation::ActivationOp get_xetla_activation_op(cldnn::activation_func func) { + switch (func) { + case cldnn::activation_func::relu: + return Activation::ActivationOp::ReLU; + case cldnn::activation_func::tan: + return Activation::ActivationOp::Tanh; + case cldnn::activation_func::swish: + return Activation::ActivationOp::SiLU; + case cldnn::activation_func::gelu: + return Activation::ActivationOp::GeLU; + default: + return Activation::ActivationOp::none; + } +} + +inline Eltwise::EltwiseOp get_xetla_eltwise_op(cldnn::eltwise_mode mode) { + switch (mode) { + case cldnn::eltwise_mode::sum: + return Eltwise::EltwiseOp::sum; + case cldnn::eltwise_mode::prod: + return Eltwise::EltwiseOp::prod; + default: + return Eltwise::EltwiseOp::none; + } +} + +} // namespace ov::intel_gpu::cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora.cpp new file mode 100644 index 00000000000000..e11e76b29f3570 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora.cpp @@ -0,0 +1,701 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "xetla_lora.hpp" + +#include "common_utils/kernel_generator_base.hpp" +#include "impls/ocl_v2/utils/jitter.hpp" +#include "primitive_cm_base.hpp" +#include "primitive_inst.h" +#include "registry/implementation_manager.hpp" +#include "utils/kernel_generator.hpp" +#include "utils/xetla_helpers.hpp" +#include "utils/xetla_postops.hpp" + +namespace ov::intel_gpu::cm { +namespace { + +constexpr auto get_lora_build_options() { + return " -Qxcm_jit_option=-DPASTokenReduction " + " -mllvm --vc-disable-indvars-opt=true " + " /Qxcm_jit_option=-enableBCR /Qxcm_doubleGRF " + " -DXETLA_CODE_BASE=__CM__ "; +} + +class XeTLALoraBaseGenerator : public KernelGenerator { +public: + XeTLALoraBaseGenerator(std::string_view name, std::string_view suffix = "") : KernelGenerator(name, suffix) {} + virtual ~XeTLALoraBaseGenerator() = default; + [[nodiscard]] std::string get_build_options(const RuntimeParams& params) const override { + return KernelGenerator::get_build_options(params) + get_lora_build_options(); + } + + struct Layouts { + static constexpr MemLayout mem_layout_a = MemLayout::row_major; + static constexpr MemLayout mem_layout_state_a = MemLayout::col_major; + static constexpr MemLayout mem_layout_state_b = MemLayout::row_major; + static constexpr MemLayout mem_layout_temp = MemLayout::row_major; + static constexpr MemLayout mem_layout_c = MemLayout::row_major; + }; + + static bool is_2dload_aligned(size_t size, ov::element::Type dtype) { + static constexpr size_t min_size = 64 * 8; + static constexpr size_t multiple_of = 16 * 8; + auto dtype_size = ov::element::Type(dtype).bitwidth(); + auto size_in_bits = size * dtype_size; + return size_in_bits >= min_size && size_in_bits % multiple_of == 0; + } + + struct Tiling { + const size_t wg_m; + const size_t wg_n; + const size_t sg_m; + const size_t sg_n; + const size_t sg_k; + const size_t num_global_kslicing; + const size_t num_local_kslicing; + }; + +public: + struct LoraShapeUtils { + static std::tuple get_lora_gemm_shape(const RuntimeParams& params) { + return {get_total_tokens(params), get_lora_rank(params), get_hidden_size_input(params), get_hidden_size_output(params)}; + } + static std::tuple get_lora_gemm_shape(const cldnn::primitive_inst& instance) { + return {get_total_tokens(instance), get_lora_rank(instance), get_hidden_size_input(instance), get_hidden_size_output(instance)}; + } + + private: + static size_t get_total_tokens(const cldnn::layout& layout) { + assert(layout.format == cldnn::format::bfyx); + assert(!(layout.get_partial_shape()[0].is_dynamic() || layout.get_partial_shape()[1].is_dynamic())); + return extract_channel(ChannelName::BATCH, layout) * extract_channel(ChannelName::FEATURE, layout); + } + + static size_t get_hidden_size_input(const cldnn::layout& layout) { + assert(layout.format == cldnn::format::bfyx); + assert(!(layout.get_partial_shape()[2].is_dynamic() || layout.get_partial_shape()[3].is_dynamic())); + return extract_channel(ChannelName::Y, layout) * extract_channel(ChannelName::X, layout); + } + + static size_t get_hidden_size_output(const cldnn::layout& layout) { + assert(layout.format == cldnn::format::bfyx); + assert(!(layout.get_partial_shape()[2].is_dynamic() || layout.get_partial_shape()[3].is_dynamic())); + return extract_channel(ChannelName::Y, layout) * extract_channel(ChannelName::X, layout); + } + + static size_t get_lora_rank(const cldnn::layout& layout) { + assert(layout.format == cldnn::format::bfyx); + assert(!(layout.get_partial_shape()[0].is_dynamic())); + return extract_channel(ChannelName::FEATURE, layout); + } + + public: + static size_t get_total_tokens(const RuntimeParams& params) { + return get_total_tokens(params.output_layouts[0]); + } + + static size_t get_hidden_size_input(const RuntimeParams& params) { + return get_hidden_size_input(params.input_layouts[1]); + } + + static size_t get_hidden_size_output(const RuntimeParams& params) { + return get_hidden_size_output(params.output_layouts[0]); + } + + static size_t get_lora_rank(const RuntimeParams& params) { + return get_lora_rank(params.input_layouts[3]); + } + + static size_t get_total_tokens(const cldnn::primitive_inst& instance) { + return get_total_tokens(instance.get_output_layout(0)); + } + + static size_t get_hidden_size_input(const cldnn::primitive_inst& instance) { + return get_hidden_size_input(instance.get_input_layout(1)); + } + + static size_t get_hidden_size_output(const cldnn::primitive_inst& instance) { + return get_hidden_size_output(instance.get_output_layout(0)); + } + + static size_t get_lora_rank(const cldnn::primitive_inst& instance) { + return get_lora_rank(instance.get_input_layout(3)); + } + + static auto get_total_tokens_jit(const RuntimeParams& params) { + ov::intel_gpu::ocl::LayoutJitter jit(params.input_layouts[1], params.in_port_to_shape_info_offset.at(1)); + const auto jit_val = "(" + jit.dim(ChannelName::BATCH) + " * " + jit.dim(ChannelName::FEATURE) + ")"; + return jit_val; + } + static auto get_lora_rank_jit(const RuntimeParams& params) { + ov::intel_gpu::ocl::LayoutJitter jit(params.input_layouts[3], params.in_port_to_shape_info_offset.at(3)); + const auto jit_val = jit.dim(ChannelName::FEATURE); + return jit_val; + } + static auto get_hidden_size_input_jit(const RuntimeParams& params) { + ov::intel_gpu::ocl::LayoutJitter jit(params.input_layouts[1], params.in_port_to_shape_info_offset.at(1)); + const auto jit_val = "(" + jit.dim(ChannelName::Y) + " * " + jit.dim(ChannelName::X) + ")"; + return jit_val; + } + static auto get_hidden_size_output_jit(const RuntimeParams& params) { + ov::intel_gpu::ocl::LayoutJitter jit(params.output_layouts[0], params.out_port_to_shape_info_offset.at(0)); + const auto jit_val = "(" + jit.dim(ChannelName::Y) + " * " + jit.dim(ChannelName::X) + ")"; + return jit_val; + } + }; +}; + +class XetlaLoRAFusedGenerator : public XeTLALoraBaseGenerator { + const Tiling tilingA; + const Tiling tilingB; + const size_t total_wg_n_b = 512; + +public: + XetlaLoRAFusedGenerator(Tiling tilingA, Tiling tilingB, std::string_view prefix = "") + : XeTLALoraBaseGenerator("xetla_lora_fused", prefix), + tilingA{tilingA}, + tilingB{tilingB} {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + + const auto mem_layout_a = Layouts::mem_layout_a; + const auto mem_layout_state_a = Layouts::mem_layout_state_a; + const auto mem_layout_state_b = Layouts::mem_layout_state_b; + const auto mem_layout_c = Layouts::mem_layout_c; + + const uint32_t temp_in_reg = 1; + + jit.add({make_jit_constant("KERNEL_NAME", get_entry_point(params)), + make_jit_constant("LORA_DTYPE_A", ov_to_xetla_dtype(params.input_layouts[1].data_type)), + make_jit_constant("LORA_DTYPE_B", ov_to_xetla_dtype(params.input_layouts[2].data_type)), + make_jit_constant("LORA_DTYPE_C", ov_to_xetla_dtype(params.output_layouts[0].data_type)), + make_jit_constant("LORA_DTYPE_ACC", ov_to_xetla_dtype(ov::element::Type_t::f32)), + make_jit_constant("LORA_SIZE_RANK", LoraShapeUtils::get_lora_rank_jit(params)), + make_jit_constant("LORA_WG_M", tilingA.wg_m), + make_jit_constant("LORA_WG_N_A", tilingA.wg_n), + make_jit_constant("LORA_WG_N_B", tilingB.wg_n), + make_jit_constant("LORA_SG_M", tilingA.sg_m), + make_jit_constant("LORA_SG_N_A", tilingA.sg_n), + make_jit_constant("LORA_SG_N_B", tilingB.sg_n), + make_jit_constant("LORA_SG_K_A", tilingA.sg_k), + make_jit_constant("LORA_SG_K_B", tilingB.sg_k), + make_jit_constant("LORA_WG_B_TOTAL", total_wg_n_b), + make_jit_constant("LORA_LOCAL_SLICING", tilingA.num_local_kslicing), + make_jit_constant("LORA_MMA_ENGINE", "mma_engine::xmx"), + make_jit_constant("LORA_MEM_LAYOUT_A", get_xetla_mem_layout(mem_layout_a)), + make_jit_constant("LORA_MEM_LAYOUT_STATE_A", get_xetla_mem_layout(mem_layout_state_a)), + make_jit_constant("LORA_MEM_LAYOUT_STATE_B", get_xetla_mem_layout(mem_layout_state_b)), + make_jit_constant("LORA_MEM_LAYOUT_C", get_xetla_mem_layout(mem_layout_c)), + make_jit_constant("LORA_MEM_SPACE_TEMP", "mem_space::global"), + make_jit_constant("LORA_UNALIGNED", "false"), + make_jit_constant("LORA_TEMP_IN_REG", temp_in_reg), + make_jit_constant("LORA_SIZE_M", LoraShapeUtils::get_total_tokens_jit(params)), + make_jit_constant("LORA_SIZE_K", LoraShapeUtils::get_hidden_size_input_jit(params)), + make_jit_constant("LORA_SIZE_N", LoraShapeUtils::get_hidden_size_output_jit(params))}); + + if (params.is_dynamic()) { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "int *shape_info [[type(\"svmptr_t\")]],")}); + } else { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "")}); + } + + XeTLAPostOPs xetla_postops; + xetla_postops.add_post_op(ov_to_xetla_dtype(params.input_layouts[0].data_type), Eltwise::EltwiseOp::sum); + xetla_postops.add_post_ops(params, 5); + + auto post_op_definitions = xetla_postops.get_definitions(); + for (const auto& [name, value] : post_op_definitions) { + jit.add({make_jit_constant(name, value)}); + } + + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + if (params.is_dynamic()) { + args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0}); + } + args.push_back({ArgumentDescriptor::Types::INPUT, 1}); // lora_input + args.push_back({ArgumentDescriptor::Types::INPUT, 2}); // state_a + args.push_back({ArgumentDescriptor::Types::INPUT, 3}); // state_alpha + args.push_back({ArgumentDescriptor::Types::INPUT, 4}); // state_b + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // out + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); // temp + args.push_back({ArgumentDescriptor::Types::INPUT, 0}); // main_input + + KernelGenerator::add_fused_ops_arguments(args, params); + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{ + [tilingA = tilingA, tilingB = tilingB, total_wg_n_b = total_wg_n_b](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + size_t local_range_m = (tilingA.wg_m + tilingA.sg_m - 1) / tilingA.sg_m; + size_t local_range_nA = (tilingA.wg_n + tilingA.sg_n - 1) / tilingA.sg_n; + size_t local_range_nB = (tilingB.wg_n + tilingB.sg_n - 1) / tilingB.sg_n; + size_t local_range_n = local_range_nA > local_range_nB ? local_range_nA : local_range_nB; + + size_t group_range_m = (LoraShapeUtils::get_total_tokens(params) + tilingA.wg_m - 1) / tilingA.wg_m; + size_t group_range_n = (LoraShapeUtils::get_hidden_size_output(params) + total_wg_n_b - 1) / total_wg_n_b; + + wgs.global = {group_range_n * local_range_n, group_range_m * local_range_m, tilingA.num_global_kslicing * tilingB.num_global_kslicing}; + wgs.local = {local_range_n, local_range_m, tilingA.num_local_kslicing}; + }}; + } +}; + +class XetlaLoRAGEMMAGenerator : public XeTLALoraBaseGenerator { + const bool is_aligned; + const Tiling tiling; + +public: + XetlaLoRAGEMMAGenerator(bool is_aligned, Tiling tiling, std::string_view prefix = "A") + : XeTLALoraBaseGenerator("xetla_lora_gemmA", prefix), + is_aligned{is_aligned}, + tiling{tiling} {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + + const auto mem_layout_a = Layouts::mem_layout_a; + const auto mem_layout_b = Layouts::mem_layout_state_a; + const auto mem_layout_c = Layouts::mem_layout_c; + + jit.add({make_jit_constant("KERNEL_NAME", get_entry_point(params)), + make_jit_constant("LORA_DTYPE_A", ov_to_xetla_dtype(params.input_layouts[1].data_type)), + make_jit_constant("LORA_DTYPE_B", ov_to_xetla_dtype(params.input_layouts[2].data_type)), + make_jit_constant("LORA_DTYPE_C", ov_to_xetla_dtype(params.output_layouts[0].data_type)), + make_jit_constant("LORA_DTYPE_ACC", ov_to_xetla_dtype(ov::element::Type_t::f32)), + make_jit_constant("LORA_WG_M", tiling.wg_m), + make_jit_constant("LORA_WG_N", tiling.wg_n), + make_jit_constant("LORA_SG_M", tiling.sg_m), + make_jit_constant("LORA_SG_N", tiling.sg_n), + make_jit_constant("LORA_SG_K", tiling.sg_k), + make_jit_constant("LORA_GLOBAL_SLICING", tiling.num_global_kslicing), + make_jit_constant("LORA_LOCAL_SLICING", tiling.num_local_kslicing), + make_jit_constant("LORA_MMA_ENGINE", "mma_engine::xmx"), + make_jit_constant("LORA_MEM_LAYOUT_A", get_xetla_mem_layout(mem_layout_a)), + make_jit_constant("LORA_MEM_LAYOUT_B", get_xetla_mem_layout(mem_layout_b)), + make_jit_constant("LORA_MEM_LAYOUT_C", get_xetla_mem_layout(mem_layout_c)), + make_jit_constant("LORA_SIZE_M", LoraShapeUtils::get_total_tokens_jit(params)), + make_jit_constant("LORA_SIZE_K", LoraShapeUtils::get_hidden_size_input_jit(params)), + make_jit_constant("LORA_SIZE_N", LoraShapeUtils::get_lora_rank_jit(params)), + make_jit_constant("LORA_UNALIGNED", !is_aligned)}); + + if (params.is_dynamic()) { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "int *shape_info [[type(\"svmptr_t\")]],")}); + } else { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "")}); + } + + XeTLAPostOPs xetla_postops; + auto post_op_definitions = xetla_postops.get_definitions(); + for (const auto& [name, value] : post_op_definitions) { + jit.add({make_jit_constant(name, value)}); + } + + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + if (params.is_dynamic()) { + args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0}); + } + args.push_back({ArgumentDescriptor::Types::INPUT, 1}); // lora_input + args.push_back({ArgumentDescriptor::Types::INPUT, 2}); // state_a + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); // temp + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); // acc + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); // cnt + args.push_back({ArgumentDescriptor::Types::INPUT, 3}); // state_alpha + + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{[tiling = tiling](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + size_t group_range_m = (LoraShapeUtils::get_total_tokens(params) + tiling.wg_m - 1u) / tiling.wg_m; + size_t group_range_n = (LoraShapeUtils::get_lora_rank(params) + tiling.wg_n - 1u) / tiling.wg_n; + + size_t local_range_m = (tiling.wg_m + tiling.sg_m - 1u) / tiling.sg_m; + size_t local_range_n = (tiling.wg_n + tiling.sg_n - 1u) / tiling.sg_n; + + wgs.global = {group_range_n * local_range_n, group_range_m * local_range_m, tiling.num_global_kslicing * tiling.num_local_kslicing}; + wgs.local = {local_range_n, local_range_m, tiling.num_local_kslicing}; + }}; + } +}; + +class XetlaLoRAGEMMBGenerator : public XeTLALoraBaseGenerator { + const bool is_aligned; + const Tiling tiling; + +public: + XetlaLoRAGEMMBGenerator(bool is_aligned, Tiling tiling, std::string_view prefix = "B") + : XeTLALoraBaseGenerator("xetla_lora_gemmB", prefix), + is_aligned{is_aligned}, + tiling{tiling} {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + + const auto mem_layout_a = Layouts::mem_layout_a; + const auto mem_layout_b = Layouts::mem_layout_state_b; + const auto mem_layout_c = Layouts::mem_layout_c; + + jit.add({make_jit_constant("KERNEL_NAME", get_entry_point(params)), + make_jit_constant("LORA_DTYPE_A", ov_to_xetla_dtype(params.input_layouts[1].data_type)), + make_jit_constant("LORA_DTYPE_B", ov_to_xetla_dtype(params.input_layouts[4].data_type)), + make_jit_constant("LORA_DTYPE_C", ov_to_xetla_dtype(params.output_layouts[0].data_type)), + make_jit_constant("LORA_DTYPE_ACC", ov_to_xetla_dtype(ov::element::Type_t::f32)), + make_jit_constant("LORA_WG_M", tiling.wg_m), + make_jit_constant("LORA_WG_N", tiling.wg_n), + make_jit_constant("LORA_SG_M", tiling.sg_m), + make_jit_constant("LORA_SG_N", tiling.sg_n), + make_jit_constant("LORA_SG_K", tiling.sg_k), + make_jit_constant("LORA_GLOBAL_SLICING", tiling.num_global_kslicing), + make_jit_constant("LORA_LOCAL_SLICING", tiling.num_local_kslicing), + make_jit_constant("LORA_MMA_ENGINE", "mma_engine::xmx"), + make_jit_constant("LORA_MEM_LAYOUT_A", get_xetla_mem_layout(mem_layout_a)), + make_jit_constant("LORA_MEM_LAYOUT_B", get_xetla_mem_layout(mem_layout_b)), + make_jit_constant("LORA_MEM_LAYOUT_C", get_xetla_mem_layout(mem_layout_c)), + make_jit_constant("LORA_SIZE_M", LoraShapeUtils::get_total_tokens_jit(params)), + make_jit_constant("LORA_SIZE_K", LoraShapeUtils::get_lora_rank_jit(params)), + make_jit_constant("LORA_SIZE_N", LoraShapeUtils::get_hidden_size_output_jit(params)), + make_jit_constant("LORA_UNALIGNED", !is_aligned)}); + + if (params.is_dynamic()) { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "int *shape_info [[type(\"svmptr_t\")]],")}); + } else { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "")}); + } + + XeTLAPostOPs xetla_postops; + xetla_postops.add_post_op(ov_to_xetla_dtype(params.input_layouts[0].data_type), Eltwise::EltwiseOp::sum); + xetla_postops.add_post_ops(params, 5); + + auto post_op_definitions = xetla_postops.get_definitions(); + for (const auto& [name, value] : post_op_definitions) { + jit.add({make_jit_constant(name, value)}); + } + + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + if (params.is_dynamic()) { + args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0}); + } + + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); // temp + args.push_back({ArgumentDescriptor::Types::INPUT, 4}); // state_b + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // out + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); // acc + args.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); // cnt + args.push_back({ArgumentDescriptor::Types::INPUT, 0}); // main_input + + KernelGenerator::add_fused_ops_arguments(args, params); + + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{[tiling = tiling](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + size_t group_range_m = (LoraShapeUtils::get_total_tokens(params) + tiling.wg_m - 1u) / tiling.wg_m; + size_t group_range_n = (LoraShapeUtils::get_hidden_size_output(params) + tiling.wg_n - 1u) / tiling.wg_n; + + size_t local_range_m = (tiling.wg_m + tiling.sg_m - 1u) / tiling.sg_m; + size_t local_range_n = (tiling.wg_n + tiling.sg_n - 1u) / tiling.sg_n; + + wgs.global = {group_range_n * local_range_n, group_range_m * local_range_m, tiling.num_global_kslicing * tiling.num_local_kslicing}; + wgs.local = {local_range_n, local_range_m, tiling.num_local_kslicing}; + }}; + } +}; + +class XetlaLoraPostopGenerator : public XeTLALoraBaseGenerator { + const Tiling tiling; + +public: + XetlaLoraPostopGenerator(Tiling tiling, std::string_view prefix = "") : XeTLALoraBaseGenerator("xetla_postop", prefix), tiling{tiling} {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + + const auto mem_layout_a = Layouts::mem_layout_a; + const auto mem_layout_c = Layouts::mem_layout_c; + + jit.add({make_jit_constant("KERNEL_NAME", get_entry_point(params)), + make_jit_constant("XETLA_DTYPE_IN", ov_to_xetla_dtype(params.input_layouts[1].data_type)), + make_jit_constant("XETLA_DTYPE_OUT", ov_to_xetla_dtype(params.input_layouts[2].data_type)), + make_jit_constant("XETLA_DTYPE_ACC", ov_to_xetla_dtype(ov::element::Type_t::f32)), + make_jit_constant("XETLA_WG_M", tiling.wg_m), + make_jit_constant("XETLA_WG_N", tiling.wg_n), + make_jit_constant("XETLA_SG_M", tiling.sg_m), + make_jit_constant("XETLA_SG_N", tiling.sg_n), + make_jit_constant("XETLA_MEM_LAYOUT_IN", get_xetla_mem_layout(mem_layout_a)), + make_jit_constant("XETLA_MEM_LAYOUT_OUT", get_xetla_mem_layout(mem_layout_c)), + make_jit_constant("XETLA_SIZE_M", LoraShapeUtils::get_total_tokens_jit(params)), + make_jit_constant("XETLA_SIZE_N", LoraShapeUtils::get_hidden_size_output_jit(params))}); + + if (params.is_dynamic()) { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "int *shape_info [[type(\"svmptr_t\")]],")}); + } else { + jit.add({make_jit_constant("XETLA_SHAPE_INFO_ARG", "")}); + } + + XeTLAPostOPs xetla_postops; + xetla_postops.add_post_ops(params, 5); + + auto post_op_definitions = xetla_postops.get_definitions(); + for (const auto& [name, value] : post_op_definitions) { + jit.add({make_jit_constant(name, value)}); + } + + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + if (params.is_dynamic()) { + args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0}); + } + args.push_back({ArgumentDescriptor::Types::INPUT, 0}); // main_input + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // out + + KernelGenerator::add_fused_ops_arguments(args, params); + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{[tiling = tiling](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + assert(!params.is_dynamic()); + auto& wgs = kd.params.workGroups; + + size_t local_range_m = (tiling.wg_m + tiling.sg_m - 1) / tiling.sg_m; + size_t local_range_n = (tiling.wg_n + tiling.sg_n - 1) / tiling.sg_n; + + size_t group_range_m = (LoraShapeUtils::get_total_tokens(params) + tiling.wg_m - 1) / tiling.wg_m; + size_t group_range_n = (LoraShapeUtils::get_hidden_size_output(params) + tiling.wg_n - 1) / tiling.wg_n; + + wgs.global = {group_range_n * local_range_n, group_range_m * local_range_m, 1}; + wgs.local = {local_range_n, local_range_m, 1}; + }}; + } +}; + +class LoRAImpl : public PrimitiveImplCM { +public: + DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::cm::LoRAImpl) + + using lora = XeTLALoraBaseGenerator; + + Stage::Ptr lora_fused_short_r32 = // wg_m, wg_n, sg_m, sg_n, sg_k, global, local + make_stage(lora::Tiling{8, 32, 8, 32, 32, 1, 1}, lora::Tiling{8, 128, 8, 32, 32, 1, 1}, "short_r32"); + Stage::Ptr lora_fused_short_r64 = + make_stage(lora::Tiling{8, 64, 8, 64, 32, 1, 1}, lora::Tiling{8, 128, 8, 32, 32, 1, 1}, "short_r64"); + Stage::Ptr lora_fused_short_r128 = + make_stage(lora::Tiling{8, 128, 8, 128, 32, 1, 1}, lora::Tiling{8, 128, 8, 32, 32, 1, 1}, "short_r128"); + + Stage::Ptr lora_fused_long_r32 = + make_stage(lora::Tiling{32 * 4, 32, 32, 32, 32, 1, 1}, lora::Tiling{32 * 4, 128, 32, 32, 32, 1, 1}, "long_r32"); + Stage::Ptr lora_fused_long_r64 = + make_stage(lora::Tiling{16 * 4, 64, 16, 64, 32, 1, 1}, lora::Tiling{16 * 4, 128, 16, 32, 32, 1, 1}, "long_r64"); + Stage::Ptr lora_fused_long_r128 = + make_stage(lora::Tiling{8 * 4, 128, 8, 128, 32, 1, 1}, lora::Tiling{8 * 4, 128, 8, 32, 32, 1, 1}, "long_128"); + + Stage::Ptr lora_gemm_a_short_slicing1 = make_stage(true, lora::Tiling{8, 32, 8, 16, 32, 1, 1}, "a_short_s1"); + Stage::Ptr lora_gemm_a_short_slicing8 = make_stage(true, lora::Tiling{8, 32, 8, 16, 32, 1, 8}, "a_short_s8"); + + Stage::Ptr lora_gemm_a_long0_slicing1 = make_stage(true, lora::Tiling{128, 32, 32, 16, 32, 1, 1}, "a_long0_s1"); + Stage::Ptr lora_gemm_a_long0_slicing2 = make_stage(true, lora::Tiling{128, 32, 32, 16, 32, 1, 2}, "a_long0_s2"); + + Stage::Ptr lora_gemm_b_short = make_stage(true, lora::Tiling{8, 128, 8, 16, 32, 1, 1}, "b_short"); + Stage::Ptr lora_gemm_b_long0 = make_stage(true, lora::Tiling{128, 256, 32, 32, 32, 1, 1}, "b_long0"); + + Stage::Ptr lora_gemm_a_unaligned = make_stage(false, lora::Tiling{8 * 8, 16 * 4, 8, 16, 32, 1, 1}, "a_unaligned"); + Stage::Ptr lora_gemm_b_unaligned = make_stage(false, lora::Tiling{8 * 4, 16 * 8, 8, 16, 32, 1, 1}, "b_unaligned"); + + Stage::Ptr lora_postops = make_stage(lora::Tiling{1, 32 * 32, 1, 32, 1, 1, 1}, "postops"); + + LoRAImpl() : PrimitiveImplOCL(LoRAImplementationManager::get_type_info_static()) {} + LoRAImpl(const program_node& node, const RuntimeParams& params) : LoRAImpl() { + add_stage(lora_fused_short_r32, params); + add_stage(lora_fused_short_r64, params); + add_stage(lora_fused_short_r128, params); + add_stage(lora_fused_long_r32, params); + add_stage(lora_fused_long_r64, params); + add_stage(lora_fused_long_r128, params); + add_stage(lora_gemm_a_short_slicing1, params); + add_stage(lora_gemm_a_short_slicing8, params); + add_stage(lora_gemm_a_long0_slicing1, params); + add_stage(lora_gemm_a_long0_slicing2, params); + add_stage(lora_gemm_b_short, params); + add_stage(lora_gemm_b_long0, params); + add_stage(lora_gemm_a_unaligned, params); + add_stage(lora_gemm_b_unaligned, params); + add_stage(lora_postops, params); + } + + [[nodiscard]] std::unique_ptr clone() const override { + return make_deep_copy(this); + } + + [[nodiscard]] std::vector get_internal_buffer_descs(const RuntimeParams& params) const override { + size_t buf_size = XeTLALoraBaseGenerator::LoraShapeUtils::get_total_tokens(params) * XeTLALoraBaseGenerator::LoraShapeUtils::get_lora_rank(params); + return {BufferDescriptor{buf_size, ov::element::f16}, BufferDescriptor{0, ov::element::f32}, BufferDescriptor{0, ov::element::u32}}; + } + + cldnn::event::ptr execute(const std::vector& events, cldnn::primitive_inst& instance) override { + cldnn::stream& stream = instance.get_network().get_stream(); + if (instance.can_be_optimized()) { + return stream.aggregate_events(events, false, instance.is_output()); + } + + update_rt_params(instance); + + std::vector tmp_events(events); + const auto exec_stages = get_stages_execution_order(instance); + for (const auto& stage_id : exec_stages) { + tmp_events = {execute_stage(tmp_events, instance, *_stages[stage_id])}; + } + + return tmp_events[0]; + } + +private: + enum KernelsTypes { + FUSED_SHORT_R32 = 0, + FUSED_SHORT_R64, + FUSED_SHORT_R128, + FUSED_LONG_R32, + FUSED_LONG_R64, + FUSED_LONG_R128, + GEMM_A_SHORT_S1, + GEMM_A_SHORT_S8, + GEMM_A_LONG0_S1, + GEMM_A_LONG0_S2, + GEMM_B_SHORT, + GEMM_B_LONG0, + GEMM_A_UNALIGNED, + GEMM_B_UNALIGNED, + POSTOPS + }; + + std::vector get_stages_execution_order(const cldnn::primitive_inst& instance) const override { + std::vector stages_order; + using lora = XeTLALoraBaseGenerator; + + bool is_empty_lora = instance.get_input_layout(2).count() == 0; + if (is_empty_lora) { + stages_order.emplace_back(KernelsTypes::POSTOPS); + return stages_order; + } + + auto [tokens, rank, hidden_in, hidden_out] = lora::LoraShapeUtils::get_lora_gemm_shape(instance); + + const auto ld_input = lora::Layouts::mem_layout_a == MemLayout::col_major ? tokens : hidden_in; + const auto ld_state_a = lora::Layouts::mem_layout_state_a == MemLayout::col_major ? hidden_out : rank; + const auto ld_state_b = lora::Layouts::mem_layout_state_b == MemLayout::col_major ? rank : hidden_in; + const auto ld_state_temp = lora::Layouts::mem_layout_temp == MemLayout::col_major ? tokens : rank; + const auto ld_state_output = lora::Layouts::mem_layout_c == MemLayout::col_major ? tokens : hidden_out; + + const bool is_aligned_input = lora::is_2dload_aligned(ld_input, instance.get_input_layout(1).data_type); + const bool is_aligned_state_a = lora::is_2dload_aligned(ld_state_a, instance.get_input_layout(2).data_type); + const bool is_aligned_state_b = lora::is_2dload_aligned(ld_state_b, instance.get_input_layout(4).data_type); + const bool is_aligned_temp = lora::is_2dload_aligned(ld_state_temp, instance.get_input_layout(1).data_type); + const bool is_aligned_output = lora::is_2dload_aligned(ld_state_output, instance.get_output_layout(0).data_type); + + const bool can_use_fused_reg = rank <= 128 && is_aligned_input && is_aligned_state_a && is_aligned_state_b && is_aligned_output; + const bool is_gemmA_aligned = is_aligned_input && is_aligned_state_a && is_aligned_temp; + const bool is_gemmB_aligned = is_aligned_temp && is_aligned_state_b && is_aligned_output; + + if (tokens <= 32 && is_gemmA_aligned && is_gemmB_aligned) { + size_t iters = (hidden_in + 32 - 1) / 32; + if (iters > 16) { + stages_order.emplace_back(KernelsTypes::GEMM_A_SHORT_S8); + } else { + stages_order.emplace_back(KernelsTypes::GEMM_A_SHORT_S1); + } + stages_order.emplace_back(KernelsTypes::GEMM_B_SHORT); + return stages_order; + } + + if (tokens > 32 && is_gemmA_aligned && is_gemmB_aligned) { + size_t iters = (hidden_in + 32 - 1) / 32; + if (iters > 4) { + stages_order.emplace_back(KernelsTypes::GEMM_A_LONG0_S2); + } else { + stages_order.emplace_back(KernelsTypes::GEMM_A_LONG0_S1); + } + stages_order.emplace_back(KernelsTypes::GEMM_B_LONG0); + return stages_order; + } + + if (can_use_fused_reg) { + if (tokens <= 32) { + KernelsTypes kernel_type = KernelsTypes::FUSED_SHORT_R32; + if (rank <= 128) { + kernel_type = KernelsTypes::FUSED_SHORT_R128; + } + if (rank <= 64) { + kernel_type = KernelsTypes::FUSED_SHORT_R64; + } + if (rank <= 32) { + kernel_type = KernelsTypes::FUSED_SHORT_R32; + } + stages_order.emplace_back(kernel_type); + } else { + KernelsTypes kernel_type = KernelsTypes::FUSED_LONG_R32; + if (rank <= 128) { + kernel_type = KernelsTypes::FUSED_LONG_R128; + } + if (rank <= 64) { + kernel_type = KernelsTypes::FUSED_LONG_R64; + } + if (rank <= 32) { + kernel_type = KernelsTypes::FUSED_LONG_R32; + } + stages_order.emplace_back(kernel_type); + } + return stages_order; + } + + stages_order.emplace_back(KernelsTypes::GEMM_A_UNALIGNED); + stages_order.emplace_back(KernelsTypes::GEMM_B_UNALIGNED); + return stages_order; + } +}; + +} // namespace + +std::unique_ptr LoRAImplementationManager::create_impl(const program_node& node, const RuntimeParams& params) const { + assert(node.is_type()); + return std::make_unique(node, params); +} + +} // namespace ov::intel_gpu::cm + +BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::cm::LoRAImpl) diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora.hpp new file mode 100644 index 00000000000000..8d4fc371295e29 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora.hpp @@ -0,0 +1,106 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include "common_utils/jitter.hpp" +#include "intel_gpu/runtime/layout.hpp" +#include "lora_inst.h" +#include "registry/implementation_manager.hpp" +#include "utils/xetla_postops.hpp" + +using namespace cldnn; // TODO: Remove once namespaces are aligned + +namespace ov::intel_gpu::cm { + +struct LoRAImplementationManager : public ImplementationManager { + OV_GPU_PRIMITIVE_IMPL("cm::lora") + explicit LoRAImplementationManager(shape_types shape_type, ValidateFunc vf = nullptr) : ImplementationManager(impl_types::cm, shape_type, std::move(vf)) {} + + [[nodiscard]] in_out_fmts_t query_formats(const program_node& node) const override { + assert(node.is_type()); + std::vector in_fmts(node.get_dependencies().size(), format::bfyx); + std::vector out_fmts(node.get_outputs_count(), format::bfyx); + return {in_fmts, out_fmts}; + } + + [[nodiscard]] std::unique_ptr create_impl(const program_node& node, const kernel_impl_params& params) const override; + + [[nodiscard]] bool validate_impl(const program_node& node) const override { + assert(node.is_type()); + + auto& engine = node.get_program().get_engine(); + const auto& config = node.get_program().get_config(); + const auto& info = engine.get_device_info(); + + if (!check_cm_jit_support(engine, config) || info.arch != gpu_arch::xe2 || !config.get_use_cm()) { + return false; + } + + static constexpr std::array supported_fmts = {format::bfyx}; + static constexpr std::array supported_types = {ov::element::f16, ov::element::bf16}; + + for (const auto& input_layout : node.get_input_layouts()) { + if (!one_of(input_layout.format, supported_fmts) || !one_of(input_layout.data_type, supported_types)) { + return false; + } + if (input_layout.data_padding != padding()) { + return false; + } + } + + for (const auto& output_layout : node.get_output_layouts()) { + if (!one_of(output_layout.format, supported_fmts) || !one_of(output_layout.data_type, supported_types)) { + return false; + } + if (output_layout.data_padding != padding()) { + return false; + } + } + + const auto lora_count = ((node.get_inputs_count() - 2ul) / 3ul); + if (lora_count != 1) { + return false; + } + + for (auto& prim : node.get_fused_primitives()) { + const bool is_eltwise = fused_ops_are_one_of({prim}); + const bool is_activation = fused_ops_are_one_of({prim}); + + if (is_activation) { + const auto activation_desc = std::static_pointer_cast(prim.desc); + const auto xetla_activation_func = get_xetla_activation_op(activation_desc->activation_function); + + if (Activation::ActivationOp::none == xetla_activation_func) { + return false; + } + if (!((activation_desc->additional_params.a == 1.0f) && (activation_desc->additional_params.b == 0.0f))) { + return false; + } + if (prim.deps.size() != 0) { + return false; + } + + } else if (is_eltwise) { + const auto eltwise_desc = std::static_pointer_cast(prim.desc); + const auto xetla_eltwise_mode = get_xetla_eltwise_op(eltwise_desc->mode); + if (Eltwise::EltwiseOp::none == xetla_eltwise_mode) { + return false; + } + + if (eltwise_desc->broadcast_spec.m_axis != 0) { + return false; + } + + } else { + return false; + } + } + return true; + } +}; + +} // namespace ov::intel_gpu::cm diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_fused.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_fused.cm new file mode 100644 index 00000000000000..0b3716fb0dfb4a --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_fused.cm @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "include/batch_headers/cm_xetla.h" + +namespace KERNEL_NAME { + +#include "include/xetla_lora_fused.h" + +#define arch_tag gpu_arch::Xe + +_GENX_MAIN_ void KERNEL_NAME( + XETLA_SHAPE_INFO_ARG + LORA_DTYPE_A *lora_input + [[type("svmptr_t")]], + LORA_DTYPE_B *state_a [[type("svmptr_t")]], + LORA_DTYPE_B *state_alpha [[type("svmptr_t")]], + LORA_DTYPE_B *state_b [[type("svmptr_t")]], + LORA_DTYPE_C *out [[type("svmptr_t")]], + LORA_DTYPE_A *lora_temp [[type("svmptr_t")]] XETLA_POST_OP_KERNEL_ARGS) { + + sycl::nd_item<3> item; + + static constexpr uint32_t periodic_sync_interval_A = 0; + static constexpr uint32_t prefetch_distance_A + = (128 / (LORA_SG_K_A * sizeof(LORA_DTYPE_A))); + + static constexpr uint32_t periodic_sync_interval_B = 0; + static constexpr uint32_t prefetch_distance_B + = (128 / (LORA_SG_K_B * sizeof(LORA_DTYPE_B))); + + static constexpr mem_layout mem_layout_a = LORA_MEM_LAYOUT_A; + static constexpr mem_layout mem_layout_state_A = LORA_MEM_LAYOUT_STATE_A; + static constexpr mem_layout mem_layout_state_B = LORA_MEM_LAYOUT_STATE_B; + static constexpr mem_layout mem_layout_c = LORA_MEM_LAYOUT_C; + + using gemm_t = lora_gemm_fused; + + if constexpr (gemm_t::barrier_count != 0) { + cm_nbarrier_init(gemm_t::barrier_count); + } + if constexpr (gemm_t::slm_size != 0) { cm_slm_init(gemm_t::slm_size); } + + gemm_t::run(item, LORA_SIZE_M, LORA_SIZE_K, LORA_SIZE_N, LORA_SIZE_RANK, + lora_input, state_a, state_alpha, state_b, out, + lora_temp XETLA_POST_OP_ARGS_PASS); +} +} \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_gemmA.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_gemmA.cm new file mode 100644 index 00000000000000..52ecec206b368f --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_gemmA.cm @@ -0,0 +1,71 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "include/batch_headers/cm_xetla.h" + +namespace KERNEL_NAME { + +#define LORA_GEMM_A + +#include "include/xetla_lora.h" + +#define arch_tag gpu_arch::Xe + +_GENX_MAIN_ void KERNEL_NAME( + XETLA_SHAPE_INFO_ARG + LORA_DTYPE_A *mat_a [[type("svmptr_t")]], + LORA_DTYPE_B *mat_b [[type("svmptr_t")]], + LORA_DTYPE_C *out [[type("svmptr_t")]], + LORA_DTYPE_ACC *acc [[type("svmptr_t")]], + uint32_t *cnt [[type("svmptr_t")]], + LORA_DTYPE_B *scale [[type("svmptr_t")]] +) { + + sycl::nd_item<3> item; + + static constexpr uint32_t periodic_sync_interval = 8; + static constexpr uint32_t prefetch_distance = (128 / (LORA_SG_K * sizeof(LORA_DTYPE_A))); + + const uint32_t m = LORA_SIZE_M; + const uint32_t k = LORA_SIZE_K; + const uint32_t n = LORA_SIZE_N; + const uint32_t lda + = LORA_MEM_LAYOUT_A == mem_layout::col_major ? m : k; + const uint32_t ldb + = LORA_MEM_LAYOUT_B == mem_layout::col_major ? k : n; + const uint32_t ldc + = LORA_MEM_LAYOUT_C == mem_layout::col_major ? m : n; + + static constexpr bool unaligned = LORA_UNALIGNED; + + static_assert(!(unaligned && (LORA_MMA_ENGINE == mma_engine::fpu)), + "FPU engine GEMM does not support unaligned access"); + + using gemm_t = gemm_universal; + + if constexpr (gemm_t::barrier_count != 0) { + cm_nbarrier_init(gemm_t::barrier_count); + } + if constexpr (gemm_t::slm_size != 0) { cm_slm_init(gemm_t::slm_size); } + + gemm_t::run(item, mat_a, mat_b, out, acc, cnt, m, n, k, lda, ldb, ldc, scale); +} +} // namespace KERNEL_NAME diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_gemmB.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_gemmB.cm new file mode 100644 index 00000000000000..c3661e9f05fc1d --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_lora_gemmB.cm @@ -0,0 +1,69 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "include/batch_headers/cm_xetla.h" + +namespace KERNEL_NAME { + +#include "include/xetla_lora.h" + +#define arch_tag gpu_arch::Xe + +_GENX_MAIN_ void KERNEL_NAME( + XETLA_SHAPE_INFO_ARG + LORA_DTYPE_A *mat_a [[type("svmptr_t")]], + LORA_DTYPE_B *mat_b [[type("svmptr_t")]], + LORA_DTYPE_C *out [[type("svmptr_t")]], + LORA_DTYPE_ACC *acc [[type("svmptr_t")]], + uint32_t *cnt [[type("svmptr_t")]] + XETLA_POST_OP_KERNEL_ARGS) { + + sycl::nd_item<3> item; + + static constexpr uint32_t periodic_sync_interval = 8; + static constexpr uint32_t prefetch_distance = (128 / (LORA_SG_K * sizeof(LORA_DTYPE_A))); + + const uint32_t m = LORA_SIZE_M; + const uint32_t k = LORA_SIZE_K; + const uint32_t n = LORA_SIZE_N; + const uint32_t lda + = LORA_MEM_LAYOUT_A == mem_layout::col_major ? m : k; + const uint32_t ldb + = LORA_MEM_LAYOUT_B == mem_layout::col_major ? k : n; + const uint32_t ldc + = LORA_MEM_LAYOUT_C == mem_layout::col_major ? m : n; + + static constexpr bool unaligned = LORA_UNALIGNED; + + static_assert(!(unaligned && (LORA_MMA_ENGINE == mma_engine::fpu)), + "FPU engine GEMM does not support unaligned access"); + + using gemm_t = gemm_universal; + + if constexpr (gemm_t::barrier_count != 0) { + cm_nbarrier_init(gemm_t::barrier_count); + } + if constexpr (gemm_t::slm_size != 0) { cm_slm_init(gemm_t::slm_size); } + + gemm_t::run(item, mat_a, mat_b, out, acc, cnt, m, n, k, lda, ldb, + ldc XETLA_POST_OP_ARGS_PASS); +} +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/xetla_postop.cm b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_postop.cm new file mode 100644 index 00000000000000..4180a269dc6bd1 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cm/xetla_postop.cm @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2022-2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +namespace KERNEL_NAME { + +#include "xetla_postop.h" + +#define arch_tag gpu_arch::Xe + +_GENX_MAIN_ void KERNEL_NAME(XETLA_SHAPE_INFO_ARG XETLA_DTYPE_IN *in + [[type("svmptr_t")]], + XETLA_DTYPE_OUT *out [[type("svmptr_t")]] XETLA_POST_OP_KERNEL_ARGS) { + + sycl::nd_item<3> item; + + static constexpr mem_layout mem_layout_in = XETLA_MEM_LAYOUT_IN; + static constexpr mem_layout mem_layout_out = XETLA_MEM_LAYOUT_OUT; + + uint32_t mat_m = XETLA_SIZE_M; + uint32_t mat_n = XETLA_SIZE_N; + + postop::run(item, mat_m, mat_n, in, + out XETLA_POST_OP_ARGS_PASS); +} +} // namespace KERNEL_NAME diff --git a/src/plugins/intel_gpu/src/graph/network.cpp b/src/plugins/intel_gpu/src/graph/network.cpp index 6219a839294c63..b0f8c2924810ec 100644 --- a/src/plugins/intel_gpu/src/graph/network.cpp +++ b/src/plugins/intel_gpu/src/graph/network.cpp @@ -967,8 +967,11 @@ void network::allocate_primitive_instance(program_node const& node) { bool transpose_required = false; if (is_lora_state) { const auto& lora_prim = node.get_users().front()->as().get_primitive(); + bool is_cm_user_platfom = get_engine().get_device_info().supports_immad; for (size_t state_idx : {2, 4, 5, 7, 8, 10}) { - if (state_idx < lora_prim->input.size() && + if (is_cm_user_platfom && state_idx == 2) { + transpose_required = !lora_prim->transposed_states; + } else if (state_idx < lora_prim->input.size() && lora_prim->input[state_idx].pid == node.id()) { transpose_required = lora_prim->transposed_states; } diff --git a/src/plugins/intel_gpu/src/graph/registry/lora_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/lora_impls.cpp index 615ef90b5c2446..b8b7f4d01444d7 100644 --- a/src/plugins/intel_gpu/src/graph/registry/lora_impls.cpp +++ b/src/plugins/intel_gpu/src/graph/registry/lora_impls.cpp @@ -10,12 +10,17 @@ #include "impls/ocl_v2/lora.hpp" #endif +#if OV_GPU_WITH_CM + #include "impls/cm/xetla_lora.hpp" +#endif + namespace ov::intel_gpu { using namespace cldnn; const std::vector>& Registry::get_implementations() { static const std::vector> impls = { + OV_GPU_CREATE_INSTANCE_CM(cm::LoRAImplementationManager, shape_types::any) OV_GPU_CREATE_INSTANCE_OCL(ocl::Lora, shape_types::any) }; diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 1eccef693eb24d..dd4a21be678df4 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -891,7 +891,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { auto first_dep = add->get_input_node_shared_ptr(0); auto second_dep = add->get_input_node_shared_ptr(1); return !config.get_enable_lora_operation() || - device_info.supports_immad || + (device_info.supports_immad && device_info.arch != cldnn::gpu_arch::xe2) || ov::is_type(first_dep) || ov::is_type(second_dep); }); diff --git a/src/plugins/intel_gpu/tests/unit/fusions/lora_fusion_test.cpp b/src/plugins/intel_gpu/tests/unit/fusions/lora_fusion_test.cpp index 01cd6796487caa..f9283776fa4bab 100644 --- a/src/plugins/intel_gpu/tests/unit/fusions/lora_fusion_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/fusions/lora_fusion_test.cpp @@ -122,6 +122,8 @@ class LoraFusingsTest : public ::BaseFusingTest { #define CASE_LORA_F32_DEFAULT_OPT { 1, 1, 128 }, { 256, 128 }, { 1, 1, 256 }, {{ 16, 128 }, { 1, 16 }, { 256, 16 }}, data_types::f32, format::bfyx, format::oiyx #define CASE_LORA_F32_DEFAULT_REF { 1, 1, 128 }, { 256, 128 }, { 1, 1, 256 }, {{ 15, 128 }, { 1, 15 }, { 256, 15 }}, data_types::f32, format::bfyx, format::oiyx #define CASE_LORA_F32_EMPTY { 1, 1, 128 }, { 256, 128 }, { 1, 1, 256 }, {{ 0, 128 }, { 1, 0 }, { 256, 0 }}, data_types::f32, format::bfyx, format::oiyx +#define CASE_LORA_F16_DEFAULT_OPT { 1, 1, 128 }, { 256, 128 }, { 1, 1, 256 }, {{ 32, 128 }, { 1, 32 }, { 256, 32 }}, data_types::f16, format::bfyx, format::oiyx +#define CASE_LORA_F16_EMPTY { 1, 1, 128 }, { 256, 128 }, { 1, 1, 256 }, {{ 0, 128 }, { 1, 0 }, { 256, 0 }}, data_types::f16, format::bfyx, format::oiyx class lora_act_eltw : public LoraFusingsTest {}; TEST_P(lora_act_eltw, basic) { @@ -144,7 +146,7 @@ TEST_P(lora_act_eltw, basic) { read_value{"rv_b", { input_info("state_b") }, "var_b", { get_lora_state_layout(p, 2) }}, lora("lora", { input_info("fc_prim"), input_info("input"), input_info("rv_a"), input_info("rv_alpha"), input_info("rv_b") }, true), - activation("act", input_info("lora"), activation_func::swish), + activation("act", input_info("lora"), activation_func::swish, { 1.f, 0.f }), data("eltw_data", get_mem(get_per_last_dim_layout(p), 1, 9)), eltwise("eltw", { input_info("act"), input_info("eltw_data") }, eltwise_mode::sum, p.input_type), reorder("reorder_bfyx", input_info("eltw"), p.planar_format, data_types::f32) @@ -157,5 +159,7 @@ TEST_P(lora_act_eltw, basic) { INSTANTIATE_TEST_SUITE_P(fusings_gpu, lora_act_eltw, ::testing::ValuesIn(std::vector{ lora_test_params{ CASE_LORA_F32_DEFAULT_OPT, 6, 11 }, lora_test_params{ CASE_LORA_F32_DEFAULT_REF, 6, 11 }, - lora_test_params{ CASE_LORA_F32_EMPTY, 6, 10 } + lora_test_params{ CASE_LORA_F32_EMPTY, 6, 10 }, + lora_test_params{ CASE_LORA_F16_DEFAULT_OPT, 6, 11 }, + lora_test_params{ CASE_LORA_F16_EMPTY, 6, 10 } })); diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/common/core/cm/math_general.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/common/core/cm/math_general.hpp index 1c0fb305a6be04..423ac8e1f40a8f 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/common/core/cm/math_general.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/common/core/cm/math_general.hpp @@ -258,6 +258,54 @@ __XETLA_API T xetla_exp(T src, Sat sat = {}) { return cm_exp(T(src * log2e), Sat::value); } +/// @brief Calculate exponent value for each element of the input vector, the base is e. +/// @tparam T element type of the input and return vectors. +/// @tparam SZ size of the input and returned vectors. +/// @param src the input vector. +/// @param sat enables/disables the saturation (off by default). Possible +/// values: saturation_on/saturation_off. +/// @return vector of component-wise exponent elements. +template +__XETLA_API xetla_vector xetla_exp_precise( + xetla_vector src, Sat sat = {}) { + static_assert((std::is_same, float>::value) + || (std::is_same, fp16>::value), + "Only support fp32 and fp16"); + constexpr float log2e = 1.44269502162933349609375f; + constexpr float ln2 = 0.693145751953125f; + constexpr float coeff = -0.000001428606765330187045037746429443359375f; + + xetla_vector temp = src * log2e; + xetla_vector rounded = cm_rndz(temp); + temp = src + rounded * -ln2; + temp += rounded * coeff; + temp *= log2e; + return cm_exp(rounded, Sat::value) * cm_exp(temp, Sat::value); +} + +/// @brief Calculate exponent value of a scalar, the base is e. +/// @tparam T element type of the input and return a scalar. +/// @param src the scalar value. +/// @param sat enables/disables the saturation (off by default). Possible +/// values: saturation_on/saturation_off. +/// @return exponent value. +template +__XETLA_API T xetla_exp_precise(T src, Sat sat = {}) { + static_assert((std::is_same, float>::value) + || (std::is_same, fp16>::value), + "Only support fp32 and fp16"); + constexpr float log2e = 1.44269502162933349609375f; + constexpr float ln2 = 0.693145751953125f; + constexpr float coeff = -0.000001428606765330187045037746429443359375f; + + float temp = src * log2e; + float rounded = cm_rndz(temp); + temp = src + rounded * -ln2; + temp += rounded * coeff; + temp *= log2e; + return cm_exp(rounded, Sat::value) * cm_exp(temp, Sat::value); +} + /// @brief Calculate exponent value for each element of the input vector, the base is 2. /// @tparam T element type of the input and return vectors. /// @tparam SZ size of the input and returned vectors. @@ -485,6 +533,41 @@ __XETLA_API T xetla_sigmoid(T src) { return (src <= -10) ? 0 : ret; } +/// @brief Calculate sigmoid value for each element of the input vector. +/// @tparam T element type of the input and return vectors. +/// @tparam SZ size of the input and returned vectors. +/// @param src the input vector. +/// @return vector of sigmoid of component-wise elements. +template +__XETLA_API xetla_vector xetla_sigmoid_precise(xetla_vector src) { + static_assert((std::is_same, float>::value) + || (std::is_same, fp16>::value), + "Only support fp32 and fp16"); + xetla_mask mask_gt = src > 105.f; + xetla_mask mask_lt = src < -105.f; + xetla_vector exp = xetla_exp_precise(-src); + xetla_vector ret_sub = 1.f / (exp + 1.f); + ret_sub.xetla_merge(1, mask_gt); + ret_sub.xetla_merge(0, mask_lt); + return ret_sub; +} + +/// @brief Calculate sigmoid of a scalar. +/// @tparam T element type of the input and return a scalar. +/// @param src the scalar value. +/// @return sigmoid value. +template +__XETLA_API T xetla_sigmoid_precise(T src) { + static_assert((std::is_same, float>::value) + || (std::is_same, fp16>::value), + "Only support fp32 and fp16"); + float exp = xetla_exp_precise(-src); + float ret = 1.f / (exp + 1.f); + float val = src > 105.f ? 1 : val; + val = src < -105.f ? 0 : val; + return val; +} + /// Add two unsigned integer vectors, return the result and in-place update the carry. /// @tparam T element type of the src, should be uint32_t. /// @tparam SZ element num of the vector. diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/epilogue_policy.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/epilogue_policy.hpp index 88f1e52d79a20d..78d2edac47996a 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/epilogue_policy.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/epilogue_policy.hpp @@ -29,12 +29,13 @@ namespace gpu::xetla::group { /// @brief Epilogue policy for tile_op + store C fusion. /// @tparam tile_op_t_ Is the tile_op functor. /// @tparam arch_tag_ Is the HW architecture. -template +template struct epilogue_policy_tile_op { using tile_op_t = tile_op_t_; static constexpr gpu_arch arch_tag = arch_tag_; + static constexpr msg_type msg_type_c = msg_type_c_; }; - /// @} xetla_epilogue } // namespace gpu::xetla::group diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/impl/tile_op_xe.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/impl/tile_op_xe.hpp index 5376009508ab09..2356dd087e1c9b 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/impl/tile_op_xe.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/group/epilogue/impl/tile_op_xe.hpp @@ -30,9 +30,11 @@ namespace gpu::xetla::group { /// @brief Is the epilogue functor specialized for epilogue_policy_tile_op and Xe architecture. template -class epilogue_t, tile_shape_, - mem_desc_c_t_, std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> { + gpu_arch arch_tag_, msg_type msg_type_c_> +class epilogue_t, + tile_shape_, mem_desc_c_t_, + std::enable_if_t<(arch_tag_ == gpu_arch::Xe) + >> { public: using epilogue_policy = epilogue_policy_tile_op; using tile_op_t = typename epilogue_policy::tile_op_t; @@ -99,7 +101,7 @@ class epilogue_t, tile_shape_, public: static constexpr msg_type msg_type_c - = (mem_space_c == mem_space::global ? msg_type::block_2d + = (mem_space_c == mem_space::global ? msg_type_c_ : msg_type::scatter); /// @brief Default epilogue. diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/compute_policy.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/compute_policy.hpp index 297d6eb90a72fe..de608a93682f0e 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/compute_policy.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/compute_policy.hpp @@ -66,6 +66,41 @@ struct compute_policy_default_xmx +struct compute_policy_unaligned_xmx {}; + +/// @brief Specialized for Xe architecture. +template +struct compute_policy_unaligned_xmx { + using compute_attr = compute_attr_; + using perf_tuning_knob = perf_tuning_knob_; + static constexpr int k_stride = perf_tuning_knob::k_stride; + static constexpr int stages = perf_tuning_knob::stages; + static constexpr int sync_freq = perf_tuning_knob::sync_freq; + static constexpr gpu_arch arch_tag = gpu_arch::Xe; + using dtype_mma_acc = typename compute_attr::dtype_acc; + using dtype_mma_a = typename compute_attr::dtype_a; + using dtype_mma_b = typename compute_attr::dtype_b; + + static constexpr uint32_t block_bytes_x_a = 32; + static constexpr uint32_t block_size_x_a + = block_bytes_x_a / sizeof(dtype_mma_a); + static constexpr uint32_t block_size_y_a = 16; + + static constexpr uint32_t block_size_x_b = 16; + static constexpr uint32_t block_bytes_y_b = 32; + static constexpr uint32_t block_size_y_b + = block_bytes_y_b / sizeof(dtype_mma_b); + static_assert(block_size_x_a == block_size_y_b, + "mat_a x need to match with mat_b y"); +}; + /// @brief Compute policy for fpu engine. /// @tparam compute_attr_ Is compute-related attributes. /// @tparam perf_tuning_knob_ Is performance-related knobs. diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/gemm.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/gemm.hpp index 10a54116427135..b593b7b808d258 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/gemm.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/gemm.hpp @@ -29,4 +29,5 @@ #include "group/gemm/impl/default_xmx_xe.hpp" #include "group/gemm/impl/pre_processing_xe.hpp" #include "group/gemm/impl/selector_xe.hpp" +#include "group/gemm/impl/unaligned_xmx_xe.hpp" #include "group/tile_shape.hpp" diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/impl/unaligned_xmx_xe.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/impl/unaligned_xmx_xe.hpp new file mode 100644 index 00000000000000..7b02ce08846784 --- /dev/null +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/group/gemm/impl/unaligned_xmx_xe.hpp @@ -0,0 +1,524 @@ +/******************************************************************************* +* Copyright (c) 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include "group/gemm/api.hpp" +#include "group/gemm/compute_policy.hpp" + +namespace gpu::xetla::group { + +/// @addtogroup xetla_gemm +/// @{ + +/// @brief Is the gemm functor for unaligned input, Xe architecture and matrix engine. +template +class gemm_t, + tile_shape_, // tile shape of workgroup-level gemm + mem_desc_a_t_, // memory attribute of matA + mem_desc_b_t_, // memory attribute of matB + pre_processing_t_, // pre_processing functor + std::enable_if_t<(arch_tag_ == gpu_arch::Xe) + >> { +public: + using mem_desc_a_t = mem_desc_a_t_; + using mem_desc_b_t = mem_desc_b_t_; + using tile_shape = tile_shape_; + using pre_processing_t = pre_processing_t_; + using compute_policy = compute_policy_unaligned_xmx; + + static constexpr uint32_t num_cyclic = 3; + + static constexpr uint32_t k_stride = compute_policy::k_stride; + static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y; + static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x; + static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; + static constexpr uint32_t wg_size_y = tile_shape::wg_size_y; + using work_group_t = typename tile_shape::work_group_t; + + constexpr static gpu_arch arch_tag = compute_policy::arch_tag; + + static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout; + static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; + static constexpr bool is_col_major_a + = mem_layout_a == mem_layout::col_major; + static constexpr bool is_col_major_b + = mem_layout_b == mem_layout::col_major; + +private: + /******** set data type **********/ + using dtype_a = typename mem_desc_a_t::dtype; + using dtype_b = typename mem_desc_b_t::dtype; + using dtype_mma_acc = typename compute_policy::dtype_mma_acc; + using dtype_mma_a = typename compute_policy::dtype_mma_a; + using dtype_mma_b = typename compute_policy::dtype_mma_b; + + using check_dtype + = group::gemm::default_xmx::check_dtype_default< + dtype_a, dtype_b, dtype_mma_a, dtype_mma_b>; + + /******** set memory attribute **********/ + static constexpr mem_space mem_space_a = mem_desc_a_t::space; + static constexpr mem_space mem_space_b = mem_desc_b_t::space; + + static constexpr bool is_local_a = mem_space_a == mem_space::local; + static constexpr bool is_local_b = mem_space_b == mem_space::local; + static constexpr tdesc_update_dir update_dir_a = is_col_major_a + ? tdesc_update_dir::y_dir + : tdesc_update_dir::x_dir; + static constexpr tdesc_update_dir update_dir_b = is_col_major_b + ? tdesc_update_dir::x_dir + : tdesc_update_dir::y_dir; + + using check_memory + = group::gemm::default_xmx::check_memory_default< + mem_layout_a, mem_layout_b, mem_space_a, mem_space_b>; + + static constexpr uint32_t stages = compute_policy::stages; + static constexpr uint32_t sync_freq = compute_policy::sync_freq; + + /******** set tile layout && worker scope **********/ + static constexpr uint32_t tile_size_x_a = k_stride; + static constexpr uint32_t tile_size_y_a = sg_tile_m; + static constexpr uint32_t tile_size_x_b = sg_tile_n; + static constexpr uint32_t tile_size_y_b = k_stride; + static constexpr uint32_t tile_size_x_c = sg_tile_n; + static constexpr uint32_t tile_size_y_c = sg_tile_m; + static constexpr uint32_t block_size_x_a = compute_policy::block_size_x_a; + static constexpr uint32_t block_size_y_a + = (compute_policy::block_size_y_a > tile_size_y_a) + ? tile_size_y_a + : compute_policy::block_size_y_a; + static constexpr uint32_t block_size_x_b = compute_policy::block_size_x_b; + static constexpr uint32_t block_size_y_b = compute_policy::block_size_y_b; + + using check_tile_size = group::gemm< + gpu_arch::Xe>::default_xmx::check_tile_size_default; + + /******** set tile **********/ + static constexpr reg_layout reg_layout_a = reg_layout::tiled; + using matA_tile_desc_t = subgroup::tile_desc_t; + + using matA_t = subgroup::tile_t; + + using cooperative_helper_A_t = subgroup::cooperative_load_helper_t; + using cooperative_tile_desc_A_t = + typename cooperative_helper_A_t::co_tile_desc_t; + using partial_matA_t = subgroup::tile_t; + using matA_payload_t = subgroup::mem_payload_t; + + using matA_payload_local_st_t = subgroup::mem_payload_t< + mem_desc_t, + cooperative_tile_desc_A_t, msg_type::scatter, arch_tag>; + using matA_payload_local_ld_t = subgroup::mem_payload_t< + mem_desc_t, + matA_tile_desc_t, msg_type::scatter, arch_tag>; + + using matA_acc_t = subgroup::tile_t; + using matA_prefetch_payload_t = subgroup::prefetch_payload_t, + wg_size_x, arch_tag>; + static constexpr reg_layout reg_layout_b + = sizeof(dtype_b) < sizeof(uint32_t) ? reg_layout::vnni_tiled + : reg_layout::tiled; + using matB_tile_desc_t = subgroup::tile_desc_t; + using matB_t = subgroup::tile_t; + + using cooperative_helper_B_t = subgroup::cooperative_load_helper_t; + using cooperative_tile_desc_B_t = + typename cooperative_helper_B_t::co_tile_desc_t; + + using partial_matB_t = subgroup::tile_t; + + using matB_payload_t = subgroup::mem_payload_t; + + using matB_payload_local_st_t = subgroup::mem_payload_t< + mem_desc_t, + cooperative_tile_desc_B_t, msg_type::scatter, arch_tag>; + using matB_payload_local_ld_t = subgroup::mem_payload_t< + mem_desc_t, + matB_tile_desc_t, msg_type::scatter, arch_tag>; + + using matB_acc_t = subgroup::tile_t; + using matB_prefetch_payload_t = subgroup::prefetch_payload_t, + wg_size_y, arch_tag>; + +public: + using matAcc_tile_desc_t = subgroup::tile_desc_t; + using matAcc_t = subgroup::tile_t; + +private: + using tile_mma = subgroup::tile_mma_t; + // static constexpr bool enable_periodic_sync = (sync_freq != 0); + static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; + static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + static constexpr uint32_t tile_size_a + = tile_size_x_a * tile_size_y_a * sizeof(dtype_a); + static constexpr uint32_t tile_size_b + = tile_size_x_b * tile_size_y_b * sizeof(dtype_b); + static constexpr uint32_t slm_size_a = wg_size_y * tile_size_a; + static constexpr uint32_t slm_size_b = wg_size_x * tile_size_b; + +public: + static constexpr uint32_t barrier_count = barrier_count_x + barrier_count_y; + + static constexpr uint32_t slm_size = (slm_size_a + slm_size_b) * num_cyclic; + static constexpr uint32_t slm_base_a = 0; + static constexpr uint32_t slm_base_b = 0 + slm_size_a * num_cyclic; + + static constexpr msg_type msg_type_a = matA_payload_t::message_type; + static constexpr msg_type msg_type_b = matB_payload_t::message_type; + + using pre_processing_arg_t = typename pre_processing_t::arguments_t; + + /// @brief Arguments for gemm. + /// User should prepare matA_base_desc, matB_base_desc, inner_loop_count... + struct arguments_t { + /// @brief Is the memory description of matA, including base, shape and coordinate. + mem_desc_a_t matA_base_desc; + /// @brief Is the memory description of matB, including base, shape and coordinate. + mem_desc_b_t matB_base_desc; + /// @brief Is the total inner loop count required to compute the entire K-dim. + uint32_t inner_loop_count; + /// @brief Is the arguments for pre-processing functor. + pre_processing_arg_t pre_processing_args; + + /// @brief Default construct. + inline arguments_t() = default; + // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor) + // Please check if you need to add self-define destructor + // ~arguments_t(){} + + /// @brief Constructs a new arguments t object. + /// @param matA_desc Is the memory description of matA, including base, shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_count, pre_processing_arg_t args = {}) + : matA_base_desc(matA_desc) + , matB_base_desc(matB_desc) + , inner_loop_count(loop_count) + , pre_processing_args(args) {} + // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor) + // Please check if you need to add self-define destructor + // inline ~arguments_t(){} + inline arguments_t(const arguments_t &args) + : matA_base_desc(args.matA_base_desc) + , matB_base_desc(args.matB_base_desc) + , inner_loop_count(args.inner_loop_count) + , pre_processing_args(args.pre_processing_args) {} + inline arguments_t &operator=(const arguments_t &args) { + this->matA_base_desc = args.matA_base_desc; + this->matB_base_desc = args.matB_base_desc; + this->inner_loop_count = args.inner_loop_count; + this->pre_processing_args = args.pre_processing_args; + return *this; + } + + /// @brief Explicit initialization function. + /// @param matA_desc Is the memory description of matA, including base, shape and coordinate. + /// @param matB_desc Is the memory description of matB, including base, shape and coordinate. + /// @param loop_count Is the total inner loop count required to compute the entire K-dim. + /// @param args Is the arguments for pre-processing functor. + inline void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_count, pre_processing_arg_t args = {}) { + matA_base_desc = matA_desc; + matB_base_desc = matB_desc; + inner_loop_count = loop_count; + pre_processing_args = args; + } + }; + + /// @brief Gets the subgroup-level tile offset x. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset x. + __XETLA_API static int get_matC_offset_x(work_group_t &g) { + int32_t sg_idx = g.get_id() % wg_size_x; + return sg_idx * sg_tile_n; + } + + /// @brief Gets the subgroup-level tile offset y. + /// @param g Is the workgroup of the current tile. + /// @return Subgroup-level tile offset y. + __XETLA_API static int get_matC_offset_y(work_group_t &g) { + int32_t sg_idy = g.get_id() / wg_size_x; + return sg_idy * sg_tile_m; + } + + XETLA_MARKER( + "This release function will wait until all the r/w and nbarrier " + "id used in this gemm have been committed. By default, it will " + "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. " + "If you call this function, please set a free barrier id or make " + "sure barrier_id 0 is not being occupied and you need to allocate " + "one more barrier count in addition to the gemm barrier counts.") + __XETLA_API static void release(uint8_t nbarrier_id = 0) { + static constexpr bool need_local_fence + = (mem_space_a == mem_space::local) + || (mem_space_b == mem_space::local); + if constexpr (need_local_fence) { + xetla_fence(); + } + xetla_fence(); + static constexpr uint32_t wg_size = wg_size_x * wg_size_y; + if constexpr (wg_size > 1) { + xetla_nbarrier_t nbarrier; + nbarrier.init_nbarrier( + nbarrier_id, nbarrier_role::producer_consumer); + nbarrier.arrive_wait(); + } + } + + /// @brief Main execution function for gemm. + /// The basic process is load data -> matrix multiply. + /// @param g Is the workgroup of the current tile. + /// @param matAcc Is the reference of the accumulation buffer. + /// @param args Is the gemm::arguments_t. + /// @param slm_base Is the slm base address. + /// @param nbarrier_base Is the named barrier base. + __XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, + arguments_t args, uint32_t slm_base = 0, + uint32_t nbarrier_base = 0) { + int32_t sg_idx = g.get_id() % wg_size_x; + int32_t sg_idy = g.get_id() / wg_size_x; + + XETLA_ASSERT(g.get_id() < (wg_size_x * wg_size_y), + "Thread id(%d) should less than wg_size(%d)", g.get_id(), + wg_size_x * wg_size_y); + + update_sg_tile_tdesc(args, sg_idx, sg_idy); + pre_processing_t pre_processing; + matA_t matA; + matB_t matB; + partial_matA_t partial_matA; + partial_matB_t partial_matB; + // >>>>>>>>>>>>>>>>>> pre_processing init + pre_processing.init(g, args.pre_processing_args); + uint32_t base_A = slm_base_a + sg_idy * tile_size_a; + uint32_t base_B = slm_base_b + sg_idx * tile_size_b; + + uint32_t store_idx = 0; + uint32_t load_idx = 0; + + matA_payload_t matA_payload(args.matA_base_desc); + matA_payload_local_st_t matA_local_st_payload(base_A, tile_size_x_a, + tile_size_y_a, tile_size_x_a, + cooperative_helper_A_t::get_offset_x(sg_idx), + cooperative_helper_A_t::get_offset_y(sg_idx)); + matA_payload_local_ld_t matA_local_ld_payload( + base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, 0, 0); + + matB_payload_t matB_payload(args.matB_base_desc); + matB_payload_local_st_t matB_local_st_payload(base_B, tile_size_x_b, + tile_size_y_b, tile_size_x_b, + cooperative_helper_B_t::get_offset_x(sg_idy), + cooperative_helper_B_t::get_offset_y(sg_idy)); + matB_payload_local_ld_t matB_local_ld_payload( + base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, 0, 0); + + matA_prefetch_payload_t matA_prefetch_payload( + args.matA_base_desc, sg_idx); + matB_prefetch_payload_t matB_prefetch_payload( + args.matB_base_desc, sg_idy); + + xetla_nbarrier_t nbarrier_a; + nbarrier_a.init_nbarrier( + sg_idy + nbarrier_base, nbarrier_role::producer_consumer); + xetla_nbarrier_t nbarrier_b; + nbarrier_b.init_nbarrier(sg_idx + barrier_count_y + nbarrier_base, + nbarrier_role::producer_consumer); + + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx++; + + matA_payload.template update_tdesc(matA_t::tile_size_x); + matB_payload.template update_tdesc(matB_t::tile_size_y); + xetla_fence(); + nbarrier_a.arrive(); + nbarrier_b.arrive(); +#pragma unroll + for (int i = 1; i < num_cyclic - 1; i++) { + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + + matA_payload.template update_tdesc( + matA_t::tile_size_x); + matB_payload.template update_tdesc( + matB_t::tile_size_y); + + matA_local_st_payload + .template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_st_payload + .template update_tdesc( + wg_size_x * matB_t::tile_size_y); + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx++; + } + + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x * (num_cyclic - 1)); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y * (num_cyclic - 1)); +#pragma unroll + for (int i = 0; i < stages; i++) { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + + for (int i = 0; i < args.inner_loop_count; i++) { + tile_load(partial_matA, matA_payload); + tile_load(partial_matB, matB_payload); + + matA_payload.template update_tdesc( + matA_t::tile_size_x); + matB_payload.template update_tdesc( + matB_t::tile_size_y); + + if constexpr (stages != 0) { + subgroup::tile_prefetch( + matA_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); + } + + nbarrier_a.wait(); + nbarrier_b.wait(); + + tile_load(matA, matA_local_ld_payload); + tile_load(matB, matB_local_ld_payload); + + load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0; + + if (load_idx != 0) { + matA_local_ld_payload + .template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_ld_payload + .template update_tdesc( + wg_size_x * matB_t::tile_size_y); + } else { + matA_local_ld_payload + .template update_tdesc( + (1 - num_cyclic) * wg_size_y + * matA_t::tile_size_y); + matB_local_ld_payload + .template update_tdesc( + (1 - num_cyclic) * wg_size_x + * matB_t::tile_size_y); + } + xetla_fence(); + + if constexpr (stages != 0) { + matA_prefetch_payload.template update_tdesc( + matA_t::tile_size_x); + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + } + + nbarrier_a.arrive(); + nbarrier_b.arrive(); + SW_BARRIER(); + matA_acc_t matA_acc; + matB_acc_t matB_acc; + subgroup::elemwise_cvt(matA_acc, matA); + subgroup::vnni_transform(matB_acc, matB); + pre_processing(matA_acc, matB_acc, matA, matB); + SW_BARRIER(); + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + SW_BARRIER(); + + if (store_idx != 0) { + matA_local_st_payload + .template update_tdesc( + wg_size_y * matA_t::tile_size_y); + matB_local_st_payload + .template update_tdesc( + wg_size_x * matB_t::tile_size_y); + } else { + matA_local_st_payload + .template update_tdesc( + (1 - num_cyclic) * wg_size_y + * matA_t::tile_size_y); + matB_local_st_payload + .template update_tdesc( + (1 - num_cyclic) * wg_size_x + * matB_t::tile_size_y); + } + + tile_store(partial_matA, matA_local_st_payload); + tile_store(partial_matB, matB_local_st_payload); + store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0; + } + SW_BARRIER(); + nbarrier_a.wait(); + nbarrier_b.wait(); + } + +private: + /// @brief Updates tile base descriptor based on the tid. + __XETLA_API static void update_sg_tile_tdesc( + arguments_t &args, int32_t sg_idx, int32_t sg_idy) { + int32_t tile_offset_n = sg_idx * sg_tile_n; + int32_t tile_offset_m = sg_idy * sg_tile_m; + + args.matA_base_desc.update_coord_y( + tile_offset_m + cooperative_helper_A_t::get_offset_y(sg_idx)); + args.matA_base_desc.update_coord_x( + cooperative_helper_A_t::get_offset_x(sg_idx)); + args.matB_base_desc.update_coord_x( + tile_offset_n + cooperative_helper_B_t::get_offset_x(sg_idy)); + args.matB_base_desc.update_coord_y( + cooperative_helper_B_t::get_offset_y(sg_idy)); + } +}; + +/// @} xetla_gemm + +} // namespace gpu::xetla::group diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/cooperative_load_helper.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/cooperative_load_helper.hpp new file mode 100644 index 00000000000000..00dc3dc7a5a3c9 --- /dev/null +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/cooperative_load_helper.hpp @@ -0,0 +1,152 @@ +/******************************************************************************* +* Copyright (c) 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include "subgroup/tile/tile.hpp" + +namespace gpu::xetla::subgroup { + +/// @brief Helper to do the cooperative workgroups load. +/// @tparam matAcc_t Is the input mat type. +/// @tparam tile_shape Is the group-level tile shape. +/// @tparam mem_layout Is the memory layout of input. +/// @tparam num_cooperative_wg Is the number of workgroups to do the cooperation. +/// @tparam arch_tag Is the HW architecture. +template +class cooperative_load_helper_t {}; + +/// @brief Workgroups to do the cooperative load. Specialized for and row_major and Xe architecture. +template +class cooperative_load_helper_t> { +public: + static constexpr gpu_arch arch_tag = arch_tag_; + using matAcc_t = matAcc_t_; + using dtype = typename matAcc_t::dtype; + using tile_desc_t = typename matAcc_t::tile_desc; + static constexpr mem_layout layout = mem_layout::row_major; + +private: + // cooperative split, y dir first + static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0, + "num_cooperative_wg should be power of 2"); + +public: + static constexpr uint32_t src_block_size_x = tile_desc_t::block_size_x; + static constexpr uint32_t src_block_size_y = tile_desc_t::block_size_y; + static constexpr uint32_t src_tile_size_x = tile_desc_t::tile_size_x; + static constexpr uint32_t src_tile_size_y = tile_desc_t::tile_size_y; + + static constexpr uint32_t coop_num_y + = gpu::xetla::subgroup::detail::gcd::value; + static constexpr uint32_t coop_remain_num_x + = num_cooperative_wg / coop_num_y; + static constexpr uint32_t tile_size_y = src_tile_size_y / coop_num_y; + static constexpr uint32_t tile_size_x = src_tile_size_x / coop_remain_num_x; + static constexpr uint32_t coop_num_x = src_tile_size_x / tile_size_x; + + static_assert((tile_size_y * tile_size_x % 16) == 0, + "cooperative tile size should be a multiply of simd-16 "); + +public: + static constexpr uint32_t block_size_x + = gpu::xetla::subgroup::detail::gcd::value; + static constexpr uint32_t block_size_y + = (tile_size_y > src_block_size_y) ? src_block_size_y : tile_size_y; + + using co_tile_desc_t = subgroup::tile_desc_t; + +public: + inline cooperative_load_helper_t() = default; + + inline static int32_t get_offset_x(uint32_t coop_id) { + return coop_id % coop_remain_num_x * tile_size_x; + } + + inline static int32_t get_offset_y(uint32_t coop_id) { + return coop_id / coop_remain_num_x * tile_size_y; + } +}; + +/// @brief Workgroups to do the cooperative load. Specialized for and row_major and Xe architecture. +template +class cooperative_load_helper_t> { +public: + static constexpr gpu_arch arch_tag = arch_tag_; + using matAcc_t = matAcc_t_; + using dtype = typename matAcc_t::dtype; + using tile_desc_t = typename matAcc_t::tile_desc; + static constexpr mem_layout layout = mem_layout::col_major; + +private: + // cooperative split, y dir first + static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0, + "num_cooperative_wg should be power of 2"); + +public: + static constexpr uint32_t src_block_size_x = tile_desc_t::block_size_x; + static constexpr uint32_t src_block_size_y = tile_desc_t::block_size_y; + static constexpr uint32_t src_tile_size_x = tile_desc_t::tile_size_x; + static constexpr uint32_t src_tile_size_y = tile_desc_t::tile_size_y; + + static constexpr uint32_t coop_num_x + = gpu::xetla::subgroup::detail::gcd::value; + static constexpr uint32_t coop_remain_num_y + = num_cooperative_wg / coop_num_x; + static constexpr uint32_t tile_size_x = src_tile_size_x / coop_num_x; + static constexpr uint32_t tile_size_y = src_tile_size_y / coop_remain_num_y; + static constexpr uint32_t coop_num_y = src_tile_size_y / tile_size_y; + + static_assert((tile_size_y * tile_size_x % 16) == 0, + "cooperative tile size should be a multiply of simd-16 "); + +public: + static constexpr uint32_t block_size_y + = gpu::xetla::subgroup::detail::gcd::value; + static constexpr uint32_t block_size_x + = (tile_size_x > src_block_size_x) ? src_block_size_x : tile_size_x; + + using co_tile_desc_t = subgroup::tile_desc_t; + +public: + inline cooperative_load_helper_t() = default; + + inline static int32_t get_offset_x(uint32_t coop_id) { + return coop_id / coop_remain_num_y * tile_size_x; + } + + inline static int32_t get_offset_y(uint32_t coop_id) { + return coop_id % coop_remain_num_y * tile_size_y; + } +}; + +} // namespace gpu::xetla::subgroup diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/subgroup.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/subgroup.hpp index 240aa7b90a7d4a..4459b1fe8be413 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/subgroup.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/subgroup.hpp @@ -19,4 +19,6 @@ #pragma once +#include "subgroup/cooperative_load_helper.hpp" #include "subgroup/tile/tile.hpp" + diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/load_xe.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/load_xe.hpp index 16a0fb62ab09ee..94bf3d876dedfb 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/load_xe.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/load_xe.hpp @@ -724,8 +724,8 @@ tile_load(tile_t &tile, payload_t &payload, oob_check_tag tag = {}) { payload.base_ptr, payload.channel_offset + payload.base_offset + address_offset, - pred_x && pred_y); - reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y); + pred_x & pred_y); + reg_tmp.xetla_merge(reg_tmp, 0, pred_x & pred_y); reg_sub.xetla_select( sub_block_y * tile_desc::block_size_x) @@ -772,9 +772,9 @@ tile_load(tile_t &tile, payload_t &payload, oob_check_tag tag = {}) { payload.base_ptr, payload.channel_offset + payload.base_offset + address_offset, - pred_x && pred_y); + pred_x & pred_y); - reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y); + reg_tmp.xetla_merge(reg_tmp, 0, pred_x & pred_y); reg_sub.xetla_select( sub_block_y * tile_desc::block_size_x) diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/store_xe.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/store_xe.hpp index 084f17b59c1b7a..e2d05fc2edb21b 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/store_xe.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/store_xe.hpp @@ -365,7 +365,7 @@ tile_store(tile_t &tile, payload_t &payload, oob_check_tag tag = {}) { reg_sub.xetla_select( sub_block_y * tile_desc::block_size_x) .xetla_format(), - (pred_x && pred_y)); + (pred_x & pred_y)); } } } @@ -403,7 +403,7 @@ tile_store(tile_t &tile, payload_t &payload, oob_check_tag tag = {}) { reg_sub.xetla_select( sub_block_y * tile_desc::block_size_x) .xetla_format(), - (pred_x && pred_y)); + (pred_x & pred_y)); } } } diff --git a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/tile_op_functor.hpp b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/tile_op_functor.hpp index 4b5445a6c924c5..6a6bf31c470079 100644 --- a/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/src/plugins/intel_gpu/thirdparty/xetla/include/subgroup/tile/impl/tile_op_functor.hpp @@ -29,6 +29,122 @@ namespace gpu::xetla::subgroup { +template +struct msg_type_postop_query { + static constexpr msg_type value = memory_space == mem_space::global + ? msg_type::unaligned_2d + : msg_type::scatter; +}; + +template +constexpr msg_type msg_type_postop_v + = msg_type_postop_query::value; +/// @brief Is MatAcc * vector scale / div.Add commentMore actions +/// @tparam scale_dtype Is the scale data type. +/// @tparam arch_tag Is the hardware architecture tag. +template +struct scale_v_div_op_t {}; +/// @brief Is the scale_v op functor, specialized for Xe architecture. +template +struct scale_v_div_op_t> { + using scale_dtype = scale_dtype_; + + using scale_mem_desc_t + = mem_desc_t; + + using scale_shape_t = typename scale_mem_desc_t::shape_t; + using scale_base_t = typename scale_mem_desc_t::base_t; + using coord_t = typename scale_mem_desc_t::coord_t; + + struct arguments_t { + scale_base_t scale_base; + scale_shape_t scale_shape; + float div; + + inline arguments_t() = default; + inline arguments_t(scale_base_t scale_base_, scale_shape_t scale_shape_, + float div_) + : scale_base(scale_base_), scale_shape(scale_shape_), div(div_) {} + }; + template + __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, + const coord_t &coord, const arguments_t &args, + uint32_t slm_base = 0, uint32_t nbarrier_base = 0) { + using dtype_acc = typename matAcc_t::dtype; + + static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x; + static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y; + static constexpr uint32_t block_size_x = matAcc_t::block_size_x; + static constexpr uint32_t block_size_y = matAcc_t::block_size_y; + static constexpr int32_t num_block_x = matAcc_t::num_block_x; + static constexpr int32_t num_block_y = matAcc_t::num_block_y; + static constexpr uint32_t tile_elems = matAcc_t::tile_elems; + static constexpr uint32_t block_elems = matAcc_t::block_elems; + + using scale_tile_desc_t = tile_desc_t; + using scale_tile_t = tile_t; + using scale_payload_t = mem_payload_t, + arch_tag>; + coord_t scale_coord(coord.x, 0); + scale_mem_desc_t scale_mem_desc( + args.scale_base, args.scale_shape, scale_coord); + scale_tile_t scale_tile; + scale_payload_t scale_payload(scale_mem_desc); + tile_load( + scale_tile, scale_payload); + +#pragma unroll + for (int i = 0; i < tile_size_y / block_size_y; i++) { +#pragma unroll + for (int j = 0; j < num_block_x; j++) { + auto acc_reg = matAcc.reg.xetla_select( + (i * num_block_x + j) * block_elems); + auto scale_reg = scale_tile.reg.xetla_select( + j * block_size_x) + / args.div; +#pragma unroll + for (int row_i = 0; row_i < block_size_y; row_i++) { + acc_reg.xetla_select(row_i * block_size_x) + = scale_reg + * acc_reg.xetla_select( + row_i * block_size_x); + } + } + } + // process the tail + if constexpr ((tile_size_y % block_size_y) != 0) { + constexpr uint32_t tail_start_y + = tile_size_y / block_size_y * block_size_y; + constexpr int32_t tail_size_y = tile_size_y % block_size_y; + constexpr int32_t tail_block_elems = tail_size_y * block_size_x; +#pragma unroll + for (int j = 0; j < num_block_x; j++) { + auto acc_reg = matAcc.reg.xetla_select( + tail_start_y * tile_size_x + j * tail_block_elems); + auto scale_reg = scale_tile.reg.xetla_select( + j * block_size_x) + / args.div; +#pragma unroll + for (int row_i = 0; row_i < tail_size_y; row_i++) { + acc_reg.xetla_select(row_i * block_size_x) + = scale_reg + * acc_reg.xetla_select( + row_i * block_size_x); + } + } + } + } +}; + /// @brief Is none op functor, for placeholder purpose. /// Used in epilogue::tile_op or chained_tile_op. struct none_op_t { @@ -105,7 +221,7 @@ struct relu_pack_mask_op_t; using mask_out_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; mem_desc_mask_t mem_desc_mask(args.mask_base, args.mask_shape, coord); @@ -189,7 +305,7 @@ struct relu_unpack_mask_op_t; using mask_in_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; mem_desc_mask_t mem_desc_mask(args.mask_base, args.mask_shape, coord); @@ -307,6 +423,37 @@ struct silu_op_t { } }; +/// @brief Is the element-wise silu op functor. +/// Get the silu input from matAcc, update the the silu output in place, +/// Used in epilogue::tile_op or chained_tile_op. +struct silu_precise_op_t { + struct arguments_t {}; + template + __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, + const coord_t &coord, const arguments_t &args, + uint32_t slm_base = 0, uint32_t nbarrier_base = 0) { + constexpr int elems = matAcc_t::tile_desc::block_elems; + constexpr int rounds = matAcc_t::tile_desc::tile_elems / elems; +#pragma unroll + for (int i = 0; i < rounds; ++i) { + auto sub_vec = matAcc.reg.xetla_select(elems * i); + xetla_vector sigmoid_value + = xetla_sigmoid_precise(sub_vec); + sub_vec = sub_vec * sigmoid_value; + } + constexpr int remaining_elems = matAcc_t::tile_desc::tile_elems % elems; + if constexpr (remaining_elems != 0) { + auto sub_vec = matAcc.reg.xetla_select( + elems * (matAcc_t::tile_elems / elems)); + xetla_vector + sigmoid_value + = xetla_sigmoid_precise( + sub_vec); + sub_vec = sub_vec * sigmoid_value; + } + } +}; + /// @brief Is the element-wise gelu inference forward op functor. /// Get the gelu input from matAcc, update the the gelu output in place, /// Used in epilogue::tile_op or chained_tile_op. @@ -605,7 +752,7 @@ struct bias_add_op_t; using bias_t = tile_t; using bias_payload_t = mem_payload_t, arch_tag>; + msg_type_postop_v, arch_tag>; coord_t bias_coord(coord.x, 0); mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord); bias_t bias; @@ -769,7 +916,7 @@ struct scale_v_offset_v_op_t; using scale_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; coord_t scale_coord(coord.x, 0); scale_mem_desc_t scale_mem_desc( @@ -784,7 +931,7 @@ struct scale_v_offset_v_op_t; using offset_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; coord_t offset_coord(coord.x, 0); offset_mem_desc_t offset_mem_desc( @@ -888,7 +1035,7 @@ struct scale_v_op_t; using scale_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; coord_t scale_coord(coord.x, 0); scale_mem_desc_t scale_mem_desc( @@ -987,7 +1134,7 @@ struct elemwise_reduce_op_t; using mat_in_payload_t = mem_payload_t, arch_tag>; + msg_type_postop_v, arch_tag>; using mat_in_tile_acc_t = tile_t; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); mat_in_tile_t mat_in; @@ -1024,7 +1171,7 @@ struct elemwise_reduce_op_t; using mat_tail_in_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; using mat_tail_in_tile_acc_t = tile_t; @@ -1099,7 +1246,7 @@ struct elemwise_reduce_op_stream_k_t; using mat_in_payload_t = mem_payload_t, arch_tag>; + msg_type_postop_v, arch_tag>; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); mat_in_tile_t mat_in; mat_in_tile_t mat_zero(0); @@ -1208,7 +1355,7 @@ struct dropout_op_t; using mask_in_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; mem_desc_mask_t mem_desc_mask(args.base, args.shape, coord); mask_in_tile_t mask_in; @@ -1301,7 +1448,7 @@ struct rng_dropout_op_t; using mask_out_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; if (args.prob == 0) { return; } //calculate the scale internally @@ -1433,7 +1580,7 @@ struct linear_op_t; using mat_in_payload_t = mem_payload_t, arch_tag>; + msg_type_postop_v, arch_tag>; using mat_in_tile_acc_t = tile_t; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); mat_in_tile_t mat_in; @@ -1475,7 +1622,7 @@ struct linear_op_t; using mat_tail_in_payload_t = mem_payload_t, + msg_type_postop_v, arch_tag>; using mat_tail_in_tile_acc_t = tile_t; diff --git a/src/tests/functional/plugin/shared/include/shared_test_classes/subgraph/lora_pattern.hpp b/src/tests/functional/plugin/shared/include/shared_test_classes/subgraph/lora_pattern.hpp index ee381388499905..3c26a3286fe746 100644 --- a/src/tests/functional/plugin/shared/include/shared_test_classes/subgraph/lora_pattern.hpp +++ b/src/tests/functional/plugin/shared/include/shared_test_classes/subgraph/lora_pattern.hpp @@ -11,6 +11,7 @@ namespace test { class LoraPatternBase : public SubgraphBaseTest { protected: + bool is_low_precision(ov::element::Type net_type) const; void run_test_empty_tensors(); void run_test_random_tensors(ov::element::Type net_type, size_t lora_rank); @@ -29,6 +30,7 @@ using LoraMatMulParams = std::tuple { public: static std::string getTestCaseName(testing::TestParamInfo obj); + void generate_inputs(const std::vector& targetInputStaticShapes) override; void SetUp() override; }; diff --git a/src/tests/functional/plugin/shared/src/subgraph/lora_pattern.cpp b/src/tests/functional/plugin/shared/src/subgraph/lora_pattern.cpp index a01593558f68bb..cd24949865fac2 100644 --- a/src/tests/functional/plugin/shared/src/subgraph/lora_pattern.cpp +++ b/src/tests/functional/plugin/shared/src/subgraph/lora_pattern.cpp @@ -8,6 +8,7 @@ #include "common_test_utils/node_builders/convolution.hpp" #include "common_test_utils/ov_tensor_utils.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" +#include "shared_test_classes/base/utils/ranges.hpp" #include "template/properties.hpp" #include "openvino/op/add.hpp" #include "openvino/op/matmul.hpp" @@ -17,6 +18,10 @@ namespace ov { namespace test { +bool LoraPatternBase::is_low_precision(ov::element::Type net_type) const { + return net_type.size() <= 2; +} + void LoraPatternBase::run_test_empty_tensors() { compile_model(); inferRequest = compiledModel.create_infer_request(); @@ -34,7 +39,40 @@ void LoraPatternBase::run_test_empty_tensors() { ov::test::utils::compare(tx_result, tz_result); } +void LoraPatternMatmul::generate_inputs(const std::vector& targetInputStaticShapes) { + inputs.clear(); + ov::test::utils::ModelRange modelRange; + modelRange.find_mode_ranges(function); + + auto itTargetShape = targetInputStaticShapes.begin(); + for (const auto ¶m : function->get_parameters()) { + std::shared_ptr inputNode = param; + for (size_t i = 0; i < param->get_output_size(); i++) { + for (const auto &node : param->get_output_target_inputs(i)) { + std::shared_ptr nodePtr = node.get_node()->shared_from_this(); + for (size_t port = 0; port < nodePtr->get_input_size(); ++port) { + if (nodePtr->get_input_node_ptr(port)->shared_from_this() == inputNode->shared_from_this()) { + if (is_low_precision(nodePtr->get_input_element_type(port))) { + const auto& tensor = ov::test::utils::create_and_fill_tensor_real_distribution( + nodePtr->get_input_element_type(port), *itTargetShape, -2.f, 2.f, 0); + inputs.insert({param, tensor}); + } else { + inputs.insert({param, modelRange.generate_input(nodePtr, port, *itTargetShape)}); + } + break; + } + } + } + } + itTargetShape++; + } +} + void LoraPatternBase::run_test_random_tensors(ov::element::Type net_type, size_t lora_rank) { + if (net_type == ov::element::f16) { + GTEST_SKIP() << "Skipping test for f16 - accuracy issues on devices without DPAS"; + } + compile_model(); inferRequest = compiledModel.create_infer_request(); ASSERT_TRUE(inferRequest); @@ -127,6 +165,7 @@ std::string LoraPatternMatmul::getTestCaseName(testing::TestParamInfo