Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dd3da1e
Vectorized MOE fusion init
mitruska Sep 22, 2025
7b9572e
MOE op init
mitruska Sep 22, 2025
4a50118
MOE attrs/inputs adjust
mitruska Sep 22, 2025
7e3230e
Adjust inputs desc
mitruska Sep 22, 2025
104246a
Add adapters for expert_type enum
mitruska Sep 22, 2025
62e054b
Merge remote-tracking branch 'upstream/master' into mitruska/moe_vect…
mitruska Sep 23, 2025
6df368c
Fuse Multiply output before Reshape
mitruska Sep 23, 2025
c6448d3
MOE fusion unit test
mitruska Sep 24, 2025
f5c1c41
Add missing header
mitruska Sep 24, 2025
762bc9a
Move MOE op to internal
mitruska Sep 24, 2025
41145cf
Apply MOE transformation for CPU
mitruska Sep 24, 2025
5a34684
Revert CPIU transformation pipeline change
mitruska Sep 24, 2025
b46f960
Fix cast warning
mitruska Sep 24, 2025
e343cd8
Remove OPENVINO_API macros
mitruska Sep 24, 2025
0406105
Update input desc
mitruska Sep 24, 2025
eaede0d
No keep dims in Reduce
mitruska Sep 24, 2025
9ce1569
Add transpose attrs to MatMul patterns
mitruska Sep 30, 2025
90f31a2
Switch beta with alpha to match the beta for swish naming
mitruska Sep 30, 2025
9cd0170
Merge remote-tracking branch 'upstream/master' into mitruska/moe_vect…
mitruska Sep 30, 2025
6c9eb7d
Merge branch 'master' into mitruska/moe_vect_experts_fuse
mitruska Oct 1, 2025
3a96855
Add fusion transformation for the second expert_type (GEMM3)
mitruska Oct 2, 2025
df97c22
Update GEMM3 transpose_b attr to be true
mitruska Oct 13, 2025
6e17476
Merge remote-tracking branch 'upstream/master' into mitruska/moe_vect…
mitruska Oct 13, 2025
8be2257
Update GEMM2 pattern to match MatMul transpose_b=True
mitruska Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API FuseVectorizedMOE2GEMM;
class TRANSFORMATIONS_API FuseVectorizedMOE3GEMM;
class TRANSFORMATIONS_API VectorizedExpertsFusion;

} // namespace pass
} // namespace ov

class ov::pass::FuseVectorizedMOE2GEMM : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("FuseVectorizedMOE2GEMM");
FuseVectorizedMOE2GEMM();
};

class ov::pass::FuseVectorizedMOE3GEMM : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("FuseVectorizedMOE3GEMM");
FuseVectorizedMOE3GEMM();
};

