Skip to content

Commit a20f240

Browse files
committed
Apply Vladislav's comments
1 parent 954c63f commit a20f240

File tree

7 files changed

+118
-109
lines changed

7 files changed

+118
-109
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -260,16 +260,15 @@ void jit_emitter::store_context(const std::vector<size_t>& gpr_regs,
260260

261261
// Store GPR registers using stack offset (preserving original order)
262262
const auto last = store_gpr_regs_size % 2;
263-
auto current_offset = 0U;
263+
int32_t current_offset = 0;
264264
for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) {
265265
h->stp(Xbyak_aarch64::XReg(gpr_regs[i]),
266266
Xbyak_aarch64::XReg(gpr_regs[i + 1]),
267-
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(current_offset)));
268-
current_offset += get_gpr_length() * 2;
267+
Xbyak_aarch64::ptr(h->sp, current_offset));
268+
current_offset += static_cast<int32_t>(get_gpr_length() * 2);
269269
}
270270
if (last != 0) {
271-
h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]),
272-
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(current_offset)));
271+
h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]), Xbyak_aarch64::ptr(h->sp, current_offset));
273272
}
274273
}
275274

@@ -354,25 +353,21 @@ void jit_emitter::restore_context(const std::vector<size_t>& gpr_regs,
354353
// 2. General-purpose Registers - optimized to deallocate stack space once
355354
const auto save_gpr_regs_size = gpr_regs.size();
356355
if (save_gpr_regs_size > 0) {
357-
// Calculate total stack space (must match store_context calculation)
358-
const auto total_gpr_shift = ov::intel_cpu::rnd_up(get_gpr_length() * save_gpr_regs_size, sp_alignment);
359-
360356
// Restore GPR registers using stack offset (reverse order to match original behavior)
361357
const auto last = save_gpr_regs_size % 2;
362-
auto current_offset = get_gpr_length() * save_gpr_regs_size;
363358
if (last != 0) {
364-
current_offset -= get_gpr_length();
365-
h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]),
366-
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(current_offset)));
359+
int32_t current_offset = get_gpr_length() * save_gpr_regs_size - get_gpr_length();
360+
h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]), Xbyak_aarch64::ptr(h->sp, current_offset));
367361
}
368362

369363
for (size_t i = last; i < save_gpr_regs_size; i += 2) {
370-
current_offset -= get_gpr_length() * 2;
364+
int32_t current_offset = get_gpr_length() * (save_gpr_regs_size - (i + 2));
371365
h->ldp(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - (i + 1)]),
372366
Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - i]),
373-
Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(current_offset)));
367+
Xbyak_aarch64::ptr(h->sp, current_offset));
374368
}
375369

