forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmoe_permute.py
More file actions
149 lines (137 loc) · 7.49 KB
/
moe_permute.py
File metadata and controls
149 lines (137 loc) · 7.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from paddle import _C_ops
from paddle.base.framework import in_dynamic_or_pir_mode
if TYPE_CHECKING:
from paddle import Tensor
def moe_permute(
hidden_states: Tensor,
scale: Tensor | None,
expert_routemap_topk: Tensor,
expert_prob_topk: Tensor,
num_experts: int,
tokens_per_expert: list,
padding_alignment: int,
do_gather: bool = True,
using_ue8m0_scale: bool = False,
name: str | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
r"""
Permute tokens for Mixture of Experts (MoE) computation in distributed training scenarios.
Note:
This function reorganizes input tokens based on expert assignments to prepare for expert computation.
It handles both bfloat16 and float8_e4m3fn data types with proper scaling for float8 inputs.
1. This function is typically used in pair of moe_unpermute to provide complete MoE functionality.
2. For float8 inputs, proper scaling must be provided via the scale parameter.
3. The padding_alignment parameter affects memory efficiency but not correctness.
4. Any output tokens can find an exact-match in the original input tokens.
5. This permute function has overcomed the aadiff issue, is deterministic.
6. If using_ue8m0_scale is True, then the data type of scale must be int32, and each int32 is packaged from 4 ue8m0 scaling factors.
Args:
hidden_states (Tensor): The input tensor containing tokens to be permuted, stored in row-major layout.
Supported data types: bfloat16 or float8_e4m3fn.
Shape: [sequence_length, token_dimension]
scale (Tensor|None): Scaling factors required when hidden_states is of float8 type.
For float8 inputs, this tensor provides the scaling factors for dequantization.
Shape: [sequence_length, ceil(token_dimension / 128)]. If using_ue8m0_scale is True, the shape is [sequence_length, ceil(ceil(token_dimension / 128)/4)].
Data type: float32 or int32(Only when using_ue8m0_scale is True). If using_ue8m0_scale is True, the data type of scale is int32 which is packed of four ue8m0 scaling factors.
expert_routemap_topk (Tensor): Tensor indicating expert assignments for each token (top-k experts).
Each value represents the expert index the token is assigned to (-1 indicates not assigned).
Shape: [sequence_length, top_k_experts]
Data type: int32
Value range: [-1, num_experts)
expert_prob_topk (Tensor): Tensor containing routing probabilities for top-k experts.
Shape: [sequence_length, top_k_experts]
Data type: float32
num_experts (int): Total number of experts in the MoE layer, limited between 1 and 64.
tokens_per_expert (list[int]): List where each element indicates the number of tokens
assigned to the corresponding expert.
padding_alignment (int): Tokens alignment requirement for expert buffers (in bytes).
Must be a power of 2. Typical values are 16, 32 or 64 for optimal memory access.
do_gather(bool): Decide whether do actual tokens gather operation or not, default is True.
using_ue8m0_scale (bool): Whether to use the ue8m0 scaling for float8 inputs. Default is False.
name (str|None, optional): Name prefix for the operation (optional).
Default: None
Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
- hidden_states_unzipped (Tensor): The permuted and broadcasted input tensor.
Shape: [total_tokens_after_broadcast, token_dimension]
Data type: same as input hidden_states
- zipped_expertwise_rowmap (Tensor): Mapping tensor used to restore original order (unpermute).
Shape: [sequence_length, num_experts]
Data type: int32
- token_prob_unzipped (Tensor): Flattened expert probabilities aligned with permuted tokens.
Shape: [total_tokens_after_broadcast, 1]
Data type: float32
- scale_unzipped (Tensor): Broadcasted scale tensor (only valid for float8 inputs).
Shape: [total_tokens_after_broadcast, scale.shape[-1]]
Data type: float32 or int32. It is same as scale.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
>>> # doctest: +SKIP('This is only support in cuda 12.0+')
>>> import paddle
>>> import numpy as np
>>> import paddle.nn.functional as F
>>> hidden_states = paddle.randn([3, 128], dtype='bfloat16')
>>> expert_routemap_topk = paddle.to_tensor([[-1, 0, -1, -1, 2, -1, -1, -1],
... [1, -1, -1, -1, -1, -1, -1, -1],
... [-1, -1, -1, -1, -1, -1, 1, -1]],
... dtype='int32')
>>> expert_prob_topk= paddle.to_tensor([[0.0, 0.6, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0],
... [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]],
... dtype='float32')
>>> num_experts = 3
>>> tokens_per_expert = [1, 2, 1]
>>> padding_alignment = 2
>>> hidden_states_unzipped, zipped_expertwise_rowmap, token_prob_unzipped, scale_unzipped = F.moe_permute(
... hidden_states,
... None,
... expert_routemap_topk,
... expert_prob_topk,
... num_experts,
... tokens_per_expert,
... padding_alignment,
... )
>>> # weighted by probs.
>>> hidden_states_unzipped = (hidden_states_unzipped.astype("float32") * token_prob_unzipped.astype("float32").unsqueeze(-1)).astype("bfloat16")
>>> zipped_tokens, zipped_probs = F.moe_unpermute(hidden_states_unzipped, zipped_expertwise_rowmap, expert_routemap_topk, token_prob_unzipped,3,3)
>>> np.testing.assert_allclose(zipped_tokens.numpy(), hidden_states.numpy(), rtol=1e-05, atol=1e-06)
"""
if in_dynamic_or_pir_mode():
(
hidden_states_unzipped,
zipped_expertwise_rowmap,
token_prob_unzipped,
scale_unzipped,
) = _C_ops.moe_permute(
hidden_states,
scale,
expert_routemap_topk,
expert_prob_topk,
num_experts,
tokens_per_expert,
padding_alignment,
do_gather,
using_ue8m0_scale,
)
return (
hidden_states_unzipped,
zipped_expertwise_rowmap,
token_prob_unzipped,
scale_unzipped,
)