Skip to content

Commit ea557ba

Browse files
authored
[XPU] support complex128(complex double) (#74161)
1 parent 76860dc commit ea557ba

File tree

11 files changed

+199
-45
lines changed

11 files changed

+199
-45
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,6 +1902,7 @@ XPUOpMap& get_kl3_ops() {
19021902
{"expand_v2_grad",
19031903
XPUKernelSet({phi::DataType::FLOAT32,
19041904
phi::DataType::INT64,
1905+
phi::DataType::FLOAT64,
19051906
phi::DataType::BFLOAT16,
19061907
phi::DataType::FLOAT16})},
19071908
{"eye",
@@ -2043,12 +2044,18 @@ XPUOpMap& get_kl3_ops() {
20432044
phi::DataType::INT64,
20442045
phi::DataType::INT32,
20452046
phi::DataType::COMPLEX64})},
2046-
{"real", XPUKernelSet({phi::DataType::COMPLEX64})},
2047-
{"real_grad", XPUKernelSet({phi::DataType::COMPLEX64})},
2048-
{"imag", XPUKernelSet({phi::DataType::COMPLEX64})},
2049-
{"imag_grad", XPUKernelSet({phi::DataType::COMPLEX64})},
2050-
{"complex", XPUKernelSet({phi::DataType::FLOAT32})},
2051-
{"complex_grad", XPUKernelSet({phi::DataType::FLOAT32})},
2047+
{"real",
2048+
XPUKernelSet({phi::DataType::COMPLEX64, phi::DataType::COMPLEX128})},
2049+
{"real_grad",
2050+
XPUKernelSet({phi::DataType::COMPLEX64, phi::DataType::COMPLEX128})},
2051+
{"imag",
2052+
XPUKernelSet({phi::DataType::COMPLEX64, phi::DataType::COMPLEX128})},
2053+
{"imag_grad",
2054+
XPUKernelSet({phi::DataType::COMPLEX64, phi::DataType::COMPLEX128})},
2055+
{"complex",
2056+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
2057+
{"complex_grad",
2058+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
20522059
{"as_complex", XPUKernelSet({phi::DataType::FLOAT32})},
20532060
{"as_real", XPUKernelSet({phi::DataType::COMPLEX64})},
20542061
{"fft_c2c", XPUKernelSet({phi::DataType::COMPLEX64})},

paddle/phi/backends/xpu/xpu_header.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#ifdef PADDLE_WITH_XPU
1817
#include <map>
1918
#include <string>
2019
#include <unordered_map>
2120

2221
#include "paddle/phi/common/bfloat16.h"
22+
#include "paddle/phi/common/complex.h"
2323
#include "paddle/phi/common/float16.h"
2424
#ifdef PADDLE_WITH_XPU_BKCL
2525
#include "xpu/bkcl.h"
@@ -30,7 +30,9 @@ limitations under the License. */
3030
#ifdef PADDLE_WITH_XPU_PLUGIN
3131
#include "xpu/plugin.h"
3232
#endif
33-
33+
#ifdef PADDLE_WITH_XPU_FFT
34+
#include "fft/cuComplex.h"
35+
#endif
3436
namespace xpu = baidu::xpu::api;
3537

3638
template <typename T>
@@ -107,4 +109,22 @@ class XPUCopyTypeTrait<uint8_t> {
107109
using Type = int8_t;
108110
};
109111

112+
#ifdef PADDLE_WITH_XPU_FFT
113+
template <typename T>
114+
class XPUComplexTypeTrait {
115+
public:
116+
using Type = T;
117+
};
118+
119+
template <>
120+
class XPUComplexTypeTrait<float> {
121+
public:
122+
using Type = cuFloatComplex;
123+
};
124+
125+
template <>
126+
class XPUComplexTypeTrait<double> {
127+
public:
128+
using Type = cuDoubleComplex;
129+
};
110130
#endif

paddle/phi/core/visit_type.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,13 @@ namespace phi {
374374
}()
375375

376376
#ifdef PADDLE_WITH_XPU_FFT
377-
#define PD_XPU_COMPLEX64_CASE(NAME, ...) \
378-
PD_PRIVATE_CASE_TYPE( \
379-
NAME, ::phi::DataType::COMPLEX64, phi::complex64, __VA_ARGS__)
377+
#define PD_XPU_COMPLEX_CASE(NAME, ...) \
378+
PD_PRIVATE_CASE_TYPE( \
379+
NAME, ::phi::DataType::COMPLEX64, phi::complex64, __VA_ARGS__) \
380+
PD_PRIVATE_CASE_TYPE( \
381+
NAME, ::phi::DataType::COMPLEX128, phi::complex128, __VA_ARGS__)
380382
#else
381-
#define PD_XPU_COMPLEX64_CASE(NAME, ...)
383+
#define PD_XPU_COMPLEX_CASE(NAME, ...)
382384
#endif
383385

384386
#if defined(PADDLE_WITH_XPU)
@@ -399,7 +401,7 @@ namespace phi {
399401
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::FLOAT32, float, __VA_ARGS__) \
400402
PD_PRIVATE_CASE_TYPE( \
401403
NAME, ::phi::DataType::FLOAT64, double, __VA_ARGS__) \
402-
PD_XPU_COMPLEX64_CASE(NAME, __VA_ARGS__) \
404+
PD_XPU_COMPLEX_CASE(NAME, __VA_ARGS__) \
403405
default: \
404406
PADDLE_THROW(common::errors::InvalidArgument( \
405407
"Invalid enum data type `%d`.", static_cast<int>(__dtype__))); \

paddle/phi/kernels/xpu/as_complex_kernel.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ void AsComplexKernel(const Context& dev_ctx,
3636

3737
} // namespace phi
3838

39-
PD_REGISTER_KERNEL(as_complex, XPU, ALL_LAYOUT, phi::AsComplexKernel, float) {
39+
PD_REGISTER_KERNEL(
40+
as_complex, XPU, ALL_LAYOUT, phi::AsComplexKernel, float, double) {
4041
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
4142
}
4243
#endif // PADDLE_WITH_XPU_FFT

paddle/phi/kernels/xpu/as_real_kernel.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ void AsRealKernel(const Context& dev_ctx,
3737

3838
} // namespace phi
3939

40-
PD_REGISTER_KERNEL(
41-
as_real, XPU, ALL_LAYOUT, phi::AsRealKernel, phi::complex64) {
40+
PD_REGISTER_KERNEL(as_real,
41+
XPU,
42+
ALL_LAYOUT,
43+
phi::AsRealKernel,
44+
phi::complex64,
45+
phi::complex128) {
4246
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
4347
}
4448
#endif // PADDLE_WITH_XPU_FFT

paddle/phi/kernels/xpu/complex_grad_kernel.cc

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#ifdef PADDLE_WITH_XPU_FFT
1616
#include "paddle/phi/kernels/complex_grad_kernel.h"
17-
1817
#include "fft/cuComplex.h"
1918
#include "paddle/phi/backends/xpu/enforce_xpu.h"
2019
#include "paddle/phi/common/type_traits.h"
@@ -63,6 +62,8 @@ template <typename T, typename Context>
6362
void RealGradKernel(const Context& dev_ctx,
6463
const DenseTensor& dout,
6564
DenseTensor* dx) {
65+
using XPUComplexType =
66+
typename XPUComplexTypeTrait<phi::dtype::Real<T>>::Type;
6667
if (dx && dx->numel() == 0) {
6768
dev_ctx.template Alloc<T>(dx);
6869
return;
@@ -79,7 +80,7 @@ void RealGradKernel(const Context& dev_ctx,
7980
reinterpret_cast<const phi::dtype::Real<T>*>(
8081
dout.data<phi::dtype::Real<T>>()),
8182
imag.data<phi::dtype::Real<T>>(),
82-
reinterpret_cast<cuFloatComplex*>(dx_data));
83+
reinterpret_cast<XPUComplexType*>(dx_data));
8384
PADDLE_ENFORCE_XPU_SUCCESS(r);
8485
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
8586
}
@@ -88,6 +89,8 @@ template <typename T, typename Context>
8889
void ImagGradKernel(const Context& dev_ctx,
8990
const DenseTensor& dout,
9091
DenseTensor* dx) {
92+
using XPUComplexType =
93+
typename XPUComplexTypeTrait<phi::dtype::Real<T>>::Type;
9194
if (dx && dx->numel() == 0) {
9295
dev_ctx.template Alloc<T>(dx);
9396
return;
@@ -104,7 +107,7 @@ void ImagGradKernel(const Context& dev_ctx,
104107
real.data<phi::dtype::Real<T>>(),
105108
reinterpret_cast<const phi::dtype::Real<T>*>(
106109
dout.data<phi::dtype::Real<T>>()),
107-
reinterpret_cast<cuFloatComplex*>(dx_data));
110+
reinterpret_cast<XPUComplexType*>(dx_data));
108111
PADDLE_ENFORCE_XPU_SUCCESS(r);
109112
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
110113
}
@@ -117,6 +120,7 @@ void ComplexGradKernel(const Context& dev_ctx,
117120
DenseTensor* dx,
118121
DenseTensor* dy) {
119122
using C = phi::dtype::complex<T>;
123+
using XPUComplexType = typename XPUComplexTypeTrait<T>::Type;
120124
if (dout.numel() == 0) {
121125
if (dx) {
122126
if (dx->numel() == 0) {
@@ -146,7 +150,7 @@ void ComplexGradKernel(const Context& dev_ctx,
146150
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
147151
int r = xfft_internal::xpu::complex_spilt(
148152
numel,
149-
reinterpret_cast<const cuFloatComplex*>(dout.data<C>()),
153+
reinterpret_cast<const XPUComplexType*>(dout.data<C>()),
150154
real_data,
151155
imag_data);
152156
PADDLE_ENFORCE_XPU_SUCCESS(r);
@@ -171,18 +175,26 @@ void ComplexGradKernel(const Context& dev_ctx,
171175
}
172176
} // namespace phi
173177

174-
PD_REGISTER_KERNEL(
175-
imag_grad, XPU, ALL_LAYOUT, phi::ImagGradKernel, phi::complex64) {
178+
PD_REGISTER_KERNEL(imag_grad,
179+
XPU,
180+
ALL_LAYOUT,
181+
phi::ImagGradKernel,
182+
phi::complex64,
183+
phi::complex128) {
176184
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
177185
}
178186

179-
PD_REGISTER_KERNEL(
180-
real_grad, XPU, ALL_LAYOUT, phi::RealGradKernel, phi::complex64) {
187+
PD_REGISTER_KERNEL(real_grad,
188+
XPU,
189+
ALL_LAYOUT,
190+
phi::RealGradKernel,
191+
phi::complex64,
192+
phi::complex128) {
181193
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
182194
}
183195

184196
PD_REGISTER_KERNEL(
185-
complex_grad, XPU, ALL_LAYOUT, phi::ComplexGradKernel, float) {
197+
complex_grad, XPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) {
186198
kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
187199
}
188200
#endif // PADDLE_WITH_XPU_FFT

paddle/phi/kernels/xpu/complex_kernel.cc

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#ifdef PADDLE_WITH_XPU_FFT
1616
#include "paddle/phi/kernels/complex_kernel.h"
17-
1817
#include "fft/cuComplex.h"
1918
#include "paddle/phi/backends/xpu/enforce_xpu.h"
2019
#include "paddle/phi/common/type_traits.h"
@@ -68,6 +67,13 @@ void ConjKernel(const Context& dev_ctx,
6867
reinterpret_cast<cuFloatComplex*>(out->data<T>()));
6968
PADDLE_ENFORCE_XPU_SUCCESS(r);
7069
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
70+
} else if (std::is_same_v<T, phi::complex128>) {
71+
int r = xfft_internal::xpu::Conj(
72+
x.numel(),
73+
reinterpret_cast<const cuDoubleComplex*>(x.data<T>()),
74+
reinterpret_cast<cuDoubleComplex*>(out->data<T>()));
75+
PADDLE_ENFORCE_XPU_SUCCESS(r);
76+
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
7177
} else {
7278
using XPUType = typename XPUCopyTypeTrait<T>::Type;
7379
const auto* input_data = x.data<T>();
@@ -83,6 +89,8 @@ template <typename T, typename Context>
8389
void RealKernel(const Context& dev_ctx,
8490
const DenseTensor& x,
8591
DenseTensor* out) {
92+
using XPUComplexType =
93+
typename XPUComplexTypeTrait<phi::dtype::Real<T>>::Type;
8694
if (out->numel() == 0) {
8795
dev_ctx.template Alloc<phi::dtype::Real<T>>(out);
8896
return;
@@ -96,7 +104,7 @@ void RealKernel(const Context& dev_ctx,
96104
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
97105
int r = xfft_internal::xpu::complex_spilt(
98106
out->numel(),
99-
reinterpret_cast<const cuFloatComplex*>(x.data<T>()),
107+
reinterpret_cast<const XPUComplexType*>(x.data<T>()),
100108
out->data<phi::dtype::Real<T>>(),
101109
imag.data<phi::dtype::Real<T>>());
102110
PADDLE_ENFORCE_XPU_SUCCESS(r);
@@ -107,6 +115,8 @@ template <typename T, typename Context>
107115
void ImagKernel(const Context& dev_ctx,
108116
const DenseTensor& x,
109117
DenseTensor* out) {
118+
using XPUComplexType =
119+
typename XPUComplexTypeTrait<phi::dtype::Real<T>>::Type;
110120
if (out->numel() == 0) {
111121
dev_ctx.template Alloc<phi::dtype::Real<T>>(out);
112122
return;
@@ -120,7 +130,7 @@ void ImagKernel(const Context& dev_ctx,
120130
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx.x_context()->xpu_stream));
121131
int r = xfft_internal::xpu::complex_spilt(
122132
out->numel(),
123-
reinterpret_cast<const cuFloatComplex*>(x.data<T>()),
133+
reinterpret_cast<const XPUComplexType*>(x.data<T>()),
124134
real.data<phi::dtype::Real<T>>(),
125135
out->data<phi::dtype::Real<T>>());
126136
PADDLE_ENFORCE_XPU_SUCCESS(r);
@@ -133,6 +143,7 @@ void ComplexKernel(const Context& dev_ctx,
133143
const DenseTensor& y,
134144
DenseTensor* out) {
135145
using C = phi::dtype::complex<T>;
146+
using XPUComplexType = typename XPUComplexTypeTrait<T>::Type;
136147
if (out->numel() == 0) {
137148
dev_ctx.template Alloc<C>(out);
138149
return;
@@ -173,7 +184,7 @@ void ComplexKernel(const Context& dev_ctx,
173184
out->numel(),
174185
x_data,
175186
y_data,
176-
reinterpret_cast<cuFloatComplex*>(out->data<C>()));
187+
reinterpret_cast<XPUComplexType*>(out->data<C>()));
177188
PADDLE_ENFORCE_XPU_SUCCESS(r);
178189
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait());
179190
}
@@ -190,17 +201,21 @@ PD_REGISTER_KERNEL(conj,
190201
double,
191202
phi::float16,
192203
phi::bfloat16,
193-
phi::complex64) {}
204+
phi::complex64,
205+
phi::complex128) {}
194206

195-
PD_REGISTER_KERNEL(real, XPU, ALL_LAYOUT, phi::RealKernel, phi::complex64) {
207+
PD_REGISTER_KERNEL(
208+
real, XPU, ALL_LAYOUT, phi::RealKernel, phi::complex64, phi::complex128) {
196209
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
197210
}
198211

199-
PD_REGISTER_KERNEL(imag, XPU, ALL_LAYOUT, phi::ImagKernel, phi::complex64) {
212+
PD_REGISTER_KERNEL(
213+
imag, XPU, ALL_LAYOUT, phi::ImagKernel, phi::complex64, phi::complex128) {
200214
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
201215
}
202216

203-
PD_REGISTER_KERNEL(complex, XPU, ALL_LAYOUT, phi::ComplexKernel, float) {
217+
PD_REGISTER_KERNEL(
218+
complex, XPU, ALL_LAYOUT, phi::ComplexKernel, float, double) {
204219
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
205220
}
206221
#endif // PADDLE_WITH_XPU_FFT

paddle/phi/kernels/xpu/contiguous_kernel.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ void ContiguousKernel(const Context& dev_ctx,
5757
}
5858

5959
#ifdef PADDLE_WITH_XPU_FFT
60-
template <>
61-
void ContiguousKernel<phi::complex64, XPUContext>(const XPUContext& dev_ctx,
62-
const DenseTensor& input,
63-
DenseTensor* out) {
64-
using T = phi::complex64;
65-
60+
template <typename T>
61+
typename std::enable_if<std::is_same<T, phi::complex64>::value ||
62+
std::is_same<T, phi::complex128>::value>::type
63+
ComplexContiguousKernelImpl(const XPUContext& dev_ctx,
64+
const DenseTensor& input,
65+
DenseTensor* out) {
6666
DenseTensorMeta meta = input.meta();
6767
meta.strides = meta.calc_strides(meta.dims);
6868
meta.offset = 0;
@@ -105,6 +105,19 @@ void ContiguousKernel<phi::complex64, XPUContext>(const XPUContext& dev_ctx,
105105
PADDLE_ENFORCE_XDNN_SUCCESS(r, "as_strided");
106106
}
107107
}
108+
template <>
109+
void ContiguousKernel<phi::complex64, XPUContext>(const XPUContext& dev_ctx,
110+
const DenseTensor& input,
111+
DenseTensor* out) {
112+
ComplexContiguousKernelImpl<phi::complex64>(dev_ctx, input, out);
113+
}
114+
115+
template <>
116+
void ContiguousKernel<phi::complex128, XPUContext>(const XPUContext& dev_ctx,
117+
const DenseTensor& input,
118+
DenseTensor* out) {
119+
ComplexContiguousKernelImpl<phi::complex128>(dev_ctx, input, out);
120+
}
108121
#endif
109122

110123
} // namespace phi
@@ -123,6 +136,7 @@ PD_REGISTER_KERNEL(contiguous,
123136
double,
124137
#ifdef PADDLE_WITH_XPU_FFT
125138
phi::complex64,
139+
phi::complex128,
126140
#endif
127141
phi::float16,
128142
phi::bfloat16) {

0 commit comments

Comments
 (0)