Skip to content

Commit 17952cd

Browse files
authored
Adding complex to complex FFT implementation (#4903)
* added cfft function * fixed formatting issues * Apply suggestion from Copilot * Update documentation for cfft * Remove duplicated logic in stft * Improve test coverage for cfft * Fix formatting issues * Fix import statement * Update crates/burn-backend-tests/tests/cubecl/fft.rs * Update crates/burn-tensor/src/tensor/signal/fft.rs
1 parent 4306ccb commit 17952cd

3 files changed

Lines changed: 359 additions & 30 deletions

File tree

crates/burn-backend-tests/tests/cubecl/fft.rs

Lines changed: 241 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::*;
2-
use burn_tensor::signal::{irfft, rfft};
2+
use burn_tensor::signal::{cfft, irfft, rfft};
33
use burn_tensor::{TensorData, Tolerance};
44

55
#[test]
@@ -468,3 +468,243 @@ fn rfft_2d_with_n_padded() {
468468
);
469469
}
470470
}
471+
472+
// ---- cfft tests ----
473+
474+
#[test]
475+
fn cfft_output_has_n_bins() {
476+
// cfft should return N bins, not N/2+1
477+
let re = TestTensor::<1>::from([1.0, 2.0, 3.0, 4.0]);
478+
let im = TestTensor::<1>::from([0.0, 0.0, 0.0, 0.0]);
479+
let (out_re, out_im) = cfft(re, im, 0, None);
480+
481+
assert_eq!(out_re.dims(), [4]);
482+
assert_eq!(out_im.dims(), [4]);
483+
}
484+
485+
#[test]
486+
fn cfft_pure_real_input() {
487+
// When imaginary part is zero, cfft should produce the same result
488+
// as extending rfft to the full spectrum
489+
let signal = [1.0f32, 2.0, 3.0, 4.0];
490+
let re = TestTensor::<1>::from(signal);
491+
let im = TestTensor::<1>::from([0.0, 0.0, 0.0, 0.0]);
492+
493+
let (cfft_re, cfft_im) = cfft(re, im, 0, None);
494+
495+
// Expected: DFT of [1,2,3,4]
496+
// X[0] = 10, X[1] = -2+2i, X[2] = -2, X[3] = -2-2i
497+
let expected_re = TensorData::from([10.0, -2.0, -2.0, -2.0]);
498+
let expected_im = TensorData::from([0.0, 2.0, 0.0, -2.0]);
499+
500+
cfft_re
501+
.into_data()
502+
.assert_approx_eq::<FloatElem>(&expected_re, Tolerance::absolute(1e-3));
503+
cfft_im
504+
.into_data()
505+
.assert_approx_eq::<FloatElem>(&expected_im, Tolerance::absolute(1e-3));
506+
}
507+
508+
#[test]
509+
fn cfft_pure_imaginary_input() {
510+
// Signal is purely imaginary: z[n] = i * [1, 2, 3, 4]
511+
// FFT(i*x) = i*FFT(x), so result_re = -FFT(x)_im, result_im = FFT(x)_re
512+
let re = TestTensor::<1>::from([0.0, 0.0, 0.0, 0.0]);
513+
let im = TestTensor::<1>::from([1.0, 2.0, 3.0, 4.0]);
514+
515+
let (cfft_re, cfft_im) = cfft(re, im, 0, None);
516+
517+
// FFT([1,2,3,4]) = [10, -2+2i, -2, -2-2i]
518+
// i * FFT(x) = i * [10, -2+2i, -2, -2-2i]
519+
// = [-0, -2+(-2)i, 0, 2+(-2)i] → re = [0, -2, 0, 2], im = [10, -2, -2, -2]
520+
let expected_re = TensorData::from([0.0, -2.0, 0.0, 2.0]);
521+
let expected_im = TensorData::from([10.0, -2.0, -2.0, -2.0]);
522+
523+
cfft_re
524+
.into_data()
525+
.assert_approx_eq::<FloatElem>(&expected_re, Tolerance::absolute(1e-3));
526+
cfft_im
527+
.into_data()
528+
.assert_approx_eq::<FloatElem>(&expected_im, Tolerance::absolute(1e-3));
529+
}
530+
531+
#[test]
532+
fn cfft_complex_exponential() {
533+
// z[n] = exp(i * 2π * n / 4) for n=0..3, i.e. frequency bin 1
534+
// re = [cos(0), cos(π/2), cos(π), cos(3π/2)] = [1, 0, -1, 0]
535+
// im = [sin(0), sin(π/2), sin(π), sin(3π/2)] = [0, 1, 0, -1]
536+
// DFT should be: X[0]=0, X[1]=4, X[2]=0, X[3]=0
537+
let re = TestTensor::<1>::from([1.0, 0.0, -1.0, 0.0]);
538+
let im = TestTensor::<1>::from([0.0, 1.0, 0.0, -1.0]);
539+
540+
let (cfft_re, cfft_im) = cfft(re, im, 0, None);
541+
542+
let expected_re = TensorData::from([0.0, 4.0, 0.0, 0.0]);
543+
let expected_im = TensorData::from([0.0, 0.0, 0.0, 0.0]);
544+
545+
cfft_re
546+
.into_data()
547+
.assert_approx_eq::<FloatElem>(&expected_re, Tolerance::absolute(1e-3));
548+
cfft_im
549+
.into_data()
550+
.assert_approx_eq::<FloatElem>(&expected_im, Tolerance::absolute(1e-3));
551+
}
552+
553+
#[test]
554+
fn cfft_zeros() {
555+
let re = TestTensor::<1>::from([0.0, 0.0, 0.0, 0.0]);
556+
let im = TestTensor::<1>::from([0.0, 0.0, 0.0, 0.0]);
557+
558+
let (cfft_re, cfft_im) = cfft(re, im, 0, None);
559+
560+
let expected = TensorData::from([0.0, 0.0, 0.0, 0.0]);
561+
562+
cfft_re
563+
.into_data()
564+
.assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-4));
565+
cfft_im
566+
.into_data()
567+
.assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-4));
568+
}
569+
570+
#[test]
571+
fn cfft_dim1_2d_tensor() {
572+
// Apply cfft along dim=1 on a 2D tensor
573+
// Row 0: pure real [1, 2, 3, 4] → DFT = [10, -2+2i, -2, -2-2i]
574+
// Row 1: complex exponential exp(i·2π·n/4) → DFT = [0, 4, 0, 0]
575+
let re = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0], [1.0, 0.0, -1.0, 0.0]]);
576+
let im = TestTensor::<2>::from([[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, -1.0]]);
577+
578+
let (cfft_re, cfft_im) = cfft(re, im, 1, None);
579+
580+
// Output should be [2, 4] (N=4 bins per row)
581+
assert_eq!(cfft_re.dims(), [2, 4]);
582+
assert_eq!(cfft_im.dims(), [2, 4]);
583+
584+
// Row 0: DFT of [1,2,3,4]+i*0 = [10, -2+2i, -2, -2-2i]
585+
// Row 1: DFT of exp(i*2π*n/4) = [0, 4, 0, 0]
586+
let expected_re = TensorData::from([[10.0, -2.0, -2.0, -2.0], [0.0, 4.0, 0.0, 0.0]]);
587+
let expected_im = TensorData::from([[0.0, 2.0, 0.0, -2.0], [0.0, 0.0, 0.0, 0.0]]);
588+
589+
cfft_re
590+
.into_data()
591+
.assert_approx_eq::<FloatElem>(&expected_re, Tolerance::absolute(1e-3));
592+
cfft_im
593+
.into_data()
594+
.assert_approx_eq::<FloatElem>(&expected_im, Tolerance::absolute(1e-3));
595+
}
596+
597+
#[test]
598+
fn cfft_with_n_padding() {
599+
// Signal length 2, padded to N=4
600+
// z = [1+0i, 0+0i] padded to [1+0i, 0, 0, 0]
601+
// DFT = [1, 1, 1, 1] (all real, zero imag)
602+
let re = TestTensor::<1>::from([1.0, 0.0]);
603+
let im = TestTensor::<1>::from([0.0, 0.0]);
604+
605+
let (cfft_re, cfft_im) = cfft(re, im, 0, Some(4));
606+
607+
assert_eq!(cfft_re.dims(), [4]);
608+
609+
let expected_re = TensorData::from([1.0, 1.0, 1.0, 1.0]);
610+
let expected_im = TensorData::from([0.0, 0.0, 0.0, 0.0]);
611+
612+
cfft_re
613+
.into_data()
614+
.assert_approx_eq::<FloatElem>(&expected_re, Tolerance::absolute(1e-3));
615+
cfft_im
616+
.into_data()
617+
.assert_approx_eq::<FloatElem>(&expected_im, Tolerance::absolute(1e-3));
618+
}
619+
620+
#[test]
621+
fn cfft_length_1() {
622+
// N=1: DFT of a single complex value is itself
623+
let re = TestTensor::<1>::from([3.0]);
624+
let im = TestTensor::<1>::from([5.0]);
625+
626+
let (cfft_re, cfft_im) = cfft(re, im, 0, None);
627+
628+
assert_eq!(cfft_re.dims(), [1]);
629+
cfft_re
630+
.into_data()
631+
.assert_approx_eq::<FloatElem>(&TensorData::from([3.0]), Tolerance::absolute(1e-4));
632+
cfft_im
633+
.into_data()
634+
.assert_approx_eq::<FloatElem>(&TensorData::from([5.0]), Tolerance::absolute(1e-4));
635+
}
636+
637+
#[test]
638+
fn cfft_length_2() {
639+
// N=2: z = [a, b] → X[0] = a+b, X[1] = a-b
640+
// z = [1+2i, 3+4i]
641+
// X[0] = (1+3) + i(2+4) = 4+6i
642+
// X[1] = (1-3) + i(2-4) = -2-2i
643+
let re = TestTensor::<1>::from([1.0, 3.0]);
644+
let im = TestTensor::<1>::from([2.0, 4.0]);
645+
646+
let (cfft_re, cfft_im) = cfft(re, im, 0, None);
647+
648+
assert_eq!(cfft_re.dims(), [2]);
649+
cfft_re
650+
.into_data()
651+
.assert_approx_eq::<FloatElem>(&TensorData::from([4.0, -2.0]), Tolerance::absolute(1e-4));
652+
cfft_im
653+
.into_data()
654+
.assert_approx_eq::<FloatElem>(&TensorData::from([6.0, -2.0]), Tolerance::absolute(1e-4));
655+
}
656+
657+
#[test]
658+
#[should_panic(expected = "same shape")]
659+
fn cfft_rejects_mismatched_shapes() {
660+
let re = TestTensor::<1>::from([1.0, 2.0, 3.0, 4.0]);
661+
let im = TestTensor::<1>::from([1.0, 2.0]);
662+
let _ = cfft(re, im, 0, None);
663+
}
664+
665+
#[test]
666+
fn cfft_dim0_2d_tensor() {
667+
// Apply cfft along dim=0 on a 2D tensor (4 rows, 2 columns)
668+
// Column 0: complex exponential exp(i·2π·n/4) → DFT = [0, 4, 0, 0]
669+
// Column 1: pure real [1, 2, 3, 4] → DFT = [10, -2+2i, -2, -2-2i]
670+
let re = TestTensor::<2>::from([[1.0, 1.0], [0.0, 2.0], [-1.0, 3.0], [0.0, 4.0]]);
671+
let im = TestTensor::<2>::from([[0.0, 0.0], [1.0, 0.0], [0.0, 0.0], [-1.0, 0.0]]);
672+
673+
let (cfft_re, cfft_im) = cfft(re, im, 0, None);
674+
675+
assert_eq!(cfft_re.dims(), [4, 2]);
676+
assert_eq!(cfft_im.dims(), [4, 2]);
677+
678+
let expected_re = TensorData::from([[0.0, 10.0], [4.0, -2.0], [0.0, -2.0], [0.0, -2.0]]);
679+
let expected_im = TensorData::from([[0.0, 0.0], [0.0, 2.0], [0.0, 0.0], [0.0, -2.0]]);
680+
681+
cfft_re
682+
.into_data()
683+
.assert_approx_eq::<FloatElem>(&expected_re, Tolerance::absolute(1e-3));
684+
cfft_im
685+
.into_data()
686+
.assert_approx_eq::<FloatElem>(&expected_im, Tolerance::absolute(1e-3));
687+
}
688+
689+
#[test]
690+
fn cfft_with_n_truncation() {
691+
// Signal length 8, truncated to n=4 → DFT of [1+0i, 2+0i, 3+0i, 4+0i]
692+
// Trailing values are discarded, not included in the transform.
693+
let re = TestTensor::<1>::from([1.0, 2.0, 3.0, 4.0, 99.0, 99.0, 99.0, 99.0]);
694+
let im = TestTensor::<1>::from([0.0, 0.0, 0.0, 0.0, 99.0, 99.0, 99.0, 99.0]);
695+
696+
let (cfft_re, cfft_im) = cfft(re, im, 0, Some(4));
697+
698+
assert_eq!(cfft_re.dims(), [4]);
699+
700+
// DFT of [1,2,3,4] = [10, -2+2i, -2, -2-2i]
701+
let expected_re = TensorData::from([10.0, -2.0, -2.0, -2.0]);
702+
let expected_im = TensorData::from([0.0, 2.0, 0.0, -2.0]);
703+
704+
cfft_re
705+
.into_data()
706+
.assert_approx_eq::<FloatElem>(&expected_re, Tolerance::absolute(1e-3));
707+
cfft_im
708+
.into_data()
709+
.assert_approx_eq::<FloatElem>(&expected_im, Tolerance::absolute(1e-3));
710+
}

