-
Notifications
You must be signed in to change notification settings - Fork 101
Load KV scales for FP8 MLA #763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for loading KV (Key-Value) cache scales for FP8 Multi-Head Latent Attention (MLA) models. The implementation introduces specialized KV cache quantization methods for both MLA and MHA (Multi-Head Attention) architectures, along with custom parameter name remapping logic to handle various checkpoint formats.
Key Changes:
- Added specialized
HPUCompressedTensorsKVCacheMethodimplementations for MLA and MHA attention mechanisms - Implemented custom parameter name remapping for FP8 k/v_scale parameters from different checkpoint formats
- Updated MoE method to handle additional return value from
select_experts()
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: yiliu30 <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
✅ CI PassedAll checks passed successfully against the following vllm commit: |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
Overall, I think the PR is functionality ready. @yiliu30, my advice before merging is to
|
We do need to expose such a flag since Gaudi2 FP8 E4M3 is not standard. The flag is to let users be aware whether the quantized model they want to deploy is on Gaudi2 or Gaudi3. |
I would assume all fp8 models running on Gaudi2 would need 'VLLM_SCALE_ADJUSTMENT=True'. While in this case, the UT is running on Gaudi2 while set |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| """Process KV cache scales for cross-platform FP8 quantization compatibility. | ||
| Scale Adjustment Scenarios: | ||
| | No | Quant Plat | Serve Plat | Scale Adj | FP8 Max Quant | FP8 Max Deploy | | ||
| |----|--------------|------------|-----------|---------------|----------------| | ||
| | 1 | G2 | G2 | OFF | 240 | 240 | | ||
| | 2 | G3/Other GPU | G2 | ON | 448 | 240 | | ||
| | 3 | G3/Other GPU | G3 | OFF | 448 | 448 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @xuechendi, I’ve added more detailed docs for VLLM_SCALE_ADJUSTMENT here.
Whether we enable or disable VLLM_SCALE_ADJUSTMENT depends on the platform used for quantizing and deploying.
If we quantize the model on G2 with Autoround, we use 240 as the FP8 max. When deploying it on G2, no adjustment is required. The model used in the PR follows this scenario(1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, is there exiting info in quantized model name or config indicating if the weight is quantized on G2 or not?
How user tell if explicit "VLLM_SCALE_ADJUSTMENT=true/false" is needed for this model on G2
Hi @xuechendi , For 2, Given Scenario 1, there’s no need to adjust the scale, which could save considerable model loading time for large models, like DS R1.I’d prefer to keep it as is, and we could clarify
|
Signed-off-by: yiliu30 <[email protected]>
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
Thanks, @yiliu30 , one last comment, it will be great if we can check if the model was quantized by G2 or G3 in model config. then we might add a check in codes to decide if the VLLM_SCALE_ADJUSTMENT is needed or not. Functionality, I think the codes work, but with so much possible scenario, it might be bad for user experience |
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Hi @xuechendi , good idea! "fp8_dtype_flavor": "float8_e4m3fnuz", |
Signed-off-by: yiliu30 <[email protected]>
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
Closed by accident, reopening. |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
- lazy mode w/ scale=1.0 ``` # vllm (pretrained=/mnt/disk5/hf_models/DeepSeek-V2-Lite-Chat-FP8_STATIC-fp8-kv-2,tensor_parallel_size=8,enable_expert_parallel=True,max_model_len=4096,max_num_seqs=64,gpu_memory_utilization=0.85,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,max_num_batched_tokens=32768,kv_cache_dtype=fp8_inc), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 64 # |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| # |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| # |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6422|± |0.0132| # | | |strict-match | 5|exact_match|↑ |0.6399|± |0.0132| ``` - Lazy mode with the scale loaded from the checkpoint, which captured during the calibration process. ```bash vllm (pretrained=/mnt/disk5/hf_models/DeepSeek-V2-Lite-Chat-FP8_STATIC-fp8-kv-2,tensor_parallel_size=8,enable_expert_parallel=True,max_model_len=4096,max_num_seqs=64,gpu_memory_utilization=0.85,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,max_num_batched_tokens=32768,kv_cache_dtype=fp8_inc), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128 |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6513|± |0.0131| | | |strict-match | 5|exact_match|↑ |0.6505|± |0.0131| ``` Test model: https://huggingface.co/INC4AI/DeepSeek-V2-Lite-Chat-BF16-FP8-STATIC-FP8-KV-TEST-ONLY cc @hshen14 @thuang6 --------- Signed-off-by: yiliu30 <[email protected]> Signed-off-by: Jin, Youzhi <[email protected]>
Test model: https://huggingface.co/INC4AI/DeepSeek-V2-Lite-Chat-BF16-FP8-STATIC-FP8-KV-TEST-ONLY
cc @hshen14 @thuang6