Skip to content

Commit b51e817

Browse files
committed
Unify GemmKaiKernelExecutor::execute
1 parent e5f74e4 commit b51e817

File tree

3 files changed

+8
-24
lines changed

3 files changed

+8
-24
lines changed

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,7 @@ void jit_gemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const
133133
Xbyak_aarch64::XReg x0(0);
134134
Xbyak_aarch64::XReg x1(1);
135135

136-
auto execute_func_ptr =
137-
static_cast<void (*)(const GemmKaiKernelExecutor*, const GemmKaiKernelExecutor::call_args*)>(
138-
GemmKaiKernelExecutor::execute);
139-
h->mov(call_address_reg, reinterpret_cast<uintptr_t>(execute_func_ptr));
136+
h->mov(call_address_reg, reinterpret_cast<uintptr_t>(GemmKaiKernelExecutor::execute));
140137

141138
h->mov(x0, reinterpret_cast<uintptr_t>(m_kernel_executor_kai.get()));
142139
h->mov(x1, h->sp);
@@ -153,10 +150,7 @@ uintptr_t jit_gemm_emitter::get_compiled_kernel_ptr() const {
153150
}
154151

155152
uintptr_t jit_gemm_emitter::get_execute_function_ptr() {
156-
auto execute_func_ptr =
157-
static_cast<void (*)(const GemmKaiKernelExecutor*, const GemmKaiKernelExecutor::call_args*)>(
158-
GemmKaiKernelExecutor::execute);
159-
return reinterpret_cast<const uintptr_t>(execute_func_ptr);
153+
return reinterpret_cast<const uintptr_t>(GemmKaiKernelExecutor::execute);
160154
}
161155

162156
} // namespace ov::intel_cpu::aarch64

src/plugins/intel_cpu/src/emitters/snippets/aarch64/kernel_executors/gemm.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ void GemmKaiKernelExecutor::update_config(const ov::snippets::lowered::Expressio
5454
config.update(M, N, K, LDA, LDB, LDC, beta);
5555
}
5656

57-
void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, void* in0, void* in1, void* out0) {
57+
void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, const call_args* args) {
5858
OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor");
59-
// matmul for input1 and slices of repacked input2
59+
OV_CPU_JIT_EMITTER_ASSERT(args, "has nullptr args");
60+
6061
const auto& config = static_cast<const GemmKernelKaiConfig&>(executor->get_config());
6162
const auto& kernel = executor->get_kernel();
6263
const auto& ukernel = *kernel->gemm_ukernel;
@@ -80,12 +81,12 @@ void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, void*
8081
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride_row);
8182
// in0, in1, out is point to current block memory, based on block loop info, and shift done in loop begin and
8283
// end emitters(adjusted copyb loop info as repack outside block loops).
83-
float* rhs_ptr = static_cast<float*>(in1) + rhs_packed_offset / sizeof(float);
84-
float* dst_ptr = (static_cast<float*>(out0) + dst_offset / (sizeof(float)));
84+
float* rhs_ptr = const_cast<float*>(static_cast<const float*>(args->B)) + rhs_packed_offset / sizeof(float);
85+
float* dst_ptr = const_cast<float*>(static_cast<const float*>(args->C)) + dst_offset / (sizeof(float));
8586
ukernel.run_matmul(M,
8687
n_block_size,
8788
K,
88-
in0,
89+
const_cast<void*>(args->A),
8990
lhs_stride,
9091
rhs_ptr,
9192
dst_ptr,
@@ -96,11 +97,4 @@ void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, void*
9697
}
9798
}
9899

99-
void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, const call_args* args) {
100-
OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor");
101-
OV_CPU_JIT_EMITTER_ASSERT(args, "has nullptr args");
102-
103-
execute(executor, const_cast<void*>(args->A), const_cast<void*>(args->B), args->C);
104-
}
105-
106100
} // namespace ov::intel_cpu::aarch64

src/plugins/intel_cpu/src/emitters/snippets/aarch64/kernel_executors/gemm.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ class GemmKaiKernelExecutor : public snippets::KernelExecutor<GemmKernelKaiConfi
6565
// No need kernel update, just update config is enough for update. The universal ukernel is reused with any config.
6666
void update_kernel(const GemmKernelKaiConfig& config,
6767
std::shared_ptr<GemmCompiledKernel>& kernel) const override final;
68-
69-
// Function that will be called in runtime to execute the kernel
70-
static void execute(const GemmKaiKernelExecutor* executor, void* in0, void* in1, void* out0);
71-
7268
// ABI-compliant execute function that takes call_args structure
7369
static void execute(const GemmKaiKernelExecutor* executor, const call_args* args);
7470

0 commit comments

Comments
 (0)