Skip to content

Commit 4074f96

Browse files
authored
[XPU] support xpu pin memory copy (#76092)
* [XPU] support xpu pin memory copy
1 parent fb85882 commit 4074f96

File tree

3 files changed

+65
-80
lines changed

3 files changed

+65
-80
lines changed

paddle/fluid/imperative/tracer.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
163163
PADDLE_THROW(common::errors::PermissionDenied(
164164
"Paddle can't use XPU device since it's not compiled with XPU,"
165165
"Please recompile or reinstall Paddle with XPU support."));
166+
#endif
167+
} else if (phi::is_xpu_pinned_place(place)) {
168+
#if defined(PADDLE_WITH_XPU)
169+
gc = std::make_unique<framework::XPUPinnedGarbageCollector>(place, 0);
170+
VLOG(10) << "Created GarbageCollector at " << place;
171+
#else
172+
PADDLE_THROW(common::errors::PermissionDenied(
173+
"Paddle can't use XPUPinned device since it's not compiled with XPU,"
174+
"Please recompile or reinstall Paddle with XPU support."));
166175
#endif
167176
} else if (phi::is_cpu_place(place)) {
168177
gc = std::make_unique<framework::CPUGarbageCollector>(place, 0);

paddle/phi/core/memory/memcpy.cc

Lines changed: 39 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -335,32 +335,6 @@ void Copy<phi::Place, phi::XPUPlace>(phi::Place dst_place,
335335
}
336336
}
337337

