Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sub-channel quantized type implementation #120172

Merged
merged 7 commits into from
Mar 23, 2025

Conversation

sdasgup3
Copy link
Contributor

This is an implementation for RFC: Supporting Sub-Channel Quantization in MLIR.

In order to make the review process easier, the PR has been divided into the following commit labels:

  1. Add implementation for sub-channel type: Includes the class design for UniformQuantizedSubChannelType, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered.
  2. Add implementation for sub-channel type: Lowering of quant.qcast and quant.dcast operations to Linalg operations.
  3. Adding C/Python Apis: We first define he C-APIs and build the Python-APIs on top of those.
  4. Add pass to normalize generic ....: This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible.

A design note:

  • Explicitly storing the quantized_dimensions, even when they can be derived for ranked tensor.
    While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked
    data tensors (ref for background), there are cases where this can lead to ambiguity and issues with round-tripping.
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>

The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome!

PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.

@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-quant

Author: Sandeep Dasgupta (sdasgup3)

Changes

This is an implementation for RFC: Supporting Sub-Channel Quantization in MLIR.

In order to make the review process easier, the PR has been divided into the following commit labels:

  1. Add implementation for sub-channel type: Includes the class design for UniformQuantizedSubChannelType, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered.
  2. Add implementation for sub-channel type: Lowering of quant.qcast and quant.dcast operations to Linalg operations.
  3. Adding C/Python Apis: We first define he C-APIs and build the Python-APIs on top of those.
  4. Add pass to normalize generic ....: This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible.

A design note:

  • Explicitly storing the quantized_dimensions, even when they can be derived for ranked tensor.
    While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked
    data tensors (ref for background), there are cases where this can lead to ambiguity and issues with round-tripping.
Consider the example: tensor&lt;2x4x!quant.uniform&lt;i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}&gt;&gt;

The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome!

PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush.


Patch is 127.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120172.diff

25 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/Quant.h (+41)
  • (modified) mlir/include/mlir/Dialect/Quant/IR/QuantBase.td (+183-9)
  • (modified) mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td (+21-9)
  • (modified) mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h (+131)
  • (modified) mlir/include/mlir/Dialect/Quant/Transforms/Passes.td (+33)
  • (modified) mlir/lib/Bindings/Python/DialectQuant.cpp (+74)
  • (modified) mlir/lib/CAPI/Dialect/Quant.cpp (+56)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp (+1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantOps.cpp (+123-24)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+119-2)
  • (modified) mlir/lib/Dialect/Quant/IR/TypeDetail.h (+122)
  • (modified) mlir/lib/Dialect/Quant/IR/TypeParser.cpp (+278-40)
  • (modified) mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp (+197-84)
  • (added) mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp (+179)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi (+21-1)
  • (modified) mlir/test/CAPI/quant.c (+126)
  • (modified) mlir/test/Dialect/Quant/Bytecode/types.mlir (+9)
  • (modified) mlir/test/Dialect/Quant/invalid.mlir (+68)
  • (modified) mlir/test/Dialect/Quant/lower-quant-ops.mlir (+64)
  • (added) mlir/test/Dialect/Quant/normalize-quant-types.mlir (+51)
  • (modified) mlir/test/Dialect/Quant/ops.mlir (+19)
  • (modified) mlir/test/Dialect/Quant/parse-uniform-invalid.mlir (+95-5)
  • (modified) mlir/test/Dialect/Quant/parse-uniform.mlir (+18)
  • (modified) mlir/test/python/dialects/quant.py (+46)
diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h
index a7d98dc3c1a775..dc0989e53344ea 100644
--- a/mlir/include/mlir-c/Dialect/Quant.h
+++ b/mlir/include/mlir-c/Dialect/Quant.h
@@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type);
 MLIR_CAPI_EXPORTED bool
 mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// UniformQuantizedSubChannelType
