From 14b4bca625c85e28c281287a94c6e7f23623d296 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 7 Sep 2025 21:33:44 -0500 Subject: [PATCH 1/2] vulkan: fix failing dequant shaders --- ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp | 3 ++- ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp | 7 ++++--- ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp | 6 ++++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp index 48f6b65bc40ce..127c7b6424030 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -29,7 +29,7 @@ void main() { uint qs = data_a[ib].qs[4 * ib32 + l]; const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; qs |= (qh << (8 - 2 * l)) & 0x300; - const uvec2 grid = iq2s_grid[qs & 511]; + const uvec2 grid = iq2s_grid[qs]; const u8vec4 grid0 = unpack8(grid.x); const u8vec4 grid1 = unpack8(grid.y); data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp index e370690bcb089..0ae9acd02a6ca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -33,7 +33,8 @@ void main() { [[unroll]] for (uint l = 0; l < 4; ++l) { const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit - const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; + const uint qs = data_a[ib].qs[8 * is + l]; + const uvec2 grid = iq2xxs_grid[qs]; const u8vec4 grid0 = unpack8(grid.x); const u8vec4 grid1 = unpack8(grid.y); data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index c3f4bca5d95e2..b8eb0e3e68160 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -22,14 +22,15 @@ void main() { const uint b_idx = 256 * ib + 32 * is; const float d = float(data_a[ib].d); - const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); + const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf)); // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. uint qh = data_a[ib].qh[is]; [[unroll]] for (uint l = 0; l < 8; ++l) { - uint qs = data_a[ib].qs[8 * is + l]; + uint iqs = 8 * is + l; + uint qs = data_a[ib].qs[iqs]; uint gidx = qs | ((qh << (8 - l)) & 256); - uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); + uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1)); u8vec4 grid = unpack8(iq3s_grid[gidx]); data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp index a92b82961afda..19c7fdeefceed 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -35,8 +35,10 @@ void main() { const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); // Restore parity bit. const uint sign8 = sign7 | (bitCount(sign7) << 7); - const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); - const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); + const uint qs0 = data_a[ib].qs[8 * is + 2 * l]; + const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1]; + const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]); data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); From 34f0e67e198a4bb77131b64396d31591ae5f00b4 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 9 Sep 2025 08:22:58 -0500 Subject: [PATCH 2/2] add missing const --- ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index b8eb0e3e68160..e4f42be94c759 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -27,11 +27,11 @@ void main() { // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. uint qh = data_a[ib].qh[is]; [[unroll]] for (uint l = 0; l < 8; ++l) { - uint iqs = 8 * is + l; - uint qs = data_a[ib].qs[iqs]; - uint gidx = qs | ((qh << (8 - l)) & 256); - uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1)); - u8vec4 grid = unpack8(iq3s_grid[gidx]); + const uint iqs = 8 * is + l; + const uint qs = data_a[ib].qs[iqs]; + const uint gidx = qs | ((qh << (8 - l)) & 256); + const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1)); + const u8vec4 grid = unpack8(iq3s_grid[gidx]); data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));