diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 43ffd9853299ab..05868ba8197e2f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -137,6 +137,8 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() { config.rotary_ndims = config.head_size; config.is_interleaved = true; config.output_trans0213 = false; + config.use_rope_cache = false; + config.cos_sin_ndims = static_cast(head_size.i()); OutputVector new_args; new_args.push_back(pattern_map.at(x)); @@ -534,6 +536,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { pattern_map.at(rotary_emb).get_node_shared_ptr(), pattern_map.at(result).get_node_shared_ptr()}; config.rotary_ndims = static_cast(ndims.i()); + config.use_rope_cache = true; + config.cos_sin_ndims = static_cast(ndims_over_2.i()); // Fuse output transpose to Rope. auto root_target_inputs = root->output(0).get_target_inputs(); @@ -736,7 +740,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(const bool support_2d_rope) { config.use_rope_cache = true; config.head_cnt = static_cast(head_cnt.i()); config.head_size = static_cast(head_size.i()); - + config.cos_sin_ndims = static_cast(ndims_over_2.i()); const auto& qkv_proj_node = pattern_map.at(qkv_proj); const size_t qkv_proj_output_id = qkv_proj_node.get_index(); if (qkv_proj_output_id == 0) { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.cpp index 25324982a33963..9bd6d9f40fe937 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.cpp @@ -144,25 +144,48 @@ void jit_rotary_kernel::rotary_interleave(size_t step) { vshufps(src1, tmp0, tmp1, 0xdd); } }; + const bool no_share_halves = (m_jcp.cos_sin_ndims != 0 && m_jcp.cos_sin_ndims != m_jcp.rotary_ndims / 2); deinterlace(vmm_src0, vmm_src1, vmm_dst0, vmm_dst1); - // cos[j] - load(vmm_cos, reg_cos, ov::element::f32, step, false); - // sin[j] - if (m_jcp.mix_cos_sin) { - load(vmm_sin, reg_cos, ov::element::f32, step, false, step * sizeof(float)); - deinterlace(vmm_cos, vmm_sin, vmm_dst0, vmm_dst1); - } else { + if (no_share_halves) { + // load cos + load(vmm_cos, reg_cos, ov::element::f32, step, false); + load(vmm_cos1, reg_cos, ov::element::f32, step, false, step * ov::element::f32.size()); + deinterlace(vmm_cos, vmm_cos1, vmm_dst0, vmm_dst1); + // load sin load(vmm_sin, reg_sin, ov::element::f32, step, false); + load(vmm_sin1, reg_sin, ov::element::f32, step, false, step * ov::element::f32.size()); + deinterlace(vmm_sin, vmm_sin1, vmm_dst0, vmm_dst1); + + // sin[i] * src1 + uni_vmulps(vmm_dst0, vmm_sin, vmm_src1); + // cos[i] * src0 - sin[i] * src1 + vfmsub231ps(vmm_dst0, vmm_cos, vmm_src0); + + // cos[i+1] * src1 + uni_vmulps(vmm_dst1, vmm_cos1, vmm_src1); + // cos[i+1] * src1 + sin[i+1] * src0 + vfmadd231ps(vmm_dst1, vmm_sin1, vmm_src0); + } else { + // cos[j] + load(vmm_cos, reg_cos, ov::element::f32, step, false); + // sin[j] + if (m_jcp.mix_cos_sin) { + load(vmm_sin, reg_cos, ov::element::f32, step, false, step * sizeof(float)); + deinterlace(vmm_cos, vmm_sin, vmm_dst0, vmm_dst1); + } else { + load(vmm_sin, reg_sin, ov::element::f32, step, false); + } + // sin[j] * src1 + uni_vmulps(vmm_dst0, vmm_sin, vmm_src1); + // cos[j] * src0 - sin[j] * src1 + vfmsub231ps(vmm_dst0, vmm_cos, vmm_src0); + + // cos[j] * src1 + uni_vmulps(vmm_dst1, vmm_cos, vmm_src1); + // cos[j] * src1 + sin[j] * src0 + vfmadd231ps(vmm_dst1, vmm_sin, vmm_src0); } - // sin[j] * src1 - uni_vmulps(vmm_dst0, vmm_sin, vmm_src1); - // cos[j] * src0 - sin[j] * src1 - vfmsub231ps(vmm_dst0, vmm_cos, vmm_src0); - // cos[j] * src1 - uni_vmulps(vmm_dst1, vmm_cos, vmm_src1); - // cos[j] * src1 + sin[j] * src0 - vfmadd231ps(vmm_dst1, vmm_sin, vmm_src0); if (isa == cpu_isa_t::avx2) { // dst0: 0 2 4 6 8 10 12 14 // dst1: 1 3 5 7 9 11 13 15 @@ -190,11 +213,16 @@ void jit_rotary_kernel::rotary_interleave(size_t step) { store(reg_dst, vmm_dst1, m_jcp.dst_prc, step, step * m_jcp.dst_prc.size()); add(reg_src, m_jcp.src_prc.size() * step * 2); add(reg_dst, m_jcp.dst_prc.size() * step * 2); - if (m_jcp.mix_cos_sin) { - add(reg_cos, 2 * sizeof(float) * step); + if (no_share_halves) { + add(reg_cos, sizeof(float) * step * 2); + add(reg_sin, sizeof(float) * step * 2); } else { - add(reg_cos, sizeof(float) * step); - add(reg_sin, sizeof(float) * step); + if (m_jcp.mix_cos_sin) { + add(reg_cos, 2 * sizeof(float) * step); + } else { + add(reg_cos, sizeof(float) * step); + add(reg_sin, sizeof(float) * step); + } } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.hpp index 7fbeb56c83ac83..88a2e06ae3c9d3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/rope_kernel.hpp @@ -74,9 +74,11 @@ struct jit_rotary_kernel : public JitKernel& inputs, const std::vector& outputs) override { ov::intel_cpu::PlainTensor t_src(inputs[0]); - ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]); ov::intel_cpu::PlainTensor t_dst(outputs[0]); - auto batch_size = t_src.size(0); - auto seq_len = t_src.size(1); - auto head_cnt = t_src.size(2); - auto head_dims = t_src.size(3); - auto rotary_dims = m_config.rotary_ndims; auto half_rotary_dims = rotary_dims / 2; + if (m_config.use_rope_cache) { + ov::intel_cpu::PlainTensor t_sin_cos(inputs[1]); + const auto batch_size = t_src.size(0); + const auto seq_len = t_src.size(1); + const auto head_cnt = t_src.size(2); + const auto head_dims = t_src.size(3); + parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) { + auto* x = t_src.ptr(b, p, h); + float* sin = &t_sin_cos.at({b, p, 0}, true); + float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true); + auto* dst = m_config.output_trans0213 ? t_dst.ptr(b, h, p) : t_dst.ptr(b, p, h); - parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) { - auto* x = t_src.ptr(b, p, h); - float* sin = &t_sin_cos.at({b, p, 0}, true); - float* cos = &t_sin_cos.at({b, p, half_rotary_dims}, true); - auto* dst = m_config.output_trans0213 ? t_dst.ptr(b, h, p) : t_dst.ptr(b, p, h); - - if (m_rotaryKernel) { - execJitKernel(m_rotaryKernel, x, dst, cos, sin); - } else { - size_t i = 0; - for (size_t j = 0; i < rotary_dims; i += 2, j++) { - dst[i] = cos[j] * x[i] - sin[j] * x[i + 1]; - dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i]; + if (m_rotaryKernel) { + execJitKernel(m_rotaryKernel, x, dst, cos, sin); + } else { + size_t i = 0; + for (size_t j = 0; i < rotary_dims; i += 2, j++) { + dst[i] = cos[j] * x[i] - sin[j] * x[i + 1]; + dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i]; + } } - } - memcpy(dst + rotary_dims, x + rotary_dims, (head_dims - rotary_dims) * sizeof(T)); - }); + memcpy(dst + rotary_dims, x + rotary_dims, (head_dims - rotary_dims) * sizeof(T)); + }); + } else { + const auto batch_size = t_src.size(0); + const auto dim_1 = t_src.size(1); + const auto dim_2 = t_src.size(2); + const auto head_dims = t_src.size(3); + ov::intel_cpu::PlainTensor t_cos(inputs[1]); + ov::intel_cpu::PlainTensor t_sin(inputs[2]); + parallel_for3d(batch_size, dim_1, dim_2, [&](size_t b, size_t d_1, size_t d_2) { + auto* x = t_src.ptr(b, d_1, d_2); + float* sin = &t_sin.at({b, d_1, d_2}, true); + float* cos = &t_cos.at({b, d_1, d_2}, true); + auto* dst = m_config.output_trans0213 ? t_dst.ptr(b, d_2, d_1) : t_dst.ptr(b, d_1, d_2); + if (m_rotaryKernel) { + execJitKernel(m_rotaryKernel, x, dst, cos, sin); + } else { + for (size_t i = 0; i < rotary_dims; i += 2) { + dst[i] = cos[i] * x[i] - sin[i] * x[i + 1]; + dst[i + 1] = cos[i + 1] * x[i + 1] + sin[i + 1] * x[i]; + } + } + memcpy(dst + rotary_dims, x + rotary_dims, (head_dims - rotary_dims) * sizeof(T)); + }); + } } }; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 35329228973004..ca8b59a95605fc 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -61,7 +61,6 @@ #include "transformations/common_optimizations/convert_pagedattn_inputs.hpp" #include "transformations/common_optimizations/convert_quantize_dequantize.hpp" #include "transformations/common_optimizations/fq_mul_fusion.hpp" -#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" #include "transformations/common_optimizations/lora_subgraph_fusion.hpp" #include "transformations/common_optimizations/lstm_cell_fusion.hpp" #include "transformations/common_optimizations/mark_precision_sensitive_shapeof_subgraphs.hpp" @@ -235,6 +234,7 @@ # include "snippets/pass/common_optimizations.hpp" # include "snippets/pass/split_dimension_m.hpp" # include "snippets/utils/tokenization_utils.hpp" +# include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" # include "transformations/common_optimizations/rms_fusion.hpp" # include "transformations/cpu_opset/common/op/sdpa.hpp" # include "transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.hpp" @@ -288,6 +288,7 @@ #endif #if defined(OPENVINO_ARCH_ARM64) +# include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" # include "transformations/op_conversions/hard_sigmoid_decomposition.hpp" # include "transformations/op_conversions/hsigmoid_decomposition.hpp" #endif @@ -1100,7 +1101,6 @@ void Transformations::PostLpt() { CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true); CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true); - CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::RoPEFusionFlux); CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion); #if defined(OPENVINO_ARCH_X86_64) diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp index f35df5056e3ebd..69a791b1d078b2 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -93,5 +93,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTOSS, ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestGPTOSS::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestFlux, + RoPETestFlux, + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + RoPETestFlux::getTestCaseName); + } // namespace test } // namespace ov