Skip to content

add 4bits channel-wised quantization capability for MatMulNbits Op #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions onnxruntime/core/mlas/lib/q4_dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,26 @@ MlasQuantizeBlockwise(
}
break;

case 3072:
if (columnwise) {
BlockwiseQuantizer<T, 3072, qbits, true>::quantizeAndTranspose(
dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool);
} else {
BlockwiseQuantizer<T, 3072, qbits, false>::quantizeAndTranspose(
dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool);
}
break;

case 8192:
if (columnwise) {
BlockwiseQuantizer<T, 8192, qbits, true>::quantizeAndTranspose(
dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool);
} else {
BlockwiseQuantizer<T, 8192, qbits, false>::quantizeAndTranspose(
dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool);
}
break;

default:
// Only block size 16, 32, 64, 128, 256 are supported.
break;
Expand Down
28 changes: 25 additions & 3 deletions onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
quant_format=QuantFormat.QOperator,
op_types_to_quantize: tuple[str, ...] | None = None,
quant_axes: tuple[tuple[str, int], ...] | None = None,
channel_wised_quantize: bool = False,
):
"""
This is a class for weight only affine quantization configuration.
Expand All @@ -212,6 +213,8 @@ def __init__(
set of operator types to quantize.
quant_axes (dict[str, int], optional):
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
channel_wised_quantize (bool, optional):
whether use K (rows) as block size, channel wised quantization. Default is False.
"""
super().__init__(
algorithm="DEFAULT",
Expand All @@ -223,6 +226,7 @@ def __init__(
self.is_symmetric = is_symmetric
self.bits = 4
self.accuracy_level = accuracy_level
self.channel_wised_quantize = channel_wised_quantize


class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
Expand Down Expand Up @@ -728,7 +732,8 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.nd
raise ValueError("Current int4 block quantization only supports 2D tensors!")
rows, cols = fp32weight.shape

block_size = self.config.block_size
# block size equal to rows (K) if channel wised quantize enabled
block_size = rows if self.config.channel_wised_quantize else self.config.block_size
k_blocks = (rows + block_size - 1) // block_size

if self.config.quant_format == QuantFormat.QOperator:
Expand All @@ -745,6 +750,22 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.nd
quantize_matmul_4bits(
packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
)

# Quantize to int4 [-8, 7] when channel wise and symmetric quantize enabled.
# The packed uint4 is already symmetric quantization and +8 to uint4 [0, 15], bring it back to int4 [-8, 7].
# It saved a sub op when model infer, also meets the optimization pattern in Intel NPU to raise performance.
# Ref: https://github.com/openvinotoolkit/openvino.genai/tree/master/samples/python/text_generation#npu-support
keep_int4 = True if self.config.channel_wised_quantize and self.config.is_symmetric else False
if keep_int4:
# Get uint4 Quantized data, convert to int4 by -8, and repack as uint8
high_4bit_u = (packed >> 4) & 0x0F
low_4bit_u = packed & 0x0F
high_4bit_i = high_4bit_u.astype(np.int8) - 8
low_4bit_i = low_4bit_u.astype(np.int8) - 8
high_4bit_requantized = np.clip(high_4bit_i, -8, 7) & 0x0F
low_4bit_requantized = np.clip(low_4bit_i, -8, 7) & 0x0F
packed = (high_4bit_requantized << 4) | low_4bit_requantized
packed = packed.astype(np.uint8)
else:
packed = np.zeros((rows * cols + 1) // 2, dtype="uint8")
zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
Expand Down Expand Up @@ -801,7 +822,7 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
kwargs["K"] = rows
kwargs["N"] = cols
kwargs["bits"] = 4
kwargs["block_size"] = self.config.block_size
kwargs["block_size"] = rows if self.config.channel_wised_quantize else self.config.block_size
if self.config.accuracy_level is not None:
kwargs["accuracy_level"] = self.config.accuracy_level

Expand All @@ -826,7 +847,8 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
)
dq_input_names.append(zp_tensor.name)
b_graph.initializer.extend([zp_tensor])
dq_kwargs = {"axis": 0, "block_size": self.config.block_size}
rows, cols = b_ndarray.shape
dq_kwargs = {"axis": 0, "block_size": rows if self.config.channel_wised_quantize else self.config.block_size}
dq_node = onnx.helper.make_node(
"DequantizeLinear",
inputs=dq_input_names,
Expand Down
Loading