@@ -54,9 +54,10 @@ void GemmKaiKernelExecutor::update_config(const ov::snippets::lowered::Expressio
54
54
config.update (M, N, K, LDA, LDB, LDC, beta);
55
55
}
56
56
57
- void GemmKaiKernelExecutor::execute (const GemmKaiKernelExecutor* executor, void * in0, void * in1, void * out0 ) {
57
+ void GemmKaiKernelExecutor::execute (const GemmKaiKernelExecutor* executor, const call_args* args ) {
58
58
OV_CPU_JIT_EMITTER_ASSERT (executor, " has nullptr executor" );
59
- // matmul for input1 and slices of repacked input2
59
+ OV_CPU_JIT_EMITTER_ASSERT (args, " has nullptr args" );
60
+
60
61
const auto & config = static_cast <const GemmKernelKaiConfig&>(executor->get_config ());
61
62
const auto & kernel = executor->get_kernel ();
62
63
const auto & ukernel = *kernel->gemm_ukernel ;
@@ -80,12 +81,12 @@ void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, void*
80
81
const size_t dst_offset = ukernel.get_dst_offset (0 , n_start, dst_stride_row);
81
82
// in0, in1, out is point to current block memory, based on block loop info, and shift done in loop begin and
82
83
// end emitters(adjusted copyb loop info as repack outside block loops).
83
- float * rhs_ptr = static_cast <float *>(in1 ) + rhs_packed_offset / sizeof (float );
84
- float * dst_ptr = (static_cast <float *>(out0) + dst_offset / (sizeof (float ) ));
84
+ float * rhs_ptr = const_cast < float *>( static_cast <const float *>(args-> B ) ) + rhs_packed_offset / sizeof (float );
85
+ float * dst_ptr = const_cast < float *> (static_cast <const float *>(args-> C )) + dst_offset / (sizeof (float ));
85
86
ukernel.run_matmul (M,
86
87
n_block_size,
87
88
K,
88
- in0 ,
89
+ const_cast < void *>(args-> A ) ,
89
90
lhs_stride,
90
91
rhs_ptr,
91
92
dst_ptr,
@@ -96,11 +97,4 @@ void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, void*
96
97
}
97
98
}
98
99
99
- void GemmKaiKernelExecutor::execute (const GemmKaiKernelExecutor* executor, const call_args* args) {
100
- OV_CPU_JIT_EMITTER_ASSERT (executor, " has nullptr executor" );
101
- OV_CPU_JIT_EMITTER_ASSERT (args, " has nullptr args" );
102
-
103
- execute (executor, const_cast <void *>(args->A ), const_cast <void *>(args->B ), args->C );
104
- }
105
-
106
100
} // namespace ov::intel_cpu::aarch64
0 commit comments