Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 189 additions & 2 deletions src/layer/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,24 @@ int Gemm::load_model(const ModelBin& mb)
if (transB == 0)
B_data = mb.load(constantN, constantK, 0);
else
B_data = mb.load(constantK, constantN, 0);
{
if (int8_scale_term == 5)
{
// int6 block quantize
// TODO auto type for int6 storage
B_data = mb.load((constantK + 3) / 4 * 3, constantN, 0);
}
else if (int8_scale_term == 6)
{
// int4 block quantize
// TODO auto type for int4 storage
B_data = mb.load((constantK + 1) / 2, constantN, 0);
}
else
{
B_data = mb.load(constantK, constantN, 0);
}
}
if (B_data.empty())
return -100;
}
Expand Down Expand Up @@ -119,7 +136,177 @@ int Gemm::load_model(const ModelBin& mb)

if (constantB == 1)
{
B_data_int8_scale = mb.load(1, 1)[0];
if (int8_scale_term == 4)
{
// int8 block quantize
// assert transB == 1 // FIXME hardcode
const int block_size = 64; // FIXME hardcode
const int block_count = (constantK + block_size - 1) / block_size;

B_data_quantize_scales = mb.load(block_count, constantN, 0);

// dequantize B_data to fp32
Mat B_data_fp32(constantK, constantN);
if (B_data_fp32.empty())
return -100;

for (int i = 0; i < constantN; i++)
{
const signed char* i8ptr = B_data.row<const signed char>(i);
float* ptr = B_data_fp32.row(i);
float* scale_ptr = B_data_quantize_scales.row(i);

for (int j = 0; j < block_count; j++)
{
// block quantize
const signed char* i8ptr1 = i8ptr + j * block_size;
const float inv_scale = 1.f / scale_ptr[j];
float* ptr1 = ptr + j * block_size;
const int block_size1 = std::min(block_size, constantK - j * block_size);

for (int k = 0; k < block_size1; k++)
{
ptr1[k] = i8ptr1[k] * inv_scale;
}
}
}

B_data = B_data_fp32;

// reset int8_scale_term to use fp32 path
int8_scale_term = 0;
}
else if (int8_scale_term == 5)
{
// int6 block quantize
// assert transB == 1 // FIXME hardcode
const int block_size = 64; // FIXME hardcode
const int block_count = (constantK + block_size - 1) / block_size;

B_data_quantize_scales = mb.load(block_count, constantN, 0);

// dequantize B_data to fp32
Mat B_data_fp32(constantK, constantN);
if (B_data_fp32.empty())
return -100;

union i6x4_t
{
signed char i6[3];
struct
{
signed char i6_a : 6;
signed char i6_b : 6;
signed char i6_c : 6;
signed char i6_d : 6;
} __attribute__((packed));
};
Comment on lines +195 to +206
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __attribute__((packed)) attribute is GCC-specific and not portable. This will fail on MSVC. Consider using #pragma pack for cross-platform compatibility or conditionally compile based on compiler.

Suggested change
union i6x4_t
{
signed char i6[3];
struct
{
signed char i6_a : 6;
signed char i6_b : 6;
signed char i6_c : 6;
signed char i6_d : 6;
} __attribute__((packed));
};
#if defined(_MSC_VER)
#pragma pack(push, 1)
#endif
union i6x4_t
{
signed char i6[3];
#if defined(_MSC_VER)
struct
{
signed char i6_a : 6;
signed char i6_b : 6;
signed char i6_c : 6;
signed char i6_d : 6;
};
#else
struct
{
signed char i6_a : 6;
signed char i6_b : 6;
signed char i6_c : 6;
signed char i6_d : 6;
} __attribute__((packed));
#endif
};
#if defined(_MSC_VER)
#pragma pack(pop)
#endif

Copilot uses AI. Check for mistakes.

for (int i = 0; i < constantN; i++)
{
const i6x4_t* i6ptr = B_data.row<const i6x4_t>(i);
float* ptr = B_data_fp32.row(i);
float* scale_ptr = B_data_quantize_scales.row(i);

for (int j = 0; j < block_count; j++)
{
// block quantize
const i6x4_t* i6ptr1 = i6ptr + j * block_size / 4;
const float inv_scale = 1.f / scale_ptr[j];
float* ptr1 = ptr + j * block_size;
const int block_size1 = std::min(block_size, constantK - j * block_size);

int k = 0;
for (; k + 3 < block_size1; k += 4)
{
ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale;
ptr1[k + 1] = i6ptr1[k / 4].i6_b * inv_scale;
ptr1[k + 2] = i6ptr1[k / 4].i6_c * inv_scale;
ptr1[k + 3] = i6ptr1[k / 4].i6_d * inv_scale;
}
for (; k + 2 < block_size1; k += 3)
{
ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale;
ptr1[k + 1] = i6ptr1[k / 4].i6_b * inv_scale;
ptr1[k + 2] = i6ptr1[k / 4].i6_c * inv_scale;
}
for (; k + 1 < block_size1; k += 2)
{
ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale;
ptr1[k + 1] = i6ptr1[k / 4].i6_b * inv_scale;
}
for (; k < block_size1; k++)
{
ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale;
}
}
}

B_data = B_data_fp32;

// reset int8_scale_term to use fp32 path
int8_scale_term = 0;
}
else if (int8_scale_term == 6)
{
// int4 block quantize
// assert transB == 1 // FIXME hardcode
const int block_size = 64; // FIXME hardcode
const int block_count = (constantK + block_size - 1) / block_size;

B_data_quantize_scales = mb.load(block_count, constantN, 0);

// dequantize B_data to fp32
Mat B_data_fp32(constantK, constantN);
if (B_data_fp32.empty())
return -100;

union i4x2_t
{
signed char i4;
struct
{
signed char i4_low : 4;
signed char i4_high : 4;
} __attribute__((packed));
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __attribute__((packed)) attribute is GCC-specific and not portable. This will fail on MSVC. Consider using #pragma pack for cross-platform compatibility or conditionally compile based on compiler.

Copilot uses AI. Check for mistakes.
};

for (int i = 0; i < constantN; i++)
{
const i4x2_t* i4ptr = B_data.row<const i4x2_t>(i);
float* ptr = B_data_fp32.row(i);
float* scale_ptr = B_data_quantize_scales.row(i);

for (int j = 0; j < block_count; j++)
{
// block quantize
const i4x2_t* i4ptr1 = i4ptr + j * block_size / 2;
const float inv_scale = 1.f / scale_ptr[j];
float* ptr1 = ptr + j * block_size;
const int block_size1 = std::min(block_size, constantK - j * block_size);

int k = 0;
for (; k + 2 < block_size1; k += 2)
{
ptr1[k] = i4ptr1[k / 2].i4_low * inv_scale;
ptr1[k + 1] = i4ptr1[k / 2].i4_high * inv_scale;
}
for (; k < block_size1; k++)
{
ptr1[k] = i4ptr1[k / 2].i4_low * inv_scale;
}
}
}

B_data = B_data_fp32;

// reset int8_scale_term to use fp32 path
int8_scale_term = 0;
}
else
{
B_data_int8_scale = mb.load(1, 1)[0];
}
}
}
#endif // NCNN_INT8
Expand Down
1 change: 1 addition & 0 deletions src/layer/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Gemm : public Layer
#if NCNN_INT8
Mat A_data_int8_scales;
float B_data_int8_scale;
Mat B_data_quantize_scales;
#endif
};

