Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions paddle/fluid/imperative/all_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
namespace paddle::imperative {

static const phi::Place &GetVarPlace(const framework::Variable &src) {
if (src.IsType<phi::DenseTensor>()) {
return src.Get<phi::DenseTensor>().place();
if (src.IsType<DenseTensor>()) {
return src.Get<DenseTensor>().place();
#if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<phi::SelectedRows>()) {
return src.Get<phi::SelectedRows>().value().place();
Expand All @@ -52,8 +52,8 @@ static const phi::Place &GetVarPlace(const framework::Variable &src) {
}
}

static void AllReduce(const phi::DenseTensor &src,
phi::DenseTensor *dst,
static void AllReduce(const DenseTensor &src,
DenseTensor *dst,
const gpuStream_t stream,
const platform::NCCLComm *comm) {
const auto &place = src.place();
Expand Down Expand Up @@ -224,14 +224,12 @@ void AllReduce(const framework::Variable &src,
platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = (use_calc_stream ? dev_ctx->stream() : comm->stream());

if (src.IsType<phi::DenseTensor>()) {
if (!dst->IsType<phi::DenseTensor>()) {
if (src.IsType<DenseTensor>()) {
if (!dst->IsType<DenseTensor>()) {
dst->Clear();
}
AllReduce(src.Get<phi::DenseTensor>(),
dst->GetMutable<phi::DenseTensor>(),
stream,
comm);
AllReduce(
src.Get<DenseTensor>(), dst->GetMutable<DenseTensor>(), stream, comm);
#if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<phi::SelectedRows>()) {
if (&src != dst) {
Expand Down
25 changes: 11 additions & 14 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ void BasicEngine::Init(
true,
common::errors::NotFound("Tensor %s has no gradient", var->Name()));

auto& fwd_var = var->Var().Get<phi::DenseTensor>();
auto& fwd_var = var->Var().Get<DenseTensor>();
auto* grad_var =
var->GradVarBase()->MutableVar()->GetMutable<phi::DenseTensor>();
var->GradVarBase()->MutableVar()->GetMutable<DenseTensor>();
VLOG(6) << "init loss grad:" << var->GradVarBase()->Name()
<< " as stop_gradient false";
var->GradVarBase()->InnerSetOverriddenStopGradient(false);
Expand All @@ -112,7 +112,7 @@ void BasicEngine::Init(
grad_var->mutable_data(fwd_var.place(), fwd_var.type());
phi::funcs::set_constant(*dev_ctx, grad_var, 1.0f);
} else {
paddle::framework::TensorCopy(grad_tensor->Var().Get<phi::DenseTensor>(),
paddle::framework::TensorCopy(grad_tensor->Var().Get<DenseTensor>(),
fwd_var.place(),
*dev_ctx,
grad_var);
Expand Down Expand Up @@ -149,10 +149,9 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
}

auto* inner_var = var->MutableVar();
phi::DenseTensor* tensor = nullptr;
if (!inner_var->IsInitialized() ||
inner_var->IsType<phi::DenseTensor>()) {
tensor = inner_var->GetMutable<phi::DenseTensor>();
DenseTensor* tensor = nullptr;
if (!inner_var->IsInitialized() || inner_var->IsType<DenseTensor>()) {
tensor = inner_var->GetMutable<DenseTensor>();
}

if (tensor && !tensor->IsInitialized()) {
Expand Down Expand Up @@ -340,8 +339,8 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(

static bool IsInputCanInplace(const std::shared_ptr<VariableWrapper>& var) {
auto* inner_var = var->MutableVar();
if (inner_var->IsInitialized() && inner_var->IsType<phi::DenseTensor>()) {
auto tensor = inner_var->GetMutable<phi::DenseTensor>();
if (inner_var->IsInitialized() && inner_var->IsType<DenseTensor>()) {
auto tensor = inner_var->GetMutable<DenseTensor>();
if (tensor->IsInitialized()) {
return true;
}
Expand All @@ -358,7 +357,7 @@ static void PerformBackwardInplace(const std::string& op_type,
if (infer_inplace) {
auto in_to_outs = infer_inplace(true);
for (auto& pair : in_to_outs) {
phi::DenseTensor *in_tensor = nullptr, *out_tensor = nullptr;
DenseTensor *in_tensor = nullptr, *out_tensor = nullptr;
for (auto& p : ins) {
if (p.first == pair.first) {
// has at least one var
Expand All @@ -368,8 +367,7 @@ static void PerformBackwardInplace(const std::string& op_type,
// the refcount of var to be inplaced should be 1
if (in_var.use_count() == 1) {
if (IsInputCanInplace(in_var)) {
in_tensor =
in_var->MutableVar()->GetMutable<phi::DenseTensor>();
in_tensor = in_var->MutableVar()->GetMutable<DenseTensor>();
}
}
}
Expand All @@ -383,8 +381,7 @@ static void PerformBackwardInplace(const std::string& op_type,
if (!p.second.empty() && p.second[0]) {
auto& out_var = p.second[0];
if (out_var->Type() == framework::proto::VarType::DENSE_TENSOR) {
out_tensor =
out_var->MutableVar()->GetMutable<phi::DenseTensor>();
out_tensor = out_var->MutableVar()->GetMutable<DenseTensor>();
}
}
}
Expand Down
16 changes: 7 additions & 9 deletions paddle/fluid/imperative/bkcl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
namespace paddle {
namespace imperative {

static void AllReduce(const phi::DenseTensor &src,
phi::DenseTensor *dst,
static void AllReduce(const DenseTensor &src,
DenseTensor *dst,
const XPUStream stream,
const platform::BKCLComm *comm) {
const auto &place = src.place();
Expand Down Expand Up @@ -162,14 +162,12 @@ void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
XPUStream stream =
use_calc_stream ? dev_ctx->x_context()->xpu_stream : comm->stream();

if (src.IsType<phi::DenseTensor>()) {
if (!dst->IsType<phi::DenseTensor>()) {
if (src.IsType<DenseTensor>()) {
if (!dst->IsType<DenseTensor>()) {
dst->Clear();
}
AllReduce(src.Get<phi::DenseTensor>(),
dst->GetMutable<phi::DenseTensor>(),
stream,
comm);
AllReduce(
src.Get<DenseTensor>(), dst->GetMutable<DenseTensor>(), stream, comm);
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"XPU unsupported variable type %s for imperative allreduce, only "
Expand All @@ -180,7 +178,7 @@ void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,

void BKCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id;
phi::DenseTensor *src_tensor = src->GetMutable<phi::DenseTensor>();
DenseTensor *src_tensor = src->GetMutable<DenseTensor>();
const auto &place = src_tensor->place();
platform::BKCLComm *comm =
platform::BKCLCommContext::Instance().Get(ring_id, place);
Expand Down
17 changes: 8 additions & 9 deletions paddle/fluid/imperative/dygraph_grad_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ class GradOpBaseMakerBase {

if (!is_input) {
auto* tensor =
grad_var_base_tmp->MutableVar()->GetMutable<phi::DenseTensor>();
tensor->Resize(var_base_temp->Var().Get<phi::DenseTensor>().dims());
grad_var_base_tmp->MutableVar()->GetMutable<DenseTensor>();
tensor->Resize(var_base_temp->Var().Get<DenseTensor>().dims());
}
vec_temp.emplace_back(grad_var_base_tmp);
} else {
Expand Down Expand Up @@ -363,14 +363,13 @@ class TracedGradOp {
} else if (var_wrapper->InplaceVersionSnapshot() ==
var_wrapper->MutableVar()->CurrentInplaceVersion()) {
return var_wrapper;
} else if (var_wrapper->MutableVar()->IsType<phi::DenseTensor>() ||
} else if (var_wrapper->MutableVar()->IsType<DenseTensor>() ||
var_wrapper->MutableVar()->IsType<phi::SelectedRows>()) {
auto* tensor =
var_wrapper->MutableVar()->IsType<phi::DenseTensor>()
? var_wrapper->MutableVar()->GetMutable<phi::DenseTensor>()
: var_wrapper->MutableVar()
->GetMutable<phi::SelectedRows>()
->mutable_value();
auto* tensor = var_wrapper->MutableVar()->IsType<DenseTensor>()
? var_wrapper->MutableVar()->GetMutable<DenseTensor>()
: var_wrapper->MutableVar()
->GetMutable<phi::SelectedRows>()
->mutable_value();
if (!tensor->IsInitialized()) {
return var_wrapper;
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/imperative/gloo_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src,
int ring_id,
bool use_calc_stream) {
// AllReduce(src, dst, strategy_, ring_id, use_calc_stream);
if (src.IsType<phi::DenseTensor>()) {
if (!dst->IsType<phi::DenseTensor>()) {
if (src.IsType<DenseTensor>()) {
if (!dst->IsType<DenseTensor>()) {
dst->Clear();
}
AllReduce(src.Get<phi::DenseTensor>(), dst->GetMutable<phi::DenseTensor>());
AllReduce(src.Get<DenseTensor>(), dst->GetMutable<DenseTensor>());
} else if (src.IsType<phi::SelectedRows>()) {
if (&src != dst) {
if (!dst->IsType<phi::SelectedRows>()) {
Expand All @@ -106,8 +106,8 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src,
}
}

void GLOOParallelContext::AllReduce(const phi::DenseTensor &src_tensor,
phi::DenseTensor *dst_tensor) {
void GLOOParallelContext::AllReduce(const DenseTensor &src_tensor,
DenseTensor *dst_tensor) {
auto gloo_wrapper = framework::GlooWrapper::GetInstance();
dst_tensor->Resize(src_tensor.dims());
switch (framework::TransToProtoVarType(src_tensor.dtype())) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/gloo_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GLOOParallelContext : public ParallelContext {
void SynchronizeCompute() override;

private:
void AllReduce(const phi::DenseTensor& src, phi::DenseTensor* dst);
void AllReduce(const DenseTensor& src, DenseTensor* dst);
void AllReduce(const phi::SelectedRows& src, phi::SelectedRows* dst);

private:
Expand Down
Loading
Loading