Skip to content

Conversation

@xinyu-intel
Copy link
Contributor

@xinyu-intel xinyu-intel commented Nov 27, 2025

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an optimized grouped top-k operation implementation for the Gaudi platform. The optimization involves intelligent handling of expert selection for mixture-of-experts (MoE) models, with special logic for batch sizes and optional score correction bias.

  • Adds has_optimized_grouped_topk() method returning True to indicate platform support
  • Implements grouped_topk() method with scoring functions (softmax/sigmoid), group-based expert selection, and optional bias correction
  • Includes adaptive algorithm selection based on token count threshold (1024)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 305 to 306
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, )
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements should be placed at the top of the file, not within a method. Move this import to the module level to follow Python conventions and improve code clarity.

Copilot uses AI. Check for mistakes.
Comment on lines 316 to 331
top1_val, top1_idx = torch.max(scores_tmp, dim=-1)
scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
group_scores, top2_idx = torch.max(scores_tmp, dim=-1)
group_scores.add_(top1_val)
else:
group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]
if num_token > 1024:
group_mask = torch.zeros_like(group_scores)
for i in range(topk_group):
_, group_idx = torch.max(group_scores, dim=-1)
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
if i < topk_group - 1:
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
else:
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable names top1_val, top1_idx, and top2_idx are inconsistent with their actual usage. top1_val stores maximum values but group_scores stores the second maximum values. Consider renaming to first_max_val, first_max_idx, and second_max_idx for clarity, or top2_val instead of group_scores to maintain consistency.

Suggested change
top1_val, top1_idx = torch.max(scores_tmp, dim=-1)
scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
group_scores, top2_idx = torch.max(scores_tmp, dim=-1)
group_scores.add_(top1_val)
else:
group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]
if num_token > 1024:
group_mask = torch.zeros_like(group_scores)
for i in range(topk_group):
_, group_idx = torch.max(group_scores, dim=-1)
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
if i < topk_group - 1:
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
else:
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
first_max_val, first_max_idx = torch.max(scores_tmp, dim=-1)
scores_tmp.scatter_(-1, first_max_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
second_max_val, second_max_idx = torch.max(scores_tmp, dim=-1)
second_max_val.add_(first_max_val)
else:
second_max_val = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]
if num_token > 1024:
group_mask = torch.zeros_like(second_max_val)
for i in range(topk_group):
_, group_idx = torch.max(second_max_val, dim=-1)
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
if i < topk_group - 1:
second_max_val.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
else:
group_idx = torch.topk(second_max_val, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
group_mask = torch.zeros_like(second_max_val) # [n, n_group]

Copilot uses AI. Check for mistakes.
group_scores, top2_idx = torch.max(scores_tmp, dim=-1)
group_scores.add_(top1_val)
else:
group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The threshold value 1024 for algorithm selection lacks explanation. Add a comment explaining why this specific threshold was chosen and what trade-offs exist between the two approaches.

Suggested change
group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]
group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]
# For large batches (num_token > 1024), use a loop-based approach to avoid potential performance and memory issues with torch.topk on large tensors.
# The threshold 1024 was chosen based on empirical benchmarks: for smaller batches, the vectorized torch.topk is faster and simpler,
# while for larger batches, the loop-based approach is more memory-efficient and avoids slowdowns. Adjust as needed for your hardware.

Copilot uses AI. Check for mistakes.
Comment on lines 322 to 332
if num_token > 1024:
group_mask = torch.zeros_like(group_scores)
for i in range(topk_group):
_, group_idx = torch.max(group_scores, dim=-1)
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
if i < topk_group - 1:
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
else:
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop-based approach for token counts > 1024 performs multiple sequential torch.max operations, which may be inefficient. Consider profiling both branches to ensure the loop approach actually provides better performance for large token counts, or document the performance characteristics that justify this implementation.

Suggested change
if num_token > 1024:
group_mask = torch.zeros_like(group_scores)
for i in range(topk_group):
_, group_idx = torch.max(group_scores, dim=-1)
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
if i < topk_group - 1:
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
else:
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]

Copilot uses AI. Check for mistakes.
@github-actions
Copy link

✅ CI Passed

All checks passed successfully against the following vllm commit:
0353d2e162cbda776d9dbfe026e65303204a7f1f

@xinyu-intel xinyu-intel force-pushed the dev/xinyu/grouped_topk branch from e191a3d to 0cb8f4f Compare December 4, 2025 01:54
@xinyu-intel xinyu-intel changed the title platform: optimize grouped topk op CustomOp: grouped topk Dec 4, 2025
@xinyu-intel xinyu-intel force-pushed the dev/xinyu/grouped_topk branch from 0cb8f4f to 2b24404 Compare December 4, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant