Skip to content

Commit e02ae65

Browse files
committed
Address Alexandra's comments
1 parent 38128f2 commit e02ae65

File tree

6 files changed

+35
-51
lines changed

6 files changed

+35
-51
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -252,27 +252,20 @@ void jit_emitter::store_context(const std::vector<size_t>& gpr_regs,
252252
// 1. General-purpose Registers - optimized to allocate stack space once
253253
const auto store_gpr_regs_size = gpr_regs.size();
254254
if (store_gpr_regs_size > 0) {
255-
// Calculate total stack space needed for all GPR registers
256-
const auto last = store_gpr_regs_size % 2;
257-
auto total_gpr_shift = 0U;
258-
for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) {
259-
total_gpr_shift += ov::intel_cpu::rnd_up(get_gpr_length() * 2, sp_alignment);
260-
}
261-
if (last != 0) {
262-
total_gpr_shift += ov::intel_cpu::rnd_up(get_gpr_length(), sp_alignment);
263-
}
255+
// Calculate total stack space needed for all GPR registers (align once)
256+
const auto total_gpr_shift = ov::intel_cpu::rnd_up(get_gpr_length() * store_gpr_regs_size, sp_alignment);
264257

265258
// Single stack allocation for all GPR registers
266259
h->sub(h->sp, h->sp, total_gpr_shift);
267260

268261
// Store GPR registers using stack offset (preserving original order)
262+
const auto last = store_gpr_regs_size % 2;
269263
auto current_offset = 0U;
270264
for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) {
271-
const auto shift = ov::intel_cpu::rnd_up(get_gpr_length() * 2, sp_alignment);
272265
h->stp(Xbyak_aarch64::XReg(gpr_regs[i]),
273266
Xbyak_aarch64::XReg(gpr_regs[i + 1]),
274267
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(current_offset)));
275-
current_offset += shift;
268+
current_offset += get_gpr_length() * 2;
276269
}
277270
if (last != 0) {
278271
h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]),
@@ -362,27 +355,19 @@ void jit_emitter::restore_context(const std::vector<size_t>& gpr_regs,
362355
const auto save_gpr_regs_size = gpr_regs.size();
363356
if (save_gpr_regs_size > 0) {
364357
// Calculate total stack space (must match store_context calculation)
365-
const auto last = save_gpr_regs_size % 2;
366-
auto total_gpr_shift = 0U;
367-
for (size_t i = 0; i < (save_gpr_regs_size - last); i += 2) {
368-
total_gpr_shift += ov::intel_cpu::rnd_up(get_gpr_length() * 2, sp_alignment);
369-
}
370-
if (last != 0) {
371-
total_gpr_shift += ov::intel_cpu::rnd_up(get_gpr_length(), sp_alignment);
372-
}
358+
const auto total_gpr_shift = ov::intel_cpu::rnd_up(get_gpr_length() * save_gpr_regs_size, sp_alignment);
373359

374360
// Restore GPR registers using stack offset (reverse order to match original behavior)
375-
auto current_offset = total_gpr_shift;
361+
const auto last = save_gpr_regs_size % 2;
362+
auto current_offset = get_gpr_length() * save_gpr_regs_size;
376363
if (last != 0) {
377-
const auto shift = ov::intel_cpu::rnd_up(get_gpr_length(), sp_alignment);
378-
current_offset -= shift;
364+
current_offset -= get_gpr_length();
379365
h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]),
380366
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(current_offset)));
381367
}
382368

