@@ -335,32 +335,6 @@ void Copy<phi::Place, phi::XPUPlace>(phi::Place dst_place,
335335 }
336336}
337337
338- template <>
339- PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
340- void * dst,
341- phi::Place src_place,
342- const void * src,
343- size_t num,
344- void * stream) {
345- if (dst_place.GetType () == phi::AllocationType::CPU) {
346- phi::CPUPlace place_dst;
347- if (src_place.GetType () == phi::AllocationType::XPU) {
348- phi::XPUPlace place_src (src_place.GetDeviceId ());
349- return Copy (place_dst, dst, place_src, src, num);
350- } else {
351- VLOG (4 ) << " cannot fit into a copy stereotype, might be an error" ;
352- }
353- } else if (dst_place.GetType () == phi::AllocationType::XPU) {
354- phi::XPUPlace place_dst (dst_place.GetDeviceId ());
355- if (src_place.GetType () == phi::AllocationType::CPU) {
356- phi::CPUPlace place_src;
357- return Copy (place_dst, dst, place_src, src, num);
358- } else {
359- VLOG (4 ) << " cannot fit into a copy stereotype, might be an error" ;
360- }
361- }
362- }
363-
364338template <>
365339void Copy<phi::CPUPlace, phi::XPUPinnedPlace>(phi::CPUPlace dst_place,
366340 void * dst,
@@ -478,6 +452,32 @@ void Copy<phi::XPUPlace, phi::XPUPinnedPlace>(phi::XPUPlace dst_place,
478452 VLOG (4 ) << " cudaMemcpy time: " << elapsed.count () << " ms" ;
479453}
480454
455+ // NOTE: for XPU and XPUPINNED.
456+ template <>
457+ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
458+ void * dst,
459+ phi::Place src_place,
460+ const void * src,
461+ size_t num,
462+ void * stream) {
463+ if (src_place.GetType () == phi::AllocationType::XPUPINNED &&
464+ dst_place.GetType () == phi::AllocationType::XPU) {
465+ phi::XPUPinnedPlace place_src;
466+ phi::XPUPlace place_dst (dst_place.GetDeviceId ());
467+ return Copy (place_dst, dst, place_src, src, num, stream);
468+ } else if (src_place.GetType () == phi::AllocationType::XPU &&
469+ dst_place.GetType () == phi::AllocationType::XPUPINNED) {
470+ phi::XPUPinnedPlace place_dst;
471+ phi::XPUPlace place_src (src_place.GetDeviceId ());
472+ return Copy (place_dst, dst, place_src, src, num, stream);
473+ } else {
474+ PADDLE_THROW (::common::errors::Unimplemented (
475+ " Asynchronous Copy from %s to %s is not supported." ,
476+ src_place,
477+ dst_place));
478+ }
479+ }
480+
481481template <>
482482void Copy<phi::XPUPinnedPlace, phi::Place>(phi::XPUPinnedPlace dst_place,
483483 void * dst,
@@ -866,50 +866,6 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
866866 phi::GPUPinnedPlace place_dst;
867867 phi::GPUPlace place_src (src_place.GetDeviceId ());
868868 return Copy (place_dst, dst, place_src, src, num, stream);
869- #ifdef PADDLE_WITH_XPU
870- } else if (src_place.GetType () == phi::AllocationType::CPU &&
871- dst_place.GetType () == phi::AllocationType::XPUPINNED) {
872- phi::CPUPlace place_src;
873- phi::XPUPinnedPlace place_dst;
874- return Copy (place_dst, dst, place_src, src, num);
875- } else if (src_place.GetType () == phi::AllocationType::XPUPINNED &&
876- dst_place.GetType () == phi::AllocationType::CPU) {
877- phi::CPUPlace place_dst;
878- phi::XPUPinnedPlace place_src;
879- return Copy (place_dst, dst, place_src, src, num);
880- } else if (src_place.GetType () == phi::AllocationType::XPUPINNED &&
881- dst_place.GetType () == phi::AllocationType::XPUPINNED) {
882- phi::XPUPinnedPlace place_dst;
883- phi::XPUPinnedPlace place_src;
884- return Copy (place_dst, dst, place_src, src, num);
885- } else if (src_place.GetType () == phi::AllocationType::XPUPINNED &&
886- dst_place.GetType () == phi::AllocationType::XPU) {
887- phi::XPUPinnedPlace place_src;
888- phi::XPUPlace place_dst (dst_place.GetDeviceId ());
889- return Copy (place_dst, dst, place_src, src, num, stream);
890- } else if (src_place.GetType () == phi::AllocationType::XPU &&
891- dst_place.GetType () == phi::AllocationType::XPUPINNED) {
892- phi::XPUPinnedPlace place_dst;
893- phi::XPUPlace place_src (src_place.GetDeviceId ());
894- return Copy (place_dst, dst, place_src, src, num, stream);
895- #endif
896- #ifdef PADDLE_WITH_CUSTOM_DEVICE
897- } else if (src_place.GetType () == phi::AllocationType::CPU && // NOLINT
898- dst_place.GetType () == phi::AllocationType::CUSTOM) {
899- phi::CPUPlace place_src;
900- phi::CustomPlace place_dst (dst_place);
901- return Copy (place_dst, dst, place_src, src, num, stream);
902- } else if (src_place.GetType () == phi::AllocationType::CUSTOM && // NOLINT
903- dst_place.GetType () == phi::AllocationType::CPU) {
904- phi::CustomPlace place_src (src_place);
905- phi::CPUPlace place_dst;
906- return Copy (place_dst, dst, place_src, src, num, stream);
907- } else if (src_place.GetType () == phi::AllocationType::CUSTOM && // NOLINT
908- dst_place.GetType () == phi::AllocationType::CUSTOM) {
909- phi::CustomPlace place_src (src_place);
910- phi::CustomPlace place_dst (dst_place);
911- return Copy (place_dst, dst, place_src, src, num, stream);
912- #endif
913869 }
914870}
915871
@@ -1011,7 +967,7 @@ void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place,
1011967}
1012968#endif
1013969
1014- // NOTE: Only for CPUPlace, XPUPlace and PinnedPlace .
970+ // NOTE: Synchronous Copy for All Place .
1015971template <>
1016972PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
1017973 void * dst,
@@ -1026,8 +982,13 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
1026982 std::memcpy (dst, src, num);
1027983 }
1028984#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1029- else if (src_place.GetType () == phi::AllocationType::CPU && // NOLINT
1030- dst_place.GetType () == phi::AllocationType::GPUPINNED) {
985+ else if (src_place.GetType () == phi::AllocationType::GPU && // NOLINT
986+ dst_place.GetType () == phi::AllocationType::CPU) {
987+ phi::GPUPlace place_src (src_place.GetDeviceId ());
988+ phi::CPUPlace place_dst;
989+ return Copy (place_dst, dst, place_src, src, num, nullptr );
990+ } else if (src_place.GetType () == phi::AllocationType::CPU &&
991+ dst_place.GetType () == phi::AllocationType::GPUPINNED) {
1031992 std::memcpy (dst, src, num);
1032993 } else if (src_place.GetType () == phi::AllocationType::GPUPINNED &&
1033994 dst_place.GetType () == phi::AllocationType::CPU) {
@@ -1039,11 +1000,7 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
10391000#endif
10401001#ifdef PADDLE_WITH_XPU
10411002 else if (src_place.GetType () == phi::AllocationType::CPU && // NOLINT
1042- dst_place.GetType () == phi::AllocationType::CPU) {
1043- phi::CPUPlace place_dst, place_src;
1044- return Copy (place_dst, dst, place_src, src, num);
1045- } else if (src_place.GetType () == phi::AllocationType::CPU &&
1046- dst_place.GetType () == phi::AllocationType::XPU) {
1003+ dst_place.GetType () == phi::AllocationType::XPU) {
10471004 phi::XPUPlace place_dst (dst_place.GetDeviceId ());
10481005 phi::CPUPlace place_src;
10491006 return Copy (place_dst, dst, place_src, src, num);
@@ -1124,6 +1081,10 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
11241081 return Copy (place_dst, dst, place_src, src, num, nullptr );
11251082 }
11261083#endif
1084+ else { // NOLINT
1085+ PADDLE_THROW (::common::errors::Unimplemented (
1086+ " Copy from %s to %s is not supported." , src_place, dst_place));
1087+ }
11271088}
11281089
11291090// NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace).
0 commit comments