Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 2 additions & 40 deletions vllm_gaudi/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,46 +64,8 @@ def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
assert self.dp_group is not None
assert hidden_states.dim() == 2, "Input hidden states must be 2D"

dp_metadata = get_hpu_dp_metadata()
if dp_metadata is not None:
hidden_states_across_dp = dp_metadata.hidden_states_across_dp
router_logits_across_dp = dp_metadata.router_logits_across_dp
else:
# create hidden_states_across_dp tensor
input_size = hidden_states.size()
# Allocate output tensor.
output_size = list(input_size)
if is_sequence_parallel:
# if sequence parallel enabled, hidden states was already being chunked by sp_size
output_size[0] *= self.world_size
else:
output_size[0] *= self.dp_world_size
hidden_states_across_dp = torch.empty(output_size, dtype=hidden_states.dtype, device=hidden_states.device)

# create router_logits_across_dp tensor
router_logits_size = router_logits.size()
router_logits_output_size = list(router_logits_size)
if is_sequence_parallel:
router_logits_output_size[0] *= self.world_size
else:
router_logits_output_size[0] *= self.dp_world_size
router_logits_across_dp = torch.empty(router_logits_output_size,
dtype=router_logits.dtype,
device=router_logits.device)

torch.distributed.all_gather_into_tensor(
hidden_states_across_dp,
hidden_states,
group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)

torch.distributed.all_gather_into_tensor(
router_logits_across_dp,
router_logits,
group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group)
return hidden_states_across_dp, router_logits_across_dp
# Use dispatch_tensor in the plugin FusedMoEMethod for better performance
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False) -> torch.Tensor:
if htorch.utils.internal.is_lazy():
Expand Down
21 changes: 18 additions & 3 deletions vllm_gaudi/ops/hpu_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Fp8Config)
import vllm_gaudi.extension.ops as hpu_ops
from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOpFP8PerChannel, VllmMixtureOfExpertsOpFP8)
from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_tensor, get_hpu_dp_metadata


class Fp8LinearMethod(OrigFp8LinearMethod):
Expand Down Expand Up @@ -158,16 +159,30 @@ def apply(
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)

topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These type conversions are performed unconditionally but may be redundant if the tensors are already in the correct dtype. Consider adding guards to skip the conversion when unnecessary, e.g., if topk_ids.dtype != torch.int64: topk_ids = topk_ids.to(torch.int64).

Suggested change
topk_weights = topk_weights.to(x.dtype)
if topk_weights.dtype != x.dtype:
topk_weights = topk_weights.to(x.dtype)

Copilot uses AI. Check for mistakes.

if layer.dp_size > 1:
hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp
x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel)

topk_ids_across_dp = get_hpu_dp_metadata().topk_ids_across_dp
topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel)

topk_weights_across_dp = get_hpu_dp_metadata().topk_weights_across_dp
topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel)

topk_ids = topk_ids.view(*x.shape[:-1], -1)
topk_weights = topk_weights.view(*x.shape[:-1], -1)
output = layer.moe_op(
x,
topk_ids.to(torch.int64),
topk_weights.to(x.dtype),
topk_ids,
topk_weights,
permuted_weights=True,
activation=activation,
)
return output.view(*input_shape)
return output.view(*(x.size(0), *input_shape[1:]))


fp8.Fp8LinearMethod = Fp8LinearMethod
Expand Down
21 changes: 18 additions & 3 deletions vllm_gaudi/ops/hpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import vllm
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod)
from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp)
from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_tensor, get_hpu_dp_metadata


@UnquantizedFusedMoEMethod.register_oot
Expand Down Expand Up @@ -61,16 +62,30 @@ def forward_oot(
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(x.dtype)

topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These type conversions are performed unconditionally but may be redundant if the tensors are already in the correct dtype. Consider adding guards to skip the conversion when unnecessary, e.g., if topk_ids.dtype != torch.int64: topk_ids = topk_ids.to(torch.int64).

Suggested change
topk_weights = topk_weights.to(x.dtype)
if topk_weights.dtype != x.dtype:
topk_weights = topk_weights.to(x.dtype)

Copilot uses AI. Check for mistakes.

if layer.dp_size > 1:
hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp
x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel)

topk_ids_across_dp = get_hpu_dp_metadata().topk_ids_across_dp
topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel)

topk_weights_across_dp = get_hpu_dp_metadata().topk_weights_across_dp
topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel)

