Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 142 additions & 128 deletions paddle/phi/kernels/cpu/strided_copy_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ void StridedCopyKernel(const Context& dev_ctx,
#if defined(PADDLE_WITH_CUDA)
// not support Windows
#if !defined(_WIN32)
if (FLAGS_use_stride_kernel && FLAGS_use_stride_compute_kernel &&
if (FLAGS_use_stride_kernel &&
input.place().GetType() == phi::AllocationType::CPU &&
out->place().GetType() == phi::AllocationType::GPU &&
input.dtype() == out->dtype() && !input.meta().is_contiguous()) {
input.dtype() == out->dtype() &&
(!input.meta().is_contiguous() || !out->meta().is_contiguous())) {
phi::DenseTensor dst_gpu;
phi::DenseTensor src_cpu;

if (out->meta().is_contiguous()) {
dst_gpu = *out;
} else {
Expand All @@ -81,176 +80,191 @@ void StridedCopyKernel(const Context& dev_ctx,
dev_ctx.Alloc(&dst_gpu, input.dtype());
}

phi::DenseTensor cpu_input = input;
phi::DenseTensor* cpu_out = &src_cpu;
void* cpu_output_data;
auto src_cpu_place = input.place();
auto dst_gpu_place = out->place();
auto& pool = phi::DeviceContextPool::Instance();
auto* gpu_dev_ctx = static_cast<phi::GPUContext*>(pool.Get(out->place()));
auto stream = gpu_dev_ctx->stream();

if (input.meta().is_contiguous()) {
auto src_cpu_place = input.place();
auto dst_gpu_place = out->place();
auto size = phi::SizeOf(input.dtype()) * input.numel();
void* dst_ptr = gpu_dev_ctx->Alloc(
&dst_gpu,
dst_gpu.dtype(),
0,
dst_gpu_place.GetType() == AllocationType::GPUPINNED);

phi::memory_utils::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, input.data<T>(), size, stream);

} else {
phi::DenseTensor src_cpu;
phi::DenseTensor cpu_input = input;
phi::DenseTensor* cpu_out = &src_cpu;
void* cpu_output_data;

phi::DenseTensorMeta cpu_meta = cpu_input.meta();
cpu_meta.strides = cpu_meta.calc_strides(cpu_meta.dims);
cpu_meta.offset = 0;
cpu_out->set_meta(cpu_meta);
phi::DenseTensorMeta cpu_meta = cpu_input.meta();
cpu_meta.strides = cpu_meta.calc_strides(cpu_meta.dims);
cpu_meta.offset = 0;
cpu_out->set_meta(cpu_meta);

#if defined(PADDLE_WITH_OPENMP)
dev_ctx.HostAlloc(cpu_out, cpu_out->dtype());
dev_ctx.HostAlloc(cpu_out, cpu_out->dtype());
#endif
const void* cpu_input_data = cpu_input.data();
cpu_output_data = malloc(phi::SizeOf(cpu_input.dtype()) * cpu_out->numel());
const void* cpu_input_data = cpu_input.data();
cpu_output_data =
malloc(phi::SizeOf(cpu_input.dtype()) * cpu_out->numel());

if (FastTransposeCopyValid(*cpu_out, cpu_input)) {
constexpr int64_t TRANS_NUMEL = 60;
void* trans_buffer =
malloc(phi::SizeOf(input.dtype()) * TRANS_NUMEL * TRANS_NUMEL);
if (FastTransposeCopyValid(*cpu_out, cpu_input)) {
constexpr int64_t TRANS_NUMEL = 60;
void* trans_buffer =
malloc(phi::SizeOf(input.dtype()) * TRANS_NUMEL * TRANS_NUMEL);

const T* tmp_src_ptr = reinterpret_cast<const T*>(cpu_input_data);
const T* tmp_src_ptr = reinterpret_cast<const T*>(cpu_input_data);
#if defined(PADDLE_WITH_OPENMP)
T* tmp_out_ptr = reinterpret_cast<T*>(cpu_output_data);
T* tmp_out_ptr = reinterpret_cast<T*>(cpu_output_data);
#else
T* tmp_out_ptr = cpu_out->data<T>();
T* tmp_out_ptr = cpu_out->data<T>();
#endif
T* tmp_buf_ptr = reinterpret_cast<T*>(trans_buffer);
T* tmp_buf_ptr = reinterpret_cast<T*>(trans_buffer);

int64_t dim0 = cpu_out->dims()[0];
int64_t dim1 = cpu_out->dims()[1];
int64_t dim0 = cpu_out->dims()[0];
int64_t dim1 = cpu_out->dims()[1];

for (int64_t d0 = 0; d0 < dim0; d0 += TRANS_NUMEL) {
for (int64_t d1 = 0; d1 < dim1; d1 += TRANS_NUMEL) {
const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0;
T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1;
for (int64_t d0 = 0; d0 < dim0; d0 += TRANS_NUMEL) {
for (int64_t d1 = 0; d1 < dim1; d1 += TRANS_NUMEL) {
const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0;
T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1;

int nr = std::min(dim0 - d0, TRANS_NUMEL);
int nc = std::min(dim1 - d1, TRANS_NUMEL);
int nr = std::min(dim0 - d0, TRANS_NUMEL);
int nc = std::min(dim1 - d1, TRANS_NUMEL);

for (int c = 0; c < nc; c++) {
memcpy(tmp_buf_ptr + c * TRANS_NUMEL,
src_ptr_inter + c * dim0,
nr * sizeof(T));
}
for (int c = 0; c < nc; c++) {
memcpy(tmp_buf_ptr + c * TRANS_NUMEL,
src_ptr_inter + c * dim0,
nr * sizeof(T));
}

int rc_max = std::max(nr, nc);
int rc_min = std::min(nr, nc);
for (int r = 0; r < rc_max; r++) {
int end = std::min(r, rc_min);
for (int c = 0; c < end; c++) {
T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c];
tmp_buf_ptr[r + TRANS_NUMEL * c] =
tmp_buf_ptr[r * TRANS_NUMEL + c];
tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp;
int rc_max = std::max(nr, nc);
int rc_min = std::min(nr, nc);
for (int r = 0; r < rc_max; r++) {
int end = std::min(r, rc_min);
for (int c = 0; c < end; c++) {
T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c];
tmp_buf_ptr[r + TRANS_NUMEL * c] =
tmp_buf_ptr[r * TRANS_NUMEL + c];
tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp;
}
}
}

for (int r = 0; r < nr; r++) {
memcpy(out_ptr_inter + r * dim1,
tmp_buf_ptr + r * TRANS_NUMEL,
nc * sizeof(T));
for (int r = 0; r < nr; r++) {
memcpy(out_ptr_inter + r * dim1,
tmp_buf_ptr + r * TRANS_NUMEL,
nc * sizeof(T));
}
}
}
}
free(trans_buffer);
} else {
free(trans_buffer);
} else {
#if defined(PADDLE_WITH_OPENMP)
phi::DenseTensorIteratorConfig config;
config.add_output(*cpu_out);
config.add_const_input(cpu_input);
config.is_alloc_out_ = true;
phi::DenseTensorIterator iter = config.build();

std::vector<int64_t> tmp_strides(
iter.ntensors() * static_cast<size_t>(std::max(iter.ndim(), 2)));
phi::DenseTensorIteratorConfig config;
config.add_output(*cpu_out);
config.add_const_input(cpu_input);
config.is_alloc_out_ = true;
phi::DenseTensorIterator iter = config.build();

DealWithStride(iter, tmp_strides.data());
std::vector<int64_t> tmp_strides(
iter.ntensors() * static_cast<size_t>(std::max(iter.ndim(), 2)));

std::vector<int64_t> out_stride(tmp_strides.begin() + iter.ntensors(),
tmp_strides.end());
DealWithStride(iter, tmp_strides.data());

std::vector<int64_t> output_stride = iter.strides(0);
std::vector<int64_t> input_stride = iter.strides(1);
std::vector<int64_t> out_stride(tmp_strides.begin() + iter.ntensors(),
tmp_strides.end());

const int64_t& numel = iter.numel();
std::vector<int64_t> output_stride = iter.strides(0);
std::vector<int64_t> input_stride = iter.strides(1);

const char* in_ptr = reinterpret_cast<const char*>(cpu_input_data);
char* out_ptr = reinterpret_cast<char*>(cpu_output_data);
const int64_t& numel = iter.numel();

int64_t end = numel;
int64_t begin = 0;
int64_t grain_size = 32768;
const char* in_ptr = reinterpret_cast<const char*>(cpu_input_data);
char* out_ptr = reinterpret_cast<char*>(cpu_output_data);

int64_t* whole_stride = tmp_strides.data();
int64_t end = numel;
int64_t begin = 0;
int64_t grain_size = 32768;

omp_set_num_threads(std::thread::hardware_concurrency());
int64_t* whole_stride = tmp_strides.data();

#pragma omp parallel
{
int64_t num_threads = omp_get_num_threads();
{
int64_t num_threads = omp_get_num_threads();

if (grain_size > 0) {
num_threads = std::min(num_threads, DivUp((end - begin), grain_size));
}
if (grain_size > 0) {
num_threads =
std::min(num_threads, DivUp((end - begin), grain_size));
}

int64_t tid = omp_get_thread_num();
int64_t chunk_size = DivUp((end - begin), num_threads);
int64_t begin_tid = begin + tid * chunk_size;

if (begin_tid < end) {
int64_t range_start = begin_tid;
int64_t range_end = std::min(end, chunk_size + begin_tid);

auto dimiter = DimIter(iter.shape(), range_start, range_end);
while (!dimiter.iter_to_end()) {
const auto v_ndim = dimiter.values.size();
const char* tmp_in_data = in_ptr;
char* tmp_out_data = out_ptr;
for (size_t dim = 0; dim < v_ndim; dim++) {
int64_t value = dimiter.values[dim];
tmp_out_data += value * whole_stride[dim * iter.ntensors() + 0];
tmp_in_data += value * whole_stride[dim * iter.ntensors() + 1];
}
int64_t tid = omp_get_thread_num();
int64_t chunk_size = DivUp((end - begin), num_threads);
int64_t begin_tid = begin + tid * chunk_size;

if (begin_tid < end) {
int64_t range_start = begin_tid;
int64_t range_end = std::min(end, chunk_size + begin_tid);

auto dimiter = DimIter(iter.shape(), range_start, range_end);
while (!dimiter.iter_to_end()) {
const auto v_ndim = dimiter.values.size();
const char* tmp_in_data = in_ptr;
char* tmp_out_data = out_ptr;
for (size_t dim = 0; dim < v_ndim; dim++) {
int64_t value = dimiter.values[dim];
tmp_out_data += value * whole_stride[dim * iter.ntensors() + 0];
tmp_in_data += value * whole_stride[dim * iter.ntensors() + 1];
}

auto step = dimiter.iter_for_step();
auto step = dimiter.iter_for_step();

for (int64_t i = 0; i < step[1]; i++) {
for (int64_t j = 0; j < step[0]; j++) {
const char* real_in_ptr = tmp_in_data + j * whole_stride[1];
char* real_out_ptr = tmp_out_data + j * whole_stride[0];
for (int64_t i = 0; i < step[1]; i++) {
for (int64_t j = 0; j < step[0]; j++) {
const char* real_in_ptr = tmp_in_data + j * whole_stride[1];
char* real_out_ptr = tmp_out_data + j * whole_stride[0];

*reinterpret_cast<T*>(real_out_ptr) =
*reinterpret_cast<const T*>(real_in_ptr);
*reinterpret_cast<T*>(real_out_ptr) =
*reinterpret_cast<const T*>(real_in_ptr);
}
tmp_in_data = tmp_in_data + out_stride[1];
tmp_out_data = tmp_out_data + out_stride[0];
}
tmp_in_data = tmp_in_data + out_stride[1];
tmp_out_data = tmp_out_data + out_stride[0];
}

dimiter.iter_to_next(step);
dimiter.iter_to_next(step);
}
}
}
}
#else
phi::ContiguousKernel<T, Context>(dev_ctx, input, cpu_out);
phi::ContiguousKernel<T, Context>(dev_ctx, input, cpu_out);
#endif
}

auto src_cpu_place = input.place();
auto dst_gpu_place = out->place();

auto& pool = phi::DeviceContextPool::Instance();
auto* gpu_dev_ctx = static_cast<phi::GPUContext*>(pool.Get(out->place()));
auto stream = gpu_dev_ctx->stream();
}
#if defined(PADDLE_WITH_OPENMP)
auto* src_ptr = cpu_output_data;
auto* src_ptr = cpu_output_data;
#else
auto* src_ptr = cpu_out->data<T>();
auto* src_ptr = cpu_out->data<T>();
#endif

auto size = phi::SizeOf(input.dtype()) * src_cpu.numel();
void* dst_ptr = gpu_dev_ctx->Alloc(
&dst_gpu,
dst_gpu.dtype(),
0,
dst_gpu_place.GetType() == AllocationType::GPUPINNED);
auto size = phi::SizeOf(input.dtype()) * src_cpu.numel();
void* dst_ptr = gpu_dev_ctx->Alloc(
&dst_gpu,
dst_gpu.dtype(),
0,
dst_gpu_place.GetType() == AllocationType::GPUPINNED);

phi::memory_utils::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
phi::memory_utils::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);

free(cpu_output_data);
free(cpu_output_data);
}
if (out != &dst_gpu) {
PD_VISIT_ALL_TYPES(
out->dtype(), "StridedCopyKernel", ([&] {
Expand Down
28 changes: 28 additions & 0 deletions test/legacy_test/test_fast_h2d_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,33 @@ def test_dygraph(self):
self.check_dygraph_result(place=get_device_place())


@unittest.skipIf(
not paddle.core.is_compiled_with_cuda(),
"core is not compiled with CUDA",
)
class TestFastCPUCopy3(unittest.TestCase):
def setUp(self):
src_shape = [2, 2]
tgt_shape = [2, 4]
# self.input_np_a = np.random.random((2,2)).astype(np.float32)
# self.input_np_b = np.random.random((2,4)).astype(np.float32)
self.input_dtype = 'float32'
paddle.device.set_device("cpu")
self.src_cpu = paddle.ones(src_shape, dtype="float32")
paddle.device.set_device("gpu:0")
self.dst_gpu = paddle.zeros(tgt_shape, dtype="float32")

def check_dygraph_result(self, place):
paddle.device.set_device("gpu:0")
tmp_dst_gpu = self.dst_gpu[..., :2]
tmp_dst_gpu.copy_(self.src_cpu)
tmo_dst_gpu1 = self.dst_gpu[..., 2:]
tmo_dst_gpu1.copy_(self.src_cpu)
np.testing.assert_allclose(self.dst_gpu.numpy(), 1.0)

def test_dygraph(self):
self.check_dygraph_result(place=get_device_place())


if __name__ == '__main__':
unittest.main()
Loading