Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(
lfilter.cpp
overdrive.cpp
utils.cpp
accessor_tests.cpp
)

set(
Expand Down
53 changes: 53 additions & 0 deletions src/libtorchaudio/accessor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <torch/csrc/stable/tensor_struct.h>
#include <type_traits>
#include <cstdarg>

using torch::stable::Tensor;

template<unsigned int k, typename T, bool IsConst = true>
class Accessor {
int64_t strides[k];
int64_t sizes[k];
T *data;

public:
using tensor_type = typename std::conditional<IsConst, const Tensor&, Tensor&>::type;

Accessor(tensor_type tensor) {
auto raw_ptr = tensor.data_ptr();
data = static_cast<T*>(raw_ptr);
for (unsigned int i = 0; i < k; i++) {
strides[i] = tensor.stride(i);
sizes[i] = tensor.size(i);
}
}

T index(...) {
va_list args;
va_start(args, k);
int64_t ix = 0;
for (unsigned int i = 0; i < k; i++) {
ix += strides[i] * va_arg(args, int);
}
va_end(args);
return data[ix];
}

int64_t size(int dim) {
return sizes[dim];
}

template<bool C = IsConst>
typename std::enable_if<!C, void>::type set_index(T value, ...) {
va_list args;
va_start(args, value);
int64_t ix = 0;
for (unsigned int i = 0; i < k; i++) {
ix += strides[i] * va_arg(args, int);
}
va_end(args);
data[ix] = value;
}
};
44 changes: 44 additions & 0 deletions src/libtorchaudio/accessor_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <libtorchaudio/accessor.h>
#include <cstdint>
#include <torch/csrc/stable/library.h>

namespace torchaudio {

namespace accessor_tests {

using namespace std;
using torch::stable::Tensor;

bool test_accessor(const Tensor tensor) {
int64_t* data_ptr = (int64_t*)tensor.data_ptr();
auto accessor = Accessor<3, int64_t>(tensor);
for (unsigned int i = 0; i < tensor.size(0); i++) {
for (unsigned int j = 0; j < tensor.size(1); j++) {
for (unsigned int k = 0; k < tensor.size(2); k++) {
auto check = *(data_ptr++) == accessor.index(i, j, k);
if (!check) {
return false;
}
}
}
}
return true;
}

void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor t1(to<AtenTensorHandle>(stack[0]));
auto result = test_accessor(std::move(t1));
stack[0] = from(result);
}

STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"_test_accessor(Tensor log_probs) -> bool");
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_test_accessor", &boxed_test_accessor);
}

}
}
3 changes: 0 additions & 3 deletions src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

using namespace std;
Expand Down
110 changes: 56 additions & 54 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>

#ifdef USE_CUDA
#include <libtorchaudio/iir_cuda.h>
#endif

namespace {

using torch::stable::Tensor;

template <typename scalar_t>
void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
const Tensor input_signal_windows,
const Tensor a_coeff_flipped,
Tensor padded_output_waveform) {
int64_t n_batch = input_signal_windows.size(0);
int64_t n_channel = input_signal_windows.size(1);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_samples_output = padded_output_waveform.size(2);
int64_t n_order = a_coeff_flipped.size(1);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
scalar_t *output_data;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(padded_output_waveform.get(), (void**)&output_data));
scalar_t *input_data;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(input_signal_windows.get(), (void**)&input_data));
scalar_t *a_coeff_flipped_data;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(a_coeff_flipped.get(), (void**)&a_coeff_flipped_data));

at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; i++) {
Expand All @@ -39,25 +49,32 @@ void host_lfilter_core_loop(
}

void cpu_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
const Tensor input_signal_windows,
const Tensor a_coeff_flipped,
Tensor padded_output_waveform) {
TORCH_CHECK(
input_signal_windows.device().is_cpu() &&
a_coeff_flipped.device().is_cpu() &&
padded_output_waveform.device().is_cpu());
input_signal_windows.is_cpu() &&
a_coeff_flipped.is_cpu() &&
padded_output_waveform.is_cpu());

TORCH_CHECK(
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
padded_output_waveform.is_contiguous());

int32_t input_signal_windows_dtype;
int32_t a_coeff_flipped_dtype;
int32_t padded_output_waveform_dtype;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(input_signal_windows.get(), &input_signal_windows_dtype));
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(a_coeff_flipped.get(), &a_coeff_flipped_dtype));
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(padded_output_waveform.get(), &padded_output_waveform_dtype));

TORCH_CHECK(
(input_signal_windows.dtype() == torch::kFloat32 ||
input_signal_windows.dtype() == torch::kFloat64) &&
(a_coeff_flipped.dtype() == torch::kFloat32 ||
a_coeff_flipped.dtype() == torch::kFloat64) &&
(padded_output_waveform.dtype() == torch::kFloat32 ||
padded_output_waveform.dtype() == torch::kFloat64));
(input_signal_windows_dtype == aoti_torch_dtype_float32() ||
input_signal_windows_dtype == aoti_torch_dtype_float64()) &&
(a_coeff_flipped_dtype == aoti_torch_dtype_float32() ||
a_coeff_flipped_dtype == aoti_torch_dtype_float64()) &&
(padded_output_waveform_dtype == aoti_torch_dtype_float32() ||
padded_output_waveform_dtype == aoti_torch_dtype_float64()));

TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));
Expand All @@ -66,51 +83,36 @@ void cpu_lfilter_core_loop(
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
padded_output_waveform.size(2));

AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
host_lfilter_core_loop<scalar_t>(
if (input_signal_windows_dtype == aoti_torch_dtype_float32()) {
host_lfilter_core_loop<float>(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
});
}

void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(1);
auto coeff = a_coeff_flipped.unsqueeze(2);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal =
torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order).transpose(0, 1);
auto o0 =
torch::select(input_signal_windows, 2, i_sample) -
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(),
torch::indexing::Slice(),
i_sample + n_order - 1},
o0);
} else if (input_signal_windows_dtype == aoti_torch_dtype_float64()) {
host_lfilter_core_loop<double>(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
}
}

} // namespace

TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
void boxed_cpu_lfilter_core_loop(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor t1(to<AtenTensorHandle>(stack[0]));
Tensor t2(to<AtenTensorHandle>(stack[1]));
Tensor t3(to<AtenTensorHandle>(stack[2]));
cpu_lfilter_core_loop(
std::move(t1), std::move(t2), std::move(t3));
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
}

#ifdef USE_CUDA
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_lfilter_core_loop", &boxed_cpu_lfilter_core_loop);
}
#endif

TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop);
}
// #ifdef USE_CUDA
// STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
// m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
// }
// #endif
2 changes: 1 addition & 1 deletion src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop
_lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop.default
else:
_lfilter_core_loop = _lfilter_core_generic_loop

Expand Down
7 changes: 7 additions & 0 deletions test/torchaudio_unittest/accessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch
from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE

if _IS_TORCHAUDIO_EXT_AVAILABLE:
def test_accessor():
tensor = torch.randint(1000, (5,4,3))
assert torch.ops.torchaudio._test_accessor(tensor)
Loading