22#include " ../../utils.h"
33#include < cstdint>
44#include < cstring>
5+ #include < iostream>
56#include < numeric>
67
78infiniopStatus_t cpuCreateRearrangeDescriptor (infiniopHandle_t,
@@ -11,41 +12,52 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
1112 if (!dtype_eq (dst->dt , src->dt )) {
1213 return STATUS_BAD_TENSOR_DTYPE;
1314 }
14- if (dst->ndim != src->ndim || dst->ndim < 2 ) {
15+
16+ auto ndim = dst->ndim ;
17+ if (src->ndim != ndim || ndim == 0 ) {
1518 return STATUS_BAD_TENSOR_SHAPE;
1619 }
17- std::vector<uint64_t > shape;
18- std::vector<int64_t > strides_dst, strides_src;
19- auto ndim = dst->ndim ;
2020 for (int i = 0 ; i < ndim; ++i) {
2121 if (dst->shape [i] != src->shape [i]) {
2222 return STATUS_BAD_TENSOR_SHAPE;
2323 }
24- shape.push_back (dst->shape [i]);
25- strides_dst.push_back (dst->strides [i]);
26- strides_src.push_back (src->strides [i]);
2724 }
2825 if (dst->strides [ndim - 1 ] != 1 || src->strides [ndim - 1 ] != 1 ) {
2926 return STATUS_BAD_TENSOR_STRIDES;
3027 }
28+
29+ std::vector<uint64_t >
30+ shape (dst->shape , dst->shape + ndim);
31+ std::vector<int64_t >
32+ strides_dst (dst->strides , dst->strides + ndim),
33+ strides_src (src->strides , src->strides + ndim);
34+
3135 unsigned int r = 0 ;
32- if (ndim == 2 ) {
33- r = dst->shape [0 ];
34- } else if (ndim == 3 ) {
35- r = dst->shape [0 ] * dst->shape [1 ];
36- } else {
37- for (int i = ndim - 3 ; i >= 1 ; --i) {
38- if (dst->shape [i] * dst->strides [i] != dst->strides [i - 1 ] || src->shape [i] * src->strides [i] != src->strides [i - 1 ]) {
39- return STATUS_BAD_TENSOR_STRIDES;
36+ switch (ndim) {
37+ case 1 :
38+ ndim = 2 ;
39+ strides_dst.insert (strides_dst.begin (), shape[0 ]);
40+ strides_src.insert (strides_src.begin (), shape[0 ]);
41+ shape.insert (shape.begin (), 1 );
42+ case 2 :
43+ r = shape[0 ];
44+ break ;
45+ case 3 :
46+ r = shape[0 ] * shape[1 ];
47+ break ;
48+ default :
49+ for (int i = ndim - 3 ; i >= 1 ; --i) {
50+ if (shape[i] * strides_dst[i] != strides_dst[i - 1 ] || shape[i] * strides_src[i] != strides_src[i - 1 ]) {
51+ return STATUS_BAD_TENSOR_STRIDES;
52+ }
4053 }
41- }
42- r = std::accumulate (dst-> shape , dst-> shape + ndim - 1 , 1 , std::multiplies< unsigned int >()) ;
54+ r = std::accumulate (shape. begin (), shape. end () - 1 , 1 , std::multiplies{});
55+ break ;
4356 }
4457 *desc_ptr = new RearrangeCpuDescriptor{
4558 DevCpu,
4659 dst->dt ,
4760 r,
48- ndim,
4961 shape,
5062 strides_dst,
5163 strides_src,
@@ -70,11 +82,12 @@ inline int indices(uint64_t i, uint64_t ndim, std::vector<int64_t> strides, std:
7082void reform_cpu (RearrangeCpuDescriptor_t desc, void *dst, void const *src) {
7183 auto dst_ptr = reinterpret_cast <uint8_t *>(dst);
7284 auto src_ptr = reinterpret_cast <const uint8_t *>(src);
73- int bytes_size = desc->shape [desc->ndim - 1 ] * desc->dt .size ;
85+ auto ndim = desc->shape .size ();
86+ int bytes_size = desc->shape [ndim - 1 ] * desc->dt .size ;
7487#pragma omp parallel for
7588 for (uint64_t i = 0 ; i < desc->r ; ++i) {
76- auto dst_offset = indices (i, desc-> ndim , desc->strides_dst , desc->shape );
77- auto src_offset = indices (i, desc-> ndim , desc->strides_src , desc->shape );
89+ auto dst_offset = indices (i, ndim, desc->strides_dst , desc->shape );
90+ auto src_offset = indices (i, ndim, desc->strides_src , desc->shape );
7891 std::memcpy (dst_ptr + dst_offset * desc->dt .size , src_ptr + src_offset * desc->dt .size , bytes_size);
7992 }
8093}
0 commit comments