diff --git a/src/common/transformations/include/transformations/common_optimizations/matmul_experts_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/matmul_experts_fusion.hpp new file mode 100644 index 00000000000000..482695ff3ce9ae --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/matmul_experts_fusion.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API FuseVectorizedMOE2GEMM; +class TRANSFORMATIONS_API FuseVectorizedMOE3GEMM; +class TRANSFORMATIONS_API VectorizedExpertsFusion; + +} // namespace pass +} // namespace ov + +class ov::pass::FuseVectorizedMOE2GEMM : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("FuseVectorizedMOE2GEMM"); + FuseVectorizedMOE2GEMM(); +}; + +class ov::pass::FuseVectorizedMOE3GEMM : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("FuseVectorizedMOE3GEMM"); + FuseVectorizedMOE3GEMM(); +}; + +class ov::pass::VectorizedExpertsFusion : public ov::pass::GraphRewrite { +public: + OPENVINO_GRAPH_REWRITE_RTTI("VectorizedExpertsFusion"); + VectorizedExpertsFusion() { + add_matcher(); + add_matcher(); + } +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/matmul_experts_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/matmul_experts_fusion.cpp new file mode 100644 index 00000000000000..76bbbef9abf8e0 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/matmul_experts_fusion.cpp @@ -0,0 +1,198 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/matmul_experts_fusion.hpp" + +#include "itt.hpp" +#include "openvino/core/graph_util.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/clamp.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/minimum.hpp" +#include "openvino/op/moe.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reduce_sum.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/swish.hpp" +#include "openvino/op/tile.hpp" +#include "openvino/op/topk.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::pass; +ov::pass::FuseVectorizedMOE2GEMM::FuseVectorizedMOE2GEMM() { + MATCHER_SCOPE(FuseVectorizedMOE2GEMM); + + auto experts_input = pattern::wrap_type({pattern::any_input(), pattern::any_input()}); + auto tile = pattern::wrap_type({experts_input, pattern::any_input()}); + auto after_tile_reshape = pattern::wrap_type({tile, pattern::any_input()}); + auto gate_up_matmul = pattern::wrap_type({after_tile_reshape, pattern::any_input()}, + {{"transpose_a", false}, {"transpose_b", false}}); + auto gate_up_add = pattern::wrap_type({gate_up_matmul, pattern::any_input()}); + + // Branch 1: Slice_1 -> Clamp -> Add_1 + auto slice1 = pattern::wrap_type( + {gate_up_add, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()}); + auto clamp = pattern::wrap_type({slice1}); + auto add1 = pattern::wrap_type({clamp, pattern::wrap_const()}); + + // Branch 2: Slice_2 -> Minimum_1 -> Swish + auto slice2 = pattern::wrap_type( + {gate_up_add, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()}); + auto minimum1 = pattern::wrap_type({slice2, pattern::wrap_const()}); + auto swish_beta = pattern::wrap_const(); + auto swish = pattern::wrap_type({minimum1, swish_beta}); + + // Join: Multiply_2 + auto multiply2 = pattern::wrap_type({add1, swish}); + + // Down projection + auto down_proj_matmul = pattern::wrap_type({multiply2, pattern::any_input()}, + {{"transpose_a", false}, {"transpose_b", false}}); + auto down_proj_add = pattern::wrap_type({down_proj_matmul, pattern::wrap_const()}); + auto end_reshape = pattern::wrap_type({down_proj_add, pattern::any_input()}); + + // Routing weights/mask + auto router_topk_indices = pattern::any_input(); + auto scatter_elements_update = pattern::wrap_type( + {pattern::any_input(), router_topk_indices, pattern::any_input(), pattern::any_input()}); + + auto router_transpose = pattern::wrap_type({scatter_elements_update, pattern::any_input()}); + auto router_reshape = pattern::wrap_type({router_transpose, pattern::any_input()}); + auto unsqueeze_routing_weights = pattern::wrap_type({router_reshape, pattern::any_input()}); + + auto mul3 = pattern::wrap_type({end_reshape, unsqueeze_routing_weights}); + auto reduce_sum = pattern::wrap_type({mul3, pattern::any_input()}, {{"keep_dims", false}}); + auto moe_pattern = reduce_sum; + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto& pm = m.get_pattern_value_map(); + + auto experts_input_node = pm.at(experts_input).get_node()->input_value(0); + + auto routing_weights_node = pm.at(unsqueeze_routing_weights).get_node_shared_ptr(); + auto gate_up_weight = pm.at(gate_up_matmul).get_node()->input_value(1).get_node_shared_ptr(); + auto gate_up_bias_node = pm.at(gate_up_add).get_node()->input_value(1).get_node_shared_ptr(); + auto down_proj_weight = pm.at(down_proj_matmul).get_node()->input_value(1).get_node_shared_ptr(); + auto down_proj_bias_node = pm.at(down_proj_add).get_node()->input_value(1).get_node_shared_ptr(); + auto topk_indices_node = pm.at(scatter_elements_update).get_node()->input_value(1); + + ov::OutputVector moe_inputs = {experts_input_node, + routing_weights_node, + topk_indices_node, + gate_up_weight, + gate_up_bias_node, + down_proj_weight, + down_proj_bias_node}; + + ov::op::internal::MOE::Config config; + + // Extract expert_beta from Swish beta attribute + auto swish_beta_const = ov::as_type_ptr(pm.at(swish_beta).get_node_shared_ptr()); + auto swish_beta_const_val = swish_beta_const->cast_vector()[0]; + config.expert_beta = swish_beta_const_val; + + // Extract expert_alpha from Clamp max attribute + if (auto clamp_op = ov::as_type_ptr(pm.at(clamp).get_node_shared_ptr())) { + config.expert_alpha = static_cast(clamp_op->get_max()); + } + + // Set expert_type + config.expert_type = ov::op::internal::MOE::Expert_type::GEMM2_BIAS_SWIGLU_CLAMP; + + auto moe = std::make_shared(moe_inputs, config); + moe->set_friendly_name(m.get_match_root()->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), moe); + ov::replace_node(m.get_match_root(), moe); + + register_new_node(moe); + return true; + }; + + auto matcher = std::make_shared(moe_pattern, matcher_name); + this->register_matcher(matcher, callback); +} + +ov::pass::FuseVectorizedMOE3GEMM::FuseVectorizedMOE3GEMM() { + MATCHER_SCOPE(FuseVectorizedMOE3GEMM); + + auto experts_input = pattern::wrap_type({pattern::any_input(), pattern::any_input()}); + auto tile = pattern::wrap_type({experts_input, pattern::any_input()}); + auto after_tile_reshape = pattern::wrap_type({tile, pattern::any_input()}); + + // First GEMM (activation gate) + auto gate_matmul = pattern::wrap_type({after_tile_reshape, pattern::any_input()}, + {{"transpose_a", false}, {"transpose_b", true}}); + auto swish = pattern::wrap_type({gate_matmul}); + // Second GEMM (up_projection) + auto up_matmul = pattern::wrap_type({after_tile_reshape, pattern::any_input()}, + {{"transpose_a", false}, {"transpose_b", true}}); + // Join: Multiply (SwiGLU) + auto swiglu = pattern::wrap_type({swish, up_matmul}); + + // Third GEMM (down_projection) + auto down_matmul = pattern::wrap_type({swiglu, pattern::any_input()}, + {{"transpose_a", false}, {"transpose_b", true}}); + auto end_reshape = pattern::wrap_type({down_matmul, pattern::any_input()}); + + // Routing weights/mask + auto router_topk_indices = pattern::any_input(); + auto scatter_elements_update = pattern::wrap_type( + {pattern::any_input(), router_topk_indices, pattern::any_input(), pattern::any_input()}); + auto router_transpose = pattern::wrap_type({scatter_elements_update, pattern::any_input()}); + auto router_reshape = pattern::wrap_type({router_transpose, pattern::any_input()}); + auto unsqueeze_routing_weights = pattern::wrap_type({router_reshape, pattern::any_input()}); + + auto mul3 = pattern::wrap_type({end_reshape, unsqueeze_routing_weights}); + auto reduce_sum = pattern::wrap_type({mul3, pattern::any_input()}, {{"keep_dims", false}}); + auto moe_pattern = reduce_sum; + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto& pm = m.get_pattern_value_map(); + auto experts_input_node = pm.at(experts_input).get_node()->input_value(0); + auto routing_weights_node = pm.at(unsqueeze_routing_weights).get_node_shared_ptr(); + auto gate_weight = pm.at(gate_matmul).get_node()->input_value(1).get_node_shared_ptr(); + auto up_weight = pm.at(up_matmul).get_node()->input_value(1).get_node_shared_ptr(); + auto down_weight = pm.at(down_matmul).get_node()->input_value(1).get_node_shared_ptr(); + auto topk_indices_node = pm.at(scatter_elements_update).get_node()->input_value(1); + + ov::OutputVector moe_inputs = { + experts_input_node, + routing_weights_node, + topk_indices_node, + gate_weight, + up_weight, + down_weight, + }; + + ov::op::internal::MOE::Config config; + config.expert_type = ov::op::internal::MOE::Expert_type::GEMM3_SWIGLU; + // Extract expert_beta if Swish has beta input provided + if (auto swish_op = ov::as_type_ptr(pm.at(swish).get_node_shared_ptr())) { + if (swish_op->get_input_size() > 1) { + if (auto swish_beta_const = + ov::as_type_ptr(swish_op->get_input_node_shared_ptr(1))) { + config.expert_beta = swish_beta_const->cast_vector()[0]; + } + } + } + + auto moe = std::make_shared(moe_inputs, config); + moe->set_friendly_name(m.get_match_root()->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), moe); + ov::replace_node(m.get_match_root(), moe); + + register_new_node(moe); + return true; + }; + + auto matcher = std::make_shared(moe_pattern, matcher_name); + this->register_matcher(matcher, callback); +} diff --git a/src/common/transformations/tests/common_optimizations/fuse_vectorized_moe_test.cpp b/src/common/transformations/tests/common_optimizations/fuse_vectorized_moe_test.cpp new file mode 100644 index 00000000000000..90fac722910d04 --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/fuse_vectorized_moe_test.cpp @@ -0,0 +1,451 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/graph_comparator.hpp" +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/core/node_vector.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/runtime/core.hpp" +#include "ov_ops/type_relaxed.hpp" +#include "transformations/common_optimizations/matmul_experts_fusion.hpp" +#include "transformations/utils/gen_pattern.hpp" + +inline std::shared_ptr build_2gemm_moe_pattern_model() { + using namespace ov; + + const size_t batch = 2; + const Dimension in_dim = Dimension::dynamic(); + const size_t hidden_size = 2048; + const size_t intermediate_size = 4096; + const size_t topk = 2; + const size_t number_of_experts = 3; + const size_t fusion_factor = 2; + const auto expert_alpha = 1.702f; + const auto expert_beta = 7.0f; + + auto input_shape = PartialShape{batch, in_dim, hidden_size}; + auto input = std::make_shared(element::f32, input_shape); + auto experts_reshape = std::make_shared( + input, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{-1, hidden_size}), + false); + + auto tile = std::make_shared( + experts_reshape, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{number_of_experts, 1})); + auto after_tile_reshape = std::make_shared( + tile, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{number_of_experts, batch, hidden_size}), + false); + + auto gate_up_matmul = std::make_shared( + after_tile_reshape, + op::v0::Constant::create(element::f32, + Shape{number_of_experts, hidden_size, intermediate_size * fusion_factor}, + {1.0f})); + auto gate_up_add = std::make_shared( + gate_up_matmul, + op::v0::Constant::create(element::f32, Shape{number_of_experts, 1, intermediate_size * fusion_factor}, {0.0f})); + + auto slice1 = std::make_shared( + gate_up_add, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{0, 0, 0}), + op::v0::Constant::create(element::i64, + Shape{3}, + std::vector{number_of_experts, batch, intermediate_size * 2}), + op::v0::Constant::create(element::i64, Shape{3}, std::vector{1, 1, 2}), + op::v0::Constant::create(element::i64, Shape{3}, std::vector{0, 1, 2})); + auto clamp = std::make_shared(slice1, -expert_beta, expert_beta); + auto add1 = std::make_shared(clamp, op::v0::Constant::create(element::f32, Shape{1}, {1.0f})); + + auto slice2 = std::make_shared( + gate_up_add, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{0, 1, 0}), + op::v0::Constant::create(element::i64, + Shape{3}, + std::vector{number_of_experts, batch, intermediate_size * 2}), + op::v0::Constant::create(element::i64, Shape{3}, std::vector{1, 1, 2}), + op::v0::Constant::create(element::i64, Shape{3}, std::vector{0, 1, 2})); + auto minimum1 = + std::make_shared(slice2, op::v0::Constant::create(element::f32, Shape{1}, {10.0f})); + auto swish_beta = op::v0::Constant::create(element::f32, Shape{}, std::vector{expert_alpha}); + auto swish = std::make_shared(minimum1, swish_beta); + + auto multiply2 = std::make_shared(add1, swish); + + auto down_proj_matmul = std::make_shared( + multiply2, + op::v0::Constant::create(element::f32, Shape{number_of_experts, intermediate_size, hidden_size}, {1.0f})); + + auto down_proj_add = std::make_shared( + down_proj_matmul, + op::v0::Constant::create(element::f32, Shape{number_of_experts, 1, hidden_size}, {1.0f})); + + auto end_reshape = std::make_shared( + down_proj_add, + op::v0::Constant::create(element::i64, + Shape{4}, + std::vector{number_of_experts, batch, -1, hidden_size}), + false); + + // Router subgraph used to test correctness of routing weights connection + auto reshape_2nd_consumer_router_matmul = std::make_shared( + experts_reshape, + op::v0::Constant::create(element::f32, Shape{number_of_experts, hidden_size}, {1.0f}), + false, + true); + + auto router_bias = + std::make_shared(reshape_2nd_consumer_router_matmul, + op::v0::Constant::create(element::f32, Shape{1, number_of_experts}, {1.0f})); + + auto router_topk_values_and_indices = + std::make_shared(router_bias, + op::v0::Constant::create(element::i64, Shape{}, {topk}), + -1, + op::v11::TopK::Mode::MAX, + op::v11::TopK::SortType::SORT_VALUES, + element::i64); + + auto router_topk_values = router_topk_values_and_indices->output(0); + auto router_topk_indices = router_topk_values_and_indices->output(1); + + auto scatter_elements_update = std::make_shared( + router_topk_values, + router_topk_indices, + op::v0::Constant::create(element::f32, Shape{batch, topk}, {0}), + op::v0::Constant::create(element::i64, Shape{1}, std::vector{1})); + auto router_transpose = std::make_shared( + scatter_elements_update, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{1, 0})); + auto router_reshape = std::make_shared( + router_transpose, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{number_of_experts, batch, -1}), + true); + auto unsqueeze_routing_weights = + std::make_shared(router_reshape, + op::v0::Constant::create(element::i64, Shape{1}, std::vector{-1})); + + auto mul3 = std::make_shared(end_reshape, unsqueeze_routing_weights); + + // ReduceSum - final node of the MOE pattern to be fused + auto reduce_sum = + std::make_shared(mul3, + op::v0::Constant::create(element::i64, Shape{1}, std::vector{0}), + false); + + return std::make_shared(ov::OutputVector{reduce_sum}, ov::ParameterVector{input}); +} + +inline std::shared_ptr build_fused_2gemm_moe_reference_model() { + using namespace ov; + + const size_t batch = 2; + const Dimension in_dim = Dimension::dynamic(); + const size_t hidden_size = 2048; + const size_t intermediate_size = 4096; + const size_t topk = 2; + const size_t number_of_experts = 3; + const size_t fusion_factor = 2; + const auto expert_alpha = 1.702f; + const auto expert_beta = 7.0f; + + auto input_shape = PartialShape{batch, in_dim, hidden_size}; + auto input = std::make_shared(element::f32, input_shape); + + // Begin of Router subgraph (not fused, but valuable for testing) + auto experts_reshape = std::make_shared( + input, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{-1, hidden_size}), + false); + + auto reshape_2nd_consumer_router_matmul = std::make_shared( + experts_reshape, + op::v0::Constant::create(element::f32, Shape{number_of_experts, hidden_size}, {1.0f}), + false, + true); + + auto router_bias = + std::make_shared(reshape_2nd_consumer_router_matmul, + op::v0::Constant::create(element::f32, Shape{1, number_of_experts}, {1.0f})); + + auto router_topk_values_and_indices = + std::make_shared(router_bias, + op::v0::Constant::create(element::i64, Shape{}, {topk}), + -1, + op::v11::TopK::Mode::MAX, + op::v11::TopK::SortType::SORT_VALUES, + element::i64); + + auto router_topk_values = router_topk_values_and_indices->output(0); + auto router_topk_indices = router_topk_values_and_indices->output(1); + + auto scatter_elements_update = std::make_shared( + router_topk_values, + router_topk_indices, + op::v0::Constant::create(element::f32, Shape{batch, topk}, {0}), + op::v0::Constant::create(element::i64, Shape{1}, std::vector{1})); + auto router_transpose = std::make_shared( + scatter_elements_update, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{1, 0})); + auto router_reshape = std::make_shared( + router_transpose, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{number_of_experts, batch, -1}), + true); + auto unsqueeze_routing_weights = + std::make_shared(router_reshape, op::v0::Constant::create(element::i64, Shape{1}, {-1})); + // End of Router subgraph + + // Expert MatMuls weights fused into MOE + auto w0_weight = op::v0::Constant::create(element::f32, + Shape{number_of_experts, hidden_size, intermediate_size * fusion_factor}, + {1.0f}); + auto w0_bias = + op::v0::Constant::create(element::f32, Shape{number_of_experts, 1, intermediate_size * fusion_factor}, {0.0f}); + auto w1_weight = + op::v0::Constant::create(element::f32, Shape{number_of_experts, intermediate_size, hidden_size}, {1.0f}); + auto w1_bias = op::v0::Constant::create(element::f32, Shape{number_of_experts, 1, hidden_size}, {1.0f}); + + ov::OutputVector moe_inputs = + {input, unsqueeze_routing_weights, router_topk_indices, w0_weight, w0_bias, w1_weight, w1_bias}; + + ov::op::internal::MOE::Config config; + config.expert_type = ov::op::internal::MOE::Expert_type::GEMM2_BIAS_SWIGLU_CLAMP; + config.expert_alpha = expert_alpha; + config.expert_beta = expert_beta; + + auto moe = std::make_shared(moe_inputs, config); + return std::make_shared(ov::OutputVector{moe}, ov::ParameterVector{input}); +} + +inline std::shared_ptr build_3gemm_moe_pattern_model() { + using namespace ov; + + const size_t batch = 2; + const Dimension in_dim = Dimension::dynamic(); + const size_t hidden_size = 2048; + const size_t intermediate_size = 4096; + const size_t number_of_experts = 3; + const size_t topk = 2; + + auto input_shape = PartialShape{batch, in_dim, hidden_size}; + auto input = std::make_shared(element::f32, input_shape); + auto experts_reshape = std::make_shared( + input, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{-1, hidden_size}), + false); + + auto tile = std::make_shared( + experts_reshape, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{number_of_experts, 1})); + auto after_tile_reshape = std::make_shared( + tile, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{number_of_experts, batch, hidden_size}), + false); + + // First GEMM (gate) + auto gate_matmul = std::make_shared( + after_tile_reshape, + op::v0::Constant::create(element::f32, Shape{number_of_experts, intermediate_size, hidden_size}, {1.0f}), + false, + true); + + auto swish = std::make_shared(gate_matmul); + + // Second GEMM (up) + auto up_matmul = std::make_shared( + after_tile_reshape, + op::v0::Constant::create(element::f32, Shape{number_of_experts, intermediate_size, hidden_size}, {1.0f}), + false, + true); + + auto swiglu = std::make_shared(swish, up_matmul); + + // Third GEMM (down) + auto down_matmul = std::make_shared( + swiglu, + op::v0::Constant::create(element::f32, Shape{number_of_experts, hidden_size, intermediate_size}, {1.0f}), + false, + true); + + auto experts_out_reshape = std::make_shared( + down_matmul, + op::v0::Constant::create(element::i64, + Shape{4}, + std::vector{number_of_experts, batch, -1, hidden_size}), + false); + + // Router subgraph used to test correctness of routing weights connection + auto router_matmul = std::make_shared( + experts_reshape, + op::v0::Constant::create(element::f32, Shape{number_of_experts, hidden_size}, {1.0f}), + false, + true); + + auto router_topk_values_and_indices = + std::make_shared(router_matmul, + op::v0::Constant::create(element::i64, Shape{}, {topk}), + -1, + op::v11::TopK::Mode::MAX, + op::v11::TopK::SortType::SORT_VALUES, + element::i64); + + auto router_topk_values = router_topk_values_and_indices->output(0); + auto router_topk_indices = router_topk_values_and_indices->output(1); + + auto scatter_elements_update = std::make_shared( + router_topk_values, + router_topk_indices, + op::v0::Constant::create(element::f32, Shape{batch, topk}, {0}), + op::v0::Constant::create(element::i64, Shape{1}, std::vector{1})); + auto router_transpose = std::make_shared( + scatter_elements_update, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{1, 0})); + auto router_reshape = std::make_shared( + router_transpose, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{number_of_experts, batch, -1}), + true); + auto unsqueeze_routing_weights = + std::make_shared(router_reshape, + op::v0::Constant::create(element::i64, Shape{1}, std::vector{-1})); + + auto mul3 = std::make_shared(experts_out_reshape, unsqueeze_routing_weights); + + // ReduceSum - final node of the MOE pattern to be fused + auto reduce_sum = + std::make_shared(mul3, + op::v0::Constant::create(element::i64, Shape{1}, std::vector{0}), + false); + + return std::make_shared(ov::OutputVector{reduce_sum}, ov::ParameterVector{input}); +} + +inline std::shared_ptr build_fused_3gemm_moe_reference_model() { + using namespace ov; + + const size_t batch = 2; + const Dimension in_dim = Dimension::dynamic(); + const size_t hidden_size = 2048; + const size_t intermediate_size = 4096; + const size_t number_of_experts = 3; + const size_t topk = 2; + + auto input = std::make_shared(element::f32, PartialShape{batch, in_dim, hidden_size}); + + // Begin of Router subgraph (not fused, but valuable for testing) + auto experts_reshape = std::make_shared( + input, + op::v0::Constant::create(element::i64, Shape{2}, std::vector{-1, hidden_size}), + false); + + auto router_matmul = std::make_shared( + experts_reshape, + op::v0::Constant::create(element::f32, Shape{number_of_experts, hidden_size}, {1.0f}), + false, + true); + + auto router_topk = std::make_shared(router_matmul, + op::v0::Constant::create(element::i64, Shape{}, {topk}), + -1, + op::v11::TopK::Mode::MAX, + op::v11::TopK::SortType::SORT_VALUES, + element::i64); + + auto router_topk_values = router_topk->output(0); + auto router_topk_indices = router_topk->output(1); + + auto scatter_elements_update = std::make_shared( + router_topk_values, + router_topk_indices, + op::v0::Constant::create(element::f32, Shape{batch, topk}, {0}), + op::v0::Constant::create(element::i64, Shape{1}, {1})); + + auto router_transpose = + std::make_shared(scatter_elements_update, + op::v0::Constant::create(element::i64, Shape{2}, {1, 0})); + auto router_reshape = std::make_shared( + router_transpose, + op::v0::Constant::create(element::i64, Shape{3}, std::vector{number_of_experts, batch, -1}), + true); + + auto unsqueeze_routing_weights = + std::make_shared(router_reshape, op::v0::Constant::create(element::i64, Shape{1}, {-1})); + + // MOE fused op + auto w0_weight = + op::v0::Constant::create(element::f32, Shape{number_of_experts, intermediate_size, hidden_size}, {1.0f}); + auto w1_weight = + op::v0::Constant::create(element::f32, Shape{number_of_experts, intermediate_size, hidden_size}, {1.0f}); + auto w2_weight = + op::v0::Constant::create(element::f32, Shape{number_of_experts, hidden_size, intermediate_size}, {1.0f}); + + ov::OutputVector moe_inputs = + {input, unsqueeze_routing_weights, router_topk_indices, w0_weight, w1_weight, w2_weight}; + + ov::op::internal::MOE::Config config; + config.expert_type = ov::op::internal::MOE::Expert_type::GEMM3_SWIGLU; + + auto moe = std::make_shared(moe_inputs, config); + return std::make_shared(ov::OutputVector{moe}, ov::ParameterVector{input}); +} + +TEST_F(TransformationTestsF, FuseVectorizedMOE2GEMM_basic) { + model = build_2gemm_moe_pattern_model(); + manager.register_pass(); + model_ref = build_fused_2gemm_moe_reference_model(); +} + +TEST_F(TransformationTestsF, FuseVectorizedMOE2GEMM_VectorizedExpertsFusion) { + model = build_2gemm_moe_pattern_model(); + manager.register_pass(); + model_ref = build_fused_2gemm_moe_reference_model(); +} + +TEST_F(TransformationTestsF, FuseVectorizedMOE2GEMM_no_fusion) { + model = build_3gemm_moe_pattern_model(); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, FuseVectorizedMOE3GEMM_basic) { + model = build_3gemm_moe_pattern_model(); + manager.register_pass(); + model_ref = build_fused_3gemm_moe_reference_model(); +} + +TEST_F(TransformationTestsF, FuseVectorizedMOE3GEMM_VectorizedExpertsFusion) { + model = build_3gemm_moe_pattern_model(); + manager.register_pass(); + model_ref = build_fused_3gemm_moe_reference_model(); +} + +TEST_F(TransformationTestsF, FuseVectorizedMOE3GEMM_no_fusion) { + model = build_2gemm_moe_pattern_model(); + manager.register_pass(); +} \ No newline at end of file diff --git a/src/core/dev_api/openvino/op/moe.hpp b/src/core/dev_api/openvino/op/moe.hpp new file mode 100644 index 00000000000000..5147f15fa8b184 --- /dev/null +++ b/src/core/dev_api/openvino/op/moe.hpp @@ -0,0 +1,78 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "openvino/core/node.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/op.hpp" + +namespace ov::op::internal { +/// +/// \brief MOE experts +class OPENVINO_API MOE : public ov::op::Op { +public: + OPENVINO_OP("MOE"); + + MOE() = default; + + enum class Expert_type { GEMM2_BIAS_SWIGLU_CLAMP, GEMM3_SWIGLU }; + + struct Config { + Expert_type expert_type{Expert_type::GEMM2_BIAS_SWIGLU_CLAMP}; + float expert_alpha{0.0f}; // Expert attribute for clamp bounds + float expert_beta{1.0f}; // Expert attribute for swish beta + }; + + /// \brief Constructs a MOE operation with config only + /// \param args The input tensors, in the following order: + /// 0: hidden_states - input tensor with hidden representations + /// 1: routing_weights - [num_experts, ...] normalized weights for selected experts + /// (input to final multiplication) + /// 2: router_topk_output_indices - [..., topk] indices of selected top-k experts + /// 3: w0_weight - expert weights for first projection, shape [num_experts, inter_size, hidden_size] or + /// [num_experts, hidden_size, 2 * inter_size] if fused + /// 4: w0_bias (optional) - expert bias for first projection, + /// shape [num_experts, ...] or empty tensor if not needed + /// 5: w1_weight - expert weights for second projection, + /// shape [num_experts, inter_size, hidden_size] + /// 6: w1_bias (optional) - expert bias for second projection, shape + /// [num_experts, ...] or empty tensor if not needed + /// 7: w2_weight - expert weights for final projection, shape + /// [num_experts, hidden_size, inter_size] + /// 8: w2_bias (optional) - expert bias for final projection + /// \param config Configuration for the MOE operation + MOE(const OutputVector& args, const Config& config); + + const Config& get_config() const; + void set_config(const Config& config); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + +private: + Config m_config; +}; + +} // namespace ov::op::internal + +namespace ov { +std::ostream& operator<<(std::ostream& s, const ov::op::internal::MOE::Expert_type& type); + +template <> +class AttributeAdapter + : public EnumAttributeAdapterBase { +public: + AttributeAdapter(ov::op::internal::MOE::Expert_type& value) + : EnumAttributeAdapterBase(value) {} + + OPENVINO_RTTI("AttributeAdapter"); + ~AttributeAdapter() override = default; +}; +} // namespace ov diff --git a/src/core/src/op/moe.cpp b/src/core/src/op/moe.cpp new file mode 100644 index 00000000000000..6d28bf1bc52d6c --- /dev/null +++ b/src/core/src/op/moe.cpp @@ -0,0 +1,65 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/moe.hpp" + +#include "itt.hpp" + +namespace ov { +namespace op { +namespace internal { + +MOE::MOE(const OutputVector& args, const Config& config) : Op(args), m_config(config) { + constructor_validate_and_infer_types(); +} + +const MOE::Config& MOE::get_config() const { + return m_config; +} + +void MOE::set_config(const Config& config) { + m_config = config; +} + +std::shared_ptr MOE::clone_with_new_inputs(const ov::OutputVector& new_args) const { + OV_OP_SCOPE(internal_MOE_clone_with_new_inputs); + check_new_args_count(this, new_args); + + return std::make_shared(new_args, m_config); +} + +void MOE::validate_and_infer_types() { + OV_OP_SCOPE(internal_MOE_validate_and_infer_types); + // TODO: Add inputs validation + + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +} + +bool MOE::visit_attributes(ov::AttributeVisitor& visitor) { + OV_OP_SCOPE(internal_MOE_visit_attributes); + visitor.on_attribute("expert_type", m_config.expert_type); + visitor.on_attribute("expert_alpha", m_config.expert_alpha); + visitor.on_attribute("expert_beta", m_config.expert_beta); + + return true; +} + +} // namespace internal +} // namespace op + +std::ostream& operator<<(std::ostream& s, const ov::op::internal::MOE::Expert_type& type) { + return s << as_string(type); +} + +template <> +OPENVINO_API EnumNames& EnumNames::get() { + static auto enum_names = EnumNames( + "ov::op::internal::MOE::Expert_type", + { + {"gemm2_bias_swiglu_clamp", ov::op::internal::MOE::Expert_type::GEMM2_BIAS_SWIGLU_CLAMP}, + {"gemm3_swiglu", ov::op::internal::MOE::Expert_type::GEMM3_SWIGLU}, + }); + return enum_names; +} +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 0c417c4dc29e45..3f3b63d6088231 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -557,7 +557,6 @@ void Transformations::PreLpt(const std::vector& defaultPrecis }); }, ov::pass::KeepConstAndDecompression); - CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion); CPU_REGISTER_PASS_COMMON(manager, SDPASubgraphFusion); ov::pass::ConvertPagedAttnInputs::KVCacheConfig cacheConfig;