338-
template <>
339-
PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
340-
void* dst,
341-
phi::Place src_place,
342-
const void* src,
343-
size_t num,
344-
void* stream) {
345-
if (dst_place.GetType() == phi::AllocationType::CPU) {
346-
phi::CPUPlace place_dst;
347-
if (src_place.GetType() == phi::AllocationType::XPU) {
348-
phi::XPUPlace place_src(src_place.GetDeviceId());
349-
return Copy(place_dst, dst, place_src, src, num);
350-
} else {
351-
VLOG(4) << "cannot fit into a copy stereotype, might be an error";
352-
}
353-
} else if (dst_place.GetType() == phi::AllocationType::XPU) {
354-
phi::XPUPlace place_dst(dst_place.GetDeviceId());
355-
if (src_place.GetType() == phi::AllocationType::CPU) {
356-
phi::CPUPlace place_src;
357-
return Copy(place_dst, dst, place_src, src, num);
358-
} else {
359-
VLOG(4) << "cannot fit into a copy stereotype, might be an error";
360-
}
361-
}
362-
}
363-
364338
template <>
365339
void Copy<phi::CPUPlace, phi::XPUPinnedPlace>(phi::CPUPlace dst_place,
366340
void* dst,
@@ -478,6 +452,32 @@ void Copy<phi::XPUPlace, phi::XPUPinnedPlace>(phi::XPUPlace dst_place,
478452
VLOG(4) << "cudaMemcpy time: " << elapsed.count() << " ms";
479453
}
480454

455+
// NOTE: for XPU and XPUPINNED.
456+
template <>
457+
PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
458+
void* dst,
459+
phi::Place src_place,
460+
const void* src,
461+
size_t num,
462+
void* stream) {
463+
if (src_place.GetType() == phi::AllocationType::XPUPINNED &&
464+
dst_place.GetType() == phi::AllocationType::XPU) {
465+
phi::XPUPinnedPlace place_src;
466+
phi::XPUPlace place_dst(dst_place.GetDeviceId());
467+
return Copy(place_dst, dst, place_src, src, num, stream);
468+
} else if (src_place.GetType() == phi::AllocationType::XPU &&
469+
dst_place.GetType() == phi::AllocationType::XPUPINNED) {
470+
phi::XPUPinnedPlace place_dst;
471+
phi::XPUPlace place_src(src_place.GetDeviceId());
472+
return Copy(place_dst, dst, place_src, src, num, stream);
473+
} else {
474+
PADDLE_THROW(::common::errors::Unimplemented(
475+
"Asynchronous Copy from %s to %s is not supported.",
476+
src_place,
477+
dst_place));
478+
}
479+
}
480+
481481
template <>
482482
void Copy<phi::XPUPinnedPlace, phi::Place>(phi::XPUPinnedPlace dst_place,
483483
void* dst,
@@ -866,50 +866,6 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
866866
phi::GPUPinnedPlace place_dst;
867867
phi::GPUPlace place_src(src_place.GetDeviceId());
868868
return Copy(place_dst, dst, place_src, src, num, stream);
869-
#ifdef PADDLE_WITH_XPU
870-
} else if (src_place.GetType() == phi::AllocationType::CPU &&
871-
dst_place.GetType() == phi::AllocationType::XPUPINNED) {
872-
phi::CPUPlace place_src;
873-
phi::XPUPinnedPlace place_dst;
874-
return Copy(place_dst, dst, place_src, src, num);
875-
} else if (src_place.GetType() == phi::AllocationType::XPUPINNED &&
876-
dst_place.GetType() == phi::AllocationType::CPU) {
877-
phi::CPUPlace place_dst;
878-
phi::XPUPinnedPlace place_src;
879-
return Copy(place_dst, dst, place_src, src, num);
880-
} else if (src_place.GetType() == phi::AllocationType::XPUPINNED &&
881-
dst_place.GetType() == phi::AllocationType::XPUPINNED) {
882-
phi::XPUPinnedPlace place_dst;
883-
phi::XPUPinnedPlace place_src;
884-
return Copy(place_dst, dst, place_src, src, num);
885-
} else if (src_place.GetType() == phi::AllocationType::XPUPINNED &&
886-
dst_place.GetType() == phi::AllocationType::XPU) {
887-
phi::XPUPinnedPlace place_src;
888-
phi::XPUPlace place_dst(dst_place.GetDeviceId());
889-
return Copy(place_dst, dst, place_src, src, num, stream);
890-
} else if (src_place.GetType() == phi::AllocationType::XPU &&
891-
dst_place.GetType() == phi::AllocationType::XPUPINNED) {
892-
phi::XPUPinnedPlace place_dst;
893-
phi::XPUPlace place_src(src_place.GetDeviceId());
894-
return Copy(place_dst, dst, place_src, src, num, stream);
895-
#endif
896-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
897-
} else if (src_place.GetType() == phi::AllocationType::CPU && // NOLINT
898-
dst_place.GetType() == phi::AllocationType::CUSTOM) {
899-
phi::CPUPlace place_src;
900-
phi::CustomPlace place_dst(dst_place);
901-
return Copy(place_dst, dst, place_src, src, num, stream);
902-
} else if (src_place.GetType() == phi::AllocationType::CUSTOM && // NOLINT
903-
dst_place.GetType() == phi::AllocationType::CPU) {
904-
phi::CustomPlace place_src(src_place);
905-
phi::CPUPlace place_dst;
906-
return Copy(place_dst, dst, place_src, src, num, stream);
907-
} else if (src_place.GetType() == phi::AllocationType::CUSTOM && // NOLINT
908-
dst_place.GetType() == phi::AllocationType::CUSTOM) {
909-
phi::CustomPlace place_src(src_place);
910-
phi::CustomPlace place_dst(dst_place);
911-
return Copy(place_dst, dst, place_src, src, num, stream);
912-
#endif
913869
}
914870
}
915871

@@ -1011,7 +967,7 @@ void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place,
1011967
}
1012968
#endif
1013969

