diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt index c08f399ee182..9f57627c321f 100644 --- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt @@ -1,6 +1,2 @@ -add_mlir_dialect(QuantOps quant) -add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) - -set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td) -mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant") -add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen) +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt new file mode 100644 index 000000000000..c08f399ee182 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/IR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(QuantOps quant) +add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td) +mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant") +add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen) diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h similarity index 56% rename from mlir/include/mlir/Dialect/Quant/QuantOps.h rename to mlir/include/mlir/Dialect/Quant/IR/Quant.h index 14fb3035ab0d..e6b94b10c377 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantOps.h +++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h @@ -1,4 +1,4 @@ -//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===// +//===- Quant.h - Quantization Ops -------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_QUANTOPS_H_ -#define MLIR_DIALECT_QUANT_QUANTOPS_H_ +#ifndef MLIR_DIALECT_QUANT_IR_QUANT_H_ +#define MLIR_DIALECT_QUANT_IR_QUANT_H_ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -19,9 +19,21 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/Quant/QuantOpsDialect.h.inc" +#include "mlir/Dialect/Quant/IR/QuantOpsDialect.h.inc" + +namespace mlir { +namespace quant { + +class QuantizedType; +class UniformQuantizedType; +class UniformQuantizedPerAxisType; +class QuantileQuantizedType; +class QuantileQuantizedPerAxisType; + +} // namespace quant +} // namespace mlir #define GET_OP_CLASSES -#include "mlir/Dialect/Quant/QuantOps.h.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.h.inc" -#endif // MLIR_DIALECT_QUANT_QUANTOPS_H_ +#endif // MLIR_DIALECT_QUANT_IR_QUANT_H_ diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td new file mode 100644 index 000000000000..b761003e0a2a --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantBase.td @@ -0,0 +1,307 @@ +//===- QuantBase.td - Quantization dialect base ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Quantization dialect, types, and traits. +// +//===----------------------------------------------------------------------===// + +#ifndef QUANT_BASE +#define QUANT_BASE + +include "mlir/IR/OpBase.td" + +def Quant_Dialect : Dialect { + let name = "quant"; + let description = [{ + The `quant` dialect offers a framework for defining and manipulating + quantized values. Central to this framework is the `!quant.uniform` data + type, used to represent quantized values. This dialect also provides a + suite of operations to handle and convert quantized values between their + original floating-point representations and the optimized, lower bit-width + integer representations. The `quant` dialect is instrumented with + transformation passes to lower these operations into other core MLIR + dialects, while also flattening all occurrences of quantized types into + their integer counterparts. + + + ## The `!quant.uniform` type + + The quantization process establishes a relationship between two types of + values: an *expressed value* and a *stored value*. The former refers to the + floating-point representation used in an original machine learning model, + capturing the precise numerical characteristics needed for accurate + calculations. The latter is the simplified integer representation that + resides in memory after quantization. The `!quant.uniform` data type + encodes the necessary information for (lossy) round-trip conversion between + an expressed and a stored value. + + The `quant.uniform` type has two variants: per-layer quantization and + per-channel (or per-axis) quantization. In per-layer quantization, the + quantization information affects an entire tensor uniformly. Conversely, in + per-channel quantization, the data type encodes the specific tensor axis + that serves as the channel and includes quantization information for each + individual channel within the tensor. Below are the specific syntactic and + semantic considerations for each modality. + + + ### Per-layer quantization + + This is the general syntax of the `!quant.uniform` type representing + per-layer quantization: + + ``` + `!quant.uniform` `<` + storedType (`<` storageMin `:` storageMax `>`)? `:` + expressedType `,` + scale (`:` zeroPoint)? + `>` + ``` + + The type contains the following parameters: + + - `storedType`: Integer type of the value stored in memory. This type + conveys the bit width and signedness of the quantized stored value. + Signed integer types are represented as `'i' bitWidth` (e.g., `i8`), + while unsigned integer types are represented as `'u' bitWidth` (e.g., + `u8`). + + - `storageMin`, `storageMax`: Optional bounds for the stored value. If + given, they must be within the range of `storedType`. If omitted, the + entire range of `storedType` is allowed (e.g., `-128...127` for `i8` or + `0...255` for `u8`). + + - `expressedType`: Floating-point type of the value expressed by this + quantized type (e.g., `f32`, `f80`, `bf16`, or `tf32`). + + - `scale`: Floating-point value of type `expressedType` used in the + conversion between stored and expressed values. + + - `zeroPoint`: Optional integer value of type `storageType` used in the + conversion between stored and expressed values. If omitted, the default + is 0. + + Type conversions, rounding methods, and clamping actions aside, the + relationship between the expressed and stored values as encoded in a + quantized type is denoted by the following formula: + + $$ + expressedValue = (storedValue ~-~ zeroPoint) ~\times~ scale + $$ + + Operations `quant.qcast` (quantize cast) and `quant.dcast` (dequantize + cast) can be used to quantize a floating-point value and dequantize a + stored value, respectively. See the documentation for these operations for + details on how the quantization and dequantization processes are influenced + by the `!quant.uniform` type parameters. + + Here are some examples of the use of `!quant.uniform` with per-layer + quantization: + + ``` + // An 8-bit signed integer type is used to represent a 32-bit float. No + // clamping information is provided, so the full [-128, 127] range is + // available. The scale is set to 3.0, and the zero point takes its default + // 0 value. + !quant.uniform + + // A 16-bit unsigned integer type is used to represent a 32-bit float. Out + // of the 16 bits, only 10 are used, acoording to the 0..1023 clamping + // range. The type sets the scale to 1.23 and the zero point to 512. + !quant.uniform:f32, 1.23:512> + ``` + + ### Per-channel quantization + + The general syntax of the `!quant.uniform` type representing per-channel + quantization is as follows: + + ``` + `!quant.uniform` `<` + storedType (`<` storageMin `:` storageMax `>`)? `:` + expressedType `:` + channelAxis `,` + `{` + scale0 (`:` zeroPoint0)? `,` + scale1 (`:` zeroPoint1)? ... + '}' + `>` + ``` + + In this data type, there are multiple pairs of `scale` and `zeroPoint` + values. The `channelAxis` field represents the dimension of the containing + tensor acting as the channel. The size of the tensor along this dimension + is expected to match the number of provided `scale`-`zeroPoint` pairs, and + a given pair *i* applies to all elements in the tensor whose index along + dimension `channelAxis` is *i*. A quantized data type using per-channel + quantization is always expected to be contained within a tensor type. + + Here are some examples: + + ``` + // A 2x3x4 tensor contains 8-bit signed integers representing 32-bit + // floats. Dimension 1 of the tensor acts as the channel dimension. Its + // size 3 matches the number of provided scale values. Tensor elemenets at + // positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and + // 5.0, respectively. + tensor<2x3x4x!quant.uniform> + + // A 2D dynamically sized tensor contains 16-bit unsigned integers + // representing 32-bit floats. Dimension 0 of the tensor acts as the + // channel dimension. Since 2 scale and zero-point values are provided, the + // size of dimension 0 is expected to be 2 at runtime. Tensor elements + // [0][*] use scale 2.0 and zero point 10, while elements [1][*] use scale + // 3.0 and zero point 20. + tensor> + ``` + + + ## Per-axis quantization integrity + + When type `!quant.uniform` contains per-axis quantization information, the + rules below are enforced. These rules guarantee that the quantization + information encoded in the data type is applicable to the context in which + the quantized type is used. For efficiency, these rules are actively + enforced by the verifiers of `quant` dialect ops, but they must be + respected in any context in which the `!quant.uniform` data type is used, + such as the header of a `func.func` op, or the input of an arithmetic + operation. + + - A quantized type with per-channel quantization information must be the + element type of a tensor container type, and may not occur directly as + the data type of a scalar value. + + ``` + // Incorrect. Type !quant.uniform specifies per-channel quantization for a + // scalar type. + %result = quant.qcast %input : f32 to !quant.uniform + + // Correct. Type `!quant.uniform` with per-channel quantization is wrapped + // in a `tensor` type. + %result = quant.qcast %input : tensor<2xf32> to tensor<2x!quant.uniform> + ``` + + - If the tensor containing the `!quant.uniform` type is ranked, its rank + must be greater than the channel axis specified in the quantized type. + + ``` + // Incorrect. The tensor rank (2) is not greater than the channel axis in + // the quantized type (3). + %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform> + + // Correct. The tensor rank (2) is now greater than the channel axis (1): + %result = quant.qcast %input : tensor<1x2xf32> to tensor<1x2x!quant.uniform> + ``` + + - If the axis dimension in the containing tensor is static, its size must + be equal to the number of scales present in the quantized type. + + ``` + // Incorrect. The channel axis is 1, and the size of dimension 1 in the + // containing tensor is 3. However, there are 4 scale values present in the + // quantized type. + %result = quant.qcast %input : tensor to tensor> + + // Correct. The quantized type now includes 3 scale values, matching the + // size of dimension 1 of the result tensor. + %result = quant.qcast %input : tensor to tensor> + ``` + }]; + let cppNamespace = "::mlir::quant"; + let useDefaultTypePrinterParser = 1; +} + + +//===----------------------------------------------------------------------===// +// Type predicates +//===----------------------------------------------------------------------===// + +class quant_ScalarOrTensorOf : + Type.predicate]>, + "scalar or tensor of " # etype.summary>; + +def quant_QuantizedType : + Type($_self)">, "quantized type">; + +def quant_ScalarType : + Type, + "signless integer, float, or quantized scalar">; + +def quant_IntegerOrQuantizedType : + Type, + "signless integer or quantized type">; + +def quant_FloatScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_IntegerScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_QuantizedScalarOrTensor : + quant_ScalarOrTensorOf; + +def quant_IntegerOrQuantizedScalarOrTensor : + quant_ScalarOrTensorOf; + +// An implementation of QuantileQuantizedType. +def quant_QuantileQuantizedType : + DialectType($_self)">, + "QuantileQuantizedType">; + +// An implementation of QuantileQuantizedPerAxisType. +def quant_QuantileQuantizedPerAxisType : + DialectType($_self)">, + "QuantileQuantizedPerAxisType">; +//===----------------------------------------------------------------------===// +// Traits +//===----------------------------------------------------------------------===// + +def quant_SameScalarOrTensorShape : + PredOpTrait< + "input and result are both scalars or both tensors with matching shape", + Or<[ + And<[ + TypeIsPred<"input", quant_ScalarType>, + TypeIsPred<"result", quant_ScalarType> + ]>, + And<[ + TypeIsPred<"input", AnyUnrankedTensor>, + TypeIsPred<"result", AnyUnrankedTensor> + ]>, + And<[ + TypeIsPred<"input", AnyRankedTensor>, + TypeIsPred<"result", AnyRankedTensor>, + AllShapesMatch<["input", "result"]>.predicate + ]> + ]> + >; + +def quant_IntegerAndQuantizedCombination : + PredOpTrait< + "input must be integer and result must be quantized, or vice versa", + Or<[ + And<[ + TypeIsPred<"input", quant_QuantizedScalarOrTensor>, + TypeIsPred<"result", quant_IntegerScalarOrTensor> + ]>, + And<[ + TypeIsPred<"input", quant_IntegerScalarOrTensor>, + TypeIsPred<"result", quant_QuantizedScalarOrTensor> + ]> + ]> + >; + +#endif // QUANT_BASE diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td similarity index 99% rename from mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td rename to mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td index 6c1e2b01f4ca..0c7430c5f19a 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td @@ -130,4 +130,4 @@ def QuantDialectTypes : DialectTypes<"Quant"> { ]; } -#endif // QUANT_BYTECODE \ No newline at end of file +#endif // QUANT_BYTECODE diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td new file mode 100644 index 000000000000..6ef925146dce --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantOps.td @@ -0,0 +1,243 @@ +//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for Quantization. +// +//===----------------------------------------------------------------------===// + +#ifndef QUANT_OPS +#define QUANT_OPS + +include "mlir/Dialect/Quant/IR/QuantBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// Base classes +//===----------------------------------------------------------------------===// + +class quant_Op traits> : + Op; + +//===----------------------------------------------------------------------===// +// Quantization casts +//===----------------------------------------------------------------------===// + +def quant_DequantizeCastOp : quant_Op<"dcast", [ + Pure, + quant_SameScalarOrTensorShape]> { + let summary = "Dequantize cast operation"; + let description = [{ + Convert an input quantized value into its expressed floating-point value. + The dequantization process consists of the following steps: + + ``` + def dequantize(quantizedValue: quantizedType) -> expressedType: + storedValue = reinterpretCast(quantizedValue, storageType) + storedValueFloat = convertIntToFloat(storedValue, expressedType) + zeroPointFloat = convertIntToFloat(zeroPoint, expressedType) + expressedValue = (storedValueFloat - zeroPointFloat) * scale + return expressedValue + ``` + + Here, `storageType`, `expressedType`, `scale`, and `zeroPoint` are obtained + from the corresponding parameters encoded in `quantizedType`. For + per-channel quantization, the appropriate `scale` and `zeroPoint` values + are used for each tensor element computation according to the channel the + element belongs to. + + The numerical results produced by the algorithm above may vary depending on + the rounding methods used by `convertIntToFloat()`, subtraction (`-`), and + multiplication (`*`). This operation does not define specific rounding + methods; instead, it is the responsibility of a transform pipeline to + determine which rounding method to apply when this operation is broken down + into lower-level dialects. + + The operation must satisfy the following syntactic constraints: + + - Operand `input` must be a scalar or tensor of type `!quant.uniform`. + + - The result type must be a floating-point scalar or tensor. + + - The `expressedType` parameter of the `!quant.uniform` type of the input + must match the floating-point type of the result. + + - The operand and result types must be both scalars or both tensors. If + tensors, they must be both ranked or both unranked. If ranked, both must + have the same shape, including matching static and dynamic dimensions. + + - If the operand uses per-channel quantization, its `!quant.uniform` type + must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Dequantize a scalar quantized value + %result = quant.dcast %input : !quant.uniform to f32 + + // Dequantize a dynamically shaped tensor of quantized values + %result = quant.dcast %input : tensor> to tensor + + // Dequantize an unranked tensor using per-axis quantization information + %result = quant.dcast %input : tensor<*x!quant.uniform> to tensor<*xf32> + ``` + }]; + let arguments = (ins quant_QuantizedScalarOrTensor:$input); + let results = (outs quant_FloatScalarOrTensor:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// Return the float type of the scalar or tensor result. + FloatType getFloatType(); + + /// Return the quantized type of the scalar or tensor input. + quant::QuantizedType getQuantizedType(); + }]; +} + +def quant_QuantizeCastOp : quant_Op<"qcast", [ + Pure, + quant_SameScalarOrTensorShape]> { + let summary = "Quantize cast operation"; + let description = [{ + Convert a floating-point value to a quantized type. The quantization + process consists of the following steps: + + ``` + def quantize(expressedValue: expressedType) -> quantizedType: + zeroPointFloat = convertIntToFloat(zeroPoint, expressedType) + scaledValue = expressedValue / scale + storedValueFloat = scaledValue + zeroPointFloat + storedValue = convertFloatToInt(storedValueFloat, storageType) + storedValueClamped = clamp(storedValue, storageMin, storageMax) + quantizedValue = reinterpretCast(storedValueClamped, quantizedType) + return quantizedValue + ``` + + Here, `storageType`, `storageMin`, `storageMax`, `expressedType`, `scale`, + and `zeroPoint` are obtained from the corresponding parameters encoded in + `quantizedType`. For per-channel quantization, the appropriate `scale` and + `zeroPoint` values are used for each tensor element computation according + to the channel the element belongs to. + + The numerical results produced by the algorithm above may vary depending on + the rounding methods used by `convertIntToFloat()`, `convertFloatToInt()`, + `clamp()`, division (`/`), and addition (`+`). This operation does not + define specific rounding methods; instead, it is the responsibility of a + transform pipeline to determine which rounding method to apply when this + operation is broken down into lower-level dialects. + + The operation must satisfy the following syntactic constraints: + + - Operand `input` must be a floating-point scalar or tensor. + + - The result type must be a scalar or tensor of type `!quant.uniform`. + + - The `expressedType` parameter in the `!quant.uniform` type of the result + must match the floating-point type of the input. + + - The operand and result types must be both scalars or both tensors. If + tensors, they must be both ranked or both unranked. If ranked, both must + have the same shape, including matching static and dynamic dimensions. + + - If the result uses per-channel quantization, its `!quant.uniform` type + must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Quantize a scalar floating-point value + %result = quant.qcast %input : f32 to !quant.uniform + + // Quantize a dynamically shaped tensor of quantized values + %result = quant.qcast %input : tensor to tensor> + + // Quantize an unranked tensor using per-axis quantization information + %result = quant.qcast %input : tensor<*xf32> to tensor<*x!quant.uniform> + ``` + }]; + let arguments = (ins quant_FloatScalarOrTensor:$input); + let results = (outs quant_QuantizedScalarOrTensor:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// Return the float type of the scalar or tensor input. + FloatType getFloatType(); + + /// Return the quantized type of the scalar or tensor result. + quant::QuantizedType getQuantizedType(); + }]; +} + +def quant_StorageCastOp : quant_Op<"scast", [ + Pure, + quant_SameScalarOrTensorShape, + quant_IntegerAndQuantizedCombination]> { + let summary = "Storage cast operation"; + let description = [{ + Convert a value from a quantized type to the corresponding signless integer + storage type, or vice versa. This conversion simply involves a + reinterpretation of the input bits and does not involve any data + manipulation. + + The following syntactic restrictions must be met: + + - Operand `input` must be a scalar or tensor of a signless integer or + `!quant.uniform` type. + + - The result must be a scalar or tensor of a signless integer or + `!quant.uniform` type. + + - If the operand is a scalar or tensor of type integer, the result must be + a scalar or tensor of type `!quant.uniform`, and vice versa. + + - The operand and result must be both scalars or both tensors. If tensors, + they must be both ranked or both unranked. If ranked, both must have the + same shape, including matching static and dynamic dimensions. + + - The width of the `storageType` parameter of the quantized type of the + operand or result must match the width of the signless integer type of + the operand or result. + + - If the operand or result uses per-channel quantization, its + `!quant.uniform` type must adhere to the [Per-axis quantization + integrity](#per-axis-quantization-integrity) guidelines. + + Examples: + + ``` + // Cast a scalar quantized value into its storage type + %result = quant.scast %input : !quant.uniform to i8 + + // Cast a dynamically shaped tensor of quantized values into their storage type + %result = quant.scast %input : tensor> to tensor + + // Cast an unranked tensor of signless integers into a quantized type using + // per-channel quantization + %result = quant.scast %input : tensor<*xi8> to tensor<*x!quant.uniform> + ``` + }]; + let arguments = (ins quant_IntegerOrQuantizedScalarOrTensor:$input); + let results = (outs quant_IntegerOrQuantizedScalarOrTensor:$result); + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($result)"; + let hasVerifier = 1; + let hasFolder = 1; + let extraClassDeclaration = [{ + /// Return the integer type used either in the input or the result. + IntegerType getIntegerType(); + + /// Return the quantized type used either in the input or the result. + quant::QuantizedType getQuantizedType(); + }]; +} + +#endif // QUANT_OPS diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h similarity index 97% rename from mlir/include/mlir/Dialect/Quant/QuantTypes.h rename to mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index 6a6d3a54891c..58452fa5ed0b 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_QUANTTYPES_H -#define MLIR_DIALECT_QUANT_QUANTTYPES_H +#ifndef MLIR_DIALECT_QUANT_IR_QUANTTYPES_H +#define MLIR_DIALECT_QUANT_IR_QUANTTYPES_H #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -128,6 +128,10 @@ class QuantizedType : public Type { /// The maximum value that storageType can take. int64_t getStorageTypeMax() const; + /// Return whether the storage type has explicit min or max boundaries + /// different from the minimum and maximum representable values. + bool hasStorageTypeBounds() const; + /// Gets the integral bit width that the underlying storage type can exactly /// represent. For integral storage types, this will just be their width. unsigned getStorageTypeIntegralWidth() const; @@ -296,8 +300,6 @@ class UniformQuantizedType int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); - static bool classof(mlir::Type type); - /// Gets the scale term. The scale designates the difference between the real /// values corresponding to consecutive quantized values differing by 1. double getScale() const; @@ -361,8 +363,6 @@ class UniformQuantizedPerAxisType int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); - static bool classof(mlir::Type type); - /// Gets the quantization scales. The scales designate the difference between /// the real values corresponding to consecutive quantized values differing /// by 1. The ith scale corresponds to the ith slice in the @@ -397,8 +397,9 @@ class UniformQuantizedPerAxisType }; /// QuantileQuantizedType derives from UniformQuantizedType and adds to it a -/// look up table array of quantile values. The type of the data in the look up table is determined by -/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64. +/// look up table array of quantile values. The type of the data in the look up +/// table is determined by the quantileType member: supported quantileType types +/// are integer/unsigned/hf8/bf8/f16/bf16/f32/f64. /// /// Syntax synopsis: /// Per-layer, all parameters expressed: @@ -464,8 +465,9 @@ class QuantileQuantizedType }; /// Represents per-axis QuantileQuantizedType (also known as per-channel -/// quantization). The type of the data in the look up table is determined by the -/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64. +/// quantization). The type of the data in the look up table is determined by +/// the quantileType member: supported quantileType types are +/// integer/unsigned/hf8/bf8/f16/bf16/f32/f64. /// /// Syntax synopsis: /// Per-axis, all parameters expressed: @@ -572,4 +574,4 @@ class CalibratedQuantizedType } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_QUANTTYPES_H +#endif // MLIR_DIALECT_QUANT_IR_QUANTTYPES_H diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td deleted file mode 100644 index 7937265ce2f2..000000000000 --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ /dev/null @@ -1,103 +0,0 @@ -//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the operation definition file for Quantization. -// -//===----------------------------------------------------------------------===// - -#ifndef DIALECT_QUANT_QUANT_OPS_ -#define DIALECT_QUANT_QUANT_OPS_ - -include "mlir/Dialect/Quant/QuantOpsBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" - -//===----------------------------------------------------------------------===// -// Base classes -//===----------------------------------------------------------------------===// - -class quant_Op traits> : - Op; - -//===----------------------------------------------------------------------===// -// Quantization casts -//===----------------------------------------------------------------------===// - -def quant_QuantizeCastOp : quant_Op<"qcast", [Pure]> { - let summary = "convert a quantizable type to a quantized type"; - let description = [{ - A QuantizeCast `qcast` represents a potential type shift from a quantizable - type to a quantized type. - - At runtime, a `qcast` will apply the transformation expressed by its - operand and result type. For flexibility during transformation, it is also - possible to have a `qcast` that performs no transformation (both its - operand and result type are quantizable). - - A `qcast` will typically originate from either: - a) An expressed or implied constraint in the source dialect which signals - that a certain level of quantization is possible or required. - b) An inference made by a quantization algorithm indicating that a - quantized representation may be acceptable. - - Especially early in transformation, it is common to have pairs of - `qcast` and `dcast` at points where a transition to a quantized type is - required. In addition, it is also common to have an identity `qcast` - (where the operand and result type are not quantized) at all points where - it is legal to use a quantized representation (but is not known to be - acceptable). - }]; - let arguments = (ins quant_RealValueType:$arg); - let results = (outs quant_RealValueType:$res); -} - -def quant_DequantizeCastOp : quant_Op<"dcast", [Pure]> { - let summary = "convert back from a quantized to quantizable (expressed) type operation"; - let description = [{ - A DequantizeCast op `dcast` represents the inverse of a `qcast`, - converting back from a quantized to quantizable (expressed) type. - - Like `qcast`s, a `dcast` is allowed to have both its operand and result - as non quantized types. This facilitates transformations and marks edges - where the computation must be carried out in the expressed type. - - Especially early in transformation, it is common to have `dcast`s on - all operands to ops that must operate with the expressed type (typically - math ops prior to lowering to target-specific, quantized kernels). - }]; - let arguments = (ins quant_RealValueType:$arg); - let results = (outs quant_RealValueType:$res); -} - -def quant_StorageCastOp : quant_Op<"scast", [Pure]> { - let summary = "cast from or to a type based on the storage type and the corresponding quantized type"; - let description = [{ - A StorageCast `scast` represents a cast from or to a type based on the - storage type and a type based on a corresponding quantized type. - - This op exists to ensure type coherency for between parts of the computation - which are operating directly on an underlying storage type and those which - operate on quantized values. - - Examples from storage to quantized type: - ``` - i8 -> !quant<"uniform[i8:f32]{1.0}"> - ``` - ``` - tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - ``` - ``` - vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> - ``` - }]; - let arguments = (ins quant_RealOrStorageValueType:$arg); - let results = (outs quant_RealOrStorageValueType:$res); - let hasFolder = 1; -} - -#endif // DIALECT_QUANT_QUANT_OPS_ diff --git a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td b/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td deleted file mode 100644 index 820219c1ed17..000000000000 --- a/mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ /dev/null @@ -1,104 +0,0 @@ -//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Predicates for types in the Quantization dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef DIALECT_QUANT_QUANT_OPS_BASE_ -#define DIALECT_QUANT_QUANT_OPS_BASE_ - -include "mlir/IR/OpBase.td" - -def Quantization_Dialect : Dialect { - let name = "quant"; - let cppNamespace = "::mlir::quant"; - - let useDefaultTypePrinterParser = 1; -} - -//===----------------------------------------------------------------------===// -// Quantization type definitions -//===----------------------------------------------------------------------===// - -class quant_TypedPrimitiveOrContainer : - Type.predicate, - VectorOf<[etype]>.predicate]>, - "primitive/tensor/vector of " # etype.summary>; - -// An implementation of QuantizedType. -def quant_QuantizedType : - Type($_self)">, "QuantizedType">; - -// A primitive type that can represent a real value. This is either a -// floating point value or a quantized type. -def quant_RealPrimitiveType : - Type, - "real valued primitive (float or quantized type)">; - -// A primitive type that can represent a storage value. This is either an -// integer or quantized type. -def quant_StoragePrimitiveType : - Type, - "quantized storage primitive (integer or quantized type)">; - -// A primitive or container of RealPrimitiveType. -def quant_RealValueType : - quant_TypedPrimitiveOrContainer; - -// A primitive or container of StoragePrimitiveType. -def quant_StorageValueType : - quant_TypedPrimitiveOrContainer; - -// Either a real valued or storage primitive or container type. -def quant_RealOrStorageValueType : - Type, - "real valued or storage primitive or container type">; - -// An implementation of UniformQuantizedType. -def quant_UniformQuantizedType : - DialectType($_self)">, - "UniformQuantizedType">; - -// An implementation of UniformQuantizedPerAxisType. -def quant_UniformQuantizedPerAxisType : - DialectType($_self)">, - "UniformQuantizedPerAxisType">; - -// An implementation of QuantileQuantizedType. -def quant_QuantileQuantizedType : - DialectType($_self)">, - "QuantileQuantizedType">; - -// An implementation of QuantileQuantizedPerAxisType. -def quant_QuantileQuantizedPerAxisType : - DialectType($_self)">, - "QuantileQuantizedPerAxisType">; - -// Predicate for detecting a container or primitive of UniformQuantizedType. -def quant_UniformQuantizedValueType : - quant_TypedPrimitiveOrContainer; - -// Predicate for detecting a container or primitive of UniformQuantizedPerAxisType. -def quant_UniformQuantizedPerAxisValueType : - quant_TypedPrimitiveOrContainer; - -// Predicate for detecting a container or primitive of QuantileQuantizedType. -def quant_QuantileQuantizedValueType : - quant_TypedPrimitiveOrContainer; - -// Predicate for detecting a container or primitive of QuantileQuantizedPerAxisType. -def quant_QuantileQuantizedPerAxisValueType : - quant_TypedPrimitiveOrContainer; - -#endif // DIALECT_QUANT_QUANT_OPS_BASE_ diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..30f7c1696bdb --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant) +add_public_tablegen_target(MLIRQuantTransformsIncGen) + +add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h new file mode 100644 index 000000000000..84be2a21b34e --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.h @@ -0,0 +1,29 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DECL +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +void populateLowerQuantOpsPatterns(RewritePatternSet &patterns); + +} // namespace quant +} // namespace mlir + +#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td new file mode 100644 index 000000000000..b25296d4db5a --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/Transforms/Passes.td @@ -0,0 +1,49 @@ +//===-- Passes.td - Arith pass definition file --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_QUANT_TRANSFORMS_PASSES +#define MLIR_DIALECT_QUANT_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def LowerQuantOps : Pass<"lower-quant-ops", "func::FuncOp"> { + let summary = "Lower quant.dcast and quant.qcast ops"; + let description = [{ + Lower quantization (`quant.qcast`) and dequantization (`quant.dcast`) ops + into other core dialects. + + The lowering process generates storage type casts in the form of + `quant.scast` ops to act as an interface between the original quantized + types of operands and results and their corresponding storage types used in + the generated arithmetic computations. + }]; + let dependentDialects = [ + "arith::ArithDialect", + "linalg::LinalgDialect", + "quant::QuantDialect", + "shape::ShapeDialect", + "tensor::TensorDialect" + ]; +} + +def StripFuncQuantTypes : Pass<"strip-func-quant-types"> { + let summary = "Strip quantized types from function headers"; + let description = [{ + Identify occurrences of function arguments using a quantized type and + replace them with a new value of the corresponding storage (signless + integer) type. For each converted argument, a `quant.scast` op is introduced + at the head of the function's entry block converting the new integer + argument into the original quantized value. + }]; + let dependentDialects = [ + "func::FuncDialect", + "quant::QuantDialect" + ]; +} + +#endif // MLIR_DIALECT_QUANT_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h similarity index 93% rename from mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h rename to mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h index 367d468b2acf..6551efc6242a 100644 --- a/mlir/include/mlir/Dialect/Quant/FakeQuantSupport.h +++ b/mlir/include/mlir/Dialect/Quant/Utils/FakeQuantSupport.h @@ -34,10 +34,10 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ -#define MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ +#ifndef MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ +#define MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" namespace mlir { namespace quant { @@ -64,4 +64,4 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_FAKEQUANTSUPPORT_H_ +#endif // MLIR_DIALECT_QUANT_UTILS_FAKEQUANTSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h similarity index 97% rename from mlir/include/mlir/Dialect/Quant/UniformSupport.h rename to mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h index 4119aced4c07..6773f45069c8 100644 --- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h +++ b/mlir/include/mlir/Dialect/Quant/Utils/UniformSupport.h @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ -#define MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ +#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ +#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ #include -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "llvm/ADT/APFloat.h" @@ -218,4 +218,4 @@ class UniformQuantizedPerAxisValueConverter { } // namespace quant } // namespace mlir -#endif // MLIR_DIALECT_QUANT_UNIFORMSUPPORT_H_ +#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_ diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 1412c7a2615d..df91ba51a059 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -40,7 +40,7 @@ def Tosa_Dialect : Dialect { there will be tools to lower from the ML frameworks into TOSA. }]; - let dependentDialects = ["tensor::TensorDialect", "quant::QuantizationDialect"]; + let dependentDialects = ["tensor::TensorDialect", "quant::QuantDialect"]; let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 298c97015fe2..5e80745777b3 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -16,8 +16,8 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/Utils/UniformSupport.h" namespace mlir { namespace tosa { diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 19a62cadaa2e..7377d490a1dc 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -60,7 +60,7 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" @@ -127,7 +127,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { omp::OpenMPDialect, pdl::PDLDialect, pdl_interp::PDLInterpDialect, - quant::QuantizationDialect, + quant::QuantDialect, ROCDL::ROCDLDialect, scf::SCFDialect, shape::ShapeDialect, diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 28dc3cc23daf..2f8c3d2b471a 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -34,6 +34,7 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Mesh/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Transforms/Passes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" @@ -79,6 +80,7 @@ inline void registerAllPasses() { memref::registerMemRefPasses(); mesh::registerMeshPasses(); ml_program::registerMLProgramPasses(); + quant::registerQuantPasses(); registerSCFPasses(); registerShapePasses(); spirv::registerSPIRVPasses(); diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 0a7181d8bc17..c94dbb5692fd 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -8,12 +8,12 @@ #include "mlir-c/Dialect/Quant.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" using namespace mlir; -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect) //===---------------------------------------------------------------------===// // QuantizedType diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt index 037bba8dcb5c..31167e6af908 100644 --- a/mlir/lib/Dialect/Quant/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp index ee6be2554131..fd9973e00260 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp @@ -9,8 +9,8 @@ #include "QuantDialectBytecode.h" #include "mlir/Bytecode/BytecodeImplementation.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/APFloat.h" @@ -32,7 +32,7 @@ static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader, return success(); } -#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantDialectBytecode.cpp.inc" /// This class implements the bytecode interface for the Quant dialect. struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { @@ -65,6 +65,6 @@ struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { }; } // namespace -void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) { +void quant::detail::addBytecodeInterface(QuantDialect *dialect) { dialect->addInterfaces(); } diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h index 9e9cbf66d84d..eef2b5bbefec 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h @@ -15,12 +15,12 @@ #define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H namespace mlir::quant { -class QuantizationDialect; +class QuantDialect; namespace detail { /// Add the interfaces necessary for encoding the quantization dialect /// components in bytecode. -void addBytecodeInterface(QuantizationDialect *dialect); +void addBytecodeInterface(QuantDialect *dialect); } // namespace detail } // namespace mlir::quant diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index 124733286ce8..ae4478919a49 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -6,45 +6,205 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" #include "QuantDialectBytecode.h" #include "TypeDetail.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/MathExtras.h" -#include +#include "mlir/IR/TypeUtilities.h" -using namespace mlir; -using namespace mlir::quant; -using namespace mlir::quant::detail; +#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc" -#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc" +namespace mlir { +namespace quant { -void QuantizationDialect::initialize() { +namespace { + +// Verify the integrity of per-axis quantization information, if present. +// +// - quantizedType +// Any quantized type. Any quantized type with no per-axis quantization is +// ignored. +// +// - containerType +// Original input or result type of the operation using the provided quantized +// type. Used to ensure that the quantized type appears within a tensor and +// that the tensor is compatible with per-axis quantization information. +// +LogicalResult verifyPerAxisQuantization(Operation *op, + QuantizedType quantizedType, + Type containerType) { + auto quantizedPerAxisType = + dyn_cast(quantizedType); + if (!quantizedPerAxisType) + return success(); + + auto tensorType = dyn_cast(containerType); + if (!tensorType) + return op->emitError("scalar types may not use per-axis quantization"); + + if (!tensorType.hasRank()) + return success(); + + int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension(); + if (quantizedDimension >= tensorType.getRank()) + return op->emitError("quantized dimension must be less than tensor rank"); + + int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); + if (quantizedDimensionSize != ShapedType::kDynamic && + quantizedDimensionSize != + (int64_t)quantizedPerAxisType.getScales().size()) + return op->emitError( + "quantized dimension size does not match number of scales"); + + return success(); +} + +// Common verification logic for 'quant.dcast' and 'quant.qcast' ops. +// +// - quantizedType +// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'), +// whether as a primitive type or in a tensor. +// +// - floatType +// Float type used in the input ('quant.qcast') or result ('quant.dcast'), +// whether as a primitive type or in a tensor. +// +// - containerType +// Type of original input or result. +// +LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, + FloatType floatType, Type containerType) { + if (quantizedType.getExpressedType() != floatType) + return op->emitError( + "expressed type in quantized type expected to match float type"); + + // Veriy integrity of per-axis quantization information, if present. + return verifyPerAxisQuantization(op, quantizedType, containerType); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +void QuantDialect::initialize() { addTypes(); addOperations< #define GET_OP_LIST -#include "mlir/Dialect/Quant/QuantOps.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" >(); - addBytecodeInterface(this); + detail::addBytecodeInterface(this); +} + +//===----------------------------------------------------------------------===// +// DequantizeCastOp +//===----------------------------------------------------------------------===// + +LogicalResult DequantizeCastOp::verify() { + return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), + getInput().getType()); +} + +OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op + // with the value of x. Values x and y are guaranteed to be of the same type + // in this pattern. + auto srcQcastOp = getInput().getDefiningOp(); + if (!srcQcastOp) + return {}; + assert(srcQcastOp.getInput().getType() == getType()); + return srcQcastOp.getInput(); +} + +FloatType DequantizeCastOp::getFloatType() { + return cast(getElementTypeOrSelf(getResult().getType())); +} + +QuantizedType DequantizeCastOp::getQuantizedType() { + return cast(getElementTypeOrSelf(getInput().getType())); +} + +//===----------------------------------------------------------------------===// +// QuantizeCastOp +//===----------------------------------------------------------------------===// + +LogicalResult QuantizeCastOp::verify() { + return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(), + getInput().getType()); +} + +OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) { + // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op + // with the value of x if the casts invert each other. Contrary to the folding + // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values + // x and y are not guaranteed to be of the same type here, as they may use + // different quantization parameters. + auto srcDcastOp = getInput().getDefiningOp(); + if (!srcDcastOp || srcDcastOp.getInput().getType() != getType()) + return {}; + return srcDcastOp.getInput(); +} + +FloatType QuantizeCastOp::getFloatType() { + return cast(getElementTypeOrSelf(getInput().getType())); +} + +QuantizedType QuantizeCastOp::getQuantizedType() { + return cast(getElementTypeOrSelf(getResult().getType())); +} + +//===----------------------------------------------------------------------===// +// StorageCastOp +//===----------------------------------------------------------------------===// + +LogicalResult StorageCastOp::verify() { + auto quantizedType = getQuantizedType(); + auto integerType = getIntegerType(); + if (quantizedType.getStorageType() != integerType) + return emitError( + "storage type in quantized type expected to match integer type"); + + // Verify integrity of per-axis quantization information, if available. While + // the quantization type may appear in the input or the result, their tensor + // shapes are guaranteed to be identical at this point. + return verifyPerAxisQuantization(*this, quantizedType, getInput().getType()); } OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { - // Matches x -> [scast -> scast] -> y, replacing the second scast with the - // value of x if the casts invert each other. - auto srcScastOp = getArg().getDefiningOp(); - if (!srcScastOp || srcScastOp.getArg().getType() != getType()) - return OpFoldResult(); - return srcScastOp.getArg(); + // Matches x -> quant.scast -> quant.scast -> y, replacing the second + // quant.scast with the value of x if the casts invert each other. + auto srcScastOp = getInput().getDefiningOp(); + if (!srcScastOp || srcScastOp.getInput().getType() != getType()) + return {}; + return srcScastOp.getInput(); } +IntegerType StorageCastOp::getIntegerType() { + auto inputScalarType = getElementTypeOrSelf(getInput().getType()); + if (auto integerType = dyn_cast(inputScalarType)) + return integerType; + + auto resultScalarType = getElementTypeOrSelf(getResult().getType()); + return cast(resultScalarType); +} + +QuantizedType StorageCastOp::getQuantizedType() { + auto inputScalarType = getElementTypeOrSelf(getInput().getType()); + if (auto quantizedType = dyn_cast(inputScalarType)) + return quantizedType; + + auto resultScalarType = getElementTypeOrSelf(getResult().getType()); + return cast(resultScalarType); +} + +} // namespace quant +} // namespace mlir + #define GET_OP_CLASSES -#include "mlir/Dialect/Quant/QuantOps.cpp.inc" +#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 5a3500ec4278..e611b24b56c2 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -6,9 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "TypeDetail.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -20,12 +21,28 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; +namespace { + +// Return the minimum scale representable in a given float type +double getMinScale(Type expressedType) { + auto floatType = cast(expressedType); + return APFloat::getSmallest(floatType.getFloatSemantics()).convertToDouble(); +} + +// Return the maximum scale representable in a given float type +double getMaxScale(Type expressedType) { + auto floatType = cast(expressedType); + return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); +} + +} // namespace + unsigned QuantizedType::getFlags() const { return static_cast(impl)->flags; } bool QuantizedType::classof(Type type) { - return llvm::isa(type.getDialect()); + return llvm::isa(type.getDialect()); } LogicalResult @@ -36,7 +53,6 @@ QuantizedType::verify(function_ref emitError, bool isSigned = (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; - // Integral storage type width checks if (storageType.isa()) { unsigned integralWidth = llvm::dyn_cast(storageType).getWidth(); @@ -45,7 +61,8 @@ QuantizedType::verify(function_ref emitError, return emitError() << "illegal storage type size: " << integralWidth; } - int64_t defaultMin, defaultMax; + int64_t defaultMin = std::numeric_limits::min(); + int64_t defaultMax = std::numeric_limits::max(); if (storageType.isa()) { const auto width = llvm::dyn_cast(storageType).getWidth(); defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); @@ -61,7 +78,6 @@ QuantizedType::verify(function_ref emitError, "types, Float8E4M3FNType and Float8E5M2Type "; } - // Verify storageTypeMin and storageTypeMax. if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultMin || storageTypeMax > defaultMax) { return emitError() << "illegal storage min and storage max: (" @@ -82,6 +98,17 @@ int64_t QuantizedType::getStorageTypeMax() const { return static_cast(impl)->storageTypeMax; } +bool QuantizedType::hasStorageTypeBounds() const { + unsigned int integralWidth = getStorageTypeIntegralWidth(); + bool isSignedInteger = isSigned(); + int64_t defaultIntegerMin = + getDefaultMinimumForInteger(isSignedInteger, integralWidth); + int64_t defaultIntegerMax = + getDefaultMaximumForInteger(isSignedInteger, integralWidth); + return defaultIntegerMin != getStorageTypeMin() || + defaultIntegerMax != getStorageTypeMax(); +} + unsigned QuantizedType::getStorageTypeIntegralWidth() const { // NOTE: If ever supporting non-integral storage types, some other scheme // for determining the width will be needed. @@ -268,7 +295,6 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType, return Base::get(storageType.getContext(), flags, storageType, expressedType, scale, zeroPoint, storageTypeMin, storageTypeMax); } - UniformQuantizedType UniformQuantizedType::getChecked( function_ref emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, @@ -299,17 +325,17 @@ LogicalResult UniformQuantizedType::verify( return emitError() << "expressed type must be floating point"; // Verify scale. + double minScale = getMinScale(expressedType); + double maxScale = getMaxScale(expressedType); if (std::isinf(scale) || std::isnan(scale)) return emitError() << "illegal scale: " << scale; + if (scale > maxScale) + return emitError() << "scale out of expressed type range [" << minScale + << ", " << maxScale << "]"; return success(); } -bool UniformQuantizedType::classof(mlir::Type type) { - return type.getTypeID() == mlir::TypeID::get() || - type.getTypeID() == mlir::TypeID::get(); -} - double UniformQuantizedType::getScale() const { return getImpl()->scale; } int64_t UniformQuantizedType::getZeroPoint() const { @@ -363,17 +389,22 @@ LogicalResult UniformQuantizedPerAxisType::verify( << scales.size() << ", " << zeroPoints.size(); // Verify scale. + double minScale = getMinScale(expressedType); + double maxScale = getMaxScale(expressedType); + for (double scale : scales) { if (std::isinf(scale) || std::isnan(scale)) return emitError() << "illegal scale: " << scale; + if (scale > maxScale) + return emitError() << "scale out of expressed type range [" << minScale + << ", " << maxScale << "]"; } - return success(); -} + // Verify quantized dimension. + if (quantizedDimension < 0) + return emitError() << "illegal quantized dimension: " << quantizedDimension; -bool UniformQuantizedPerAxisType::classof(mlir::Type type) { - return type.getTypeID() == mlir::TypeID::get() || - type.getTypeID() == mlir::TypeID::get(); + return success(); } ArrayRef UniformQuantizedPerAxisType::getScales() const { diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 1fd148dd4736..4d503c6feedf 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" @@ -72,75 +72,22 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) { return type; } -static Type parseQuantileType(DialectAsmParser &parser) { - auto typeLoc = parser.getCurrentLocation(); - Type type; - - // Parse storage type (alpha_ident, integer_literal). - StringRef identifier; - unsigned storageTypeWidth = 0; - OptionalParseResult result = parser.parseOptionalType(type); - if (result.has_value()) { - if (!succeeded(*result)) - return nullptr; - - if (!type.isa() && !type.isa()) { - parser.emitError(typeLoc, "illegal quantile type alias"); - return nullptr; - } - } else if (succeeded(parser.parseKeyword(&identifier))) { - // Otherwise, this must be an unsigned integer (`u` integer-literal) - if (identifier.consume_front("u")) { - if (identifier.getAsInteger(10, storageTypeWidth)) { - parser.emitError(typeLoc, "expected quantile type width"); - return nullptr; - } - constexpr bool isSigned = false; - type = parser.getBuilder().getIntegerType(storageTypeWidth, isSigned); - - } else { - parser.emitError(typeLoc, "illegal quantile type alias"); - return nullptr; - } - } else { - return nullptr; - } - - return type; -} - -static ParseResult -checkStorageRange(DialectAsmParser &parser, int64_t storageTypeMin, - int64_t storageTypeMax, int64_t defaultStorageTypeMin, - int64_t defaultStorageTypeMax, SMLoc minLoc, SMLoc maxLoc) { - if (storageTypeMin < defaultStorageTypeMin) { - return parser.emitError(minLoc, "illegal storage type minimum: ") - << storageTypeMin; - } - if (storageTypeMax > defaultStorageTypeMax) { - return parser.emitError(maxLoc, "illegal storage type maximum: ") - << storageTypeMax; - } - return success(); -} - static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, bool isSigned, int64_t &storageTypeMin, int64_t &storageTypeMax) { - int64_t defaultMin, defaultMax; - if (storageType.isa()) { - const auto width = llvm::dyn_cast(storageType).getWidth(); - defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); - defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width); - } else if (storageType.isa()) { + int64_t defaultMin = std::numeric_limits::min(); + int64_t defaultMax = std::numeric_limits::max(); + if (auto integerStorageType = dyn_cast(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForInteger( + isSigned, integerStorageType.getWidth()); + defaultMax = QuantizedType::getDefaultMaximumForInteger( + isSigned, integerStorageType.getWidth()); + } else if (llvm::isa(storageType)) { defaultMin = QuantizedType::getDefaultMinimumForF8E5M2(); defaultMax = QuantizedType::getDefaultMaximumForF8E5M2(); - } else if (storageType.isa()) { + } else if (llvm::isa(storageType)) { defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN(); defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN(); - } else { - defaultMin = std::numeric_limits::max(); - defaultMax = std::numeric_limits::min(); } if (failed(parser.parseOptionalLess())) { @@ -150,15 +97,23 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, } // Explicit storage min and storage max. - // F8 min and max values are integers, so parseInteger() is used. SMLoc minLoc = parser.getCurrentLocation(), maxLoc; if (parser.parseInteger(storageTypeMin) || parser.parseColon() || parser.getCurrentLocation(&maxLoc) || parser.parseInteger(storageTypeMax) || parser.parseGreater()) return failure(); - return checkStorageRange(parser, storageTypeMin, storageTypeMax, defaultMin, - defaultMax, minLoc, maxLoc); + if (storageTypeMin < defaultMin) { + return parser.emitError(minLoc, "illegal storage type minimum: ") + << storageTypeMin; + } + if (storageTypeMax > defaultMax) { + return parser.emitError(maxLoc, "illegal storage type maximum: ") + << storageTypeMax; + } + llvm::errs() << "storage type min: " << storageTypeMin << '\n'; + llvm::errs() << "storage type max: " << storageTypeMax << '\n'; + return success(); } static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, @@ -229,13 +184,49 @@ static Type parseAnyType(DialectAsmParser &parser) { typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); } +static Type parseQuantileType(DialectAsmParser &parser) { + auto typeLoc = parser.getCurrentLocation(); + Type type; + + // Parse storage type (alpha_ident, integer_literal). + StringRef identifier; + unsigned storageTypeWidth = 0; + OptionalParseResult result = parser.parseOptionalType(type); + if (result.has_value()) { + if (!succeeded(*result)) + return nullptr; + + if (!type.isa() && !type.isa()) { + parser.emitError(typeLoc, "illegal quantile type alias"); + return nullptr; + } + } else if (succeeded(parser.parseKeyword(&identifier))) { + // Otherwise, this must be an unsigned integer (`u` integer-literal) + if (identifier.consume_front("u")) { + if (identifier.getAsInteger(10, storageTypeWidth)) { + parser.emitError(typeLoc, "expected quantile type width"); + return nullptr; + } + constexpr bool isSigned = false; + type = parser.getBuilder().getIntegerType(storageTypeWidth, isSigned); + + } else { + parser.emitError(typeLoc, "illegal quantile type alias"); + return nullptr; + } + } else { + return nullptr; + } + + return type; +} + static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, int64_t &zeroPoint) { // scale[:zeroPoint]? // scale. - if (parser.parseFloat(scale)) { + if (parser.parseFloat(scale)) return failure(); - } // zero point. zeroPoint = 0; @@ -314,7 +305,7 @@ static Type parseUniformType(DialectAsmParser &parser, bool isQuantile) { return nullptr; } - // quantile type. + // Quantile type. if (isQuantile) { if (parser.parseColon()) { return nullptr; @@ -459,7 +450,8 @@ static Type parseCalibratedType(DialectAsmParser &parser) { } /// Parse a type registered to this dialect. -Type QuantizationDialect::parseType(DialectAsmParser &parser) const { + +Type QuantDialect::parseType(DialectAsmParser &parser) const { // All types start with an identifier that we switch on. StringRef typeNameSpelling; @@ -494,24 +486,23 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { out << "u" << storageWidth; } - // storageTypeMin and storageTypeMax if not default. int64_t defaultMin = type.getStorageType().isa() ? QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth) - : type.getStorageType().isa() - ? QuantizedType::getDefaultMinimumForF8E5M2() - : type.getStorageType().isa() - ? QuantizedType::getDefaultMinimumForF8E4M3FN() - : std::numeric_limits::max(); + : type.getStorageType().isa() + ? QuantizedType::getDefaultMinimumForF8E5M2() + : type.getStorageType().isa() + ? QuantizedType::getDefaultMinimumForF8E4M3FN() + : std::numeric_limits::max(); int64_t defaultMax = type.getStorageType().isa() ? QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth) - : type.getStorageType().isa() - ? QuantizedType::getDefaultMaximumForF8E5M2() - : type.getStorageType().isa() - ? QuantizedType::getDefaultMaximumForF8E4M3FN() - : std::numeric_limits::min(); + : type.getStorageType().isa() + ? QuantizedType::getDefaultMaximumForF8E5M2() + : type.getStorageType().isa() + ? QuantizedType::getDefaultMaximumForF8E4M3FN() + : std::numeric_limits::min(); if (defaultMin != type.getStorageTypeMin() || defaultMax != type.getStorageTypeMax()) { @@ -650,7 +641,7 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type, } /// Print a type registered to this dialect. -void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { +void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { if (auto anyType = llvm::dyn_cast(type)) printAnyQuantizedType(anyType, os); else if (auto uniformType = llvm::dyn_cast(type)) diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..2fd4a41999d4 --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -0,0 +1,26 @@ +add_mlir_dialect_library(MLIRQuantTransforms + LowerQuantOps.cpp + StripFuncQuantTypes.cpp + + ADDITIONAL_HEADER_DIRS + {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms + + DEPENDS + MLIRQuantTransformsIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRFuncDialect + MLIRFuncTransforms + MLIRIndexDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgUtils + MLIRPass + MLIRQuantDialect + MLIRShapeDialect + MLIRTensorDialect + MLIRTransforms + MLIRTransformUtils + + ) diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp new file mode 100644 index 000000000000..4adeb9218ff8 --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -0,0 +1,676 @@ +//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms `quant.dcast` and `quant.qcast` into lower-level ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DEF_LOWERQUANTOPS +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +namespace { + +// If 'inputType' is a tensor, return its element type. If it is a scalar, +// return it as is. +Type getScalarType(Type inputType) { + if (auto tensorType = dyn_cast(inputType)) + return tensorType.getElementType(); + return inputType; +} + +// Return the shape of an input value as a list of attributes (static dimensions) +// and values (dynamic dimensions). If 'input' is a scalar, an empty list is +// returned. If 'input' is a tensor, its shape is returned. +SmallVector +getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) { + if (isa(input.getType())) + return tensor::getMixedSizes(builder, loc, input); + return {}; +} + +// If 'referenceType' is a scalar, return 'elementType' as is. If +// 'referenceType' is a tensor, return another tensor with the same shape and +// elements of type 'elementType'. +Type getScalarOrTensorType(Type elementType, Type referenceType) { + if (auto tensorType = dyn_cast(referenceType)) + return tensorType.clone(elementType); + return elementType; +} + +// Return a constant with the given value. If 'referenceType' is a tensor, a +// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a +// scalar, 'referenceShape' is ignored and a scalar constant is returned. +Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, + Type referenceType, + ArrayRef referenceShape) { + // If the result type is a scalar, return the unmodified scalar constant. + auto tensorType = dyn_cast(referenceType); + if (!tensorType) { + assert(referenceShape.empty()); + return scalar; + } + + // Create tensor splat + auto tensorConstant = + builder.create(loc, scalar, referenceShape); + return tensorConstant; +} + +// Reshape an unranked tensor into a 1D ranked tensor. +// +// - input +// Unranked tensor. +// +// Return values: +// +// - flatInput +// 1D ranked, dynamically shaped tensor. +// +// - inputShape +// 1D extent tensor containing the shape of the original unranked input. +// +std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, + Value input) { + // Get unranked input shape and total size + auto *context = builder.getContext(); + auto shapeType = shape::getExtentTensorType(context); + auto inputShape = builder.create(loc, shapeType, input); + Value inputSize = builder.create( + loc, builder.getIndexType(), inputShape); + + // Turn input size into 1D tensor + auto flatShapeType = shape::getExtentTensorType(context, 1); + auto flatInputShape = builder.create( + loc, flatShapeType, inputSize); + + // Reshape input tensor into 1D + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + auto flatInputType = + RankedTensorType::get({ShapedType::kDynamic}, elementType); + auto flatInput = builder.create( + loc, flatInputType, input, flatInputShape); + return std::make_pair(flatInput, inputShape); +} + +// Reshape an unranked tensor into a 3D ranked tensor where the central +// dimension of the result tensor corresponds to dimension 'axis' of the input +// tensor. +// +// - input +// Unranked tensor. +// +// - axis +// Index of the input dimension around which other input dimiensions will be +// collapsed. +// +// - axisSize +// Size of input dimension 'axis'. +// +// Return values: +// +// - flatInput +// 3D ranked tensor of shape [?, axisSize, ?]. +// +// - inputShape +// 1D extent tensor containing the shape of the original unranked input. +// +std::pair flattenUnrankedTensorAroundAxis(OpBuilder &builder, + Location loc, + Value input, + int64_t axis, + int64_t axisSize) { + // Get full tensor shape + auto *context = builder.getContext(); + auto indexType = builder.getIndexType(); + auto shapeType = shape::getExtentTensorType(context); + auto inputShape = builder.create(loc, shapeType, input); + + // Get shape and sizes on left and right of axis + auto axisValue = builder.create(loc, axis); + auto axisNextValue = builder.create(loc, axis + 1); + auto shapeLeft = builder.create( + loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) + .getResult(0); + auto sizeLeft = builder.create( + loc, indexType, shapeLeft); + auto shapeRight = builder.create( + loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) + .getResult(1); + auto sizeRight = builder.create( + loc, indexType, shapeRight); + + // Compute flat input shape as a 3-element 1D tensor + auto axisSizeValue = builder.create(loc, axisSize); + auto flatShapeType = shape::getExtentTensorType(context, 3); + auto flatInputShape = builder.create( + loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight}); + + // Reshape input to 3D tensor + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + auto flatInputType = RankedTensorType::get( + {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); + auto flatInput = builder.create( + loc, flatInputType, input, flatInputShape); + + return std::make_pair(flatInput, inputShape); +} + +// Reshape an input tensor into its original unranked shape. +// +// - input +// Ranked tensor. +// +// - inputShape +// 1D extent tensor. +// +Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, + Value inputShape) { + auto inputType = cast(input.getType()); + auto elementType = inputType.getElementType(); + auto unrankedType = UnrankedTensorType::get(elementType); + return builder.create(loc, unrankedType, input, inputShape); +} + +// Create a tensor constant containing all scales in a per-channel quantized +// type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> +// +Value materializePerChannelScales(OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { + auto scales = quantizedType.getScales(); + auto expressedType = quantizedType.getExpressedType(); + auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute { + return builder.getFloatAttr(expressedType, scale); + }); + auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType); + auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); + return builder.create(loc, tensorType, scalesAttr); +} + +// Create a tensor constant containing all zero points in a per-channel +// quantized type. Example: +// +// !quant.uniform +// +// produces +// +// %cst = arith.constant dense<[10, 20]> : tensor<2xi8> +// +Value materializePerChannelZeroPoints( + OpBuilder &builder, Location loc, + UniformQuantizedPerAxisType quantizedType) { + auto zeroPoints = quantizedType.getZeroPoints(); + auto storageType = quantizedType.getStorageType(); + auto zeroPointAttrs = llvm::map_to_vector( + zeroPoints, + [&](int64_t zeroPoint) -> Attribute { + return builder.getIntegerAttr(storageType, zeroPoint); + }); + auto tensorType = + RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType); + auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); + return builder.create(loc, tensorType, zeroPointsAttr); +} + +// Clamp the given scalar or tensor input using the storage bounds encoded in +// the given quantized type, if present. +// +// - input +// Scalar or ranked tensor input. The element type must match the storage type +// of 'quantizedType'. +// +// - inputShape +// If 'input' is a tensor, combination of attributes/values representing its +// static/dynamic dimensions. If 'input' is a scalar, empty list. +// +// - quantizedType +// Per-axis or per-channel quantized type. +Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, + QuantizedType quantizedType) { + // If quantized type does not narrow down the storage type range, there is + // nothing to do. + if (!quantizedType.hasStorageTypeBounds()) + return input; + + // Materialize bounds + auto inputType = input.getType(); + auto storageType = quantizedType.getStorageType(); + auto storageMinScalar = builder.create( + loc, quantizedType.getStorageTypeMin(), storageType); + auto storageMaxScalar = builder.create( + loc, quantizedType.getStorageTypeMax(), storageType); + auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar, + inputType, inputShape); + auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar, + inputType, inputShape); + + // Clamp + if (quantizedType.isSigned()) { + input = builder.create(loc, input, storageMin); + input = builder.create(loc, input, storageMax); + } else { + input = builder.create(loc, input, storageMin); + input = builder.create(loc, input, storageMax); + } + return input; +} + +// Emit op 'arith.fptosi' or 'arith.fptoui'. +Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, + Type resultType, bool isSigned) { + if (isSigned) + return builder.create(loc, resultType, input); + return builder.create(loc, resultType, input); +} + +// Emit op 'arith.sitofp' or 'arith.uitofp'. +Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, + Type resultType, bool isSigned) { + if (isSigned) + return builder.create(loc, resultType, input); + return builder.create(loc, resultType, input); +} + +// Quantize a scalar or ranked tensor value. The stored value is clamped using +// the storage bounds encoded in the given quantized type. +// +// See function 'convertRanked()' below for a description of the arguments. +Value quantizeValue(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + // Convert scale to tensor if necessary + auto inputType = input.getType(); + scale = getScalarOrTensorConstant( + builder, loc, scale, inputType, inputShape); + + // Scale input + auto scaledValue = builder.create(loc, input, scale); + + // Skip unnecessary computations if no zero point is given + Value storedValueFloat = scaledValue; + if (!matchPattern(zeroPoint, m_Zero())) { + // Convert zero point to tensor if necessary + zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, + inputShape); + + // Convert zero point from storage to expressed type + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, + scale.getType(), + quantizedType.isSigned()); + + // Add zero point to stored value + storedValueFloat = + builder.create(loc, scaledValue, zeroPoint); + } + + // Convert stored value to storage type + auto storageScalarOrTensorType = + getScalarOrTensorType(quantizedType.getStorageType(), inputType); + auto storedValueInt = convertFloatToInteger( + builder, loc, storedValueFloat, storageScalarOrTensorType, + quantizedType.isSigned()); + + // Clamp stored value it if the storage type is bound + auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt, + inputShape, quantizedType); + return storedValueClamped; +} + +// Dequantize a scalar or ranked tensor input. +// +// See function 'convertRanked()' below for a description of the arguments. +Value dequantizeValue(OpBuilder &builder, Location loc, Value input, + ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + // Convert scale to tensor if necessary + auto inputType = input.getType(); + scale = getScalarOrTensorConstant( + builder, loc, scale, inputType, inputShape); + + // Convert stored value to float + auto result = convertIntegerToFloat( + builder, loc, input, scale.getType(), quantizedType.isSigned()); + + // Skip unnecessary computations if no zero point is given + if (!matchPattern(zeroPoint, m_Zero())) { + // Convert zero point to tensor if necessary + zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType, + inputShape); + + // Convert zero point from storage to expressed type + zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, + scale.getType(), + quantizedType.isSigned()); + + // Subtract zero point to stored value + result = builder.create(loc, result, zeroPoint); + } + + // Multiply by scale + result = builder.create(loc, result, scale); + return result; +} + +// Convert a scalar or ranked tensor input with the given scale and zero point +// values. +// +// - input +// Scalar or ranked tensor value. +// +// - inputShape +// If 'input' is a tensor, combination or attributes/values representing its +// static/dynamic dimensions. If 'input' is a scalar, empty list. +// +// - scale +// Scale as a floating-point scalar value. +// +// - zeroPoint +// Zero point as an integer scalar value. +// +// - quantizedType +// Scalar quantized type of the result ('quant.qcast') or of the input +// ('quant.dcast'). +// +Value convertRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, ArrayRef inputShape, Value scale, + Value zeroPoint, QuantizedType quantizedType) { + if (isa(op)) + return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint, + quantizedType); + if (isa(op)) + return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint, + quantizedType); + llvm_unreachable("unexpected quant op"); +} + +// Convert an operation using per-layer quantization with a scalar or ranked +// tensor input. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar or ranked tensor. +// +// - quantizedType +// Per-layer quantized type. +// +Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, UniformQuantizedType quantizedType) { + // Create scale and zero point constants + auto expressedType = quantizedType.getExpressedType(); + auto storageType = quantizedType.getStorageType(); + auto scaleAttr = + builder.getFloatAttr(expressedType, quantizedType.getScale()); + auto scale = builder.create(loc, expressedType, scaleAttr); + auto zeroPointAttr = + builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); + auto zeroPoint = + builder.create(loc, storageType, zeroPointAttr); + + auto inputShape = getScalarOrTensorShape(builder, loc, input); + return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint, + quantizedType); +} + +// Convert an operation using per-layer quantization. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. +// +// - quantizedType +// Per-layer quantized type. +// +Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op, + Value input, UniformQuantizedType quantizedType) { + // Flatten input if unranked + bool isUnranked = isa(input.getType()); + Value inputShape; + if (isUnranked) + std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input); + + // Process ranked tensor + auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType); + + // Restore original shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + + return result; +} + +// Convert an operation using per-channel quantization and a scalar or ranked +// tensor as an input. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar or ranked tensor. +// +// - quantizedType +// Per-channel quantized type. +// +Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, + Value input, + UniformQuantizedPerAxisType quantizedType, + int64_t channelAxis) { + auto *context = builder.getContext(); + + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + auto scales = materializePerChannelScales(builder, loc, quantizedType); + auto zeroPoints = + materializePerChannelZeroPoints(builder, loc, quantizedType); + + auto elementType = isa(inputType.getElementType()) + ? quantizedType.getStorageType() + : quantizedType.getExpressedType(); + auto initShape = tensor::getMixedSizes(builder, loc, input); + Value init = builder.create(loc, initShape, elementType); + + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + auto channelAxisAffineMap = AffineMap::get( + inputRank, 0, builder.getAffineDimExpr(channelAxis), context); + SmallVector indexingMaps{ + builder.getMultiDimIdentityMap(inputRank), + channelAxisAffineMap, + channelAxisAffineMap, + builder.getMultiDimIdentityMap(inputRank) + }; + auto result = builder.create( + loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, + iteratorTypes, + [&](OpBuilder& builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + builder.create(loc, result); + }) + .getResult(0); + + return result; +} + +// Convert an operation using per-channel quantization. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. +// +// - quantizedType +// Per-channel quantized type. +// +Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op, + Value input, + UniformQuantizedPerAxisType quantizedType) { + // Flatten unranked tensor into a 3D ranked tensor if necessary + bool isUnranked = isa(input.getType()); + int64_t channelAxis = quantizedType.getQuantizedDimension(); + int64_t channelAxisSize = (int64_t) quantizedType.getScales().size(); + Value inputShape; + if (isUnranked) { + std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis( + builder, loc, input, channelAxis, channelAxisSize); + channelAxis = 1; + } + + // Work on a ranked tensor + auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType, + channelAxis); + + // Restore original tensor shape if unranked + if (isUnranked) + result = restoreUnrankedTensorShape(builder, loc, result, inputShape); + + return result; +} + +// Convert a quantization operation. +// +// - op +// 'quant.dcast' or 'quant.qcast' op. +// +// - input +// Scalar, ranked tensor, or unranked tensor. The element type matches +// the storage type (quant.dcast) or expressed type (quant.qcast) of +// 'quantizedType'. +// +// - quantizedType +// Per-layer or per-channel quantized type. +// +Value convertQuantized(OpBuilder &builder, Location loc, Operation *op, + Value input, Type quantizedType) { + if (auto uniformQuantizedType = dyn_cast(quantizedType)) + return convertPerLayer(builder, loc, op, input, uniformQuantizedType); + + if (auto uniformQuantizedPerAxisType = + dyn_cast(quantizedType)) + return convertPerChannel(builder, loc, op, input, + uniformQuantizedPerAxisType); + + llvm_unreachable("unexpected quantized type"); +} + +// Lowering pattern for 'quant.dcast' +struct DequantizeCastOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getInput(); + auto quantizedType = + cast(getScalarType(op.getInput().getType())); + + // Convert quantized input to storage type + auto storageScalarOrTensorType = + getScalarOrTensorType(quantizedType.getStorageType(), input.getType()); + input = rewriter.create( + loc, storageScalarOrTensorType, input); + + auto result = convertQuantized(rewriter, loc, op, input, quantizedType); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +// Lowering pattern for 'quant.qcast' +struct QuantizeCastOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto input = op.getInput(); + auto quantizedType = getScalarType(op.getResult().getType()); + + // Flatten unranked tensor input + auto result = convertQuantized(rewriter, loc, op, input, quantizedType); + + // Cast stored value to result quantized value + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), result); + return success(); + } +}; + +struct LowerQuantOps : public impl::LowerQuantOpsBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateLowerQuantOpsPatterns(patterns); + + ConversionTarget target(getContext()); + target.addLegalOp(); + target.addIllegalDialect(); + target.addLegalDialect< + arith::ArithDialect, + linalg::LinalgDialect, + shape::ShapeDialect, + tensor::TensorDialect + >(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) { + patterns.add< + DequantizeCastOpConversion, + QuantizeCastOpConversion + >(patterns.getContext()); +} + +} // namespace quant +} // namespace mlir diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp new file mode 100644 index 000000000000..8996eff61a39 --- /dev/null +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -0,0 +1,114 @@ +//===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Strips quantized types from function headers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace quant { + +#define GEN_PASS_DEF_STRIPFUNCQUANTTYPES +#include "mlir/Dialect/Quant/Transforms/Passes.h.inc" + +namespace { + +class QuantizedTypeConverter : public TypeConverter { + + static Type convertQuantizedType(QuantizedType quantizedType) { + return quantizedType.getStorageType(); + } + + static Type convertTensorType(TensorType tensorType) { + if (auto quantizedType = dyn_cast(tensorType.getElementType())) + return tensorType.clone(convertQuantizedType(quantizedType)); + return tensorType; + } + + static Value materializeConversion(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + return builder.create(loc, type, inputs[0]); + } + +public: + + explicit QuantizedTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertQuantizedType); + addConversion(convertTensorType); + + addArgumentMaterialization(materializeConversion); + addSourceMaterialization(materializeConversion); + addTargetMaterialization(materializeConversion); + } +}; + +// Conversion pass +class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase { + + // Return whether a type is considered legal when occurring in the header of + // a function or as an operand to a 'return' op. + static bool isLegalType(Type type) { + if (auto tensorType = dyn_cast(type)) + return isLegalType(tensorType.getElementType()); + return !isa(type); + } + +public: + + void runOnOperation() override { + + auto moduleOp = cast(getOperation()); + auto* context = &getContext(); + + QuantizedTypeConverter typeConverter; + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + // Mark func.func, func.return, and func.call illegal if they contain any + // quantized types. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + // Register conversion patterns + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + + // Apply conversion + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +} // namespace quant +} // namespace mlir + diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp index 8c6972982469..308ff35e0145 100644 --- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" using namespace mlir; using namespace mlir::quant; diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index 408701f80444..62c7a7128d63 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/Dialect/Quant/Utils/UniformSupport.h" #include "mlir/IR/BuiltinTypes.h" #include diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 4c50aaecfe94..b1749164cc60 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" @@ -492,8 +492,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) return getInput1(); @@ -517,8 +519,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) { return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsAttr && lhsAttr.isSplat()) { if (llvm::isa(resultETy) && lhsAttr.getSplatValue().isZero()) @@ -586,8 +590,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); const int64_t shift = llvm::isa(resultETy) ? getShift() : 0; if (rhsTy == resultTy) { @@ -614,8 +620,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) { return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) return getInput1(); @@ -657,8 +665,10 @@ struct APIntFoldGreaterEqual { OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (!lhsAttr || !rhsAttr) return {}; @@ -669,8 +679,10 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); if (!lhsAttr || !rhsAttr) return {}; @@ -682,8 +694,10 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast(getType()); - auto lhsAttr = llvm::dyn_cast_if_present(adaptor.getInput1()); - auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getInput2()); + auto lhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput1()); + auto rhsAttr = + llvm::dyn_cast_if_present(adaptor.getInput2()); Value lhs = getInput1(); Value rhs = getInput2(); auto lhsTy = llvm::cast(lhs.getType()); @@ -806,14 +820,16 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { } // reshape(const(x)) -> const(reshape-attr(x)) - if (auto operand = llvm::dyn_cast_if_present(adaptor.getInput1())) { + if (auto operand = + llvm::dyn_cast_if_present(adaptor.getInput1())) { // Constants must have static shape. if (!outputTy.hasStaticShape()) return {}; // Okay to duplicate splat constants. if (operand.isSplat()) - return SplatElementsAttr::get(outputTy, operand.getSplatValue()); + return SplatElementsAttr::get(outputTy, + operand.getSplatValue()); // Don't duplicate other constants. if (!getInput1().hasOneUse()) @@ -873,7 +889,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput(); auto operandTy = llvm::cast(operand.getType()); auto axis = getAxis(); - auto operandAttr = llvm::dyn_cast_if_present(adaptor.getInput()); + auto operandAttr = + llvm::dyn_cast_if_present(adaptor.getInput()); if (operandAttr) return operandAttr; @@ -922,7 +939,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { if (getOnTrue() == getOnFalse()) return getOnTrue(); - auto predicate = llvm::dyn_cast_if_present(adaptor.getPred()); + auto predicate = + llvm::dyn_cast_if_present(adaptor.getPred()); if (!predicate) return {}; @@ -944,7 +962,8 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::cast(getType()); // Transposing splat values just means reshaping. - if (auto input = llvm::dyn_cast_if_present(adaptor.getInput1())) { + if (auto input = + llvm::dyn_cast_if_present(adaptor.getInput1())) { if (input.isSplat() && resultTy.hasStaticShape() && inputTy.getElementType() == resultTy.getElementType()) return input.reshape(resultTy); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 729116da45e4..26ceb59e0baf 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" @@ -211,7 +211,8 @@ static bool hasZeroDimension(ShapedType shapedType) { return false; } -template static LogicalResult verifyConvOp(T op) { +template +static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); auto weightType = llvm::dyn_cast(op.getWeight().getType()); diff --git a/mlir/test/Dialect/Quant/canonicalize.mlir b/mlir/test/Dialect/Quant/canonicalize.mlir index 36c3eaf5e10d..73c57e2a4821 100644 --- a/mlir/test/Dialect/Quant/canonicalize.mlir +++ b/mlir/test/Dialect/Quant/canonicalize.mlir @@ -1,24 +1,124 @@ // RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s +// CHECK-LABEL: @dcast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @dcast_fold(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias> + %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32> + return %1 : tensor<4xf32> +} + // ----- -// CHECK-LABEL: redundant_scast -func.func @redundant_scast() -> tensor<4xi8> { - // CHECK-NEXT: arith.constant dense<10> : tensor<4xi8> - // CHECK-NEXT: return - %cst = arith.constant dense<5> : tensor<4xi8> - %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - %2 = "quant.scast"(%1) : (tensor<4x!quant.uniform>) -> tensor<4xi8> - %3 = arith.addi %2, %2 : tensor<4xi8> - return %3 : tensor<4xi8> + +// CHECK-LABEL: @dcast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.dcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +func.func @dcast_no_fold_source(%arg0: tensor<4xi8>) -> tensor<4xf32> { + %0 = quant.scast %arg0 : tensor<4xi8> to tensor<4x!qalias> + %1 = quant.dcast %0 : tensor<4x!qalias> to tensor<4xf32> + return %1 : tensor<4xf32> } // ----- -// CHECK-LABEL: non_redundant_scast -func.func @non_redundant_scast() -> tensor<4x!quant.uniform> { - // CHECK-NEXT: arith.constant dense<5> : tensor<4xi8> - // CHECK-NEXT: scast - // CHECK-NEXT: return - %cst = arith.constant dense<5> : tensor<4xi8> - %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - return %1 : tensor<4x!quant.uniform> + +// CHECK-LABEL: @qcast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @qcast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> { + %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> } + +// ----- + +// CHECK-LABEL: @qcast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = arith.negf %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +func.func @qcast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4x!qalias> { + %0 = arith.negf %arg0 : tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_no_fold_type +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.dcast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.qcast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform +func.func @qcast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> { + %0 = quant.dcast %arg0 : tensor<4x!qalias> to tensor<4xf32> + %1 = quant.qcast %0 : tensor<4xf32> to tensor<4x!qalias1> + return %1 : tensor<4x!qalias1> +} + +// ----- + +// CHECK-LABEL: @scast_fold +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: return %[[ARG_0]] + +!qalias = !quant.uniform +func.func @scast_fold(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias> { + %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8> + %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias> + return %1 : tensor<4x!qalias> +} + +// ----- + +// CHECK-LABEL: @scast_no_fold_source +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[QCAST:.*]] = quant.qcast %[[ARG_0]] +// CHECK: %[[SCAST:.*]] = quant.scast %[[QCAST]] +// CHECK: return %[[SCAST]] + +!qalias = !quant.uniform +func.func @scast_no_fold_source(%arg0: tensor<4xf32>) -> tensor<4xi8> { + %0 = quant.qcast %arg0 : tensor<4xf32> to tensor<4x!qalias> + %1 = quant.scast %0 : tensor<4x!qalias> to tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + +// CHECK-LABEL: @scast_no_fold_type +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[VAL_0:.*]] = quant.scast %[[ARG_0]] +// CHECK: %[[VAL_1:.*]] = quant.scast %[[VAL_0]] +// CHECK: return %[[VAL_1]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform +func.func @scast_no_fold_type(%arg0: tensor<4x!qalias>) -> tensor<4x!qalias1> { + %0 = quant.scast %arg0 : tensor<4x!qalias> to tensor<4xi8> + %1 = quant.scast %0 : tensor<4xi8> to tensor<4x!qalias1> + return %1 : tensor<4x!qalias1> +} + diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir new file mode 100644 index 000000000000..ba3a8e312d96 --- /dev/null +++ b/mlir/test/Dialect/Quant/invalid.mlir @@ -0,0 +1,258 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +func.func @dcast_invalid_input(%arg0: f32) { + // expected-error@+1 {{operand #0 must be scalar or tensor of quantized type}} + %0 = quant.dcast %arg0 : f32 to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{result #0 must be scalar or tensor of floating-point}} + %0 = quant.dcast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_scalar_tensor(%arg0: !qalias) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : !qalias to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : tensor to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3x!qalias>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_float_type_mismatch(%arg0: !qalias) { + // expected-error@+1 {{expressed type in quantized type expected to match float type}} + %0 = quant.dcast %arg0 : !qalias to f64 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_scalar(%arg0: !qalias) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.dcast %arg0 : !qalias to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x!qalias>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.dcast %arg0 : tensor<2x3x!qalias> to tensor<2x3xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_invalid_rank(%arg0: tensor<2x3x4x!qalias>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.dcast %arg0 : tensor<2x3x4x!qalias> to tensor<2x3x4xf32> + return +} + +// ----- + +func.func @qcast_invalid_input(%arg0: f32) { + // expected-error@+1 {{result #0 must be scalar or tensor of quantized type}} + %0 = quant.qcast %arg0 : f32 to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{operand #0 must be scalar or tensor of floating-point}} + %0 = quant.qcast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_scalar_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xf32>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_float_type_mismatch(%arg0: f64) { + // expected-error@+1 {{expressed type in quantized type expected to match float type}} + %0 = quant.qcast %arg0 : f64 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_scalar(%arg0: f32) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.qcast %arg0 : f32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3xf32>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.qcast %arg0 : tensor<2x3xf32> to tensor<2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_invalid_rank(%arg0: tensor<2x3x4xf32>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.qcast %arg0 : tensor<2x3x4xf32> to tensor<2x3x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_invalid_input(%arg0: si32) { + // expected-error@+1 {{operand #0 must be scalar or tensor of signless integer or quantized type}} + %0 = quant.scast %arg0 : si32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_invalid_result(%arg0: !qalias) { + // expected-error@+1 {{result #0 must be scalar or tensor of signless integer or quantized type}} + %0 = quant.scast %arg0 : !qalias to si32 + return +} + +// ----- + +func.func @scast_both_integers(%arg0: i8) { + // expected-error@+1 {{input must be integer and result must be quantized, or vice versa}} + %0 = quant.scast %arg0 : i8 to i8 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_both_quantized(%arg0: !qalias) { + // expected-error@+1 {{input must be integer and result must be quantized, or vice versa}} + %0 = quant.scast %arg0 : !qalias to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_scalar_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_ranked_unranked_tensor(%arg0: tensor) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_mismatch_static_dynamic_tensor(%arg0: tensor<2x3xi8>) { + // expected-error@+1 {{input and result are both scalars or both tensors with matching shape}} + %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_integer_type_mismatch(%arg0: i32) { + // expected-error@+1 {{storage type in quantized type expected to match integer type}} + %0 = quant.scast %arg0 : i32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_scalar(%arg0: i8) { + // expected-error@+1 {{scalar types may not use per-axis quantization}} + %0 = quant.scast %arg0 : i8 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3xi8>) { + // expected-error@+1 {{quantized dimension must be less than tensor rank}} + %0 = quant.scast %arg0 : tensor<2x3xi8> to tensor<2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) { + // expected-error@+1 {{quantized dimension size does not match number of scales}} + %0 = quant.scast %arg0 : tensor<2x3x4xi8> to tensor<2x3x4x!qalias> + return +} + diff --git a/mlir/test/Dialect/Quant/lower-quant-ops.mlir b/mlir/test/Dialect/Quant/lower-quant-ops.mlir new file mode 100644 index 000000000000..6bba9f5c0377 --- /dev/null +++ b/mlir/test/Dialect/Quant/lower-quant-ops.mlir @@ -0,0 +1,511 @@ +// RUN: mlir-opt %s --lower-quant-ops --split-input-file | FileCheck %s + +// CHECK-LABEL: @dcast_per_layer_scalar +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_scalar_unsigned +// CHECK-SAME: %[[ARG_0:.*]]: !quant.uniform + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : !quant.uniform to i8 + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[STORED_FLOAT:.*]] = arith.uitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.uitofp %[[ZERO_POINT]] : i8 to f32 + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: return %[[EXPRESSED]] : f32 + +!qalias = !quant.uniform +func.func @dcast_per_layer_scalar_unsigned(%arg0: !qalias) -> f32 { + %0 = quant.dcast %arg0 : !qalias to f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_0d +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor> to tensor + +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor to tensor +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor +// CHECK: return %[[EXPRESSED]] : tensor + +!qalias = !quant.uniform +func.func @dcast_per_layer_0d(%arg0: tensor) -> tensor { + %0 = quant.dcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<3x?x5x!quant.uniform> to tensor<3x?x5xi8> +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_INT]], %[[C_1]] : tensor<3x?x5xi8> +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32> +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8> +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor<3x?x5xf32> +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32> +// CHECK: return %[[EXPRESSED]] : tensor<3x?x5xf32> + +!qalias = !quant.uniform +func.func @dcast_per_layer_ranked(%arg0: tensor<3x?x5x!qalias>) -> tensor<3x?x5xf32> { + %0 = quant.dcast %arg0 : tensor<3x?x5x!qalias> to tensor<3x?x5xf32> + return %0 : tensor<3x?x5xf32> +} + +// ----- + +// CHECK-LABEL: @dcast_per_layer_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_INT:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform> to tensor<*xi8> +// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[STORED_INT]] : tensor<*xi8> -> tensor +// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor -> index +// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex> +// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_INT]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<1xindex>) -> tensor +// CHECK: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_COLLAPSED]] : tensor to tensor +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor + +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE_TENSOR]] : tensor + +// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[EXPRESSED]](%[[INPUT_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32> + +!qalias = !quant.uniform +func.func @dcast_per_layer_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> + +// CHECK-LABEL: @dcast_per_channel_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<4x?x?x5x!quant.uniform> to tensor<4x?x?x5xi8> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8> +// CHECK: %[[C_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_1]] : tensor<4x?x?x5xi8> +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_TENSOR]], %[[C_2]] : tensor<4x?x?x5xi8> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xf32> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[STORED_TENSOR]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xi8>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xf32>) { +// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32): +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: linalg.yield %[[EXPRESSED]] : f32 +// CHECK: } -> tensor<4x?x?x5xf32> +// CHECK: return %[[GENERIC]] : tensor<4x?x?x5xf32> + +!qalias = !quant.uniform +func.func @dcast_per_channel_ranked(%arg0: tensor<4x?x?x5x!qalias>) -> tensor<4x?x?x5xf32> { + %0 = quant.dcast %arg0 : tensor<4x?x?x5x!qalias> to tensor<4x?x?x5xf32> + return %0 : tensor<4x?x?x5xf32> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @dcast_per_channel_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK: %[[STORED_TENSOR:.*]] = quant.scast %[[ARG_0]] : tensor<*x!quant.uniform> to tensor<*xi8> +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[STORED_TENSOR]] : tensor<*xi8> -> tensor +// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index +// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index +// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor -> index +// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor -> index + +// CHECK: %[[NUM_CHANNELS:.*]] = arith.constant 3 : index +// CHECK: %[[COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[NUM_CHANNELS]], %[[SIZE_RIGHT]] : tensor<3xindex> +// CHECK: %[[STORED_COLLAPSED:.*]] = tensor.reshape %[[STORED_TENSOR]](%[[COLLAPSED_SHAPE]]) : (tensor<*xi8>, tensor<3xindex>) -> tensor + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8> +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_0]] : tensor +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[STORED_COLLAPSED]], %[[C_2]] : tensor +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[STORED_COLLAPSED]], %[[SCALES]], %[[ZERO_POINTS]] : tensor, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor) { +// CHECK: ^bb0(%[[STORED_INT:.*]]: i8, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: f32): +// CHECK: %[[STORED_FLOAT:.*]] = arith.sitofp %[[STORED_INT]] : i8 to f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[SCALED:.*]] = arith.subf %[[STORED_FLOAT]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[EXPRESSED:.*]] = arith.mulf %[[SCALED]], %[[SCALE]] : f32 +// CHECK: linalg.yield %[[EXPRESSED]] : f32 +// CHECK: } -> tensor + +// CHECK: %[[EXPRESSED_EXPANDED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[EXPRESSED_EXPANDED]] : tensor<*xf32> + +!qalias = !quant.uniform +func.func @dcast_per_channel_unranked(%arg0: tensor<*x!qalias>) -> tensor<*xf32> { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : f32 to i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : i8 to !quant.uniform +// CHECK: return %[[STORED_QUANT]] : !quant.uniform + +!qalias = !quant.uniform +func.func @qcast_per_layer_scalar(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar_bounds +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[SCALED]] : f32 to i8 + +// CHECK-DAG: %[[C_NEG_5:.*]] = arith.constant -5 : i8 +// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_5]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform:f32, 2.000000e+00> +// CHECK: return %[[STORED_QUANT]] : !quant.uniform:f32, 2.000000e+00> + +!qalias = !quant.uniform:f32, 2.0> +func.func @qcast_per_layer_scalar_bounds(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_scalar_unsigned_bounds +// CHECK-SAME: %[[ARG_0:.*]]: f32 + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 0 : i8 + +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptoui %[[SCALED]] : f32 to i8 + +// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : i8 +// CHECK-DAG: %[[C_10:.*]] = arith.constant 10 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxui %[[STORED_INT]], %[[C_2]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minui %[[STORED_CLAMPED_TEMP]], %[[C_10]] : i8 + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : i8 to !quant.uniform:f32, 2.000000e+00> +// CHECK: return %[[STORED_QUANT]] : !quant.uniform:f32, 2.000000e+00> + +!qalias = !quant.uniform:f32, 2.0> +func.func @qcast_per_layer_scalar_unsigned_bounds(%arg0: f32) -> !qalias { + %0 = quant.qcast %arg0 : f32 to !qalias + return %0 : !qalias +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_0d +// CHECK-SAME: %[[ARG_0:.*]]: tensor + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]] : tensor +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor + +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor to tensor +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor to tensor + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor to tensor> +// CHECK: return %[[STORED_QUANT]] : tensor> + +!qalias = !quant.uniform +func.func @qcast_per_layer_0d(%arg0: tensor) -> tensor { + %0 = quant.qcast %arg0 : tensor to tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x?x5xf32> + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index + +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<3x?x5xf32> +// CHECK: %[[SCALE_TENSOR:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xf32> +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_TENSOR]] : tensor<3x?x5xf32> + +// CHECK: %[[ZERO_POINT_TENSOR:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_1]]] : tensor<3x?x5xi8> +// CHECK: %[[ZERO_POINT_TENSOR_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_TENSOR]] : tensor<3x?x5xi8> to tensor<3x?x5xf32> +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_TENSOR_FLOAT]] : tensor<3x?x5xf32> +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor<3x?x5xf32> to tensor<3x?x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_INT]] : tensor<3x?x5xi8> to tensor<3x?x5x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<3x?x5x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_layer_ranked(%arg0: tensor<3x?x5xf32>) -> tensor<3x?x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<3x?x5xf32> to tensor<3x?x5x!qalias> + return %0 : tensor<3x?x5x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_ranked_bounds +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x5xf32> + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 + +// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]] : tensor<3x5xf32> +// CHECK: %[[SCALED:.*]] = arith.divf %[[ARG_0]], %[[SCALE_SPLAT]] : tensor<3x5xf32> + +// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]] : tensor<3x5xi8> +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor<3x5xi8> to tensor<3x5xf32> + +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor<3x5xf32> +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : tensor<3x5xf32> to tensor<3x5xi8> + +// CHECK-DAG: %[[C_NEG_8:.*]] = arith.constant -8 : i8 +// CHECK-DAG: %[[C_7:.*]] = arith.constant 7 : i8 +// CHECK-DAG: %[[SPLAT_NEG_8:.*]] = tensor.splat %[[C_NEG_8]] : tensor<3x5xi8> +// CHECK-DAG: %[[SPLAT_7:.*]] = tensor.splat %[[C_7]] : tensor<3x5xi8> +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[SPLAT_NEG_8]] : tensor<3x5xi8> +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[SPLAT_7]] : tensor<3x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_CLAMPED]] : tensor<3x5xi8> to tensor<3x5x!quant.uniform:f32, 2.000000e+00:10>> +// CHECK: return %[[STORED_QUANT]] : tensor<3x5x!quant.uniform:f32, 2.000000e+00:10>> + +!qalias = !quant.uniform:f32, 2.0:10> +func.func @qcast_per_layer_ranked_bounds(%arg0: tensor<3x5xf32>) -> tensor<3x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<3x5xf32> to tensor<3x5x!qalias> + return %0 : tensor<3x5x!qalias> +} + +// ----- + +// CHECK-LABEL: @qcast_per_layer_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32> + +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor +// CHECK: %[[SIZE:.*]] = shape.num_elements %[[SHAPE]] : tensor -> index +// CHECK: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[SIZE]] : tensor<1xindex> +// CHECK: %[[RANKED_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + +// CHECK-DAG: %[[SCALE:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[ZERO_POINT:.*]] = arith.constant 10 : i8 +// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index + +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[RANKED_INPUT]], %[[C_0]] : tensor +// CHECK: %[[SCALE_SPLAT:.*]] = tensor.splat %[[SCALE]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[SCALED:.*]] = arith.divf %[[RANKED_INPUT]], %[[SCALE_SPLAT]] : tensor + +// CHECK: %[[ZERO_POINT_SPLAT:.*]] = tensor.splat %[[ZERO_POINT]]{{\[}}%[[DIM_0]]] : tensor +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT_SPLAT]] : tensor to tensor +// CHECK: %[[STORED:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : tensor +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED]] : tensor to tensor + +// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[STORED_INT]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xi8> +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_layer_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return %0 : tensor<*x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_ranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x?x?x5xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20]> : tensor<2xi8> + +// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[C_1]] : tensor<4x?x?x5xf32> +// CHECK-DAG: %[[C_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG_0]], %[[C_2]] : tensor<4x?x?x5xf32> +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_1]], %[[DIM_2]]) : tensor<4x?x?x5xi8> + +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x?x?x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x?x?x5xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor<4x?x?x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x?x?x5xi8> to tensor<4x?x?x5x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<4x?x?x5x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_channel_ranked(%arg0: tensor<4x?x?x5xf32>) -> tensor<4x?x?x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<4x?x?x5xf32> to tensor<4x?x?x5x!qalias> + return %0 : tensor<4x?x?x5x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_ranked_bounds +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x2x5xf32> + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<0> : tensor<2xi8> + +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x5xi8> +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG_0]], %[[SCALES]], %[[ZERO_POINTS]] : tensor<4x2x5xf32>, tensor<2xf32>, tensor<2xi8>) outs(%[[INIT]] : tensor<4x2x5xi8>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: %[[C_NEG_8:.*]] = arith.constant -8 : i8 +// CHECK: %[[C_7:.*]] = arith.constant 7 : i8 +// CHECK: %[[STORED_CLAMPED_TEMP:.*]] = arith.maxsi %[[STORED_INT]], %[[C_NEG_8]] : i8 +// CHECK: %[[STORED_CLAMPED:.*]] = arith.minsi %[[STORED_CLAMPED_TEMP]], %[[C_7]] : i8 +// CHECK: linalg.yield %[[STORED_CLAMPED]] : i8 +// CHECK: } -> tensor<4x2x5xi8> + +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[GENERIC]] : tensor<4x2x5xi8> to tensor<4x2x5x!quant.uniform:f32:1, {2.000000e+00,3.000000e+00}>> +// CHECK: return %[[STORED_QUANT]] : tensor<4x2x5x!quant.uniform:f32:1, {2.000000e+00,3.000000e+00}>> + +!qalias = !quant.uniform:f32:1, {2.0, 3.0}> +func.func @qcast_per_channel_ranked_bounds(%arg0: tensor<4x2x5xf32>) -> tensor<4x2x5x!qalias> { + %0 = quant.qcast %arg0 : tensor<4x2x5xf32> to tensor<4x2x5x!qalias> + return %0 : tensor<4x2x5x!qalias> +} + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: @qcast_per_channel_unranked +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32> + +// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -> tensor +// CHECK: %[[CHANNEL_AXIS:.*]] = arith.constant 2 : index +// CHECK: %[[CHANNEL_AXIS_NEXT:.*]] = arith.constant 3 : index +// CHECK: %[[SHAPE_LEFT:.*]], %[[DISCARDED_0:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_LEFT:.*]] = shape.num_elements %[[SHAPE_LEFT]] : tensor -> index +// CHECK: %[[DISCARDED_1:.*]], %[[SHAPE_RIGHT:.*]] = "shape.split_at"(%[[SHAPE]], %[[CHANNEL_AXIS_NEXT]]) : (tensor, index) -> (tensor, tensor) +// CHECK: %[[SIZE_RIGHT:.*]] = shape.num_elements %[[SHAPE_RIGHT]] : tensor -> index + +// CHECK: %[[CHANNEL_AXIS_SIZE:.*]] = arith.constant 3 : index +// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[SIZE_LEFT]], %[[CHANNEL_AXIS_SIZE]], %[[SIZE_RIGHT]] : tensor<3xindex> +// CHECK: %[[FLAT_INPUT:.*]] = tensor.reshape %[[ARG_0]](%[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor + +// CHECK: %[[SCALES:.*]] = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK: %[[ZERO_POINTS:.*]] = arith.constant dense<[10, 20, 30]> : tensor<3xi8> + +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_0]] : tensor +// CHECK: %[[C_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[FLAT_INPUT]], %[[C_2]] : tensor +// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_2]]) : tensor + +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[FLAT_INPUT]], %[[SCALES]], %[[ZERO_POINTS]] : tensor, tensor<3xf32>, tensor<3xi8>) outs(%[[INIT]] : tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[SCALE:.*]]: f32, %[[ZERO_POINT:.*]]: i8, %[[OUT:.*]]: i8): +// CHECK: %[[SCALED:.*]] = arith.divf %[[IN]], %[[SCALE]] : f32 +// CHECK: %[[ZERO_POINT_FLOAT:.*]] = arith.sitofp %[[ZERO_POINT]] : i8 to f32 +// CHECK: %[[STORED_FLOAT:.*]] = arith.addf %[[SCALED]], %[[ZERO_POINT_FLOAT]] : f32 +// CHECK: %[[STORED_INT:.*]] = arith.fptosi %[[STORED_FLOAT]] : f32 to i8 +// CHECK: linalg.yield %[[STORED_INT]] : i8 +// CHECK: } -> tensor + +// CHECK: %[[STORED_UNRANKED:.*]] = tensor.reshape %[[GENERIC]](%[[SHAPE]]) : (tensor, tensor) -> tensor<*xi8> +// CHECK: %[[STORED_QUANT:.*]] = quant.scast %[[STORED_UNRANKED]] : tensor<*xi8> to tensor<*x!quant.uniform> +// CHECK: return %[[STORED_QUANT]] : tensor<*x!quant.uniform> + +!qalias = !quant.uniform +func.func @qcast_per_channel_unranked(%arg0: tensor<*xf32>) -> tensor<*x!qalias> { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return %0 : tensor<*x!qalias> +} + diff --git a/mlir/test/Dialect/Quant/ops.mlir b/mlir/test/Dialect/Quant/ops.mlir new file mode 100644 index 000000000000..4abc5830d081 --- /dev/null +++ b/mlir/test/Dialect/Quant/ops.mlir @@ -0,0 +1,151 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +!qalias = !quant.uniform +func.func @dcast_scalar(%arg0: !qalias) { + %0 = quant.dcast %arg0 : !qalias to f32 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_ranked(%arg0: tensor<2x?x4x!qalias>) { + %0 = quant.dcast %arg0 : tensor<2x?x4x!qalias> to tensor<2x?x4xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_unranked(%arg0: tensor<*x!qalias>) { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_static(%arg0: tensor<1x2x3x!qalias>) { + %0 = quant.dcast %arg0 : tensor<1x2x3x!qalias> to tensor<1x2x3xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.dcast %arg0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @dcast_per_axis_unranked(%arg0: tensor<*x!qalias>) { + %0 = quant.dcast %arg0 : tensor<*x!qalias> to tensor<*xf32> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_scalar(%arg0: f32) { + %0 = quant.qcast %arg0 : f32 to !qalias + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_ranked(%arg0: tensor<2x?x4xf32>) { + %0 = quant.qcast %arg0 : tensor<2x?x4xf32> to tensor<2x?x4x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_unranked(%arg0: tensor<*xf32>) { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_static(%arg0: tensor<1x2x3xf32>) { + %0 = quant.qcast %arg0 : tensor<1x2x3xf32> to tensor<1x2x3x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.qcast %arg0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @qcast_per_axis_unranked(%arg0: tensor<*xf32>) { + %0 = quant.qcast %arg0 : tensor<*xf32> to tensor<*x!qalias> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_scalar(%arg0: i8) { + %0 = quant.scast %arg0 : i8 to !qalias + %1 = quant.scast %0 : !qalias to i8 + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_ranked(%arg0: tensor<2x?x4xi8>) { + %0 = quant.scast %arg0 : tensor<2x?x4xi8> to tensor<2x?x4x!qalias> + %1 = quant.scast %0 : tensor<2x?x4x!qalias> to tensor<2x?x4xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_unranked(%arg0: tensor<*xi8>) { + %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias> + %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_static(%arg0: tensor<1x2x3xi8>) { + %0 = quant.scast %arg0 : tensor<1x2x3xi8> to tensor<1x2x3x!qalias> + %1 = quant.scast %0 : tensor<1x2x3x!qalias> to tensor<1x2x3xi8> + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_dynamic(%arg0: tensor) { + %0 = quant.scast %arg0 : tensor to tensor + %1 = quant.scast %0 : tensor to tensor + return +} + +// ----- + +!qalias = !quant.uniform +func.func @scast_per_axis_unranked(%arg0: tensor<*xi8>) { + %0 = quant.scast %arg0 : tensor<*xi8> to tensor<*x!qalias> + %1 = quant.scast %0 : tensor<*x!qalias> to tensor<*xi8> + return +} + + diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir index 698f17604f80..8a5af6df73e8 100644 --- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir @@ -135,3 +135,19 @@ // provided. // expected-error@+1 {{expected floating point literal}} !qalias = !quant.uniform:f32, {2.000000e+02,-19.987200e-01:1}> + +// ----- +// Illegal negative axis in per-axis quantization +// expected-error@+1 {{illegal quantized dimension: -1}} +!qalias = !quant.uniform + +// ----- +// Scale f16 overflow +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform + + +// ----- +// Scale f16 overflow in per-axis quantization +// expected-error@+1 {{scale out of expressed type range}} +!qalias = !quant.uniform diff --git a/mlir/test/Dialect/Quant/strip-func-quant-types.mlir b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir new file mode 100644 index 000000000000..e5f0d4921bed --- /dev/null +++ b/mlir/test/Dialect/Quant/strip-func-quant-types.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s --strip-func-quant-types --split-input-file | FileCheck %s + +// CHECK-LABEL: @strip_operands +// CHECK-SAME: %[[ARG_0:.*]]: i8 +// CHECK-SAME: %[[ARG_1:.*]]: i16 +// CHECK-SAME: %[[ARG_2:.*]]: f32 + +// CHECK: %[[ARG_0_CAST:.*]] = quant.scast %[[ARG_1]] : i16 to !quant.uniform<{{.*}}> +// CHECK: %[[ARG_1_CAST:.*]] = quant.scast %[[ARG_0]] : i8 to !quant.uniform<{{.*}}> + +// CHECK: "test.custom_op"(%[[ARG_1_CAST]]) +// CHECK: "test.custom_op"(%[[ARG_0_CAST]]) +// CHECK: "test.custom_op"(%[[ARG_2]]) + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func @strip_operands(%arg0: !qalias, %arg1: !qalias1, %arg2: f32) { + "test.custom_op"(%arg0) : (!qalias) -> tensor<4x!qalias> + "test.custom_op"(%arg1) : (!qalias1) -> tensor + "test.custom_op"(%arg2) : (f32) -> tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @strip_results +// CHECK-SAME: tensor<4xi8>, tensor, tensor<*xi8>, tensor<4xf32> + +// CHECK: %[[RESULT_0:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_0:.*]] = quant.scast %[[RESULT_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8> + +// CHECK: %[[RESULT_1:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_1:.*]] = quant.scast %[[RESULT_1]] : tensor> to tensor + +// CHECK: %[[RESULT_2:.*]] = "test.custom_op"() +// CHECK: %[[RESULT_CAST_2:.*]] = quant.scast %[[RESULT_2]] : tensor<*x!quant.uniform<{{.*}}>> to tensor<*xi8> + +// CHECK: %[[RESULT_3:.*]] = "test.custom_op"() + +// CHECK: return %[[RESULT_CAST_0]], %[[RESULT_CAST_1]], %[[RESULT_CAST_2]], %[[RESULT_3]] + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func @strip_results() -> (tensor<4x!qalias>, tensor, tensor<*x!qalias>, tensor<4xf32>) { + %0 = "test.custom_op"() : () -> tensor<4x!qalias> + %1 = "test.custom_op"() : () -> tensor + %2 = "test.custom_op"() : () -> tensor<*x!qalias> + %3 = "test.custom_op"() : () -> tensor<4xf32> + return %0, %1, %2, %3 : tensor<4x!qalias>, tensor, tensor<*x!qalias>, tensor<4xf32> +} + +// ----- + + +// CHECK-LABEL: @callee +// CHECK-SAME: (tensor<4xi8>, tensor) -> (tensor<*xi8>, tensor<4xf32>) + +// CHECK-LABEL: @strip_call + +// CHECK: %[[OPERAND_0:.*]] = "test.custom_op"() +// CHECK: %[[OPERAND_0_CAST:.*]] = quant.scast %[[OPERAND_0]] : tensor<4x!quant.uniform<{{.*}}>> to tensor<4xi8> + +// CHECK: %[[OPERAND_1:.*]] = "test.custom_op"() +// CHECK: %[[OPERAND_1_CAST:.*]] = quant.scast %[[OPERAND_1]] : tensor> to tensor + +// CHECK: %[[RESULTS:.*]]:2 = call @callee(%[[OPERAND_0_CAST]], %[[OPERAND_1_CAST]]) + +// CHECK: %[[RESULT_0_CAST:.*]] = quant.scast %[[RESULTS]]#0 : tensor<*xi8> to tensor<*x!quant.uniform<{{.*}}>> +// CHECK: "test.custom_op"(%[[RESULT_0_CAST]]) + +// CHECK: "test.custom_op"(%[[RESULTS]]#1) + +// CHECK: return + +!qalias = !quant.uniform +!qalias1 = !quant.uniform + +func.func private @callee(tensor<4x!qalias>, tensor) -> (tensor<*x!qalias>, tensor<4xf32>) + +func.func @strip_call() { + %0 = "test.custom_op"() : () -> tensor<4x!qalias> + %1 = "test.custom_op"() : () -> tensor + %2:2 = func.call @callee(%0, %1) : (tensor<4x!qalias>, tensor) -> (tensor<*x!qalias>, tensor<4xf32>) + "test.custom_op"(%2#0) : (tensor<*x!qalias>) -> () + "test.custom_op"(%2#1) : (tensor<4xf32>) -> () + return +}