Skip to content

Commit 2264907

Browse files
committed
Use stable tensors in overdrive
1 parent 0b6ff52 commit 2264907

File tree

1 file changed

+76
-30
lines changed

1 file changed

+76
-30
lines changed

src/libtorchaudio/overdrive.cpp

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,98 @@
11
#include <torch/script.h>
22
#include <torch/torch.h>
3+
#include <torch/csrc/stable/library.h>
4+
#include <torch/csrc/stable/tensor.h>
5+
#include <torch/csrc/stable/ops.h>
6+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
7+
#include <torch/csrc/inductor/aoti_torch/utils.h>
8+
#include <libtorchaudio/accessor.h>
39

4-
namespace {
10+
using namespace std;
11+
12+
namespace torchaudio {
13+
14+
using torch::stable::Tensor;
515

616
template <typename scalar_t>
717
void overdrive_cpu_kernel(
8-
at::TensorAccessor<scalar_t, 2> waveform_accessor,
9-
at::TensorAccessor<scalar_t, 2> temp_accessor,
10-
at::TensorAccessor<scalar_t, 1> last_in_accessor,
11-
at::TensorAccessor<scalar_t, 1> last_out_accessor,
12-
at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
18+
Accessor<2, scalar_t> waveform_accessor,
19+
Accessor<2, scalar_t> temp_accessor,
20+
Accessor<1, scalar_t, false> last_in_accessor,
21+
Accessor<1, scalar_t> last_out_accessor,
22+
Accessor<2, scalar_t, false> output_waveform_accessor) {
1323
int64_t n_frames = waveform_accessor.size(1);
1424
int64_t n_channels = waveform_accessor.size(0);
1525

1626
at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) {
1727
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
1828
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
19-
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
20-
last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
21-
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
22-
output_waveform_accessor[i_channel][i_frame] =
23-
waveform_accessor[i_channel][i_frame] * 0.5 +
24-
last_out_accessor[i_channel] * 0.75;
29+
last_out_accessor.set_index(
30+
temp_accessor.index(i_channel, i_frame) -
31+
last_in_accessor.index(i_channel) + 0.995 * last_out_accessor.index(i_channel),
32+
i_channel);
33+
last_in_accessor.set_index(temp_accessor.index(i_channel, i_frame), i_channel);
34+
output_waveform_accessor.set_index(
35+
waveform_accessor.index(i_channel, i_frame) * 0.5 +
36+
last_out_accessor.index(i_channel) * 0.75,
37+
i_channel, i_frame);
2538
}
2639
}
2740
});
2841
}
2942

3043
void overdrive_core_loop_cpu(
31-
at::Tensor& waveform,
32-
at::Tensor& temp,
33-
at::Tensor& last_in,
34-
at::Tensor& last_out,
35-
at::Tensor& output_waveform) {
36-
AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
37-
overdrive_cpu_kernel<scalar_t>(
38-
waveform.accessor<scalar_t, 2>(),
39-
temp.accessor<scalar_t, 2>(),
40-
last_in.accessor<scalar_t, 1>(),
41-
last_out.accessor<scalar_t, 1>(),
42-
output_waveform.accessor<scalar_t, 2>());
43-
}));
44+
const Tensor waveform,
45+
const Tensor temp,
46+
Tensor last_in,
47+
const Tensor last_out,
48+
Tensor output_waveform) {
49+
int32_t dtype;
50+
aoti_torch_get_dtype(waveform.get(), &dtype);
51+
if (dtype == aoti_torch_dtype_float64()) {
52+
overdrive_cpu_kernel<double>(
53+
Accessor<2, double>(wave_acc),
54+
Accessor<2, double>(temp_acc),
55+
Accessor<1, double>(last_in_acc),
56+
Accessor<1, double>(last_out_acc),
57+
Accessor<2, double>(out_acc));
58+
} else if (dtype == aoti_torch_dtype_float32()) {
59+
overdrive_cpu_kernel<float>(
60+
Accessor<2, float>(wave_acc),
61+
Accessor<2, float>(temp_acc),
62+
Accessor<1, float>(last_in_acc),
63+
Accessor<1, float>(last_out_acc),
64+
Accessor<2, float>(out_acc));
65+
} else if (dtype == aoti_torch_dtype_float16()) {
66+
overdrive_cpu_kernel<c10::Half>(
67+
Accessor<2, c10::Half>(wave_acc),
68+
Accessor<2, c10::Half>(temp_acc),
69+
Accessor<1, c10::Half>(last_in_acc),
70+
Accessor<1, c10::Half>(last_out_acc),
71+
Accessor<2, c10::Half>(out_acc));
72+
}
4473
}
4574

46-
} // namespace
4775

48-
// Note: We want to avoid using "catch-all" kernel.
49-
// The following registration should be replaced with CPU specific registration.
50-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
51-
m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu);
76+
77+
void boxed_overdrive_core_loop(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
78+
Tensor t1(to<AtenTensorHandle>(stack[0]));
79+
Tensor t2(to<AtenTensorHandle>(stack[1]));
80+
Tensor t3(to<AtenTensorHandle>(stack[2]));
81+
Tensor t4(to<AtenTensorHandle>(stack[3]));
82+
Tensor t5(to<AtenTensorHandle>(stack[4]));
83+
overdrive_core_loop(
84+
std::move(t1), std::move(t2), std::move(t3), std::move(t4), std::move(t5));
5285
}
86+
87+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
88+
m.def(
89+
"overdrive_core_loop(Tensor waveform,"
90+
"Tensor temp, Tensor last_in, Tensor last_out,"
91+
"Tensor output_waveform)"
92+
}
93+
94+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
95+
m.impl("overdrive_core_loop", &overdrive_core_loop_cpu);
96+
}
97+
98+
} // namespace

0 commit comments

Comments
 (0)