+//===---------------------------------------------------------------------===//
+
+/// Returns `true` if the given type is a UniformQuantizedSubChannel.
+MLIR_CAPI_EXPORTED bool
+mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
+
+/// Creates a UniformQuantizedSubChannelType with the given parameters.
+///
+/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
+/// DenseElementsAttrs.  `quantizedDimensions` and `blockSizes`
+/// point to `blockSizeInfoLength` number of elements, describing respectively
+/// the quantization axis and corresponding block size.
+MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
+    unsigned flags, MlirType storageType, MlirType expressedType,
+    MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr,
+    intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
+    int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);
+
+/// Returns the number of block sizes provided in type.
+MLIR_CAPI_EXPORTED intptr_t
+mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);
+
+/// Returns the quantized dimension at the given position.
+MLIR_CAPI_EXPORTED int32_t
+mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
+                                                        intptr_t pos);
+
+/// Returns the block size at the given position.
+MLIR_CAPI_EXPORTED int64_t
+mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos);
+
+/// Returns the scales of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetScales(MlirType type);
+
+/// Returns the zero-points of the quantized type.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
+
 //===---------------------------------------------------------------------===//
 // CalibratedQuantizedType
 //===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
index 791cb9de48d058..0d97889960019c 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
@@ -40,13 +40,17 @@ def Quant_Dialect : Dialect {
     encodes the necessary information for (lossy) round-trip conversion between
     an expressed and a stored value.
 
-    The `quant.uniform` type has two variants: per-layer quantization and
-    per-channel (or per-axis) quantization. In per-layer quantization, the
-    quantization information affects an entire tensor uniformly. Conversely, in
-    per-channel quantization, the data type encodes the specific tensor axis
-    that serves as the channel and includes quantization information for each
-    individual channel within the tensor. Below are the specific syntactic and
-    semantic considerations for each modality.
+    The `quant.uniform` type has three variants: per-layer quantization,
+    per-channel (or per-axis) quantization, and sub-channel (or blockwize)
+    quantization.  In per-layer quantization, the quantization information
+    affects an entire tensor uniformly. Conversely, in per-channel
+    quantization, the data type encodes the specific tensor axis that serves
+    as the channel and includes quantization information for each individual
+    channel within the tensor. Sub-channel quantization is a generalization
+    of per-tensor and per-channel quantization, where the quantization
+    parameters are defined for blocks of elements along one or more
+    dimensions of the tensor. Below are the specific syntactic and semantic
+    considerations for each modality.
 
 
     ### Per-layer quantization
@@ -145,7 +149,7 @@ def Quant_Dialect : Dialect {
     ```
     // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
     // floats. Dimension 1 of the tensor acts as the channel dimension. Its
-    // size 3 matches the number of provided scale values. Tensor elemenets at
+    // size 3 matches the number of provided scale values. Tensor elements at
     // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
     // 5.0, respectively.
     tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
@@ -159,6 +163,72 @@ def Quant_Dialect : Dialect {
     tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
     ```
 
+    ### Sub-channel quantization
+
+    Sub-channel quantization, also known as blockwise quantization, provides
+    finer-grained control than per-tensor or per-channel quantization. It
+    divides a tensor into blocks of elements, each with its own quantization
+    parameters (scale and zero point). This is particularly useful when
+    different regions of a tensor exhibit distinct value ranges.
+
+    The `!quant.uniform` type represents sub-channel quantization with the
+    following syntax:
+
+    ```
+    `!quant.uniform` `<`
+      storedType (`<` storageMin `:` storageMax `>`)? `:`
+      expressedType `:` blockSizeInfo
+      scaleZeroTensor `>`
+
+    blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}`
+    axisBlock ::= axis `:` blockSize
+    scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList
+    scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}`
+    scaleZeroList  ::= scaleZero (`,` scaleZero)*
+    scaleZero ::= scale (`:` zeroPoint)?
+    
+    scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list
+    scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}`
+    scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)*
+    ```
+
+    The `blockSize` field specifies the size of the blocks along dimension
+    `axis` of the tensor. The `scale` and `zeroPoint` fields specify the
+    quantization parameters for a particular block. Specifically, the tensor
+    element at position [i0...iN] uses
+    `scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and
+    `scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale
+    and zeroPoint respectively.
+
+    Here are some examples:
+
+    ```
+    // A 3x4 tensor of i8 values representing f32 values, quantized 
+    // along axis-0 and axis-1 with block sizes 1 and 2,
+    // respectively. As a result, the shape of the scales (or zero-points) will
+    // be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of
+    // blocks along each axis. Tensor elements at positions 
+    // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+    // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+    // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+    // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+    // [2][0] and [2][1] use scale `s20` and zero point `z20`,
+    // [2][2] and [2][3] use scale `s21` and zero point `z21`,
+    tensor<3x4x!quant.uniform<i8:f32:{0:1, 1:2},
+      {{s00:z00, s01:z01}, {s10:z10,s11:z11}, {s20:z20,s21:z21}}>>
+
+    // A 2D dynamically sized tensor contains u16 values
+    // representing f32 values. Since the shape of the quantization
+    // parameters (i.e. scales and zero-points) is given as [2,2] and
+    // the blocks-sizes are given as [1,2], the shape of the tensor is expected
+    // to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions
+    // [0][0] and [0][1] use scale `s00` and zero point `z00`,
+    // [0][2] and [0][3] use scale `s01` and zero point `z01`,
+    // [1][0] and [1][1] use scale `s10` and zero point `z10`,
+    // [1][2] and [1][3] use scale `s11` and zero point `z11`,
+    tensor<?x?x!quant.uniform<u16:f32:{0:1, 1:2},
+      {{s00:z00, s01:z01}, {s10:z10,s11:z11}}>>
+    ```
 
     ## Per-axis quantization integrity
 
@@ -170,7 +240,7 @@ def Quant_Dialect : Dialect {
     respected in any context in which the `!quant.uniform` data type is used,
     such as the header of a `func.func` op, or the input of an arithmetic
     operation.
- 
+
     - A quantized type with per-channel quantization information must be the
       element type of a tensor container type, and may not occur directly as
       the data type of a scalar value.
@@ -209,6 +279,110 @@ def Quant_Dialect : Dialect {
     // Correct. The quantized type now includes 3 scale values, matching the
     // size of dimension 1 of the result tensor.
     %result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>
+
+    ## Sub-channel quantization integrity
+
+    When type `!quant.uniform` contains sub-channel quantization information,
+    the following rules are enforced.  For efficiency, these rules are actively
+    enforced by the verifiers of `quant` dialect ops, but they must be
+    respected in any context in which the `!quant.uniform` data type is used,
+    such as the header of a `func.func` op, or the input of an arithmetic
+    operation.
+
+    - A quantized type with sub-channel quantization information must be the
+      element type of a tensor container type, and may not occur directly as
+      the data type of a scalar value.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+    // scalar type.
+    %result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>
+
+    // Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
+    // in a `tensor` type.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+    ```
+
+    - The tensor containing the sub-channel quantized type must be ranked.
+
+    ```
+    // Incorrect. Type !quant.uniform specifies sub-channel quantization for a
+    // unranked tensor type.
+    %result = quant.qcast %input : tensor<*xf32> to
+                tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
+    ```
+
+    - The axis for which a block size is specified should be valid for a tensor
+    of a given rank. Block sizes can be specified for a subset of axes. 
+    Any unspecified block size for an axis i defaults to the tensor dimension
+    size of that axis (shape(tensor)[i]).
+
+    ```
+    // Incorrect. The block-size is specified for axis 2 which is greater than
+    // the rank of the tensor.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>
+
+    // Incorrect. The block-size is specified for a negative axis.
+    %result = quant.qcast %input : tensor<2x2xf32> to
+                tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>
+
+    // Correct. The block size for axis 1 is skipped which should be assumed as
+    // 2, the dim-size of tensor at axis 1.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>
+
+    // Correct. The block size for all the axes are skipped making the
+    // sub-channel type essentially a per-tensor type.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
+    ```
+
+    - Block size for a particular axis should be a positive integer and should
+      be less than the dimension size of the tensor along that axis.
+
+    ```
+    // Incorrect. The block size for axis 0 is -1.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>
+
+    // Incorrect. The block size for axis 0 is 8 which is greater than the
+    // dimension size of tensor at axis 0 (which is 6).
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>
+
+    // Correct. The block size for axis 0 is now 3.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+    ```
+
+    - shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
+      axis i in [0, 1, ..., rank(tensor)-1]].
+
+    ```
+    // Incorrect. The block size for axis 0 is 4 and the corresponding
+    // dimension size is 6 and 6 % 4 != 0.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>
+
+    // Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
+    ```
+
+    - shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.
+
+    ```
+    // Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
+    // shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>
+
+    // Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
+    // shape(scales) equals [6,2]/[3,2].
+    %result = quant.qcast %input : tensor<6x2xf32> to
+                tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
     ```
   }];
   let cppNamespace = "::mlir::quant";
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
index bd9cdf82382275..8c74dbef5d94a3 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
@@ -13,6 +13,7 @@
 #ifndef QUANT_BYTECODE
 #define QUANT_BYTECODE
 
