Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) bitnet and t-mac #23540

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
Expand Down Expand Up @@ -586,6 +587,7 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_bitnet_kernel_avx2.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
Expand Down
39 changes: 26 additions & 13 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ class MatMulNBits final : public OpKernel {
has_unquantized_zero_point_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8;
}

ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
ORT_ENFORCE(nbits_ == 2 || nbits_ == 4,
"Only 2 and 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
const Tensor* tensor_zero_point = nullptr;
has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point);
}
Expand Down Expand Up @@ -436,17 +436,30 @@ Status MatMulNBits<float>::ComputeBUnpacked(const Tensor* a,
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_, true);

if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType<float>())) {
// dequantize b, only 4b quantization is supported for now
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
// dequantize b, only 2 and 4b quantization is supported for now
if (this->nbits_ == 2) {
MlasDequantizeBlockwise<float, 2>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
} else if (this->nbits_ == 4) {
MlasDequantizeBlockwise<float, 4>(
tmp_b_data_ptr.get(), // dequantized output
b_data, // quantized input
scales_data, // quantization scales
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
static_cast<int32_t>(block_size_), // quantization block size
column_wise_quant_, // columnwise quantization or row-wise
static_cast<int32_t>(K_), // number of rows in quantized input
static_cast<int32_t>(N_), // number of columns in quantized input
thread_pool);
}
} else {
ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now");
// !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
Expand Down
35 changes: 28 additions & 7 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
namespace onnxruntime {
namespace contrib {

template <class T, class zeroT>
void Dequantize2BitsKernelReOrder(
T* /*output*/, const uint8_t* /*quant_data*/, const T* /*scale_data*/,
const zeroT* /*zero_points*/, const int32_t* /*reorder_idx*/, int /*block_size*/,
int /*groups_per_threadblock*/, int /*total_groups*/, int /*out_rows*/, int /*out_cols*/,
int /*blockIdx_x*/, int /*threadIdx_x*/) {
assert(false);
}

template <class T, class zeroT>
void Dequantize4BitsKernelReOrder(
T* output, const uint8_t* quant_data, const T* scale_data,
Expand Down Expand Up @@ -73,7 +82,7 @@ void Dequantize4BitsKernelReOrder(
}
}

template <typename inputT, typename zeroT>
template <typename inputT, typename zeroT, int qbits>
void DequantizeBlockwise(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
Expand All @@ -95,24 +104,36 @@ void DequantizeBlockwise(
pool, static_cast<std::ptrdiff_t>(blocks_per_grid),
[&](std::ptrdiff_t block_id) {
for (int j = 0; j < 256; j++) {
Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points,
reorder_idx, block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
if constexpr (qbits == 2) {
Dequantize2BitsKernelReOrder(output, quant_data, scales_data, zero_points,
reorder_idx, block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
} else {
Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points,
reorder_idx, block_size, groups_per_threadblock,
total_groups, N, K, static_cast<int>(block_id), j);
}
}
});
}

template void DequantizeBlockwise<float, uint8_t>(
template void DequantizeBlockwise<float, uint8_t, 2>(
float* output, const uint8_t* quant_data, const float* scales_data,
const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);


template void DequantizeBlockwise<float, uint8_t, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, float>(
template void DequantizeBlockwise<float, float, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const float* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, MLFloat16>(
template void DequantizeBlockwise<float, MLFloat16, 4>(
float* output, const uint8_t* quant_data, const float* scales_data,
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace onnxruntime {
namespace contrib {

template <typename inputT, typename zeroT>
template <typename inputT, typename zeroT, int qbits=4>
void DequantizeBlockwise(
inputT* output, // dequantized output
const uint8_t* quant_data, // quantized input
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ MlasBlockwiseQuantizedShape(
*
* If the qbits or block_size values are unsupported the output sizes will be zero.
*/
template<int qbits>
void MLASCALL
MlasBlockwiseQuantizedBufferSizes(
int qbits,
int block_size,
bool columnwise,
int rows,
Expand Down
Loading
Loading