-
Notifications
You must be signed in to change notification settings - Fork 77
CustomOp: grouped topk #647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds an optimized grouped top-k operation implementation for the Gaudi platform. The optimization involves intelligent handling of expert selection for mixture-of-experts (MoE) models, with special logic for batch sizes and optional score correction bias.
- Adds
has_optimized_grouped_topk()method returning True to indicate platform support - Implements
grouped_topk()method with scoring functions (softmax/sigmoid), group-based expert selection, and optional bias correction - Includes adaptive algorithm selection based on token count threshold (1024)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
vllm_gaudi/platform.py
Outdated
| from vllm.model_executor.layers.batch_invariant import ( | ||
| vllm_is_batch_invariant, ) |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import statements should be placed at the top of the file, not within a method. Move this import to the module level to follow Python conventions and improve code clarity.
vllm_gaudi/platform.py
Outdated
| top1_val, top1_idx = torch.max(scores_tmp, dim=-1) | ||
| scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | ||
| group_scores, top2_idx = torch.max(scores_tmp, dim=-1) | ||
| group_scores.add_(top1_val) | ||
| else: | ||
| group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] | ||
| if num_token > 1024: | ||
| group_mask = torch.zeros_like(group_scores) | ||
| for i in range(topk_group): | ||
| _, group_idx = torch.max(group_scores, dim=-1) | ||
| group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) | ||
| if i < topk_group - 1: | ||
| group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | ||
| else: | ||
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] | ||
| group_mask = torch.zeros_like(group_scores) # [n, n_group] |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The variable names top1_val, top1_idx, and top2_idx are inconsistent with their actual usage. top1_val stores maximum values but group_scores stores the second maximum values. Consider renaming to first_max_val, first_max_idx, and second_max_idx for clarity, or top2_val instead of group_scores to maintain consistency.
| top1_val, top1_idx = torch.max(scores_tmp, dim=-1) | |
| scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | |
| group_scores, top2_idx = torch.max(scores_tmp, dim=-1) | |
| group_scores.add_(top1_val) | |
| else: | |
| group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] | |
| if num_token > 1024: | |
| group_mask = torch.zeros_like(group_scores) | |
| for i in range(topk_group): | |
| _, group_idx = torch.max(group_scores, dim=-1) | |
| group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) | |
| if i < topk_group - 1: | |
| group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | |
| else: | |
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] | |
| group_mask = torch.zeros_like(group_scores) # [n, n_group] | |
| first_max_val, first_max_idx = torch.max(scores_tmp, dim=-1) | |
| scores_tmp.scatter_(-1, first_max_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | |
| second_max_val, second_max_idx = torch.max(scores_tmp, dim=-1) | |
| second_max_val.add_(first_max_val) | |
| else: | |
| second_max_val = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] | |
| if num_token > 1024: | |
| group_mask = torch.zeros_like(second_max_val) | |
| for i in range(topk_group): | |
| _, group_idx = torch.max(second_max_val, dim=-1) | |
| group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) | |
| if i < topk_group - 1: | |
| second_max_val.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | |
| else: | |
| group_idx = torch.topk(second_max_val, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] | |
| group_mask = torch.zeros_like(second_max_val) # [n, n_group] |
vllm_gaudi/platform.py
Outdated
| group_scores, top2_idx = torch.max(scores_tmp, dim=-1) | ||
| group_scores.add_(top1_val) | ||
| else: | ||
| group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The threshold value 1024 for algorithm selection lacks explanation. Add a comment explaining why this specific threshold was chosen and what trade-offs exist between the two approaches.
| group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] | |
| group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] | |
| # For large batches (num_token > 1024), use a loop-based approach to avoid potential performance and memory issues with torch.topk on large tensors. | |
| # The threshold 1024 was chosen based on empirical benchmarks: for smaller batches, the vectorized torch.topk is faster and simpler, | |
| # while for larger batches, the loop-based approach is more memory-efficient and avoids slowdowns. Adjust as needed for your hardware. |
vllm_gaudi/platform.py
Outdated
| if num_token > 1024: | ||
| group_mask = torch.zeros_like(group_scores) | ||
| for i in range(topk_group): | ||
| _, group_idx = torch.max(group_scores, dim=-1) | ||
| group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) | ||
| if i < topk_group - 1: | ||
| group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | ||
| else: | ||
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] | ||
| group_mask = torch.zeros_like(group_scores) # [n, n_group] | ||
| group_mask.scatter_(1, group_idx, 1) # [n, n_group] |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop-based approach for token counts > 1024 performs multiple sequential torch.max operations, which may be inefficient. Consider profiling both branches to ensure the loop approach actually provides better performance for large token counts, or document the performance characteristics that justify this implementation.
| if num_token > 1024: | |
| group_mask = torch.zeros_like(group_scores) | |
| for i in range(topk_group): | |
| _, group_idx = torch.max(group_scores, dim=-1) | |
| group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) | |
| if i < topk_group - 1: | |
| group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | |
| else: | |
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] | |
| group_mask = torch.zeros_like(group_scores) # [n, n_group] | |
| group_mask.scatter_(1, group_idx, 1) # [n, n_group] | |
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] | |
| group_mask = torch.zeros_like(group_scores) # [n, n_group] | |
| group_mask.scatter_(1, group_idx, 1) # [n, n_group] |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
e191a3d to
0cb8f4f
Compare
0cb8f4f to
2b24404
Compare
Signed-off-by: Xinyu Chen <[email protected]>
Signed-off-by: Xinyu Chen <[email protected]>
depends on vllm-project/vllm#29575