diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h index a7d98dc3c1a77..dc0989e53344e 100644 --- a/mlir/include/mlir-c/Dialect/Quant.h +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type); MLIR_CAPI_EXPORTED bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type); +//===---------------------------------------------------------------------===// +// UniformQuantizedSubChannelType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedSubChannel. +MLIR_CAPI_EXPORTED bool +mlirTypeIsAUniformQuantizedSubChannelType(MlirType type); + +/// Creates a UniformQuantizedSubChannelType with the given parameters. +/// +/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be +/// DenseElementsAttrs. `quantizedDimensions` and `blockSizes` +/// point to `blockSizeInfoLength` number of elements, describing respectively +/// the quantization axis and corresponding block size. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, + intptr_t blockSizeInfoLength, int32_t *quantizedDimensions, + int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the number of block sizes provided in type. +MLIR_CAPI_EXPORTED intptr_t +mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type); + +/// Returns the quantized dimension at the given position. +MLIR_CAPI_EXPORTED int32_t +mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type, + intptr_t pos); + +/// Returns the block size at the given position. +MLIR_CAPI_EXPORTED int64_t +mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos); + +/// Returns the scales of the quantized type. +MLIR_CAPI_EXPORTED MlirAttribute +mlirUniformQuantizedSubChannelTypeGetScales(MlirType type); + +/// Returns the zero-points of the quantized type. +MLIR_CAPI_EXPORTED MlirAttribute +mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type); + //===---------------------------------------------------------------------===// // CalibratedQuantizedType //===---------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td index 791cb9de48d05..0d97889960019 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -40,13 +40,17 @@ def Quant_Dialect : Dialect { encodes the necessary information for (lossy) round-trip conversion between an expressed and a stored value. - The `quant.uniform` type has two variants: per-layer quantization and - per-channel (or per-axis) quantization. In per-layer quantization, the - quantization information affects an entire tensor uniformly. Conversely, in - per-channel quantization, the data type encodes the specific tensor axis - that serves as the channel and includes quantization information for each - individual channel within the tensor. Below are the specific syntactic and - semantic considerations for each modality. + The `quant.uniform` type has three variants: per-layer quantization, + per-channel (or per-axis) quantization, and sub-channel (or blockwize) + quantization. In per-layer quantization, the quantization information + affects an entire tensor uniformly. Conversely, in per-channel + quantization, the data type encodes the specific tensor axis that serves + as the channel and includes quantization information for each individual + channel within the tensor. Sub-channel quantization is a generalization + of per-tensor and per-channel quantization, where the quantization + parameters are defined for blocks of elements along one or more + dimensions of the tensor. Below are the specific syntactic and semantic + considerations for each modality. ### Per-layer quantization @@ -145,7 +149,7 @@ def Quant_Dialect : Dialect { ``` // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit // floats. Dimension 1 of the tensor acts as the channel dimension. Its - // size 3 matches the number of provided scale values. Tensor elemenets at + // size 3 matches the number of provided scale values. Tensor elements at // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and // 5.0, respectively. tensor<2x3x4x!quant.uniform> @@ -159,6 +163,72 @@ def Quant_Dialect : Dialect { tensor> ``` + ### Sub-channel quantization + + Sub-channel quantization, also known as blockwise quantization, provides + finer-grained control than per-tensor or per-channel quantization. It + divides a tensor into blocks of elements, each with its own quantization + parameters (scale and zero point). This is particularly useful when + different regions of a tensor exhibit distinct value ranges. + + The `!quant.uniform` type represents sub-channel quantization with the + following syntax: + + ``` + `!quant.uniform` `<` + storedType (`<` storageMin `:` storageMax `>`)? `:` + expressedType `:` blockSizeInfo + scaleZeroTensor `>` + + blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}` + axisBlock ::= axis `:` blockSize + scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList + scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}` + scaleZeroList ::= scaleZero (`,` scaleZero)* + scaleZero ::= scale (`:` zeroPoint)? + + scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list + scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}` + scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)* + ``` + + The `blockSize` field specifies the size of the blocks along dimension + `axis` of the tensor. The `scale` and `zeroPoint` fields specify the + quantization parameters for a particular block. Specifically, the tensor + element at position [i0...iN] uses + `scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and + `scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale + and zeroPoint respectively. + + Here are some examples: + + ``` + // A 3x4 tensor of i8 values representing f32 values, quantized + // along axis-0 and axis-1 with block sizes 1 and 2, + // respectively. As a result, the shape of the scales (or zero-points) will + // be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of + // blocks along each axis. Tensor elements at positions + // [0][0] and [0][1] use scale `s00` and zero point `z00`, + // [0][2] and [0][3] use scale `s01` and zero point `z01`, + // [1][0] and [1][1] use scale `s10` and zero point `z10`, + // [1][2] and [1][3] use scale `s11` and zero point `z11`, + // [2][0] and [2][1] use scale `s20` and zero point `z20`, + // [2][2] and [2][3] use scale `s21` and zero point `z21`, + tensor<3x4x!quant.uniform> + + // A 2D dynamically sized tensor contains u16 values + // representing f32 values. Since the shape of the quantization + // parameters (i.e. scales and zero-points) is given as [2,2] and + // the blocks-sizes are given as [1,2], the shape of the tensor is expected + // to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions + // [0][0] and [0][1] use scale `s00` and zero point `z00`, + // [0][2] and [0][3] use scale `s01` and zero point `z01`, + // [1][0] and [1][1] use scale `s10` and zero point `z10`, + // [1][2] and [1][3] use scale `s11` and zero point `z11`, + tensor> + ``` ## Per-axis quantization integrity @@ -170,7 +240,7 @@ def Quant_Dialect : Dialect { respected in any context in which the `!quant.uniform` data type is used, such as the header of a `func.func` op, or the input of an arithmetic operation. - + - A quantized type with per-channel quantization information must be the element type of a tensor container type, and may not occur directly as the data type of a scalar value. @@ -209,6 +279,110 @@ def Quant_Dialect : Dialect { // Correct. The quantized type now includes 3 scale values, matching the // size of dimension 1 of the result tensor. %result = quant.qcast %input : tensor to tensor> + + ## Sub-channel quantization integrity + + When type `!quant.uniform` contains sub-channel quantization information, + the following rules are enforced. For efficiency, these rules are actively + enforced by the verifiers of `quant` dialect ops, but they must be + respected in any context in which the `!quant.uniform` data type is used, + such as the header of a `func.func` op, or the input of an arithmetic + operation. + + - A quantized type with sub-channel quantization information must be the + element type of a tensor container type, and may not occur directly as + the data type of a scalar value. + + ``` + // Incorrect. Type !quant.uniform specifies sub-channel quantization for a + // scalar type. + %result = quant.qcast %input : f32 to !quant.uniform + + // Correct. Type `!quant.uniform` with sub-channel quantization is wrapped + // in a `tensor` type. + %result = quant.qcast %input : tensor<2x2xf32> to + tensor<2x2x!quant.uniform> + ``` + + - The tensor containing the sub-channel quantized type must be ranked. + + ``` + // Incorrect. Type !quant.uniform specifies sub-channel quantization for a + // unranked tensor type. + %result = quant.qcast %input : tensor<*xf32> to + tensor<*x!quant.uniform> + ``` + + - The axis for which a block size is specified should be valid for a tensor + of a given rank. Block sizes can be specified for a subset of axes. + Any unspecified block size for an axis i defaults to the tensor dimension + size of that axis (shape(tensor)[i]). + + ``` + // Incorrect. The block-size is specified for axis 2 which is greater than + // the rank of the tensor. + %result = quant.qcast %input : tensor<2x2xf32> to + tensor<2x2x!quant.uniform> + + // Incorrect. The block-size is specified for a negative axis. + %result = quant.qcast %input : tensor<2x2xf32> to + tensor<2x2x!quant.uniform> + + // Correct. The block size for axis 1 is skipped which should be assumed as + // 2, the dim-size of tensor at axis 1. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + + // Correct. The block size for all the axes are skipped making the + // sub-channel type essentially a per-tensor type. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + ``` + + - Block size for a particular axis should be a positive integer and should + be less than the dimension size of the tensor along that axis. + + ``` + // Incorrect. The block size for axis 0 is -1. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + + // Incorrect. The block size for axis 0 is 8 which is greater than the + // dimension size of tensor at axis 0 (which is 6). + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + + // Correct. The block size for axis 0 is now 3. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + ``` + + - shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for + axis i in [0, 1, ..., rank(tensor)-1]]. + + ``` + // Incorrect. The block size for axis 0 is 4 and the corresponding + // dimension size is 6 and 6 % 4 != 0. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + + // Correct. The block size for axis 0 is now 3 making 6 % 3 = 0. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + ``` + + - shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes. + + ``` + // Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but + // shape(scales) is [1,2] which is not equal to [6,2]/[3,2]. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> + + // Correct. shape(tensor) = [6,2], blockSizes = [3,2], and + // shape(scales) equals [6,2]/[3,2]. + %result = quant.qcast %input : tensor<6x2xf32> to + tensor<6x2x!quant.uniform> ``` }]; let cppNamespace = "::mlir::quant"; diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td index bd9cdf8238227..8c74dbef5d94a 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td @@ -13,6 +13,7 @@ #ifndef QUANT_BYTECODE #define QUANT_BYTECODE +include "mlir/IR/BuiltinDialectBytecode.td" include "mlir/IR/BytecodeBase.td" def DoubleAPFloat: @@ -81,20 +82,31 @@ def UniformQuantizedPerAxisType: DialectType<(type }]; } +def UniformQuantizedSubChannelType + : DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType, + SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax, + Array:$quantizedDimensions, + Array:$blockSizes, DenseElementsAttr:$scales, + DenseElementsAttr:$zeroPoints)> { + // Note: builder order differs from bytecode. + let cBuilder = [{ + get<$_resultType>(context, flags, storageType, expressedType, scales, + zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions, + [](int64_t dim) { return static_cast(dim);})), blockSizes, + storageTypeMin, storageTypeMax) + }]; +} + /// This enum contains marker codes used to indicate which attribute is /// currently being decoded, and how it should be decoded. The order of these /// codes should generally be unchanged, as any changes will inevitably break /// compatibility with older bytecode. def QuantDialectTypes : DialectTypes<"Quant"> { - let elems = [ - ReservedOrDead, - AnyQuantizedType, - AnyQuantizedTypeWithExpressedType, - CalibratedQuantizedType, - UniformQuantizedType, - UniformQuantizedPerAxisType - ]; + let elems = [ReservedOrDead, AnyQuantizedType, + AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType, + UniformQuantizedType, UniformQuantizedPerAxisType, + UniformQuantizedSubChannelType]; } -#endif // QUANT_BYTECODE \ No newline at end of file +#endif // QUANT_BYTECODE diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index 43440ba623b9c..44062fe376ec0 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -23,6 +23,7 @@ namespace detail { struct QuantizedTypeStorage; struct AnyQuantizedTypeStorage; +struct UniformQuantizedSubChannelTypeStorage; struct UniformQuantizedTypeStorage; struct UniformQuantizedPerAxisTypeStorage; struct CalibratedQuantizedTypeStorage; @@ -382,6 +383,136 @@ class UniformQuantizedPerAxisType } }; +/// Represents sub-channel (also known as blockwise quantization). +/// +/// Syntax synopsis: +/// UniformQuantizedSubChannelType ::= '!quant.uniform' '<' +/// storageType ('<' storageMin ':' storageMax '>')? ':' +/// expressedType ':' BlockSizeInfo ',' ScaleZeroTensor '>' +/// BlockSizeInfo: '{' '}' | '{' AxisBlock (',' AxisBlock)* '}' +/// AxisBlock ::= AxisSpec ':' BlockSizeSpec +/// ScaleZeroTensor ::= ScaleZeroDenseExp | ScaleZeroList +/// ScaleZeroDenseExp ::= '{' ScaleZeroTensor (',' ScaleZeroTensor)* '}' +/// ScaleZeroList ::= ScaleZero (',' ScaleZero)* +/// ScaleZero ::= Scale (':' ZeroPoint)? +/// +/// StorageType: 'i'|'u' NumBits +/// ExpressedType: 'f16', 'f32', 'bf16', 'f64' +/// AxisSpec: An integer value +/// BlockSizeSpec: An integer value +/// Scale: An attribute (usually floating-point value) +/// ZeroPoint: An attribute (usually integer value) +class UniformQuantizedSubChannelType + : public Type::TypeBase { +public: + using Base::Base; + using Base::getChecked; + + static constexpr StringLiteral name = "quant.uniform_sub_channel"; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static UniformQuantizedSubChannelType + get(unsigned flags, Type storageType, Type expressedType, + DenseElementsAttr scales, DenseElementsAttr zeroPoints, + ArrayRef quantizedDimensions, ArrayRef blockSizes, + int64_t storageTypeMin, int64_t storageTypeMax); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static UniformQuantizedSubChannelType + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, DenseElementsAttr scales, + DenseElementsAttr zeroPoints, + ArrayRef quantizedDimensions, + ArrayRef blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult + verifyInvariants(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, + DenseElementsAttr scales, DenseElementsAttr zeroPoints, + ArrayRef quantizedDimensions, + ArrayRef blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax); + + /// Gets the quantization scales. The scales are organized in a + /// multi-dimensional tensor. The size of each dimension in the scales tensor + /// is determined by the number of blocks along the corresponding dimension in + /// the quantized data tensor. + /// + /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1] + /// and the block sizes are [B0, B1, ..., BR-1], then the scales tensor will + /// have shape [X0/B0, X1/B1, ..., XR-1/BR-1]. + /// + /// The scale value for a specific element in the quantized data tensor at + /// position [i0, i1, ..., iR-1] is determined by accessing the corresponding + /// element in the scales tensor at position [i0/B0, i1/B1, ..., iR-1/BR-1]. + DenseElementsAttr getScales() const; + + /// Gets the quantization zero-points. The zero-points are organized in a + /// multi-dimensional tensor. The size of each dimension in the zero-point + /// tensor is determined by the number of blocks along the corresponding + /// dimension in the quantized data tensor. + /// + /// For example, if the quantized data tensor has shape [X0, X1, ..., XR-1] + /// and the block sizes are [B0, B1, ..., BR-1], then the zero-point tensor + /// will have shape [X0/B0, X1/B1, ..., XR-1/BR-1]. + /// + /// The zero-point value for a specific element in the quantized data tensor + /// at position [i0, i1, ..., iR-1] is determined by accessing the + /// corresponding element in the zero-point tensor at position [i0/B0, i1/B1, + /// ..., iR-1/BR-1]. + DenseElementsAttr getZeroPoints() const; + + /// Gets the quantized dimensions. Each element in the returned list + /// represents an axis of the quantized data tensor that has a specified block + /// size. The order of elements corresponds to the order of block sizes + /// returned by `getBlockSizes()`. + /// + /// It means that the data tensor is quantized along the `i`-th dimension in + /// the returned list using the `i`-th block size from `getBlockSizes()`. + /// + /// Note that the type expression does not have to specify the block size for + /// all axes in the data tensor. Any unspecified block size for an axis `i` + /// defaults to the tensor dimension size of that axis. + /// + /// For example, for a quantized type: + /// `tensor<8x4x2x!quant.uniform` + /// + /// `getQuantizedDimensions()` returns [1, 0]. + /// `getBlockSizes()` returns [2, 8]. + /// + /// This indicates that: + /// * Axis 1 (second dimension) is quantized with a block size of 2. + /// * Axis 0 (first dimension) is quantized with a block size of 8. + /// Since axis 2 is not specified, it implicitly has a block size equal to + /// the size of the third dimension (which is 2 in this case). + ArrayRef getQuantizedDimensions() const; + + /// Gets the block sizes for the quantized dimensions. The `i`-th element in + /// the returned list corresponds to the block size for the `i`-th dimension + /// in the list returned by `getQuantizedDimensions()`. + /// + /// See `getQuantizedDimensions()` for more details and examples. + ArrayRef getBlockSizes() const; + + /// Gets the block size information. This returns a list of pairs, where each + /// pair represents a quantized dimension and its corresponding block size. + /// + /// For example, for the type: + /// `tensor<8x4x!quant.uniform` + /// + /// This method returns: + /// `[(1, 2), (0, 8)]` + /// + /// This list indicates that axis 1 has a block size of 2, and axis 0 has a + /// block size of 8. + const SmallVector> getBlockSizeInfo() const; +}; + /// A quantized type that infers its range from given min/max values. /// /// Typical syntax: diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td index b25296d4db5a9..a62315c0395f7 100644 --- a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -31,6 +31,44 @@ def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> { ]; } +def NormalizeQuantTypes : Pass<"normalize-quant-types", "func::FuncOp"> { + let summary = "Normalize generic quantized types to specific quantized types"; + let description = [{ + This pass converts generic quantized types in the `quant` dialect to more + specific types when possible. + + The following conversions are performed: + + 1. Sub-channel to per-axis: If the shape of the scales tensor of sub-channel + quantized type has all but one non-one value, it is converted to a + per-axis quantized type. + + For example: + + * `!quant.uniform` + -> `!quant.uniform` + * `tensor>` + -> `tensor>` + + 2. Sub-channel to per-tensor: If a sub-channel quantized type has only + one scale or zero-point, it is converted to a per-tensor + quantized type. + + For example: + + * `!quant.uniform` + -> `!quant.uniform` + * `tensor>` + -> `tensor>` + + The rationale for these conversions is that the decompositions / handling of + more precise quantized types tends to be more efficient than treating + everything as subchannel type. + + }]; + let dependentDialects = ["func::FuncDialect", "quant::QuantDialect"]; +} + def StripFuncQuantTypes : Pass<"strip-func-quant-types"> { let summary = "Strip quantized types from function headers"; let description = [{ diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index 29f19c9c50065..55571cd1e50a6 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -9,10 +9,11 @@ #include #include +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Dialect/Quant.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; @@ -284,6 +285,79 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { }, "Fixed point values are real numbers divided by a scale."); + //===-------------------------------------------------------------------===// + // UniformQuantizedSubChannelType + //===-------------------------------------------------------------------===// + auto uniformQuantizedSubChannelType = mlir_type_subclass( + m, "UniformQuantizedSubChannelType", + mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); + uniformQuantizedSubChannelType.def_classmethod( + "get", + [](nb::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, + std::vector quantizedDimensions, + std::vector blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + return cls(mlirUniformQuantizedSubChannelTypeGet( + flags, storageType, expressedType, scales, zeroPoints, + static_cast(blockSizes.size()), + quantizedDimensions.data(), blockSizes.data(), storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedSubChannel in the same context as " + "the provided storage type.", + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), + nb::arg("quantized_dimensions"), nb::arg("block_sizes"), + nb::arg("storage_type_min"), nb::arg("storage_type_max")); + uniformQuantizedSubChannelType.def_property_readonly( + "quantized_dimensions", + [](MlirType type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector quantizedDimensions; + quantizedDimensions.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + quantizedDimensions.push_back( + mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i)); + } + return quantizedDimensions; + }, + "Gets the quantized dimensions. Each element in the returned list " + "represents an axis of the quantized data tensor that has a specified " + "block size. The order of elements corresponds to the order of block " + "sizes returned by 'block_sizes' method. It means that the data tensor " + "is quantized along the i-th dimension in the returned list using the " + "i-th block size from block_sizes method."); + uniformQuantizedSubChannelType.def_property_readonly( + "block_sizes", + [](MlirType type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector blockSizes; + blockSizes.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + blockSizes.push_back( + mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); + } + return blockSizes; + }, + "Gets the block sizes for the quantized dimensions. The i-th element in " + "the returned list corresponds to the block size for the i-th dimension " + "in the list returned by quantized_dimensions method."); + uniformQuantizedSubChannelType.def_property_readonly( + "scales", + [](MlirType type) -> MlirAttribute { + return mlirUniformQuantizedSubChannelTypeGetScales(type); + }, + "The scales of the quantized type."); + uniformQuantizedSubChannelType.def_property_readonly( + "zero_points", + [](MlirType type) -> MlirAttribute { + return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type); + }, + "The zero points of the quantized type."); + //===-------------------------------------------------------------------===// // CalibratedQuantizedType //===-------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index c94dbb5692fdb..01a6a948f1dc0 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Quant.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" @@ -194,6 +195,61 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { return cast(unwrap(type)).isFixedPoint(); } +//===---------------------------------------------------------------------===// +// UniformQuantizedSubChannelType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirUniformQuantizedSubChannelTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims, + int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + auto scales = dyn_cast(unwrap(scalesAttr)); + auto zeroPoints = dyn_cast(unwrap(zeroPointsAttr)); + + if (!scales || !zeroPoints) { + return {}; + } + + return wrap(quant::UniformQuantizedSubChannelType::get( + flags, unwrap(storageType), unwrap(expressedType), scales, zeroPoints, + llvm::ArrayRef(quantizedDimensions, nDims), + llvm::ArrayRef(blockSizes, nDims), storageTypeMin, + storageTypeMax)); +} + +intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) { + return cast(unwrap(type)) + .getBlockSizes() + .size(); +} + +int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type, + intptr_t pos) { + return cast(unwrap(type)) + .getQuantizedDimensions()[pos]; +} + +int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, + intptr_t pos) { + return cast(unwrap(type)) + .getBlockSizes()[pos]; +} + +MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) { + return wrap( + cast(unwrap(type)).getScales()); +} + +MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) { + return wrap(cast(unwrap(type)) + .getZeroPoints()); +} + //===---------------------------------------------------------------------===// // CalibratedQuantizedType //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp index 6a4ac310eb052..44ec0c517d561 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/Diagnostics.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index c584903f3a15d..94e1c8b8ba296 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" - namespace mlir { namespace quant { @@ -25,22 +24,17 @@ namespace { // Verify the integrity of per-axis quantization information, if present. // -// - quantizedType -// Any quantized type. Any quantized type with no per-axis quantization is -// ignored. +// - uniformQuantizedPerAxisType +// A quantized type with per-axis quantization. // // - containerType // Original input or result type of the operation using the provided quantized // type. Used to ensure that the quantized type appears within a tensor and // that the tensor is compatible with per-axis quantization information. // -LogicalResult verifyPerAxisQuantization(Operation *op, - QuantizedType quantizedType, - Type containerType) { - auto quantizedPerAxisType = dyn_cast(quantizedType); - if (!quantizedPerAxisType) - return success(); - +LogicalResult verifyPerAxisQuantization( + Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType, + Type containerType) { auto tensorType = dyn_cast(containerType); if (!tensorType) return op->emitError("scalar types may not use per-axis quantization"); @@ -48,19 +42,112 @@ LogicalResult verifyPerAxisQuantization(Operation *op, if (!tensorType.hasRank()) return success(); - int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension(); - if (quantizedDimension >= tensorType.getRank()) + int32_t quantizedDimension = + uniformQuantizedPerAxisType.getQuantizedDimension(); + if ((int64_t)quantizedDimension >= tensorType.getRank()) return op->emitError("quantized dimension must be less than tensor rank"); int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); if (quantizedDimensionSize != ShapedType::kDynamic && - quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size()) + quantizedDimensionSize != + (int64_t)uniformQuantizedPerAxisType.getScales().size()) return op->emitError( "quantized dimension size does not match number of scales"); return success(); } +// Verifies that the sub-channel quantization parameters are consistent with +// the given container type. The function checks the following: +// +// - The container type must be a ranked tensor type. +// - Each quantized dimension must be less than the rank of the tensor. +// - The size of each dimension at the quantized dimension must be divisible +// by the corresponding block size. +// - The scale dimension size at each axis index should match the tensor +// dimension at the index divided by the corresponding block size. +// +// The `uniformQuantizedSubChannelType` argument provides the sub-channel +// quantization parameters, and the `containerType` argument specifies the +// type of the container holding the quantized data. +// +LogicalResult verifySubChannelQuantization( + Operation *op, + UniformQuantizedSubChannelType uniformQuantizedSubChannelType, + Type containerType) { + auto tensorType = dyn_cast(containerType); + if (!tensorType) + return op->emitError("scalar types may not use sub-channel quantization"); + + if (!tensorType.hasRank()) + return op->emitError( + "tensor containing the sub-channel quantized type must be ranked"); + + const SmallVector> &blockSizeInfo = + uniformQuantizedSubChannelType.getBlockSizeInfo(); + auto shape = tensorType.getShape(); + + // The dimension size of scale for an axis which is not specified as quantized + // dimension should be 1. + SmallVector expectedScaleShape(tensorType.getShape().size(), 1); + for (auto [quantizedDimension, blockSize] : blockSizeInfo) { + if (quantizedDimension >= tensorType.getRank()) + return op->emitError() + << "quantized dimension " << quantizedDimension + << " must be less than tensor rank " << tensorType.getRank(); + if (!tensorType.isDynamicDim(quantizedDimension) && + tensorType.getDimSize(quantizedDimension) % blockSize != 0) + return op->emitError() + << "tensor dimension size " + << tensorType.getDimSize(quantizedDimension) << " at axis " + << quantizedDimension + << " must be divisible by the corresponding block size " + << blockSize; + if (tensorType.isDynamicDim(quantizedDimension)) + expectedScaleShape[quantizedDimension] = ShapedType::kDynamic; + else + expectedScaleShape[quantizedDimension] = + tensorType.getDimSize(quantizedDimension) / blockSize; + } + + // Block sizes must be greater than 0 and divide the corresponding dimension + // size. While a block size b must be less than or equal to the corresponding + // dimension size d, this constraint is implicitly enforced by requiring that + // d % b == 0 when d != 0. + // + // However, a problem arises when d = 0. The divisibility constraint allows b + // to be any value, potentially violating the requirement that b <= d. + // Furthermore, if b is unspecified (implicitly equal to d), it violates the + // constraint that b > 0. + // + // Therefore, we explicitly disallow the case where d = 0 to maintain + // consistency and avoid these issues. + if (llvm::find(tensorType.getShape(), 0) != tensorType.getShape().end()) { + return op->emitError() << "tensor dimension size of zero is not allowed " + "with sub-channel quantization"; + } + + auto scaleShape = + uniformQuantizedSubChannelType.getScales().getType().getShape(); + if (scaleShape.size() != shape.size()) { + return op->emitError() << "Rank of scales " << scaleShape.size() + << " must match " + << "the rank of the tensor " << shape.size(); + } + + for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) { + if (expectedScaleShape[index] != ShapedType::kDynamic && + expectedScaleShape[index] != scaleShape[index]) + return op->emitError() << "dimension size " << scaleDim + << " of scales tensor at axis " << index + << " should match (tensor dimension at axis / " + "block sizes at axis) = " + << expectedScaleShape[index]; + } + + return success(); +} + // Common verification logic for 'quant.dcast' and 'quant.qcast' ops. // // - quantizedType @@ -80,12 +167,23 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, return op->emitError( "expressed type in quantized type expected to match float type"); - // Veriy integrity of per-axis quantization information, if present. - return verifyPerAxisQuantization(op, quantizedType, containerType); -} + // Verify integrity of per-axis quantization information, if present. + if (auto quantizedPerAxisType = + dyn_cast(quantizedType)) { + return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType); + } -} // namespace + if (auto quantizedSubChannelType = + dyn_cast(quantizedType)) { + return verifySubChannelQuantization(op, quantizedSubChannelType, + containerType); + } + + // At this point the type is UniformQuantizedType + return success(); +} +} // namespace //===----------------------------------------------------------------------===// // Dialect @@ -93,7 +191,7 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, void QuantDialect::initialize() { addTypes(); + UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" @@ -101,7 +199,6 @@ void QuantDialect::initialize() { detail::addBytecodeInterface(this); } - //===----------------------------------------------------------------------===// // DequantizeCastOp //===----------------------------------------------------------------------===// @@ -130,7 +227,6 @@ QuantizedType DequantizeCastOp::getQuantizedType() { return cast(getElementTypeOrSelf(getInput().getType())); } - //===----------------------------------------------------------------------===// // QuantizeCastOp //===----------------------------------------------------------------------===// @@ -160,7 +256,6 @@ QuantizedType QuantizeCastOp::getQuantizedType() { return cast(getElementTypeOrSelf(getResult().getType())); } - //===----------------------------------------------------------------------===// // StorageCastOp //===----------------------------------------------------------------------===// @@ -175,7 +270,20 @@ LogicalResult StorageCastOp::verify() { // Verify integrity of per-axis quantization information, if available. While // the quantization type may appear in the input or the result, their tensor // shapes are guaranteed to be identical at this point. - return verifyPerAxisQuantization(*this, quantizedType, getInput().getType()); + if (auto quantizedPerAxisType = + dyn_cast(quantizedType)) { + return verifyPerAxisQuantization(*this, quantizedPerAxisType, + getInput().getType()); + } + + if (auto quantizedSunChannelType = + dyn_cast(quantizedType)) { + return verifySubChannelQuantization(*this, quantizedSunChannelType, + getInput().getType()); + } + + // At this point the type is UniformQuantizedType + return success(); } OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { @@ -205,10 +313,8 @@ QuantizedType StorageCastOp::getQuantizedType() { return cast(resultScalarType); } - } // namespace quant } // namespace mlir #define GET_OP_CLASSES #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" - diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 7c0d369648651..9b8eec609b039 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "TypeDetail.h" #include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -34,7 +34,7 @@ double getMaxScale(Type expressedType) { return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); } -} // namespace +} // namespace unsigned QuantizedType::getFlags() const { return static_cast(impl)->flags; @@ -410,6 +410,123 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { return getImpl()->quantizedDimension; } +UniformQuantizedSubChannelType UniformQuantizedSubChannelType::get( + unsigned flags, Type storageType, Type expressedType, + DenseElementsAttr scales, DenseElementsAttr zeroPoints, + ArrayRef quantizedDimensions, ArrayRef blockSizes, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::get(storageType.getContext(), flags, storageType, expressedType, + scales, zeroPoints, quantizedDimensions, blockSizes, + storageTypeMin, storageTypeMax); +} + +UniformQuantizedSubChannelType UniformQuantizedSubChannelType::getChecked( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, DenseElementsAttr scales, + DenseElementsAttr zeroPoints, ArrayRef quantizedDimensions, + ArrayRef blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, expressedType, scales, zeroPoints, + quantizedDimensions, blockSizes, storageTypeMin, + storageTypeMax); +} + +LogicalResult UniformQuantizedSubChannelType::verifyInvariants( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, DenseElementsAttr scales, + DenseElementsAttr zeroPoints, ArrayRef quantizedDimensions, + ArrayRef blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, + expressedType, storageTypeMin, + storageTypeMax))) { + return failure(); + } + + // Uniform quantization requires fully expressed parameters, including + // expressed type. + if (!expressedType) + return emitError() << "uniform quantization requires expressed type"; + + // Verify that the expressed type is floating point. + // If this restriction is ever eliminated, the parser/printer must be + // extended. + if (!llvm::isa(expressedType)) + return emitError() << "expressed type must be floating point"; + + // Verify scale type to match expressedType. + if (scales.getType().getElementType() != expressedType) { + return emitError() << "type of scale values " + << scales.getType().getElementType() + << " must match the expressed type " << expressedType; + } + + // Verify zero-point type to match storageType. + if (zeroPoints.getType().getElementType() != storageType) { + return emitError() << "type of zero point values " + << zeroPoints.getType().getElementType() + << " must match the storage type " << storageType; + } + + // Ensure that the shape of scales and zeroPoints match. + if (scales.getType().getShape() != zeroPoints.getType().getShape()) + return emitError() << "shape of scales and zeroPoints (" + << scales.getType().getShape() << " vs " + << zeroPoints.getType().getShape() << ") does not match"; + + // Ensure that the number of quantized-dimensions and block-sizes match. + if (quantizedDimensions.size() != blockSizes.size()) + return emitError() << "number of quantized dimensions and block sizes (" + << scales.size() << " vs " << zeroPoints.size() + << ") does not match"; + + // Verify quantized dimension. + for (auto quantizedDimension : quantizedDimensions) { + if (quantizedDimension < 0) + return emitError() << "illegal quantized dimension: " + << quantizedDimension; + } + + // Verify block sizes. + for (auto blockSize : blockSizes) { + if (blockSize <= 0) + return emitError() << "illegal block size: " << blockSize; + } + + return success(); +} + +DenseElementsAttr UniformQuantizedSubChannelType::getScales() const { + return getImpl()->getScales(); +} + +DenseElementsAttr UniformQuantizedSubChannelType::getZeroPoints() const { + return getImpl()->getZeroPoints(); +} + +ArrayRef +UniformQuantizedSubChannelType::getQuantizedDimensions() const { + return getImpl()->getQuantizedDimensions(); +} + +ArrayRef UniformQuantizedSubChannelType::getBlockSizes() const { + return getImpl()->getBlockSizes(); +} + +const SmallVector> +UniformQuantizedSubChannelType::getBlockSizeInfo() const { + SmallVector> result; + result.reserve(getQuantizedDimensions().size()); + + for (auto [dim, size] : + llvm::zip(getQuantizedDimensions(), getBlockSizes())) { + result.push_back({dim, size}); + } + + return result; +} + CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType, double min, double max) { return Base::get(expressedType.getContext(), expressedType, min, max); diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h index ef098811927cd..bb38b1a2a91e2 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h +++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h @@ -9,6 +9,7 @@ #ifndef TYPE_DETAIL_H_ #define TYPE_DETAIL_H_ +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" @@ -253,6 +254,127 @@ struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { int32_t quantizedDimension; }; +struct UniformQuantizedSubChannelTypeStorage : public QuantizedTypeStorage { + struct KeyTy { + KeyTy(unsigned flags, Type storageType, Type expressedType, + DenseElementsAttr scales, DenseElementsAttr zeroPoints, + ArrayRef quantizedDimensions, ArrayRef blockSizes, + int64_t storageTypeMin, int64_t storageTypeMax) + : flags(flags), storageType(storageType), expressedType(expressedType), + scales(scales), zeroPoints(zeroPoints), + quantizedDimensions(quantizedDimensions), blockSizes(blockSizes), + storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} + /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. + unsigned flags; + + // Integral type for the storage point representation. + Type storageType; + + // Floating point type that the quantized type approximates. + Type expressedType; + + DenseElementsAttr scales; + DenseElementsAttr zeroPoints; + ArrayRef quantizedDimensions; + ArrayRef blockSizes; + int64_t storageTypeMin; + int64_t storageTypeMax; + + DenseElementsAttr getScales() const { return scales; } + + DenseElementsAttr getZeroPoints() const { return zeroPoints; } + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && + lhs.expressedType == rhs.expressedType && + lhs.scales == rhs.scales && lhs.zeroPoints == rhs.zeroPoints && + lhs.quantizedDimensions == rhs.quantizedDimensions && + lhs.blockSizes == rhs.blockSizes && + lhs.storageTypeMin == rhs.storageTypeMin && + lhs.storageTypeMax == rhs.storageTypeMax; + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + // Hash the scalar attributes. + unsigned hash = llvm::hash_combine(flags, storageType, expressedType, + storageTypeMin, storageTypeMax); + + // Hash the scales. + for (auto scaleAttr : scales.getValues()) { + hash = llvm::hash_combine( + hash, llvm::bit_cast(scaleAttr.convertToDouble())); + } + + // Hash the zero points. (Assumed to be integers, adjust if needed). + for (auto zeroPointAttr : zeroPoints.getValues()) { + hash = llvm::hash_combine(hash, zeroPointAttr.getSExtValue()); + } + + // Hash the quantized dimensions and block sizes. + hash = llvm::hash_combine( + hash, + llvm::hash_combine_range(quantizedDimensions.begin(), + quantizedDimensions.end()), + llvm::hash_combine_range(blockSizes.begin(), blockSizes.end())); + + return hash; + } + }; + + // We pass scales and zeroPoints in directly rather than relying on KeyTy + // because we have to create new reallocated versions in `construct` below. + UniformQuantizedSubChannelTypeStorage(const KeyTy &key, + DenseElementsAttr scales, + DenseElementsAttr zeroPoints, + ArrayRef quantizedDimensions, + ArrayRef blockSizes) + : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, + key.storageTypeMin, key.storageTypeMax), + scales(scales), zeroPoints(zeroPoints), + quantizedDimensions(quantizedDimensions), blockSizes(blockSizes) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static UniformQuantizedSubChannelTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + DenseElementsAttr scales = key.scales; + DenseElementsAttr zeroPoints = key.zeroPoints; + ArrayRef quantizedDimensions = + allocator.copyInto(key.quantizedDimensions); + ArrayRef blockSizes = allocator.copyInto(key.blockSizes); + return new (allocator.allocate()) + UniformQuantizedSubChannelTypeStorage(key, scales, zeroPoints, + quantizedDimensions, blockSizes); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + DenseElementsAttr getScales() const { return scales; } + + DenseElementsAttr getZeroPoints() const { return zeroPoints; } + + ArrayRef getQuantizedDimensions() const { + return quantizedDimensions; + } + + ArrayRef getBlockSizes() const { return blockSizes; } + + DenseElementsAttr scales; + DenseElementsAttr zeroPoints; + ArrayRef quantizedDimensions; + ArrayRef blockSizes; +}; + struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage { struct KeyTy { KeyTy(Type expressedType, double min, double max) diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 851763d8942e8..c6a6881b46f26 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -159,38 +159,173 @@ static Type parseAnyType(DialectAsmParser &parser) { typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); } -static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, +/// Checks if the given scale value is within the valid range of the expressed +/// type. The `expressedType` argument is the floating-point type used for +/// expressing the quantized values, and `scale` is the double value to check. +LogicalResult +isScaleInExpressedTypeRange(function_ref emitError, + Type expressedType, double scale) { + auto floatType = cast(expressedType); + double minScale = + APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble(); + double maxScale = + APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); + if (scale < minScale || scale > maxScale) + return emitError() << "scale " << scale << " out of expressed type range [" + << minScale << ", " << maxScale << "]"; + return success(); +} + +/// Parses a quantization parameter, which is either a scale value (float) or a +/// scale-zero point pair (float:integer). `expressedType`, expressing the type +/// of scale values, is used to validate the scale. The parsed scale and zero +/// point (if any) are stored in `scale` and `zeroPoint`. +static ParseResult parseQuantParams(DialectAsmParser &parser, + Type expressedType, double &scale, int64_t &zeroPoint) { - // scale[:zeroPoint]? - // scale. - if (parser.parseFloat(scale)) + + if (parser.parseFloat(scale)) { return failure(); + } + + if (failed(isScaleInExpressedTypeRange( + [&]() { return parser.emitError(parser.getCurrentLocation()); }, + expressedType, scale))) { + return failure(); + } - // zero point. zeroPoint = 0; if (failed(parser.parseOptionalColon())) { - // Default zero point. return success(); } return parser.parseInteger(zeroPoint); } +/// Parses block size information for sub-channel quantization, assuming the +/// leading '{' has already been parsed. The block size information is provided +/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'. +/// +/// The parsed axis indices are stored in `quantizedDimensions`, and the +/// corresponding block sizes are stored in `blockSizes`. +static ParseResult +parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, + SmallVectorImpl &quantizedDimensions, + SmallVectorImpl &blockSizes) { + // Empty block-sizes info. + if (succeeded(parser.parseOptionalRBrace())) { + return success(); + } + + auto parseBlockSizeElements = [&]() -> ParseResult { + quantizedDimensions.resize(quantizedDimensions.size() + 1); + blockSizes.resize(blockSizes.size() + 1); + if (parser.parseInteger(quantizedDimensions.back()) || + parser.parseColon() || parser.parseInteger(blockSizes.back())) + return failure(); + return success(); + }; + + if (parser.parseCommaSeparatedList(parseBlockSizeElements) || + parser.parseRBrace()) { + return failure(); + } + + return success(); +} + +/// Parses a bracketed list of quantization parameters, returning the dimensions +/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended +/// to the dimensions of the sub-tensors. This function assumes that the initial +/// left brace has already been parsed. For example: +/// +/// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success, +/// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4] +/// +/// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success, +/// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1, +/// 9] +/// +/// This function expects all sub-tensors to have the same rank. +static ParseResult +parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, + SmallVectorImpl &scales, + SmallVectorImpl &zeroPoints, + SmallVectorImpl &dims) { + auto checkDims = [&](const SmallVectorImpl &prevDims, + const SmallVectorImpl &newDims) -> ParseResult { + if (prevDims == newDims) + return success(); + return parser.emitError(parser.getCurrentLocation()) + << "tensor literal is invalid; ranks are not consistent " + "between elements"; + }; + + bool first = true; + SmallVector newDims; + unsigned size = 0; + + auto parseOneElement = [&]() -> ParseResult { + SmallVector thisDims; + if (succeeded(parser.parseOptionalLBrace())) { + if (parseQuantParamListUntilRBrace(parser, expressedType, scales, + zeroPoints, thisDims)) + return failure(); + } else { + zeroPoints.resize(zeroPoints.size() + 1); + scales.resize(scales.size() + 1); + if (parseQuantParams(parser, expressedType, scales.back(), + zeroPoints.back())) { + return failure(); + } + } + ++size; + if (!first) + return checkDims(newDims, thisDims); + newDims = thisDims; + first = false; + return success(); + }; + + if (parser.parseCommaSeparatedList(parseOneElement) || parser.parseRBrace()) { + return failure(); + } + + // Return the sublists' dimensions with 'size' prepended. + dims.clear(); + dims.push_back(size); + dims.append(newDims.begin(), newDims.end()); + + return success(); +} + /// Parses a UniformQuantizedType. /// /// uniform_type ::= uniform_per_layer /// | uniform_per_axis +/// | uniform_sub_channel /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec /// `,` scale-zero `>` /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec -/// axis-spec `,` scale-zero-list `>` +/// axis-spec `,` `{` scale-zero-list `}` `>` +/// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec +/// block-size-info `,` scale-zero-tensor `>` /// storage-spec ::= storage-type (`<` storage-range `>`)? /// storage-range ::= integer-literal `:` integer-literal /// storage-type ::= (`i` | `u`) integer-literal /// expressed-type-spec ::= `:` `f` integer-literal /// axis-spec ::= `:` integer-literal -/// scale-zero ::= float-literal `:` integer-literal -/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}` +/// scale-zero ::= scale (`:` zero-point)? +/// scale ::= float-literal +/// zero-point ::= integer-literal +/// scale-zero-list ::= scale-zero (`,` scale-zero)* +/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}` +/// axis-block ::= axis-spec `:` block-size-spec +/// block-size-spec ::= integer-literal +/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list +/// scale-zero-dense-exp ::= `{` +/// scale-zero-tensor (`,` scale-zero-tensor)* +/// `}` static Type parseUniformType(DialectAsmParser &parser) { IntegerType storageType; FloatType expressedType; @@ -198,7 +333,9 @@ static Type parseUniformType(DialectAsmParser &parser) { int64_t storageTypeMin; int64_t storageTypeMax; bool isPerAxis = false; - int32_t quantizedDimension; + bool isSubChannel = false; + SmallVector quantizedDimensions; + SmallVector blockSizes; SmallVector scales; SmallVector zeroPoints; @@ -228,11 +365,22 @@ static Type parseUniformType(DialectAsmParser &parser) { return nullptr; } - // Optionally parse quantized dimension for per-axis quantization. + // Optionally parse quantized dimension for per-axis or sub-channel + // quantization. if (succeeded(parser.parseOptionalColon())) { - if (parser.parseInteger(quantizedDimension)) - return nullptr; - isPerAxis = true; + if (succeeded(parser.parseOptionalLBrace())) { + isSubChannel = true; + if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions, + blockSizes)) { + return nullptr; + } + } else { + isPerAxis = true; + quantizedDimensions.resize(1); + if (parser.parseInteger(quantizedDimensions.back())) { + return nullptr; + } + } } // Comma leading into range_spec. @@ -240,26 +388,21 @@ static Type parseUniformType(DialectAsmParser &parser) { return nullptr; } - // Parameter specification. - // For per-axis, ranges are in a {} delimitted list. - if (isPerAxis) { - if (parser.parseLBrace()) { - return nullptr; - } - } - - // Parse scales/zeroPoints. - SMLoc scaleZPLoc = parser.getCurrentLocation(); - do { - scales.resize(scales.size() + 1); + // Quantization parameter (scales/zeroPoints) specification. + bool isPerTensor = !isPerAxis && !isSubChannel; + SmallVector dims; + if (isPerTensor) { zeroPoints.resize(zeroPoints.size() + 1); - if (parseQuantParams(parser, scales.back(), zeroPoints.back())) { + scales.resize(scales.size() + 1); + if (parseQuantParams(parser, expressedType, scales.back(), + zeroPoints.back())) { return nullptr; } - } while (isPerAxis && succeeded(parser.parseOptionalComma())); - if (isPerAxis) { - if (parser.parseRBrace()) { + } else { + if (parser.parseLBrace() || + parseQuantParamListUntilRBrace(parser, expressedType, scales, + zeroPoints, dims)) { return nullptr; } } @@ -268,19 +411,30 @@ static Type parseUniformType(DialectAsmParser &parser) { return nullptr; } - if (!isPerAxis && scales.size() > 1) { - return (parser.emitError(scaleZPLoc, - "multiple scales/zeroPoints provided, but " - "quantizedDimension wasn't specified"), - nullptr); - } - if (isPerAxis) { - ArrayRef scalesRef(scales.begin(), scales.end()); - ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); return parser.getChecked( + typeFlags, storageType, expressedType, scales, zeroPoints, + quantizedDimensions[0], storageTypeMin, storageTypeMax); + } else if (isSubChannel) { + SmallVector apFloatScales = + llvm::to_vector(llvm::map_range(scales, [&](double scale) -> APFloat { + APFloat apFloatScale(scale); + bool unused; + apFloatScale.convert(expressedType.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &unused); + return apFloatScale; + })); + SmallVector apIntZeroPoints = llvm::to_vector( + llvm::map_range(zeroPoints, [&](int64_t zeroPoint) -> APInt { + return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint); + })); + auto scalesRef = mlir::DenseElementsAttr::get( + RankedTensorType::get(dims, expressedType), apFloatScales); + auto zeroPointsRef = mlir::DenseElementsAttr::get( + RankedTensorType::get(dims, storageType), apIntZeroPoints); + return parser.getChecked( typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, - quantizedDimension, storageTypeMin, storageTypeMax); + quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax); } return parser.getChecked( @@ -360,6 +514,17 @@ static void printQuantParams(double scale, int64_t zeroPoint, } } +static void +printBlockSizeInfo(ArrayRef> blockSizeInfo, + DialectAsmPrinter &out) { + out << "{"; + llvm::interleaveComma( + llvm::seq(0, blockSizeInfo.size()), out, [&](size_t index) { + out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second; + }); + out << "}"; +} + /// Helper that prints a AnyQuantizedType. static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out) { @@ -405,6 +570,74 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, out << "}>"; } +/// Prints quantization parameters as a nested list of `scale`[:`zero_point`] +/// elements. The nesting corresponds to the `shape` dimensions. +/// +/// Elements are delimited by commas, and the inner dimensions are enclosed in +/// braces. `zero_point` is only printed if it is non-zero. For example: +/// +/// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0], +/// zeroPoints=[0, 0, 1, 9], +/// shape=[2, 2]) +/// +/// would print: +/// +/// {{1.0, 2.0}, {3.0:1, 4.0:9}} +void printDenseQuantizationParameters(ArrayRef scales, + ArrayRef zeroPoints, + ArrayRef shape, + DialectAsmPrinter &out) { + int64_t rank = shape.size(); + SmallVector counter(rank, 0); + unsigned openBrackets = 0; + + auto incrementCounterAndDelimit = [&]() { + ++counter[rank - 1]; + for (unsigned i = rank - 1; i > 0; --i) { + if (counter[i] >= shape[i]) { + counter[i] = 0; + ++counter[i - 1]; + --openBrackets; + out << '}'; + } + } + }; + + for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) { + if (idx != 0) + out << ", "; + while (openBrackets++ < rank) + out << '{'; + openBrackets = rank; + out << scales[idx]; + if (zeroPoints[idx] != 0) { + out << ":" << zeroPoints[idx]; + } + incrementCounterAndDelimit(); + } + while (openBrackets-- > 0) + out << '}'; +} + +/// Helper that prints a UniformQuantizedSubChannelType. +static void +printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, + DialectAsmPrinter &out) { + out << "uniform<"; + printStorageType(type, out); + out << ":" << type.getExpressedType() << ":"; + printBlockSizeInfo(type.getBlockSizeInfo(), out); + out << ", "; + + auto scalesItr = type.getScales().getValues(); + auto zeroPointsItr = type.getZeroPoints().getValues(); + SmallVector scales(scalesItr.begin(), scalesItr.end()); + SmallVector zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end()); + printDenseQuantizationParameters(scales, zeroPoints, + type.getScales().getType().getShape(), out); + out << ">"; +} + /// Helper that prints a CalibratedQuantizedType. static void printCalibratedQuantizedType(CalibratedQuantizedType type, DialectAsmPrinter &out) { @@ -421,6 +654,9 @@ void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { printUniformQuantizedType(uniformType, os); else if (auto perAxisType = llvm::dyn_cast(type)) printUniformQuantizedPerAxisType(perAxisType, os); + else if (auto perAxisType = + llvm::dyn_cast(type)) + printUniformQuantizedSubChannelType(perAxisType, os); else if (auto calibratedType = llvm::dyn_cast(type)) printCalibratedQuantizedType(calibratedType, os); else diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt index 2fd4a41999d45..825d11992d309 100644 --- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRQuantTransforms LowerQuantOps.cpp + NormalizeQuantTypes.cpp StripFuncQuantTypes.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 4adeb9218ff8e..c2dbcde1aeba6 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -38,11 +38,11 @@ Type getScalarType(Type inputType) { return inputType; } -// Return the shape of an input value as a list of attributes (static dimensions) -// and values (dynamic dimensions). If 'input' is a scalar, an empty list is -// returned. If 'input' is a tensor, its shape is returned. -SmallVector -getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) { +// Return the shape of an input value as a list of attributes (static +// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty +// list is returned. If 'input' is a tensor, its shape is returned. +SmallVector getScalarOrTensorShape(OpBuilder &builder, + Location loc, Value input) { if (isa(input.getType())) return tensor::getMixedSizes(builder, loc, input); return {}; @@ -100,16 +100,16 @@ std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, // Turn input size into 1D tensor auto flatShapeType = shape::getExtentTensorType(context, 1); - auto flatInputShape = builder.create( - loc, flatShapeType, inputSize); + auto flatInputShape = + builder.create(loc, flatShapeType, inputSize); // Reshape input tensor into 1D auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto flatInputType = RankedTensorType::get({ShapedType::kDynamic}, elementType); - auto flatInput = builder.create( - loc, flatInputType, input, flatInputShape); + auto flatInput = builder.create(loc, flatInputType, input, + flatInputShape); return std::make_pair(flatInput, inputShape); } @@ -135,11 +135,9 @@ std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, // - inputShape // 1D extent tensor containing the shape of the original unranked input. // -std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, - Location loc, - Value input, - int64_t axis, - int64_t axisSize) { +std::pair +flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, + int64_t axis, int64_t axisSize) { // Get full tensor shape auto *context = builder.getContext(); auto indexType = builder.getIndexType(); @@ -149,16 +147,20 @@ std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, // Get shape and sizes on left and right of axis auto axisValue = builder.create(loc, axis); auto axisNextValue = builder.create(loc, axis + 1); - auto shapeLeft = builder.create( - loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) - .getResult(0); - auto sizeLeft = builder.create( - loc, indexType, shapeLeft); - auto shapeRight = builder.create( - loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) - .getResult(1); - auto sizeRight = builder.create( - loc, indexType, shapeRight); + auto shapeLeft = + builder + .create(loc, TypeRange{shapeType, shapeType}, + inputShape, axisValue) + .getResult(0); + auto sizeLeft = + builder.create(loc, indexType, shapeLeft); + auto shapeRight = + builder + .create(loc, TypeRange{shapeType, shapeType}, + inputShape, axisNextValue) + .getResult(1); + auto sizeRight = + builder.create(loc, indexType, shapeRight); // Compute flat input shape as a 3-element 1D tensor auto axisSizeValue = builder.create(loc, axisSize); @@ -171,8 +173,8 @@ std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, auto elementType = inputType.getElementType(); auto flatInputType = RankedTensorType::get( {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); - auto flatInput = builder.create( - loc, flatInputType, input, flatInputShape); + auto flatInput = builder.create(loc, flatInputType, input, + flatInputShape); return std::make_pair(flatInput, inputShape); } @@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto unrankedType = UnrankedTensorType::get(elementType); - return builder.create(loc, unrankedType, input, inputShape); + return builder.create(loc, unrankedType, input, + inputShape); } // Create a tensor constant containing all scales in a per-channel quantized @@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc, auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute { return builder.getFloatAttr(expressedType, scale); }); - auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType); + auto tensorType = + RankedTensorType::get({(int64_t)scales.size()}, expressedType); auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); return builder.create(loc, tensorType, scalesAttr); } @@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints( UniformQuantizedPerAxisType quantizedType) { auto zeroPoints = quantizedType.getZeroPoints(); auto storageType = quantizedType.getStorageType(); - auto zeroPointAttrs = llvm::map_to_vector( - zeroPoints, - [&](int64_t zeroPoint) -> Attribute { + auto zeroPointAttrs = + llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute { return builder.getIntegerAttr(storageType, zeroPoint); }); auto tensorType = @@ -239,6 +242,54 @@ Value materializePerChannelZeroPoints( return builder.create(loc, tensorType, zeroPointsAttr); } +// Create a tensor constant containing all scales in a sub-channel quantized +// type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32> +// +Value materializeSubChannelScales( + OpBuilder &builder, Location loc, + UniformQuantizedSubChannelType quantizedType) { + auto scales = quantizedType.getScales(); + auto expressedType = quantizedType.getExpressedType(); + auto scaleAttrs = llvm::map_to_vector( + scales.getValues(), [&](APFloat scale) -> Attribute { + return builder.getFloatAttr(expressedType, scale); + }); + auto tensorType = + RankedTensorType::get(scales.getType().getShape(), expressedType); + auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); + return builder.create(loc, tensorType, scalesAttr); +} + +// Create a tensor constant containing all zero points in a sub-channel +// quantized type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8> +// +Value materializeSubChannelZeroPoints( + OpBuilder &builder, Location loc, + UniformQuantizedSubChannelType quantizedType) { + auto zeroPoints = quantizedType.getZeroPoints(); + auto storageType = quantizedType.getStorageType(); + auto zeroPointAttrs = llvm::map_to_vector( + zeroPoints.getValues(), [&](APInt zeroPoint) -> Attribute { + return builder.getIntegerAttr(storageType, zeroPoint); + }); + auto tensorType = + RankedTensorType::get(zeroPoints.getType().getShape(), storageType); + auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); + return builder.create(loc, tensorType, zeroPointsAttr); +} + // Clamp the given scalar or tensor input using the storage bounds encoded in // the given quantized type, if present. // @@ -299,7 +350,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, return builder.create(loc, resultType, input); } -// Quantize a scalar or ranked tensor value. The stored value is clamped using +// Quantize a scalar or ranked tensor value. The stored value is clamped using // the storage bounds encoded in the given quantized type. // // See function 'convertRanked()' below for a description of the arguments. @@ -308,8 +359,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, Value zeroPoint, QuantizedType quantizedType) { // Convert scale to tensor if necessary auto inputType = input.getType(); - scale = getScalarOrTensorConstant( - builder, loc, scale, inputType, inputShape); + scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape); // Scale input auto scaledValue = builder.create(loc, input, scale); @@ -322,8 +372,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, inputShape); // Convert zero point from storage to expressed type - zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, - scale.getType(), + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(), quantizedType.isSigned()); // Add zero point to stored value @@ -334,9 +383,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, // Convert stored value to storage type auto storageScalarOrTensorType = getScalarOrTensorType(quantizedType.getStorageType(), inputType); - auto storedValueInt = convertFloatToInteger( - builder, loc, storedValueFloat, storageScalarOrTensorType, - quantizedType.isSigned()); + auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat, + storageScalarOrTensorType, + quantizedType.isSigned()); // Clamp stored value it if the storage type is bound auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt, @@ -352,12 +401,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input, Value zeroPoint, QuantizedType quantizedType) { // Convert scale to tensor if necessary auto inputType = input.getType(); - scale = getScalarOrTensorConstant( - builder, loc, scale, inputType, inputShape); + scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape); // Convert stored value to float - auto result = convertIntegerToFloat( - builder, loc, input, scale.getType(), quantizedType.isSigned()); + auto result = convertIntegerToFloat(builder, loc, input, scale.getType(), + quantizedType.isSigned()); // Skip unnecessary computations if no zero point is given if (!matchPattern(zeroPoint, m_Zero())) { @@ -366,8 +414,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input, inputShape); // Convert zero point from storage to expressed type - zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, - scale.getType(), + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(), quantizedType.isSigned()); // Subtract zero point to stored value @@ -501,35 +548,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, auto initShape = tensor::getMixedSizes(builder, loc, input); Value init = builder.create(loc, initShape, elementType); - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); + SmallVector iteratorTypes(inputRank, + utils::IteratorType::parallel); auto channelAxisAffineMap = AffineMap::get( inputRank, 0, builder.getAffineDimExpr(channelAxis), context); SmallVector indexingMaps{ - builder.getMultiDimIdentityMap(inputRank), - channelAxisAffineMap, - channelAxisAffineMap, - builder.getMultiDimIdentityMap(inputRank) - }; - auto result = builder.create( - loc, - init.getType(), // resultType - ValueRange{input, scales, zeroPoints}, // inputs - ValueRange{init}, // outputs - indexingMaps, - iteratorTypes, - [&](OpBuilder& builder, Location loc, ValueRange args) { - assert(args.size() == 4); - auto input = args[0]; - auto scale = args[1]; - auto zeroPoint = args[2]; - - auto result = convertRanked(builder, loc, op, input, {}, scale, - zeroPoint, quantizedType); - - builder.create(loc, result); - }) - .getResult(0); + builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap, + channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)}; + auto result = builder + .create( + loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = + convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + builder.create(loc, result); + }) + .getResult(0); return result; } @@ -551,7 +596,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, // Flatten unranked tensor into a 3D ranked tensor if necessary bool isUnranked = isa(input.getType()); int64_t channelAxis = quantizedType.getQuantizedDimension(); - int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); + int64_t channelAxisSize = (int64_t)quantizedType.getScales().size(); Value inputShape; if (isUnranked) { std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( @@ -570,6 +615,73 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, return result; } +// Convert an operation using sub-channel quantization. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor. +// +// - quantizedType +// Sub-channel quantized type. +// +Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op, + Value input, + UniformQuantizedSubChannelType quantizedType) { + auto *context = builder.getContext(); + + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + auto scales = materializeSubChannelScales(builder, loc, quantizedType); + auto zeroPoints = + materializeSubChannelZeroPoints(builder, loc, quantizedType); + + auto elementType = isa(inputType.getElementType()) + ? quantizedType.getStorageType() + : quantizedType.getExpressedType(); + auto initShape = tensor::getMixedSizes(builder, loc, input); + Value init = builder.create(loc, initShape, elementType); + + SmallVector iteratorTypes(inputRank, + utils::IteratorType::parallel); + const SmallVector> &blockSizeInfo = + quantizedType.getBlockSizeInfo(); + SmallVector affineExprs(inputRank, + builder.getAffineConstantExpr(0)); + for (auto [quantizedDimension, blockSize] : blockSizeInfo) { + affineExprs[quantizedDimension] = + builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize); + } + auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context); + SmallVector indexingMaps{ + builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap, + builder.getMultiDimIdentityMap(inputRank)}; + auto result = builder + .create( + loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = + convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + builder.create(loc, result); + }) + .getResult(0); + + return result; +} + // Convert a quantization operation. // // - op @@ -593,11 +705,17 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, return convertPerChannel(builder, loc, op, input, uniformQuantizedPerAxisType); + if (auto uniformQuantizedSubChannelType = + dyn_cast(quantizedType)) + return convertSubChannel(builder, loc, op, input, + uniformQuantizedSubChannelType); + llvm_unreachable("unexpected quantized type"); } // Lowering pattern for 'quant.dcast' -struct DequantizeCastOpConversion : public OpConversionPattern { +struct DequantizeCastOpConversion + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -622,7 +740,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern { +struct QuantizeCastOpConversion + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -650,12 +769,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase { ConversionTarget target(getContext()); target.addLegalOp(); target.addIllegalDialect(); - target.addLegalDialect< - arith::ArithDialect, - linalg::LinalgDialect, - shape::ShapeDialect, - tensor::TensorDialect - >(); + target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -666,10 +781,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase { } // namespace void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) { - patterns.add< - DequantizeCastOpConversion, - QuantizeCastOpConversion - >(patterns.getContext()); + patterns.add( + patterns.getContext()); } } // namespace quant diff --git a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp new file mode 100644 index 0000000000000..030cf07794377 --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp @@ -0,0 +1,179 @@ +//===- NormalizeQuantTypes.cpp - Normalize quantized types +//----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Normalize generic quantized types to specific quantized types +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DEF_NORMALIZEQUANTTYPES +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +namespace { + +/// Returns true if the given sub-channel quantized type is convertible to a +/// per-tensor quantized type. This is true if the sub-channel type has only +/// one scale and one zero point. +/// +/// Assumes that `tensorType` is a tensor with element type +/// `quant::UniformQuantizedSubChannelType`. +static bool isConvertibleToPerTensor(TensorType tensorType) { + return cast(tensorType.getElementType()) + .getScales() + .getType() + .getNumElements() == 1; +} + +/// Returns true if the given sub-channel quantized type is convertible to a +/// per-axis quantized type. This is true if the shape of the scales tensor has +/// all but one non-one value. +/// +/// Assumes that `tensorType` is a tensor with element type +/// `quant::UniformQuantizedSubChannelType`. +static bool isConvertibleToPerAxis(TensorType tensorType) { + auto shape = cast(tensorType.getElementType()) + .getScales() + .getType() + .getShape(); + return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1; +} + +/// This class defines a type converter that converts sub-channel quantized +/// types to per-tensor or per-axis quantized types whenever possible. +class NormalizedQuantTypesConverter : public TypeConverter { + + static Type convertType(Type type) { + auto tensorType = dyn_cast(type); + if (!tensorType) { + return type; + } + + auto subChannelType = + dyn_cast(tensorType.getElementType()); + if (!subChannelType) { + return type; + } + + if (isConvertibleToPerTensor(tensorType)) { + double scale = + subChannelType.getScales().getValues()[0].convertToDouble(); + int64_t zeroPoint = + subChannelType.getZeroPoints().getValues()[0].getSExtValue(); + auto perTensorType = UniformQuantizedType::get( + subChannelType.getFlags(), subChannelType.getStorageType(), + subChannelType.getExpressedType(), scale, zeroPoint, + subChannelType.getStorageTypeMin(), + subChannelType.getStorageTypeMax()); + return tensorType.clone(perTensorType); + } + + if (isConvertibleToPerAxis(tensorType)) { + auto shape = subChannelType.getScales().getType().getShape(); + auto quantizedDimItr = + llvm::find_if(shape, [](int64_t dim) { return dim != 1; }); + auto scales = llvm::to_vector(llvm::map_range( + subChannelType.getScales().getValues(), + [](APFloat scale) { return scale.convertToDouble(); })); + auto zeroPoints = llvm::to_vector(llvm::map_range( + subChannelType.getZeroPoints().getValues(), + [](APInt zeroPoint) { return zeroPoint.getSExtValue(); })); + auto perAxisType = UniformQuantizedPerAxisType::get( + subChannelType.getFlags(), subChannelType.getStorageType(), + subChannelType.getExpressedType(), scales, zeroPoints, + quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(), + subChannelType.getStorageTypeMax()); + return tensorType.clone(perAxisType); + } + return type; + } + +public: + explicit NormalizedQuantTypesConverter() { addConversion(convertType); } +}; + +/// This class implements a conversion pattern that converts any generic +/// operation with sub-channel quantized types to an equivalent operation with +/// per-tensor or per-axis quantized types. +class ConvertGenericOpwithSubChannelType : public ConversionPattern { +public: + ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SmallVector resultTypes; + if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes))) + return failure(); + + auto *newOp = Operation::create( + op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(), + op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions()); + for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) { + Region &before = std::get<0>(regions); + Region &parent = std::get<1>(regions); + rewriter.inlineRegionBefore(before, parent, parent.end()); + if (failed(rewriter.convertRegionTypes(&parent, *typeConverter))) + return failure(); + } + rewriter.insert(newOp); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +// Conversion pass +class NormalizeQuantTypes + : public impl::NormalizeQuantTypesBase { +public: + void runOnOperation() override { + + auto *context = &getContext(); + + NormalizedQuantTypesConverter typeConverter; + ConversionTarget target(*context); + + // Determine legal operations. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + // Register conversion patterns + RewritePatternSet patterns(context); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + patterns.add(typeConverter, context); + + // Apply conversion + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +} // namespace quant +} // namespace mlir diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi index 47168d49c5568..3f5304584edef 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from mlir.ir import Type +from mlir.ir import DenseElementsAttr, Type __all__ = [ "QuantizedType", @@ -109,6 +109,26 @@ class UniformQuantizedPerAxisType(QuantizedType): @property def is_fixed_point(self) -> bool: ... +class UniformQuantizedSubChannelType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: DenseElementsAttr, zero_points: DenseElementsAttr, + quantized_dimensions: list[int], block_sizes: list[int], + storage_type_min: int, storage_type_max: int): + ... + + @property + def quantized_dimensions(self) -> list[int]: ... + + @property + def block_sizes(self) -> list[int]: ... + + @property + def scales(self) -> DenseElementsAttr: ... + + @property + def zero_points(self) -> DenseElementsAttr: ... def CalibratedQuantizedType(QuantizedType): diff --git a/mlir/test/CAPI/quant.c b/mlir/test/CAPI/quant.c index 0a09e084119f7..30f376ebeb112 100644 --- a/mlir/test/CAPI/quant.c +++ b/mlir/test/CAPI/quant.c @@ -10,6 +10,7 @@ // RUN: mlir-capi-quant-test 2>&1 | FileCheck %s #include "mlir-c/Dialect/Quant.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/IR.h" @@ -203,6 +204,130 @@ void testUniformPerAxisType(MlirContext ctx) { fprintf(stderr, "\n\n"); } +// CHECK-LABEL: testUniformSubChannelType +void testUniformSubChannelType(MlirContext ctx) { + fprintf(stderr, "testUniformSubChannelType\n"); + + MlirType subChannelParsed = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString( + "!quant.uniform")); + + MlirType i8 = mlirIntegerTypeGet(ctx, 8); + MlirType f32 = mlirF32TypeGet(ctx); + + // block-size information + int32_t quantizedDimensions[] = {0, 1}; + int64_t blockSizes[] = {1, 2}; + int64_t numBlockSizes = 2; + + // quantization parameters + int64_t quantParamShape[] = {2, 2}; + int64_t quantParamRank = 2; + int64_t numQuantizationParams = 4; + MlirAttribute scales[] = {mlirFloatAttrDoubleGet(ctx, f32, 2.0), + mlirFloatAttrDoubleGet(ctx, f32, 3.0), + mlirFloatAttrDoubleGet(ctx, f32, 4.0), + mlirFloatAttrDoubleGet(ctx, f32, 5.0)}; + MlirAttribute zeroPoints[] = { + mlirIntegerAttrGet(i8, 10), mlirIntegerAttrGet(i8, 20), + mlirIntegerAttrGet(i8, 30), mlirIntegerAttrGet(i8, 40)}; + + MlirType scalesType = + mlirRankedTensorTypeGet(quantParamRank, quantParamShape, f32, + /*encoding=*/mlirAttributeGetNull()); + MlirType zeroPointsType = mlirRankedTensorTypeGet( + quantParamRank, quantParamShape, i8, /*encoding=*/mlirAttributeGetNull()); + MlirAttribute denseScalesAttr = + mlirDenseElementsAttrGet(scalesType, numQuantizationParams, scales); + MlirAttribute denseZeroPointsAttr = mlirDenseElementsAttrGet( + zeroPointsType, numQuantizationParams, zeroPoints); + + MlirType subChannel = mlirUniformQuantizedSubChannelTypeGet( + mlirQuantizedTypeGetSignedFlag(), i8, f32, denseScalesAttr, + denseZeroPointsAttr, numBlockSizes, quantizedDimensions, blockSizes, + mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true, + /*integralWidth=*/8), + mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true, + /*integralWidth=*/8)); + + MlirAttribute arrayScalesAttr = + mlirArrayAttrGet(ctx, numQuantizationParams, scales); + MlirAttribute arrayZeroPointsAttr = + mlirArrayAttrGet(ctx, numQuantizationParams, zeroPoints); + MlirType illegalSubChannel = mlirUniformQuantizedSubChannelTypeGet( + mlirQuantizedTypeGetSignedFlag(), i8, f32, arrayScalesAttr, + arrayZeroPointsAttr, numBlockSizes, quantizedDimensions, blockSizes, + mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true, + /*integralWidth=*/8), + mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true, + /*integralWidth=*/8)); + + // CHECK: is null sub-channel type: 1 + fprintf(stderr, "is null sub-channel type: %d\n", + mlirTypeIsNull(illegalSubChannel)); + + // CHECK: num dims: 2 + fprintf(stderr, "num dims: %" PRId64 "\n", + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(subChannel)); + + // CHECK: axis-block-size-pair[0]: 0:1 + fprintf( + stderr, "axis-block-size-pair[0]: %" PRId32 ":%" PRId64 "\n", + mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(subChannel, 0), + mlirUniformQuantizedSubChannelTypeGetBlockSize(subChannel, 0)); + + // CHECK: axis-block-size-pair[1]: 1:2 + fprintf( + stderr, "axis-block-size-pair[1]: %" PRId32 ":%" PRId64 "\n", + mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(subChannel, 1), + mlirUniformQuantizedSubChannelTypeGetBlockSize(subChannel, 1)); + + denseScalesAttr = mlirUniformQuantizedSubChannelTypeGetScales(subChannel); + denseZeroPointsAttr = + mlirUniformQuantizedSubChannelTypeGetZeroPoints(subChannel); + scalesType = mlirAttributeGetType(denseScalesAttr); + zeroPointsType = mlirAttributeGetType(denseZeroPointsAttr); + + // CHECK: tensor<2x2xf32> + mlirTypeDump(scalesType); + // CHECK: tensor<2x2xi8> + mlirTypeDump(zeroPointsType); + + // CHECK: number of quantization parameters: 4 + fprintf(stderr, "number of quantization parameters: %" PRId64 "\n", + mlirElementsAttrGetNumElements(denseScalesAttr)); + + // CHECK: quantization-parameter[0]: 2.000000:10 + fprintf(stderr, "quantization-parameter[0]: %lf:%" PRId8 "\n", + mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 0), + mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 0)); + + // CHECK: quantization-parameter[1]: 3.000000:20 + fprintf(stderr, "quantization-parameter[1]: %lf:%" PRId8 "\n", + mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 1), + mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 1)); + + // CHECK: quantization-parameter[2]: 4.000000:30 + fprintf(stderr, "quantization-parameter[2]: %lf:%" PRId8 "\n", + mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 2), + mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 2)); + + // CHECK: quantization-parameter[3]: 5.000000:40 + fprintf(stderr, "quantization-parameter[3]: %lf:%" PRId8 "\n", + mlirDenseElementsAttrGetFloatValue(denseScalesAttr, 3), + mlirDenseElementsAttrGetInt8Value(denseZeroPointsAttr, 3)); + + // CHECK: equal: 1 + fprintf(stderr, "equal: %d\n", mlirTypeEqual(subChannel, subChannelParsed)); + + // CHECK: !quant.uniform + mlirTypeDump(subChannel); + fprintf(stderr, "\n\n"); +} + // CHECK-LABEL: testCalibratedType void testCalibratedType(MlirContext ctx) { fprintf(stderr, "testCalibratedType\n"); @@ -233,6 +358,7 @@ int main(void) { testAnyQuantizedType(ctx); testUniformType(ctx); testUniformPerAxisType(ctx); + testUniformSubChannelType(ctx); testCalibratedType(ctx); mlirContextDestroy(ctx); return EXIT_SUCCESS; diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir index 359a58557087e..8c79b757eeb19 100644 --- a/mlir/test/Dialect/Quant/Bytecode/types.mlir +++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir @@ -64,3 +64,12 @@ module @parseUniformPerAxisMixed attributes { bytecode.test = !quant.uniform } {} +//===----------------------------------------------------------------------===// +// UniformQuantizedSubChannel +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseUniformSubChannel +module @parseUniformSubChannel attributes { + // CHECK: !quant.uniform + bytecode.test = !quant.uniform +} {} diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir index ba3a8e312d96e..7bb50f352f938 100644 --- a/mlir/test/Dialect/Quant/invalid.mlir +++ b/mlir/test/Dialect/Quant/invalid.mlir @@ -256,3 +256,71 @@ func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) { return } +// ----- + +!qalias = !quant.uniform +func.func @qcast_sub_channel_scalar(%arg0: f32) { + // expected-error@+1 {{scalar types may not use sub-channel quantization}} + %0 = quant.qcast %arg0 : f32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_sub_channel_unranked(%arg0: tensor<*xf32>) { + // expected-error@+1 {{tensor containing the sub-channel quantized type must be ranked}} + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_sub_channel_invalid_quantized_dimension(%arg0: tensor<2x4xf32>) { + // expected-error@+1 {{quantized dimension 3 must be less than tensor rank 2}} + %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_sub_channel_invalid_tensor_dim_size(%arg0: tensor<2x4xf32>) { + // expected-error@+1 {{tensor dimension size 4 at axis 1 must be divisible by the corresponding block size 3}} + %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_sub_channel_invalid_zero_tensor_dim_size(%arg0: tensor<0x4xf32>) { + // expected-error@+1 {{tensor dimension size of zero is not allowed with sub-channel quantization}} + %0 = quant.qcast %arg0 : tensor<0x4xf32> to tensor<0x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_sub_channel_invalid_scale_dim_size(%arg0: tensor<2x4xf32>) { + // expected-error@+1 {{dimension size 2 of scales tensor at axis 1 should match (tensor dimension at axis / block sizes at axis) = 2}} + %0 = quant.qcast %arg0 : tensor<2x4xf32> to tensor<2x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_sub_channel_invalid_scale_dim_size(%arg0: tensor) { + // expected-error@+1 {{Rank of scales 3 must match the rank of the tensor 2}} + %0 = quant.qcast %arg0 : tensor to tensor + return +} diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir index 6bba9f5c03772..23c34b906dd46 100644 --- a/mlir/test/Dialect/Quant/lower-quant-ops.mlir +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -509,3 +509,67 @@ func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> return %0 : tensor<*x!qalias> } +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3 floordiv 2)> + +// CHECK-LABEL: @qcast_sub_channel_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<2x?x?x4xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<{{.*}}2.000000e+00, 3.000000e+00{{.*}}, {{.*}}4.000000e+00, 5.000000e+00{{.*}}> : tensor<2x1x1x2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<{{.*}}10, 20{{.*}}, {{.*}}30, 40{{.*}}> : tensor<2x1x1x2xi8> + +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<2x?x?x4xf32> +// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<2x?x?x4xf32> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<2x?x?x4xi8> + +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<2x?x?x4xf32>, tensor<2x1x1x2xf32>, tensor<2x1x1x2xi8>) outs(%[[INIT]] : tensor<2x?x?x4xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor<2x?x?x4xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x?x?x4xi8> to tensor<2x?x?x4x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] + +!qalias = !quant.uniform +func.func @qcast_sub_channel_ranked(%arg0: tensor<2x?x?x4xf32>) -> tensor<2x?x?x4x!qalias> { + %0 = quant.qcast %arg0 : tensor<2x?x?x4xf32> to tensor<2x?x?x4x!qalias> + return %0 : tensor<2x?x?x4x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3 floordiv 2)> + +// CHECK-LABEL: @qcast_sub_channel_ranked_bounds +// CHECK-SAME: %[[ARG_0:.*]]: tensor<2x3x5x4xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<{{.*}}2.000000e+00, 3.000000e+00{{.*}}, {{.*}}4.000000e+00, 5.000000e+00{{.*}}> : tensor<2x1x1x2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<{{.*}}10, 20{{.*}}, {{.*}}30, 40{{.*}}> : tensor<2x1x1x2xi8> + +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x3x5x4xi8> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<2x3x5x4xf32>, tensor<2x1x1x2xf32>, tensor<2x1x1x2xi8>) outs(%[[INIT]] : tensor<2x3x5x4xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor<2x3x5x4xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<2x3x5x4xi8> to tensor<2x3x5x4x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] + +!qalias = !quant.uniform +func.func @qcast_sub_channel_ranked_bounds(%arg0: tensor<2x3x5x4xf32>) -> tensor<2x3x5x4x!qalias> { + %0 = quant.qcast %arg0 : tensor<2x3x5x4xf32> to tensor<2x3x5x4x!qalias> + return %0 : tensor<2x3x5x4x!qalias> +} diff --git a/mlir/test/Dialect/Quant/normalize-quant-types.mlir b/mlir/test/Dialect/Quant/normalize-quant-types.mlir new file mode 100644 index 0000000000000..573781c9ecc04 --- /dev/null +++ b/mlir/test/Dialect/Quant/normalize-quant-types.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s --normalize-quant-types --split-input-file | FileCheck %s + +// CHECK-LABEL: @callee( +// CHECK-SAME: [[PER_TENSOR:tensor<\?x\?x!quant.uniform>]], +// CHECK-SAME: [[PER_TENSOR]] +// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]]) +// CHECK-LABEL: @normalize_quant_types_to_per_tensor +// CHECK-SAME: %[[ARG_0:.*]]: [[PER_TENSOR:tensor<\?x\?x!quant.uniform>]], +// CHECK-SAME: %[[ARG_1:.*]]: [[PER_TENSOR]] +// CHECK-SAME: ([[PER_TENSOR]], [[PER_TENSOR]]) +// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]] +// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_TENSOR]]) -> [[PER_TENSOR]] +// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]]) +// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_TENSOR]], [[PER_TENSOR]] + +!qalias1 = !quant.uniform +!qalias2 = !quant.uniform + +func.func private @callee(tensor, tensor) -> (tensor, tensor) + +func.func @normalize_quant_types_to_per_tensor(%arg0: tensor, + %arg1: tensor) -> (tensor, tensor) { + %0 = "test.custom_op"(%arg0) : (tensor) -> tensor + %1 = "test.custom_op"(%arg1) : (tensor) -> tensor + %3:2 = func.call @callee(%0, %1) : (tensor, tensor) -> (tensor, tensor) + return %3#0, %3#1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: @normalize_quant_types_to_per_axis +// CHECK-SAME: %[[ARG_0:.*]]: [[PER_AXIS:tensor<\?x\?x!quant.uniform>]], +// CHECK-SAME: %[[ARG_1:.*]]: [[PER_AXIS]] +// CHECK-SAME: ([[PER_AXIS]], [[PER_AXIS]]) +// CHECK: %[[TEMP_0:.*]] = "test.custom_op"(%[[ARG_0]]) : ([[PER_AXIS]]) -> [[PER_AXIS]] +// CHECK: %[[TEMP_1:.*]] = "test.custom_op"(%[[ARG_1]]) : ([[PER_AXIS]]) -> [[PER_AXIS]] +// CHECK: %[[TEMP_3:.*]]:2 = call @callee(%[[TEMP_0]], %[[TEMP_1]]) +// CHECK: return %[[TEMP_3]]#0, %[[TEMP_3]]#1 : [[PER_AXIS]], [[PER_AXIS]] + +!qalias1 = !quant.uniform +!qalias2 = !quant.uniform + +func.func private @callee(tensor, tensor) -> (tensor, tensor) + +func.func @normalize_quant_types_to_per_axis(%arg0: tensor, + %arg1: tensor) -> (tensor, tensor) { + %0 = "test.custom_op"(%arg0) : (tensor) -> tensor + %1 = "test.custom_op"(%arg1) : (tensor) -> tensor + %3:2 = func.call @callee(%0, %1) : (tensor, tensor) -> (tensor, tensor) + return %3#0, %3#1 : tensor, tensor +} diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir index 4abc5830d081e..33ff93ecbc1d7 100644 --- a/mlir/test/Dialect/Quant/ops.mlir +++ b/mlir/test/Dialect/Quant/ops.mlir @@ -148,4 +148,23 @@ func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) { return } +// ----- + +!qalias = !quant.uniform +func.func @sub_channel_quantization(%arg0: tensor<2x4xi8>) -> tensor<2x4xi8> { + %0 = quant.scast %arg0 : tensor<2x4xi8> to tensor<2x4x!qalias> + %1 = quant.dcast %0 : tensor<2x4x!qalias> to tensor<2x4xf32> + %2 = quant.qcast %1 : tensor<2x4xf32> to tensor<2x4x!qalias> + %3 = quant.scast %2 : tensor<2x4x!qalias> to tensor<2x4xi8> + return %3 : tensor<2x4xi8> +} +// ----- + +!qalias = !quant.uniform +func.func @sub_channel_quantization_with_unknown_dims(%arg0: tensor<2x?xf32>) { + %0 = quant.qcast %arg0 : tensor<2x?xf32> to tensor<2x?x!qalias> + return +} diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir index 4528d2826a850..3b358443e43f2 100644 --- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir @@ -107,7 +107,7 @@ // ----- // Illegal scale: negative -// expected-error@+1 {{scale out of expressed type range}} +// expected-error@+1 {{scale -1.000000e+00 out of expressed type range}} !qalias = !quant.uniform:f32, -1.0:127> // ----- @@ -128,20 +128,110 @@ // ----- // Scale f16 underflow -// expected-error@+1 {{scale out of expressed type range}} +// expected-error@+1 {{scale 5.800000e-08 out of expressed type range}} !qalias = !quant.uniform // ----- // Scale f16 overflow -// expected-error@+1 {{scale out of expressed type range}} +// expected-error@+1 {{scale 6.600000e+04 out of expressed type range}} !qalias = !quant.uniform // ----- // Scale f16 underflow in per-axis quantization -// expected-error@+1 {{scale out of expressed type range}} +// expected-error@+1 {{scale 5.800000e-08 out of expressed type range}} !qalias = !quant.uniform // ----- // Scale f16 overflow in per-axis quantization -// expected-error@+1 {{scale out of expressed type range}} +// expected-error@+1 {{scale 6.600000e+04 out of expressed type range}} !qalias = !quant.uniform + +// ----- +// Illegal negative axis in sub-channel quantization +// expected-error@+1 {{illegal quantized dimension: -1}} +!qalias = !quant.uniform + +// ----- +// Illegal zero block-size in sub-channel quantization +// expected-error@+1 {{illegal block size: 0}} +!qalias = !quant.uniform + +// ----- +// Illegal negative block-size in sub-channel quantization +// expected-error@+1 {{illegal block size: -1}} +!qalias = !quant.uniform + +// ----- +// Missing block size in sub-channel quantization +// expected-error@+1 {{expected ':'}} +!qalias = !quant.uniform + +// ----- +// Missing quantization dimension in sub-channel quantization +// expected-error@+1 {{expected integer value}} +!qalias = !quant.uniform + +// ----- +// Invalid tensor literal structure in sub-channel quantization +// expected-error@+2 {{expected '>'}} +!qalias = !quant.uniform + +// ----- +// Ragged tensor literal in sub-channel quantization +// expected-error@+2 {{ranks are not consistent between elements}} +!qalias = !quant.uniform + +// ----- +// Missing braces around block-size information in sub-channel quantization +// expected-error@+1 {{expected ','}} +!qalias = !quant.uniform + +// ----- +// Missing right-brace around block-size information in sub-channel quantization +// expected-error@+1 {{unbalanced '{' character}} +!qalias = !quant.uniform + +// ----- +// Missing left-brace around block-size information in sub-channel quantization +// expected-error@+1 {{unbalanced '<' character}} +!qalias = !quant.uniform + +// ----- +// Missing Axis:BlockSize pair +// expected-error@+1 {{expected integer value}} +!qalias = !quant.uniform + +// ----- +// Missing Scale:ZeroPoint pair +// expected-error@+2 {{expected floating point literal}} +!qalias = !quant.uniform + +// ----- +// Missing ZeroPoint in Scale:ZeroPoint pair +// expected-error@+2 {{expected integer value}} +!qalias = !quant.uniform + +// ----- +// Empty quantization paramaters in sub-channel quantization +// expected-error@+1 {{expected floating point literal}} +!qalias = !quant.uniform + +// ----- +// Scale out of expressed type range in sub-channel quantization +// expected-error@+2 {{scale 6.600000e+04 out of expressed type range}} +!qalias = !quant.uniform + diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir index 4fbe86d935ea3..80a6621ed6979 100644 --- a/mlir/test/Dialect/Quant/parse-uniform.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform.mlir @@ -154,3 +154,21 @@ func.func @parse() -> !qalias { %0 = "foo"() : () -> !qalias return %0 : !qalias } + +// ----- +// Sub-channel scales and zero points (mixed affine and fixedpoint) +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + +// ----- +// Empty block-size information in sub-channel quantization +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py index b1d6e85f519b5..57c528da7b9eb 100644 --- a/mlir/test/python/dialects/quant.py +++ b/mlir/test/python/dialects/quant.py @@ -1,5 +1,6 @@ # RUN: %PYTHON %s | FileCheck %s +import numpy as np from mlir.ir import * from mlir.dialects import quant @@ -18,21 +19,28 @@ def test_type_hierarchy(): any = Type.parse("!quant.any:f32>") uniform = Type.parse("!quant.uniform:f32, 0.99872:127>") per_axis = Type.parse("!quant.uniform") + sub_channel = Type.parse( + "!quant.uniform" + ) calibrated = Type.parse("!quant.calibrated>") assert not quant.QuantizedType.isinstance(i8) assert quant.QuantizedType.isinstance(any) assert quant.QuantizedType.isinstance(uniform) assert quant.QuantizedType.isinstance(per_axis) + assert quant.QuantizedType.isinstance(sub_channel) assert quant.QuantizedType.isinstance(calibrated) assert quant.AnyQuantizedType.isinstance(any) assert quant.UniformQuantizedType.isinstance(uniform) assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) + assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel) assert quant.CalibratedQuantizedType.isinstance(calibrated) assert not quant.AnyQuantizedType.isinstance(uniform) assert not quant.UniformQuantizedType.isinstance(per_axis) + assert not quant.UniformQuantizedType.isinstance(sub_channel) + assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel) # CHECK-LABEL: TEST: test_any_quantized_type @@ -121,6 +129,47 @@ def test_uniform_per_axis_type(): assert per_axis == Type.parse("!quant.uniform") +# CHECK-LABEL: TEST: test_uniform_sub_channel_type +@run +def test_uniform_sub_channel_type(): + with Context(): + i8 = IntegerType.get_signless(8) + f32 = F32Type.get() + sub_channel = quant.UniformQuantizedSubChannelType.get( + quant.QuantizedType.FLAG_SIGNED, + i8, + f32, + DenseElementsAttr.get( + np.asarray([2.0, 3.0, 4.0, 5.0], np.float32).reshape(2, 2) + ), + DenseElementsAttr.get(np.asarray([10, 20, 30, 40], np.int8).reshape(2, 2)), + [0, 1], + [1, 2], + storage_type_min=quant.QuantizedType.default_minimum_for_integer( + is_signed=True, integral_width=8 + ), + storage_type_max=quant.QuantizedType.default_maximum_for_integer( + is_signed=True, integral_width=8 + ), + ) + + # CHECK: quantized dimensions: [0, 1] + print(f"quantized dimensions: {sub_channel.quantized_dimensions}") + # CHECK: block sizes: [1, 2] + print(f"block sizes: {sub_channel.block_sizes}") + # CHECK: scales: {{\[}}[2. 3.] + # CHECK: [4. 5.]] + print(f"scales: {np.asarray(sub_channel.scales)}") + # CHECK: zero-points: {{\[}}[10 20] + # CHECK: [30 40]] + print(f"zero-points: {np.asarray(sub_channel.zero_points)}") + # CHECK: !quant.uniform + print(sub_channel) + assert sub_channel == Type.parse( + "!quant.uniform" + ) + + # CHECK-LABEL: TEST: test_calibrated_type @run def test_calibrated_type():