+include "mlir/IR/BuiltinDialectBytecode.td"
 include "mlir/IR/BytecodeBase.td"
 
 def DoubleAPFloat:
@@ -81,20 +82,31 @@ def UniformQuantizedPerAxisType: DialectType<(type
   }];
 }
 
+def UniformQuantizedSubChannelType
+    : DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType,
+          SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax,
+          Array<SignedVarIntList>:$quantizedDimensions,
+          Array<SignedVarIntList>:$blockSizes, DenseElementsAttr:$scales,
+          DenseElementsAttr:$zeroPoints)> {
+  // Note: builder order differs from bytecode.
+  let cBuilder = [{
+      get<$_resultType>(context, flags, storageType, expressedType, scales,
+        zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions,
+        [](int64_t dim) { return static_cast<int32_t>(dim);})), blockSizes,
+        storageTypeMin, storageTypeMax)
+  }];
+}
+
 /// This enum contains marker codes used to indicate which attribute is
 /// currently being decoded, and how it should be decoded. The order of these
 /// codes should generally be unchanged, as any changes will inevitably break
 /// compatibility with older bytecode.
 
 def QuantDialectTypes : DialectTypes<"Quant"> {
-  let elems = [
-    ReservedOrDead,
-    AnyQuantizedType,
-    AnyQuantizedTypeWithExpressedType,
-    CalibratedQuantizedType,
-    UniformQuantizedType,
-    UniformQuantizedPerAxisType
-  ];
+  let elems = [ReservedOrDead, AnyQuantizedType,
+               AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType,
+               UniformQuantizedType, UniformQuantizedPerAxisType,
+               UniformQuantizedSubChannelType];
 }
 
