diff --git a/vllm_gaudi/ops/hpu_compressed_tensors.py b/vllm_gaudi/ops/hpu_compressed_tensors.py index ea983f6e4..518b9d749 100644 --- a/vllm_gaudi/ops/hpu_compressed_tensors.py +++ b/vllm_gaudi/ops/hpu_compressed_tensors.py @@ -14,8 +14,13 @@ PackedvLLMParameter, RowvLLMParameter) from vllm.model_executor.layers.quantization.compressed_tensors import (compressed_tensors) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod as OrigCompressedTensorsLinearMethod, CompressedTensorsConfig, - CompressedTensorsMoEMethod, CompressedTensorsKVCacheMethod) + CompressedTensorsLinearMethod as OrigCompressedTensorsLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + CompressedTensorsMoEMethod, + CompressedTensorsKVCacheMethod, + SparsityCompressionConfig, +) from vllm.model_executor.layers.quantization.compressed_tensors import (compressed_tensors_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( # noqa: E501 CompressedTensorsScheme, CompressedTensorsWNA16) @@ -807,6 +812,37 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: class HPUCompressedTensorsConfig(CompressedTensorsConfig): + def __init__( + self, + target_scheme_map: dict[str, Any], + ignore: list[str], + quant_format: str, + sparsity_scheme_map: dict[str, SparsityCompressionConfig], + sparsity_ignore_list: list[str], + kv_cache_scheme: dict[str, Any] | None = None, + config: dict[str, Any] | None = None, + transform_config: dict[str, Any] | None = None, + total_num_heads: int | None = None, + total_num_kv_heads: int | None = None, + ): + super().__init__( + target_scheme_map, + ignore, + quant_format, + sparsity_scheme_map, + sparsity_ignore_list, + kv_cache_scheme, + config, + transform_config, + total_num_heads, + total_num_kv_heads, + ) + # Fix https://github.com/vllm-project/vllm/pull/30141 + # LLMC overrides the `kv_cache_dtype` to 'fp8', while HPU uses 'fp8_inc'. + if getattr(self, "kv_cache_scheme", None) is not None: + self.kv_cache_dtype = "fp8_inc" + self.kv_cache_scheme = None + def get_quant_method( self, layer: torch.nn.Module,