@@ -320,7 +320,7 @@ def _fwd_grouped_kernel_stage1_n16x2_prefetch_k(
320320 cur_k2 = smem_kv2 .load (layout = dot_k_layout )
321321
322322 smem_k_rope .store (k_pe .T )
323- gl .amd .cdna3 .sched_barrier (0x0 )
323+ # gl.amd.cdna3.sched_barrier(0x0)
324324 split_kv_start += BLOCK_N
325325
326326 for start_n in range (split_kv_start , split_kv_end , BLOCK_N ):
@@ -794,7 +794,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
794794 K_Buffer .type .element_ty , [kv_lora_rank // 2 , BLOCK_N ], layout = shared_k
795795 )
796796
797- gl .amd .cdna3 .sched_barrier (0x0 )
797+ # gl.amd.cdna3.sched_barrier(0x0)
798798
799799 smem_kv1 .store (kv1 .T )
800800 smem_kv2 .store (kv2 .T )
@@ -825,22 +825,22 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
825825 cur_k1 = smem_kv1 .load (layout = dot_k_layout )
826826 cur_k2 = smem_kv2 .load (layout = dot_k_layout )
827827
828- gl .amd .cdna3 .sched_barrier (0x0 )
828+ # gl.amd.cdna3.sched_barrier(0x0)
829829 smem_kv1 = smem_kv1 ._reinterpret (
830830 K_Buffer .type .element_ty , [BLOCK_N , kv_lora_rank // 2 ], layout = shared_v )
831831 kv1_transpose = gl .convert_layout (kv1 , kv_itt_layout )
832- gl .amd .cdna3 .sched_barrier (0x0 )
832+ # gl.amd.cdna3.sched_barrier(0x0)
833833
834834 smem_kv1 .store (kv1_transpose )
835835 smem_kv2 = smem_kv2 ._reinterpret (
836836 K_Buffer .type .element_ty , [BLOCK_N , kv_lora_rank // 2 ], layout = shared_v )
837837 kv2_transpose = gl .convert_layout (kv2 , kv_itt_layout )
838- gl .amd .cdna3 .sched_barrier (0x0 )
838+ # gl.amd.cdna3.sched_barrier(0x0)
839839
840840 smem_kv2 .store (kv2_transpose )
841841
842842 smem_k_rope .store (k_pe .T )
843- gl .amd .cdna3 .sched_barrier (0x0 )
843+ # gl.amd.cdna3.sched_barrier(0x0)
844844 split_kv_start += 1
845845
846846 mask_qk_h = gl .arange (0 , BLOCK_H , gl .SliceLayout (1 , mfma_layout_qk ))
@@ -855,12 +855,12 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
855855
856856 cur_k_pe = smem_k_rope .load (layout = dot_k_layout )
857857
858- gl .amd .cdna3 .sched_barrier (0x0 )
858+ # gl.amd.cdna3.sched_barrier(0x0)
859859 k_id = kv_loc * PAGE_BLOCK_SIZE + cur_N
860860 offs_buf_kv = k_id [:, None ] * stride_buf_kh + offs_k_c [None , :]
861861 mask_k_id = start_n * PAGE_BLOCK_SIZE + cur_N
862862 mask_k = mask_k_id < cur_batch_seq_len
863- gl .amd .cdna3 .sched_barrier (0x0 )
863+ # gl.amd.cdna3.sched_barrier(0x0)
864864
865865 qk = gl .amd .cdna3 .mfma (q0 , cur_k1 , zeros )
866866 kv1 = gl .amd .cdna3 .buffer_load (
@@ -869,7 +869,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
869869 mask = mask_k [:, None ] & mask_k_c [None , :]
870870 )
871871
872- gl .amd .cdna3 .sched_barrier (0x0 )
872+ # gl.amd.cdna3.sched_barrier(0x0)
873873
874874 qk = gl .amd .cdna3 .mfma (q1 , cur_k2 , qk )
875875 kv2 = gl .amd .cdna3 .buffer_load (
@@ -902,7 +902,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
902902 mask_k_id = start_n * PAGE_BLOCK_SIZE + cur_N_pe
903903 mask_k_pe = mask_k_id < cur_batch_seq_len
904904
905- gl .amd .cdna3 .sched_barrier (0x0 )
905+ # gl.amd.cdna3.sched_barrier(0x0)
906906 k_pe = gl .amd .cdna3 .buffer_load (
907907 ptr = K_Buffer ,
908908 offsets = offs_buf_k_pe ,
@@ -919,7 +919,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
919919 re_scale = tl .math .exp2 ((e_max - n_e_max ) * log2e )
920920 p = tl .math .exp2 ((qk - n_e_max [:, None ]) * log2e )
921921 smem_p .store (p .to (q0 .dtype ))
922- gl .amd .cdna3 .sched_barrier (0x0 )
922+ # gl.amd.cdna3.sched_barrier(0x0)
923923
924924 cur_p = smem_p .load (layout = dot_p_layout )
925925 smem_kv1 = smem_kv1 ._reinterpret (
@@ -940,15 +940,15 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
940940
941941 cur_k1 = smem_kv1 .load (layout = dot_k_layout )
942942 kv1_transpose = gl .convert_layout (kv1 , kv_itt_layout )
943- gl .amd .cdna3 .sched_barrier (0x0 )
943+ # gl.amd.cdna3.sched_barrier(0x0)
944944 smem_kv1 = smem_kv1 ._reinterpret (
945945 K_Buffer .type .element_ty , [BLOCK_N , kv_lora_rank // 2 ], layout = shared_v )
946946
947947 smem_kv1 .store (kv1_transpose )
948948 cur_k2 = smem_kv2 .load (layout = dot_k_layout )
949949
950950 kv2_transpose = gl .convert_layout (kv2 , kv_itt_layout )
951- gl .amd .cdna3 .sched_barrier (0x0 )
951+ # gl.amd.cdna3.sched_barrier(0x0)
952952 smem_kv2 = smem_kv2 ._reinterpret (
953953 K_Buffer .type .element_ty , [BLOCK_N , kv_lora_rank // 2 ], layout = shared_v )
954954 smem_kv2 .store (kv2_transpose )
@@ -1001,7 +1001,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k_paged_64(
10011001 acc1 = acc1 * re_scale [:, None ]
10021002 acc2 = acc2 * re_scale [:, None ]
10031003 e_sum = e_sum * re_scale + gl .sum (p , 1 )
1004- gl .amd .cdna3 .sched_barrier (0x0 )
1004+ # gl.amd.cdna3.sched_barrier(0x0)
10051005 cur_p = smem_p .load (layout = dot_p_layout )
10061006 e_max = n_e_max
10071007
@@ -1352,7 +1352,7 @@ def _fwd_grouped_kernel_stage1_n16x4_prefetch_k(
13521352 cur_k2 = smem_kv2 .load (layout = dot_k_layout )
13531353
13541354 smem_k_rope .store (k_pe .T )
1355- gl .amd .cdna3 .sched_barrier (0x0 )
1355+ # gl.amd.cdna3.sched_barrier(0x0)
13561356 split_kv_start += BLOCK_N
13571357
13581358 for start_n in range (split_kv_start , split_kv_end , BLOCK_N ):
0 commit comments