@@ -112,7 +112,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
112112 const index_t row_offset_bias = binfo.bias_offset (params.bias_batch_stride , params.bias_row_stride , bidb)
113113 + (bidh / params.h_h_k_ratio ) * params.bias_head_stride + (m_block_max - 1 ) * kBlockM * params.bias_row_stride + n_block * kBlockN ;
114114 const index_t row_offset_dbias = binfo.bias_offset (params.dbias_batch_stride , params.dbias_row_stride , bidb)
115- + ( bidh / params. h_h_k_ratio ) * params.dbias_head_stride + (m_block_max - 1 ) * kBlockM * params.dbias_row_stride + n_block * kBlockN ;
115+ + bidh * params.dbias_head_stride + (m_block_max - 1 ) * kBlockM * params.dbias_row_stride + n_block * kBlockN ;
116116 const index_t row_offset_do = binfo.q_offset (params.do_batch_stride , params.do_row_stride , bidb)
117117 + (m_block_max - 1 ) * kBlockM * params.do_row_stride + bidh * params.do_head_stride ;
118118 const index_t row_offset_o = binfo.q_offset (params.o_batch_stride , params.o_row_stride , bidb)
@@ -766,11 +766,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
766766 Tensor dS_reshaped = make_tensor (dS.data (), acc_dp.layout ());
767767 // Convert dS from fp32 to fp16
768768 Tensor tdSrdS = FLASH_NAMESPACE::convert_type<Element>(dS_reshaped);
769-
770- // Write tdSrdS to gdBias
771- Tensor tdBiasrdS = smem_thr_copy_Bias .retile_S (tdSrdS);
772- cute::copy (smem_tiled_copy_Bias, tdBiasrdS, tSsBias );
769+ Tensor tdBiasadS = smem_thr_copy_Bias. retile_S (tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
770+ cute::copy (smem_tiled_copy_Bias, tdBiasadS, tSsBias);
771+ Tensor tdSadS = smem_thr_copy_PdS .retile_S (tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
772+ cute::copy (smem_tiled_copy_PdS, tdSadS, tdSsdS );
773773 __syncthreads ();
774+ // Write sdBias to gdBias
774775 FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ false >(
775776 gmem_tiled_copy_Bias,
776777 tBiassBias, tdBiasgdBias,
@@ -780,10 +781,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
780781 );
781782
782783 // if (cute::thread0()) { print(tPrP); }
783- Tensor tdSadS = smem_thr_copy_PdS.retile_S (tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
784- cute::copy (smem_tiled_copy_PdS, tdSadS, tdSsdS);
785- __syncthreads ();
786-
787784 // Layout p_l = tPrP.layout();
788785 // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
789786 // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
0 commit comments