Skip to content

Conversation

@Schopenhauer-loves-Hegel

PR Category

Operator

Type of Change

New Feature

Description

add operator istft and unit test and benchmark using triton copilot

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

(zpy_triton) root@job-4211e20e-bdf4-4193-bf7e-7448650342e5-master-0:/share/project/tj/workspace/pr_flaggems/FlagGems# pytest benchmark/test_special
_perf.py -m istft -s
/usr/local/lib/python3.10/dist-packages/hypothesis/entry_points.py:23: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources
=============================================================== test session starts ===============================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
rootdir: /share/project/tj/workspace/pr_flaggems/FlagGems
configfile: pytest.ini
plugins: anyio-4.11.0, hypothesis-5.35.1, flakefinder-1.1.0, rerunfailures-14.0, shard-0.1.2, xdist-3.6.1, xdoctest-1.0.2
collected 22 items / 21 deselected / 1 selected                                                                                                   
Running 1 items in this shard

benchmark/test_special_perf.py 
Operator: istft  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.148048            1.122912               0.132          ([torch.Size([257, 10])], {'n_fft': 512, 'hop_length': 128, 'win_length': 512, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.132368            1.871760               0.071          ([torch.Size([129, 20])], {'n_fft': 256, 'hop_length': 64, 'win_length': 256, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.131776            1.225888               0.107          ([torch.Size([513, 8])], {'n_fft': 1024, 'hop_length': 256, 'win_length': 1024, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})


Operator: istft  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.135424            1.119680               0.121          ([torch.Size([257, 10])], {'n_fft': 512, 'hop_length': 128, 'win_length': 512, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.132480            1.845296               0.072          ([torch.Size([129, 20])], {'n_fft': 256, 'hop_length': 64, 'win_length': 256, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.135776            1.235616               0.110          ([torch.Size([513, 8])], {'n_fft': 1024, 'hop_length': 256, 'win_length': 1024, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})

.

================================================================ warnings summary =================================================================
../../../../../../usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py:108: 11 warnings

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Schopenhauer-loves-Hegel, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new Inverse Short-Time Fourier Transform (ISTFT) operator into the FlagGems library. The implementation leverages Triton to provide a GPU-accelerated version of the operator, which is essential for various audio and signal processing applications. The changes encompass the core logic for IFFT, overlap-add, and normalization, along with necessary integration into the library's operator registry and robust testing infrastructure.

Highlights

  • New istft Operator: Introduced the Inverse Short-Time Fourier Transform (ISTFT) operator, a crucial component for reconstructing time-domain signals from their spectrograms.
  • Triton-based Implementation: The istft operator is implemented using Triton kernels (_ifft_onesided_kernel, _ifft_full_kernel, _overlap_add_kernel, _normalize_kernel) for potential GPU acceleration.
  • Comprehensive Testing: Includes new unit tests (test_accuracy_istft) to ensure numerical correctness against PyTorch's native istft and performance benchmarks (IstftBenchmark) to evaluate its efficiency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new istft operator implemented with Triton, along with corresponding benchmarks and accuracy tests. While adding this operator is a valuable contribution, the current implementation has critical performance issues. The benchmark results in the pull request description show that the Triton version is approximately 10 times slower than the native PyTorch implementation. My review identifies two primary causes for this performance degradation: the use of sequential Python loops to launch multiple small kernels, and a naive, inefficient algorithm within the IFFT kernels themselves. These performance bottlenecks must be addressed. I have also included several suggestions to improve code quality and maintainability.

Comment on lines 32 to 43
for freq in range(N_FREQS):
offset = freq * stride_freq + frame_idx * stride_frame
real_val = tl.load(spec_real_ptr + offset)
imag_val = tl.load(spec_imag_ptr + offset)
angle = (TWO_PI * freq) * t_float * n_fft_inv
cos_val = tl.cos(angle)
sin_val = tl.sin(angle)
contrib = real_val * cos_val - imag_val * sin_val
if (freq == 0) or (freq == N_FREQS - 1):
acc += contrib
else:
acc += 2.0 * contrib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of the IFFT inside this kernel uses a serial for loop over the frequency dimension (for freq in range(N_FREQS)). This is a naive Discrete Fourier Transform (DFT) algorithm, which has a complexity of O(N^2) and is extremely inefficient on a GPU. This is a critical performance issue and a primary reason for the slow benchmark results. You should avoid serial loops inside Triton kernels whenever possible. Consider parallelizing the computation over the frequency dimension or implementing a more efficient parallel FFT algorithm (e.g., Cooley-Tukey).

Comment on lines 259 to 318
for b in range(batch_size):
real_ptr = spec_real[b]
imag_ptr = spec_imag[b]
for frame_idx in range(n_frames):
if onesided:
_ifft_onesided_kernel[grid_ifft](
real_ptr,
imag_ptr,
frame_real,
stride_freq,
stride_frame,
scale,
frame_idx,
n_fft,
BLOCK_T=256,
N_FREQS=n_freqs,
)
if needs_imag:
frame_imag.zero_()
else:
_ifft_full_kernel[grid_ifft](
real_ptr,
imag_ptr,
frame_real,
frame_imag,
stride_freq,
stride_frame,
scale,
frame_idx,
n_fft,
BLOCK_T=256,
N_FREQS=n_freqs,
)

frame_offset = frame_idx * hop_length
_overlap_add_kernel[grid_overlap](
frame_real,
frame_imag if needs_imag else frame_real,
output_real[b],
output_imag[b] if needs_imag else output_real[b],
envelope[b],
window,
frame_offset,
win_length,
full_length,
BLOCK_T=256,
APPLY_WINDOW=True,
HAS_IMAG=needs_imag,
)

grid_norm = lambda meta: (triton.cdiv(full_length, meta["BLOCK_T"]),)
for b in range(batch_size):
_normalize_kernel[grid_norm](
output_real[b],
output_imag[b] if needs_imag else output_real[b],
envelope[b],
full_length,
BLOCK_T=256,
HAS_IMAG=needs_imag,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of Python loops to iterate over batch_size (lines 259 and 310) and n_frames (line 262) launches many small kernels, which is a major performance bottleneck. To improve performance, you should launch each kernel (_ifft_*, _overlap_add_kernel, _normalize_kernel) only once. The work should be parallelized across batches and frames by expanding the Triton grid and calculating the batch_idx and frame_idx inside the kernels using tl.program_id.

Comment on lines 21 to 22
# 在 kernel 内定义常量
TWO_PI = 6.283185307179586 # 2 * pi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments in this file are in Chinese. For consistency with the rest of the project and to ensure maintainability for all contributors, please translate them to English.

Suggested change
# 在 kernel 内定义常量
TWO_PI = 6.283185307179586 # 2 * pi
# Define constants inside the kernel
TWO_PI = 6.283185307179586 # 2 * pi

Comment on lines 163 to 166
def _make_hann_window(length, *, device, dtype):
TWO_PI = 6.283185307179586 # 2 * pi
n = torch.arange(length, device=device, dtype=dtype)
return 0.5 - 0.5 * torch.cos(TWO_PI * n / length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This helper function _make_hann_window is defined but never used within the istft operator. If it's not intended for future use, it should be removed to keep the code clean.

scale,
frame_idx,
n_fft,
BLOCK_T=256,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The BLOCK_T size is hardcoded to 256. This may not be optimal for all input sizes or hardware. To improve performance and portability, you should use @triton.autotune to automatically find the best block size. This would involve defining a range of possible values for BLOCK_T.


@pytest.mark.istft
@pytest.mark.parametrize("n_fft, n_frames", [(512, 10), (256, 20), (1024, 8)])
@pytest.mark.parametrize("dtype", [torch.float32])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The accuracy test for istft is parameterized only for torch.float32. The istft implementation and its benchmark also support torch.float16. Please extend the test to cover torch.float16 to ensure its correctness.

Suggested change
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])

@Schopenhauer-loves-Hegel
Copy link
Author

I have update the operator and got this benchmark result:

benchmark/test_special_perf.py 
Operator: istft  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.144192            0.741360               0.194          ([torch.Size([257, 10])], {'n_fft': 512, 'hop_length': 128, 'win_length': 512, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.141920            0.736064               0.193          ([torch.Size([129, 20])], {'n_fft': 256, 'hop_length': 64, 'win_length': 256, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.138384            0.734208               0.188          ([torch.Size([513, 8])], {'n_fft': 1024, 'hop_length': 256, 'win_length': 1024, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})


Operator: istft  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.133376            0.428416               0.311          ([torch.Size([257, 10])], {'n_fft': 512, 'hop_length': 128, 'win_length': 512, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.141776            0.416032               0.341          ([torch.Size([129, 20])], {'n_fft': 256, 'hop_length': 64, 'win_length': 256, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})
SUCCESS               0.133376            0.461440               0.289          ([torch.Size([513, 8])], {'n_fft': 1024, 'hop_length': 256, 'win_length': 1024, 'window': None, 'center': True, 'normalized': False, 'onesided': True, 'length': None, 'return_complex': False})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant