Skip to content

Conversation

@yiliu30
Copy link
Contributor

@yiliu30 yiliu30 commented Dec 29, 2025

  • lazy mode w/ scale=1.0
# vllm (pretrained=/mnt/disk5/hf_models/DeepSeek-V2-Lite-Chat-FP8_STATIC-fp8-kv-2,tensor_parallel_size=8,enable_expert_parallel=True,max_model_len=4096,max_num_seqs=64,gpu_memory_utilization=0.85,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,max_num_batched_tokens=32768,kv_cache_dtype=fp8_inc), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 64
# |Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
# |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
# |gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6422|±  |0.0132|
# |     |       |strict-match    |     5|exact_match|↑  |0.6399|±  |0.0132|
  • Lazy mode with the scale loaded from the checkpoint, which captured during the calibration process.
vllm (pretrained=/mnt/disk5/hf_models/DeepSeek-V2-Lite-Chat-FP8_STATIC-fp8-kv-2,tensor_parallel_size=8,enable_expert_parallel=True,max_model_len=4096,max_num_seqs=64,gpu_memory_utilization=0.85,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,max_num_batched_tokens=32768,kv_cache_dtype=fp8_inc), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.6513|±  |0.0131|
|     |       |strict-match    |     5|exact_match||0.6505|±  |0.0131|

Test model: https://huggingface.co/INC4AI/DeepSeek-V2-Lite-Chat-BF16-FP8-STATIC-FP8-KV-TEST-ONLY

cc @hshen14 @thuang6

Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
@github-actions
Copy link

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for loading KV (Key-Value) cache scales for FP8 Multi-Head Latent Attention (MLA) models. The implementation introduces specialized KV cache quantization methods for both MLA and MHA (Multi-Head Attention) architectures, along with custom parameter name remapping logic to handle various checkpoint formats.

Key Changes:

  • Added specialized HPUCompressedTensorsKVCacheMethod implementations for MLA and MHA attention mechanisms
  • Implemented custom parameter name remapping for FP8 k/v_scale parameters from different checkpoint formats
  • Updated MoE method to handle additional return value from select_experts()

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: yiliu30 <[email protected]>
@github-actions
Copy link

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
@github-actions
Copy link

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@yiliu30 yiliu30 changed the title Load KV scales for FP8 MLA [WIP]Load KV scales for FP8 MLA Dec 30, 2025
@github-actions
Copy link

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

@github-actions
Copy link

github-actions bot commented Jan 4, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
@github-actions
Copy link

github-actions bot commented Jan 4, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

@xuechendi
Copy link
Collaborator

Overall, I think the PR is functionality ready. @yiliu30, my advice before merging is to

  1. add a new fp8 document to better explain current approaches vs previous INC approaches and other OOB redhat compressed tensor with fp8 kv scaling
  2. try to not need user specifically setting "'VLLM_SCALE_ADJUSTMENT=false" on Gaudi2 so we don't confuse users

@hshen14
Copy link

hshen14 commented Jan 5, 2026

Overall, I think the PR is functionality ready. @yiliu30, my advice before merging is to

  1. add a new fp8 document to better explain current approaches vs previous INC approaches and other OOB redhat compressed tensor with fp8 kv scaling
  2. try to not need user specifically setting "'VLLM_SCALE_ADJUSTMENT=false" on Gaudi2 so we don't confuse users

We do need to expose such a flag since Gaudi2 FP8 E4M3 is not standard. The flag is to let users be aware whether the quantized model they want to deploy is on Gaudi2 or Gaudi3.

@xuechendi
Copy link
Collaborator

We do need to expose such a flag since Gaudi2 FP8 E4M3 is not standard. The flag is to let users be aware whether the quantized model they want to deploy is on Gaudi2 or Gaudi3.

I would assume all fp8 models running on Gaudi2 would need 'VLLM_SCALE_ADJUSTMENT=True'. While in this case, the UT is running on Gaudi2 while set VLLM_SCALE_ADJUSTMENT=false. So I am little confused.

@github-actions
Copy link

github-actions bot commented Jan 6, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Comment on lines 755 to 763
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process KV cache scales for cross-platform FP8 quantization compatibility.
Scale Adjustment Scenarios:
| No | Quant Plat | Serve Plat | Scale Adj | FP8 Max Quant | FP8 Max Deploy |
|----|--------------|------------|-----------|---------------|----------------|
| 1 | G2 | G2 | OFF | 240 | 240 |
| 2 | G3/Other GPU | G2 | ON | 448 | 240 |
| 3 | G3/Other GPU | G3 | OFF | 448 | 448 |
Copy link
Contributor Author

@yiliu30 yiliu30 Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xuechendi, I’ve added more detailed docs for VLLM_SCALE_ADJUSTMENT here.
Whether we enable or disable VLLM_SCALE_ADJUSTMENT depends on the platform used for quantizing and deploying.
If we quantize the model on G2 with Autoround, we use 240 as the FP8 max. When deploying it on G2, no adjustment is required. The model used in the PR follows this scenario(1).

