Skip to content

Commit e163aef

Browse files
committed
copyb
1 parent b51e817 commit e163aef

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_copy_b_emitter.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
#include <unordered_set>
1616
#include <vector>
1717

18+
#include "emitters/snippets/aarch64/jit_binary_call_emitter.hpp"
1819
#include "emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
1920
#include "emitters/snippets/aarch64/utils.hpp"
2021
#include "emitters/snippets/utils/utils.hpp"
2122
#include "emitters/utils.hpp"
2223
#include "openvino/core/node.hpp"
2324
#include "openvino/core/type.hpp"
2425
#include "openvino/core/type/element_type.hpp"
26+
#include "snippets/emitter.hpp"
2527
#include "snippets/kernel_executor_table.hpp"
2628
#include "snippets/lowered/expression.hpp"
2729
#include "snippets/utils/utils.hpp"
@@ -38,7 +40,7 @@ jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(jit_generator* h,
3840
cpu_isa_t isa,
3941
const ExpressionPtr& expr,
4042
const snippets::KernelExecutorTablePtr& kernel_table)
41-
: jit_emitter(h, isa) {
43+
: jit_binary_call_emitter(h, isa, expr->get_live_regs()) {
4244
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
4345
const auto gemm_repack = ov::as_type_ptr<GemmCopyB>(expr->get_node());
4446
OV_CPU_JIT_EMITTER_ASSERT(gemm_repack, "expects GemmCopyB node");
@@ -78,17 +80,28 @@ void jit_gemm_copy_b_emitter::validate_arguments(const std::vector<size_t>& in,
7880

7981
void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
8082
validate_arguments(in, out);
81-
// todo: use optimized reg spill after CVS-162498
82-
std::unordered_set<size_t> exclude = {};
83-
store_context(exclude);
83+
84+
std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};
85+
86+
init_binary_call_regs(2, mem_ptrs_idxs);
87+
emit_call(mem_ptrs_idxs);
88+
}
89+
90+
void jit_gemm_copy_b_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const {
91+
const auto& regs_to_spill = get_regs_to_spill();
92+
std::unordered_set<size_t> exclude_spill = {};
93+
for (const auto& reg : regs_to_spill) {
94+
if (reg.type == snippets::RegType::gpr) {
95+
exclude_spill.insert(reg.idx);
96+
}
97+
}
98+
store_context(exclude_spill);
8499

85100
Xbyak_aarch64::XReg x0(0);
86101
Xbyak_aarch64::XReg x1(1);
87102
Xbyak_aarch64::XReg x2(2);
88103
Xbyak_aarch64::XReg aux_reg(3);
89104

90-
// Prepare memory pointers with offsets
91-
std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};
92105
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);
93106

94107
// Apply memory offsets and load adjusted pointers
@@ -111,11 +124,11 @@ void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std
111124
const auto& compiled_kernel = get_compiled_kernel_ptr();
112125
h->mov(x0, compiled_kernel);
113126

114-
Xbyak_aarch64::XReg func_reg(9);
115-
h->mov(func_reg, get_execute_function_ptr());
116-
h->blr(func_reg);
127+
const auto& call_address_reg = get_call_address_reg();
128+
h->mov(call_address_reg, get_execute_function_ptr());
129+
h->blr(call_address_reg);
117130

118-
restore_context(exclude);
131+
restore_context(exclude_spill);
119132
}
120133

121134
uintptr_t jit_gemm_copy_b_emitter::get_compiled_kernel_ptr() const {
@@ -125,4 +138,4 @@ uintptr_t jit_gemm_copy_b_emitter::get_compiled_kernel_ptr() const {
125138
uintptr_t jit_gemm_copy_b_emitter::get_execute_function_ptr() {
126139
return reinterpret_cast<const uintptr_t>(GemmCopyBKaiKernelExecutor::execute);
127140
}
128-
} // namespace ov::intel_cpu::aarch64
141+
} // namespace ov::intel_cpu::aarch64

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_copy_b_emitter.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
#pragma once
66

7-
#include "emitters/plugin/aarch64/jit_emitter.hpp"
7+
#include "emitters/snippets/aarch64/jit_binary_call_emitter.hpp"
88
#include "emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
9+
#include "snippets/emitter.hpp"
910

1011
namespace ov::intel_cpu::aarch64 {
11-
class jit_gemm_copy_b_emitter : public jit_emitter {
12+
class jit_gemm_copy_b_emitter : public jit_binary_call_emitter {
1213
public:
1314
jit_gemm_copy_b_emitter(dnnl::impl::cpu::aarch64::jit_generator* h,
1415
dnnl::impl::cpu::aarch64::cpu_isa_t isa,
@@ -25,6 +26,7 @@ class jit_gemm_copy_b_emitter : public jit_emitter {
2526
protected:
2627
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
2728
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
29+
void emit_call(const std::vector<size_t>& mem_ptrs_idxs) const;
2830

2931
static uintptr_t get_execute_function_ptr();
3032
uintptr_t get_compiled_kernel_ptr() const;

0 commit comments

Comments
 (0)