Skip to content

Commit ed38393

Browse files
committed
fix(cpu): 为 rearrange 支持 ndim == 1
Signed-off-by: YdrMaster <[email protected]>
1 parent 745a4b8 commit ed38393

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

src/ops/rearrange/cpu/rearrange_cpu.cc

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "../../utils.h"
33
#include <cstdint>
44
#include <cstring>
5+
#include <iostream>
56
#include <numeric>
67

78
infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
@@ -11,41 +12,52 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
1112
if (!dtype_eq(dst->dt, src->dt)) {
1213
return STATUS_BAD_TENSOR_DTYPE;
1314
}
14-
if (dst->ndim != src->ndim || dst->ndim < 2) {
15+
16+
auto ndim = dst->ndim;
17+
if (src->ndim != ndim || ndim == 0) {
1518
return STATUS_BAD_TENSOR_SHAPE;
1619
}
17-
std::vector<uint64_t> shape;
18-
std::vector<int64_t> strides_dst, strides_src;
19-
auto ndim = dst->ndim;
2020
for (int i = 0; i < ndim; ++i) {
2121
if (dst->shape[i] != src->shape[i]) {
2222
return STATUS_BAD_TENSOR_SHAPE;
2323
}
24-
shape.push_back(dst->shape[i]);
25-
strides_dst.push_back(dst->strides[i]);
26-
strides_src.push_back(src->strides[i]);
2724
}
2825
if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) {
2926
return STATUS_BAD_TENSOR_STRIDES;
3027
}
28+
29+
std::vector<uint64_t>
30+
shape(dst->shape, dst->shape + ndim);
31+
std::vector<int64_t>
32+
strides_dst(dst->strides, dst->strides + ndim),
33+
strides_src(src->strides, src->strides + ndim);
34+
3135
unsigned int r = 0;
32-
if (ndim == 2) {
33-
r = dst->shape[0];
34-
} else if (ndim == 3) {
35-
r = dst->shape[0] * dst->shape[1];
36-
} else {
37-
for (int i = ndim - 3; i >= 1; --i) {
38-
if (dst->shape[i] * dst->strides[i] != dst->strides[i - 1] || src->shape[i] * src->strides[i] != src->strides[i - 1]) {
39-
return STATUS_BAD_TENSOR_STRIDES;
36+
switch (ndim) {
37+
case 1:
38+
ndim = 2;
39+
strides_dst.insert(strides_dst.begin(), shape[0]);
40+
strides_src.insert(strides_src.begin(), shape[0]);
41+
shape.insert(shape.begin(), 1);
42+
case 2:
43+
r = shape[0];
44+
break;
45+
case 3:
46+
r = shape[0] * shape[1];
47+
break;
48+
default:
49+
for (int i = ndim - 3; i >= 1; --i) {
50+
if (shape[i] * strides_dst[i] != strides_dst[i - 1] || shape[i] * strides_src[i] != strides_src[i - 1]) {
51+
return STATUS_BAD_TENSOR_STRIDES;
52+
}
4053
}
41-
}
42-
r = std::accumulate(dst->shape, dst->shape + ndim - 1, 1, std::multiplies<unsigned int>());
54+
r = std::accumulate(shape.begin(), shape.end() - 1, 1, std::multiplies{});
55+
break;
4356
}
4457
*desc_ptr = new RearrangeCpuDescriptor{
4558
DevCpu,
4659
dst->dt,
4760
r,
48-
ndim,
4961
shape,
5062
strides_dst,
5163
strides_src,
@@ -70,11 +82,12 @@ inline int indices(uint64_t i, uint64_t ndim, std::vector<int64_t> strides, std:
7082
void reform_cpu(RearrangeCpuDescriptor_t desc, void *dst, void const *src) {
7183
auto dst_ptr = reinterpret_cast<uint8_t *>(dst);
7284
auto src_ptr = reinterpret_cast<const uint8_t *>(src);
73-
int bytes_size = desc->shape[desc->ndim - 1] * desc->dt.size;
85+
auto ndim = desc->shape.size();
86+
int bytes_size = desc->shape[ndim - 1] * desc->dt.size;
7487
#pragma omp parallel for
7588
for (uint64_t i = 0; i < desc->r; ++i) {
76-
auto dst_offset = indices(i, desc->ndim, desc->strides_dst, desc->shape);
77-
auto src_offset = indices(i, desc->ndim, desc->strides_src, desc->shape);
89+
auto dst_offset = indices(i, ndim, desc->strides_dst, desc->shape);
90+
auto src_offset = indices(i, ndim, desc->strides_src, desc->shape);
7891
std::memcpy(dst_ptr + dst_offset * desc->dt.size, src_ptr + src_offset * desc->dt.size, bytes_size);
7992
}
8093
}

src/ops/rearrange/cpu/rearrange_cpu.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ struct RearrangeCpuDescriptor {
77
Device device;
88
DataLayout dt;
99
uint64_t r;
10-
uint64_t ndim;
1110
std::vector<uint64_t> shape;
1211
std::vector<int64_t> strides_dst;
1312
std::vector<int64_t> strides_src;

0 commit comments

Comments
 (0)