-
Notifications
You must be signed in to change notification settings - Fork 78
DP: dispatch tensor in FusedMoEMethod #680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |||||||
| import vllm | ||||||||
| from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) | ||||||||
| from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) | ||||||||
| from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_tensor, get_hpu_dp_metadata | ||||||||
|
|
||||||||
|
|
||||||||
| @UnquantizedFusedMoEMethod.register_oot | ||||||||
|
|
@@ -61,16 +62,30 @@ def forward_oot( | |||||||
| topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) | ||||||||
| topk_weights /= topk_weights.sum(dim=-1, keepdim=True) | ||||||||
| topk_weights = topk_weights.to(x.dtype) | ||||||||
|
|
||||||||
| topk_ids = topk_ids.to(torch.int64) | ||||||||
| topk_weights = topk_weights.to(x.dtype) | ||||||||
|
||||||||
| topk_weights = topk_weights.to(x.dtype) | |
| if topk_weights.dtype != x.dtype: | |
| topk_weights = topk_weights.to(x.dtype) |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3,14 +3,16 @@ | |||||
| from vllm.config import VllmConfig | ||||||
| from dataclasses import dataclass | ||||||
| from typing import Optional | ||||||
| from vllm.distributed import get_dp_group, get_ep_group | ||||||
| from vllm.platforms import current_platform | ||||||
| import habana_frameworks.torch as htorch | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class HPUDPMetadata: | ||||||
| hidden_states_across_dp: torch.Tensor | ||||||
| router_logits_across_dp: torch.Tensor | ||||||
| topk_ids_across_dp: torch.Tensor | ||||||
| topk_weights_across_dp: torch.Tensor | ||||||
| local_hidden_states: torch.Tensor | ||||||
|
|
||||||
| @staticmethod | ||||||
|
|
@@ -27,35 +29,30 @@ def make( | |||||
| dtype = vllm_config.model_config.dtype | ||||||
| device = current_platform.device_type | ||||||
|
|
||||||
| num_expert_names = [ | ||||||
| "moe_num_experts", # Dbrx | ||||||
| "num_experts", # Jamba | ||||||
| "n_routed_experts", # DeepSeek | ||||||
| "num_local_experts", # Mixtral | ||||||
| ] | ||||||
| num_experts = 0 | ||||||
| for name in num_expert_names: | ||||||
| num_experts = getattr(vllm_config.model_config.hf_text_config, name, 0) | ||||||
| if num_experts > 0: | ||||||
| break | ||||||
| assert num_experts > 0, \ | ||||||
| "No expert found in the model config. Please check the model config." | ||||||
| num_experts_per_tok = 0 | ||||||
|
||||||
| num_experts_per_tok = 0 |
Copilot
AI
Dec 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message 'No expert found in the model config' is misleading since the assertion checks num_experts_per_tok (number of experts per token), not the existence of experts. Update the message to reflect what is actually being validated, e.g., 'num_experts_per_tok must be greater than 0 in model config'.
| assert num_experts_per_tok > 0, ("No expert found in the model config. Please check the model config.") | |
| assert num_experts_per_tok > 0, ("num_experts_per_tok must be greater than 0 in model config. Please check the model config.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These type conversions are performed unconditionally but may be redundant if the tensors are already in the correct dtype. Consider adding guards to skip the conversion when unnecessary, e.g.,
if topk_ids.dtype != torch.int64: topk_ids = topk_ids.to(torch.int64).