From 5484560d5fe44a058652ad3523ae252c9b58dc30 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Wed, 29 Jan 2025 19:11:36 -0800 Subject: [PATCH 1/6] init code structure for matmul 2 bits Signed-off-by: Liqun Fu --- cmake/onnxruntime_mlas.cmake | 2 + .../cpu/quantization/matmul_nbits.cc | 39 +- .../cpu/quantization/matmul_nbits_impl.cc | 35 +- .../cpu/quantization/matmul_nbits_impl.h | 2 +- onnxruntime/core/mlas/inc/mlas_q4.h | 2 +- onnxruntime/core/mlas/lib/q4_dq.cpp | 363 ++++++++++++------ onnxruntime/core/mlas/lib/qnbitgemm.cpp | 124 ++++-- onnxruntime/core/mlas/lib/qnbitgemm.h | 10 + .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 86 +++++ .../mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h | 52 +++ .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 10 + .../test/contrib_ops/matmul_4bits_test.cc | 261 +++++++------ .../test/mlas/bench/bench_qnbitgemm.cpp | 4 +- .../test/mlas/unittest/test_blockq4.cpp | 2 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 37 +- .../test/optimizer/graph_transform_test.cc | 2 +- 16 files changed, 720 insertions(+), 311 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index ed3ad89247975..90667e488ffe8 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -182,6 +182,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp @@ -586,6 +587,7 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp ) if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) set(mlas_platform_srcs_avx2 diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index c3e43f897c509..f9e0795b6dbfe 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -111,8 +111,8 @@ class MatMulNBits final : public OpKernel { has_unquantized_zero_point_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; } - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(nbits_ == 2 || nbits_ == 4, + "Only 2 and 4b quantization is supported for MatMulNBits op, additional bits support is planned."); const Tensor* tensor_zero_point = nullptr; has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point); } @@ -436,17 +436,30 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - static_cast(zero_points_data), // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); + // dequantize b, only 2 and 4b quantization is supported for now + if (this->nbits_ == 2) { + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else if (this->nbits_ == 4) { + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } } else { ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index 6a19a741c3028..dd3d1fd9ac2cc 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -16,6 +16,15 @@ namespace onnxruntime { namespace contrib { +template +void Dequantize2BitsKernelReOrder( + T* /*output*/, const uint8_t* /*quant_data*/, const T* /*scale_data*/, + const zeroT* /*zero_points*/, const int32_t* /*reorder_idx*/, int /*block_size*/, + int /*groups_per_threadblock*/, int /*total_groups*/, int /*out_rows*/, int /*out_cols*/, + int /*blockIdx_x*/, int /*threadIdx_x*/) { + assert(false); +} + template void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, @@ -73,7 +82,7 @@ void Dequantize4BitsKernelReOrder( } } -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input @@ -95,24 +104,36 @@ void DequantizeBlockwise( pool, static_cast(blocks_per_grid), [&](std::ptrdiff_t block_id) { for (int j = 0; j < 256; j++) { - Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, - reorder_idx, block_size, groups_per_threadblock, - total_groups, N, K, static_cast(block_id), j); + if constexpr (qbits == 2) { + Dequantize2BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } else { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } } }); } -template void DequantizeBlockwise( +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + + +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index 5061ac5c800a6..b875048cbc585 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -6,7 +6,7 @@ namespace onnxruntime { namespace contrib { -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index aec14070ffd55..80db68750799b 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -277,9 +277,9 @@ MlasBlockwiseQuantizedShape( * * If the qbits or block_size values are unsupported the output sizes will be zero. */ +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 015d69de68766..acc3cdd651751 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -402,7 +402,8 @@ template < struct BlockwiseQuantizer { // To support other qbits, need to add bit packing code for // storing to dst and zero points - static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(qbits == 4 || qbits == 2, "Only 4b block quantization is supported!"); + //static_assert(qbits != 2 || Columnwise, "Only support Columnwise in qbits == 2 case."); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; @@ -480,7 +481,7 @@ struct BlockwiseQuantizer { thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)(BitsTraits::kMid)); const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -521,40 +522,68 @@ struct BlockwiseQuantizer { } } - // !! 4b specific code as we need to pack 2 4b numbers into one byte + // !! qbits specific code as we need to pack 2 4b numbers into one byte if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + const int32_t meta_idx = meta_col * ((row_blks + 1) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; + if constexpr (qbits == 4) { zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } else if constexpr (qbits == 2) { + zero_points[meta_idx] = (zp_bytes[0] & 0x3) | ((zp_bytes[1] & 0x3) << 2) | + ((zp_bytes[2] & 0x3) << 4) | ((zp_bytes[3] & 0x3) << 6); + } else { + static_assert(false && "only support qbits of 4 and 2"); + } } - for (int32_t j = c; j < c_end; ++j) { + for (int32_t j = c; j < c_end; ++j) { // this does not work if j runs more then 1 because zp_bytes is indexed by i. const int32_t meta_c = j / QuantBlk::kColumn; - for (int32_t i = r; i < r_end; i += 2) { + for (int32_t i = r; i < r_end; i += BitsTraits::kPackSize) { const int32_t meta_r = i / QuantBlk::kRow; const float scale = static_cast(scales[meta_c * row_blks + meta_r]); const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - const int8_t zp = zp_bytes[meta_r & 1]; - const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; - - const float v0 = static_cast(src[i * leadingDimension + j]); - const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), - 0.0f, BitsTraits::kMaxFp); - - uint8_t vi1 = (uint8_t)zp; - if (i + 1 < r_end) { - float reciprocal_scale1 = reciprocal_scale; - if constexpr (QuantBlk::kRow == 1) { - const float scale1 = - static_cast(scales[meta_c * row_blks + meta_r + 1]); - reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + if constexpr (qbits == 4) { + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp1; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, + BitsTraits::kMaxFp); + } + dst[j * q_rows + i / BitsTraits::kPackSize] = (vi0 & 0xf) | (vi1 << 4); + } else { + const int8_t zp0 = zp_bytes[(i / QuantBlk::kRow) & 3]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 3]; + const int8_t zp2 = zp_bytes[((i + 2) / QuantBlk::kRow) & 3]; + const int8_t zp3 = zp_bytes[((i + 3) / QuantBlk::kRow) & 3]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp0), 0.0f, BitsTraits::kMaxFp); + uint8_t vi1 = 0, vi2 = 0, vi3 = 0; + if (i + 1 < r_end) { + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale + zp1), 0.0f, BitsTraits::kMaxFp); + } + if (i + 2 < r_end) { + const float v2 = static_cast(src[(i + 2) * leadingDimension + j]); + vi2 = (uint8_t)std::clamp(roundf(v2 * reciprocal_scale + zp2), 0.0f, BitsTraits::kMaxFp); } - const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); - vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, - BitsTraits::kMaxFp); + if (i + 3 < r_end) { + const float v3 = static_cast(src[(i + 3) * leadingDimension + j]); + vi3 = (uint8_t)std::clamp(roundf(v3 * reciprocal_scale + zp3), 0.0f, BitsTraits::kMaxFp); + } + dst[j * q_rows + i / BitsTraits::kPackSize] = (vi0 & 0x03) | ((vi1 & 0x03) << 2) | ((vi2 & 0x03) << 4) | ((vi3 & 0x03) << 6); } - - // !! 4b specific code - dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); } } }); @@ -587,6 +616,8 @@ struct BlockwiseQuantizer { const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + constexpr int pack_size = BitsTraits::kPackSize; + int q_rows, q_cols; quantizedShape(rows, columns, q_rows, q_cols); @@ -605,37 +636,78 @@ struct BlockwiseQuantizer { for (int32_t j = c; j < c_end; ++j) { const int32_t meta_col = j / QuantBlk::kColumn; - // !! 4b specific code + // !! 2 and 4b specific code // the whole loop is 4b specific due to sub 8 bit packing // and unpacking. We can potentially make this qbits generic // by wraping the packing/unpacking code like cutlass::Array - for (int32_t i = r; i < r_end; i += 2) { + for (int32_t i = r; i < r_end; i += pack_size) { const int32_t meta_row = i / QuantBlk::kRow; const float scale0 = static_cast(scales[meta_col * row_blks + meta_row]); - const int zp_pair = - (zero_points == nullptr) - ? 0x88 - : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); - - const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; - const float v0 = (static_cast(vi0) - zp0) * scale0; - - dst[j * rows + i] = static_cast(v0); - if ((i + 1) < r_end) { - float scale1 = scale0; - int zp1 = zp0; - if constexpr (QuantBlk::kRow == 1) { - scale1 = - static_cast(scales[meta_col * row_blks + meta_row + 1]); - zp1 = (zp_pair >> 4) & 0xf; + if constexpr (qbits == 4) { + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / pack_size) + meta_row / pack_size]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } else { + const int zp_quad = zero_points[meta_col * ((row_blks + 3) / pack_size) + meta_row / pack_size]; + int zp = 0; + const int meta_row_mod = meta_row % 4; + switch (meta_row_mod) { + case 0: + zp = zp_quad & 0x3; + break; + case 1: + zp = (zp_quad >> 2) & 0x3; + break; + case 2: + zp = (zp_quad >> 4) & 0x3; + break; + case 3: + zp = (zp_quad >> 6) & 0x3; + break; + } + + const uint8_t& weight = weights[j * q_rows + i / pack_size]; + const uint8_t vi0 = weight & 0x3; + const float v0 = (static_cast(vi0) - zp) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + const uint8_t vi1 = (weight >> 2) & 0x3; + const float v1 = (static_cast(vi1) - zp) * scale0; + dst[j * rows + (i + 1)] = static_cast(v1); + } + if ((i + 2) < r_end) { + const uint8_t vi2 = (weight >> 4) & 0x3; + const float v2 = (static_cast(vi2) - zp) * scale0; + dst[j * rows + (i + 2)] = static_cast(v2); + } + if ((i + 3) < r_end) { + const uint8_t vi3 = (weight >> 6) & 0x3; + const float v3 = (static_cast(vi3) - zp) * scale0; + dst[j * rows + (i + 3)] = static_cast(v3); } - const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; - const float v1 = (static_cast(vi1) - zp1) * scale1; - dst[j * rows + (i + 1)] = static_cast(v1); } } } @@ -1450,8 +1522,17 @@ MlasBlockwiseQuantizedShape( int& q_cols ); -template -void +template void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols +); + +template void MlasBlockwiseQuantizedShape( int block_size, bool columnwise, @@ -1461,9 +1542,9 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, @@ -1478,72 +1559,70 @@ MlasBlockwiseQuantizedBufferSizes( *q_zero_point_size_in_bytes = 0; } - if (qbits == 4) { - switch (block_size) { - case 16: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 32: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 64: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 128: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 256: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; - default: - // Only block size 16, 32, 64, 128, 256 are supported. - break; - } + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; } } @@ -1620,8 +1699,29 @@ MlasQuantizeBlockwise( } } -template -void +template void MLASCALL +MlasBlockwiseQuantizedBufferSizes<2>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template void MLASCALL +MlasBlockwiseQuantizedBufferSizes<4>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template void MlasQuantizeBlockwise( uint8_t* dst, float* scales, @@ -1635,6 +1735,20 @@ MlasQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool +); + template void MlasQuantizeBlockwise( @@ -1730,6 +1844,19 @@ MlasDequantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index f064a8e1d6a78..7e7baf137f604 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -33,6 +33,7 @@ enum QNBitGemmVariant { HQNBitGemmVariant_BitWidth4_CompFp16, HQNBitGemmVariant_BitWidth4_CompInt8, + SQNBitGemmVariant_BitWidth2_CompInt8, // End of valid variants // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. @@ -47,16 +48,24 @@ GetQNBitGemmVariant( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - if (BlkBitWidth == 4 && - (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == SQNBIT_CompFp32) { - return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == HQNBIT_CompFp16) { - return HQNBitGemmVariant_BitWidth4_CompFp16; - } else if (ComputeType == SQNBIT_CompInt8) { - return SQNBitGemmVariant_BitWidth4_CompInt8; - } else if (ComputeType == HQNBIT_CompInt8) { - return HQNBitGemmVariant_BitWidth4_CompInt8; + if (BlkBitWidth == 4) { + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256) { + if (ComputeType == SQNBIT_CompFp32) { + return SQNBitGemmVariant_BitWidth4_CompFp32; + } else if (ComputeType == HQNBIT_CompFp16) { + return HQNBitGemmVariant_BitWidth4_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { + return SQNBitGemmVariant_BitWidth4_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQNBitGemmVariant_BitWidth4_CompInt8; + } + } + } else if (BlkBitWidth == 2) { + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256) { + if (ComputeType == SQNBIT_CompInt8) + { + return SQNBitGemmVariant_BitWidth2_CompInt8; + } } } @@ -89,11 +98,14 @@ MlasIsQNBitGemmAvailable( Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } - case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + case SQNBitGemmVariant_BitWidth4_CompInt8: { return (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } + case SQNBitGemmVariant_BitWidth2_CompInt8: { + return (Dispatch->SQ2BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr); + } default: { return false; } @@ -120,14 +132,17 @@ QNBitGemmPerGemmWorkspaceSize( if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + } else if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->Q2BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); } + return 0; } size_t QNBitGemmPerGemmWorkspaceAlignment( - size_t BlkBitWidth, + size_t /*BlkBitWidth*/, size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) @@ -137,7 +152,8 @@ QNBitGemmPerGemmWorkspaceAlignment( return 1; } - if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { + // alignment is the same w.r.t. BlkBitWidth. + if (/*BlkBitWidth == 4 && */Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } @@ -204,6 +220,12 @@ MlasQNBitGemmPackQuantBDataSize( ); } + if (BlkBitWidth == 2 && Dispatch->Q2BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q2BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); + } + return 0; } @@ -269,9 +291,9 @@ MlasQNBitGemmPackQuantBData( ThreadPool ); } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. - //assert(QuantBScale == nullptr); - //assert(QuantBZeroPoint == nullptr); + // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. + // assert(QuantBScale == nullptr); + // assert(QuantBZeroPoint == nullptr); Dispatch->SQ4BitGemmPackQuantBData( N, K, @@ -283,6 +305,19 @@ MlasQNBitGemmPackQuantBData( ); return; } + } else if (BlkBitWidth == 2) { + if (Dispatch->SQ2BitGemmPackQuantBData != nullptr) { + Dispatch->SQ2BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + return; + } } } @@ -507,6 +542,20 @@ HQ4BitGemm_CompFp16( } } +void +SQ2BitGemm_CompInt8( + const size_t /*BlkLen*/, + const size_t /*K*/, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const /*DataParams*/, + void* const /*PerGemmWorkspace*/, + const size_t /*RangeStartM*/, + const size_t /*RangeCountM*/, + const size_t /*RangeStartN*/, + const size_t /*RangeCountN*/ +) +{ +} + void SQ4BitGemm_CompInt8( const size_t BlkLen, @@ -639,6 +688,7 @@ SQ4BitGemm_CompInt8( template void InitializeWorkspace_CompInt8( + size_t BlkBitWidth, size_t M, size_t N, size_t K, @@ -653,6 +703,7 @@ InitializeWorkspace_CompInt8( template <> void InitializeWorkspace_CompInt8( + size_t BlkBitWidth, size_t M, size_t N, size_t K, @@ -667,26 +718,14 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + // TODO: THIS is temporary: in case of BlkBitWidth == 2 we want to force use QuantizeARow even if + // QuantizeARowComputeBlkSum_CompInt8 is available. + const auto QuantizeARow2 = BlkBitWidth == 2 ? nullptr : GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (QuantizeARow) { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); - } else { + if (QuantizeARow2) { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; const float* ARowPtr = data.A; @@ -704,12 +743,26 @@ InitializeWorkspace_CompInt8( QuantARowBlkSum += BlockCountK; } }); + } else { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); } } template <> void InitializeWorkspace_CompInt8( + size_t BlkBitWidth, size_t M, size_t N, size_t K, @@ -720,6 +773,7 @@ InitializeWorkspace_CompInt8( size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ) { + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); MLAS_UNREFERENCED_PARAMETER(M); MLAS_UNREFERENCED_PARAMETER(N); MLAS_UNREFERENCED_PARAMETER(K); @@ -733,6 +787,7 @@ InitializeWorkspace_CompInt8( template using InitializeWorkspaceFn = std::function; default: return nullptr; @@ -797,6 +853,8 @@ GetQNBitGemm(QNBitGemmVariant variant) return SQ4BitGemm_CompFp32; case SQNBitGemmVariant_BitWidth4_CompInt8: return SQ4BitGemm_CompInt8; + case SQNBitGemmVariant_BitWidth2_CompInt8: + return SQ2BitGemm_CompInt8; default: return nullptr; } @@ -849,7 +907,7 @@ MlasQNBitGemmBatch( if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( - M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + BlkBitWidth, M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index eb3d0b44ae3de..c0dd11e2444ee 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -100,6 +100,13 @@ struct MLAS_QNBIT_GEMM_DISPATCH { Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + // TODO: rename Q4BitGemmPackQuantBDataSize_Fn to QNBitGemmPackQuantBDataSize_Fn + // because its signature shall be the same regardness of bit width. + // or has bit width as an argument so we only need one function. + // this same applied to Q4BitGemmPackQuantBData_Fn, Q4BitGemmPerGemmWorkspaceSize_Fn, + // SQ2BitGemmKernel_CompInt8. + Q4BitGemmPackQuantBDataSize_Fn* Q2BitGemmPackQuantBDataSize = nullptr; + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, @@ -113,6 +120,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* SQ2BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, @@ -152,6 +160,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { ); Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; + Q4BitGemmPerGemmWorkspaceSize_Fn* Q2BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -342,6 +351,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { ); SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; + SQ4BitGemmKernel_CompInt8_Fn* SQ2BitGemmKernel_CompInt8 = nullptr; /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp new file mode 100644 index 0000000000000..6c1a133609f69 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -0,0 +1,86 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_avx2.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx2. + +--*/ + +#include +#include +#include + +#include "qnbitgemm.h" + +size_t +Q2BitGemmPackQuantBDataSize( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +void SQ2BitGemmPackQuantBData( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* /*QuantBDataBegin*/, + std::byte* /*PackedQuantBDataBegin*/, + MLAS_THREADPOOL* /*ThreadPool*/ +) +{ +} + +size_t +Q2BitGemmPerGemmWorkspaceSize( + size_t /*M*/, + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +size_t +SQ2BitGemmKernel_CompInt8_avx2( + size_t /*BlkLen*/, + const std::byte* /*QuantA*/, + const std::byte* /*QuantBData*/, + const float* /*QuantBScale*/, + const std::byte* /*QuantBZeroPoint*/, + float* /*C*/, + size_t /*CountM*/, + size_t /*CountN*/, + size_t /*CountK*/, + size_t /*BlockCountK*/, + size_t /*ldc*/, + const float* /*Bias*/ +) +{ + return 0; +} + +void +QuantizeARow_CompInt8( + size_t /*BlkLen*/, + const float* /*A*/, + size_t /*CountK*/, + std::byte* /*QuantA*/ +) +{ +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h new file mode 100644 index 0000000000000..5e8aefb792265 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.h @@ -0,0 +1,52 @@ +#pragma once +#include "qnbitgemm.h" + +size_t Q2BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +void +SQ2BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +size_t +Q2BitGemmPerGemmWorkspaceSize( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +); + +size_t +SQ2BitGemmKernel_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias +); + +void QuantizeARow_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA +); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 81615da46aa2e..fe9720fd7e383 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -29,6 +29,8 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" + void MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) { @@ -1346,6 +1348,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 9bf08c6350833..d6940dc2cf367 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -32,8 +32,10 @@ namespace test { namespace { -constexpr int QBits = 4; +constexpr int Q2Bits = 2; +constexpr int Q4Bits = 4; +template void QuantizeDequantize(std::vector& raw_vals, std::vector& quant_vals, std::vector& scales, @@ -44,7 +46,7 @@ void QuantizeDequantize(std::vector& raw_vals, auto& ortenv = **ort_env.get(); onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); - MlasQuantizeBlockwise( + MlasQuantizeBlockwise( quant_vals.data(), scales.data(), zp != nullptr ? zp->data() : nullptr, @@ -57,7 +59,7 @@ void QuantizeDequantize(std::vector& raw_vals, tp); // Note that raw_vals is NxK after dequant - MlasDequantizeBlockwise( + MlasDequantizeBlockwise( raw_vals.data(), // dequantized output quant_vals.data(), // quantized input scales.data(), // quantization scales @@ -95,7 +97,7 @@ std::ostream& operator<<(std::ostream& os, const TestOptions& opts) { << ", has_bias:" << opts.has_bias; } -template +template void RunTest(const TestOptions& opts, std::vector>&& explicit_eps = {}) { SCOPED_TRACE(opts); @@ -121,12 +123,12 @@ void RunTest(const TestOptions& opts, #endif int q_rows, q_cols; - MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, + MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, static_cast(K), static_cast(N), q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(QBits, static_cast(opts.block_size), /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes(static_cast(opts.block_size), /* columnwise */ true, static_cast(K), static_cast(N), q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); @@ -134,7 +136,7 @@ void RunTest(const TestOptions& opts, std::vector scales(q_scale_size); std::vector zp(q_zp_size_in_bytes); - QuantizeDequantize(input1_f_vals, + QuantizeDequantize(input1_f_vals, input1_vals, scales, opts.has_zero_point ? &zp : nullptr, @@ -175,7 +177,7 @@ void RunTest(const TestOptions& opts, test.AddAttribute("K", K); test.AddAttribute("N", N); test.AddAttribute("block_size", opts.block_size); - test.AddAttribute("bits", QBits); + test.AddAttribute("bits", qbits); test.AddAttribute("accuracy_level", opts.accuracy_level); if constexpr (use_float16) { @@ -267,7 +269,7 @@ void RunTest(const TestOptions& opts, } // namespace -template +template void TestMatMulNBitsTyped() { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; @@ -282,24 +284,27 @@ void TestMatMulNBitsTyped() { base_opts.output_rel_error = 0.02f; } + if constexpr (qbits == 4) { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { TestOptions opts = base_opts; opts.has_zero_point = true; - RunTest(opts); + RunTest(opts); } #if !defined(USE_DML) && !defined(USE_WEBGPU) + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; - RunTest(opts); + RunTest(opts); } + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; @@ -316,80 +321,84 @@ void TestMatMulNBitsTyped() { // only enabled for CPU EP for now std::vector> explicit_eps; explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - RunTest(opts, std::move(explicit_eps)); + RunTest(opts, std::move(explicit_eps)); } { TestOptions opts = base_opts; opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); + RunTest(opts); } #endif // !defined(USE_DML) && !defined(USE_WEBGPU) } TEST(MatMulNBits, Float32_Accuracy0) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float32_Accuracy1) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float32_Accuracy4) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); +} + +TEST(MatMulNBits, DISABLED_Float32_Accuracy4_Q2) { + TestMatMulNBitsTyped(); } #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64) @@ -397,68 +406,68 @@ TEST(MatMulNBits, Float32_Accuracy4) { // Actual and expected difference is over 0.01 with DmlExecutionProvider. // Skip the tests instead of raising the tolerance to make is pass. TEST(MatMulNBits, Float16_Accuracy2) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_Accuracy0) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_Accuracy4) { - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); - TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } #endif #endif diff --git a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp index 64d229889214b..a511664407af0 100644 --- a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp @@ -31,8 +31,8 @@ void RunQNBitGemmBenchmark(size_t BlkLen, } size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes( - BlkBitWidth, static_cast(BlkLen), /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes( + static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index f75002f715154..11e5cec1f1e69 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -53,7 +53,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, + MlasBlockwiseQuantizedBufferSizes<4>(block_size, columnwise, rows, columns, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index e22018ae2877f..365137d466256 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -142,20 +142,37 @@ class MlasSQNBitGemmTest : public MlasTestBase { const float b_scale = QuantBScale[n * BlockCountK + k_blk]; - static_assert(BlkBitWidth == 4, "only implemented for 4-bit quantized B"); + uint8_t b_zp = 0; + if constexpr (BlkBitWidth == 4) { + b_zp = 8; + } else if constexpr (BlkBitWidth == 2) { + assert(QuantBZeroPoint && "zero point input is needed for BlkBitWidth == 2"); + } else { + static_assert(false, "only implemented for 2- and 4-bit quantized B"); + } - uint8_t b_zp = 8; + int pack_size = 8 / BlkBitWidth; if (QuantBZeroPoint != nullptr) { - const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / 2) + k_blk / 2]; - b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / pack_size) + k_blk / pack_size]; + if constexpr (BlkBitWidth == 4) { + b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); + } else if constexpr (BlkBitWidth == 2) { + int shift = (k_blk & 3) * 2; + b_zp = (b_zp_byte >> shift) & 0x03; + } } int32_t qsum = 0; for (size_t kk = 0; kk < k_blk_len; ++kk) { const int8_t qa = QuantAData[m * BlockCountK * BlkLen + k + kk]; - const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / 2]; - const int8_t qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp; + const uint8_t qb_byte = QuantBData[(n * BlockCountK * BlkLen + k + kk) / pack_size]; + int8_t qb = 0; + if constexpr (BlkBitWidth == 4) { + qb = ((kk & 1) == 1 ? (qb_byte >> 4) : (qb_byte & 0x0F)) - b_zp; + } else if constexpr (BlkBitWidth == 2) { + qb = ((qb_byte >> ((kk & 3) * 2)) & 0x03) - b_zp; + } qsum += qa * qb; } @@ -246,7 +263,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { uint8_t* QuantBZeroPoint = nullptr; { size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes(BlkLen, /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); @@ -422,13 +439,17 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 64>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 128>::RegisterShortExecuteTests(); count += SQNBitGemmShortExecuteTest<4, 256>::RegisterShortExecuteTests(); - return count; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e069f6ef2432a..f097521ddc21a 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -7930,7 +7930,7 @@ TEST_F(GraphTransformationTests, MatMulNBitsBiasFusion) { q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(qbits, block_size, /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes(block_size, /* columnwise */ true, K, N, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); From 8c1cfe11d3cc150db5427242ae6c27a1e5748cd4 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 30 Jan 2025 16:36:41 -0800 Subject: [PATCH 2/6] add and pass q4dq tests for q2bit - rename file and test name later Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/q4_dq.cpp | 99 ++++- .../test/mlas/unittest/test_blockq4.cpp | 387 +++++++++++++----- 2 files changed, 361 insertions(+), 125 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index acc3cdd651751..39d921a76fac4 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -481,7 +481,11 @@ struct BlockwiseQuantizer { thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)(BitsTraits::kMid)); + if constexpr (qbits == 2) + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)2); + if constexpr (qbits == 4) + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -524,14 +528,13 @@ struct BlockwiseQuantizer { // !! qbits specific code as we need to pack 2 4b numbers into one byte if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; if constexpr (qbits == 4) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); } else if constexpr (qbits == 2) { + const int32_t meta_idx = meta_col * ((row_blks + 3) / BitsTraits::kPackSize) + meta_row / BitsTraits::kPackSize; zero_points[meta_idx] = (zp_bytes[0] & 0x3) | ((zp_bytes[1] & 0x3) << 2) | ((zp_bytes[2] & 0x3) << 4) | ((zp_bytes[3] & 0x3) << 6); - } else { - static_assert(false && "only support qbits of 4 and 2"); } } @@ -670,7 +673,8 @@ struct BlockwiseQuantizer { dst[j * rows + (i + 1)] = static_cast(v1); } } else { - const int zp_quad = zero_points[meta_col * ((row_blks + 3) / pack_size) + meta_row / pack_size]; + const int zp_quad = (zero_points == nullptr) ? + 0xAA : zero_points[meta_col * ((row_blks + 3) / pack_size) + meta_row / pack_size]; int zp = 0; const int meta_row_mod = meta_row % 4; switch (meta_row_mod) { @@ -730,19 +734,35 @@ struct BlockwiseQuantizer { * @tparam signed_quant quantized type is signed */ template -struct BlockwiseQDQQuantizer; - -template -struct BlockwiseQDQQuantizer { +struct BlockwiseQDQQuantizer { static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) { - return (val >> (idx << 2)) & 0xF; + if constexpr (qbits == 2) { + return (val >> (idx << 1)) & 0x3; + } else if constexpr (qbits == 4) { + return (val >> (idx << 2)) & 0xF; + } } static MLAS_FORCEINLINE uint8_t SetElem(uint8_t val, int32_t idx, uint8_t dst) { - auto shift = idx << 2; - return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + if constexpr (qbits == 2) { + auto shift = idx << 1; + return ((val & 0x3) << shift) | (dst & (~(0x3 << shift))); + } else if constexpr (qbits == 4) { + auto shift = idx << 2; + return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } + } + + template + static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1, uint8_t v2, uint8_t v3) + { + if constexpr (add2) { + return ((v0 & 0x3) ^ 2) | (((v1 & 0x3) ^ 2) << 2) | (((v2 & 0x3) ^ 2) << 4) | (((v3 & 0x3) ^ 2) << 6); + } else { + return (v0 & 0x3) | ((v1 & 0x3) << 2) | ((v2 & 0x3) << 4) | ((v3 & 0x3) << 6); + } } template @@ -1491,7 +1511,7 @@ MlasBlockwiseQuantizedShape( template void -MlasBlockwiseQuantMetaShape( +MlasBlockwiseQuantMetaShape( int block_size, bool columnwise, int rows, @@ -1500,6 +1520,16 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); +template void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols +); + template void MlasBlockwiseQuantMetaShape( @@ -1901,6 +1931,19 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template bool +MlasQDQQuantizeBlockwise( + const float* src, + float* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( const MLAS_FP16* src, @@ -1940,6 +1983,36 @@ MlasQDQTransposeBlockwiseQuantized( } } +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template void MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index 11e5cec1f1e69..fbe9e8b5f0d98 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -19,6 +19,9 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" +constexpr int Q2Bits = 2; +constexpr int Q4Bits = 4; + class MlasBlockwiseQdqTest : public MlasTestBase { private: MatrixGuardBuffer FpBuf; @@ -36,6 +39,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MatrixGuardBuffer QDQTransposedOutputScales; MatrixGuardBuffer QDQTransposedOutputOffsets; + template void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); float* transposed = FpBuf2.GetBuffer(rows * columns, true); @@ -46,41 +50,79 @@ class MlasBlockwiseQdqTest : public MlasTestBase { int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes<4>(block_size, columnwise, rows, columns, + MlasBlockwiseQuantizedBufferSizes(block_size, columnwise, rows, columns, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); uint8_t* qdq_weights = QDQOutputElements.GetBuffer((rows * columns + 1) / 2, true); uint8_t* qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); - int v = 7; - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; + int pack_size = 8 / qbits; + int v; + if constexpr (qbits == 2) { + v = 1; + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += pack_size) { + int idx = c * q_rows + r / pack_size; + uint8_t v0 = static_cast(v); + v = (v + 1) % 4; + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 1) % 4; + if (v == 3) { + v = (v + 1) % 4; + } + } + uint8_t v2 = 0; + if (r + 2 < rows) { + v2 = static_cast(v); + v = (v + 1) % 4; + if (v == 3) { + v = (v + 1) % 4; + } + } + uint8_t v3 = 0; + if (r + 3 < rows) { + v3 = static_cast(v); + v = (v + 1) % 4; + if (v == 3) { + v = (v + 1) % 4; + } + } + elements[idx] = (v3 << 6) | (v2 << 4) | (v1 << 2) | v0; } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); + } + } else if constexpr(qbits == 4) { + v = 7; + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + uint8_t v0 = static_cast(v); v = (v + 5) % 16; if (v == 11 || v == 7 || v == 3) { // making the cycle 13 instead of 16, avoiding same values in a row v = (v + 5) % 16; } - } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } - elements[idx] = (v1 << 4) | v0; + elements[idx] = (v1 << 4) | v0; + } } } @@ -91,30 +133,57 @@ class MlasBlockwiseQdqTest : public MlasTestBase { uint8_t* qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); uint8_t* qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); if (zp) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; + if constexpr (qbits == 2) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += pack_size) { + int idx = c * ((meta_rows + 3) / pack_size) + r / pack_size; + uint8_t v0 = static_cast(v); + v = (v + 1) % 4; + uint8_t v1 = 0; + if (r + 1 < meta_rows) { + v1 = static_cast(v); + v = (v + 1) % 4; + } + uint8_t v2 = 0; + if (r + 2 < meta_rows) { + v2 = static_cast(v); + v = (v + 1) % 4; + } + uint8_t v3 = 0; + if (r + 3 < meta_rows) { + v3 = static_cast(v); + v = (v + 1) % 4; + } + zp[idx] = (v3 << 6) | (v2 << 4) | (v1 << 2) | v0; } - uint8_t v1 = 0; - if (r + 1 < meta_rows) { - v1 = static_cast(v); + } + } + else if constexpr (qbits == 4) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + uint8_t v0 = static_cast(v); v = (v + 5) % 16; if (v == 11 || v == 7 || v == 3) { // making the cycle 13 instead of 16, avoiding same values in a row v = (v + 5) % 16; } + uint8_t v1 = 0; + if (r + 1 < meta_rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + zp[idx] = (v1 << 4) | v0; } - zp[idx] = (v1 << 4) | v0; } } } - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, + MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); MlasTranspose(dequant_buf, transposed, columns, rows); @@ -123,48 +192,79 @@ class MlasBlockwiseQdqTest : public MlasTestBase { float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, + MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); - if (columnwise) { - bool signed_quant = MlasQDQQuantizeBlockwise( - transposed, qdq_scales, qdq_zp, qdq_weights, - true, rows, columns, block_size, threadpool_ptr); + if constexpr (qbits == 4) { + if (columnwise) { + bool signed_quant = MlasQDQQuantizeBlockwise( + transposed, qdq_scales, qdq_zp, qdq_weights, + true, rows, columns, block_size, threadpool_ptr); - ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; - if (symmetric) { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); + if (symmetric) { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); - } else { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); + } else { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + } } } - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) + if constexpr (qbits == 2) { + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += pack_size) { + int idx = c * q_rows + r / pack_size; + ASSERT_EQ(o_elements[idx] & 0x3, elements[idx] & 0x3) << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < rows) { + ASSERT_EQ((o_elements[idx] >> 2) & 0x3, (elements[idx] >> 2) & 0x3) + << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 2 < rows) { + ASSERT_EQ((o_elements[idx] >> 4) & 0x3, (elements[idx] >> 4) & 0x3) + << ", index=[" << r + 2 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 3 < rows) { + ASSERT_EQ((o_elements[idx] >> 6) & 0x3, (elements[idx] >> 6) & 0x3) + << ", index=[" << r + 3 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } - - if (r + 1 < rows) { - ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + } + } else if constexpr (qbits == 4) { + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) + << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) + if constexpr (qbits == 4) { + if (columnwise) { + ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) + << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + if (r + 1 < rows) { + ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if constexpr (qbits == 4) { + if (columnwise) { + ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) + << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } } } } @@ -177,34 +277,63 @@ class MlasBlockwiseQdqTest : public MlasTestBase { << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_scales_T[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if constexpr (qbits == 4) { + if (columnwise) { + ASSERT_EQ(qdq_scales_T[idx], scales[idx]) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } } } if (symmetric) return; - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) + + if constexpr (qbits == 2) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += pack_size) { + int idx = c * ((meta_rows + 3) / pack_size) + r / pack_size; + ASSERT_EQ(o_zp[idx] & 0x3, zp[idx] & 0x3) << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < meta_rows) { + ASSERT_EQ((o_zp[idx] >> 2) & 0x3, (zp[idx] >> 2) & 0x3) + << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 2 < meta_rows) { + ASSERT_EQ((o_zp[idx] >> 4) & 0x3, (zp[idx] >> 4) & 0x3) + << ", index=" << r + 2 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 3 < meta_rows) { + ASSERT_EQ((o_zp[idx] >> 6) & 0x3, (zp[idx] >> 6) & 0x3) + << ", index=" << r + 3 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } - if (r + 1 < meta_rows) { - ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + } + } else if constexpr (qbits == 4) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) + ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + if (r + 1 < meta_rows) { + ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (columnwise) { + ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) + << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } } } } @@ -217,44 +346,78 @@ class MlasBlockwiseQdqTest : public MlasTestBase { return suite_name.c_str(); } - void ExecuteShort(void) override { - Test(20, 1, 32, true, false); - Test(20, 1, 32, true, true); - Test(1, 20, 32, false, false); - Test(1, 20, 32, false, true); - Test(52, 1, 32, true, false); - Test(52, 1, 32, true, true); - Test(1, 52, 32, false, false); - Test(1, 52, 32, false, true); - Test(20, 3, 32, true, false); - Test(20, 3, 32, true, true); - Test(3, 20, 32, false, false); - Test(3, 20, 32, false, true); - Test(52, 3, 32, true, false); - Test(52, 3, 32, true, true); - Test(3, 52, 32, false, false); - Test(3, 52, 32, false, true); - Test(52, 3, 64, true, false); - Test(52, 3, 64, true, true); - Test(3, 52, 64, false, false); - Test(3, 52, 64, false, true); - Test(32 * 9 + 17, 41, 32, true, false); - Test(32 * 9 + 17, 41, 32, true, true); - Test(41, 32 * 9 + 17, 32, false, false); - Test(41, 32 * 9 + 17, 32, false, true); - Test(32 * 9 + 17, 41, 64, true, false); - Test(32 * 9 + 17, 41, 64, true, true); - Test(41, 32 * 9 + 17, 64, false, false); - Test(41, 32 * 9 + 17, 64, false, true); - Test(32 * 15 + 17, 63, 128, true, false); - Test(32 * 15 + 17, 63, 128, true, true); - Test(63, 32 * 15 + 17, 128, false, false); - Test(63, 32 * 15 + 17, 128, false, true); - - Test(256, 256, 32, true, false); - Test(256, 256, 32, true, true); - Test(256, 256, 32, false, false); - Test(256, 256, 32, false, true); + void ExecuteShort(void) { + // only support columnwise = true with qbits=2 + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + //Test(1, 20, 32, false, false); + //Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + //Test(1, 52, 32, false, false); + //Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + //Test(3, 20, 32, false, false); + //Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + //Test(3, 52, 32, false, false); + //Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + //Test(3, 52, 64, false, false); + //Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + //Test(41, 32 * 9 + 17, 32, false, false); + //Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + //Test(41, 32 * 9 + 17, 64, false, false); + //Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + //Test(63, 32 * 15 + 17, 128, false, false); + //Test(63, 32 * 15 + 17, 128, false, true); + + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + Test(1, 20, 32, false, false); + Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + Test(1, 52, 32, false, false); + Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + Test(3, 20, 32, false, false); + Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + Test(3, 52, 32, false, false); + Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + Test(3, 52, 64, false, false); + Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + Test(41, 32 * 9 + 17, 32, false, false); + Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + Test(41, 32 * 9 + 17, 64, false, false); + Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + Test(63, 32 * 15 + 17, 128, false, false); + Test(63, 32 * 15 + 17, 128, false, true); + + Test(256, 256, 32, true, false); + Test(256, 256, 32, true, true); + Test(256, 256, 32, false, false); + Test(256, 256, 32, false, true); } MlasBlockwiseQdqTest() = default; From f6f22e30d5e777ccc196957e8870a64de9f476ec Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 30 Jan 2025 22:41:16 -0800 Subject: [PATCH 3/6] some fixes Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/qnbitgemm.cpp | 5 +- .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 50 +++++++++++++------ .../test/contrib_ops/matmul_4bits_test.cc | 11 ++-- .../test/mlas/unittest/test_sqnbitgemm.cpp | 12 +++-- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 7e7baf137f604..096c795b4e1c5 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -554,6 +554,7 @@ SQ2BitGemm_CompInt8( const size_t /*RangeCountN*/ ) { + // TODO: implement this to call 2bit t-mac kernel } void @@ -920,7 +921,7 @@ MlasQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; @@ -991,7 +992,7 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 6c1a133609f69..1d7a1ce73e1d9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -15,45 +15,63 @@ Module Name: --*/ -#include -#include -#include - #include "qnbitgemm.h" +#include "sqnbitgemm_q8_block.h" size_t Q2BitGemmPackQuantBDataSize( - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - return 0; + // TODO: This code shall change according to T-Mac. + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 2; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; } void SQ2BitGemmPackQuantBData( size_t /*N*/, size_t /*K*/, size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/, const std::byte* /*QuantBDataBegin*/, std::byte* /*PackedQuantBDataBegin*/, MLAS_THREADPOOL* /*ThreadPool*/ ) { + // TODO: need implementation } size_t Q2BitGemmPerGemmWorkspaceSize( - size_t /*M*/, - size_t /*N*/, - size_t /*K*/, - size_t /*BlkLen*/, - MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ + size_t M, + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - return 0; + MLAS_UNREFERENCED_PARAMETER(N); + + switch (ComputeType) { + case SQNBIT_CompInt8: { + // workspace buffer is used for block quantization of A to int8 + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + // QuantData + Scale + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } + default: { + return 0; + } + } } size_t diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index d6940dc2cf367..bfd682ae3918f 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -97,7 +97,7 @@ std::ostream& operator<<(std::ostream& os, const TestOptions& opts) { << ", has_bias:" << opts.has_bias; } -template +template void RunTest(const TestOptions& opts, std::vector>&& explicit_eps = {}) { SCOPED_TRACE(opts); @@ -284,8 +284,7 @@ void TestMatMulNBitsTyped() { base_opts.output_rel_error = 0.02f; } - if constexpr (qbits == 4) - { + if constexpr (qbits == 4) { TestOptions opts = base_opts; RunTest(opts); } @@ -297,15 +296,13 @@ void TestMatMulNBitsTyped() { } #if !defined(USE_DML) && !defined(USE_WEBGPU) - if constexpr (qbits == 4) - { + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; RunTest(opts); } - if constexpr (qbits == 4) - { + if constexpr (qbits == 4) { TestOptions opts = base_opts; opts.has_g_idx = true; opts.has_bias = true; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 365137d466256..26f02466be450 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -146,17 +146,18 @@ class MlasSQNBitGemmTest : public MlasTestBase { if constexpr (BlkBitWidth == 4) { b_zp = 8; } else if constexpr (BlkBitWidth == 2) { - assert(QuantBZeroPoint && "zero point input is needed for BlkBitWidth == 2"); + b_zp = 2; } else { static_assert(false, "only implemented for 2- and 4-bit quantized B"); } int pack_size = 8 / BlkBitWidth; if (QuantBZeroPoint != nullptr) { - const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / pack_size) + k_blk / pack_size]; if constexpr (BlkBitWidth == 4) { + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 1) / pack_size) + k_blk / pack_size]; b_zp = (k_blk & 1) ? (b_zp_byte >> 4) : (b_zp_byte & 0x0F); } else if constexpr (BlkBitWidth == 2) { + const uint8_t b_zp_byte = QuantBZeroPoint[n * ((BlockCountK + 3) / pack_size) + k_blk / pack_size]; int shift = (k_blk & 3) * 2; b_zp = (b_zp_byte >> shift) & 0x03; } @@ -396,6 +397,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - //count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); From 3e1a951448fb37664a4f8d41e994b4142ea98978 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Mon, 3 Feb 2025 12:24:40 -0800 Subject: [PATCH 4/6] add apis to neon and other avxs Signed-off-by: Liqun Fu --- .../core/mlas/lib/qnbitgemm_kernel_neon.cpp | 62 +++++++++++++++++++ .../lib/sqnbitgemm_bitnet_kernel_avx2.cpp | 2 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 9 +++ .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 10 +++ .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 9 +++ .../test/mlas/unittest/test_sqnbitgemm.cpp | 2 - 6 files changed, 92 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index d05de64e68ec8..b12e2358d77bd 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -167,6 +167,61 @@ Q4BitGemmPerGemmWorkspaceAlignment( } } +size_t +Q2BitGemmPackQuantBDataSize( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +void +SQ2BitGemmPackQuantBData( + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* /*QuantBDataBegin*/, + std::byte* /*PackedQuantBDataBegin*/, + MLAS_THREADPOOL* /*ThreadPool*/ +) +{ +} + +size_t +Q2BitGemmPerGemmWorkspaceSize( + size_t /*M*/, + size_t /*N*/, + size_t /*K*/, + size_t /*BlkLen*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /*ComputeType*/ +) +{ + return 0; +} + +size_t +SQ2BitGemmKernel_CompInt8_avx2( + size_t /*BlkLen*/, + const std::byte* /*QuantA*/, + const std::byte* /*QuantBData*/, + const float* /*QuantBScale*/, + const std::byte* /*QuantBZeroPoint*/, + float* /*C*/, + size_t /*CountM*/, + size_t /*CountN*/, + size_t /*CountK*/, + size_t /*BlockCountK*/, + size_t /*ldc*/, + const float* /*Bias*/ +) +{ + return 0; +} + } // namespace } // namespace sqnbitgemm_neon @@ -197,5 +252,12 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp index 1d7a1ce73e1d9..d6d104967e3a7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_bitnet_kernel_avx2.cpp @@ -90,6 +90,7 @@ SQ2BitGemmKernel_CompInt8_avx2( const float* /*Bias*/ ) { + // reference SQ4BitGemmKernel_CompInt8_avx2 return 0; } @@ -101,4 +102,5 @@ QuantizeARow_CompInt8( std::byte* /*QuantA*/ ) { + // shall be similar to QuantizeARow_CompInt8_avx2 without blksum related code. } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index fe9720fd7e383..56c54cf9befb4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1375,5 +1375,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index b4e25d4e4040a..d07ba72d1ed8b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -32,6 +32,7 @@ Module Name: // #include "sqnbitgemm_kernel_avx_common_fp32.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32_avx512( @@ -368,5 +369,14 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index a4468bb906bbc..83fba19c1702d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -27,6 +27,7 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" +#include "sqnbitgemm_bitnet_kernel_avx2.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( @@ -353,5 +354,13 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; + // change funcions if implementation are different from avx2 + d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + + d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + + d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; return d; }(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 26f02466be450..d849118aae7ef 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -147,8 +147,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { b_zp = 8; } else if constexpr (BlkBitWidth == 2) { b_zp = 2; - } else { - static_assert(false, "only implemented for 2- and 4-bit quantized B"); } int pack_size = 8 / BlkBitWidth; From 013006100158dd5ef8f0ec662716d67008c1ecf5 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Mon, 3 Feb 2025 12:50:04 -0800 Subject: [PATCH 5/6] fix neon build Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index b12e2358d77bd..6fcc530ff11a8 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -252,12 +252,12 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 - d.Q2BitGemmPackQuantBDataSize = Q2BitGemmPackQuantBDataSize; - d.SQ2BitGemmPackQuantBData = SQ2BitGemmPackQuantBData; + d.Q2BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q2BitGemmPackQuantBDataSize; + d.SQ2BitGemmPackQuantBData = sqnbitgemm_neon::SQ2BitGemmPackQuantBData; - d.Q2BitGemmPerGemmWorkspaceSize = Q2BitGemmPerGemmWorkspaceSize; + d.Q2BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q2BitGemmPerGemmWorkspaceSize; - d.SQ2BitGemmKernel_CompInt8 = SQ2BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + d.SQ2BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ2BitGemmKernel_CompInt8_avx2; + d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; return d; }(); From b4aad0134c3d1cb7f2e43e05fa299abfa14eb3c5 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Mon, 3 Feb 2025 14:05:51 -0800 Subject: [PATCH 6/6] disable 2bit test Signed-off-by: Liqun Fu --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 1 + onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bfd682ae3918f..468243791e25a 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -394,6 +394,7 @@ TEST(MatMulNBits, Float32_Accuracy4) { TestMatMulNBitsTyped(); } +// TODO: enable and add more tests for 2bit development. TEST(MatMulNBits, DISABLED_Float32_Accuracy4_Q2) { TestMatMulNBitsTyped(); } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index d849118aae7ef..fee0eacc246dd 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -443,8 +443,9 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture::RegisterShortExecuteTests(); - count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + //count += SQNBitGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); //count += SQNBitGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests();