@@ -65,13 +65,12 @@ void StridedCopyKernel(const Context& dev_ctx,
6565#if defined(PADDLE_WITH_CUDA)
6666// not support Windows
6767#if !defined(_WIN32)
68- if (FLAGS_use_stride_kernel && FLAGS_use_stride_compute_kernel &&
68+ if (FLAGS_use_stride_kernel &&
6969 input.place ().GetType () == phi::AllocationType::CPU &&
7070 out->place ().GetType () == phi::AllocationType::GPU &&
71- input.dtype () == out->dtype () && !input.meta ().is_contiguous ()) {
71+ input.dtype () == out->dtype () &&
72+ (!input.meta ().is_contiguous () || !out->meta ().is_contiguous ())) {
7273 phi::DenseTensor dst_gpu;
73- phi::DenseTensor src_cpu;
74-
7574 if (out->meta ().is_contiguous ()) {
7675 dst_gpu = *out;
7776 } else {
@@ -82,176 +81,191 @@ void StridedCopyKernel(const Context& dev_ctx,
8281 dev_ctx.Alloc (&dst_gpu, input.dtype ());
8382 }
8483
85- phi::DenseTensor cpu_input = input;
86- phi::DenseTensor* cpu_out = &src_cpu;
87- void * cpu_output_data;
84+ auto src_cpu_place = input.place ();
85+ auto dst_gpu_place = out->place ();
86+ auto & pool = phi::DeviceContextPool::Instance ();
87+ auto * gpu_dev_ctx = static_cast <phi::GPUContext*>(pool.Get (out->place ()));
88+ auto stream = gpu_dev_ctx->stream ();
89+
90+ if (input.meta ().is_contiguous ()) {
91+ auto src_cpu_place = input.place ();
92+ auto dst_gpu_place = out->place ();
93+ auto size = phi::SizeOf (input.dtype ()) * input.numel ();
94+ void * dst_ptr = gpu_dev_ctx->Alloc (
95+ &dst_gpu,
96+ dst_gpu.dtype (),
97+ 0 ,
98+ dst_gpu_place.GetType () == AllocationType::GPUPINNED);
99+
100+ phi::memory_utils::Copy (
101+ dst_gpu_place, dst_ptr, src_cpu_place, input.data <T>(), size, stream);
102+
103+ } else {
104+ phi::DenseTensor src_cpu;
105+ phi::DenseTensor cpu_input = input;
106+ phi::DenseTensor* cpu_out = &src_cpu;
107+ void * cpu_output_data;
88108
89- phi::DenseTensorMeta cpu_meta = cpu_input.meta ();
90- cpu_meta.strides = cpu_meta.calc_strides (cpu_meta.dims );
91- cpu_meta.offset = 0 ;
92- cpu_out->set_meta (cpu_meta);
109+ phi::DenseTensorMeta cpu_meta = cpu_input.meta ();
110+ cpu_meta.strides = cpu_meta.calc_strides (cpu_meta.dims );
111+ cpu_meta.offset = 0 ;
112+ cpu_out->set_meta (cpu_meta);
93113
94114#if defined(PADDLE_WITH_OPENMP)
95- dev_ctx.HostAlloc (cpu_out, cpu_out->dtype ());
115+ dev_ctx.HostAlloc (cpu_out, cpu_out->dtype ());
96116#endif
97- const void * cpu_input_data = cpu_input.data ();
98- cpu_output_data = malloc (phi::SizeOf (cpu_input.dtype ()) * cpu_out->numel ());
117+ const void * cpu_input_data = cpu_input.data ();
118+ cpu_output_data =
119+ malloc (phi::SizeOf (cpu_input.dtype ()) * cpu_out->numel ());
99120
100- if (FastTransposeCopyValid (*cpu_out, cpu_input)) {
101- constexpr int64_t TRANS_NUMEL = 60 ;
102- void * trans_buffer =
103- malloc (phi::SizeOf (input.dtype ()) * TRANS_NUMEL * TRANS_NUMEL);
121+ if (FastTransposeCopyValid (*cpu_out, cpu_input)) {
122+ constexpr int64_t TRANS_NUMEL = 60 ;
123+ void * trans_buffer =
124+ malloc (phi::SizeOf (input.dtype ()) * TRANS_NUMEL * TRANS_NUMEL);
104125
105- const T* tmp_src_ptr = reinterpret_cast <const T*>(cpu_input_data);
126+ const T* tmp_src_ptr = reinterpret_cast <const T*>(cpu_input_data);
106127#if defined(PADDLE_WITH_OPENMP)
107- T* tmp_out_ptr = reinterpret_cast <T*>(cpu_output_data);
128+ T* tmp_out_ptr = reinterpret_cast <T*>(cpu_output_data);
108129#else
109- T* tmp_out_ptr = cpu_out->data <T>();
130+ T* tmp_out_ptr = cpu_out->data <T>();
110131#endif
111- T* tmp_buf_ptr = reinterpret_cast <T*>(trans_buffer);
132+ T* tmp_buf_ptr = reinterpret_cast <T*>(trans_buffer);
112133
113- int64_t dim0 = cpu_out->dims ()[0 ];
114- int64_t dim1 = cpu_out->dims ()[1 ];
134+ int64_t dim0 = cpu_out->dims ()[0 ];
135+ int64_t dim1 = cpu_out->dims ()[1 ];
115136
116- for (int64_t d0 = 0 ; d0 < dim0; d0 += TRANS_NUMEL) {
117- for (int64_t d1 = 0 ; d1 < dim1; d1 += TRANS_NUMEL) {
118- const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0;
119- T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1;
137+ for (int64_t d0 = 0 ; d0 < dim0; d0 += TRANS_NUMEL) {
138+ for (int64_t d1 = 0 ; d1 < dim1; d1 += TRANS_NUMEL) {
139+ const T* src_ptr_inter = tmp_src_ptr + d0 + d1 * dim0;
140+ T* out_ptr_inter = tmp_out_ptr + d1 + d0 * dim1;
120141
121- int nr = std::min (dim0 - d0, TRANS_NUMEL);
122- int nc = std::min (dim1 - d1, TRANS_NUMEL);
142+ int nr = std::min (dim0 - d0, TRANS_NUMEL);
143+ int nc = std::min (dim1 - d1, TRANS_NUMEL);
123144
124- for (int c = 0 ; c < nc; c++) {
125- memcpy (tmp_buf_ptr + c * TRANS_NUMEL,
126- src_ptr_inter + c * dim0,
127- nr * sizeof (T));
128- }
145+ for (int c = 0 ; c < nc; c++) {
146+ memcpy (tmp_buf_ptr + c * TRANS_NUMEL,
147+ src_ptr_inter + c * dim0,
148+ nr * sizeof (T));
149+ }
129150
130- int rc_max = std::max (nr, nc);
131- int rc_min = std::min (nr, nc);
132- for (int r = 0 ; r < rc_max; r++) {
133- int end = std::min (r, rc_min);
134- for (int c = 0 ; c < end; c++) {
135- T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c];
136- tmp_buf_ptr[r + TRANS_NUMEL * c] =
137- tmp_buf_ptr[r * TRANS_NUMEL + c];
138- tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp;
151+ int rc_max = std::max (nr, nc);
152+ int rc_min = std::min (nr, nc);
153+ for (int r = 0 ; r < rc_max; r++) {
154+ int end = std::min (r, rc_min);
155+ for (int c = 0 ; c < end; c++) {
156+ T tmp = tmp_buf_ptr[r + TRANS_NUMEL * c];
157+ tmp_buf_ptr[r + TRANS_NUMEL * c] =
158+ tmp_buf_ptr[r * TRANS_NUMEL + c];
159+ tmp_buf_ptr[r * TRANS_NUMEL + c] = tmp;
160+ }
139161 }
140- }
141162
142- for (int r = 0 ; r < nr; r++) {
143- memcpy (out_ptr_inter + r * dim1,
144- tmp_buf_ptr + r * TRANS_NUMEL,
145- nc * sizeof (T));
163+ for (int r = 0 ; r < nr; r++) {
164+ memcpy (out_ptr_inter + r * dim1,
165+ tmp_buf_ptr + r * TRANS_NUMEL,
166+ nc * sizeof (T));
167+ }
146168 }
147169 }
148- }
149- free (trans_buffer);
150- } else {
170+ free (trans_buffer);
171+ } else {
151172#if defined(PADDLE_WITH_OPENMP)
152- phi::DenseTensorIteratorConfig config;
153- config.add_output (*cpu_out);
154- config.add_const_input (cpu_input);
155- config.is_alloc_out_ = true ;
156- phi::DenseTensorIterator iter = config.build ();
157-
158- std::vector<int64_t > tmp_strides (
159- iter.ntensors () * static_cast <size_t >(std::max (iter.ndim (), 2 )));
173+ phi::DenseTensorIteratorConfig config;
174+ config.add_output (*cpu_out);
175+ config.add_const_input (cpu_input);
176+ config.is_alloc_out_ = true ;
177+ phi::DenseTensorIterator iter = config.build ();
160178
161- DealWithStride (iter, tmp_strides.data ());
179+ std::vector<int64_t > tmp_strides (
180+ iter.ntensors () * static_cast <size_t >(std::max (iter.ndim (), 2 )));
162181
163- std::vector<int64_t > out_stride (tmp_strides.begin () + iter.ntensors (),
164- tmp_strides.end ());
182+ DealWithStride (iter, tmp_strides.data ());
165183
166- std::vector<int64_t > output_stride = iter.strides ( 0 );
167- std::vector< int64_t > input_stride = iter. strides ( 1 );
184+ std::vector<int64_t > out_stride (tmp_strides. begin () + iter.ntensors (),
185+ tmp_strides. end () );
168186
169- const int64_t & numel = iter.numel ();
187+ std::vector<int64_t > output_stride = iter.strides (0 );
188+ std::vector<int64_t > input_stride = iter.strides (1 );
170189
171- const char * in_ptr = reinterpret_cast <const char *>(cpu_input_data);
172- char * out_ptr = reinterpret_cast <char *>(cpu_output_data);
190+ const int64_t & numel = iter.numel ();
173191
174- int64_t end = numel;
175- int64_t begin = 0 ;
176- int64_t grain_size = 32768 ;
192+ const char * in_ptr = reinterpret_cast <const char *>(cpu_input_data);
193+ char * out_ptr = reinterpret_cast <char *>(cpu_output_data);
177194
178- int64_t * whole_stride = tmp_strides.data ();
195+ int64_t end = numel;
196+ int64_t begin = 0 ;
197+ int64_t grain_size = 32768 ;
179198
180- omp_set_num_threads ( std::thread::hardware_concurrency () );
199+ int64_t * whole_stride = tmp_strides. data ( );
181200
182201#pragma omp parallel
183- {
184- int64_t num_threads = omp_get_num_threads ();
202+ {
203+ int64_t num_threads = omp_get_num_threads ();
185204
186- if (grain_size > 0 ) {
187- num_threads = std::min (num_threads, DivUp ((end - begin), grain_size));
188- }
205+ if (grain_size > 0 ) {
206+ num_threads =
207+ std::min (num_threads, DivUp ((end - begin), grain_size));
208+ }
189209
190- int64_t tid = omp_get_thread_num ();
191- int64_t chunk_size = DivUp ((end - begin), num_threads);
192- int64_t begin_tid = begin + tid * chunk_size;
193-
194- if (begin_tid < end) {
195- int64_t range_start = begin_tid;
196- int64_t range_end = std::min (end, chunk_size + begin_tid);
197-
198- auto dimiter = DimIter (iter.shape (), range_start, range_end);
199- while (!dimiter.iter_to_end ()) {
200- const auto v_ndim = dimiter.values .size ();
201- const char * tmp_in_data = in_ptr;
202- char * tmp_out_data = out_ptr;
203- for (size_t dim = 0 ; dim < v_ndim; dim++) {
204- int64_t value = dimiter.values [dim];
205- tmp_out_data += value * whole_stride[dim * iter.ntensors () + 0 ];
206- tmp_in_data += value * whole_stride[dim * iter.ntensors () + 1 ];
207- }
210+ int64_t tid = omp_get_thread_num ();
211+ int64_t chunk_size = DivUp ((end - begin), num_threads);
212+ int64_t begin_tid = begin + tid * chunk_size;
213+
214+ if (begin_tid < end) {
215+ int64_t range_start = begin_tid;
216+ int64_t range_end = std::min (end, chunk_size + begin_tid);
217+
218+ auto dimiter = DimIter (iter.shape (), range_start, range_end);
219+ while (!dimiter.iter_to_end ()) {
220+ const auto v_ndim = dimiter.values .size ();
221+ const char * tmp_in_data = in_ptr;
222+ char * tmp_out_data = out_ptr;
223+ for (size_t dim = 0 ; dim < v_ndim; dim++) {
224+ int64_t value = dimiter.values [dim];
225+ tmp_out_data += value * whole_stride[dim * iter.ntensors () + 0 ];
226+ tmp_in_data += value * whole_stride[dim * iter.ntensors () + 1 ];
227+ }
208228
209- auto step = dimiter.iter_for_step ();
229+ auto step = dimiter.iter_for_step ();
210230
211- for (int64_t i = 0 ; i < step[1 ]; i++) {
212- for (int64_t j = 0 ; j < step[0 ]; j++) {
213- const char * real_in_ptr = tmp_in_data + j * whole_stride[1 ];
214- char * real_out_ptr = tmp_out_data + j * whole_stride[0 ];
231+ for (int64_t i = 0 ; i < step[1 ]; i++) {
232+ for (int64_t j = 0 ; j < step[0 ]; j++) {
233+ const char * real_in_ptr = tmp_in_data + j * whole_stride[1 ];
234+ char * real_out_ptr = tmp_out_data + j * whole_stride[0 ];
215235
216- *reinterpret_cast <T*>(real_out_ptr) =
217- *reinterpret_cast <const T*>(real_in_ptr);
236+ *reinterpret_cast <T*>(real_out_ptr) =
237+ *reinterpret_cast <const T*>(real_in_ptr);
238+ }
239+ tmp_in_data = tmp_in_data + out_stride[1 ];
240+ tmp_out_data = tmp_out_data + out_stride[0 ];
218241 }
219- tmp_in_data = tmp_in_data + out_stride[1 ];
220- tmp_out_data = tmp_out_data + out_stride[0 ];
221- }
222242
223- dimiter.iter_to_next (step);
243+ dimiter.iter_to_next (step);
244+ }
224245 }
225246 }
226- }
227247#else
228- phi::ContiguousKernel<T, Context>(dev_ctx, input, cpu_out);
248+ phi::ContiguousKernel<T, Context>(dev_ctx, input, cpu_out);
229249#endif
230- }
231-
232- auto src_cpu_place = input.place ();
233- auto dst_gpu_place = out->place ();
234-
235- auto & pool = phi::DeviceContextPool::Instance ();
236- auto * gpu_dev_ctx = static_cast <phi::GPUContext*>(pool.Get (out->place ()));
237- auto stream = gpu_dev_ctx->stream ();
250+ }
238251#if defined(PADDLE_WITH_OPENMP)
239- auto * src_ptr = cpu_output_data;
252+ auto * src_ptr = cpu_output_data;
240253#else
241- auto * src_ptr = cpu_out->data <T>();
254+ auto * src_ptr = cpu_out->data <T>();
242255#endif
243256
244- auto size = phi::SizeOf (input.dtype ()) * src_cpu.numel ();
245- void * dst_ptr = gpu_dev_ctx->Alloc (
246- &dst_gpu,
247- dst_gpu.dtype (),
248- 0 ,
249- dst_gpu_place.GetType () == AllocationType::GPUPINNED);
257+ auto size = phi::SizeOf (input.dtype ()) * src_cpu.numel ();
258+ void * dst_ptr = gpu_dev_ctx->Alloc (
259+ &dst_gpu,
260+ dst_gpu.dtype (),
261+ 0 ,
262+ dst_gpu_place.GetType () == AllocationType::GPUPINNED);
250263
251- phi::memory_utils::Copy (
252- dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
264+ phi::memory_utils::Copy (
265+ dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
253266
254- free (cpu_output_data);
267+ free (cpu_output_data);
268+ }
255269 if (out != &dst_gpu) {
256270 PD_VISIT_ALL_TYPES (
257271 out->dtype (), " StridedCopyKernel" , ([&] {
0 commit comments