crates/burn-tensor/src/tensor/signal/fft.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use alloc::vec;
12
use burn_backend::Backend;
23

34
use crate::Tensor;
@@ -149,3 +150,118 @@ pub fn irfft<B: Backend, const D: usize>(
149150
);
150151
Tensor::new(TensorPrimitive::Float(signal))
151152
}
153+
154+
/// Computes the 1-dimensional discrete Fourier Transform of complex-valued input.
155+
///
156+
/// Internally calls [`rfft`] on the real and imaginary parts separately,
157+
/// extends each half-spectrum to the full `N`-bin spectrum via Hermitian
158+
/// symmetry.
159+
///
160+
/// Autodiff is not yet supported.
161+
///
162+
#[cfg_attr(
163+
doc,
164+
doc = r#"
165+
166+
Due to the linearity of the Fourier Transform, a complex-valued signal $x\[n\] = x_{re}\[n\] + i x_{im}\[n\]$ can be transformed by applying the FFT to its real and imaginary parts separately:
167+
168+
$$ \text{FFT}(x\[n\]) = \text{FFT}(x_{re}\[n\]) + i \text{FFT}(x_{im}\[n\]) $$
169+
170+
Since $x_{re}\[n\]$ and $x_{im}\[n\]$ are purely real, their transforms can be computed efficiently using the real FFT ([`rfft`]). The full spectrum is then reconstructed by exploiting Hermitian symmetry.
171+
"#
172+
)]
173+
#[cfg_attr(not(doc), doc = r"X\[k\] = Σ x\[n\] * exp(-i*2πkn/N)")]
174+
///
175+
/// # Arguments
176+
///
177+
/// * `signal_re` - The real part of the complex input signal.
178+
/// * `signal_im` - The imaginary part of the complex input signal. Must have the
179+
/// same shape as `signal_re`.
180+
/// * `dim` - The dimension along which to take the FFT.
181+
/// * `n` - Optional FFT length. When `None`, the signal must be a power of two
182+
/// along `dim`. When `Some(n)`, `n` must also be a power of two; the signal is
183+
/// truncated or zero-padded to length `n`.
184+
///
185+
/// # Returns
186+
///
187+
/// A tuple `(re, im)` representing the full complex spectrum, each with `n`
188+
/// elements along `dim`.
189+
///
190+
/// # Example
191+
///
192+
/// ```rust
193+
/// use burn_tensor::backend::Backend;
194+
/// use burn_tensor::Tensor;
195+
///
196+
/// fn example<B: Backend>() {
197+
/// let device = B::Device::default();
198+
/// let re = Tensor::<B, 1>::from_floats([1.0, 0.0, -1.0, 0.0], &device);
199+
/// let im = Tensor::<B, 1>::from_floats([0.0, 1.0, 0.0, -1.0], &device);
200+
/// let (spec_re, spec_im) = burn_tensor::signal::cfft(re, im, 0, None);
201+
/// }
202+
/// ```
203+
pub fn cfft<B: Backend, const D: usize>(
204+
signal_re: Tensor<B, D>,
205+
signal_im: Tensor<B, D>,
206+
dim: usize,
207+
n: Option<usize>,
208+
) -> (Tensor<B, D>, Tensor<B, D>) {
209+
assert!(
210+
signal_re.shape() == signal_im.shape(),
211+
"cfft: signal_re and signal_im must have the same shape, \
212+
got {:?} and {:?}",
213+
signal_re.shape(),
214+
signal_im.shape(),
215+
);
216+
217+
check!(TensorCheck::check_dim::<D>(dim));
218+
let fft_size = n.unwrap_or(signal_re.dims()[dim]);
219+
220+
// rfft validates power-of-two and n constraints internally
221+
let (xr, xi) = rfft(signal_re, dim, n);
222+
let (yr, yi) = rfft(signal_im, dim, n);
223+
224+
// Extend half-spectra (N/2+1 bins) to full N-bin spectra via Hermitian symmetry
225+
let (xr, xi) = hermitian_extend(xr, xi, dim, fft_size);
226+
let (yr, yi) = hermitian_extend(yr, yi, dim, fft_size);
227+
228+
// FFT(z) = FFT(x) + i·FFT(y)
229+
// = (Xr + i·Xi) + i·(Yr + i·Yi)
230+
// = (Xr - Yi) + i·(Xi + Yr)
231+
(xr - yi, xi + yr)
232+
}
233+
234+
/// Extend a half-spectrum from [`rfft`] (`N/2 + 1` bins) to the full `N`-bin
235+
/// spectrum using Hermitian symmetry: `X[k] = conj(X[N-k])` for `k > N/2`.
236+
pub(super) fn hermitian_extend<B: Backend, const D: usize>(
237+
half_re: Tensor<B, D>,
238+
half_im: Tensor<B, D>,
239+
dim: usize,
240+
full_len: usize,
241+
) -> (Tensor<B, D>, Tensor<B, D>) {
242+
let half_len = half_re.dims()[dim]; // N/2 + 1
243+
244+
// For N <= 2, the half-spectrum already covers all bins
245+
if full_len <= half_len {
246+
return (half_re, half_im);
247+
}
248+
249+
// Mirror bins: reverse of bins 1..N/2-1 (skipping the Nyquist bin),
250+
// with conjugated imaginary part. This produces X[N/2+1], X[N/2+2], ..., X[N-1]
251+
let mirror_len = full_len - half_len; // N/2 - 1
252+
let mirror_re = half_re
253+
.clone()
254+
.narrow(dim, 1, mirror_len)
255+
.flip([dim as isize]);
256+
let mirror_im = half_im
257+
.clone()
258+
.narrow(dim, 1, mirror_len)
259+
.flip([dim as isize])
260+
.neg();
261+
262+
// Full spectrum = [half_spectrum, conjugate_mirror]
263+
let full_re = Tensor::cat(vec![half_re, mirror_re], dim);
264+
let full_im = Tensor::cat(vec![half_im, mirror_im], dim);
265+
266+
(full_re, full_im)
267+
}

0 commit comments

Comments
 (0)