Skip to content

Conversation

@xinyu-intel
Copy link
Contributor

depends on #680

Copilot AI review requested due to automatic review settings December 4, 2025 13:12
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements FP8 hidden state dispatching in INC (Intel Neural Compressor) for data parallel (DP) execution. The main purpose is to optimize MoE (Mixture of Experts) layer communication by dispatching FP8-quantized hidden states and routing information across DP ranks, rather than full precision tensors.

Key changes:

  • Replaces router logits dispatching with topk IDs and weights dispatching for more efficient communication
  • Adds FP8 dtype support for hidden states when INC quantization is enabled
  • Introduces dispatch functions that are passed to MoE operators for flexible tensor distribution

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
vllm_gaudi/v1/worker/hpu_dp_utils.py Adds dispatch_tensor and dispatch_hidden_states functions; updates HPUDPMetadata to store topk_ids and topk_weights instead of router_logits; adds FP8 dtype detection for INC quantization
vllm_gaudi/ops/hpu_fused_moe.py Integrates dispatch functions into unquantized MoE processing; dispatches topk_ids and topk_weights when DP is enabled
vllm_gaudi/ops/hpu_fp8.py Integrates dispatch functions into FP8 MoE processing; dispatches topk_ids and topk_weights when DP is enabled
vllm_gaudi/extension/ops.py Adds dispatch_fn parameter to VllmMixtureOfExpertsOp, VllmMixtureOfExpertsOpFP8, and VllmMixtureOfExpertsOpFP8PerChannel constructors with _get_dispatch_func accessor method
vllm_gaudi/distributed/device_communicators/hpu_communicator.py Removes dispatch implementation, delegating to plugin FusedMoEMethod for better performance

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +37 to +38
quant_config = os.getenv("QUANT_CONFIG", None) is not None
is_quant_with_inc = vllm_config.model_config.quantization == "inc" or quant_config
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name quant_config is ambiguous as it suggests it holds configuration data, but it actually holds a boolean indicating whether the environment variable exists. Consider renaming to has_quant_config or quant_config_enabled for clarity.

Suggested change
quant_config = os.getenv("QUANT_CONFIG", None) is not None
is_quant_with_inc = vllm_config.model_config.quantization == "inc" or quant_config
quant_config_enabled = os.getenv("QUANT_CONFIG", None) is not None
is_quant_with_inc = vllm_config.model_config.quantization == "inc" or quant_config_enabled

Copilot uses AI. Check for mistakes.
Comment on lines +507 to +508
fn = self.dispatch_func
return fn
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate variable fn is unnecessary. The method can directly return self.dispatch_func, simplifying the code and improving readability.

Suggested change
fn = self.dispatch_func
return fn
return self.dispatch_func

Copilot uses AI. Check for mistakes.
Comment on lines +958 to 960
fn = self.dispatch_func
return fn

Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate variable fn is unnecessary. The method can directly return self.dispatch_func, simplifying the code and improving readability.

Suggested change
fn = self.dispatch_func
return fn
return self.dispatch_func

Copilot uses AI. Check for mistakes.
Comment on lines +1030 to +1031
fn = self.dispatch_func
return fn
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate variable fn is unnecessary. The method can directly return self.dispatch_func, simplifying the code and improving readability.

Suggested change
fn = self.dispatch_func
return fn
return self.dispatch_func

Copilot uses AI. Check for mistakes.
@github-actions
Copy link

github-actions bot commented Dec 4, 2025

🚧 CI Blocked

The main CI workflow was not started for the following reason:

This is a Draft PR. Please mark it as 'Ready for Review' to trigger the CI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant