Skip to content

[COMPARISON TESTS] Use pre-computed values in librosa compat test #4018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
dc073ea
Attempting autouse fixture
samanklesaria Aug 1, 2025
6da31fd
Mock calls to librosa by reading from cache
samanklesaria Aug 1, 2025
dd1ee02
Add librosa cache
samanklesaria Aug 1, 2025
f90ddaf
Convert single cache file to one per test
samanklesaria Aug 1, 2025
b8389e9
Raise tolerance for librosa power tests
samanklesaria Aug 5, 2025
a1729f9
Remove colons in request string
samanklesaria Aug 5, 2025
4dde9b6
Re-generate librosa expected results
samanklesaria Aug 5, 2025
b538475
Use newer numpy
samanklesaria Aug 1, 2025
1e9bcd7
Fix pillow typo
samanklesaria Aug 1, 2025
5443243
Require numpy >= 1.26
samanklesaria Aug 5, 2025
e3b26d0
Merge branch 'main' into librosa_mock
samanklesaria Aug 5, 2025
02bf73b
Use numpy>=1.26
samanklesaria Aug 6, 2025
eaaa22c
Use explicit seed of zero for vocoder tests
samanklesaria Aug 6, 2025
f33fc37
Add docstring for mock_function
samanklesaria Aug 6, 2025
e3dae11
Expand docstring form mock_function
samanklesaria Aug 6, 2025
edb0768
Merge branch 'main' into librosa_mock
samanklesaria Aug 7, 2025
5e5c7ed
Comment out expected-value creation code
samanklesaria Aug 7, 2025
764f917
Remove if checks. Those should never ever be False. If they are, this…
NicolasHug Aug 8, 2025
8be08ca
Make test runnable from arbitrary dir
NicolasHug Aug 8, 2025
c278819
Add comment to explain what Request Mixin is
NicolasHug Aug 8, 2025
db17ee1
Clarify if-statement in librosa_mock
samanklesaria Aug 8, 2025
c6b7bbc
Revert mistaken deletion of deprecation test
samanklesaria Aug 8, 2025
0157f60
Fix mfcc_from_waveform issue
samanklesaria Aug 8, 2025
b21b4db
Merge cpu and cuda results
samanklesaria Aug 8, 2025
91daa09
Disable generation again
samanklesaria Aug 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/unittest-linux/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ fi
pip install kaldi-io SoundFile librosa coverage pytest pytest-cov scipy expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag pyroomacoustics flashlight-text git+https://github.com/kpu/kenlm

# TODO: might be better to fix the single call to `pip install` above
pip install "pillow<10.0" "scipy<1.10" "numpy<2.0"
pip install pillow scipy "numpy>=1.26"
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unittest-linux-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ jobs:

pip3 install parameterized requests
pip3 install kaldi-io SoundFile librosa coverage pytest pytest-cov scipy expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag
pip3 install "pillow<10.0" "scipy<1.10" "numpy<2.0"
pip3 install pillow scipy "numpy>=1.26"

echo "::endgroup::"
echo "::group::Run tests"
Expand Down
62 changes: 62 additions & 0 deletions test/librosa_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import re
import os
from pathlib import Path
import torch

def mock_function(f):
"""
Create a mocked version of a function from the librosa library that loads a precomputed result
if it exists. The commented out part otherwise computes the result and saves it for future use.
This is used to compare torchaudio functionality to the equivalent functionalty in librosa without
depending on librosa after results are precomputed.
"""
this_file = Path(__file__).parent.resolve()
expected_results_folder = this_file / "torchaudio_unittest" / "assets" / "librosa_expected_results"
def wrapper(request, *args, **kwargs):
mocked_results = expected_results_folder / f"{request}.pt"
return torch.load(mocked_results, weights_only=False)

# Old definition used for generation:
# if os.path.exists(mocked_results):
# return torch.load(mocked_results, weights_only=False)
# import librosa
# result = eval(f)(*args, **kwargs)
# if request is not None:
# mocked_results.parent.mkdir(parents=True, exist_ok=True)
# torch.save(result, mocked_results)
# return result
return wrapper

griffinlim = mock_function("librosa.griffinlim")

mel = mock_function("librosa.filters.mel")

power_to_db = mock_function("librosa.core.power_to_db")

amplitude_to_db = mock_function("librosa.core.amplitude_to_db")

phase_vocoder = mock_function("librosa.phase_vocoder")

spectrogram = mock_function("librosa.core.spectrum._spectrogram")

