Skip to content

Commit 21af558

Browse files
committed
Fixes variable initialization in flash attention kernels
Initializes any_active_next to false instead of any_active to prevent potential issues with unintended carry-over of active state between iterations in the kernel loops. Changes affect both forward and backward kernel implementations to ensure consistent behavior across the codebase.
1 parent 5057680 commit 21af558

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

csrc/src/flash_bwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
568568
#pragma unroll
569569
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
570570
bool any_active = __syncthreads_or(any_active_local);
571-
bool any_active_next = any_active; // to be updated later for next iteration
571+
bool any_active_next = false; // to be updated later for next iteration
572572

573573
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
574574
gmem_tiled_copy_QKV,

csrc/src/flash_fwd_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
362362
#pragma unroll
363363
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
364364
bool any_active = __syncthreads_or(any_active_local);
365-
bool any_active_next = any_active; // to be updated later for next iteration
365+
bool any_active_next = false; // to be updated later for next iteration
366366

367367
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
368368
if (any_active) {
@@ -1016,7 +1016,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
10161016
#pragma unroll
10171017
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
10181018
bool any_active = __syncthreads_or(any_active_local);
1019-
bool any_active_next = any_active; // to be updated later for next iteration
1019+
bool any_active_next = false; // to be updated later for next iteration
10201020

10211021
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
10221022
if (any_active) {

0 commit comments

Comments
 (0)