383369
for (size_t i = last; i < save_gpr_regs_size; i += 2) {
384-
const auto shift = ov::intel_cpu::rnd_up(get_gpr_length() * 2, sp_alignment);
385-
current_offset -= shift;
370+
current_offset -= get_gpr_length() * 2;
386371
h->ldp(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - (i + 1)]),
387372
Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - i]),
388373
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(current_offset)));

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ static void load_with_offset_check(jit_generator* h, const RegType& dst, const X
4141
}
4242
} else {
4343
// Manual offset handling for other register types
44+
// Note: read about LDR (immediate) in the manual Arm A-profile A64 Instruction Set Architecture
4445
int max_offset = 4095; // Default fallback
4546
int alignment = 1; // Default fallback
4647
if constexpr (std::is_same_v<RegType, SReg>) {
@@ -58,7 +59,7 @@ static void load_with_offset_check(jit_generator* h, const RegType& dst, const X
5859
}
5960

6061
if (offset >= 0 && offset <= max_offset && (offset % alignment) == 0) {
61-
h->ldr(dst, ptr(src, offset));
62+
h->ldr(dst, ptr(src, static_cast<uint32_t>(offset)));
6263
} else {
6364
h->add_imm(h->X_DEFAULT_ADDR, src, offset, h->X_TMP_0);
6465
h->ldr(dst, ptr(h->X_DEFAULT_ADDR));

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_binary_call_emitter.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void jit_binary_call_emitter::init_binary_call_regs(size_t num_binary_args,
7070
}
7171

7272
// Add special registers that should not be allocated
73-
std::vector<size_t> reserved_regs = {
73+
static const std::vector<size_t> reserved_regs = {
7474
18, // Platform register (should not be used)
7575
29, // Frame pointer (FP)
7676
30, // Link register (LR)
@@ -129,8 +129,7 @@ void jit_binary_call_emitter::emit_stack_restore(size_t stack_size) const {
129129
OV_CPU_JIT_EMITTER_ASSERT(m_stack_preserved, "emit_stack_restore called without corresponding emit_stack_preserve");
130130

131131
// ARM64 requires 16-byte stack alignment
132-
const size_t alignment = 16;
133-
stack_size = ov::intel_cpu::rnd_up(stack_size, alignment);
132+
stack_size = ov::intel_cpu::rnd_up(stack_size, sp_alignment);
134133

135134
if (stack_size > 0) {
136135
h->add(h->sp, h->sp, stack_size);

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_copy_b_emitter.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "openvino/core/node.hpp"
2424
#include "openvino/core/type.hpp"
2525
#include "openvino/core/type/element_type.hpp"
26-
#include "snippets/emitter.hpp"
2726
#include "snippets/kernel_executor_table.hpp"
2827
#include "snippets/lowered/expression.hpp"
2928
#include "snippets/utils/utils.hpp"
@@ -88,19 +87,12 @@ void jit_gemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std
8887
}
8988

9089
void jit_gemm_copy_b_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const {
91-
const auto& regs_to_spill = get_regs_to_spill();
9290
std::unordered_set<size_t> exclude_spill = {};
93-
for (const auto& reg : regs_to_spill) {
94-
if (reg.type == snippets::RegType::gpr) {
95-
exclude_spill.insert(reg.idx);
96-
}
97-
}
9891
store_context(exclude_spill);
9992

10093
Xbyak_aarch64::XReg x0(0);
10194
Xbyak_aarch64::XReg x1(1);
10295
Xbyak_aarch64::XReg x2(2);
103-
Xbyak_aarch64::XReg aux_reg(3);
10496

10597
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);
10698

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ void jit_gemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const
8989
std::unordered_set<size_t> exclude_spill = {};
9090
store_context(exclude_spill);
9191

92-
auto reserved_stack_size = sizeof(GemmKaiKernelExecutor::call_args);
93-
reserved_stack_size = ov::intel_cpu::rnd_up(reserved_stack_size, sp_alignment);
92+
auto reserved_stack_size = ov::intel_cpu::rnd_up(sizeof(GemmKaiKernelExecutor::call_args), sp_alignment);
9493
emit_stack_preserve(reserved_stack_size);
9594

9695
const size_t A_offset = offsetof(GemmKaiKernelExecutor::call_args, A);
@@ -103,30 +102,38 @@ void jit_gemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const
103102

104103
for (size_t i = 0; i < mem_ptrs.size(); i++) {
105104
const bool is_dynamic_offset = ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]);
106-
const bool is_valid_buffer_id = !ov::snippets::utils::is_dynamic_value(m_buffer_ids[i]);
105+
const bool is_valid_buffer_id = m_buffer_ids[i] != SIZE_MAX;
107106

108-
std::vector<Xbyak_aarch64::XReg> aux_regs = {call_address_reg, callee_saved_reg, h->X_TMP_1};
107+
// Collect used register indices to avoid conflicts
108+
std::vector<size_t> used_gpr_idxs = {call_address_reg.getIdx(),
109+
callee_saved_reg.getIdx(),
110+
mem_ptrs[i].getIdx()};
109111

110112
if (is_dynamic_offset && is_valid_buffer_id) {
111-
aux_regs.emplace_back(h->X_TMP_0);
113+
// Get 3 auxiliary registers for dynamic offset handling (runtime offset needs at least 3)
114+
auto aux_gprs = ov::intel_cpu::aarch64::utils::get_aux_gprs(used_gpr_idxs, 3);
115+
std::vector<Xbyak_aarch64::XReg> aux_regs = {aux_gprs[0], aux_gprs[1], aux_gprs[2]};
112116
size_t runtime_offset = GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t);
113117
utils::push_ptr_with_runtime_offset_on_stack(h,
114118
gemm_args_offsets[i],
115119
mem_ptrs[i],
116120
aux_regs,
117121
runtime_offset);
118122
} else {
119-
size_t offset = is_dynamic_offset ? 0 : m_memory_offsets[i];
120-
utils::push_ptr_with_static_offset_on_stack(h, gemm_args_offsets[i], mem_ptrs[i], aux_regs, offset);
123+
// Static offset case (dynamic offsets with valid buffer IDs are handled above)
124+
// Get 2 auxiliary registers for static offset handling (static offset needs at least 2)
125+
auto aux_gprs = ov::intel_cpu::aarch64::utils::get_aux_gprs(used_gpr_idxs, 2);
126+
std::vector<Xbyak_aarch64::XReg> aux_regs = {aux_gprs[0], aux_gprs[1]};
127+
utils::push_ptr_with_static_offset_on_stack(h,
128+
gemm_args_offsets[i],
129+
mem_ptrs[i],
130+
aux_regs,
131+
m_memory_offsets[i]);
121132
}
122133
}
123134

124-
if (mem_ptrs.size() < 4) {
125-
h->mov(call_address_reg, 0);
126-
h->str(
127-
call_address_reg,
128-
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(ov::intel_cpu::rnd_up(reserved_stack_size, sp_alignment))));
129-
}
135+
h->mov(call_address_reg, 0);
136+
h->str(call_address_reg, Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(reserved_stack_size)));
130137

131138
Xbyak_aarch64::XReg x0(0);
132139
Xbyak_aarch64::XReg x1(1);

src/plugins/intel_cpu/src/emitters/snippets/aarch64/kernel_executors/gemm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, const
8181
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride_row);
8282
// in0, in1, out is point to current block memory, based on block loop info, and shift done in loop begin and
8383
// end emitters(adjusted copyb loop info as repack outside block loops).
84-
float* rhs_ptr = const_cast<float*>(static_cast<const float*>(args->B)) + rhs_packed_offset / sizeof(float);
85-
float* dst_ptr = const_cast<float*>(static_cast<const float*>(args->C)) + dst_offset / (sizeof(float));
84+
const float* rhs_ptr = static_cast<const float*>(args->B) + rhs_packed_offset / sizeof(float);
85+
float* dst_ptr = static_cast<float*>(args->C) + dst_offset / (sizeof(float));
8686
ukernel.run_matmul(M,
8787
n_block_size,
8888
K,

0 commit comments

Comments
 (0)