Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() {
config.rotary_ndims = config.head_size;
config.is_interleaved = true;
config.output_trans0213 = false;
config.use_rope_cache = false;
config.cos_sin_ndims = static_cast<size_t>(head_size.i());

OutputVector new_args;
new_args.push_back(pattern_map.at(x));
Expand Down Expand Up @@ -534,6 +536,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(result).get_node_shared_ptr()};
config.rotary_ndims = static_cast<size_t>(ndims.i());
config.use_rope_cache = true;
config.cos_sin_ndims = static_cast<size_t>(ndims_over_2.i());

// Fuse output transpose to Rope.
auto root_target_inputs = root->output(0).get_target_inputs();
Expand Down Expand Up @@ -736,7 +740,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(const bool support_2d_rope) {
config.use_rope_cache = true;
config.head_cnt = static_cast<size_t>(head_cnt.i());
config.head_size = static_cast<size_t>(head_size.i());

config.cos_sin_ndims = static_cast<size_t>(ndims_over_2.i());
const auto& qkv_proj_node = pattern_map.at(qkv_proj);
const size_t qkv_proj_output_id = qkv_proj_node.get_index();
if (qkv_proj_output_id == 0) {
Expand Down
66 changes: 47 additions & 19 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,25 +144,48 @@ void jit_rotary_kernel<isa>::rotary_interleave(size_t step) {
vshufps(src1, tmp0, tmp1, 0xdd);
}
};
const bool no_share_halves = (m_jcp.cos_sin_ndims != 0 && m_jcp.cos_sin_ndims != m_jcp.rotary_ndims / 2);
deinterlace(vmm_src0, vmm_src1, vmm_dst0, vmm_dst1);
// cos[j]
load(vmm_cos, reg_cos, ov::element::f32, step, false);
// sin[j]
if (m_jcp.mix_cos_sin) {
load(vmm_sin, reg_cos, ov::element::f32, step, false, step * sizeof(float));
deinterlace(vmm_cos, vmm_sin, vmm_dst0, vmm_dst1);
} else {
if (no_share_halves) {
// load cos
load(vmm_cos, reg_cos, ov::element::f32, step, false);
load(vmm_cos1, reg_cos, ov::element::f32, step, false, step * ov::element::f32.size());
deinterlace(vmm_cos, vmm_cos1, vmm_dst0, vmm_dst1);
// load sin
load(vmm_sin, reg_sin, ov::element::f32, step, false);
load(vmm_sin1, reg_sin, ov::element::f32, step, false, step * ov::element::f32.size());
deinterlace(vmm_sin, vmm_sin1, vmm_dst0, vmm_dst1);

// sin[i] * src1
uni_vmulps(vmm_dst0, vmm_sin, vmm_src1);
// cos[i] * src0 - sin[i] * src1
vfmsub231ps(vmm_dst0, vmm_cos, vmm_src0);

// cos[i+1] * src1
uni_vmulps(vmm_dst1, vmm_cos1, vmm_src1);
// cos[i+1] * src1 + sin[i+1] * src0
vfmadd231ps(vmm_dst1, vmm_sin1, vmm_src0);
} else {
// cos[j]
load(vmm_cos, reg_cos, ov::element::f32, step, false);
// sin[j]
if (m_jcp.mix_cos_sin) {
load(vmm_sin, reg_cos, ov::element::f32, step, false, step * sizeof(float));
deinterlace(vmm_cos, vmm_sin, vmm_dst0, vmm_dst1);
} else {
load(vmm_sin, reg_sin, ov::element::f32, step, false);
}
// sin[j] * src1
uni_vmulps(vmm_dst0, vmm_sin, vmm_src1);
// cos[j] * src0 - sin[j] * src1
vfmsub231ps(vmm_dst0, vmm_cos, vmm_src0);

// cos[j] * src1
uni_vmulps(vmm_dst1, vmm_cos, vmm_src1);
// cos[j] * src1 + sin[j] * src0
vfmadd231ps(vmm_dst1, vmm_sin, vmm_src0);
}
// sin[j] * src1
uni_vmulps(vmm_dst0, vmm_sin, vmm_src1);
// cos[j] * src0 - sin[j] * src1
vfmsub231ps(vmm_dst0, vmm_cos, vmm_src0);

// cos[j] * src1
uni_vmulps(vmm_dst1, vmm_cos, vmm_src1);
// cos[j] * src1 + sin[j] * src0
vfmadd231ps(vmm_dst1, vmm_sin, vmm_src0);
if (isa == cpu_isa_t::avx2) {
// dst0: 0 2 4 6 8 10 12 14
// dst1: 1 3 5 7 9 11 13 15
Expand Down Expand Up @@ -190,11 +213,16 @@ void jit_rotary_kernel<isa>::rotary_interleave(size_t step) {
store(reg_dst, vmm_dst1, m_jcp.dst_prc, step, step * m_jcp.dst_prc.size());
add(reg_src, m_jcp.src_prc.size() * step * 2);
add(reg_dst, m_jcp.dst_prc.size() * step * 2);
if (m_jcp.mix_cos_sin) {
add(reg_cos, 2 * sizeof(float) * step);
if (no_share_halves) {
add(reg_cos, sizeof(float) * step * 2);
add(reg_sin, sizeof(float) * step * 2);
} else {
add(reg_cos, sizeof(float) * step);
add(reg_sin, sizeof(float) * step);
if (m_jcp.mix_cos_sin) {
add(reg_cos, 2 * sizeof(float) * step);
} else {
add(reg_cos, sizeof(float) * step);
add(reg_sin, sizeof(float) * step);
}
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ struct jit_rotary_kernel : public JitKernel<jit_rotary_compile_params, jit_rotar
const Vmm vmm_src1 = Vmm(1);
const Vmm vmm_cos = Vmm(2);
const Vmm vmm_sin = Vmm(3);
const Vmm vmm_dst0 = Vmm(4);
const Vmm vmm_dst1 = Vmm(5);
const Vmm vmm_idx = Vmm(7);
const Vmm vmm_cos1 = Vmm(4);
const Vmm vmm_sin1 = Vmm(5);
const Vmm vmm_dst0 = Vmm(6);
const Vmm vmm_dst1 = Vmm(7);
const Vmm vmm_idx = Vmm(8);
const Xbyak::Reg64 reg_src = r8;
const Xbyak::Reg64 reg_cos = r10;
const Xbyak::Reg64 reg_sin = r11;
Expand Down
67 changes: 45 additions & 22 deletions src/plugins/intel_cpu/src/nodes/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,41 +204,64 @@ struct RoPE::RoPEExecutorInterleaved : public RoPE::Executor {
jcp.rotary_ndims = config.rotary_ndims;
jcp.interleave = true;
jcp.mix_cos_sin = false;
jcp.cos_sin_ndims = config.cos_sin_ndims;
m_rotaryKernel = createJitKernel(jcp, true);
}

void execute([[maybe_unused]] const dnnl::stream& strm,
const std::vector<MemoryPtr>& inputs,
const std::vector<MemoryPtr>& outputs) override {
ov::intel_cpu::PlainTensor t_src(inputs[0]);
ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]);
ov::intel_cpu::PlainTensor t_dst(outputs[0]);

auto batch_size = t_src.size(0);
auto seq_len = t_src.size(1);
auto head_cnt = t_src.size(2);
auto head_dims = t_src.size(3);

auto rotary_dims = m_config.rotary_ndims;
auto half_rotary_dims = rotary_dims / 2;
if (m_config.use_rope_cache) {
ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]);
const auto batch_size = t_src.size(0);
const auto seq_len = t_src.size(1);
const auto head_cnt = t_src.size(2);
const auto head_dims = t_src.size(3);
parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) {
auto* x = t_src.ptr<T>(b, p, h);
float* sin = &t_sin_cos.at<float>({b, p, 0}, true);
float* cos = &t_sin_cos.at<float>({b, p, half_rotary_dims}, true);
auto* dst = m_config.output_trans0213 ? t_dst.ptr<T>(b, h, p) : t_dst.ptr<T>(b, p, h);

parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) {
auto* x = t_src.ptr<T>(b, p, h);
float* sin = &t_sin_cos.at<float>({b, p, 0}, true);
float* cos = &t_sin_cos.at<float>({b, p, half_rotary_dims}, true);
auto* dst = m_config.output_trans0213 ? t_dst.ptr<T>(b, h, p) : t_dst.ptr<T>(b, p, h);

if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, x, dst, cos, sin);
} else {
size_t i = 0;
for (size_t j = 0; i < rotary_dims; i += 2, j++) {
dst[i] = cos[j] * x[i] - sin[j] * x[i + 1];
dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i];
if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, x, dst, cos, sin);
} else {
size_t i = 0;
for (size_t j = 0; i < rotary_dims; i += 2, j++) {
dst[i] = cos[j] * x[i] - sin[j] * x[i + 1];
dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i];
}
}
}
memcpy(dst + rotary_dims, x + rotary_dims, (head_dims - rotary_dims) * sizeof(T));
});
memcpy(dst + rotary_dims, x + rotary_dims, (head_dims - rotary_dims) * sizeof(T));
});
} else {
const auto batch_size = t_src.size(0);
const auto dim_1 = t_src.size(1);
const auto dim_2 = t_src.size(2);
const auto head_dims = t_src.size(3);
ov::intel_cpu::PlainTensor t_cos(inputs[1]);
ov::intel_cpu::PlainTensor t_sin(inputs[2]);
parallel_for3d(batch_size, dim_1, dim_2, [&](size_t b, size_t d_1, size_t d_2) {
auto* x = t_src.ptr<T>(b, d_1, d_2);
float* sin = &t_sin.at<float>({b, d_1, d_2}, true);
float* cos = &t_cos.at<float>({b, d_1, d_2}, true);
auto* dst = m_config.output_trans0213 ? t_dst.ptr<T>(b, d_2, d_1) : t_dst.ptr<T>(b, d_1, d_2);
if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, x, dst, cos, sin);
} else {
for (size_t i = 0; i < rotary_dims; i += 2) {
dst[i] = cos[i] * x[i] - sin[i] * x[i + 1];
dst[i + 1] = cos[i + 1] * x[i + 1] + sin[i + 1] * x[i];
}
}
memcpy(dst + rotary_dims, x + rotary_dims, (head_dims - rotary_dims) * sizeof(T));
});
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
#include "transformations/common_optimizations/convert_pagedattn_inputs.hpp"
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"
#include "transformations/common_optimizations/lora_subgraph_fusion.hpp"
#include "transformations/common_optimizations/lstm_cell_fusion.hpp"
#include "transformations/common_optimizations/mark_precision_sensitive_shapeof_subgraphs.hpp"
Expand Down Expand Up @@ -235,6 +234,7 @@
# include "snippets/pass/common_optimizations.hpp"
# include "snippets/pass/split_dimension_m.hpp"
# include "snippets/utils/tokenization_utils.hpp"
# include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"
# include "transformations/common_optimizations/rms_fusion.hpp"
# include "transformations/cpu_opset/common/op/sdpa.hpp"
# include "transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.hpp"
Expand Down Expand Up @@ -288,6 +288,7 @@
#endif

#if defined(OPENVINO_ARCH_ARM64)
# include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"
# include "transformations/op_conversions/hard_sigmoid_decomposition.hpp"
# include "transformations/op_conversions/hsigmoid_decomposition.hpp"
#endif
Expand Down Expand Up @@ -1100,7 +1101,6 @@ void Transformations::PostLpt() {

CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true);
CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true);
CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::RoPEFusionFlux);
CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion);

#if defined(OPENVINO_ARCH_X86_64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTOSS,
::testing::Values(ov::test::utils::DEVICE_CPU)),
RoPETestGPTOSS::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestFlux,
RoPETestFlux,
::testing::Combine(
::testing::Values(ov::element::f32),
::testing::Values(ov::test::utils::DEVICE_CPU)),
RoPETestFlux::getTestCaseName);

} // namespace test
} // namespace ov
Loading