Skip to content

Commit 24f0da4

Browse files
committed
add docs
1 parent 1f7fdec commit 24f0da4

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

python/paddle/nn/functional/moe_permute.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,16 @@ def moe_permute(
4747
3. The padding_alignment parameter affects memory efficiency but not correctness.
4848
4. Any output tokens can find an exact-match in the original input tokens.
4949
5. This permute function has overcomed the aadiff issue, is deterministic.
50+
6. If using_ue8m0_scale is True, then the data type of scale must be int32, and each int32 is packaged from 4 ue8m0 scaling factors.
5051
5152
Args:
5253
hidden_states (Tensor): The input tensor containing tokens to be permuted, stored in row-major layout.
5354
Supported data types: bfloat16 or float8_e4m3fn.
5455
Shape: [sequence_length, token_dimension]
5556
scale (Tensor|None): Scaling factors required when hidden_states is of float8 type.
5657
For float8 inputs, this tensor provides the scaling factors for dequantization.
57-
Shape: [sequence_length, ceil(token_dimension / 128)]
58-
Data type: float32
58+
Shape: [sequence_length, ceil(token_dimension / 128)]. If using_ue8m0_scale is True, the shape is [sequence_length, ceil(ceil(token_dimension / 128)/4)].
59+
Data type: float32 or int32(Only when using_ue8m0_scale is True). If using_ue8m0_scale is True, the data type of scale is int32 which is packed of four ue8m0 scaling factors.
5960
expert_routemap_topk (Tensor): Tensor indicating expert assignments for each token (top-k experts).
6061
Each value represents the expert index the token is assigned to (-1 indicates not assigned).
6162
Shape: [sequence_length, top_k_experts]
@@ -70,6 +71,7 @@ def moe_permute(
7071
padding_alignment (int): Tokens alignment requirement for expert buffers (in bytes).
7172
Must be a power of 2. Typical values are 16, 32 or 64 for optimal memory access.
7273
do_gather(bool): Decide whether do actual tokens gather operation or not, default is True.
74+
using_ue8m0_scale (bool): Whether to use the ue8m0 scaling for float8 inputs. Default is False.
7375
name (str|None, optional): Name prefix for the operation (optional).
7476
Default: None
7577
@@ -85,8 +87,8 @@ def moe_permute(
8587
Shape: [total_tokens_after_broadcast, 1]
8688
Data type: float32
8789
- scale_unzipped (Tensor): Broadcasted scale tensor (only valid for float8 inputs).
88-
Shape: [total_tokens_after_broadcast, ceil(token_dimension / 128)]
89-
Data type: float32
90+
Shape: [total_tokens_after_broadcast, scale.shape[-1]]
91+
Data type: float32 or int32. It is same as scale.
9092
9193
Examples:
9294
.. code-block:: python

0 commit comments

Comments
 (0)