Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void main() {
uint qs = data_a[ib].qs[4 * ib32 + l];
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
qs |= (qh << (8 - 2 * l)) & 0x300;
const uvec2 grid = iq2s_grid[qs & 511];
const uvec2 grid = iq2s_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void main() {
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]];
const uint qs = data_a[ib].qs[8 * is + l];
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This, and iq3_xxs, were hitting KhronosGroup/glslang#3948

const uvec2 grid = iq2xxs_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
Expand Down
11 changes: 6 additions & 5 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ void main() {
const uint b_idx = 256 * ib + 32 * is;

const float d = float(data_a[ib].d);
const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf));
const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf));

// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
uint qh = data_a[ib].qh[is];
[[unroll]] for (uint l = 0; l < 8; ++l) {
uint qs = data_a[ib].qs[8 * is + l];
uint gidx = qs | ((qh << (8 - l)) & 256);
uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1));
u8vec4 grid = unpack8(iq3s_grid[gidx]);
const uint iqs = 8 * is + l;
const uint qs = data_a[ib].qs[iqs];
const uint gidx = qs | ((qh << (8 - l)) & 256);
const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1));
const u8vec4 grid = unpack8(iq3s_grid[gidx]);
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ void main() {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
// Restore parity bit.
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]);
const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]);
const uint qs0 = data_a[ib].qs[8 * is + 2 * l];
const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1];
const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]);
const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
Expand Down
Loading