diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 68980aa53ef986..d0968334aceae0 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -625,11 +625,10 @@ void FlashAttnGradBaseKernel( const float softmax_scale = 1.0f / std::sqrt(head_size); const float softmax_unscale = std::sqrt(head_size); - int version = - FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic && - (head_size == 64 || head_size == 128 || head_size == 256) - ? FLAGS_flash_attn_version - : 2; + int version = + FLAGS_flash_attn_version == 3 && FLAGS_cudnn_deterministic && head_size >= 128 + ? 2 + : FLAGS_flash_attn_version; FlashAttnBwdParamsV2 params = FlashAttnBwdParamsV2(dev_ctx, version, diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 02226cece30a6b..a5cdbcdb09e8d9 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -378,11 +378,10 @@ void FlashAttnBaseKernel( const float softmax_scale = 1.0f / std::sqrt(head_size); const float softmax_unscale = std::sqrt(head_size); - int version = - FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic && - (head_size == 64 || head_size == 128 || head_size == 256) - ? FLAGS_flash_attn_version - : 2; + int version = + FLAGS_flash_attn_version == 3 && FLAGS_cudnn_deterministic && head_size >= 128 + ? 2 + : FLAGS_flash_attn_version; FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(dev_ctx, version, batch_size, diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu index f2629f872d3d85..771b7a7fbc8807 100644 --- a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu @@ -511,13 +511,18 @@ void FlashAttnV3GradBaseKernel( dev_ctx, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}); dynload::fa3_bwd_params_set_dq_semaphore(params_handle, dq_semaphore.data()); - if (num_heads_k != num_heads && - dynload::fa3_bwd_params_get_deterministic(params_handle)) { - // TODO(tridao): do we need to zero them out? - DenseTensor dk_semaphore = phi::Empty( + DenseTensor dk_semaphore = phi::Empty( dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}); - DenseTensor dv_semaphore = phi::Empty( + DenseTensor dv_semaphore = phi::Empty( dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}); + if (num_heads_k != num_heads && + dynload::fa3_bwd_params_get_deterministic(params_handle)) { + // xiangrui: we need to zero them out + phi::funcs::SetConstant set_zero_dk; + set_zero_dk(dev_ctx, &dk_semaphore, static_cast(0)); + phi::funcs::SetConstant set_zero_dv; + set_zero_dv(dev_ctx, &dv_semaphore, static_cast(0)); + dynload::fa3_bwd_params_set_dk_semaphore(params_handle, dk_semaphore.data()); dynload::fa3_bwd_params_set_dv_semaphore(params_handle, @@ -599,11 +604,11 @@ void FlashAttnV3GradKernel(const Context &dev_ctx, 0, common::errors::InvalidArgument( "sm_margin is not supported, please set sm_margin to 0")); - PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic, - false, - common::errors::InvalidArgument( - "deterministic is not supported in flash attention 3, " - "please set FLAGS_cudnn_deterministic to false")); + // PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic, + // false, + // common::errors::InvalidArgument( + // "deterministic is not supported in flash attention 3, " + // "please set FLAGS_cudnn_deterministic to false")); // umiswing: fake grad tensor for FlashAttnV3GradBaseKernel DenseTensor softmax_d; DenseTensor softmax_lse_log2; @@ -737,11 +742,11 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx, 0, common::errors::InvalidArgument( "sm_margin is not supported, please set sm_margin to 0")); - PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic, - false, - common::errors::InvalidArgument( - "deterministic is not supported in flash attention 3, " - "please set FLAGS_cudnn_deterministic to false")); + // PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic, + // false, + // common::errors::InvalidArgument( + // "deterministic is not supported in flash attention 3, " + // "please set FLAGS_cudnn_deterministic to false")); PADDLE_ENFORCE_EQ( q.dims()[q.dims().size() - 1], @@ -1391,13 +1396,18 @@ void FlashMaskV2GradBaseKernel( dev_ctx, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}); dynload::flashmaskv2_bwd_params_set_dq_semaphore(params_handle, dq_semaphore.data()); - if (num_heads_k != num_heads && - dynload::flashmaskv2_bwd_params_get_deterministic(params_handle)) { - // TODO(tridao): do we need to zero them out? - DenseTensor dk_semaphore = phi::Empty( + DenseTensor dk_semaphore = phi::Empty( dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}); - DenseTensor dv_semaphore = phi::Empty( + DenseTensor dv_semaphore = phi::Empty( dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}); + if (num_heads_k != num_heads && + dynload::flashmaskv2_bwd_params_get_deterministic(params_handle)) { + // xiangrui: we need to zero them out + phi::funcs::SetConstant set_zero_dk; + set_zero_dk(dev_ctx, &dk_semaphore, static_cast(0)); + phi::funcs::SetConstant set_zero_dv; + set_zero_dv(dev_ctx, &dv_semaphore, static_cast(0)); + dynload::flashmaskv2_bwd_params_set_dk_semaphore(params_handle, dk_semaphore.data()); dynload::flashmaskv2_bwd_params_set_dv_semaphore(params_handle, @@ -1546,7 +1556,7 @@ void FlashMaskV2GradKernel( -1, // window_size_left, -1, // window_size_right, 0, // softcap, - false, // deterministic, + FLAGS_cudnn_deterministic, // deterministic, 0, // sm_margin, dq, dk, diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 2d1b050cdba6e7..07ed3fff183cac 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -2268,10 +2268,10 @@ def flashmask_attention( if "xpu" in paddle.get_device(): fa_version = 2 - elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[ - "FLAGS_cudnn_deterministic" - ]: - fa_version = 2 + # elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[ + # "FLAGS_cudnn_deterministic" + # ]: + # fa_version = 2 else: fa_version = paddle.base.framework.get_flags( ["FLAGS_flash_attn_version"]