15
15
#include < unordered_set>
16
16
#include < vector>
17
17
18
+ #include " emitters/snippets/aarch64/jit_binary_call_emitter.hpp"
18
19
#include " emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
19
20
#include " emitters/snippets/aarch64/utils.hpp"
20
21
#include " emitters/snippets/utils/utils.hpp"
21
22
#include " emitters/utils.hpp"
22
23
#include " openvino/core/node.hpp"
23
24
#include " openvino/core/type.hpp"
24
25
#include " openvino/core/type/element_type.hpp"
26
+ #include " snippets/emitter.hpp"
25
27
#include " snippets/kernel_executor_table.hpp"
26
28
#include " snippets/lowered/expression.hpp"
27
29
#include " snippets/utils/utils.hpp"
@@ -38,7 +40,7 @@ jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(jit_generator* h,
38
40
cpu_isa_t isa,
39
41
const ExpressionPtr& expr,
40
42
const snippets::KernelExecutorTablePtr& kernel_table)
41
- : jit_emitter (h, isa) {
43
+ : jit_binary_call_emitter (h, isa, expr-> get_live_regs () ) {
42
44
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
43
45
const auto gemm_repack = ov::as_type_ptr<GemmCopyB>(expr->get_node ());
44
46
OV_CPU_JIT_EMITTER_ASSERT (gemm_repack, " expects GemmCopyB node" );
@@ -78,17 +80,28 @@ void jit_gemm_copy_b_emitter::validate_arguments(const std::vector<size_t>& in,
78
80
79
81
void jit_gemm_copy_b_emitter::emit_impl (const std::vector<size_t >& in, const std::vector<size_t >& out) const {
80
82
validate_arguments (in, out);
81
- // todo: use optimized reg spill after CVS-162498
82
- std::unordered_set<size_t > exclude = {};
83
- store_context (exclude);
83
+
84
+ std::vector<size_t > mem_ptrs_idxs{in[0 ], out[0 ]};
85
+
86
+ init_binary_call_regs (2 , mem_ptrs_idxs);
87
+ emit_call (mem_ptrs_idxs);
88
+ }
89
+
90
+ void jit_gemm_copy_b_emitter::emit_call (const std::vector<size_t >& mem_ptrs_idxs) const {
91
+ const auto & regs_to_spill = get_regs_to_spill ();
92
+ std::unordered_set<size_t > exclude_spill = {};
93
+ for (const auto & reg : regs_to_spill) {
94
+ if (reg.type == snippets::RegType::gpr) {
95
+ exclude_spill.insert (reg.idx );
96
+ }
97
+ }
98
+ store_context (exclude_spill);
84
99
85
100
Xbyak_aarch64::XReg x0 (0 );
86
101
Xbyak_aarch64::XReg x1 (1 );
87
102
Xbyak_aarch64::XReg x2 (2 );
88
103
Xbyak_aarch64::XReg aux_reg (3 );
89
104
90
- // Prepare memory pointers with offsets
91
- std::vector<size_t > mem_ptrs_idxs{in[0 ], out[0 ]};
92
105
const auto & mem_ptrs = utils::transform_idxs_to_regs (mem_ptrs_idxs);
93
106
94
107
// Apply memory offsets and load adjusted pointers
@@ -111,11 +124,11 @@ void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std
111
124
const auto & compiled_kernel = get_compiled_kernel_ptr ();
112
125
h->mov (x0, compiled_kernel);
113
126
114
- Xbyak_aarch64::XReg func_reg ( 9 );
115
- h->mov (func_reg , get_execute_function_ptr ());
116
- h->blr (func_reg );
127
+ const auto & call_address_reg = get_call_address_reg ( );
128
+ h->mov (call_address_reg , get_execute_function_ptr ());
129
+ h->blr (call_address_reg );
117
130
118
- restore_context (exclude );
131
+ restore_context (exclude_spill );
119
132
}
120
133
121
134
uintptr_t jit_gemm_copy_b_emitter::get_compiled_kernel_ptr () const {
@@ -125,4 +138,4 @@ uintptr_t jit_gemm_copy_b_emitter::get_compiled_kernel_ptr() const {
125
138
uintptr_t jit_gemm_copy_b_emitter::get_execute_function_ptr () {
126
139
return reinterpret_cast <const uintptr_t >(GemmCopyBKaiKernelExecutor::execute);
127
140
}
128
- } // namespace ov::intel_cpu::aarch64
141
+ } // namespace ov::intel_cpu::aarch64
0 commit comments