44template <class Tmem >
55static __global__ void rearrange (
66 void *__restrict__ dst,
7- unsigned int const rsa,
8- unsigned int const csa,
7+ int const rsa,
8+ int const csa,
99 void const *__restrict__ src,
10- unsigned int const rsb,
11- unsigned int const csb,
10+ int const rsb,
11+ int const csb,
1212 unsigned int const ncols) {
1313
1414 auto row = blockIdx .y ,
@@ -25,39 +25,42 @@ static __global__ void rearrange(
2525
2626void rearrange_nv_gpu (RearrangeCudaDescriptor_t desc, void *y, void const *x, void *stream) {
2727 auto cuda_stream = reinterpret_cast <cudaStream_t>(stream);
28- if (desc->r == 1 && desc->c == 1 && desc->b == 1 ) {
29- cudaMemcpyAsync (y, x, desc->bytes_per_thread , cudaMemcpyDeviceToDevice, cuda_stream);
28+ auto unit = desc->unit ,
29+ r = desc->r , c = desc->c ;
30+ auto dst_rs = desc->dst_rs , dst_cs = desc->dst_cs ,
31+ src_rs = desc->src_rs , src_cs = desc->src_cs ;
32+
33+ if (r == 1 && c == 1 ) {
34+ cudaMemcpyAsync (y, x, unit, cudaMemcpyDeviceToDevice, cuda_stream);
3035 return ;
3136 }
3237
33- uint64_t rsa = desc->rsa , csa = desc->csa , rsb = desc->rsb , csb = desc->csb ;
34- unsigned int r = desc->r , c = desc->c , b = desc->b , bytes_per_thread = desc->bytes_per_thread ;
35- auto dst_ptr = static_cast <void *>(reinterpret_cast <uint8_t *>(y));
36- rsa /= b;
37- csa /= b;
38- auto src_ptr = static_cast <void const *>(reinterpret_cast <uint8_t const *>(x));
39- rsb /= b;
40- csb /= b;
41- dim3 grid_dims = dim3 ((c + MAX_WARP_PER_BLOCK - 1 ) / MAX_WARP_PER_BLOCK, r);
42- dim3 block_dims = dim3 (WARP_SIZE, (c + grid_dims.x - 1 ) / grid_dims.x );
43- switch (bytes_per_thread) {
38+ auto warps = 1024 / WARP_SIZE;
39+ auto grid = dim3 ((c + warps - 1 ) / warps, r);
40+ auto block = dim3 (WARP_SIZE, (c + grid.x - 1 ) / grid.x );
41+ dst_rs /= unit;
42+ dst_cs /= unit;
43+ src_rs /= unit;
44+ src_cs /= unit;
45+
46+ switch (unit / WARP_SIZE) {
4447 case 1 :
45- rearrange<uchar1 ><<<grid_dims, block_dims , 0 , cuda_stream>>> (dst_ptr, rsa, csa, src_ptr, rsb, csb , c);
48+ rearrange<uchar1 ><<<grid, block , 0 , cuda_stream>>> (y, dst_rs, dst_cs, x, src_rs, src_cs , c);
4649 break ;
4750 case 2 :
48- rearrange<uchar2 ><<<grid_dims, block_dims , 0 , cuda_stream>>> (dst_ptr, rsa, csa, src_ptr, rsb, csb , c);
51+ rearrange<uchar2 ><<<grid, block , 0 , cuda_stream>>> (y, dst_rs, dst_cs, x, src_rs, src_cs , c);
4952 break ;
5053 case 4 :
51- rearrange<float1 ><<<grid_dims, block_dims , 0 , cuda_stream>>> (dst_ptr, rsa, csa, src_ptr, rsb, csb , c);
54+ rearrange<float1 ><<<grid, block , 0 , cuda_stream>>> (y, dst_rs, dst_cs, x, src_rs, src_cs , c);
5255 break ;
5356 case 8 :
54- rearrange<float2 ><<<grid_dims, block_dims , 0 , cuda_stream>>> (dst_ptr, rsa, csa, src_ptr, rsb, csb , c);
57+ rearrange<float2 ><<<grid, block , 0 , cuda_stream>>> (y, dst_rs, dst_cs, x, src_rs, src_cs , c);
5558 break ;
5659 case 16 :
57- rearrange<float4 ><<<grid_dims, block_dims , 0 , cuda_stream>>> (dst_ptr, rsa, csa, src_ptr, rsb, csb , c);
60+ rearrange<float4 ><<<grid, block , 0 , cuda_stream>>> (y, dst_rs, dst_cs, x, src_rs, src_cs , c);
5861 break ;
5962 case 32 :
60- rearrange<double4 ><<<grid_dims, block_dims , 0 , cuda_stream>>> (dst_ptr, rsa, csa, src_ptr, rsb, csb , c);
63+ rearrange<double4 ><<<grid, block , 0 , cuda_stream>>> (y, dst_rs, dst_cs, x, src_rs, src_cs , c);
6164 break ;
6265 default :
6366 break ;
0 commit comments