1414
1515#ifdef PADDLE_WITH_XPU_FFT
1616#include " paddle/phi/kernels/complex_kernel.h"
17-
1817#include " fft/cuComplex.h"
1918#include " paddle/phi/backends/xpu/enforce_xpu.h"
2019#include " paddle/phi/common/type_traits.h"
@@ -68,6 +67,13 @@ void ConjKernel(const Context& dev_ctx,
6867 reinterpret_cast <cuFloatComplex*>(out->data <T>()));
6968 PADDLE_ENFORCE_XPU_SUCCESS (r);
7069 PADDLE_ENFORCE_XPU_SUCCESS (xpu_wait ());
70+ } else if (std::is_same_v<T, phi::complex128>) {
71+ int r = xfft_internal::xpu::Conj (
72+ x.numel (),
73+ reinterpret_cast <const cuDoubleComplex*>(x.data <T>()),
74+ reinterpret_cast <cuDoubleComplex*>(out->data <T>()));
75+ PADDLE_ENFORCE_XPU_SUCCESS (r);
76+ PADDLE_ENFORCE_XPU_SUCCESS (xpu_wait ());
7177 } else {
7278 using XPUType = typename XPUCopyTypeTrait<T>::Type;
7379 const auto * input_data = x.data <T>();
@@ -83,6 +89,8 @@ template <typename T, typename Context>
8389void RealKernel (const Context& dev_ctx,
8490 const DenseTensor& x,
8591 DenseTensor* out) {
92+ using XPUComplexType =
93+ typename XPUComplexTypeTrait<phi::dtype::Real<T>>::Type;
8694 if (out->numel () == 0 ) {
8795 dev_ctx.template Alloc <phi::dtype::Real<T>>(out);
8896 return ;
@@ -96,7 +104,7 @@ void RealKernel(const Context& dev_ctx,
96104 PADDLE_ENFORCE_XPU_SUCCESS (xpu_wait (dev_ctx.x_context ()->xpu_stream ));
97105 int r = xfft_internal::xpu::complex_spilt (
98106 out->numel (),
99- reinterpret_cast <const cuFloatComplex *>(x.data <T>()),
107+ reinterpret_cast <const XPUComplexType *>(x.data <T>()),
100108 out->data <phi::dtype::Real<T>>(),
101109 imag.data <phi::dtype::Real<T>>());
102110 PADDLE_ENFORCE_XPU_SUCCESS (r);
@@ -107,6 +115,8 @@ template <typename T, typename Context>
107115void ImagKernel (const Context& dev_ctx,
108116 const DenseTensor& x,
109117 DenseTensor* out) {
118+ using XPUComplexType =
119+ typename XPUComplexTypeTrait<phi::dtype::Real<T>>::Type;
110120 if (out->numel () == 0 ) {
111121 dev_ctx.template Alloc <phi::dtype::Real<T>>(out);
112122 return ;
@@ -120,7 +130,7 @@ void ImagKernel(const Context& dev_ctx,
120130 PADDLE_ENFORCE_XPU_SUCCESS (xpu_wait (dev_ctx.x_context ()->xpu_stream ));
121131 int r = xfft_internal::xpu::complex_spilt (
122132 out->numel (),
123- reinterpret_cast <const cuFloatComplex *>(x.data <T>()),
133+ reinterpret_cast <const XPUComplexType *>(x.data <T>()),
124134 real.data <phi::dtype::Real<T>>(),
125135 out->data <phi::dtype::Real<T>>());
126136 PADDLE_ENFORCE_XPU_SUCCESS (r);
@@ -133,6 +143,7 @@ void ComplexKernel(const Context& dev_ctx,
133143 const DenseTensor& y,
134144 DenseTensor* out) {
135145 using C = phi::dtype::complex <T>;
146+ using XPUComplexType = typename XPUComplexTypeTrait<T>::Type;
136147 if (out->numel () == 0 ) {
137148 dev_ctx.template Alloc <C>(out);
138149 return ;
@@ -173,7 +184,7 @@ void ComplexKernel(const Context& dev_ctx,
173184 out->numel (),
174185 x_data,
175186 y_data,
176- reinterpret_cast <cuFloatComplex *>(out->data <C>()));
187+ reinterpret_cast <XPUComplexType *>(out->data <C>()));
177188 PADDLE_ENFORCE_XPU_SUCCESS (r);
178189 PADDLE_ENFORCE_XPU_SUCCESS (xpu_wait ());
179190}
@@ -190,17 +201,21 @@ PD_REGISTER_KERNEL(conj,
190201 double ,
191202 phi::float16,
192203 phi::bfloat16,
193- phi::complex64) {}
204+ phi::complex64,
205+ phi::complex128) {}
194206
195- PD_REGISTER_KERNEL (real, XPU, ALL_LAYOUT, phi::RealKernel, phi::complex64) {
207+ PD_REGISTER_KERNEL (
208+ real, XPU, ALL_LAYOUT, phi::RealKernel, phi::complex64, phi::complex128) {
196209 kernel->OutputAt (0 ).SetDataType (phi::dtype::ToReal (kernel_key.dtype ()));
197210}
198211
199- PD_REGISTER_KERNEL (imag, XPU, ALL_LAYOUT, phi::ImagKernel, phi::complex64) {
212+ PD_REGISTER_KERNEL (
213+ imag, XPU, ALL_LAYOUT, phi::ImagKernel, phi::complex64, phi::complex128) {
200214 kernel->OutputAt (0 ).SetDataType (phi::dtype::ToReal (kernel_key.dtype ()));
201215}
202216
203- PD_REGISTER_KERNEL (complex , XPU, ALL_LAYOUT, phi::ComplexKernel, float ) {
217+ PD_REGISTER_KERNEL (
218+ complex , XPU, ALL_LAYOUT, phi::ComplexKernel, float , double ) {
204219 kernel->OutputAt (0 ).SetDataType (phi::dtype::ToComplex (kernel_key.dtype ()));
205220}
206221#endif // PADDLE_WITH_XPU_FFT
0 commit comments