diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index f311c8370e..c1dfc68f1f 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -30,5 +30,5 @@ fi ( cd test - pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs" + pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not torchscript_consistency" ) diff --git a/.github/workflows/unittest-linux-gpu.yml b/.github/workflows/unittest-linux-gpu.yml index 98b5147cff..175e0f6ab8 100644 --- a/.github/workflows/unittest-linux-gpu.yml +++ b/.github/workflows/unittest-linux-gpu.yml @@ -1,123 +1,117 @@ -# name: Unit-tests on Linux GPU +name: Unit-tests on Linux GPU -# on: -# pull_request: -# push: -# branches: -# - nightly -# - main -# - release/* -# workflow_dispatch: +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: -# jobs: -# tests: -# strategy: -# matrix: -# # TODO add up to 3.13 -# python_version: ["3.9", "3.10"] -# cuda_arch_version: ["12.6"] -# fail-fast: false -# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main -# permissions: -# id-token: write -# contents: read -# with: -# runner: linux.g5.4xlarge.nvidia.gpu -# repository: pytorch/audio -# gpu-arch-type: cuda -# gpu-arch-version: ${{ matrix.cuda_arch_version }} -# timeout: 120 +jobs: + tests: + strategy: + matrix: + # TODO add up to 3.13 + python_version: ["3.9", "3.10"] + cuda_arch_version: ["12.6"] + fail-fast: false + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: linux.g5.4xlarge.nvidia.gpu + repository: pytorch/audio + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 120 -# script: | -# set -ex -# # Set up Environment Variables -# export PYTHON_VERSION="${{ matrix.python_version }}" -# export CU_VERSION="${{ matrix.cuda_arch_version }}" -# export CUDATOOLKIT="pytorch-cuda=${CU_VERSION}" -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_APPLY_CMVN_SLIDING=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_FBANK_FEATS=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_KALDI_PITCH_FEATS=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_MFCC_FEATS=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_SPECTROGRAM_FEATS=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_CUDA_SMALL_MEMORY=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_ON_PYTHON_310=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_TEMPORARY_DISABLED=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_SOX_DECODER=true -# export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_SOX_ENCODER=true + script: | + set -ex + # Set up Environment Variables + export PYTHON_VERSION="${{ matrix.python_version }}" + export CU_VERSION="${{ matrix.cuda_arch_version }}" + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_APPLY_CMVN_SLIDING=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_FBANK_FEATS=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_KALDI_PITCH_FEATS=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_MFCC_FEATS=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_CMD_COMPUTE_SPECTROGRAM_FEATS=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_CUDA_SMALL_MEMORY=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_ON_PYTHON_310=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_TEMPORARY_DISABLED=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_SOX_DECODER=true + export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_SOX_ENCODER=true -# # Set CHANNEL -# if [[(${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then -# export CHANNEL=test -# else -# export CHANNEL=nightly -# fi + # Set CHANNEL + if [[(${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then + export CHANNEL=test + else + export CHANNEL=nightly + fi -# echo "::group::Create conda env" -# # Mark Build Directory Safe -# git config --global --add safe.directory /__w/audio/audio -# conda create --quiet -y --prefix ci_env python="${PYTHON_VERSION}" -# conda activate ./ci_env + echo "::group::Create conda env" + # Mark Build Directory Safe + git config --global --add safe.directory /__w/audio/audio + conda create --quiet -y --prefix ci_env python="${PYTHON_VERSION}" + conda activate ./ci_env -# echo "::endgroup::" -# echo "::group::Install PyTorch" -# conda install \ -# --yes \ -# --quiet \ -# -c "pytorch-${CHANNEL}" \ -# -c nvidia "pytorch-${CHANNEL}"::pytorch[build="*${CU_VERSION}*"] \ -# "${CUDATOOLKIT}" + echo "::endgroup::" + echo "::group::Install Pytorch" + pip3 install --pre torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/${CHANNEL}/cu128" -# echo "::endgroup::" -# echo "::group::Install TorchAudio" -# conda install --quiet --yes 'cmake>=3.18.0' ninja -# pip3 install --progress-bar off -v -e . --no-use-pep517 + echo "::endgroup::" + echo "::group::Install TorchAudio" + conda install --quiet --yes 'cmake>=3.18.0' ninja + pip3 install --progress-bar off -v -e . --no-use-pep517 -# echo "::endgroup::" -# echo "::group::Build FFmpeg" -# .github/scripts/ffmpeg/build_gpu.sh + echo "::endgroup::" + echo "::group::Build FFmpeg" + .github/scripts/ffmpeg/build_gpu.sh -# echo "::endgroup::" -# echo "::group::Install other Dependencies" -# conda install \ -# --quiet --yes \ -# -c conda-forge \ -# -c numba/label/dev \ -# sox libvorbis 'librosa==0.10.0' parameterized 'requests>=2.20' -# pip3 install --progress-bar off \ -# kaldi-io \ -# SoundFile \ -# coverage \ -# pytest \ -# pytest-cov \ -# 'scipy==1.7.3' \ -# transformers \ -# expecttest \ -# unidecode \ -# inflect \ -# Pillow \ -# sentencepiece \ -# pytorch-lightning \ -# 'protobuf<4.21.0' \ -# demucs \ -# tinytag \ -# flashlight-text \ -# git+https://github.com/kpu/kenlm/ \ -# git+https://github.com/pytorch/fairseq.git@e47a4c8 + echo "::endgroup::" + echo "::group::Install other Dependencies" + conda install \ + --quiet --yes \ + -c conda-forge \ + -c numba/label/dev \ + sox libvorbis 'librosa==0.10.0' parameterized 'requests>=2.20' + pip3 install --progress-bar off \ + kaldi-io \ + SoundFile \ + coverage \ + pytest \ + pytest-cov \ + 'scipy==1.7.3' \ + transformers \ + expecttest \ + unidecode \ + inflect \ + Pillow \ + sentencepiece \ + pytorch-lightning \ + 'protobuf<4.21.0' \ + demucs \ + tinytag \ + flashlight-text \ + git+https://github.com/kpu/kenlm/ \ + git+https://github.com/pytorch/fairseq.git@e47a4c8 -# echo "::endgroup::" -# echo "::group::Run tests" -# export PATH="${PWD}/third_party/install/bin/:${PATH}" + echo "::endgroup::" + echo "::group::Run tests" + export PATH="${PWD}/third_party/install/bin/:${PATH}" -# declare -a args=( -# '-v' -# '--cov=torchaudio' -# "--junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml" -# '--durations' '100' -# '-k' 'cuda or gpu' -# ) + declare -a args=( + '-v' + '--cov=torchaudio' + "--junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml" + '--durations' '100' + '-k' 'cuda or gpu' + ) -# cd test -# python3 -m torch.utils.collect_env -# env | grep TORCHAUDIO || true -# pytest "${args[@]}" torchaudio_unittest -# coverage html + cd test + python3 -m torch.utils.collect_env + env | grep TORCHAUDIO || true + pytest "${args[@]}" torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not torchscript_consistency" + coverage html diff --git a/src/libtorchaudio/lfilter.cpp b/src/libtorchaudio/lfilter.cpp index 454b2cbcda..9224eade9b 100644 --- a/src/libtorchaudio/lfilter.cpp +++ b/src/libtorchaudio/lfilter.cpp @@ -100,194 +100,23 @@ void lfilter_core_generic_loop( } } -class DifferentiableIIR : public torch::autograd::Function { - public: - static torch::Tensor forward( - torch::autograd::AutogradContext* ctx, - const torch::Tensor& waveform, - const torch::Tensor& a_coeffs_normalized) { - auto device = waveform.device(); - auto dtype = waveform.dtype(); - int64_t n_batch = waveform.size(0); - int64_t n_channel = waveform.size(1); - int64_t n_sample = waveform.size(2); - int64_t n_order = a_coeffs_normalized.size(1); - int64_t n_sample_padded = n_sample + n_order - 1; - - auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous(); - - auto options = torch::TensorOptions().dtype(dtype).device(device); - auto padded_output_waveform = - torch::zeros({n_batch, n_channel, n_sample_padded}, options); - - if (device.is_cpu()) { - cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); - } else if (device.is_cuda()) { -#ifdef USE_CUDA - cuda_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); -#else - lfilter_core_generic_loop( - waveform, a_coeff_flipped, padded_output_waveform); -#endif - } else { - lfilter_core_generic_loop( - waveform, a_coeff_flipped, padded_output_waveform); - } - - auto output = padded_output_waveform.index( - {torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(n_order - 1, torch::indexing::None)}); - - ctx->save_for_backward({waveform, a_coeffs_normalized, output}); - return output; - } - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto x = saved[0]; - auto a_coeffs_normalized = saved[1]; - auto y = saved[2]; - - int64_t n_channel = x.size(1); - int64_t n_order = a_coeffs_normalized.size(1); - - auto dx = torch::Tensor(); - auto da = torch::Tensor(); - auto dy = grad_outputs[0]; - - namespace F = torch::nn::functional; - - auto tmp = - DifferentiableIIR::apply(dy.flip(2).contiguous(), a_coeffs_normalized) - .flip(2); - - if (x.requires_grad()) { - dx = tmp; - } - - if (a_coeffs_normalized.requires_grad()) { - da = -torch::matmul( - tmp.transpose(0, 1).reshape({n_channel, 1, -1}), - F::pad(y, F::PadFuncOptions({n_order - 1, 0})) - .unfold(2, n_order, 1) - .transpose(0, 1) - .reshape({n_channel, -1, n_order})) - .squeeze(1) - .flip(1); - } - return {dx, da}; - } -}; - -class DifferentiableFIR : public torch::autograd::Function { - public: - static torch::Tensor forward( - torch::autograd::AutogradContext* ctx, - const torch::Tensor& waveform, - const torch::Tensor& b_coeffs) { - int64_t n_order = b_coeffs.size(1); - int64_t n_channel = b_coeffs.size(0); - - namespace F = torch::nn::functional; - auto b_coeff_flipped = b_coeffs.flip(1).contiguous(); - auto padded_waveform = - F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - - auto output = F::conv1d( - padded_waveform, - b_coeff_flipped.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)); - - ctx->save_for_backward({waveform, b_coeffs, output}); - return output; - } - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto x = saved[0]; - auto b_coeffs = saved[1]; - auto y = saved[2]; - - int64_t n_batch = x.size(0); - int64_t n_channel = x.size(1); - int64_t n_order = b_coeffs.size(1); - - auto dx = torch::Tensor(); - auto db = torch::Tensor(); - auto dy = grad_outputs[0]; - - namespace F = torch::nn::functional; - - if (b_coeffs.requires_grad()) { - db = F::conv1d( - F::pad(x, F::PadFuncOptions({n_order - 1, 0})) - .view({1, n_batch * n_channel, -1}), - dy.view({n_batch * n_channel, 1, -1}), - F::Conv1dFuncOptions().groups(n_batch * n_channel)) - .view({n_batch, n_channel, -1}) - .sum(0) - .flip(1); - } - - if (x.requires_grad()) { - dx = F::conv1d( - F::pad(dy, F::PadFuncOptions({0, n_order - 1})), - b_coeffs.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)); - } - - return {dx, db}; - } -}; - -torch::Tensor lfilter_core( - const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, - const torch::Tensor& b_coeffs) { - TORCH_CHECK(waveform.device() == a_coeffs.device()); - TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); - TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes()); - - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3); - TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2); - TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1)); - - int64_t n_order = b_coeffs.size(1); - - TORCH_INTERNAL_ASSERT(n_order > 0); - - auto filtered_waveform = DifferentiableFIR::apply( - waveform, - b_coeffs / - a_coeffs.index( - {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); +} // namespace - auto output = DifferentiableIIR::apply( - filtered_waveform, - a_coeffs / - a_coeffs.index( - {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); - return output; +TORCH_LIBRARY(torchaudio, m) { + m.def( + "torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()"); } -} // namespace - -// Note: We want to avoid using "catch-all" kernel. -// The following registration should be replaced with CPU specific registration. -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); } -TORCH_LIBRARY(torchaudio, m) { - m.def( - "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); +#ifdef USE_CUDA +TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop); } +#endif -TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) { - m.impl("torchaudio::_lfilter", lfilter_core); +TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { + m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop); } diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 541c56c475..76deb04a96 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -4,6 +4,7 @@ import torch from torch import Tensor +import torch.nn.functional as F from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE @@ -932,70 +933,74 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T if _IS_TORCHAUDIO_EXT_AVAILABLE: - _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop + _lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop else: - _lfilter_core_cpu_loop = _lfilter_core_generic_loop - - -def _lfilter_core( - waveform: Tensor, - a_coeffs: Tensor, - b_coeffs: Tensor, -) -> Tensor: - - if a_coeffs.size() != b_coeffs.size(): - raise ValueError( - "Expected coeffs to be the same size." - f"Found a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}" - ) - if waveform.ndim != 3: - raise ValueError(f"Expected waveform to be 3 dimensional. Found: {waveform.ndim}") - if not (waveform.device == a_coeffs.device == b_coeffs.device): - raise ValueError( - "Expected waveform and coeffs to be on the same device." - f"Found: waveform device:{waveform.device}, a_coeffs device: {a_coeffs.device}, " - f"b_coeffs device: {b_coeffs.device}" - ) - - n_batch, n_channel, n_sample = waveform.size() - n_order = a_coeffs.size(1) - if n_order <= 0: - raise ValueError(f"Expected n_order to be positive. Found: {n_order}") - - # Pad the input and create output - - padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0]) - padded_output_waveform = torch.zeros_like(padded_waveform) - - # Set up the coefficients matrix - # Flip coefficients' order - a_coeffs_flipped = a_coeffs.flip(1) - b_coeffs_flipped = b_coeffs.flip(1) - - # calculate windowed_input_signal in parallel using convolution - input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel) - - input_signal_windows.div_(a_coeffs[:, :1]) - a_coeffs_flipped.div_(a_coeffs[:, :1]) - - if ( - input_signal_windows.device == torch.device("cpu") - and a_coeffs_flipped.device == torch.device("cpu") - and padded_output_waveform.device == torch.device("cpu") - ): - _lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) - else: - _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) - - output = padded_output_waveform[:, :, n_order - 1 :] - return output - - -if _IS_TORCHAUDIO_EXT_AVAILABLE: - _lfilter = torch.ops.torchaudio._lfilter -else: - _lfilter = _lfilter_core - + _lfilter_core_loop = _lfilter_core_generic_loop + + +class DifferentiableFIR(torch.autograd.Function): + @staticmethod + def forward(ctx, waveform, b_coeffs): + n_order = b_coeffs.size(1) + n_channel = b_coeffs.size(0) + b_coeff_flipped = b_coeffs.flip(1).contiguous() + padded_waveform = F.pad(waveform, (n_order - 1, 0)) + output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel) + ctx.save_for_backward(waveform, b_coeffs, output) + return output + + @staticmethod + def backward(ctx, dy): + x, b_coeffs, y = ctx.saved_tensors + n_batch = x.size(0) + n_channel = x.size(1) + n_order = b_coeffs.size(1) + db = F.conv1d( + F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), + dy.view(n_batch * n_channel, 1, -1), + groups=n_batch * n_channel + ).view( + n_batch, n_channel, -1 + ).sum(0).flip(1) if b_coeffs.requires_grad else None + dx = F.conv1d( + F.pad(dy, (0, n_order - 1)), + b_coeffs.unsqueeze(1), + groups=n_channel + ) if x.requires_grad else None + return (dx, db) + +class DifferentiableIIR(torch.autograd.Function): + @staticmethod + def forward(ctx, waveform, a_coeffs_normalized): + n_batch, n_channel, n_sample = waveform.shape + n_order = a_coeffs_normalized.size(1) + n_sample_padded = n_sample + n_order - 1 + + a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous(); + padded_output_waveform = torch.zeros(n_batch, n_channel, n_sample_padded, + device=waveform.device, dtype=waveform.dtype) + _lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform) + output = padded_output_waveform[:,:,n_order - 1:] + ctx.save_for_backward(waveform, a_coeffs_normalized, output) + return output + + @staticmethod + def backward(ctx, dy): + x, a_coeffs_normalized, y = ctx.saved_tensors + n_channel = x.size(1) + n_order = a_coeffs_normalized.size(1) + tmp = DifferentiableIIR.apply(dy.flip(2).contiguous(), a_coeffs_normalized).flip(2) + dx = tmp if x.requires_grad else None + da = -(tmp.transpose(0, 1).reshape(n_channel, 1, -1) @ + F.pad(y, (n_order - 1, 0)).unfold(2, n_order, 1).transpose(0,1) + .reshape(n_channel, -1, n_order) + ).squeeze(1).flip(1) if a_coeffs_normalized.requires_grad else None + return (dx, da) + +def _lfilter(waveform, a_coeffs, b_coeffs): + n_order = b_coeffs.size(1) + filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1]) + return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1]) def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor: r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation @@ -1066,7 +1071,6 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = return output - def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor: r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.