Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,10 @@ void FlashAttnGradBaseKernel(
const float softmax_scale = 1.0f / std::sqrt(head_size);
const float softmax_unscale = std::sqrt(head_size);

int version =
FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic &&
(head_size == 64 || head_size == 128 || head_size == 256)
? FLAGS_flash_attn_version
: 2;
int version =
FLAGS_flash_attn_version == 3 && FLAGS_cudnn_deterministic && head_size >= 128
? 2
: FLAGS_flash_attn_version;
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(dev_ctx,
version,
Expand Down
9 changes: 4 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,10 @@ void FlashAttnBaseKernel(
const float softmax_scale = 1.0f / std::sqrt(head_size);
const float softmax_unscale = std::sqrt(head_size);

int version =
FLAGS_flash_attn_version == 3 && !FLAGS_cudnn_deterministic &&
(head_size == 64 || head_size == 128 || head_size == 256)
? FLAGS_flash_attn_version
: 2;
int version =
FLAGS_flash_attn_version == 3 && FLAGS_cudnn_deterministic && head_size >= 128
? 2
: FLAGS_flash_attn_version;
FlashAttnFwdParamsV2<T> params = FlashAttnFwdParamsV2<T>(dev_ctx,
version,
batch_size,
Expand Down
52 changes: 31 additions & 21 deletions paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,18 @@ void FlashAttnV3GradBaseKernel(
dev_ctx, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads});
dynload::fa3_bwd_params_set_dq_semaphore(params_handle,
dq_semaphore.data<int>());
if (num_heads_k != num_heads &&
dynload::fa3_bwd_params_get_deterministic(params_handle)) {
// TODO(tridao): do we need to zero them out?
DenseTensor dk_semaphore = phi::Empty<int32_t>(
DenseTensor dk_semaphore = phi::Empty<int32_t>(
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
DenseTensor dv_semaphore = phi::Empty<int32_t>(
DenseTensor dv_semaphore = phi::Empty<int32_t>(
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
if (num_heads_k != num_heads &&
dynload::fa3_bwd_params_get_deterministic(params_handle)) {
// xiangrui: we need to zero them out
phi::funcs::SetConstant<Context, int32_t> set_zero_dk;
set_zero_dk(dev_ctx, &dk_semaphore, static_cast<int32_t>(0));
phi::funcs::SetConstant<Context, int32_t> set_zero_dv;
set_zero_dv(dev_ctx, &dv_semaphore, static_cast<int32_t>(0));

dynload::fa3_bwd_params_set_dk_semaphore(params_handle,
dk_semaphore.data<int>());
dynload::fa3_bwd_params_set_dv_semaphore(params_handle,
Expand Down Expand Up @@ -599,11 +604,11 @@ void FlashAttnV3GradKernel(const Context &dev_ctx,
0,
common::errors::InvalidArgument(
"sm_margin is not supported, please set sm_margin to 0"));
PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic,
false,
common::errors::InvalidArgument(
"deterministic is not supported in flash attention 3, "
"please set FLAGS_cudnn_deterministic to false"));
// PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic,
// false,
// common::errors::InvalidArgument(
// "deterministic is not supported in flash attention 3, "
// "please set FLAGS_cudnn_deterministic to false"));
// umiswing: fake grad tensor for FlashAttnV3GradBaseKernel
DenseTensor softmax_d;
DenseTensor softmax_lse_log2;
Expand Down Expand Up @@ -737,11 +742,11 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
0,
common::errors::InvalidArgument(
"sm_margin is not supported, please set sm_margin to 0"));
PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic,
false,
common::errors::InvalidArgument(
"deterministic is not supported in flash attention 3, "
"please set FLAGS_cudnn_deterministic to false"));
// PADDLE_ENFORCE_EQ(FLAGS_cudnn_deterministic,
// false,
// common::errors::InvalidArgument(
// "deterministic is not supported in flash attention 3, "
// "please set FLAGS_cudnn_deterministic to false"));

PADDLE_ENFORCE_EQ(
q.dims()[q.dims().size() - 1],
Expand Down Expand Up @@ -1391,13 +1396,18 @@ void FlashMaskV2GradBaseKernel(
dev_ctx, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads});
dynload::flashmaskv2_bwd_params_set_dq_semaphore(params_handle,
dq_semaphore.data<int>());
if (num_heads_k != num_heads &&
dynload::flashmaskv2_bwd_params_get_deterministic(params_handle)) {
// TODO(tridao): do we need to zero them out?
DenseTensor dk_semaphore = phi::Empty<int32_t>(
DenseTensor dk_semaphore = phi::Empty<int32_t>(
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
DenseTensor dv_semaphore = phi::Empty<int32_t>(
DenseTensor dv_semaphore = phi::Empty<int32_t>(
dev_ctx, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k});
if (num_heads_k != num_heads &&
dynload::flashmaskv2_bwd_params_get_deterministic(params_handle)) {
// xiangrui: we need to zero them out
phi::funcs::SetConstant<Context, int32_t> set_zero_dk;
set_zero_dk(dev_ctx, &dk_semaphore, static_cast<int32_t>(0));
phi::funcs::SetConstant<Context, int32_t> set_zero_dv;
set_zero_dv(dev_ctx, &dv_semaphore, static_cast<int32_t>(0));

dynload::flashmaskv2_bwd_params_set_dk_semaphore(params_handle,
dk_semaphore.data<int>());
dynload::flashmaskv2_bwd_params_set_dv_semaphore(params_handle,
Expand Down Expand Up @@ -1546,7 +1556,7 @@ void FlashMaskV2GradKernel(
-1, // window_size_left,
-1, // window_size_right,
0, // softcap,
false, // deterministic,
FLAGS_cudnn_deterministic, // deterministic,
0, // sm_margin,
dq,
dk,
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2268,10 +2268,10 @@ def flashmask_attention(

if "xpu" in paddle.get_device():
fa_version = 2
elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[
"FLAGS_cudnn_deterministic"
]:
fa_version = 2
# elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[
# "FLAGS_cudnn_deterministic"
# ]:
# fa_version = 2
else:
fa_version = paddle.base.framework.get_flags(
["FLAGS_flash_attn_version"]
Expand Down