Skip to content

Commit 29313d1

Browse files
committed
Use C dtype function
1 parent 2e3e418 commit 29313d1

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/libtorchaudio/overdrive.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,23 @@ void overdrive_core_loop_cpu(
4646
Tensor last_in,
4747
const Tensor last_out,
4848
Tensor output_waveform) {
49-
if (waveform.dtype() == aoti_torch_dtype_float64()) {
49+
int32_t dtype;
50+
aoti_torch_get_dtype(waveform.get(), &dtype);
51+
if (dtype == aoti_torch_dtype_float64()) {
5052
overdrive_cpu_kernel<double>(
5153
Accessor<2, double>(wave_acc),
5254
Accessor<2, double>(temp_acc),
5355
Accessor<1, double>(last_in_acc),
5456
Accessor<1, double>(last_out_acc),
5557
Accessor<2, double>(out_acc));
56-
} else if (waveform.dtype() == aoti_torch_dtype_float32()) {
58+
} else if (dtype == aoti_torch_dtype_float32()) {
5759
overdrive_cpu_kernel<float>(
5860
Accessor<2, float>(wave_acc),
5961
Accessor<2, float>(temp_acc),
6062
Accessor<1, float>(last_in_acc),
6163
Accessor<1, float>(last_out_acc),
6264
Accessor<2, float>(out_acc));
63-
} else if (waveform.dtype() == aoti_torch_dtype_float16()) {
65+
} else if (dtype == aoti_torch_dtype_float16()) {
6466
overdrive_cpu_kernel<c10::Half>(
6567
Accessor<2, c10::Half>(wave_acc),
6668
Accessor<2, c10::Half>(temp_acc),

0 commit comments

Comments
 (0)