diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 85bc227cd6..20ad792b32 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -6,6 +6,7 @@ set( lfilter.cpp overdrive.cpp utils.cpp + accessor_tests.cpp ) set( diff --git a/src/libtorchaudio/accessor.h b/src/libtorchaudio/accessor.h new file mode 100644 index 0000000000..6211acff3d --- /dev/null +++ b/src/libtorchaudio/accessor.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include + +using torch::stable::Tensor; + +template +class Accessor { + int64_t strides[k]; + int64_t sizes[k]; + T *data; + +public: + using tensor_type = typename std::conditional::type; + + Accessor(tensor_type tensor) { + auto raw_ptr = tensor.data_ptr(); + data = static_cast(raw_ptr); + for (unsigned int i = 0; i < k; i++) { + strides[i] = tensor.stride(i); + sizes[i] = tensor.size(i); + } + } + + T index(...) { + va_list args; + va_start(args, k); + int64_t ix = 0; + for (unsigned int i = 0; i < k; i++) { + ix += strides[i] * va_arg(args, int); + } + va_end(args); + return data[ix]; + } + + int64_t size(int dim) { + return sizes[dim]; + } + + template + typename std::enable_if::type set_index(T value, ...) { + va_list args; + va_start(args, value); + int64_t ix = 0; + for (unsigned int i = 0; i < k; i++) { + ix += strides[i] * va_arg(args, int); + } + va_end(args); + data[ix] = value; + } +}; diff --git a/src/libtorchaudio/accessor_tests.cpp b/src/libtorchaudio/accessor_tests.cpp new file mode 100644 index 0000000000..45312a8408 --- /dev/null +++ b/src/libtorchaudio/accessor_tests.cpp @@ -0,0 +1,44 @@ +#include +#include +#include + +namespace torchaudio { + +namespace accessor_tests { + +using namespace std; +using torch::stable::Tensor; + +bool test_accessor(const Tensor tensor) { + int64_t* data_ptr = (int64_t*)tensor.data_ptr(); + auto accessor = Accessor<3, int64_t>(tensor); + for (unsigned int i = 0; i < tensor.size(0); i++) { + for (unsigned int j = 0; j < tensor.size(1); j++) { + for (unsigned int k = 0; k < tensor.size(2); k++) { + auto check = *(data_ptr++) == accessor.index(i, j, k); + if (!check) { + return false; + } + } + } + } + return true; +} + +void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + auto result = test_accessor(std::move(t1)); + stack[0] = from(result); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "_test_accessor(Tensor log_probs) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("torchaudio::_test_accessor", &boxed_test_accessor); +} + +} +} diff --git a/src/libtorchaudio/forced_align/cpu/compute.cpp b/src/libtorchaudio/forced_align/cpu/compute.cpp index 0ddd21b126..4b87b20e5f 100644 --- a/src/libtorchaudio/forced_align/cpu/compute.cpp +++ b/src/libtorchaudio/forced_align/cpu/compute.cpp @@ -1,8 +1,5 @@ #include #include -#include -#include -#include #include using namespace std; diff --git a/src/libtorchaudio/lfilter.cpp b/src/libtorchaudio/lfilter.cpp index 4a130f34d5..478581b64a 100644 --- a/src/libtorchaudio/lfilter.cpp +++ b/src/libtorchaudio/lfilter.cpp @@ -1,5 +1,10 @@ #include #include +#include +#include +#include +#include +#include #ifdef USE_CUDA #include @@ -7,19 +12,24 @@ namespace { +using torch::stable::Tensor; + template void host_lfilter_core_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { + const Tensor input_signal_windows, + const Tensor a_coeff_flipped, + Tensor padded_output_waveform) { int64_t n_batch = input_signal_windows.size(0); int64_t n_channel = input_signal_windows.size(1); int64_t n_samples_input = input_signal_windows.size(2); int64_t n_samples_output = padded_output_waveform.size(2); int64_t n_order = a_coeff_flipped.size(1); - scalar_t* output_data = padded_output_waveform.data_ptr(); - const scalar_t* input_data = input_signal_windows.data_ptr(); - const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr(); + scalar_t *output_data; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(padded_output_waveform.get(), (void**)&output_data)); + scalar_t *input_data; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(input_signal_windows.get(), (void**)&input_data)); + scalar_t *a_coeff_flipped_data; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(a_coeff_flipped.get(), (void**)&a_coeff_flipped_data)); at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) { for (auto i = begin; i < end; i++) { @@ -39,25 +49,32 @@ void host_lfilter_core_loop( } void cpu_lfilter_core_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { + const Tensor input_signal_windows, + const Tensor a_coeff_flipped, + Tensor padded_output_waveform) { TORCH_CHECK( - input_signal_windows.device().is_cpu() && - a_coeff_flipped.device().is_cpu() && - padded_output_waveform.device().is_cpu()); + input_signal_windows.is_cpu() && + a_coeff_flipped.is_cpu() && + padded_output_waveform.is_cpu()); TORCH_CHECK( input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() && padded_output_waveform.is_contiguous()); + int32_t input_signal_windows_dtype; + int32_t a_coeff_flipped_dtype; + int32_t padded_output_waveform_dtype; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(input_signal_windows.get(), &input_signal_windows_dtype)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(a_coeff_flipped.get(), &a_coeff_flipped_dtype)); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(padded_output_waveform.get(), &padded_output_waveform_dtype)); + TORCH_CHECK( - (input_signal_windows.dtype() == torch::kFloat32 || - input_signal_windows.dtype() == torch::kFloat64) && - (a_coeff_flipped.dtype() == torch::kFloat32 || - a_coeff_flipped.dtype() == torch::kFloat64) && - (padded_output_waveform.dtype() == torch::kFloat32 || - padded_output_waveform.dtype() == torch::kFloat64)); + (input_signal_windows_dtype == aoti_torch_dtype_float32() || + input_signal_windows_dtype == aoti_torch_dtype_float64()) && + (a_coeff_flipped_dtype == aoti_torch_dtype_float32() || + a_coeff_flipped_dtype == aoti_torch_dtype_float64()) && + (padded_output_waveform_dtype == aoti_torch_dtype_float32() || + padded_output_waveform_dtype == aoti_torch_dtype_float64())); TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1)); @@ -66,51 +83,36 @@ void cpu_lfilter_core_loop( input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 == padded_output_waveform.size(2)); - AT_DISPATCH_FLOATING_TYPES( - input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { - host_lfilter_core_loop( + if (input_signal_windows_dtype == aoti_torch_dtype_float32()) { + host_lfilter_core_loop( input_signal_windows, a_coeff_flipped, padded_output_waveform); - }); -} - -void lfilter_core_generic_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { - int64_t n_samples_input = input_signal_windows.size(2); - int64_t n_order = a_coeff_flipped.size(1); - auto coeff = a_coeff_flipped.unsqueeze(2); - for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - auto windowed_output_signal = - torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order).transpose(0, 1); - auto o0 = - torch::select(input_signal_windows, 2, i_sample) - - at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); - padded_output_waveform.index_put_( - {torch::indexing::Slice(), - torch::indexing::Slice(), - i_sample + n_order - 1}, - o0); + } else if (input_signal_windows_dtype == aoti_torch_dtype_float64()) { + host_lfilter_core_loop( + input_signal_windows, a_coeff_flipped, padded_output_waveform); } } } // namespace -TORCH_LIBRARY(torchaudio, m) { - m.def( - "torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()"); +void boxed_cpu_lfilter_core_loop(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + cpu_lfilter_core_loop( + std::move(t1), std::move(t2), std::move(t3)); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()"); } -#ifdef USE_CUDA -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("torchaudio::_lfilter_core_loop", &boxed_cpu_lfilter_core_loop); } -#endif -TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { - m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop); -} +// #ifdef USE_CUDA +// STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { +// m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop); +// } +// #endif diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 76deb04a96..5b9b3f04ba 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -933,7 +933,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T if _IS_TORCHAUDIO_EXT_AVAILABLE: - _lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop + _lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop.default else: _lfilter_core_loop = _lfilter_core_generic_loop diff --git a/test/torchaudio_unittest/accessor_test.py b/test/torchaudio_unittest/accessor_test.py new file mode 100644 index 0000000000..db14258dc6 --- /dev/null +++ b/test/torchaudio_unittest/accessor_test.py @@ -0,0 +1,7 @@ +import torch +from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE + +if _IS_TORCHAUDIO_EXT_AVAILABLE: + def test_accessor(): + tensor = torch.randint(1000, (5,4,3)) + assert torch.ops.torchaudio._test_accessor(tensor)