topk_ids = topk_ids.view(*x.shape[:-1], -1)
topk_weights = topk_weights.view(*x.shape[:-1], -1)

return layer.moe_op(
x,
topk_ids.to(torch.int64),
topk_weights.to(x.dtype),
topk_ids,
topk_weights,
permuted_weights=True,
activation=activation,
).view(*input_shape)
).view(*(x.size(0), *input_shape[1:]))


def reduce_output(self, states: torch.Tensor) -> torch.Tensor:
Expand Down
53 changes: 36 additions & 17 deletions vllm_gaudi/v1/worker/hpu_dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from vllm.config import VllmConfig
from dataclasses import dataclass
from typing import Optional
from vllm.distributed import get_dp_group, get_ep_group
from vllm.platforms import current_platform
import habana_frameworks.torch as htorch


@dataclass
class HPUDPMetadata:
hidden_states_across_dp: torch.Tensor
router_logits_across_dp: torch.Tensor
topk_ids_across_dp: torch.Tensor
topk_weights_across_dp: torch.Tensor
local_hidden_states: torch.Tensor

@staticmethod
Expand All @@ -27,35 +29,30 @@ def make(
dtype = vllm_config.model_config.dtype
device = current_platform.device_type

num_expert_names = [
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(vllm_config.model_config.hf_text_config, name, 0)
if num_experts > 0:
break
assert num_experts > 0, \
"No expert found in the model config. Please check the model config."
num_experts_per_tok = 0
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 32 initializes num_experts_per_tok to 0, which is immediately overwritten on line 33. Remove the redundant initialization on line 32.

Suggested change
num_experts_per_tok = 0

Copilot uses AI. Check for mistakes.
num_experts_per_tok = getattr(vllm_config.model_config.hf_text_config, "num_experts_per_tok", 0)
assert num_experts_per_tok > 0, ("No expert found in the model config. Please check the model config.")
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message 'No expert found in the model config' is misleading since the assertion checks num_experts_per_tok (number of experts per token), not the existence of experts. Update the message to reflect what is actually being validated, e.g., 'num_experts_per_tok must be greater than 0 in model config'.

Suggested change
assert num_experts_per_tok > 0, ("No expert found in the model config. Please check the model config.")
assert num_experts_per_tok > 0, ("num_experts_per_tok must be greater than 0 in model config. Please check the model config.")

Copilot uses AI. Check for mistakes.

hidden_states_across_dp = torch.empty(
(num_tokens_across_dp, hidden_size),
dtype=dtype,
device=device,
)
router_logits_across_dp = torch.empty(
(num_tokens_across_dp, num_experts),
topk_ids_across_dp = torch.empty(
(num_tokens_across_dp, num_experts_per_tok),
dtype=torch.int64,
device=device,
)
topk_weights_across_dp = torch.empty(
(num_tokens_across_dp, num_experts_per_tok),
dtype=dtype,
device=device,
)
local_num_tokens = (num_tokens //
tp_size) if vllm_config.parallel_config.use_sequence_parallel_moe else num_tokens
local_hidden_states = torch.empty((local_num_tokens, hidden_size), dtype=dtype, device=device)

return HPUDPMetadata(hidden_states_across_dp, router_logits_across_dp, local_hidden_states)
return HPUDPMetadata(hidden_states_across_dp, topk_ids_across_dp, topk_weights_across_dp, local_hidden_states)


_hpu_dp_metadata: Optional[HPUDPMetadata] = None
Expand Down Expand Up @@ -96,3 +93,25 @@ def set_hpu_dp_metadata(
def get_hpu_dp_metadata() -> Optional[HPUDPMetadata]:
"""Get the current HPU DP metadata."""
return _hpu_dp_metadata


def dispatch_tensor(input, output: torch.Tensor | None = None, is_sequence_parallel: bool = False) -> torch.Tensor:
assert get_dp_group() is not None
assert input.dim() == 2, "Input must be 2D"

if output is None:
# create output tensor
input_size = input.size()
# Allocate output tensor.
output_size = list(input_size)
if is_sequence_parallel:
# if sequence parallel enabled, input was already being chunked by sp_size
output_size[0] *= get_ep_group().world_size
else:
output_size[0] *= get_dp_group().world_size
output = torch.empty(output_size, dtype=input.dtype, device=input.device)

torch.distributed.all_gather_into_tensor(
output, input, group=get_ep_group().device_group if is_sequence_parallel else get_dp_group().device_group)

return output