Skip to content

Commit bd9ab9b

Browse files
Add a cuda kernel for dequantizing q8_0. (huggingface#1804)
1 parent 8cc0a18 commit bd9ab9b

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

candle-core/tests/quantized_tests.rs

-4
Original file line numberDiff line numberDiff line change
@@ -738,10 +738,6 @@ macro_rules! quantized_matmul {
738738
// stable. https://github.com/rust-lang/rust/issues/29599
739739
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
740740
fn $fn_name(device: &Device) -> Result<()> {
741-
if device.is_cuda() {
742-
// TODO Enable Cuda GGML sometime maybe.
743-
return Ok(());
744-
}
745741
test_matmul(device, (1, 3, 4, 256), $dtype)?;
746742
Ok(())
747743
}

candle-kernels/src/quantized.cu

+24
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,30 @@ extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, f
877877
#endif
878878
}
879879

880+
extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
881+
const int i = blockIdx.x;
882+
883+
// assume 32 threads
884+
const int tid = threadIdx.x;
885+
const int il = tid/8;
886+
const int ir = tid%8;
887+
const int ib = 8*i + ir;
888+
if (ib >= nb32) {
889+
return;
890+
}
891+
892+
float * y = yy + 256*i + 32*ir + 8*il;
893+
894+
const block_q8_0 * x = (const block_q8_0 *)vx + ib;
895+
const float d = __half2float(x->d);
896+
897+
const int8_t * q = x->qs + 8*il;
898+
899+
for (int l = 0; l < 8; ++l) {
900+
y[l] = d * q[l];
901+
}
902+
}
903+
880904
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
881905
const block_q8_K * x = (const block_q8_K *) vx;
882906

0 commit comments

Comments
 (0)