mel_spectrogram = mock_function("librosa.feature.melspectrogram")

def _mfcc_from_waveform(waveform, sample_rate, n_fft, hop_length, n_mels, n_mfcc):
import librosa
melspec = librosa.feature.melspectrogram(
y=waveform[0].cpu().numpy(),
sr=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=hop_length,
n_mels=n_mels,
htk=True,
norm=None,
pad_mode="reflect",
)
return librosa.feature.mfcc(S=librosa.core.power_to_db(melspec), n_mfcc=n_mfcc, dct_type=2, norm="ortho")

mfcc_from_waveform = mock_function("_mfcc_from_waveform")


spectral_centroid = mock_function("librosa.feature.spectral_centroid")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
13 changes: 13 additions & 0 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@
from .image_utils import get_image, rgb_to_gray, rgb_to_yuv_ccir, save_image
from .parameterized_utils import load_params, nested_params
from .wav_utils import get_wav_data, load_wav, normalize_wav, save_wav
import pytest

class RequestMixin:
"""
Adds the `self.request` attribute to a test instance, which uniquely identifies the test.
It looks like, e.g.:
test/torchaudio_unittest/functional/librosa_compatibility_cpu_test.py__TestFunctionalCPU__test_create_mel_fb_13
"""

@pytest.fixture(autouse=True)
def inject_request(self, request):
self.request = request.node.nodeid.replace(":", "_").replace("_cpu_", "_").replace("_cuda_", "_")

