diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu index f2629f872d3d85..361a848349da5f 100644 --- a/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu @@ -707,8 +707,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx, const paddle::optional &seqused_k, const DenseTensor &out_grad, float const softmax_scale, - int const max_seqlen_q, - int const max_seqlen_k, + const Scalar &max_seqlen_q, + const Scalar &max_seqlen_k, bool const causal, int const window_size_left, int const window_size_right, @@ -756,6 +756,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx, DenseTensor dq_accum; DenseTensor dk_accum; DenseTensor dv_accum; + const int64_t max_seqlen_q_ = max_seqlen_q.to(); + const int64_t max_seqlen_k_ = max_seqlen_k.to(); FlashAttnV3GradBaseKernel(dev_ctx, out_grad, q, @@ -770,8 +772,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx, cu_seqlens_k, seqused_q, seqused_k, - max_seqlen_q, - max_seqlen_k, + max_seqlen_q_, + max_seqlen_k_, softmax_scale, causal, window_size_left, diff --git a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu index afad7e8a5eefa3..7f3a10a13efea9 100644 --- a/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu @@ -1082,8 +1082,8 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx, const paddle::optional &q_descale, const paddle::optional &k_descale, const paddle::optional &v_descale, - const int max_seqlen_q, - const int max_seqlen_k, + const Scalar &max_seqlen_q, + const Scalar &max_seqlen_k, const float softmax_scale, const bool causal, const int window_size_left, @@ -1150,6 +1150,8 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx, DenseTensor out_accum; DenseTensor softmax_lse_accum; + const int64_t max_seqlen_q_ = max_seqlen_q.to(); + const int64_t max_seqlen_k_ = max_seqlen_k.to(); FlashAttnV3BaseKernel(dev_ctx, q, k, @@ -1171,9 +1173,9 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx, q_descale, k_descale, v_descale, - paddle::none, // scheduler_metadata - max_seqlen_q, // max_seqlen_q_ - max_seqlen_k, // max_seqlen_k_ + paddle::none, // scheduler_metadata + max_seqlen_q_, // max_seqlen_q_ + max_seqlen_k_, // max_seqlen_k_ softmax_scale, causal, window_size_left, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 7ba093520b531e..2ac4bf2addf5d2 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1192,8 +1192,8 @@ data_type : q - backward_op : flash_attn_v3_varlen_grad - forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, int max_seqlen_q, int max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse) - args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, int max_seqlen_q, int max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin) + forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, Scalar max_seqlen_q, Scalar max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse) + args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin) optional : seqused_q, seqused_k output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 6ca22fc2440e8e..e98ab2b61ef3aa 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -1337,6 +1337,16 @@ data_type : int64_t support_tensor : true +- op : flash_attn_v3_varlen + backward : flash_attn_v3_varlen_grad + scalar : + max_seqlen_q : + data_type : int64_t + support_tensor : true + max_seqlen_k : + data_type : int64_t + support_tensor : true + - op : flash_attn_varlen_qkvpacked backward : flash_attn_varlen_qkvpacked_grad scalar : diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 3ff346e3dbe608..7965584b7fbc61 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -2113,7 +2113,7 @@ backward : flash_attn_v3_grad - op : flash_attn_v3_varlen - args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, int max_seqlen_q, int max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, Scalar max_seqlen_q, Scalar max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) output : Tensor(out), Tensor(softmax_lse) optional : seqused_q, seqused_k, qv, q_descale, k_descale, v_descale infer_meta : diff --git a/python/paddle/utils/decorator_utils.py b/python/paddle/utils/decorator_utils.py index cb22ec87955d54..df546e77df39c2 100644 --- a/python/paddle/utils/decorator_utils.py +++ b/python/paddle/utils/decorator_utils.py @@ -30,7 +30,7 @@ _RetT = TypeVar("_RetT") -def _is_in_or_scalar_tensor(x): +def _is_int_or_scalar_tensor(x): if isinstance(x, int): return True if isinstance(x, (paddle.Tensor, paddle.pir.Value)): @@ -420,8 +420,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: kwargs["shape_or_dtype"] = kwargs.pop("dtype") elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs): kwargs["shape_or_dtype"] = kwargs.pop("size") - elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]): - if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]): + elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]): + if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]): kwargs["x"] = args[0] kwargs['shape_or_dtype'] = list(args[1:]) args = () @@ -552,8 +552,8 @@ def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]: def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: if ("input" in kwargs) and ("x" not in kwargs): kwargs["x"] = kwargs.pop("input") - elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]): - if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]): + elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]): + if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]): kwargs["x"] = args[0] kwargs['shape'] = list(args[1:]) args = () @@ -624,8 +624,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: kwargs["x"] = kwargs.pop("input") if ("size" in kwargs) and ("shape" not in kwargs): kwargs["shape"] = kwargs.pop("size") - elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]): - if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]): + elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]): + if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]): kwargs["x"] = args[0] kwargs['shape'] = list(args[1:]) args = ()