Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ if(NOT DEFINED XPU_FFT_BASE_DATE)
if(WITH_ARM)
set(XPU_FFT_BASE_DATE "20251017")
else()
set(XPU_FFT_BASE_DATE "20251226")
set(XPU_FFT_BASE_DATE "20251230")
endif()
endif()

Expand Down
14 changes: 6 additions & 8 deletions paddle/phi/kernels/funcs/fft_fill_conj_xpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

namespace xfft_internal::xpu {
template <typename T> // T supports float2, double2
int FFTFillConj(int64_t N,
int FFTFillConj(const XPUStream stream,
int64_t N,
const T* src_data,
T* dst_data,
const int64_t* src_strides,
Expand All @@ -35,7 +36,8 @@ int FFTFillConj(int64_t N,
int64_t rank);

template <typename T> // T supports float2, double2
int FFTFillConjGrad(int N,
int FFTFillConjGrad(const XPUStream stream,
int N,
T* input,
int64_t axis,
int64_t stride_second_to_last_axis,
Expand Down Expand Up @@ -97,9 +99,8 @@ void FFTFillConj(const DeviceContext& dev_ctx,
_is_fft_axis.get(),
rank * sizeof(bool),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::FFTFillConj(
dev_ctx.x_context()->xpu_stream,
dst->numel(),
reinterpret_cast<cuFloatComplex*>(src_data),
reinterpret_cast<cuFloatComplex*>(dst_data),
Expand All @@ -111,7 +112,6 @@ void FFTFillConj(const DeviceContext& dev_ctx,
static_cast<int64_t>(last_axis_size),
static_cast<int64_t>(rank));
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}

template <typename DeviceContext, typename C>
Expand All @@ -127,17 +127,15 @@ void FFTFillConjGrad(const DeviceContext& dev_ctx,
stride_to_last_axis *= ddim[i + 1];
}
int64_t stride_second_to_last_axis = stride_to_last_axis * ddim[axes.back()];
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::FFTFillConjGrad(
dev_ctx.x_context()->xpu_stream,
x_grad->numel(),
reinterpret_cast<cuFloatComplex*>(x_grad->data<C>()),
axes.back(),
stride_second_to_last_axis,
stride_to_last_axis,
double_length);
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}

} // namespace funcs
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/funcs/fft_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ void exec_fft(const phi::XPUContext& dev_ctx,
DenseTensor workspace_tensor = Empty<uint8_t>(dev_ctx, {workspace_size});

// prepare cufft for execution
PADDLE_ENFORCE_FFT_SUCCESS(
phi::dynload::cufftSetStream(config->plan(), nullptr));
PADDLE_ENFORCE_FFT_SUCCESS(phi::dynload::cufftSetStream(
config->plan(),
reinterpret_cast<cudaStream_t>(dev_ctx.x_context()->xpu_stream)));
PADDLE_ENFORCE_FFT_SUCCESS(
phi::dynload::cufftSetWorkArea(config->plan(), workspace_tensor.data()));

Expand Down
33 changes: 18 additions & 15 deletions paddle/phi/kernels/xpu/complex_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,33 @@
namespace xfft_internal::xpu {
// just for declaration here, the real implementation is in libcufft.so
template <typename T, typename TComplex>
int combine_as_complex(int N, const T* real, const T* imag, TComplex* out);
int combine_as_complex(
const XPUStream stream, int N, const T* real, const T* imag, TComplex* out);
template <>
int combine_as_complex(int N,
int combine_as_complex(const XPUStream stream,
int N,
const float* real,
const float* imag,
float2* out);
template <>
int combine_as_complex(int N,
int combine_as_complex(const XPUStream stream,
int N,
const double* real,
const double* imag,
double2* out);

template <typename TComplex, typename T>
int complex_spilt(int N, const TComplex* in, T* real, T* imag);
int complex_spilt(
const XPUStream stream, int N, const TComplex* in, T* real, T* imag);
template <>
int complex_spilt(int N, const float2* in, float* real, float* imag);
int complex_spilt(
const XPUStream stream, int N, const float2* in, float* real, float* imag);
template <>
int complex_spilt(int N, const double2* in, double* real, double* imag);
int complex_spilt(const XPUStream stream,
int N,
const double2* in,
double* real,
double* imag);
} // namespace xfft_internal::xpu

namespace phi {
Expand Down Expand Up @@ -73,16 +82,14 @@ void RealGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
DenseTensor imag = Fill<phi::dtype::Real<T>, Context>(
dev_ctx, common::vectorize<int>(dout.dims()), phi::dtype::Real<T>(0.0));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::combine_as_complex(
dev_ctx.x_context()->xpu_stream,
numel,
reinterpret_cast<const phi::dtype::Real<T>*>(
dout.data<phi::dtype::Real<T>>()),
imag.data<phi::dtype::Real<T>>(),
reinterpret_cast<XPUComplexType*>(dx_data));
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}

template <typename T, typename Context>
Expand All @@ -100,16 +107,14 @@ void ImagGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
DenseTensor real = Fill<phi::dtype::Real<T>, Context>(
dev_ctx, common::vectorize<int>(dout.dims()), phi::dtype::Real<T>(0.0));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::combine_as_complex(
dev_ctx.x_context()->xpu_stream,
numel,
real.data<phi::dtype::Real<T>>(),
reinterpret_cast<const phi::dtype::Real<T>*>(
dout.data<phi::dtype::Real<T>>()),
reinterpret_cast<XPUComplexType*>(dx_data));
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}

template <typename T, typename Context>
Expand Down Expand Up @@ -146,15 +151,13 @@ void ComplexGradKernel(const Context& dev_ctx,
imag_dout.Resize(dout.dims());
T* real_data = dev_ctx.template Alloc<T>(&real_dout);
T* imag_data = dev_ctx.template Alloc<T>(&imag_dout);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::complex_spilt(
dev_ctx.x_context()->xpu_stream,
numel,
reinterpret_cast<const XPUComplexType*>(dout.data<C>()),
real_data,
imag_data);
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
if (dx) {
if (x.dims() == dout.dims()) {
dx->ShareDataWith(real_dout);
Expand Down
41 changes: 21 additions & 20 deletions paddle/phi/kernels/xpu/complex_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,36 @@
namespace xfft_internal::xpu {
// just for declaration here, the real implementation is in libcufft.so
template <typename T, typename TComplex>
int combine_as_complex(int N, const T* real, const T* imag, TComplex* out);
int combine_as_complex(
const XPUStream stream, int N, const T* real, const T* imag, TComplex* out);
template <>
int combine_as_complex(int N,
int combine_as_complex(const XPUStream stream,
int N,
const float* real,
const float* imag,
float2* out);
template <>
int combine_as_complex(int N,
int combine_as_complex(const XPUStream stream,
int N,
const double* real,
const double* imag,
double2* out);

template <typename TComplex, typename T>
int complex_spilt(int N, const TComplex* in, T* real, T* imag);
int complex_spilt(
const XPUStream stream, int N, const TComplex* in, T* real, T* imag);
template <>
int complex_spilt(int N, const float2* in, float* real, float* imag);
int complex_spilt(
const XPUStream stream, int N, const float2* in, float* real, float* imag);
template <>
int complex_spilt(int N, const double2* in, double* real, double* imag);
int complex_spilt(const XPUStream stream,
int N,
const double2* in,
double* real,
double* imag);

template <typename T> // T supports float2, double2
int Conj(int N, const T* input, T* output);
int Conj(const XPUStream stream, int N, const T* input, T* output);
} // namespace xfft_internal::xpu

namespace phi {
Expand All @@ -59,21 +68,19 @@ void ConjKernel(const Context& dev_ctx,
}
dev_ctx.template Alloc<T>(out);
if (std::is_same_v<T, phi::complex64>) {
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::Conj(
dev_ctx.x_context()->xpu_stream,
x.numel(),
reinterpret_cast<const cuFloatComplex*>(x.data<T>()),
reinterpret_cast<cuFloatComplex*>(out->data<T>()));
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
} else if (std::is_same_v<T, phi::complex128>) {
int r = xfft_internal::xpu::Conj(
dev_ctx.x_context()->xpu_stream,
x.numel(),
reinterpret_cast<const cuDoubleComplex*>(x.data<T>()),
reinterpret_cast<cuDoubleComplex*>(out->data<T>()));
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
} else {
using XPUType = typename XPUCopyTypeTrait<T>::Type;
const auto* input_data = x.data<T>();
Expand All @@ -100,15 +107,13 @@ void RealKernel(const Context& dev_ctx,
DenseTensor imag;
imag.Resize(x.dims());
dev_ctx.template Alloc<phi::dtype::Real<T>>(&imag);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::complex_spilt(
dev_ctx.x_context()->xpu_stream,
out->numel(),
reinterpret_cast<const XPUComplexType*>(x.data<T>()),
out->data<phi::dtype::Real<T>>(),
imag.data<phi::dtype::Real<T>>());
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}

template <typename T, typename Context>
Expand All @@ -126,15 +131,13 @@ void ImagKernel(const Context& dev_ctx,
DenseTensor real;
real.Resize(x.dims());
dev_ctx.template Alloc<phi::dtype::Real<T>>(&real);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::complex_spilt(
dev_ctx.x_context()->xpu_stream,
out->numel(),
reinterpret_cast<const XPUComplexType*>(x.data<T>()),
real.data<phi::dtype::Real<T>>(),
out->data<phi::dtype::Real<T>>());
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}

template <typename T, typename Context>
Expand Down Expand Up @@ -178,15 +181,13 @@ void ComplexKernel(const Context& dev_ctx,
}

dev_ctx.template Alloc<C>(out);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::combine_as_complex(
dev_ctx.x_context()->xpu_stream,
out->numel(),
x_data,
y_data,
reinterpret_cast<XPUComplexType*>(out->data<C>()));
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}
} // namespace phi

Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/xpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"
namespace xfft_internal::xpu {
template <typename T> // T supports float2, double2
int RemainderFunctor(int N, const T* input_x, const T* input_y, T* output);
int RemainderFunctor(const XPUStream stream,
int N,
const T* input_x,
const T* input_y,
T* output);
}
#endif

Expand Down Expand Up @@ -146,15 +150,13 @@ void RemainderKernel<phi::complex64, XPUContext>(const XPUContext& dev_ctx,
}

dev_ctx.template Alloc<T>(out);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
int r = xfft_internal::xpu::RemainderFunctor(
dev_ctx.x_context()->xpu_stream,
out->numel(),
reinterpret_cast<const cuFloatComplex*>(x_data),
reinterpret_cast<const cuFloatComplex*>(y_data),
reinterpret_cast<cuFloatComplex*>(out->data<T>()));
PADDLE_ENFORCE_XPU_SUCCESS(r);
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
}
#endif

Expand Down
Loading