Skip to content

Commit 75d53dc

Browse files
committed
Fixes bias gradient computation and memory ordering
Corrects head stride calculation for bias gradients by removing division operation that was causing incorrect indexing. Reorders memory operations to improve synchronization efficiency by moving bias-related copies before sync barrier and consolidating related operations together.
1 parent 8f2bf8e commit 75d53dc

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
112112
const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)
113113
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
114114
const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb)
115-
+ (bidh / params.h_h_k_ratio) * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
115+
+ bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
116116
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
117117
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
118118
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
@@ -766,11 +766,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
766766
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
767767
// Convert dS from fp32 to fp16
768768
Tensor tdSrdS = FLASH_NAMESPACE::convert_type<Element>(dS_reshaped);
769-
770-
// Write tdSrdS to gdBias
771-
Tensor tdBiasrdS = smem_thr_copy_Bias.retile_S(tdSrdS);
772-
cute::copy(smem_tiled_copy_Bias, tdBiasrdS, tSsBias);
769+
Tensor tdBiasadS = smem_thr_copy_Bias.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
770+
cute::copy(smem_tiled_copy_Bias, tdBiasadS, tSsBias);
771+
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
772+
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
773773
__syncthreads();
774+
// Write sdBias to gdBias
774775
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
775776
gmem_tiled_copy_Bias,
776777
tBiassBias, tdBiasgdBias,
@@ -780,10 +781,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
780781
);
781782

782783
// if (cute::thread0()) { print(tPrP); }
783-
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_N, MMA_N)
784-
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
785-
__syncthreads();
786-
787784
// Layout p_l = tPrP.layout();
788785
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
789786
// FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);

0 commit comments

Comments
 (0)