370+
const auto total_gpr_shift = ov::intel_cpu::rnd_up(get_gpr_length() * save_gpr_regs_size, sp_alignment);
376371
// Single stack deallocation for all GPR registers
377372
h->add(h->sp, h->sp, total_gpr_shift);
378373
}

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,50 @@ namespace ov::intel_cpu::aarch64 {
2525
using jit_generator = dnnl::impl::cpu::aarch64::jit_generator;
2626
using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t;
2727

28+
// Helper function to get max_offset and alignment for different register types
29+
template <typename RegType>
30+
static std::pair<int, int> get_load_store_limits() {
31+
// Default fallback
32+
int max_offset = 4095;
33+
int alignment = 1;
34+
35+
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
36+
max_offset = 65520; // 4095 * 16
37+
alignment = 16;
38+
} else if constexpr (std::is_same_v<RegType, SReg>) {
39+
max_offset = 16380;
40+
alignment = 4;
41+
} else if constexpr (std::is_same_v<RegType, DReg>) {
42+
max_offset = 32760;
43+
alignment = 8;
44+
} else if constexpr (std::is_same_v<RegType, HReg>) {
45+
max_offset = 8190;
46+
alignment = 2;
47+
} else if constexpr (std::is_same_v<RegType, BReg>) {
48+
max_offset = 4095;
49+
alignment = 1;
50+
}
51+
52+
return {max_offset, alignment};
53+
}
54+
2855
// Helper function to load with large offset handling
2956
template <typename RegType>
3057
static void load_with_offset_check(jit_generator* h, const RegType& dst, const XReg& src, int offset) {
31-
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
32-
// Manual offset handling for VReg/QReg due to uni_ldr limitations
33-
const int off_mod = offset % 16;
34-
const int off_mul_vl = offset / 16;
58+
const auto [max_offset, alignment] = get_load_store_limits<RegType>();
3559

36-
if (off_mod == 0 && off_mul_vl >= 0 && off_mul_vl <= 4095) {
60+
if (offset >= 0 && offset <= max_offset && (offset % alignment) == 0) {
61+
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
3762
h->ldr(QReg(dst.getIdx()), ptr(src, static_cast<uint32_t>(offset)));
3863
} else {
39-
h->add_imm(h->X_DEFAULT_ADDR, src, offset, h->X_TMP_0);
40-
h->ldr(QReg(dst.getIdx()), ptr(h->X_DEFAULT_ADDR));
64+
h->ldr(dst, ptr(src, static_cast<uint32_t>(offset)));
4165
}
4266
} else {
43-
// Manual offset handling for other register types
44-
// Note: read about LDR (immediate) in the manual Arm A-profile A64 Instruction Set Architecture
45-
int max_offset = 4095; // Default fallback
46-
int alignment = 1; // Default fallback
47-
if constexpr (std::is_same_v<RegType, SReg>) {
48-
max_offset = 16380;
49-
alignment = 4;
50-
} else if constexpr (std::is_same_v<RegType, DReg>) {
51-
max_offset = 32760;
52-
alignment = 8;
53-
} else if constexpr (std::is_same_v<RegType, HReg>) {
54-
max_offset = 8190;
55-
alignment = 2;
56-
} else if constexpr (std::is_same_v<RegType, BReg>) {
57-
max_offset = 4095;
58-
alignment = 1;
59-
}
60-
61-
if (offset >= 0 && offset <= max_offset && (offset % alignment) == 0) {
62-
h->ldr(dst, ptr(src, static_cast<uint32_t>(offset)));
67+
// Use add_imm which handles register allocation internally
68+
h->add_imm(h->X_DEFAULT_ADDR, src, offset, h->X_TMP_0);
69+
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
70+
h->ldr(QReg(dst.getIdx()), ptr(h->X_DEFAULT_ADDR));
6371
} else {
64-
h->add_imm(h->X_DEFAULT_ADDR, src, offset, h->X_TMP_0);
6572
h->ldr(dst, ptr(h->X_DEFAULT_ADDR));
6673
}
6774
}
@@ -70,39 +77,20 @@ static void load_with_offset_check(jit_generator* h, const RegType& dst, const X
7077
// Helper function to store with large offset handling
7178
template <typename RegType>
7279
static void store_with_offset_check(jit_generator* h, const RegType& src, const XReg& dst, int offset) {
73-
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
74-
// Manual offset handling for VReg/QReg due to uni_str limitations
75-
const int off_mod = offset % 16;
76-
const int off_mul_vl = offset / 16;
80+
const auto [max_offset, alignment] = get_load_store_limits<RegType>();
7781

78-
if (off_mod == 0 && off_mul_vl >= 0 && off_mul_vl <= 4095) {
82+
if (offset >= 0 && offset <= max_offset && (offset % alignment) == 0) {
83+
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
7984
h->str(QReg(src.getIdx()), ptr(dst, static_cast<uint32_t>(offset)));
8085
} else {
81-
h->add_imm(h->X_DEFAULT_ADDR, dst, offset, h->X_TMP_0);
82-
h->str(QReg(src.getIdx()), ptr(h->X_DEFAULT_ADDR));
86+
h->str(src, ptr(dst, static_cast<uint32_t>(offset)));
8387
}
8488
} else {
85-
// Manual offset handling for other register types
86-
int max_offset = 4095; // Default fallback
87-
int alignment = 1; // Default fallback
88-
if constexpr (std::is_same_v<RegType, SReg>) {
89-
max_offset = 16380;
90-
alignment = 4;
91-
} else if constexpr (std::is_same_v<RegType, DReg>) {
92-
max_offset = 32760;
93-
alignment = 8;
94-
} else if constexpr (std::is_same_v<RegType, HReg>) {
95-
max_offset = 8190;
96-
alignment = 2;
97-
} else if constexpr (std::is_same_v<RegType, BReg>) {
98-
max_offset = 4095;
99-
alignment = 1;
100-
}
101-
102-
if (offset >= 0 && offset <= max_offset && (offset % alignment) == 0) {
103-
h->str(src, ptr(dst, static_cast<uint32_t>(offset)));
89+
// Use add_imm which handles register allocation internally
90+
h->add_imm(h->X_DEFAULT_ADDR, dst, offset, h->X_TMP_0);
91+
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
92+
h->str(QReg(src.getIdx()), ptr(h->X_DEFAULT_ADDR));
10493
} else {
105-
h->add_imm(h->X_DEFAULT_ADDR, dst, offset, h->X_TMP_0);
10694
h->str(src, ptr(h->X_DEFAULT_ADDR));
10795
}
10896
}

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_binary_call_emitter.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,9 @@ void jit_binary_call_emitter::init_binary_call_regs(size_t num_binary_args,
121121
void jit_binary_call_emitter::emit_stack_preserve(size_t stack_size) const {
122122
OV_CPU_JIT_EMITTER_ASSERT(!m_stack_preserved, "emit_stack_preserve called twice without emit_stack_restore");
123123

124-
// ARM64 requires 16-byte stack alignment
125-
stack_size = ov::intel_cpu::rnd_up(stack_size, sp_alignment);
126-
127124
if (stack_size > 0) {
125+
// ARM64 requires 16-byte stack alignment
126+
stack_size = ov::intel_cpu::rnd_up(stack_size, sp_alignment);
128127
h->sub(h->sp, h->sp, stack_size);
129128
}
130129

@@ -134,10 +133,9 @@ void jit_binary_call_emitter::emit_stack_preserve(size_t stack_size) const {
134133
void jit_binary_call_emitter::emit_stack_restore(size_t stack_size) const {
135134
OV_CPU_JIT_EMITTER_ASSERT(m_stack_preserved, "emit_stack_restore called without corresponding emit_stack_preserve");
136135

137-
// ARM64 requires 16-byte stack alignment
138-
stack_size = ov::intel_cpu::rnd_up(stack_size, sp_alignment);
139-
140136
if (stack_size > 0) {
137+
// ARM64 requires 16-byte stack alignment
138+
stack_size = ov::intel_cpu::rnd_up(stack_size, sp_alignment);
141139
h->add(h->sp, h->sp, stack_size);
142140
}
143141

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.cpp

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -101,42 +101,19 @@ void jit_gemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const
101101

102102
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);
103103

104-
// Collect used register indices to avoid conflicts
104+
// Collect used register indices to avoid conflicts with auxiliary registers
105105
std::vector<size_t> used_gpr_idxs = {call_address_reg.getIdx(), callee_saved_reg.getIdx()};
106-
107-
for (size_t i = 0; i < mem_ptrs.size(); i++) {
108-
const bool is_dynamic_offset = ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]);
109-
const bool is_valid_buffer_id = m_buffer_ids[i] != SIZE_MAX;
110-
111-
// Add current register to avoid conflicts in auxiliary register allocation
112-
used_gpr_idxs.push_back(mem_ptrs[i].getIdx());
113-
114-
if (is_dynamic_offset && is_valid_buffer_id) {
115-
OPENVINO_ASSERT(is_valid_buffer_id, "In dynamic case Buffer ID must be defined");
116-
// Get 3 auxiliary registers for dynamic offset handling (runtime offset needs at least 3)
117-
auto aux_gprs = ov::intel_cpu::aarch64::utils::get_aux_gprs(used_gpr_idxs, 3);
118-
std::vector<Xbyak_aarch64::XReg> aux_regs = {aux_gprs[0], aux_gprs[1], aux_gprs[2]};
119-
size_t runtime_offset = GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t);
120-
utils::push_ptr_with_runtime_offset_on_stack(h,
121-
gemm_args_offsets[i],
122-
mem_ptrs[i],
123-
aux_regs,
124-
runtime_offset);
125-
} else {
126-
// Static offset case (dynamic offsets with valid buffer IDs are handled above)
127-
// Get 2 auxiliary registers for static offset handling (static offset needs at least 2)
128-
auto aux_gprs = ov::intel_cpu::aarch64::utils::get_aux_gprs(used_gpr_idxs, 2);
129-
std::vector<Xbyak_aarch64::XReg> aux_regs = {aux_gprs[0], aux_gprs[1]};
130-
utils::push_ptr_with_static_offset_on_stack(h,
131-
gemm_args_offsets[i],
132-
mem_ptrs[i],
133-
aux_regs,
134-
m_memory_offsets[i]);
135-
}
106+
for (const auto& reg : mem_ptrs) {
107+
used_gpr_idxs.push_back(reg.getIdx());
136108
}
137109

138-
h->mov(call_address_reg, 0);
139-
h->str(call_address_reg, Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(reserved_stack_size)));
110+
// Get auxiliary registers for the helper function (needs at least 3 for dynamic offsets)
111+
auto aux_gprs = ov::intel_cpu::aarch64::utils::get_aux_gprs(used_gpr_idxs, 3);
112+
113+
// Use the new helper function to push all pointers with offsets to their stack locations
114+
utils::push_ptrs_with_offsets_to_stack(h, mem_ptrs, m_memory_offsets, m_buffer_ids, aux_gprs, gemm_args_offsets);
115+
116+
// Note: scratch field was removed per earlier review feedback, so we don't need to zero it
140117

141118
Xbyak_aarch64::XReg x0(0);
142119
Xbyak_aarch64::XReg x1(1);

src/plugins/intel_cpu/src/emitters/snippets/aarch64/kernel_executors/gemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ void GemmKaiKernelExecutor::execute(const GemmKaiKernelExecutor* executor, const
8686
ukernel.run_matmul(M,
8787
n_block_size,
8888
K,
89-
static_cast<const void*>(args->A),
89+
args->A,
9090
lhs_stride,
9191
rhs_ptr,
9292
dst_ptr,

src/plugins/intel_cpu/src/emitters/snippets/aarch64/utils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,37 @@ void push_and_load_ptrs_with_offsets(dnnl::impl::cpu::aarch64::jit_generator* h,
178178
h->add(h->sp, h->sp, sp_size);
179179
}
180180

181+
void push_ptrs_with_offsets_to_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
182+
const std::vector<Xbyak_aarch64::XReg>& mem_ptrs,
183+
const std::vector<size_t>& memory_offsets,
184+
const std::vector<size_t>& buffer_ids,
185+
const std::vector<Xbyak_aarch64::XReg>& aux_regs,
186+
const std::vector<size_t>& stack_offsets) {
187+
OV_CPU_JIT_EMITTER_ASSERT(mem_ptrs.size() == memory_offsets.size(), "mem_ptrs and memory_offsets size mismatch");
188+
OV_CPU_JIT_EMITTER_ASSERT(mem_ptrs.size() == buffer_ids.size(), "mem_ptrs and buffer_ids size mismatch");
189+
OV_CPU_JIT_EMITTER_ASSERT(mem_ptrs.size() == stack_offsets.size(), "mem_ptrs and stack_offsets size mismatch");
190+
OV_CPU_JIT_EMITTER_ASSERT(aux_regs.size() >= 3, "aux_regs must contain at least 3 registers");
191+
192+
// Store all pointers with offsets to their specific stack locations
193+
for (size_t i = 0; i < mem_ptrs.size(); i++) {
194+
const auto& ptr_reg = mem_ptrs[i];
195+
int32_t stack_offset = static_cast<int32_t>(stack_offsets[i]);
196+
197+
if (i < memory_offsets.size() && ov::snippets::utils::is_dynamic_value(memory_offsets[i])) {
198+
if (i < buffer_ids.size() && !ov::snippets::utils::is_dynamic_value(buffer_ids[i]) && buffer_ids[i] < 24) {
199+
// Dynamic offset: read from runtime parameters
200+
size_t runtime_offset = GET_OFF(buffer_offsets) + buffer_ids[i] * sizeof(size_t);
201+
push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, runtime_offset);
202+
} else {
203+
// Invalid buffer ID, store with zero offset
204+
push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, 0);
205+
}
206+
} else {
207+
// Static offset: add compile-time constant
208+
size_t offset = (i < memory_offsets.size()) ? memory_offsets[i] : 0;
209+
push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, offset);
210+
}
211+
}
212+
}
213+
181214
} // namespace ov::intel_cpu::aarch64::utils