cc @hshen14 @thuang6

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, is there exiting info in quantized model name or config indicating if the weight is quantized on G2 or not?
How user tell if explicit "VLLM_SCALE_ADJUSTMENT=true/false" is needed for this model on G2

@yiliu30
Copy link
Contributor Author

yiliu30 commented Jan 6, 2026

Overall, I think the PR is functionality ready. @yiliu30, my advice before merging is to

  1. add a new fp8 document to better explain current approaches vs previous INC approaches and other OOB redhat compressed tensor with fp8 kv scaling
  2. try to not need user specifically setting "'VLLM_SCALE_ADJUSTMENT=false" on Gaudi2 so we don't confuse users

Hi @xuechendi ,
For 1, how about proceeding with merging this PR? I’ll prepare the documentation in a separate PR later.

For 2, Given Scenario 1, there’s no need to adjust the scale, which could save considerable model loading time for large models, like DS R1.I’d prefer to keep it as is, and we could clarify VLLM_SCALE_ADJUSTMENT like this: enable it when quantizing a model with the standard FP8 range and deploying it on G2. What do you think?

No Quant Plat Serve Plat Scale Adj FP8 Max Quant FP8 Max Deploy
1 G2 G2 OFF 240 240
2 G3/Other GPU G2 ON 448 240
3 G3/Other GPU G3 OFF 448 448

cc @hshen14 @thuang6

Signed-off-by: yiliu30 <[email protected]>
@github-actions
Copy link

github-actions bot commented Jan 6, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

@xuechendi
Copy link
Collaborator

Thanks, @yiliu30 , one last comment, it will be great if we can check if the model was quantized by G2 or G3 in model config. then we might add a check in codes to decide if the VLLM_SCALE_ADJUSTMENT is needed or not. Functionality, I think the codes work, but with so much possible scenario, it might be bad for user experience

Signed-off-by: yiliu30 <[email protected]>
@yiliu30
Copy link
Contributor Author

yiliu30 commented Jan 7, 2026

Thanks, @yiliu30 , one last comment, it will be great if we can check if the model was quantized by G2 or G3 in model config. then we might add a check in codes to decide if the VLLM_SCALE_ADJUSTMENT is needed or not. Functionality, I think the codes work, but with so much possible scenario, it might be bad for user experience

Hi @xuechendi , good idea!
Added a new attribute in the quant config and update the VLLM_SCALE_ADJUSTMENT automatically.

 "fp8_dtype_flavor": "float8_e4m3fnuz",

https://huggingface.co/INC4AI/DeepSeek-V2-Lite-Chat-FP8-STATIC-ATTN-TEST-ONLY/blob/main/config.json#L77

Signed-off-by: yiliu30 <[email protected]>
@github-actions
Copy link

github-actions bot commented Jan 7, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

@yiliu30
Copy link
Contributor Author

yiliu30 commented Jan 7, 2026

Closed by accident, reopening.

@github-actions
Copy link

github-actions bot commented Jan 7, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
b3a2bdf1ac90748d58bf8c05f8d0095ede5c7eca

@xuechendi xuechendi merged commit 8699778 into vllm-project:main Jan 7, 2026
101 checks passed
jinyouzhi pushed a commit to jinyouzhi/vllm-gaudi that referenced this pull request Jan 14, 2026
- lazy mode w/ scale=1.0
```
# vllm (pretrained=/mnt/disk5/hf_models/DeepSeek-V2-Lite-Chat-FP8_STATIC-fp8-kv-2,tensor_parallel_size=8,enable_expert_parallel=True,max_model_len=4096,max_num_seqs=64,gpu_memory_utilization=0.85,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,max_num_batched_tokens=32768,kv_cache_dtype=fp8_inc), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 64
# |Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
# |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
# |gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6422|±  |0.0132|
# |     |       |strict-match    |     5|exact_match|↑  |0.6399|±  |0.0132|
```

- Lazy mode with the scale loaded from the checkpoint, which captured
during the calibration process.
```bash
vllm (pretrained=/mnt/disk5/hf_models/DeepSeek-V2-Lite-Chat-FP8_STATIC-fp8-kv-2,tensor_parallel_size=8,enable_expert_parallel=True,max_model_len=4096,max_num_seqs=64,gpu_memory_utilization=0.85,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,max_num_batched_tokens=32768,kv_cache_dtype=fp8_inc), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6513|±  |0.0131|
|     |       |strict-match    |     5|exact_match|↑  |0.6505|±  |0.0131|
```
Test model:
https://huggingface.co/INC4AI/DeepSeek-V2-Lite-Chat-BF16-FP8-STATIC-FP8-KV-TEST-ONLY

cc @hshen14 @thuang6

---------

Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: Jin, Youzhi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants