Skip to content

Commit e7031da

Browse files
committed
Address Alexandra's review comments 3
1 parent a8fec5c commit e7031da

File tree

5 files changed

+80
-88
lines changed

5 files changed

+80
-88
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp

Lines changed: 62 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -272,38 +272,40 @@ void jit_emitter::store_context(const std::vector<size_t>& gpr_regs,
272272
}
273273
}
274274

275-
// 2. SIMD and Floating-Point registers
276-
// 2.1. store pair registers
277-
int prev_reg_idx = -1;
278-
size_t ignore_registers_count = 0;
279-
for (const auto reg_idx : vec_regs) {
280-
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
281-
ignore_registers_count++;
282-
continue;
275+
// 2. SIMD and Floating-Point registers - optimized to allocate stack space once
276+
const auto store_vec_regs_size = vec_regs.size() - ignore_vec_regs.size();
277+
if (store_vec_regs_size > 0) {
278+
// Calculate total stack space needed for all vector registers (align once)
279+
const auto total_vec_shift = ov::intel_cpu::rnd_up(get_vec_length() * store_vec_regs_size, sp_alignment);
280+
281+
// Single stack allocation for all vector registers
282+
h->sub(h->sp, h->sp, total_vec_shift);
283+
284+
// Store vector registers using stack offset (preserving original order)
285+
const auto last = store_vec_regs_size % 2;
286+
int32_t current_offset = 0;
287+
288+
// Collect non-ignored registers
289+
std::vector<size_t> active_regs;
290+
for (const auto reg_idx : vec_regs) {
291+
if (ignore_vec_regs.find(reg_idx) == ignore_vec_regs.end()) {
292+
active_regs.push_back(reg_idx);
293+
}
283294
}
284-
if (prev_reg_idx == -1) {
285-
prev_reg_idx = static_cast<int>(reg_idx);
286-
continue;
295+
296+
// Store pairs
297+
for (size_t i = 0; i < (active_regs.size() - last); i += 2) {
298+
h->stp(Xbyak_aarch64::QReg(active_regs[i]),
299+
Xbyak_aarch64::QReg(active_regs[i + 1]),
300+
Xbyak_aarch64::ptr(h->sp, current_offset));
301+
current_offset += static_cast<int32_t>(get_vec_length() * 2);
287302
}
288-
const auto shift = ov::intel_cpu::rnd_up(get_vec_length() * 2, sp_alignment);
289-
h->stp(Xbyak_aarch64::QReg(prev_reg_idx),
290-
Xbyak_aarch64::QReg(reg_idx),
291-
pre_ptr(h->sp, -static_cast<int32_t>(shift)));
292-
prev_reg_idx = -1;
293-
}
294303

295-
// 2.1. store the remaining register
296-
if (prev_reg_idx != -1) {
297-
if (ignore_vec_regs.find(prev_reg_idx) == ignore_vec_regs.end()) {
298-
const auto shift = ov::intel_cpu::rnd_up(get_vec_length(), sp_alignment);
299-
h->str(Xbyak_aarch64::QReg(prev_reg_idx), pre_ptr(h->sp, -static_cast<int32_t>(shift)));
300-
} else {
301-
ignore_registers_count++;
304+
// Store the remaining register
305+
if (last != 0) {
306+
h->str(Xbyak_aarch64::QReg(active_regs[active_regs.size() - 1]), Xbyak_aarch64::ptr(h->sp, current_offset));
302307
}
303308
}
304-
305-
OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
306-
"ignored registers size is not equal actual ignored registers count");
307309
}
308310

309311
void jit_emitter::restore_context(const std::unordered_set<size_t>& ignore_vec_regs) const {
@@ -313,42 +315,44 @@ void jit_emitter::restore_context(const std::unordered_set<size_t>& ignore_vec_r
313315
void jit_emitter::restore_context(const std::vector<size_t>& gpr_regs,
314316
const std::vector<size_t>& vec_regs,
315317
const std::unordered_set<size_t>& ignore_vec_regs) const {
316-
// 1. SIMD and Floating-Point registers
317-
// 1.1. restore the remaining register
318-
auto v_last = (vec_regs.size() - ignore_vec_regs.size()) % 2;
319-
if (v_last != 0) {
320-
for (size_t i = 0; i < vec_regs.size(); i++) {
321-
const auto reg_idx = vec_regs.size() - 1 - i;
322-
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
323-
v_last++;
324-
continue;
318+
// 1. SIMD and Floating-Point registers - optimized to deallocate stack space once
319+
const auto save_vec_regs_size = vec_regs.size() - ignore_vec_regs.size();
320+
if (save_vec_regs_size > 0) {
321+
// Restore vector registers using stack offset (reverse order to match original behavior)
322+
const auto last = save_vec_regs_size % 2;
323+
if (last != 0) {
324+
int32_t current_offset = get_vec_length() * save_vec_regs_size - get_vec_length();
325+
// Find the last non-ignored register
326+
for (size_t i = 0; i < vec_regs.size(); i++) {
327+
const auto reg_idx = vec_regs.size() - 1 - i;
328+
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
329+
continue;
330+
}
331+
h->ldr(Xbyak_aarch64::QReg(reg_idx), Xbyak_aarch64::ptr(h->sp, current_offset));
332+
break;
325333
}
326-
327-
const auto shift = ov::intel_cpu::rnd_up(get_vec_length(), sp_alignment);
328-
h->ldr(Xbyak_aarch64::QReg(reg_idx), post_ptr(h->sp, shift));
329-
break;
330334
}
331-
}
332-
// 1.2. restore pair registers
333-
size_t ignore_registers_count = 0;
334-
int prev_reg_idx = -1;
335-
for (size_t i = v_last; i < vec_regs.size(); i++) {
336-
const auto reg_idx = vec_regs.size() - 1 - i;
337-
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
338-
ignore_registers_count++;
339-
continue;
335+
336+
// Collect non-ignored registers
337+
std::vector<size_t> active_regs;
338+
for (const auto reg_idx : vec_regs) {
339+
if (ignore_vec_regs.find(reg_idx) == ignore_vec_regs.end()) {
340+
active_regs.push_back(reg_idx);
341+
}
340342
}
341-
if (prev_reg_idx == -1) {
342-
prev_reg_idx = static_cast<int>(reg_idx);
343-
continue;
343+
344+
// Restore pairs in reverse order
345+
for (size_t i = last; i < active_regs.size(); i += 2) {
346+
int32_t current_offset = get_vec_length() * (active_regs.size() - (i + 2));
347+
h->ldp(Xbyak_aarch64::QReg(active_regs[active_regs.size() - 1 - (i + 1)]),
348+
Xbyak_aarch64::QReg(active_regs[active_regs.size() - 1 - i]),
349+
Xbyak_aarch64::ptr(h->sp, current_offset));
344350
}
345-
const auto shift = ov::intel_cpu::rnd_up(get_vec_length() * 2, sp_alignment);
346-
h->ldp(Xbyak_aarch64::QReg(reg_idx), Xbyak_aarch64::QReg(prev_reg_idx), post_ptr(h->sp, shift));
347-
prev_reg_idx = -1;
348-
}
349351

350-
OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
351-
"ignored registers size is not equal actual ignored registers count");
352+
const auto total_vec_shift = ov::intel_cpu::rnd_up(get_vec_length() * save_vec_regs_size, sp_alignment);
353+
// Single stack deallocation for all vector registers
354+
h->add(h->sp, h->sp, total_vec_shift);
355+
}
352356

353357
// 2. General-purpose Registers - optimized to deallocate stack space once
354358
const auto save_gpr_regs_size = gpr_regs.size();

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_load_store_emitters.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,18 @@ using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t;
2929
// Helper function to get max_offset and alignment for different register types
3030
template <typename RegType>
3131
static std::pair<int, int> get_load_store_limits() {
32-
// Default fallback
3332
int max_offset = 4095;
3433
int alignment = 1;
3534

3635
if constexpr (std::is_same_v<RegType, VReg> || std::is_same_v<RegType, QReg>) {
3736
max_offset = 65520; // 4095 * 16
3837
alignment = 16;
39-
} else if constexpr (std::is_same_v<RegType, SReg>) {
40-
max_offset = 16380;
41-
alignment = 4;
4238
} else if constexpr (std::is_same_v<RegType, DReg>) {
4339
max_offset = 32760;
4440
alignment = 8;
41+
} else if constexpr (std::is_same_v<RegType, SReg>) {
42+
max_offset = 16380;
43+
alignment = 4;
4544
} else if constexpr (std::is_same_v<RegType, HReg>) {
4645
max_offset = 8190;
4746
alignment = 2;

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ void jit_gemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vecto
7979

8080
void jit_gemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const {
8181
const auto& call_address_reg = get_call_address_reg();
82-
const auto& callee_saved_reg = get_callee_saved_reg();
83-
8482
std::unordered_set<size_t> exclude_spill = {};
8583
store_context(exclude_spill);
8684

@@ -96,7 +94,7 @@ void jit_gemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const
9694
const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);
9795

9896
// Collect used register indices to avoid conflicts with auxiliary registers
99-
std::vector<size_t> used_gpr_idxs = {call_address_reg.getIdx(), callee_saved_reg.getIdx()};
97+
std::vector<size_t> used_gpr_idxs = {call_address_reg.getIdx()};
10098
for (const auto& reg : mem_ptrs) {
10199
used_gpr_idxs.push_back(reg.getIdx());
102100
}

src/plugins/intel_cpu/src/emitters/snippets/aarch64/utils.cpp

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generat
8282
const Xbyak_aarch64::XReg& ptr_reg,
8383
const std::vector<Xbyak_aarch64::XReg>& aux_regs,
8484
size_t runtime_offset) {
85-
// Safety assertions as suggested
8685
OV_CPU_JIT_EMITTER_ASSERT(aux_regs.size() >= 3, "aux_regs must contain at least 3 registers");
8786

8887
// Assert that ptr_reg is not in aux_regs
@@ -154,21 +153,16 @@ void push_and_load_ptrs_with_offsets(dnnl::impl::cpu::aarch64::jit_generator* h,
154153
const auto sp_size = rnd_up(mem_ptrs.size() * gpr_length, sp_alignment);
155154
h->sub(h->sp, h->sp, sp_size);
156155

157-
// Push all pointers with offsets onto stack
156+
// Generate stack offsets for sequential storage
157+
std::vector<size_t> stack_offsets;
158+
stack_offsets.reserve(mem_ptrs.size());
158159
for (size_t i = 0; i < mem_ptrs.size(); i++) {
159-
const auto& ptr_reg = mem_ptrs[i];
160-
int32_t stack_offset = i * gpr_length;
161-
162-
if (ov::snippets::utils::is_dynamic_value(memory_offsets[i])) {
163-
// Dynamic offset: read from runtime parameters
164-
size_t runtime_offset = GET_OFF(buffer_offsets) + buffer_ids[i] * sizeof(size_t);
165-
push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, runtime_offset);
166-
} else {
167-
// Static offset: add compile-time constant
168-
push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, memory_offsets[i]);
169-
}
160+
stack_offsets.push_back(i * gpr_length);
170161
}
171162

163+
// Use the common function to push pointers with offsets to stack
164+
push_ptrs_with_offsets_to_stack(h, mem_ptrs, memory_offsets, buffer_ids, aux_regs, stack_offsets);
165+
172166
// Load back the adjusted pointers to specified registers
173167
for (size_t i = 0; i < load_regs.size() && i < mem_ptrs.size(); i++) {
174168
h->ldr(load_regs[i], Xbyak_aarch64::ptr(h->sp, static_cast<int32_t>(i * gpr_length)));
@@ -187,22 +181,17 @@ void push_ptrs_with_offsets_to_stack(dnnl::impl::cpu::aarch64::jit_generator* h,
187181
OV_CPU_JIT_EMITTER_ASSERT(mem_ptrs.size() == memory_offsets.size(), "mem_ptrs and memory_offsets size mismatch");
188182
OV_CPU_JIT_EMITTER_ASSERT(mem_ptrs.size() == buffer_ids.size(), "mem_ptrs and buffer_ids size mismatch");
189183
OV_CPU_JIT_EMITTER_ASSERT(mem_ptrs.size() == stack_offsets.size(), "mem_ptrs and stack_offsets size mismatch");
190-
OV_CPU_JIT_EMITTER_ASSERT(aux_regs.size() >= 3, "aux_regs must contain at least 3 registers");
191184

192185
// Store all pointers with offsets to their specific stack locations
193186
for (size_t i = 0; i < mem_ptrs.size(); i++) {
194187
const auto& ptr_reg = mem_ptrs[i];
195188
auto stack_offset = static_cast<int32_t>(stack_offsets[i]);
196189

197-
if (i < memory_offsets.size() && ov::snippets::utils::is_dynamic_value(memory_offsets[i])) {
198-
if (i < buffer_ids.size() && !ov::snippets::utils::is_dynamic_value(buffer_ids[i]) && buffer_ids[i] < 24) {
199-
// Dynamic offset: read from runtime parameters
200-
size_t runtime_offset = GET_OFF(buffer_offsets) + buffer_ids[i] * sizeof(size_t);
201-
push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, runtime_offset);
202-
} else {
203-
// Invalid buffer ID, store with zero offset
204-
push_ptr_with_static_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, 0);
205-
}
190+
if (ov::snippets::utils::is_dynamic_value(memory_offsets[i])) {
191+
OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(buffer_ids[i]), "In dynamic case Buffer ID must be defined");
192+
// Dynamic offset: read from runtime parameters
193+
size_t runtime_offset = GET_OFF(buffer_offsets) + buffer_ids[i] * sizeof(size_t);
194+
push_ptr_with_runtime_offset_on_stack(h, stack_offset, ptr_reg, aux_regs, runtime_offset);
206195
} else {
207196
// Static offset: add compile-time constant
208197
size_t offset = (i < memory_offsets.size()) ? memory_offsets[i] : 0;

src/plugins/intel_cpu/src/emitters/snippets/aarch64/utils.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::aarch64::jit_generato
8585

8686
/**
8787
* @brief Push multiple data pointers on stack with offsets and load them back to specified registers
88+
* Note: This helper doesn't allocate stack space - the user should guarantee allocated space on stack
8889
* @param h generator
8990
* @param mem_ptrs vector of registers containing data pointers
9091
* @param memory_offsets vector of memory offsets (can be dynamic or static)
@@ -103,6 +104,7 @@ void push_and_load_ptrs_with_offsets(dnnl::impl::cpu::aarch64::jit_generator* h,
103104
* @brief Push multiple data pointers to specific stack offsets with applied memory offsets
104105
* This is designed for struct-based calling conventions where pointers need to be stored
105106
* at specific stack locations rather than loaded back to registers.
107+
* Note: This helper doesn't allocate stack space - the user should guarantee allocated space on stack
106108
* @param h generator
107109
* @param mem_ptrs vector of registers containing data pointers
108110
* @param memory_offsets vector of memory offsets (can be dynamic or static)

0 commit comments

Comments
 (0)