-#endif // QUANT_BYTECODE
\ No newline at end of file
+#endif // QUANT_BYTECODE
diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
index 43440ba623b9c1..44062fe376ec0d 100644
--- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
@@ -23,6 +23,7 @@ namespace detail {
 
 struct QuantizedTypeStorage;
 struct AnyQuantizedTypeStorage;
+struct UniformQuantizedSubChannelTypeStorage;
 struct UniformQuantizedTypeStorage;
 struct UniformQuantizedPerAxisTypeStorage;
 struct CalibratedQuantizedTypeStorage;
@@ -382,6 +383,136 @@ class UniformQuantizedPerAxisType
   }
 };
 
+/// Represents sub-channel (also known as blockwise quantization).
+///
+/// Syntax synopsis:
+///   UniformQuantizedSubChannelType ::= '!quant.uniform' '<'
+///       storageType ('<' storageMin ':' storageMax '>')? ':'
+///       expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>'
+///   BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}'
+///   AxisBlock ::= AxisSpec ':' BlockSizeSpec
+///   ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList
+///   ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}'
+///   ScaleZeroList  ::= ScaleZero (',' ScaleZero)*
+///   ScaleZero ::= Scale (':' ZeroPoint)?
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   AxisSpec: An integer value
+///   BlockSizeSpec: An integer value
+///   Scale: An attribute (usually floating-point value)
+///   ZeroPoint: An attribute (usually integer value)
+class UniformQuantizedSubChannelType
+    : public Type::TypeBase<UniformQuantizedSubChannelType, QuantizedType,
+                            detail::UniformQuantizedSubChannelTypeStorage> {
+public:
+  using Base::Base;
+  using Base::getChecked;
+
+  static constexpr StringLiteral name = "quant.uniform_sub_channel";
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedSubChannelType
+  get(unsigned flags, Type storageType, Type expressedType,
+      DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+      ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedSubChannelType
+  getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+             Type storageType, Type expressedType, DenseElementsAttr scales,
+             DenseElementsAttr zeroPoints,
+             ArrayRef<int32_t> quantizedDimensions,
+             ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+             int64_t storageTypeMax);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   DenseElementsAttr scales, DenseElementsAttr zeroPoints,
+                   ArrayRef<int32_t> quantizedDimensions,
+                   ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
+
+  /// Gets the quantization scales. The scales are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the scales tensor
+  /// is determined by the number of blocks along the corresponding dimension in
+  /// the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will
+  /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The scale value for a specific element in the quantized data tensor at
+  /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding
+  /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1].
+  DenseElementsAttr getScales() const;
+
+  /// Gets the quantization zero-points. The zero-points are organized in a
+  /// multi-dimensional tensor. The size of each dimension in the zero-point
+  /// tensor is determined by the number of blocks along the corresponding
+  /// dimension in the quantized data tensor.
+  ///
+  /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1]
+  /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor
+  /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1].
+  ///
+  /// The zero-point value for a specific element in the quantized data tensor
+  /// at position [i0, i1, ..., iR-1] is determined by accessing the
+  /// c...
[truncated]