1014-
// NOTE: Only for CPUPlace, XPUPlace and PinnedPlace.
970+
// NOTE: Synchronous Copy for All Place.
1015971
template <>
1016972
PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
1017973
void* dst,
@@ -1026,8 +982,13 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
1026982
std::memcpy(dst, src, num);
1027983
}
1028984
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1029-
else if (src_place.GetType() == phi::AllocationType::CPU && // NOLINT
1030-
dst_place.GetType() == phi::AllocationType::GPUPINNED) {
985+
else if (src_place.GetType() == phi::AllocationType::GPU && // NOLINT
986+
dst_place.GetType() == phi::AllocationType::CPU) {
987+
phi::GPUPlace place_src(src_place.GetDeviceId());
988+
phi::CPUPlace place_dst;
989+
return Copy(place_dst, dst, place_src, src, num, nullptr);
990+
} else if (src_place.GetType() == phi::AllocationType::CPU &&
991+
dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1031992
std::memcpy(dst, src, num);
1032993
} else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
1033994
dst_place.GetType() == phi::AllocationType::CPU) {
@@ -1039,11 +1000,7 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
10391000
#endif
10401001
#ifdef PADDLE_WITH_XPU
10411002
else if (src_place.GetType() == phi::AllocationType::CPU && // NOLINT
1042-
dst_place.GetType() == phi::AllocationType::CPU) {
1043-
phi::CPUPlace place_dst, place_src;
1044-
return Copy(place_dst, dst, place_src, src, num);
1045-
} else if (src_place.GetType() == phi::AllocationType::CPU &&
1046-
dst_place.GetType() == phi::AllocationType::XPU) {
1003+
dst_place.GetType() == phi::AllocationType::XPU) {
10471004
phi::XPUPlace place_dst(dst_place.GetDeviceId());
10481005
phi::CPUPlace place_src;
10491006
return Copy(place_dst, dst, place_src, src, num);
@@ -1124,6 +1081,10 @@ PADDLE_API void Copy<phi::Place, phi::Place>(phi::Place dst_place,
11241081
return Copy(place_dst, dst, place_src, src, num, nullptr);
11251082
}
11261083
#endif
1084+
else { // NOLINT
1085+
PADDLE_THROW(::common::errors::Unimplemented(
1086+
"Copy from %s to %s is not supported.", src_place, dst_place));
1087+
}
11271088
}
11281089

11291090
// NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace).

paddle/phi/core/tensor_utils.cc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ void Copy(const Context& dev_ctx,
7373
dst, src.dtype(), 0, dst_place.GetType() == AllocationType::GPUPINNED);
7474
#endif
7575
#ifdef PADDLE_WITH_XPU
76-
} else if (dst_place.GetType() == AllocationType::XPU) {
77-
dst_ptr = dev_ctx.Alloc(dst, src.dtype());
76+
} else if (dst_place.GetType() == AllocationType::XPU ||
77+
dst_place.GetType() == AllocationType::XPUPINNED) {
78+
dst_ptr = dev_ctx.Alloc(
79+
dst, src.dtype(), 0, dst_place.GetType() == AllocationType::XPUPINNED);
7880
#endif
7981
#ifdef PADDLE_WITH_CUSTOM_DEVICE
8082
} else if (dst_place.GetType() == AllocationType::CUSTOM) {
@@ -224,6 +226,11 @@ void Copy(const Context& dev_ctx,
224226
dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
225227
#endif
226228
#ifdef PADDLE_WITH_XPU
229+
} else if ((src_place.GetType() == AllocationType::CPU ||
230+
src_place.GetType() == AllocationType::XPUPINNED) && // NOLINT
231+
(dst_place.GetType() == AllocationType::CPU ||
232+
dst_place.GetType() == AllocationType::XPUPINNED)) {
233+
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
227234
} else if (src_place.GetType() == AllocationType::XPU && // NOLINT
228235
dst_place.GetType() == AllocationType::CPU) {
229236
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
@@ -238,6 +245,14 @@ void Copy(const Context& dev_ctx,
238245
return;
239246
}
240247
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
248+
} else if ((src_place.GetType() == AllocationType::XPU &&
249+
dst_place.GetType() == AllocationType::XPUPINNED) ||
250+
(src_place.GetType() == AllocationType::XPUPINNED &&
251+
dst_place.GetType() == AllocationType::XPU)) {
252+
auto stream =
253+
blocking ? nullptr
254+
: reinterpret_cast<const phi::XPUContext&>(dev_ctx).stream();
255+
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
241256
#endif
242257
#ifdef PADDLE_WITH_CUSTOM_DEVICE
243258
} else if (src_place.GetType() == AllocationType::CUSTOM && // NOLINT

0 commit comments

Comments
 (0)