Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,17 @@ def moe_topk_select(
probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
) # [seq_len, n_group]
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group]
group_mask = paddle.zeros_like(group_scores).put_along_axis(
group_idx, paddle.to_tensor(1.0, dtype=group_scores.dtype), axis=-1
group_mask = paddle.sum(
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
)
Comment on lines +248 to 251
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building group_mask via one_hot(group_idx, num_classes=n_group) materializes a dense tensor of shape [seq_len, topk_group, n_group], which can significantly increase memory traffic (and potentially hurt latency) for large seq_len/n_group. Consider using a scatter/put-along-axis style update that writes ones directly into a [seq_len, n_group] buffer (e.g., initialize zeros then scatter indices with a ones tensor shaped like group_idx), which avoids the large intermediate.

Suggested change
group_mask = paddle.sum(
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
)
# Build group_mask of shape [seq_len, n_group] without materializing a large one-hot tensor
group_mask = paddle.zeros([seq_length, n_group], dtype=group_scores.dtype)
updates = paddle.ones_like(group_idx, dtype=group_scores.dtype)
group_mask = paddle.put_along_axis(group_mask, group_idx, updates, axis=1, reduce="assign")

Copilot uses AI. Check for mistakes.
score_mask = (
group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1])
) # [seq_len, n_experts]
probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))

_, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1)
topk_weights = paddle.take_along_axis(gate_probs, topk_ids, axis=-1)
topk_weights = paddle.index_sample(gate_probs, topk_ids)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.index_sample has a narrower contract than take_along_axis: it samples along a fixed dimension (commonly axis=1) and (in many Paddle versions) assumes x is 2-D [N, M] and index is [N, K]. If gate_probs can ever be higher-rank or if the intended gather axis is not the second dimension, this will return incorrect results or error at runtime. If you need general axis=-1 semantics, consider switching to an index construction using gather_nd (batch indices + topk_ids) or explicitly reshape/validate gate_probs to 2-D before calling index_sample.

Suggested change
topk_weights = paddle.index_sample(gate_probs, topk_ids)
# Use paddle.index_sample with its 2-D [N, M] / [N, K] contract by flattening
# all leading dimensions into a single batch dimension, and gather along the
# last axis. This is equivalent to take_along_axis(..., axis=-1) but robust
# to higher-rank gate_probs/topk_ids.
last_expert_dim = gate_probs.shape[-1]
flat_batch = paddle.numel(gate_probs) // last_expert_dim
gate_probs_2d = gate_probs.reshape([flat_batch, last_expert_dim])
topk_last_dim = topk_ids.shape[-1]
topk_ids_2d = topk_ids.reshape([flat_batch, topk_last_dim])
topk_weights_2d = paddle.index_sample(gate_probs_2d, topk_ids_2d)
topk_weights = topk_weights_2d.reshape(topk_ids.shape)

Copilot uses AI. Check for mistakes.

# normalize combine weights
if renormalize:
Expand Down
Loading