Skip to content

Commit fa386f0

Browse files
committed
disable scheduler barriers
1 parent de179ff commit fa386f0

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

aiter/ops/triton/gluon/mla_decode_fp8.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _fwd_grouped_kernel_stage1_n16x2_prefetch_k(
320320
cur_k2 = smem_kv2.load(layout=dot_k_layout)
321321

322322
smem_k_rope.store(k_pe.T)
323-
gl.amd.cdna3.sched_barrier(0x0)
323+
# gl.amd.cdna3.sched_barrier(0x0)
324324
split_kv_start += BLOCK_N
325325

326326
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
@@ -794,7 +794,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
794794
K_Buffer.type.element_ty, [kv_lora_rank // 2, BLOCK_N], layout=shared_k
795795
)
796796

797-
gl.amd.cdna3.sched_barrier(0x0)
797+
# gl.amd.cdna3.sched_barrier(0x0)
798798

799799
smem_kv1.store(kv1.T)
800800
smem_kv2.store(kv2.T)
@@ -825,22 +825,22 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
825825
cur_k1 = smem_kv1.load(layout=dot_k_layout)
826826
cur_k2 = smem_kv2.load(layout=dot_k_layout)
827827

828-
gl.amd.cdna3.sched_barrier(0x0)
828+
# gl.amd.cdna3.sched_barrier(0x0)
829829
smem_kv1 = smem_kv1._reinterpret(
830830
K_Buffer.type.element_ty, [BLOCK_N, kv_lora_rank // 2], layout=shared_v)
831831
kv1_transpose = gl.convert_layout(kv1, kv_itt_layout)
832-
gl.amd.cdna3.sched_barrier(0x0)
832+
# gl.amd.cdna3.sched_barrier(0x0)
833833

834834
smem_kv1.store(kv1_transpose)
835835
smem_kv2 = smem_kv2._reinterpret(
836836
K_Buffer.type.element_ty, [BLOCK_N, kv_lora_rank // 2], layout=shared_v)
837837
kv2_transpose = gl.convert_layout(kv2, kv_itt_layout)
838-
gl.amd.cdna3.sched_barrier(0x0)
838+
# gl.amd.cdna3.sched_barrier(0x0)
839839

840840
smem_kv2.store(kv2_transpose)
841841

842842
smem_k_rope.store(k_pe.T)
843-
gl.amd.cdna3.sched_barrier(0x0)
843+
# gl.amd.cdna3.sched_barrier(0x0)
844844
split_kv_start += 1
845845

846846
mask_qk_h = gl.arange(0, BLOCK_H, gl.SliceLayout(1, mfma_layout_qk))
@@ -855,12 +855,12 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
855855

856856
cur_k_pe = smem_k_rope.load(layout=dot_k_layout)
857857

858-
gl.amd.cdna3.sched_barrier(0x0)
858+
# gl.amd.cdna3.sched_barrier(0x0)
859859
k_id = kv_loc * PAGE_BLOCK_SIZE + cur_N
860860
offs_buf_kv = k_id[:, None] * stride_buf_kh + offs_k_c[None, :]
861861
mask_k_id = start_n * PAGE_BLOCK_SIZE + cur_N
862862
mask_k = mask_k_id < cur_batch_seq_len
863-
gl.amd.cdna3.sched_barrier(0x0)
863+
# gl.amd.cdna3.sched_barrier(0x0)
864864

865865
qk = gl.amd.cdna3.mfma(q0, cur_k1, zeros)
866866
kv1 = gl.amd.cdna3.buffer_load(
@@ -869,7 +869,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
869869
mask=mask_k[:, None] & mask_k_c[None, :]
870870
)
871871

872-
gl.amd.cdna3.sched_barrier(0x0)
872+
# gl.amd.cdna3.sched_barrier(0x0)
873873

874874
qk = gl.amd.cdna3.mfma(q1, cur_k2, qk)
875875
kv2 = gl.amd.cdna3.buffer_load(
@@ -902,7 +902,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
902902
mask_k_id = start_n * PAGE_BLOCK_SIZE + cur_N_pe
903903
mask_k_pe = mask_k_id < cur_batch_seq_len
904904

905-
gl.amd.cdna3.sched_barrier(0x0)
905+
# gl.amd.cdna3.sched_barrier(0x0)
906906
k_pe = gl.amd.cdna3.buffer_load(
907907
ptr=K_Buffer,
908908
offsets=offs_buf_k_pe,
@@ -919,7 +919,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
919919
re_scale = tl.math.exp2((e_max - n_e_max) * log2e)
920920
p = tl.math.exp2((qk - n_e_max[:, None]) * log2e)
921921
smem_p.store(p.to(q0.dtype))
922-
gl.amd.cdna3.sched_barrier(0x0)
922+
# gl.amd.cdna3.sched_barrier(0x0)
923923

924924
cur_p = smem_p.load(layout=dot_p_layout)
925925
smem_kv1 = smem_kv1._reinterpret(
@@ -940,15 +940,15 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
940940

941941
cur_k1 = smem_kv1.load(layout=dot_k_layout)
942942
kv1_transpose = gl.convert_layout(kv1, kv_itt_layout)
943-
gl.amd.cdna3.sched_barrier(0x0)
943+
# gl.amd.cdna3.sched_barrier(0x0)
944944
smem_kv1 = smem_kv1._reinterpret(
945945
K_Buffer.type.element_ty, [BLOCK_N, kv_lora_rank // 2], layout=shared_v)
946946

947947
smem_kv1.store(kv1_transpose)
948948
cur_k2 = smem_kv2.load(layout=dot_k_layout)
949949

950950
kv2_transpose = gl.convert_layout(kv2, kv_itt_layout)
951-
gl.amd.cdna3.sched_barrier(0x0)
951+
# gl.amd.cdna3.sched_barrier(0x0)
952952
smem_kv2 = smem_kv2._reinterpret(
953953
K_Buffer.type.element_ty, [BLOCK_N, kv_lora_rank // 2], layout=shared_v)
954954
smem_kv2.store(kv2_transpose)
@@ -1001,7 +1001,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
10011001
acc1 = acc1 * re_scale[:, None]
10021002
acc2 = acc2 * re_scale[:, None]
10031003
e_sum = e_sum * re_scale + gl.sum(p, 1)
1004-
gl.amd.cdna3.sched_barrier(0x0)
1004+
# gl.amd.cdna3.sched_barrier(0x0)
10051005
cur_p = smem_p.load(layout=dot_p_layout)
10061006
e_max = n_e_max
10071007

@@ -1352,7 +1352,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k(
13521352
cur_k2 = smem_kv2.load(layout=dot_k_layout)
13531353

13541354
smem_k_rope.store(k_pe.T)
1355-
gl.amd.cdna3.sched_barrier(0x0)
1355+
# gl.amd.cdna3.sched_barrier(0x0)
13561356
split_kv_start += BLOCK_N
13571357

13581358
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):

0 commit comments

Comments
 (0)