Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
const paddle::optional<DenseTensor> &seqused_k,
const DenseTensor &out_grad,
float const softmax_scale,
int const max_seqlen_q,
int const max_seqlen_k,
const Scalar &max_seqlen_q,
const Scalar &max_seqlen_k,
bool const causal,
int const window_size_left,
int const window_size_right,
Expand Down Expand Up @@ -756,6 +756,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
DenseTensor dq_accum;
DenseTensor dk_accum;
DenseTensor dv_accum;
const int64_t max_seqlen_q_ = max_seqlen_q.to<int64_t>();
const int64_t max_seqlen_k_ = max_seqlen_k.to<int64_t>();
FlashAttnV3GradBaseKernel<T, Context>(dev_ctx,
out_grad,
q,
Expand All @@ -770,8 +772,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
max_seqlen_q_,
max_seqlen_k_,
softmax_scale,
causal,
window_size_left,
Expand Down
12 changes: 7 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,8 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx,
const paddle::optional<DenseTensor> &q_descale,
const paddle::optional<DenseTensor> &k_descale,
const paddle::optional<DenseTensor> &v_descale,
const int max_seqlen_q,
const int max_seqlen_k,
const Scalar &max_seqlen_q,
const Scalar &max_seqlen_k,
const float softmax_scale,
const bool causal,
const int window_size_left,
Expand Down Expand Up @@ -1150,6 +1150,8 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx,

DenseTensor out_accum;
DenseTensor softmax_lse_accum;
const int64_t max_seqlen_q_ = max_seqlen_q.to<int64_t>();
const int64_t max_seqlen_k_ = max_seqlen_k.to<int64_t>();
FlashAttnV3BaseKernel<T, Context>(dev_ctx,
q,
k,
Expand All @@ -1171,9 +1173,9 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx,
q_descale,
k_descale,
v_descale,
paddle::none, // scheduler_metadata
max_seqlen_q, // max_seqlen_q_
max_seqlen_k, // max_seqlen_k_
paddle::none, // scheduler_metadata
max_seqlen_q_, // max_seqlen_q_
max_seqlen_k_, // max_seqlen_k_
softmax_scale,
causal,
window_size_left,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1192,8 +1192,8 @@
data_type : q

- backward_op : flash_attn_v3_varlen_grad
forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, int max_seqlen_q, int max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, int max_seqlen_q, int max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin)
forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, Scalar max_seqlen_q, Scalar max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin)
optional : seqused_q, seqused_k
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,16 @@
data_type : int64_t
support_tensor : true

- op : flash_attn_v3_varlen
backward : flash_attn_v3_varlen_grad
scalar :
max_seqlen_q :
data_type : int64_t
support_tensor : true
max_seqlen_k :
data_type : int64_t
support_tensor : true

- op : flash_attn_varlen_qkvpacked
backward : flash_attn_varlen_qkvpacked_grad
scalar :
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2113,7 +2113,7 @@
backward : flash_attn_v3_grad

- op : flash_attn_v3_varlen
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, int max_seqlen_q, int max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, Scalar max_seqlen_q, Scalar max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin)
output : Tensor(out), Tensor(softmax_lse)
optional : seqused_q, seqused_k, qv, q_descale, k_descale, v_descale
infer_meta :
Expand Down
14 changes: 7 additions & 7 deletions python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_RetT = TypeVar("_RetT")


def _is_in_or_scalar_tensor(x):
def _is_int_or_scalar_tensor(x):
if isinstance(x, int):
return True
if isinstance(x, (paddle.Tensor, paddle.pir.Value)):
Expand Down Expand Up @@ -420,8 +420,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
kwargs["shape_or_dtype"] = kwargs.pop("dtype")
elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs):
kwargs["shape_or_dtype"] = kwargs.pop("size")
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]):
if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]):
kwargs["x"] = args[0]
kwargs['shape_or_dtype'] = list(args[1:])
args = ()
Expand Down Expand Up @@ -552,8 +552,8 @@ def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if ("input" in kwargs) and ("x" not in kwargs):
kwargs["x"] = kwargs.pop("input")
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]):
if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]):
kwargs["x"] = args[0]
kwargs['shape'] = list(args[1:])
args = ()
Expand Down Expand Up @@ -624,8 +624,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
kwargs["x"] = kwargs.pop("input")
if ("size" in kwargs) and ("shape" not in kwargs):
kwargs["shape"] = kwargs.pop("size")
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]):
if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]):
kwargs["x"] = args[0]
kwargs['shape'] = list(args[1:])
args = ()
Expand Down
Loading