@sdasgup3 sdasgup3 changed the title Subchannel quant impl Sub-channel quantized type implementation Dec 17, 2024
Copy link

github-actions bot commented Dec 17, 2024

✅ With the latest revision this PR passed the Python code formatter.

@sdasgup3 sdasgup3 force-pushed the subchannel-quant-impl branch from 3a6a0a7 to 0f4147e Compare December 17, 2024 18:18
@sdasgup3 sdasgup3 force-pushed the subchannel-quant-impl branch from 0f4147e to ce179a5 Compare January 3, 2025 18:19
Copy link

github-actions bot commented Jan 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@sdasgup3 sdasgup3 force-pushed the subchannel-quant-impl branch 2 times, most recently from e3af4df to fc23a72 Compare January 4, 2025 02:24
@sdasgup3
Copy link
Contributor Author

Hello @stellaraccident @sjarus @ftynse @makslevental @rafaelauler
Can you please tale a look at the PR and let me your your feedback?

@sdasgup3 sdasgup3 force-pushed the subchannel-quant-impl branch from fc23a72 to cb7d7ff Compare February 5, 2025 21:38
@makslevental
Copy link
Contributor

Unfortunately I don't know much about quant stuff. @qedawkins I vaguely remember you working on related stuff a while ago? Otherwise you happen to know who might be to review?

@qedawkins
Copy link
Contributor

I worked on something similar to the Linalg this is bottoming out at, but I don't have much of an opinion on the quant dialect itself. I can try to do a surface level review but am sparse on cycles right now.

@sdasgup3
Copy link
Contributor Author

sdasgup3 commented Feb 6, 2025

@qedawkins It would be great to have your review (once you have time)!

Copy link
Contributor

@GleasonK GleasonK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Lots of minor bits of feedback. Let's wait for another quantization maintainer approval before merge.

@sdasgup3
Copy link
Contributor Author