__all__ = [
"get_asset_path",
Expand All @@ -40,6 +52,7 @@
"HttpServerMixin",
"TestBaseMixin",
"PytorchTestCase",
"RequestMixin",
"TorchaudioTestCase",
"skipIfNoAudioDevice",
"skipIfNoCtcDecoder",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@
from parameterized import param
from torchaudio._internal.module_utils import is_module_available

LIBROSA_AVAILABLE = is_module_available("librosa")
import librosa_mock
import numpy as np
import pytest

if LIBROSA_AVAILABLE:
import librosa
import numpy as np

from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, TestBaseMixin, RequestMixin

from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, TestBaseMixin


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class Functional(TestBaseMixin):
class Functional(TestBaseMixin, RequestMixin):
"""Test suite for functions in `functional` module."""

dtype = torch.float64
Expand Down Expand Up @@ -50,7 +46,8 @@ def test_griffinlim(self, momentum):
length=waveform.size(1),
rand_init=False,
)
expected = librosa.griffinlim(
expected = librosa_mock.griffinlim(
self.request,
specgram[0].cpu().numpy(),
n_iter=n_iter,
hop_length=hop_length,
Expand All @@ -77,12 +74,11 @@ def test_griffinlim(self, momentum):
def test_create_mel_fb(
self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"
):
if norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
self.skipTest("Test is known to fail with older versions of librosa.")
if self.device != "cpu":
self.skipTest("No need to run this test on CUDA")

expected = librosa.filters.mel(
expected = librosa_mock.mel(
self.request,
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmax=fmax, fmin=fmin, htk=mel_scale == "htk", norm=norm
).T
result = F.melscale_fbanks(
Expand All @@ -104,7 +100,7 @@ def test_amplitude_to_DB_power(self):

spec = get_spectrogram(get_whitenoise(device=self.device, dtype=self.dtype), power=2)
result = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
expected = librosa.core.power_to_db(spec[0].cpu().numpy())[None, ...]
expected = librosa_mock.power_to_db(self.request, spec[0].cpu().numpy())[None, ...]
self.assertEqual(result, torch.from_numpy(expected))

def test_amplitude_to_DB(self):
Expand All @@ -115,28 +111,28 @@ def test_amplitude_to_DB(self):

spec = get_spectrogram(get_whitenoise(device=self.device, dtype=self.dtype), power=1)
result = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
expected = librosa.core.amplitude_to_db(spec[0].cpu().numpy())[None, ...]
expected = librosa_mock.amplitude_to_db(self.request, spec[0].cpu().numpy())[None, ...]
self.assertEqual(result, torch.from_numpy(expected))


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class FunctionalComplex(TestBaseMixin):
class FunctionalComplex(TestBaseMixin, RequestMixin):
@nested_params([0.5, 1.01, 1.3])
def test_phase_vocoder(self, rate):
torch.manual_seed(0)
hop_length = 256
num_freq = 1025
num_frames = 400

# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
spec = torch.randn(num_freq, num_frames, device=self.device, dtype=torch.complex128)
spec = torch.randn(num_freq, num_frames, dtype=torch.complex128).to(self.device,)
phase_advance = torch.linspace(0, np.pi * hop_length, num_freq, device=self.device, dtype=torch.float64)[
..., None
]

stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

expected_stretched = librosa.phase_vocoder(spec.cpu().numpy(), rate=rate, hop_length=hop_length)
expected_stretched = librosa_mock.phase_vocoder(self.request, spec.cpu().numpy(), rate=rate, hop_length=hop_length)

self.assertEqual(stretched, torch.from_numpy(expected_stretched))
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,11 @@
import torchaudio.transforms as T
from parameterized import param, parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import get_sinusoid, get_spectrogram, get_whitenoise, nested_params, TestBaseMixin
from torchaudio_unittest.common_utils import get_sinusoid, get_spectrogram, get_whitenoise, nested_params, TestBaseMixin, RequestMixin
import librosa_mock
import pytest

LIBROSA_AVAILABLE = is_module_available("librosa")

if LIBROSA_AVAILABLE:
import librosa


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TransformsTestBase(TestBaseMixin):
class TransformsTestBase(TestBaseMixin, RequestMixin):
@parameterized.expand(
[
param(n_fft=400, hop_length=200, power=2.0),
Expand All @@ -29,7 +24,8 @@ def test_Spectrogram(self, n_fft, hop_length, power):
n_channels=1,
).to(self.device, self.dtype)

expected = librosa.core.spectrum._spectrogram(
expected = librosa_mock.spectrogram(
self.request,
y=waveform[0].cpu().numpy(), n_fft=n_fft, hop_length=hop_length, power=power, pad_mode="reflect"
)[0]

Expand All @@ -47,7 +43,8 @@ def test_Spectrogram_complex(self):
n_channels=1,
).to(self.device, self.dtype)

expected = librosa.core.spectrum._spectrogram(
expected = librosa_mock.spectrogram(
self.request,
y=waveform[0].cpu().numpy(), n_fft=n_fft, hop_length=hop_length, power=1, pad_mode="reflect"
)[0]

Expand All @@ -72,7 +69,8 @@ def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
n_channels=1,
).to(self.device, self.dtype)

expected = librosa.feature.melspectrogram(
expected = librosa_mock.mel_spectrogram(
self.request,
y=waveform[0].cpu().numpy(),
sr=sample_rate,
n_fft=n_fft,
Expand All @@ -96,14 +94,14 @@ def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
def test_magnitude_to_db(self):
spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
result = T.AmplitudeToDB("magnitude", 80.0).to(self.device, self.dtype)(spectrogram)[0]
expected = librosa.core.spectrum.amplitude_to_db(spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected))
expected = librosa_mock.amplitude_to_db(self.request, spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected), atol=1e-3, rtol=1e-3)

def test_power_to_db(self):
spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
result = T.AmplitudeToDB("power", 80.0).to(self.device, self.dtype)(spectrogram)[0]
expected = librosa.core.spectrum.power_to_db(spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected))
expected = librosa_mock.power_to_db(self.request, spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected), atol=1e-3, rtol=1e-3)

@nested_params(
[
Expand All @@ -122,19 +120,14 @@ def test_mfcc(self, n_fft, hop_length, n_mels, n_mfcc):
melkwargs={"hop_length": hop_length, "n_fft": n_fft, "n_mels": n_mels},
).to(self.device, self.dtype)(waveform)[0]

melspec = librosa.feature.melspectrogram(
y=waveform[0].cpu().numpy(),
sr=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=hop_length,
n_mels=n_mels,
htk=True,
norm=None,
pad_mode="reflect",
)
expected = librosa.feature.mfcc(
S=librosa.core.spectrum.power_to_db(melspec), n_mfcc=n_mfcc, dct_type=2, norm="ortho"
expected = librosa_mock.mfcc_from_waveform(
f"{self.request}",
waveform,
sample_rate,
n_fft,
hop_length,
n_mels,
n_mfcc
)
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)

Expand All @@ -152,7 +145,8 @@ def test_spectral_centroid(self, n_fft, hop_length):
result = T.SpectralCentroid(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,).to(
self.device, self.dtype
)(waveform)
expected = librosa.feature.spectral_centroid(
expected = librosa_mock.spectral_centroid(
self.request,
y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length, pad_mode="reflect"
)
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)
Loading