From 9feae38c4950939a797eb3322082663a6de64659 Mon Sep 17 00:00:00 2001 From: lilujia Date: Mon, 5 Jan 2026 12:07:51 +0800 Subject: [PATCH] [XPU] update xpu fft version, supporting stream parameter --- cmake/external/xpu.cmake | 2 +- paddle/phi/kernels/funcs/fft_fill_conj_xpu.h | 14 +++---- paddle/phi/kernels/funcs/fft_xpu.cc | 5 ++- paddle/phi/kernels/xpu/complex_grad_kernel.cc | 33 ++++++++------- paddle/phi/kernels/xpu/complex_kernel.cc | 41 ++++++++++--------- paddle/phi/kernels/xpu/elementwise_kernel.cc | 10 +++-- 6 files changed, 55 insertions(+), 50 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index d4fc8f60cc51a3..5ba3a3fef2054e 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -73,7 +73,7 @@ if(NOT DEFINED XPU_FFT_BASE_DATE) if(WITH_ARM) set(XPU_FFT_BASE_DATE "20251017") else() - set(XPU_FFT_BASE_DATE "20251226") + set(XPU_FFT_BASE_DATE "20251230") endif() endif() diff --git a/paddle/phi/kernels/funcs/fft_fill_conj_xpu.h b/paddle/phi/kernels/funcs/fft_fill_conj_xpu.h index 3190b96298bd5a..880bfbf12ce84d 100644 --- a/paddle/phi/kernels/funcs/fft_fill_conj_xpu.h +++ b/paddle/phi/kernels/funcs/fft_fill_conj_xpu.h @@ -23,7 +23,8 @@ namespace xfft_internal::xpu { template // T supports float2, double2 -int FFTFillConj(int64_t N, +int FFTFillConj(const XPUStream stream, + int64_t N, const T* src_data, T* dst_data, const int64_t* src_strides, @@ -35,7 +36,8 @@ int FFTFillConj(int64_t N, int64_t rank); template // T supports float2, double2 -int FFTFillConjGrad(int N, +int FFTFillConjGrad(const XPUStream stream, + int N, T* input, int64_t axis, int64_t stride_second_to_last_axis, @@ -97,9 +99,8 @@ void FFTFillConj(const DeviceContext& dev_ctx, _is_fft_axis.get(), rank * sizeof(bool), XPUMemcpyKind::XPU_HOST_TO_DEVICE); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::FFTFillConj( + dev_ctx.x_context()->xpu_stream, dst->numel(), reinterpret_cast(src_data), reinterpret_cast(dst_data), @@ -111,7 +112,6 @@ void FFTFillConj(const DeviceContext& dev_ctx, static_cast(last_axis_size), static_cast(rank)); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } template @@ -127,9 +127,8 @@ void FFTFillConjGrad(const DeviceContext& dev_ctx, stride_to_last_axis *= ddim[i + 1]; } int64_t stride_second_to_last_axis = stride_to_last_axis * ddim[axes.back()]; - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::FFTFillConjGrad( + dev_ctx.x_context()->xpu_stream, x_grad->numel(), reinterpret_cast(x_grad->data()), axes.back(), @@ -137,7 +136,6 @@ void FFTFillConjGrad(const DeviceContext& dev_ctx, stride_to_last_axis, double_length); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } } // namespace funcs diff --git a/paddle/phi/kernels/funcs/fft_xpu.cc b/paddle/phi/kernels/funcs/fft_xpu.cc index 9ee6b1e65af4e5..293ef363db1ef9 100644 --- a/paddle/phi/kernels/funcs/fft_xpu.cc +++ b/paddle/phi/kernels/funcs/fft_xpu.cc @@ -154,8 +154,9 @@ void exec_fft(const phi::XPUContext& dev_ctx, DenseTensor workspace_tensor = Empty(dev_ctx, {workspace_size}); // prepare cufft for execution - PADDLE_ENFORCE_FFT_SUCCESS( - phi::dynload::cufftSetStream(config->plan(), nullptr)); + PADDLE_ENFORCE_FFT_SUCCESS(phi::dynload::cufftSetStream( + config->plan(), + reinterpret_cast(dev_ctx.x_context()->xpu_stream))); PADDLE_ENFORCE_FFT_SUCCESS( phi::dynload::cufftSetWorkArea(config->plan(), workspace_tensor.data())); diff --git a/paddle/phi/kernels/xpu/complex_grad_kernel.cc b/paddle/phi/kernels/xpu/complex_grad_kernel.cc index 4edc37db1becb4..4a98135b33d815 100644 --- a/paddle/phi/kernels/xpu/complex_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/complex_grad_kernel.cc @@ -25,24 +25,33 @@ namespace xfft_internal::xpu { // just for declaration here, the real implementation is in libcufft.so template -int combine_as_complex(int N, const T* real, const T* imag, TComplex* out); +int combine_as_complex( + const XPUStream stream, int N, const T* real, const T* imag, TComplex* out); template <> -int combine_as_complex(int N, +int combine_as_complex(const XPUStream stream, + int N, const float* real, const float* imag, float2* out); template <> -int combine_as_complex(int N, +int combine_as_complex(const XPUStream stream, + int N, const double* real, const double* imag, double2* out); template -int complex_spilt(int N, const TComplex* in, T* real, T* imag); +int complex_spilt( + const XPUStream stream, int N, const TComplex* in, T* real, T* imag); template <> -int complex_spilt(int N, const float2* in, float* real, float* imag); +int complex_spilt( + const XPUStream stream, int N, const float2* in, float* real, float* imag); template <> -int complex_spilt(int N, const double2* in, double* real, double* imag); +int complex_spilt(const XPUStream stream, + int N, + const double2* in, + double* real, + double* imag); } // namespace xfft_internal::xpu namespace phi { @@ -73,16 +82,14 @@ void RealGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(dx, static_cast(numel * sizeof(T))); DenseTensor imag = Fill, Context>( dev_ctx, common::vectorize(dout.dims()), phi::dtype::Real(0.0)); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::combine_as_complex( + dev_ctx.x_context()->xpu_stream, numel, reinterpret_cast*>( dout.data>()), imag.data>(), reinterpret_cast(dx_data)); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } template @@ -100,16 +107,14 @@ void ImagGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(dx, static_cast(numel * sizeof(T))); DenseTensor real = Fill, Context>( dev_ctx, common::vectorize(dout.dims()), phi::dtype::Real(0.0)); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::combine_as_complex( + dev_ctx.x_context()->xpu_stream, numel, real.data>(), reinterpret_cast*>( dout.data>()), reinterpret_cast(dx_data)); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } template @@ -146,15 +151,13 @@ void ComplexGradKernel(const Context& dev_ctx, imag_dout.Resize(dout.dims()); T* real_data = dev_ctx.template Alloc(&real_dout); T* imag_data = dev_ctx.template Alloc(&imag_dout); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::complex_spilt( + dev_ctx.x_context()->xpu_stream, numel, reinterpret_cast(dout.data()), real_data, imag_data); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); if (dx) { if (x.dims() == dout.dims()) { dx->ShareDataWith(real_dout); diff --git a/paddle/phi/kernels/xpu/complex_kernel.cc b/paddle/phi/kernels/xpu/complex_kernel.cc index 951ab2e140e9d8..56bdf8f3204d5e 100644 --- a/paddle/phi/kernels/xpu/complex_kernel.cc +++ b/paddle/phi/kernels/xpu/complex_kernel.cc @@ -25,27 +25,36 @@ namespace xfft_internal::xpu { // just for declaration here, the real implementation is in libcufft.so template -int combine_as_complex(int N, const T* real, const T* imag, TComplex* out); +int combine_as_complex( + const XPUStream stream, int N, const T* real, const T* imag, TComplex* out); template <> -int combine_as_complex(int N, +int combine_as_complex(const XPUStream stream, + int N, const float* real, const float* imag, float2* out); template <> -int combine_as_complex(int N, +int combine_as_complex(const XPUStream stream, + int N, const double* real, const double* imag, double2* out); template -int complex_spilt(int N, const TComplex* in, T* real, T* imag); +int complex_spilt( + const XPUStream stream, int N, const TComplex* in, T* real, T* imag); template <> -int complex_spilt(int N, const float2* in, float* real, float* imag); +int complex_spilt( + const XPUStream stream, int N, const float2* in, float* real, float* imag); template <> -int complex_spilt(int N, const double2* in, double* real, double* imag); +int complex_spilt(const XPUStream stream, + int N, + const double2* in, + double* real, + double* imag); template // T supports float2, double2 -int Conj(int N, const T* input, T* output); +int Conj(const XPUStream stream, int N, const T* input, T* output); } // namespace xfft_internal::xpu namespace phi { @@ -59,21 +68,19 @@ void ConjKernel(const Context& dev_ctx, } dev_ctx.template Alloc(out); if (std::is_same_v) { - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::Conj( + dev_ctx.x_context()->xpu_stream, x.numel(), reinterpret_cast(x.data()), reinterpret_cast(out->data())); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } else if (std::is_same_v) { int r = xfft_internal::xpu::Conj( + dev_ctx.x_context()->xpu_stream, x.numel(), reinterpret_cast(x.data()), reinterpret_cast(out->data())); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } else { using XPUType = typename XPUCopyTypeTrait::Type; const auto* input_data = x.data(); @@ -100,15 +107,13 @@ void RealKernel(const Context& dev_ctx, DenseTensor imag; imag.Resize(x.dims()); dev_ctx.template Alloc>(&imag); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::complex_spilt( + dev_ctx.x_context()->xpu_stream, out->numel(), reinterpret_cast(x.data()), out->data>(), imag.data>()); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } template @@ -126,15 +131,13 @@ void ImagKernel(const Context& dev_ctx, DenseTensor real; real.Resize(x.dims()); dev_ctx.template Alloc>(&real); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::complex_spilt( + dev_ctx.x_context()->xpu_stream, out->numel(), reinterpret_cast(x.data()), real.data>(), out->data>()); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } template @@ -178,15 +181,13 @@ void ComplexKernel(const Context& dev_ctx, } dev_ctx.template Alloc(out); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::combine_as_complex( + dev_ctx.x_context()->xpu_stream, out->numel(), x_data, y_data, reinterpret_cast(out->data())); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/elementwise_kernel.cc b/paddle/phi/kernels/xpu/elementwise_kernel.cc index 3f66dca98a834b..4b04975af1ae13 100644 --- a/paddle/phi/kernels/xpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_kernel.cc @@ -25,7 +25,11 @@ #include "paddle/phi/kernels/funcs/common_infer_shape_functions.h" namespace xfft_internal::xpu { template // T supports float2, double2 -int RemainderFunctor(int N, const T* input_x, const T* input_y, T* output); +int RemainderFunctor(const XPUStream stream, + int N, + const T* input_x, + const T* input_y, + T* output); } #endif @@ -146,15 +150,13 @@ void RemainderKernel(const XPUContext& dev_ctx, } dev_ctx.template Alloc(out); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream)); int r = xfft_internal::xpu::RemainderFunctor( + dev_ctx.x_context()->xpu_stream, out->numel(), reinterpret_cast(x_data), reinterpret_cast(y_data), reinterpret_cast(out->data())); PADDLE_ENFORCE_XPU_SUCCESS(r); - PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait()); } #endif