src/plugins/intel_cpu/src/emitters/snippets/aarch64/utils.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,22 @@ void push_and_load_ptrs_with_offsets(dnnl::impl::cpu::aarch64::jit_generator* h,
9999
const std::vector<Xbyak_aarch64::XReg>& aux_regs,
100100
const std::vector<Xbyak_aarch64::XReg>& load_regs);
101101

102+
/**
103+
* @brief Push multiple data pointers to specific stack offsets with applied memory offsets
104+
* This is designed for struct-based calling conventions where pointers need to be stored
105+
* at specific stack locations rather than loaded back to registers.
106+
* @param h generator
107+
* @param mem_ptrs vector of registers containing data pointers
108+
* @param memory_offsets vector of memory offsets (can be dynamic or static)
109+
* @param buffer_ids vector of buffer IDs for runtime offset calculation
110+
* @param aux_regs vector of available auxiliary registers (must contain >= 3 registers, no overlap with mem_ptrs)
111+
* @param stack_offsets vector of stack offsets where adjusted pointers should be stored
112+
*/
113+
void push_ptrs_with_offsets_to_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
114+
const std::vector<Xbyak_aarch64::XReg>& mem_ptrs,
115+
const std::vector<size_t>& memory_offsets,
116+
const std::vector<size_t>& buffer_ids,
117+
const std::vector<Xbyak_aarch64::XReg>& aux_regs,
118+
const std::vector<size_t>& stack_offsets);
119+
102120
} // namespace ov::intel_cpu::aarch64::utils

0 commit comments

Comments
 (0)