Skip to content

Commit d44beb0

Browse files
committed
fix(nv): 改正 rearrange
Signed-off-by: YdrMaster <[email protected]>
1 parent a03df4e commit d44beb0

File tree

3 files changed

+65
-84
lines changed

3 files changed

+65
-84
lines changed

src/ops/rearrange/cuda/rearrange.cc

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ infiniopStatus_t cudaCreateRearrangeDescriptor(CudaHandle_t handle,
77
RearrangeCudaDescriptor_t *desc_ptr,
88
infiniopTensorDescriptor_t dst,
99
infiniopTensorDescriptor_t src) {
10-
if (!dtype_eq(dst->dt, src->dt)) {
10+
auto dt = dst->dt;
11+
if (!dtype_eq(src->dt, dt)) {
1112
return STATUS_BAD_TENSOR_DTYPE;
1213
}
1314

@@ -24,62 +25,43 @@ infiniopStatus_t cudaCreateRearrangeDescriptor(CudaHandle_t handle,
2425
return STATUS_BAD_TENSOR_STRIDES;
2526
}
2627

27-
if (ndim == 1) {
28-
*desc_ptr = new RearrangeCudaDescriptor{
29-
handle->device,
30-
handle->device_id,
31-
0, 0, 0, 0,
32-
1, 1, 1,
33-
static_cast<unsigned long>(dst->shape[0] * dst->dt.size)};
34-
return STATUS_SUCCESS;
28+
switch (ndim) {
29+
case 1:
30+
*desc_ptr = new RearrangeCudaDescriptor{
31+
handle->device,
32+
handle->device_id,
33+
dt.size * dst->shape[0],
34+
1, 1,
35+
0, 0,
36+
0, 0};
37+
break;
38+
case 2:
39+
*desc_ptr = new RearrangeCudaDescriptor{
40+
handle->device,
41+
handle->device_id,
42+
dt.size * dst->shape[1],
43+
1, dst->shape[0],
44+
0, dst->strides[0],
45+
0, src->strides[0]};
46+
break;
47+
case 3:
48+
*desc_ptr = new RearrangeCudaDescriptor{
49+
handle->device,
50+
handle->device_id,
51+
dt.size * dst->shape[2],
52+
dst->shape[0], dst->shape[1],
53+
dst->strides[0], dst->strides[1],
54+
src->strides[0], src->strides[1]};
55+
break;
56+
default:
57+
return STATUS_BAD_TENSOR_SHAPE;
3558
}
3659

37-
unsigned int r = 0, c = 0, b = 0;
38-
unsigned int rsa = 0, csa = 0, rsb = 0, csb = 0;
39-
if (ndim == 2) {
40-
c = dst->shape[0];
41-
b = dst->shape[1];
42-
csa = dst->strides[0];
43-
csb = src->strides[0];
44-
} else if (ndim == 3) {
45-
r = dst->shape[0];
46-
c = dst->shape[1];
47-
b = dst->shape[2];
48-
csa = dst->strides[1];
49-
csb = src->strides[1];
50-
rsa = dst->strides[0];
51-
rsb = src->strides[0];
52-
} else {
53-
for (int i = ndim - 3; i >= 1; --i) {
54-
if (dst->shape[i] * dst->strides[i] != dst->strides[i - 1] || src->shape[i] * src->strides[i] != src->strides[i - 1]) {
55-
return STATUS_BAD_TENSOR_STRIDES;
56-
}
57-
}
58-
r = std::accumulate(dst->shape, dst->shape + ndim - 2, 1, std::multiplies<unsigned int>());
59-
c = dst->shape[ndim - 2];
60-
b = dst->shape[ndim - 1];
61-
csa = dst->strides[ndim - 2];
62-
csb = src->strides[ndim - 2];
63-
rsa = dst->strides[ndim - 3];
64-
rsb = src->strides[ndim - 3];
65-
}
66-
auto contiguous_bytes = b * dst->dt.size;
67-
if (contiguous_bytes % WARP_SIZE != 0) {
68-
return STATUS_BAD_PARAM;
69-
}
70-
auto bytes_per_thread = contiguous_bytes / WARP_SIZE;
71-
if (bytes_per_thread <= 0 || bytes_per_thread > 32 || (bytes_per_thread & (bytes_per_thread - 1)) != 0) {
72-
return STATUS_BAD_PARAM;
73-
}
74-
*desc_ptr = new RearrangeCudaDescriptor{
75-
handle->device,
76-
handle->device_id,
77-
rsa,
78-
rsb,
79-
csa,
80-
csb,
81-
r, c, b,
82-
bytes_per_thread};
60+
(*desc_ptr)->dst_rs *= dt.size;
61+
(*desc_ptr)->dst_cs *= dt.size;
62+
(*desc_ptr)->src_rs *= dt.size;
63+
(*desc_ptr)->src_cs *= dt.size;
64+
8365
return STATUS_SUCCESS;
8466
}
8567
infiniopStatus_t cudaDestroyRearrangeDescriptor(RearrangeCudaDescriptor_t desc) {

src/ops/rearrange/cuda/rearrange.cu

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
template<class Tmem>
55
static __global__ void rearrange(
66
void *__restrict__ dst,
7-
unsigned int const rsa,
8-
unsigned int const csa,
7+
int const rsa,
8+
int const csa,
99
void const *__restrict__ src,
10-
unsigned int const rsb,
11-
unsigned int const csb,
10+
int const rsb,
11+
int const csb,
1212
unsigned int const ncols) {
1313

1414
auto row = blockIdx.y,
@@ -25,39 +25,42 @@ static __global__ void rearrange(
2525

2626
void rearrange_nv_gpu(RearrangeCudaDescriptor_t desc, void *y, void const *x, void *stream) {
2727
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
28-
if (desc->r == 1 && desc->c == 1 && desc->b == 1) {
29-
cudaMemcpyAsync(y, x, desc->bytes_per_thread, cudaMemcpyDeviceToDevice, cuda_stream);
28+
auto unit = desc->unit,
29+
r = desc->r, c = desc->c;
30+
auto dst_rs = desc->dst_rs, dst_cs = desc->dst_cs,
31+
src_rs = desc->src_rs, src_cs = desc->src_cs;
32+
33+
if (r == 1 && c == 1) {
34+
cudaMemcpyAsync(y, x, unit, cudaMemcpyDeviceToDevice, cuda_stream);
3035
return;
3136
}
3237

33-
uint64_t rsa = desc->rsa, csa = desc->csa, rsb = desc->rsb, csb = desc->csb;
34-
unsigned int r = desc->r, c = desc->c, b = desc->b, bytes_per_thread = desc->bytes_per_thread;
35-
auto dst_ptr = static_cast<void *>(reinterpret_cast<uint8_t *>(y));
36-
rsa /= b;
37-
csa /= b;
38-
auto src_ptr = static_cast<void const *>(reinterpret_cast<uint8_t const *>(x));
39-
rsb /= b;
40-
csb /= b;
41-
dim3 grid_dims = dim3((c + MAX_WARP_PER_BLOCK - 1) / MAX_WARP_PER_BLOCK, r);
42-
dim3 block_dims = dim3(WARP_SIZE, (c + grid_dims.x - 1) / grid_dims.x);
43-
switch (bytes_per_thread) {
38+
auto warps = 1024 / WARP_SIZE;
39+
auto grid = dim3((c + warps - 1) / warps, r);
40+
auto block = dim3(WARP_SIZE, (c + grid.x - 1) / grid.x);
41+
dst_rs /= unit;
42+
dst_cs /= unit;
43+
src_rs /= unit;
44+
src_cs /= unit;
45+
46+
switch (unit / WARP_SIZE) {
4447
case 1:
45-
rearrange<uchar1><<<grid_dims, block_dims, 0, cuda_stream>>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c);
48+
rearrange<uchar1><<<grid, block, 0, cuda_stream>>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c);
4649
break;
4750
case 2:
48-
rearrange<uchar2><<<grid_dims, block_dims, 0, cuda_stream>>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c);
51+
rearrange<uchar2><<<grid, block, 0, cuda_stream>>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c);
4952
break;
5053
case 4:
51-
rearrange<float1><<<grid_dims, block_dims, 0, cuda_stream>>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c);
54+
rearrange<float1><<<grid, block, 0, cuda_stream>>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c);
5255
break;
5356
case 8:
54-
rearrange<float2><<<grid_dims, block_dims, 0, cuda_stream>>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c);
57+
rearrange<float2><<<grid, block, 0, cuda_stream>>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c);
5558
break;
5659
case 16:
57-
rearrange<float4><<<grid_dims, block_dims, 0, cuda_stream>>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c);
60+
rearrange<float4><<<grid, block, 0, cuda_stream>>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c);
5861
break;
5962
case 32:
60-
rearrange<double4><<<grid_dims, block_dims, 0, cuda_stream>>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c);
63+
rearrange<double4><<<grid, block, 0, cuda_stream>>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c);
6164
break;
6265
default:
6366
break;

src/ops/rearrange/cuda/rearrange.cuh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@
77
struct RearrangeCudaDescriptor {
88
Device device;
99
int device_id;
10-
uint64_t rsa;
11-
uint64_t rsb;
12-
uint64_t csa;
13-
uint64_t csb;
14-
uint64_t r, c, b;
15-
uint64_t bytes_per_thread;
10+
uint64_t unit, r, c;
11+
int64_t dst_rs, dst_cs, src_rs, src_cs;
1612
};
1713

1814
typedef struct RearrangeCudaDescriptor *RearrangeCudaDescriptor_t;

0 commit comments

Comments
 (0)