Skip to content

Commit 9897a3a

Browse files
authored
Fix building for CUDA 12.4 and for torch>=2.4 (#1297)
1 parent 07c00d1 commit 9897a3a

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

k2/csrc/ragged_ops.cu

+9-3
Original file line numberDiff line numberDiff line change
@@ -2515,12 +2515,18 @@ struct HashOutputIteratorDeref { // this is what you get when you dereference
25152515

25162516
template <typename T>
25172517
struct HashOutputIterator { // outputs just the index of the pair.
2518-
explicit HashOutputIterator(T *t) : t_(t) {}
2519-
__device__ __forceinline__ HashOutputIteratorDeref<T> operator[](
2518+
explicit __host__ __device__ __forceinline__ HashOutputIterator(T *t)
2519+
: t_(t) {}
2520+
__host__ __device__ __forceinline__ HashOutputIteratorDeref<T> operator[](
25202521
int32_t idx) const {
25212522
return HashOutputIteratorDeref<T>(t_ + idx);
25222523
}
2523-
__device__ __forceinline__ HashOutputIterator operator+(size_t offset) {
2524+
__host__ __device__ __forceinline__ HashOutputIteratorDeref<T> operator*()
2525+
const {
2526+
return HashOutputIteratorDeref<T>(t_);
2527+
}
2528+
__host__ __device__ __forceinline__ HashOutputIterator
2529+
operator+(size_t offset) {
25242530
return HashOutputIterator{t_ + offset};
25252531
}
25262532
T *t_;

k2/csrc/ragged_ops_inl.h

+9-3
Original file line numberDiff line numberDiff line change
@@ -578,12 +578,18 @@ struct PairOutputIteratorDeref { // this is what you get when you dereference
578578

579579
template <typename T>
580580
struct PairOutputIterator { // outputs just the index of the pair.
581-
explicit PairOutputIterator(int32_t *i) : i_(i) {}
582-
__device__ __forceinline__ PairOutputIteratorDeref<T> operator[](
581+
explicit __host__ __device__ __forceinline__ PairOutputIterator(int32_t *i)
582+
: i_(i) {}
583+
__host__ __device__ __forceinline__ PairOutputIteratorDeref<T> operator[](
583584
int32_t idx) const {
584585
return PairOutputIteratorDeref<T>(i_ + idx);
585586
}
586-
__device__ __forceinline__ PairOutputIterator operator+(int32_t offset) {
587+
__host__ __device__ __forceinline__ PairOutputIteratorDeref<T> operator*()
588+
const {
589+
return PairOutputIteratorDeref<T>(i_);
590+
}
591+
__host__ __device__ __forceinline__ PairOutputIterator
592+
operator+(int32_t offset) {
587593
return PairOutputIterator{i_ + offset};
588594
}
589595
int32_t *i_;

k2/python/csrc/torch.h

+9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
#include "k2/python/csrc/torch.h"
3131
#include "torch/extension.h"
3232

33+
#if K2_TORCH_VERSION_MAJOR > 2 || \
34+
(K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 4)
35+
// For torch >= 2.4.x
36+
// do nothing to fix the following error
37+
// error: class "pybind11::detail::type_caster<c10::ScalarType, void>" has
38+
// already been defined
39+
#else
40+
// For torch < 2.4
3341
namespace pybind11 {
3442
namespace detail {
3543

@@ -71,6 +79,7 @@ struct type_caster<torch::ScalarType> {
7179

7280
} // namespace detail
7381
} // namespace pybind11
82+
#endif
7483

7584
namespace k2 {
7685
/* Transfer an object to a specific device.

0 commit comments

Comments
 (0)