1515#include  < unordered_set> 
1616#include  < vector> 
1717
18+ #include  " emitters/snippets/aarch64/jit_binary_call_emitter.hpp" 
1819#include  " emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp" 
1920#include  " emitters/snippets/aarch64/utils.hpp" 
2021#include  " emitters/snippets/utils/utils.hpp" 
2122#include  " emitters/utils.hpp" 
2223#include  " openvino/core/node.hpp" 
2324#include  " openvino/core/type.hpp" 
2425#include  " openvino/core/type/element_type.hpp" 
26+ #include  " snippets/emitter.hpp" 
2527#include  " snippets/kernel_executor_table.hpp" 
2628#include  " snippets/lowered/expression.hpp" 
2729#include  " snippets/utils/utils.hpp" 
@@ -38,7 +40,7 @@ jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(jit_generator* h,
3840                                                 cpu_isa_t  isa,
3941                                                 const  ExpressionPtr& expr,
4042                                                 const  snippets::KernelExecutorTablePtr& kernel_table)
41-     : jit_emitter (h, isa) {
43+     : jit_binary_call_emitter (h, isa, expr-> get_live_regs () ) {
4244    in_out_type_ = emitter_in_out_map::gpr_to_gpr;
4345    const  auto  gemm_repack = ov::as_type_ptr<GemmCopyB>(expr->get_node ());
4446    OV_CPU_JIT_EMITTER_ASSERT (gemm_repack, " expects GemmCopyB node"  );
@@ -78,17 +80,28 @@ void jit_gemm_copy_b_emitter::validate_arguments(const std::vector<size_t>& in,
7880
7981void  jit_gemm_copy_b_emitter::emit_impl (const  std::vector<size_t >& in, const  std::vector<size_t >& out) const  {
8082    validate_arguments (in, out);
81-     //  todo: use optimized reg spill after CVS-162498
82-     std::unordered_set<size_t > exclude = {};
83-     store_context (exclude);
83+ 
84+     std::vector<size_t > mem_ptrs_idxs{in[0 ], out[0 ]};
85+ 
86+     init_binary_call_regs (2 , mem_ptrs_idxs);
87+     emit_call (mem_ptrs_idxs);
88+ }
89+ 
90+ void  jit_gemm_copy_b_emitter::emit_call (const  std::vector<size_t >& mem_ptrs_idxs) const  {
91+     const  auto & regs_to_spill = get_regs_to_spill ();
92+     std::unordered_set<size_t > exclude_spill = {};
93+     for  (const  auto & reg : regs_to_spill) {
94+         if  (reg.type  == snippets::RegType::gpr) {
95+             exclude_spill.insert (reg.idx );
96+         }
97+     }
98+     store_context (exclude_spill);
8499
85100    Xbyak_aarch64::XReg x0 (0 );
86101    Xbyak_aarch64::XReg x1 (1 );
87102    Xbyak_aarch64::XReg x2 (2 );
88103    Xbyak_aarch64::XReg aux_reg (3 );
89104
90-     //  Prepare memory pointers with offsets
91-     std::vector<size_t > mem_ptrs_idxs{in[0 ], out[0 ]};
92105    const  auto & mem_ptrs = utils::transform_idxs_to_regs (mem_ptrs_idxs);
93106
94107    //  Apply memory offsets and load adjusted pointers
@@ -111,11 +124,11 @@ void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std
111124    const  auto & compiled_kernel = get_compiled_kernel_ptr ();
112125    h->mov (x0, compiled_kernel);
113126
114-     Xbyak_aarch64::XReg  func_reg ( 9 );
115-     h->mov (func_reg , get_execute_function_ptr ());
116-     h->blr (func_reg );
127+     const   auto & call_address_reg =  get_call_address_reg ( );
128+     h->mov (call_address_reg , get_execute_function_ptr ());
129+     h->blr (call_address_reg );
117130
118-     restore_context (exclude );
131+     restore_context (exclude_spill );
119132}
120133
121134uintptr_t  jit_gemm_copy_b_emitter::get_compiled_kernel_ptr () const  {
@@ -125,4 +138,4 @@ uintptr_t jit_gemm_copy_b_emitter::get_compiled_kernel_ptr() const {
125138uintptr_t  jit_gemm_copy_b_emitter::get_execute_function_ptr () {
126139    return  reinterpret_cast <const  uintptr_t >(GemmCopyBKaiKernelExecutor::execute);
127140}
128- }  //  namespace ov::intel_cpu::aarch64
141+ }  //  namespace ov::intel_cpu::aarch64
0 commit comments