class ov::pass::VectorizedExpertsFusion : public ov::pass::GraphRewrite {
public:
OPENVINO_GRAPH_REWRITE_RTTI("VectorizedExpertsFusion");
VectorizedExpertsFusion() {
add_matcher<ov::pass::FuseVectorizedMOE2GEMM>();
add_matcher<ov::pass::FuseVectorizedMOE3GEMM>();
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/matmul_experts_fusion.hpp"

#include "itt.hpp"
#include "openvino/core/graph_util.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/clamp.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/minimum.hpp"
#include "openvino/op/moe.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/swish.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/op/topk.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::pass;
ov::pass::FuseVectorizedMOE2GEMM::FuseVectorizedMOE2GEMM() {
MATCHER_SCOPE(FuseVectorizedMOE2GEMM);

auto experts_input = pattern::wrap_type<ov::op::v1::Reshape>({pattern::any_input(), pattern::any_input()});
auto tile = pattern::wrap_type<ov::op::v0::Tile>({experts_input, pattern::any_input()});
auto after_tile_reshape = pattern::wrap_type<ov::op::v1::Reshape>({tile, pattern::any_input()});
auto gate_up_matmul = pattern::wrap_type<ov::op::v0::MatMul>({after_tile_reshape, pattern::any_input()},
{{"transpose_a", false}, {"transpose_b", true}});
auto gate_up_add = pattern::wrap_type<ov::op::v1::Add>({gate_up_matmul, pattern::any_input()});

// Branch 1: Slice_1 -> Clamp -> Add_1
auto slice1 = pattern::wrap_type<ov::op::v8::Slice>(
{gate_up_add, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()});
auto clamp = pattern::wrap_type<ov::op::v0::Clamp>({slice1});
auto add1 = pattern::wrap_type<ov::op::v1::Add>({clamp, pattern::wrap_const()});

// Branch 2: Slice_2 -> Minimum_1 -> Swish
auto slice2 = pattern::wrap_type<ov::op::v8::Slice>(
{gate_up_add, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()});
auto minimum1 = pattern::wrap_type<ov::op::v1::Minimum>({slice2, pattern::wrap_const()});
auto swish_beta = pattern::wrap_const();
auto swish = pattern::wrap_type<ov::op::v4::Swish>({minimum1, swish_beta});

// Join: Multiply_2
auto multiply2 = pattern::wrap_type<ov::op::v1::Multiply>({add1, swish});

// Down projection
auto down_proj_matmul = pattern::wrap_type<ov::op::v0::MatMul>({multiply2, pattern::any_input()},
{{"transpose_a", false}, {"transpose_b", true}});
auto down_proj_add = pattern::wrap_type<ov::op::v1::Add>({down_proj_matmul, pattern::wrap_const()});
auto end_reshape = pattern::wrap_type<ov::op::v1::Reshape>({down_proj_add, pattern::any_input()});

// Routing weights/mask
auto router_topk_indices = pattern::any_input();
auto scatter_elements_update = pattern::wrap_type<ov::op::v12::ScatterElementsUpdate>(
{pattern::any_input(), router_topk_indices, pattern::any_input(), pattern::any_input()});

auto router_transpose = pattern::wrap_type<ov::op::v1::Transpose>({scatter_elements_update, pattern::any_input()});
auto router_reshape = pattern::wrap_type<ov::op::v1::Reshape>({router_transpose, pattern::any_input()});
auto unsqueeze_routing_weights = pattern::wrap_type<ov::op::v0::Unsqueeze>({router_reshape, pattern::any_input()});

auto mul3 = pattern::wrap_type<ov::op::v1::Multiply>({end_reshape, unsqueeze_routing_weights});
auto reduce_sum = pattern::wrap_type<ov::op::v1::ReduceSum>({mul3, pattern::any_input()}, {{"keep_dims", false}});
auto moe_pattern = reduce_sum;

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pm = m.get_pattern_value_map();

auto experts_input_node = pm.at(experts_input).get_node()->input_value(0);

auto routing_weights_node = pm.at(unsqueeze_routing_weights).get_node_shared_ptr();
auto gate_up_weight = pm.at(gate_up_matmul).get_node()->input_value(1).get_node_shared_ptr();
auto gate_up_bias_node = pm.at(gate_up_add).get_node()->input_value(1).get_node_shared_ptr();
auto down_proj_weight = pm.at(down_proj_matmul).get_node()->input_value(1).get_node_shared_ptr();
auto down_proj_bias_node = pm.at(down_proj_add).get_node()->input_value(1).get_node_shared_ptr();
auto topk_indices_node = pm.at(scatter_elements_update).get_node()->input_value(1);

ov::OutputVector moe_inputs = {experts_input_node,
routing_weights_node,
topk_indices_node,
gate_up_weight,
gate_up_bias_node,
down_proj_weight,
down_proj_bias_node};

ov::op::internal::MOE::Config config;

// Extract expert_beta from Swish beta attribute
auto swish_beta_const = ov::as_type_ptr<ov::op::v0::Constant>(pm.at(swish_beta).get_node_shared_ptr());
auto swish_beta_const_val = swish_beta_const->cast_vector<float>()[0];
config.expert_beta = swish_beta_const_val;

// Extract expert_alpha from Clamp max attribute
if (auto clamp_op = ov::as_type_ptr<ov::op::v0::Clamp>(pm.at(clamp).get_node_shared_ptr())) {
config.expert_alpha = static_cast<float>(clamp_op->get_max());
}

// Set expert_type
config.expert_type = ov::op::internal::MOE::Expert_type::GEMM2_BIAS_SWIGLU_CLAMP;

auto moe = std::make_shared<ov::op::internal::MOE>(moe_inputs, config);
moe->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), moe);
ov::replace_node(m.get_match_root(), moe);

register_new_node(moe);
return true;
};

auto matcher = std::make_shared<pattern::Matcher>(moe_pattern, matcher_name);
this->register_matcher(matcher, callback);
}

