Skip to content

Commit e82beaa

Browse files
authored
vulkan: add fwht support for Intel with shmem reduction (#23964)
* vulkan: add fwht support for Intel with shmem reduction * don't use N as workgroup size * disable subgroup shuffle on MoltenVK AMD * disable fwht shader on Intel Windows due to driver bug
1 parent c4a278d commit e82beaa

3 files changed

Lines changed: 76 additions & 16 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5084,6 +5084,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
50845084
}
50855085
++idx;
50865086
}
5087+
} else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) {
5088+
// Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147
5089+
int idx = 0;
5090+
for (uint32_t n : {64, 128, 256, 512}) {
5091+
const uint32_t block_size = std::min(device->subgroup_size, n);
5092+
ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1);
5093+
++idx;
5094+
}
50875095
}
50885096

50895097
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
@@ -5630,6 +5638,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
56305638
#endif
56315639
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
56325640
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
5641+
#ifdef __APPLE__
5642+
if (device->vendor_id == VK_VENDOR_ID_AMD) {
5643+
device->subgroup_shuffle = false;
5644+
}
5645+
#endif
56335646
device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
56345647
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
56355648

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
#version 450
22

33
#extension GL_EXT_control_flow_attributes : require
4+
#ifndef FWHT_SHMEM
45
#extension GL_KHR_shader_subgroup_basic : enable
56
#extension GL_KHR_shader_subgroup_shuffle : enable
7+
#endif
68

7-
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
8-
9-
layout(constant_id = 0) const uint WARP_SIZE = 32;
9+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
1010
layout(constant_id = 1) const uint N = 128;
1111

12+
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
13+
1214
layout(push_constant) uniform parameter
1315
{
1416
uint n_rows;
@@ -20,35 +22,72 @@ layout(push_constant) uniform parameter
2022
layout(binding = 0, std430) readonly buffer A { float data_a[]; };
2123
layout(binding = 1, std430) writeonly buffer D { float data_d[]; };
2224

23-
const uint EL_W = N / WARP_SIZE;
25+
const uint EL_W = N / BLOCK_SIZE;
26+
27+
#ifdef FWHT_SHMEM
28+
shared float shmem[4 * N];
29+
#endif
2430

2531
void main() {
26-
const uint lane = gl_SubgroupInvocationID;
27-
for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
28-
row < n_rows;
29-
row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
32+
#ifdef FWHT_SHMEM
33+
const uint tid = gl_LocalInvocationID.x;
34+
const uint shmem_base = gl_LocalInvocationID.y * N;
35+
const uint row_id = gl_LocalInvocationID.y;
36+
#else
37+
const uint tid = gl_SubgroupInvocationID;
38+
const uint row_id = gl_SubgroupID;
39+
#endif
40+
41+
for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y;
42+
base_row < n_rows;
43+
base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
44+
const uint row = base_row + row_id;
3045
const uint row_offset = row * N;
3146

47+
#ifndef FWHT_SHMEM
48+
if (row >= n_rows) {
49+
continue;
50+
}
51+
#endif
52+
3253
float reg[EL_W];
3354

3455
[[unroll]]
3556
for (uint i = 0; i < EL_W; ++i) {
36-
reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale;
57+
reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0;
3758
}
3859

60+
#ifdef FWHT_SHMEM
61+
[[unroll]]
62+
for (uint h = 1; h < BLOCK_SIZE; h <<= 1) {
63+
[[unroll]]
64+
for (uint i = 0; i < EL_W; ++i) {
65+
shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i];
66+
}
67+
barrier();
68+
[[unroll]]
69+
for (uint j = 0; j < EL_W; ++j) {
70+
const float val = reg[j];
71+
const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)];
72+
reg[j] = (tid & h) == 0 ? val + other : other - val;
73+
}
74+
barrier();
75+
}
76+
#else
3977
[[unroll]]
40-
for (uint h = 1; h < WARP_SIZE; h <<= 1) {
78+
for (uint h = 1; h < BLOCK_SIZE; h <<= 1) {
4179
[[unroll]]
4280
for (uint j = 0; j < EL_W; ++j) {
4381
const float val = reg[j];
4482
const float val2 = subgroupShuffleXor(val, h);
45-
reg[j] = (lane & h) == 0 ? val + val2 : val2 - val;
83+
reg[j] = (tid & h) == 0 ? val + val2 : val2 - val;
4684
}
4785
}
86+
#endif
4887

4988
[[unroll]]
50-
for (uint h = WARP_SIZE; h < N; h <<= 1) {
51-
const uint step = h / WARP_SIZE;
89+
for (uint h = BLOCK_SIZE; h < N; h <<= 1) {
90+
const uint step = h / BLOCK_SIZE;
5291
[[unroll]]
5392
for (uint j = 0; j < EL_W; j += 2 * step) {
5493
[[unroll]]
@@ -61,9 +100,16 @@ void main() {
61100
}
62101
}
63102

64-
[[unroll]]
65-
for (uint i = 0; i < EL_W; ++i) {
66-
data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i];
103+
#ifdef FWHT_SHMEM
104+
if (row < n_rows) {
105+
#endif
106+
[[unroll]]
107+
for (uint i = 0; i < EL_W; ++i) {
108+
data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i];
109+
}
110+
#ifdef FWHT_SHMEM
67111
}
112+
barrier();
113+
#endif
68114
}
69115
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ void process_shaders() {
957957
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
958958
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
959959
string_to_spv("fwht_f32", "fwht.comp", {});
960+
string_to_spv("fwht_shmem_f32", "fwht.comp", {{"FWHT_SHMEM", "1"}});
960961
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
961962
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
962963
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

0 commit comments

Comments
 (0)