-
Notifications
You must be signed in to change notification settings - Fork 145
feat:add istft operators #1006
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat:add istft operators #1006
Conversation
Summary of ChangesHello @Schopenhauer-loves-Hegel, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new Inverse Short-Time Fourier Transform (ISTFT) operator into the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new istft operator implemented with Triton, along with corresponding benchmarks and accuracy tests. While adding this operator is a valuable contribution, the current implementation has critical performance issues. The benchmark results in the pull request description show that the Triton version is approximately 10 times slower than the native PyTorch implementation. My review identifies two primary causes for this performance degradation: the use of sequential Python loops to launch multiple small kernels, and a naive, inefficient algorithm within the IFFT kernels themselves. These performance bottlenecks must be addressed. I have also included several suggestions to improve code quality and maintainability.
src/flag_gems/ops/istft.py
Outdated
| for freq in range(N_FREQS): | ||
| offset = freq * stride_freq + frame_idx * stride_frame | ||
| real_val = tl.load(spec_real_ptr + offset) | ||
| imag_val = tl.load(spec_imag_ptr + offset) | ||
| angle = (TWO_PI * freq) * t_float * n_fft_inv | ||
| cos_val = tl.cos(angle) | ||
| sin_val = tl.sin(angle) | ||
| contrib = real_val * cos_val - imag_val * sin_val | ||
| if (freq == 0) or (freq == N_FREQS - 1): | ||
| acc += contrib | ||
| else: | ||
| acc += 2.0 * contrib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of the IFFT inside this kernel uses a serial for loop over the frequency dimension (for freq in range(N_FREQS)). This is a naive Discrete Fourier Transform (DFT) algorithm, which has a complexity of O(N^2) and is extremely inefficient on a GPU. This is a critical performance issue and a primary reason for the slow benchmark results. You should avoid serial loops inside Triton kernels whenever possible. Consider parallelizing the computation over the frequency dimension or implementing a more efficient parallel FFT algorithm (e.g., Cooley-Tukey).
src/flag_gems/ops/istft.py
Outdated
| for b in range(batch_size): | ||
| real_ptr = spec_real[b] | ||
| imag_ptr = spec_imag[b] | ||
| for frame_idx in range(n_frames): | ||
| if onesided: | ||
| _ifft_onesided_kernel[grid_ifft]( | ||
| real_ptr, | ||
| imag_ptr, | ||
| frame_real, | ||
| stride_freq, | ||
| stride_frame, | ||
| scale, | ||
| frame_idx, | ||
| n_fft, | ||
| BLOCK_T=256, | ||
| N_FREQS=n_freqs, | ||
| ) | ||
| if needs_imag: | ||
| frame_imag.zero_() | ||
| else: | ||
| _ifft_full_kernel[grid_ifft]( | ||
| real_ptr, | ||
| imag_ptr, | ||
| frame_real, | ||
| frame_imag, | ||
| stride_freq, | ||
| stride_frame, | ||
| scale, | ||
| frame_idx, | ||
| n_fft, | ||
| BLOCK_T=256, | ||
| N_FREQS=n_freqs, | ||
| ) | ||
|
|
||
| frame_offset = frame_idx * hop_length | ||
| _overlap_add_kernel[grid_overlap]( | ||
| frame_real, | ||
| frame_imag if needs_imag else frame_real, | ||
| output_real[b], | ||
| output_imag[b] if needs_imag else output_real[b], | ||
| envelope[b], | ||
| window, | ||
| frame_offset, | ||
| win_length, | ||
| full_length, | ||
| BLOCK_T=256, | ||
| APPLY_WINDOW=True, | ||
| HAS_IMAG=needs_imag, | ||
| ) | ||
|
|
||
| grid_norm = lambda meta: (triton.cdiv(full_length, meta["BLOCK_T"]),) | ||
| for b in range(batch_size): | ||
| _normalize_kernel[grid_norm]( | ||
| output_real[b], | ||
| output_imag[b] if needs_imag else output_real[b], | ||
| envelope[b], | ||
| full_length, | ||
| BLOCK_T=256, | ||
| HAS_IMAG=needs_imag, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of Python loops to iterate over batch_size (lines 259 and 310) and n_frames (line 262) launches many small kernels, which is a major performance bottleneck. To improve performance, you should launch each kernel (_ifft_*, _overlap_add_kernel, _normalize_kernel) only once. The work should be parallelized across batches and frames by expanding the Triton grid and calculating the batch_idx and frame_idx inside the kernels using tl.program_id.
src/flag_gems/ops/istft.py
Outdated
| # 在 kernel 内定义常量 | ||
| TWO_PI = 6.283185307179586 # 2 * pi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments in this file are in Chinese. For consistency with the rest of the project and to ensure maintainability for all contributors, please translate them to English.
| # 在 kernel 内定义常量 | |
| TWO_PI = 6.283185307179586 # 2 * pi | |
| # Define constants inside the kernel | |
| TWO_PI = 6.283185307179586 # 2 * pi |
src/flag_gems/ops/istft.py
Outdated
| def _make_hann_window(length, *, device, dtype): | ||
| TWO_PI = 6.283185307179586 # 2 * pi | ||
| n = torch.arange(length, device=device, dtype=dtype) | ||
| return 0.5 - 0.5 * torch.cos(TWO_PI * n / length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/flag_gems/ops/istft.py
Outdated
| scale, | ||
| frame_idx, | ||
| n_fft, | ||
| BLOCK_T=256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/test_special_ops.py
Outdated
|
|
||
| @pytest.mark.istft | ||
| @pytest.mark.parametrize("n_fft, n_frames", [(512, 10), (256, 20), (1024, 8)]) | ||
| @pytest.mark.parametrize("dtype", [torch.float32]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The accuracy test for istft is parameterized only for torch.float32. The istft implementation and its benchmark also support torch.float16. Please extend the test to cover torch.float16 to ensure its correctness.
| @pytest.mark.parametrize("dtype", [torch.float32]) | |
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) |
|
I have update the operator and got this benchmark result: |
PR Category
Operator
Type of Change
New Feature
Description
add operator istft and unit test and benchmark using triton copilot
Issue
Progress
Performance