ov::pass::FuseVectorizedMOE3GEMM::FuseVectorizedMOE3GEMM() {
MATCHER_SCOPE(FuseVectorizedMOE3GEMM);

auto experts_input = pattern::wrap_type<ov::op::v1::Reshape>({pattern::any_input(), pattern::any_input()});
auto tile = pattern::wrap_type<ov::op::v0::Tile>({experts_input, pattern::any_input()});
auto after_tile_reshape = pattern::wrap_type<ov::op::v1::Reshape>({tile, pattern::any_input()});

// First GEMM (activation gate)
auto gate_matmul = pattern::wrap_type<ov::op::v0::MatMul>({after_tile_reshape, pattern::any_input()},
{{"transpose_a", false}, {"transpose_b", true}});
auto swish = pattern::wrap_type<ov::op::v4::Swish>({gate_matmul});
// Second GEMM (up_projection)
auto up_matmul = pattern::wrap_type<ov::op::v0::MatMul>({after_tile_reshape, pattern::any_input()},
{{"transpose_a", false}, {"transpose_b", true}});
// Join: Multiply (SwiGLU)
auto swiglu = pattern::wrap_type<ov::op::v1::Multiply>({swish, up_matmul});

// Third GEMM (down_projection)
auto down_matmul = pattern::wrap_type<ov::op::v0::MatMul>({swiglu, pattern::any_input()},
{{"transpose_a", false}, {"transpose_b", true}});
auto end_reshape = pattern::wrap_type<ov::op::v1::Reshape>({down_matmul, pattern::any_input()});

// Routing weights/mask
auto router_topk_indices = pattern::any_input();
auto scatter_elements_update = pattern::wrap_type<ov::op::v12::ScatterElementsUpdate>(
{pattern::any_input(), router_topk_indices, pattern::any_input(), pattern::any_input()});
auto router_transpose = pattern::wrap_type<ov::op::v1::Transpose>({scatter_elements_update, pattern::any_input()});
auto router_reshape = pattern::wrap_type<ov::op::v1::Reshape>({router_transpose, pattern::any_input()});
auto unsqueeze_routing_weights = pattern::wrap_type<ov::op::v0::Unsqueeze>({router_reshape, pattern::any_input()});

auto mul3 = pattern::wrap_type<ov::op::v1::Multiply>({end_reshape, unsqueeze_routing_weights});
auto reduce_sum = pattern::wrap_type<ov::op::v1::ReduceSum>({mul3, pattern::any_input()}, {{"keep_dims", false}});
auto moe_pattern = reduce_sum;

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pm = m.get_pattern_value_map();
auto experts_input_node = pm.at(experts_input).get_node()->input_value(0);
auto routing_weights_node = pm.at(unsqueeze_routing_weights).get_node_shared_ptr();
auto gate_weight = pm.at(gate_matmul).get_node()->input_value(1).get_node_shared_ptr();
auto up_weight = pm.at(up_matmul).get_node()->input_value(1).get_node_shared_ptr();
auto down_weight = pm.at(down_matmul).get_node()->input_value(1).get_node_shared_ptr();
auto topk_indices_node = pm.at(scatter_elements_update).get_node()->input_value(1);

ov::OutputVector moe_inputs = {
experts_input_node,
routing_weights_node,
topk_indices_node,
gate_weight,
up_weight,
down_weight,
};

ov::op::internal::MOE::Config config;
config.expert_type = ov::op::internal::MOE::Expert_type::GEMM3_SWIGLU;
// Extract expert_beta if Swish has beta input provided
if (auto swish_op = ov::as_type_ptr<ov::op::v4::Swish>(pm.at(swish).get_node_shared_ptr())) {
if (swish_op->get_input_size() > 1) {
if (auto swish_beta_const =
ov::as_type_ptr<ov::op::v0::Constant>(swish_op->get_input_node_shared_ptr(1))) {
config.expert_beta = swish_beta_const->cast_vector<float>()[0];
}
}
}

auto moe = std::make_shared<ov::op::internal::MOE>(moe_inputs, config);
moe->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), moe);
ov::replace_node(m.get_match_root(), moe);

register_new_node(moe);
return true;
};

auto matcher = std::make_shared<pattern::Matcher>(moe_pattern, matcher_name);
this->register_matcher(matcher, callback);
}
Loading
Loading