@@ -46,21 +46,23 @@ void overdrive_core_loop_cpu(
46
46
Tensor last_in,
47
47
const Tensor last_out,
48
48
Tensor output_waveform) {
49
- if (waveform.dtype () == aoti_torch_dtype_float64 ()) {
49
+ int32_t dtype;
50
+ aoti_torch_get_dtype (waveform.get (), &dtype);
51
+ if (dtype == aoti_torch_dtype_float64 ()) {
50
52
overdrive_cpu_kernel<double >(
51
53
Accessor<2 , double >(wave_acc),
52
54
Accessor<2 , double >(temp_acc),
53
55
Accessor<1 , double >(last_in_acc),
54
56
Accessor<1 , double >(last_out_acc),
55
57
Accessor<2 , double >(out_acc));
56
- } else if (waveform. dtype () == aoti_torch_dtype_float32 ()) {
58
+ } else if (dtype == aoti_torch_dtype_float32 ()) {
57
59
overdrive_cpu_kernel<float >(
58
60
Accessor<2 , float >(wave_acc),
59
61
Accessor<2 , float >(temp_acc),
60
62
Accessor<1 , float >(last_in_acc),
61
63
Accessor<1 , float >(last_out_acc),
62
64
Accessor<2 , float >(out_acc));
63
- } else if (waveform. dtype () == aoti_torch_dtype_float16 ()) {
65
+ } else if (dtype == aoti_torch_dtype_float16 ()) {
64
66
overdrive_cpu_kernel<c10::Half>(
65
67
Accessor<2 , c10::Half>(wave_acc),
66
68
Accessor<2 , c10::Half>(temp_acc),
0 commit comments