From eec1d16627e8ccf445308fa91c90c9a7b5fb41fb Mon Sep 17 00:00:00 2001 From: DrStone71 Date: Sun, 24 Aug 2025 20:20:03 +0200 Subject: [PATCH 1/3] Compile CUDA 13 on Nvidia 5090 sm_120 CUDA 13 change library and clean code of redundancy. This fix compile on Nvidia 5090 with CUDA 13.0 --- src/libtorchaudio/forced_align/gpu/compute.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index ef7d9acaae..3486d91226 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -1,3 +1,4 @@ +#include #include #include #include @@ -94,7 +95,7 @@ __global__ void falign_cuda_step_kernel( alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx]; threadMax = max(threadMax, alphas_a[curIdxOffset][i]); } - scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cub::Max()); + scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, thrust::maximum{}); if (threadIdx.x == 0) { maxValue = maxResult; } From b70881be500e2a2c3f0310ca5fb8e153e47e0ca3 Mon Sep 17 00:00:00 2001 From: DrStone71 Date: Tue, 26 Aug 2025 17:53:48 +0200 Subject: [PATCH 2/3] Update src/libtorchaudio/forced_align/gpu/compute.cu Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> --- src/libtorchaudio/forced_align/gpu/compute.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index 3486d91226..b25449d4f2 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -95,7 +95,7 @@ __global__ void falign_cuda_step_kernel( alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx]; threadMax = max(threadMax, alphas_a[curIdxOffset][i]); } - scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, thrust::maximum{}); + scalar_t maxResult = BlockReduce(tempStorage).Reduce(threadMax, cuda::maximum{}); if (threadIdx.x == 0) { maxValue = maxResult; } From c1180095c606c62236d97d3b645e3fea6cecc534 Mon Sep 17 00:00:00 2001 From: DrStone71 Date: Tue, 26 Aug 2025 17:54:17 +0200 Subject: [PATCH 3/3] Update src/libtorchaudio/forced_align/gpu/compute.cu Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> --- src/libtorchaudio/forced_align/gpu/compute.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index b25449d4f2..c7800d39e9 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -1,4 +1,4 @@ -#include +#include #include #include #include