11#version 450
22
33#extension GL_EXT_control_flow_attributes : require
4+ #ifndef FWHT_SHMEM
45#extension GL_KHR_shader_subgroup_basic : enable
56#extension GL_KHR_shader_subgroup_shuffle : enable
7+ #endif
68
7- layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
8-
9- layout(constant_id = 0) const uint WARP_SIZE = 32;
9+ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
1010layout(constant_id = 1) const uint N = 128;
1111
12+ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
13+
1214layout(push_constant) uniform parameter
1315{
1416 uint n_rows;
@@ -20,35 +22,72 @@ layout(push_constant) uniform parameter
2022layout(binding = 0, std430) readonly buffer A { float data_a[]; };
2123layout(binding = 1, std430) writeonly buffer D { float data_d[]; };
2224
23- const uint EL_W = N / WARP_SIZE;
25+ const uint EL_W = N / BLOCK_SIZE;
26+
27+ #ifdef FWHT_SHMEM
28+ shared float shmem[4 * N];
29+ #endif
2430
2531void main() {
26- const uint lane = gl_SubgroupInvocationID;
27- for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
28- row < n_rows;
29- row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
32+ #ifdef FWHT_SHMEM
33+ const uint tid = gl_LocalInvocationID.x;
34+ const uint shmem_base = gl_LocalInvocationID.y * N;
35+ const uint row_id = gl_LocalInvocationID.y;
36+ #else
37+ const uint tid = gl_SubgroupInvocationID;
38+ const uint row_id = gl_SubgroupID;
39+ #endif
40+
41+ for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y;
42+ base_row < n_rows;
43+ base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
44+ const uint row = base_row + row_id;
3045 const uint row_offset = row * N;
3146
47+ #ifndef FWHT_SHMEM
48+ if (row >= n_rows) {
49+ continue;
50+ }
51+ #endif
52+
3253 float reg[EL_W];
3354
3455 [[unroll]]
3556 for (uint i = 0; i < EL_W; ++i) {
36- reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane ] * scale;
57+ reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid ] * scale : 0.0 ;
3758 }
3859
60+ #ifdef FWHT_SHMEM
61+ [[unroll]]
62+ for (uint h = 1; h < BLOCK_SIZE; h <<= 1) {
63+ [[unroll]]
64+ for (uint i = 0; i < EL_W; ++i) {
65+ shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i];
66+ }
67+ barrier();
68+ [[unroll]]
69+ for (uint j = 0; j < EL_W; ++j) {
70+ const float val = reg[j];
71+ const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)];
72+ reg[j] = (tid & h) == 0 ? val + other : other - val;
73+ }
74+ barrier();
75+ }
76+ #else
3977 [[unroll]]
40- for (uint h = 1; h < WARP_SIZE ; h <<= 1) {
78+ for (uint h = 1; h < BLOCK_SIZE ; h <<= 1) {
4179 [[unroll]]
4280 for (uint j = 0; j < EL_W; ++j) {
4381 const float val = reg[j];
4482 const float val2 = subgroupShuffleXor(val, h);
45- reg[j] = (lane & h) == 0 ? val + val2 : val2 - val;
83+ reg[j] = (tid & h) == 0 ? val + val2 : val2 - val;
4684 }
4785 }
86+ #endif
4887
4988 [[unroll]]
50- for (uint h = WARP_SIZE ; h < N; h <<= 1) {
51- const uint step = h / WARP_SIZE ;
89+ for (uint h = BLOCK_SIZE ; h < N; h <<= 1) {
90+ const uint step = h / BLOCK_SIZE ;
5291 [[unroll]]
5392 for (uint j = 0; j < EL_W; j += 2 * step) {
5493 [[unroll]]
@@ -61,9 +100,16 @@ void main() {
61100 }
62101 }
63102
64- [[unroll]]
65- for (uint i = 0; i < EL_W; ++i) {
66- data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i];
103+ #ifdef FWHT_SHMEM
104+ if (row < n_rows) {
105+ #endif
106+ [[unroll]]
107+ for (uint i = 0; i < EL_W; ++i) {
108+ data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i];
109+ }
110+ #ifdef FWHT_SHMEM
67111 }
112+ barrier();
113+ #endif
68114 }
69115}
0 commit comments