diff --git a/detrex/layers/csrc/DCNv3/dcnv3_cuda.cu b/detrex/layers/csrc/DCNv3/dcnv3_cuda.cu index 037de103..5e4c08e9 100644 --- a/detrex/layers/csrc/DCNv3/dcnv3_cuda.cu +++ b/detrex/layers/csrc/DCNv3/dcnv3_cuda.cu @@ -67,7 +67,7 @@ at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, auto columns = output_n.select(0, n); // AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "ms_deform_attn_forward_cuda", ([&] { + input.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { dcnv3_im2col_cuda( at::cuda::getCurrentCUDAStream(), input.data() + n * im2col_step_ * per_input_size, @@ -145,7 +145,7 @@ dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, auto grad_output_g = grad_output_n.select(0, n); // AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.type(), "ms_deform_attn_backward_cuda", ([&] { + input.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { dcnv3_col2im_cuda( at::cuda::getCurrentCUDAStream(), grad_output_g.data(), diff --git a/detrex/layers/csrc/MsDeformAttn/ms_deform_attn_cuda.cu b/detrex/layers/csrc/MsDeformAttn/ms_deform_attn_cuda.cu index 554e4938..6fda23df 100644 --- a/detrex/layers/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +++ b/detrex/layers/csrc/MsDeformAttn/ms_deform_attn_cuda.cu @@ -32,11 +32,11 @@ at::Tensor ms_deform_attn_cuda_forward( AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(value.device().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.device().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.device().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.device().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.device().is_cuda(), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); @@ -62,15 +62,15 @@ at::Tensor ms_deform_attn_cuda_forward( for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), - value.data() + n * im2col_step_ * per_value_size, - spatial_shapes.data(), - level_start_index.data(), - sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - columns.data()); + columns.data_ptr()); })); } @@ -98,12 +98,12 @@ std::vector ms_deform_attn_cuda_backward( AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); - AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); - AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); - AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); - AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); - AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + AT_ASSERTM(value.device().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.device().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.device().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.device().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.device().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.device().is_cuda(), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); @@ -132,18 +132,18 @@ std::vector ms_deform_attn_cuda_backward( for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), - grad_output_g.data(), - value.data() + n * im2col_step_ * per_value_size, - spatial_shapes.data(), - level_start_index.data(), - sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + grad_output_g.data_ptr(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - grad_value.data() + n * im2col_step_ * per_value_size, - grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + grad_value.data_ptr() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size); })); } diff --git a/detrex/layers/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh b/detrex/layers/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh index 6bc2acb7..fa0e18c3 100644 --- a/detrex/layers/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh +++ b/detrex/layers/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh @@ -266,7 +266,7 @@ __global__ void ms_deformable_im2col_gpu_kernel(const int n, int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride + q_col * qid_stride; scalar_t col = 0; for (int l_col=0; l_col < num_levels; ++l_col) @@ -342,7 +342,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(co const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride + q_col * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { @@ -447,7 +447,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(co const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride + q_col * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { @@ -555,7 +555,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride + q_col * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { @@ -660,7 +660,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride + q_col * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { @@ -773,7 +773,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride + q_col * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { @@ -883,7 +883,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; - const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride + q_col * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) {