Expand Down
13 changes: 10 additions & 3 deletions tools/modelwriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1810,9 +1810,16 @@ int ModelWriter::save(const char* parampath, const char* binpath)
}
if (op->constantB == 1)
{
ncnn::Mat B_data_int8_scales(1);
B_data_int8_scales[0] = op->B_data_int8_scale;
fwrite_weight_data(B_data_int8_scales, bp, 90, 100);
if (op->int8_scale_term == 4 || op->int8_scale_term == 5 || op->int8_scale_term == 6)
{
fwrite_weight_tag_data(op->B_data_quantize_scales, bp);
}
else
{
ncnn::Mat B_data_int8_scales(1);
B_data_int8_scales[0] = op->B_data_int8_scale;
fwrite_weight_data(B_data_int8_scales, bp, 90, 100);
}
}
}
#endif // NCNN_INT8
Expand Down
3 changes: 3 additions & 0 deletions tools/quantize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ endif()
add_executable(ncnn2int8 ncnn2int8.cpp)
target_link_libraries(ncnn2int8 PRIVATE ncnn)

add_executable(ncnnllm2int468 ncnnllm2int468.cpp)
target_link_libraries(ncnnllm2int468 PRIVATE ncnn)
Comment on lines +40 to +41
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new ncnnllm2int468 executable is not added to the virtual project group or installed via ncnn_install_tool(), unlike ncnn2int8 above. This creates inconsistency in how tools are organized and installed.

Copilot uses AI. Check for mistakes.

# add ncnn2int8 tool to a virtual project group
set_property(TARGET ncnn2int8 PROPERTY FOLDER "tools/optimization")
ncnn_install_tool(ncnn2table)
Expand Down
Loading
Loading