@GleasonK Thanks for all the feedback comments. I have addressed all but one (waiting on @stellaraccident opinion on #120172 (comment)). Feel free to take a look.

@stellaraccident
Copy link
Contributor

If it's quant dialect stuff, you don't need me. I hereby bequeath you the baton.

@GleasonK
Copy link
Contributor

Baton accepted!

I think the comment linked was just to see if you had any historical knowledge on the is_signed flag (#120172 (comment)), looks like you submitted the original change years ago - my best guess is that these types existed before uint types.

Regardless, happy to try and delete it and see what breaks. Thanks!

@sdasgup3 sdasgup3 force-pushed the subchannel-quant-impl branch from b09432f to 4e47107 Compare March 15, 2025 00:27
@GleasonK
Copy link
Contributor

GleasonK commented Mar 17, 2025

Current feedback addressed, hoping to merge this by ~EoW - I believe there's been ample time to get feedback in but want to make sure no one is surprised so please let us know if more review time is needed!

(cc @rafaelubalmw / @sjarus who were active on the initial discourse rfc)

@sdasgup3
Copy link
Contributor Author

Regardless, happy to try and delete it and see what breaks. Thanks!

Agreed!
Let us try that as a follow up.

Copy link
Contributor

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a pass over the code; looks fine and thank you for this work! Posting an explicit approval.

@GleasonK GleasonK merged commit 81d7eef into llvm:main Mar 23, 2025
11 checks passed
@rafaelubalmw
Copy link
Contributor

Hi all, thanks for the contribution. Just noticed that the documentation page got garbled, probably due to some misplaced backtick.

https://mlir.llvm.org/docs/Dialects/QuantDialect/

@mgorny
Copy link
Member

mgorny commented Mar 29, 2025

I'm seeing a test regression on 32-bit x86 with this change:

FAIL: MLIR :: CAPI/quant.c (38 of 2812)
******************** TEST 'MLIR :: CAPI/quant.c' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 10
/var/tmp/portage/llvm-core/mlir-21.0.0.9999/work/mlir_build-abi_x86_32.x86/bin/mlir-capi-quant-test 2>&1 | /usr/lib/llvm/21/bin/FileCheck /var/tmp/portage/llvm-core/mlir-21.0.0.9999/work/mlir/test/CAPI/quant.c
# executed command: /var/tmp/portage/llvm-core/mlir-21.0.0.9999/work/mlir_build-abi_x86_32.x86/bin/mlir-capi-quant-test
# executed command: /usr/lib/llvm/21/bin/FileCheck /var/tmp/portage/llvm-core/mlir-21.0.0.9999/work/mlir/test/CAPI/quant.c
# .---command stderr------------
# | /var/tmp/portage/llvm-core/mlir-21.0.0.9999/work/mlir/test/CAPI/quant.c:270:12: error: CHECK: expected string not found in input
# |  // CHECK: num dims: 2
# |            ^
# | <stdin>:51:28: note: scanning from here
# | is null sub-channel type: 1
# |                            ^
# | <stdin>:52:1: note: possible intended match here
# | num dims: 6292612763541831682
# | ^
# | 
# | Input file: <stdin>
# | Check file: /var/tmp/portage/llvm-core/mlir-21.0.0.9999/work/mlir/test/CAPI/quant.c
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |              .
# |              .
# |              .
# |             46: equal: 1 
# |             47: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> 
# |             48:  
# |             49:  
# |             50: testUniformSubChannelType 
# |             51: is null sub-channel type: 1 
# | check:270'0                                X error: no match found
# |             52: num dims: 6292612763541831682 
# | check:270'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# | check:270'1     ?                              possible intended match
# |             53: axis-block-size-pair[0]: 0:1 
# | check:270'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# |             54: axis-block-size-pair[1]: 1:2 
# | check:270'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# |             55: tensor<2x2xf32> 
# | check:270'0     ~~~~~~~~~~~~~~~~
# |             56: tensor<2x2xi8> 
# | check:270'0     ~~~~~~~~~~~~~~~
# |             57: number of quantization parameters: 4 
# | check:270'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# |              .
# |              .
# |              .
# | >>>>>>
# `-----------------------------
# error: command failed with exit status: 1

--

********************

@sdasgup3
Copy link
Contributor Author

sdasgup3 commented Mar 31, 2025

I'm seeing a test regression on 32-bit x86 with this change:

@mgorny Please find the fix @ #133763 and let me know if this looks good.

GleasonK pushed a commit that referenced this pull request Mar 31, 2025
GleasonK pushed a commit that referenced this pull request Mar 31, 2025
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Mar 31, 2025
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Mar 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants