Skip to content

Commit

Permalink
Code review fixes - change to cuda_impl instead of python_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
oelayan7 committed Jan 28, 2025
1 parent 7399fc7 commit abd494b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 2 additions & 0 deletions csrc/fp_quantizer/fp_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

at::Tensor quantize(torch::Tensor& out,
torch::Tensor& val,
torch::Tensor& scale,
int group_size,
int stochastic_rounding,
int q_bits,
Expand Down Expand Up @@ -59,6 +60,7 @@ at::Tensor quantize(torch::Tensor& out,

void dequantize(torch::Tensor& val,
torch::Tensor& val_q,
torch::Tensor& scale,
int group_size,
int q_mantisa_bits,
int q_exponent_bits)
Expand Down
12 changes: 6 additions & 6 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, quantization_config) -> None:
super().__init__(group_size=quantization_config.group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()
self.is_python_impl = getattr(fp_quant_module, "PYTHON_IMPL", False)
self.cuda_impl = getattr(fp_quant_module, "CUDA_IMPL", True)
self.q_config = quantization_config

self.orig_dtype = None
Expand Down Expand Up @@ -85,7 +85,7 @@ def quantize(self,
# group_size should be the minimal number between the defined group size and number of elements in tensor.
group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8
# CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group
if not self.is_python_impl:
if self.cuda_impl:
group_size += 4
# CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel.
self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device)
Expand All @@ -95,7 +95,7 @@ def quantize(self,
out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits,
q_mantisa_bits)
if return_meta_tensor:
if not self.is_python_impl:
if self.cuda_impl:
data, self.scale = out.split(group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
else:
Expand Down Expand Up @@ -136,11 +136,11 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None and not self.is_python_impl:
if scale is not None and self.cuda_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
elif scale is not None and self.is_python_impl:
elif scale is not None and not self.cuda_impl:
group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8
input_q = input_q.reshape(-1, group_size)

Expand Down Expand Up @@ -174,7 +174,7 @@ def selective_dequantize(self,
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None and not self.is_python_impl:
if scale is not None and self.cuda_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
Expand Down
2 changes: 1 addition & 1 deletion op_builder/hpu/fp_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_quant_range(q_bits=None):


class FPQuantizer:
PYTHON_IMPL = True
CUDA_IMPL = False

@classmethod
def selective_dequantize(cls, val_q, scales, indexes, group_size, q_mantisa_bits, q_exponent_bits):
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/ops/fp_quantizer/test_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
def test_fp_quant(dtype, q_bits, M):
device_name = get_accelerator().device_name()
quantization_group_size = 128

quant_config = QuantizationConfig()
quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()
quant_config.group_size = quantization_group_size
quant_config = QuantizationConfig(q_dtype=FPQuantizerBuilder.get_default_quant_dtype(),
group_size=quantization_group_size)
fpq = FP_Quantize(quantization_config=quant_config)

N = 8192
Expand Down

0 comments on commit abd494b

Please sign in to comment.