-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[Spec][MOE][Internal Op] Specification of MOE internal operation #32255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mitruska
wants to merge
10
commits into
openvinotoolkit:master
Choose a base branch
from
mitruska:mitruska/moe_internal_spec
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+152
−0
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
1fc8f7b
Internal MOE spec init
mitruska 9a66628
Merge remote-tracking branch 'upstream/master' into mitruska/moe_inte…
mitruska 95359d3
Minor spelling refactor
mitruska 1bb24d8
Switch beta with alpha to match the beta for swish naming
mitruska e5c0009
Refactor formatting
mitruska 8a9e4d1
Update identation
mitruska bce1465
Fix x_proj -> x_proj2 in GEMM3 mode in moe.rst
mitruska b9b12ff
Update docs/articles_en/documentation/openvino-ir-format/operation-se…
mitruska 4cb21b4
Apply suggestions from code review
mitruska 6609269
Update routing weights shape description moe.rst
mitruska File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
152 changes: 152 additions & 0 deletions
152
...ocumentation/openvino-ir-format/operation-sets/operation-specs/internal/moe.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
|
||
MOE | ||
=== | ||
|
||
.. meta:: | ||
:description: Learn about MOE - a Mixture of Experts block, receiving routing weights and active experts indices as inputs, and performing expert computation according to the selected expert_type. | ||
|
||
**Versioned name**: *MOE* | ||
|
||
**Category**: *Internal* | ||
|
||
**Short description**: | ||
The *MOE* (Mixture of Experts) operation fuses the computation of multiple experts, using routing weights and indices to select and combine expert outputs. | ||
|
||
**Detailed description**: | ||
The MOE operation receives hidden states, routing weights, and indices of selected experts, along with expert weights and (optionally) biases. It performs the expert computation as specified by the ``expert_type`` attribute, applying the routing_weights and combining the results. This enables efficient, fused computation of Mixture of Experts architectures excluding the router part (computation of routing weights). | ||
|
||
**Pseudocode for expert_type** | ||
|
||
The ``router_topk_output_indices`` are used to select the top-k experts for optimized computation, not included in the pseudocode below. | ||
|
||
* ``GEMM2_BIAS_SWIGLU_CLAMP``: | ||
|
||
.. code-block:: py | ||
:force: | ||
|
||
# Common part: Reshape hidden states and prepare for expert computation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I propose to add |
||
reshaped_hidden_states = reshape(hidden_states, [-1, 0], special_zero=True) | ||
tiled_hidden_states = tile(reshaped_hidden_states, [num_experts, 1]) | ||
reshaped_hidden_states = reshape(tiled_hidden_states, [num_experts, -1, 0], special_zero=True) | ||
|
||
# Experts computation part (GEMM2_BIAS_SWIGLU_CLAMP) | ||
# Fused gate_up computation | ||
gate_up = matmul(reshaped_hidden_states, weight_0, transpose_a=False, transpose_b=False) + bias_0 | ||
# Slice gate_up into two halves along last dimension, taking every second element with step two | ||
slice_1 = gate_up[..., ::2] # every second element starting from index 0 | ||
slice_2 = gate_up[..., 1::2] # every second element starting from index 1 | ||
# Branch 1: Minimum and Swish | ||
minimum_1 = minimum(slice_2, expert_alpha) | ||
swish_1 = swish(minimum_1, beta=expert_beta) | ||
# Branch 2: Clamp and Add | ||
clamp_1 = clamp(slice_1, -expert_alpha, expert_alpha) | ||
add_1 = clamp_1 + 1 | ||
# Multiply branches | ||
fused = add_1 * swish_1 | ||
# Down projection | ||
down_proj = matmul(fused, weight_1, transpose_a=False, transpose_b=False) + bias_1 | ||
|
||
# Common part: Routing and summation | ||
routed_experts = reshape(down_proj, [num_experts, batch_size, -1, hidden_size]) * routing_weights | ||
output = reduce_sum(routed_experts, axis=0, keep_dims=False) | ||
|
||
* ``GEMM3_SWIGLU``: | ||
|
||
.. code-block:: py | ||
:force: | ||
|
||
# Common part: Reshape hidden states and prepare for expert computation | ||
reshaped_hidden_states = reshape(hidden_states, [-1, 0], special_zero=True) | ||
tiled_hidden_states = tile(reshaped_hidden_states, [num_experts, 1]) | ||
reshaped_hidden_states = reshape(tiled_hidden_states, [num_experts, -1, 0], special_zero=True) | ||
|
||
# Experts computation part (GEMM3_SWIGLU) | ||
x_proj = matmul(reshaped_hidden_states, weight_0, transpose_a=False, transpose_b=False) | ||
x_proj2 = matmul(reshaped_hidden_states, weight_1, transpose_a=False, transpose_b=False) | ||
swiglu = swish(x_proj, beta=expert_beta) | ||
x_proj = x_proj2 * swiglu | ||
down_proj = matmul(x_proj, weight_2, transpose_a=False, transpose_b=False) | ||
|
||
# Common part: Routing and summation | ||
routed_experts = reshape(down_proj, [num_experts, batch_size, -1, hidden_size]) * routing_weights | ||
output = reduce_sum(routed_experts, axis=0, keep_dims=False) | ||
|
||
|
||
**Attributes** | ||
|
||
* *expert_type* | ||
|
||
* **Description**: Specifies the computation performed by each expert. Determines the sequence of operations (e.g., GEMM, activation, bias, clamp). | ||
* **Type**: ``enum`` (see below) | ||
* **Required**: *yes* | ||
* **Supported values**: | ||
|
||
* ``GEMM2_BIAS_SWIGLU_CLAMP``: Two GEMMs with bias, SwiGLU activation, and clamp. | ||
* ``GEMM3_SWIGLU``: Three GEMMs with SwiGLU activation. | ||
|
||
* *expert_alpha* | ||
|
||
* **Description**: Alpha attribute - used as the value for clamp min/max bounds (used with GEMM2_BIAS_SWIGLU_CLAMP). | ||
* **Type**: ``float`` | ||
* **Default value**: ``0.0`` | ||
* **Required**: *no* | ||
|
||
* *expert_beta* | ||
|
||
* **Description**: Beta attribute for activation functions (used for Swish, often with GEMM2_BIAS_SWIGLU_CLAMP). | ||
* **Type**: ``float`` | ||
* **Default value**: ``1.0`` | ||
* **Required**: *no* | ||
|
||
**Inputs** | ||
|
||
* **0**: ``hidden_states`` | ||
*2D tensor* of type *T* with shape ``[batch, ..., hidden_size]``. | ||
The input hidden representations. | ||
|
||
* **1**: ``routing_weights`` | ||
*Tensor* of type *T* with shape ``[num_experts, ..., 1]`` for example ``[num_experts, batch, seq_len, 1]``. | ||
The normalized weights for all of the experts with non-zero values at for the selected top-k experts (after routing/normalization). Used for multiplication of the experts subgraph result. | ||
|
||
* **2**: ``router_topk_output_indices`` | ||
*Tensor* of type *T_IND* with shape ``[..., topk]`` for example ``[batch, topk]``. | ||
Indices of the selected top-k ("active") experts for each input. | ||
|
||
* **3**: ``weight_0`` | ||
*Tensor* of type *T* with shape ``[num_experts, hidden_size, inter_size]`` | ||
or ``[num_experts, hidden_size, 2 * inter_size]`` if fused (e.g. with expert_type ``GEMM2_BIAS_SWIGLU_CLAMP``). | ||
Weights for the first MatMul. | ||
|
||
* **4**: ``bias_0`` *(required only for GEMM2_BIAS_SWIGLU_CLAMP)* | ||
*Tensor* of type *T* with shape ``[num_experts, ...]`` broadcastable to the output of the first MatMul, for example ``[num_experts, 1, 2 * inter_size]`` if fused (e.g. with expert_type ``GEMM2_BIAS_SWIGLU_CLAMP``) or empty tensor. | ||
Bias to be added after the first MatMul. | ||
|
||
* **5**: ``weight_1`` | ||
*Tensor* of type *T* with shape ``[num_experts, inter_size, hidden_size]``. | ||
Weights for the second MatMul. | ||
|
||
* **6**: ``bias_1`` *(optional)* | ||
*Tensor* of type *T* with shape ``[num_experts, ...]`` broadcastable to the output of the second MatMul or empty tensor. | ||
Bias to be added after the second MatMul. | ||
|
||
* **7**: ``weight_2`` *(optional)* | ||
*Tensor* of type *T* with shape ``[num_experts, hidden_size, inter_size]``. | ||
Weights for the third MatMul. | ||
|
||
* **8**: ``bias_2`` *(optional, currently not used with any of the supported expert_types)* | ||
*Tensor* of type *T* with shape ``[num_experts, ...]`` broadcastable to the output of the second MatMul or empty tensor. | ||
Bias to be added after the third MatMul. | ||
|
||
.. note:: | ||
|
||
Bias inputs are optional and can be omitted if no bias is used, for example with ``GEMM3_SWIGLU`` expert_type. Then the number of the weights should match the number of GEMMs. | ||
|
||
**Outputs** | ||
|
||
* **0**: Output tensor of type *T* with the same shape as hidden_states input ``[batch, ..., hidden_size]``. | ||
The fused output of the selected experts, weighted by routing weights. | ||
|
||
**Types** | ||
|
||
* *T*: any floating point type. | ||
* *T_IND*: ``int64`` or ``int32``. |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us not use MoE name because we can use it for external operation and for real MoE operation. Now it is a sort of
FusedExperts
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The routing weights and indices are provided as inputs, so the core MOE idea is preserved, final multiplication and ReduceSum are included.
I would keep the name as is, to make current purpose clear.
The MOE internal op can be refactored as needed in the future, also possibly extended with Router.