From 76371ac8b4b052508b739b963484354dc015fa34 Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Thu, 30 Oct 2025 13:55:59 +0000 Subject: [PATCH] cpu contiguous --- paddle/phi/kernels/cpu/strided_copy_kernel.cc | 270 +++++++++--------- test/legacy_test/test_fast_h2d_copy.py | 28 ++ 2 files changed, 170 insertions(+), 128 deletions(-) diff --git a/paddle/phi/kernels/cpu/strided_copy_kernel.cc b/paddle/phi/kernels/cpu/strided_copy_kernel.cc index 9d5a7127d45ef6..a755737b5c0558 100644 --- a/paddle/phi/kernels/cpu/strided_copy_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_copy_kernel.cc @@ -64,13 +64,12 @@ void StridedCopyKernel(const Context& dev_ctx, #if defined(PADDLE_WITH_CUDA) // not support Windows #if !defined(_WIN32) - if (FLAGS_use_stride_kernel && FLAGS_use_stride_compute_kernel && + if (FLAGS_use_stride_kernel && input.place().GetType() == phi::AllocationType::CPU && out->place().GetType() == phi::AllocationType::GPU && - input.dtype() == out->dtype() && !input.meta().is_contiguous()) { + input.dtype() == out->dtype() && + (!input.meta().is_contiguous() || !out->meta().is_contiguous())) { phi::DenseTensor dst_gpu; - phi::DenseTensor src_cpu; - if (out->meta().is_contiguous()) { dst_gpu = *out; } else { @@ -81,176 +80,191 @@ void StridedCopyKernel(const Context& dev_ctx, dev_ctx.Alloc(&dst_gpu, input.dtype()); } - phi::DenseTensor cpu_input = input; - phi::DenseTensor* cpu_out = &src_cpu; - void* cpu_output_data; + auto src_cpu_place = input.place(); + auto dst_gpu_place = out->place(); + auto& pool = phi::DeviceContextPool::Instance(); + auto* gpu_dev_ctx = static_cast(pool.Get(out->place())); + auto stream = gpu_dev_ctx->stream(); + + if (input.meta().is_contiguous()) { + auto src_cpu_place = input.place(); + auto dst_gpu_place = out->place(); + auto size = phi::SizeOf(input.dtype()) * input.numel(); + void* dst_ptr = gpu_dev_ctx->Alloc( + &dst_gpu, + dst_gpu.dtype(), + 0, + dst_gpu_place.GetType() == AllocationType::GPUPINNED); + + phi::memory_utils::Copy( + dst_gpu_place, dst_ptr, src_cpu_place, input.data(), size, stream); + + } else { + phi::DenseTensor src_cpu; + phi::DenseTensor cpu_input = input; + phi::DenseTensor* cpu_out = &src_cpu; + void* cpu_output_data; - phi::DenseTensorMeta cpu_meta = cpu_input.meta(); - cpu_meta.strides = cpu_meta.calc_strides(cpu_meta.dims); - cpu_meta.offset = 0; - cpu_out->set_meta(cpu_meta); + phi::DenseTensorMeta cpu_meta = cpu_input.meta(); + cpu_meta.strides = cpu_meta.calc_strides(cpu_meta.dims); + cpu_meta.offset = 0; + cpu_out->set_meta(cpu_meta); #if defined(PADDLE_WITH_OPENMP) - dev_ctx.HostAlloc(cpu_out, cpu_out->dtype()); + dev_ctx.HostAlloc(cpu_out, cpu_out->dtype()); #endif - const void* cpu_input_data = cpu_input.data(); - cpu_output_data = malloc(phi::SizeOf(cpu_input.dtype()) * cpu_out->numel()); + const void* cpu_input_data = cpu_input.data(); + cpu_output_data = + malloc(phi::SizeOf(cpu_input.dtype()) * cpu_out->numel()); - if (FastTransposeCopyValid(*cpu_out, cpu_input)) { - constexpr int64_t TRANS_NUMEL = 60; - void* trans_buffer = - malloc(phi::SizeOf(input.dtype()) * TRANS_NUMEL * TRANS_NUMEL); + if (FastTransposeCopyValid(*cpu_out, cpu_input)) { + constexpr int64_t TRANS_NUMEL = 60; + void* trans_buffer = + malloc(phi::SizeOf(input.dtype()) * TRANS_NUMEL * TRANS_NUMEL); - const T* tmp_src_ptr = reinterpret_cast(cpu_input_data); + const T* tmp_src_ptr = reinterpret_cast(cpu_input_data); #if defined(PADDLE_WITH_OPENMP) - T* tmp_out_ptr = reinterpret_cast(cpu_output_data); + T* tmp_out_ptr = reinterpret_cast(cpu_output_data); #else - T* tmp_out_ptr = cpu_out->data(); + T* tmp_out_ptr = cpu_out->data(); #endif - T* tmp_buf_ptr = reinterpret_cast(trans_buffer); + T* tmp_buf_ptr = reinterpret_cast(trans_buffer); - int64_t dim0 = cpu_out->dims()[0]; - int64_t dim1 = cpu_out->dims()[1]; + int64_t dim0 = cpu_out->dims()[0]; + int64_t dim1 = cpu_out->dims()[1]; - for (int64_t d0 = 0; d0 < dim0; d0 += TRANS_NUMEL) { - for (int64_t d1 = 0; d1 < dim1; d1 += TRANS_NUMEL) { - const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0; - T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1; + for (int64_t d0 = 0; d0 < dim0; d0 += TRANS_NUMEL) { + for (int64_t d1 = 0; d1 < dim1; d1 += TRANS_NUMEL) { + const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0; + T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1; - int nr = std::min(dim0 - d0, TRANS_NUMEL); - int nc = std::min(dim1 - d1, TRANS_NUMEL); + int nr = std::min(dim0 - d0, TRANS_NUMEL); + int nc = std::min(dim1 - d1, TRANS_NUMEL); - for (int c = 0; c < nc; c++) { - memcpy(tmp_buf_ptr + c * TRANS_NUMEL, - src_ptr_inter + c * dim0, - nr * sizeof(T)); - } + for (int c = 0; c < nc; c++) { + memcpy(tmp_buf_ptr + c * TRANS_NUMEL, + src_ptr_inter + c * dim0, + nr * sizeof(T)); + } - int rc_max = std::max(nr, nc); - int rc_min = std::min(nr, nc); - for (int r = 0; r < rc_max; r++) { - int end = std::min(r, rc_min); - for (int c = 0; c < end; c++) { - T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c]; - tmp_buf_ptr[r + TRANS_NUMEL * c] = - tmp_buf_ptr[r * TRANS_NUMEL + c]; - tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp; + int rc_max = std::max(nr, nc); + int rc_min = std::min(nr, nc); + for (int r = 0; r < rc_max; r++) { + int end = std::min(r, rc_min); + for (int c = 0; c < end; c++) { + T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c]; + tmp_buf_ptr[r + TRANS_NUMEL * c] = + tmp_buf_ptr[r * TRANS_NUMEL + c]; + tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp; + } } - } - for (int r = 0; r < nr; r++) { - memcpy(out_ptr_inter + r * dim1, - tmp_buf_ptr + r * TRANS_NUMEL, - nc * sizeof(T)); + for (int r = 0; r < nr; r++) { + memcpy(out_ptr_inter + r * dim1, + tmp_buf_ptr + r * TRANS_NUMEL, + nc * sizeof(T)); + } } } - } - free(trans_buffer); - } else { + free(trans_buffer); + } else { #if defined(PADDLE_WITH_OPENMP) - phi::DenseTensorIteratorConfig config; - config.add_output(*cpu_out); - config.add_const_input(cpu_input); - config.is_alloc_out_ = true; - phi::DenseTensorIterator iter = config.build(); - - std::vector tmp_strides( - iter.ntensors() * static_cast(std::max(iter.ndim(), 2))); + phi::DenseTensorIteratorConfig config; + config.add_output(*cpu_out); + config.add_const_input(cpu_input); + config.is_alloc_out_ = true; + phi::DenseTensorIterator iter = config.build(); - DealWithStride(iter, tmp_strides.data()); + std::vector tmp_strides( + iter.ntensors() * static_cast(std::max(iter.ndim(), 2))); - std::vector out_stride(tmp_strides.begin() + iter.ntensors(), - tmp_strides.end()); + DealWithStride(iter, tmp_strides.data()); - std::vector output_stride = iter.strides(0); - std::vector input_stride = iter.strides(1); + std::vector out_stride(tmp_strides.begin() + iter.ntensors(), + tmp_strides.end()); - const int64_t& numel = iter.numel(); + std::vector output_stride = iter.strides(0); + std::vector input_stride = iter.strides(1); - const char* in_ptr = reinterpret_cast(cpu_input_data); - char* out_ptr = reinterpret_cast(cpu_output_data); + const int64_t& numel = iter.numel(); - int64_t end = numel; - int64_t begin = 0; - int64_t grain_size = 32768; + const char* in_ptr = reinterpret_cast(cpu_input_data); + char* out_ptr = reinterpret_cast(cpu_output_data); - int64_t* whole_stride = tmp_strides.data(); + int64_t end = numel; + int64_t begin = 0; + int64_t grain_size = 32768; - omp_set_num_threads(std::thread::hardware_concurrency()); + int64_t* whole_stride = tmp_strides.data(); #pragma omp parallel - { - int64_t num_threads = omp_get_num_threads(); + { + int64_t num_threads = omp_get_num_threads(); - if (grain_size > 0) { - num_threads = std::min(num_threads, DivUp((end - begin), grain_size)); - } + if (grain_size > 0) { + num_threads = + std::min(num_threads, DivUp((end - begin), grain_size)); + } - int64_t tid = omp_get_thread_num(); - int64_t chunk_size = DivUp((end - begin), num_threads); - int64_t begin_tid = begin + tid * chunk_size; - - if (begin_tid < end) { - int64_t range_start = begin_tid; - int64_t range_end = std::min(end, chunk_size + begin_tid); - - auto dimiter = DimIter(iter.shape(), range_start, range_end); - while (!dimiter.iter_to_end()) { - const auto v_ndim = dimiter.values.size(); - const char* tmp_in_data = in_ptr; - char* tmp_out_data = out_ptr; - for (size_t dim = 0; dim < v_ndim; dim++) { - int64_t value = dimiter.values[dim]; - tmp_out_data += value * whole_stride[dim * iter.ntensors() + 0]; - tmp_in_data += value * whole_stride[dim * iter.ntensors() + 1]; - } + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = DivUp((end - begin), num_threads); + int64_t begin_tid = begin + tid * chunk_size; + + if (begin_tid < end) { + int64_t range_start = begin_tid; + int64_t range_end = std::min(end, chunk_size + begin_tid); + + auto dimiter = DimIter(iter.shape(), range_start, range_end); + while (!dimiter.iter_to_end()) { + const auto v_ndim = dimiter.values.size(); + const char* tmp_in_data = in_ptr; + char* tmp_out_data = out_ptr; + for (size_t dim = 0; dim < v_ndim; dim++) { + int64_t value = dimiter.values[dim]; + tmp_out_data += value * whole_stride[dim * iter.ntensors() + 0]; + tmp_in_data += value * whole_stride[dim * iter.ntensors() + 1]; + } - auto step = dimiter.iter_for_step(); + auto step = dimiter.iter_for_step(); - for (int64_t i = 0; i < step[1]; i++) { - for (int64_t j = 0; j < step[0]; j++) { - const char* real_in_ptr = tmp_in_data + j * whole_stride[1]; - char* real_out_ptr = tmp_out_data + j * whole_stride[0]; + for (int64_t i = 0; i < step[1]; i++) { + for (int64_t j = 0; j < step[0]; j++) { + const char* real_in_ptr = tmp_in_data + j * whole_stride[1]; + char* real_out_ptr = tmp_out_data + j * whole_stride[0]; - *reinterpret_cast(real_out_ptr) = - *reinterpret_cast(real_in_ptr); + *reinterpret_cast(real_out_ptr) = + *reinterpret_cast(real_in_ptr); + } + tmp_in_data = tmp_in_data + out_stride[1]; + tmp_out_data = tmp_out_data + out_stride[0]; } - tmp_in_data = tmp_in_data + out_stride[1]; - tmp_out_data = tmp_out_data + out_stride[0]; - } - dimiter.iter_to_next(step); + dimiter.iter_to_next(step); + } } } - } #else - phi::ContiguousKernel(dev_ctx, input, cpu_out); + phi::ContiguousKernel(dev_ctx, input, cpu_out); #endif - } - - auto src_cpu_place = input.place(); - auto dst_gpu_place = out->place(); - - auto& pool = phi::DeviceContextPool::Instance(); - auto* gpu_dev_ctx = static_cast(pool.Get(out->place())); - auto stream = gpu_dev_ctx->stream(); + } #if defined(PADDLE_WITH_OPENMP) - auto* src_ptr = cpu_output_data; + auto* src_ptr = cpu_output_data; #else - auto* src_ptr = cpu_out->data(); + auto* src_ptr = cpu_out->data(); #endif - auto size = phi::SizeOf(input.dtype()) * src_cpu.numel(); - void* dst_ptr = gpu_dev_ctx->Alloc( - &dst_gpu, - dst_gpu.dtype(), - 0, - dst_gpu_place.GetType() == AllocationType::GPUPINNED); + auto size = phi::SizeOf(input.dtype()) * src_cpu.numel(); + void* dst_ptr = gpu_dev_ctx->Alloc( + &dst_gpu, + dst_gpu.dtype(), + 0, + dst_gpu_place.GetType() == AllocationType::GPUPINNED); - phi::memory_utils::Copy( - dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); + phi::memory_utils::Copy( + dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); - free(cpu_output_data); + free(cpu_output_data); + } if (out != &dst_gpu) { PD_VISIT_ALL_TYPES( out->dtype(), "StridedCopyKernel", ([&] { diff --git a/test/legacy_test/test_fast_h2d_copy.py b/test/legacy_test/test_fast_h2d_copy.py index 99507b3f56699b..489bcca42cd12b 100644 --- a/test/legacy_test/test_fast_h2d_copy.py +++ b/test/legacy_test/test_fast_h2d_copy.py @@ -85,5 +85,33 @@ def test_dygraph(self): self.check_dygraph_result(place=get_device_place()) +@unittest.skipIf( + not paddle.core.is_compiled_with_cuda(), + "core is not compiled with CUDA", +) +class TestFastCPUCopy3(unittest.TestCase): + def setUp(self): + src_shape = [2, 2] + tgt_shape = [2, 4] + # self.input_np_a = np.random.random((2,2)).astype(np.float32) + # self.input_np_b = np.random.random((2,4)).astype(np.float32) + self.input_dtype = 'float32' + paddle.device.set_device("cpu") + self.src_cpu = paddle.ones(src_shape, dtype="float32") + paddle.device.set_device("gpu:0") + self.dst_gpu = paddle.zeros(tgt_shape, dtype="float32") + + def check_dygraph_result(self, place): + paddle.device.set_device("gpu:0") + tmp_dst_gpu = self.dst_gpu[..., :2] + tmp_dst_gpu.copy_(self.src_cpu) + tmo_dst_gpu1 = self.dst_gpu[..., 2:] + tmo_dst_gpu1.copy_(self.src_cpu) + np.testing.assert_allclose(self.dst_gpu.numpy(), 1.0) + + def test_dygraph(self): + self.check_dygraph_result(place=get_device_place()) + + if __name__ == '__main__': unittest.main()