diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index 1a887b50e1a3..903d84270d32 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -24,6 +24,7 @@ at::Tensor quantize(torch::Tensor& out, torch::Tensor& val, + torch::Tensor& scale, int group_size, int stochastic_rounding, int q_bits, @@ -59,6 +60,7 @@ at::Tensor quantize(torch::Tensor& out, void dequantize(torch::Tensor& val, torch::Tensor& val_q, + torch::Tensor& scale, int group_size, int q_mantisa_bits, int q_exponent_bits) diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 1586f220907e..69c21eaf693b 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -47,7 +47,7 @@ def __init__(self, quantization_config) -> None: super().__init__(group_size=quantization_config.group_size) if fp_quant_module is None: fp_quant_module = FPQuantizerBuilder().load() - self.is_python_impl = getattr(fp_quant_module, "PYTHON_IMPL", False) + self.cuda_impl = getattr(fp_quant_module, "CUDA_IMPL", True) self.q_config = quantization_config self.orig_dtype = None @@ -85,7 +85,7 @@ def quantize(self, # group_size should be the minimal number between the defined group size and number of elements in tensor. group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8 # CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group - if not self.is_python_impl: + if self.cuda_impl: group_size += 4 # CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel. self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device) @@ -95,7 +95,7 @@ def quantize(self, out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits, q_mantisa_bits) if return_meta_tensor: - if not self.is_python_impl: + if self.cuda_impl: data, self.scale = out.split(group_size, dim=-1) data = data.contiguous().reshape(input.shape) else: @@ -136,11 +136,11 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None and not self.is_python_impl: + if scale is not None and self.cuda_impl: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() - elif scale is not None and self.is_python_impl: + elif scale is not None and not self.cuda_impl: group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8 input_q = input_q.reshape(-1, group_size) @@ -174,7 +174,7 @@ def selective_dequantize(self, assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None and not self.is_python_impl: + if scale is not None and self.cuda_impl: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() diff --git a/op_builder/hpu/fp_quantizer.py b/op_builder/hpu/fp_quantizer.py index b00cb0cc43cd..c74affb55045 100644 --- a/op_builder/hpu/fp_quantizer.py +++ b/op_builder/hpu/fp_quantizer.py @@ -46,7 +46,7 @@ def get_quant_range(q_bits=None): class FPQuantizer: - PYTHON_IMPL = True + CUDA_IMPL = False @classmethod def selective_dequantize(cls, val_q, scales, indexes, group_size, q_mantisa_bits, q_exponent_bits): diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py index ee7c5bc2d7f1..d068a05b77bb 100644 --- a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -26,10 +26,8 @@ def test_fp_quant(dtype, q_bits, M): device_name = get_accelerator().device_name() quantization_group_size = 128 - - quant_config = QuantizationConfig() - quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() - quant_config.group_size = quantization_group_size + quant_config = QuantizationConfig(q_dtype=FPQuantizerBuilder.get_default_quant_dtype(), + group_size=quantization_group_size) fpq = FP_Quantize(quantization_config=quant_config) N = 8192