Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/fluid/framework/naive_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class PADDLE_API NaiveExecutor {
bool switch_stream = false);

// Get an tensor to operating directly, without the need for feed_ops.
phi::DenseTensor* FindTensor(const std::string& name);
DenseTensor* FindTensor(const std::string& name);

Scope* GetScope() { return scope_; }

Expand Down Expand Up @@ -116,9 +116,9 @@ class PADDLE_API NaiveExecutor {
std::vector<PirHookFunc> pir_input_hookfuncs_;

// Record information that tensor_a should ShareBufferWith tensor_b.
std::unordered_map<OperatorBase*, std::unordered_map<phi::DenseTensor*, int>>
std::unordered_map<OperatorBase*, std::unordered_map<DenseTensor*, int>>
reuse_cache_;
std::vector<phi::DenseTensor*> cluster_buffer_;
std::vector<DenseTensor*> cluster_buffer_;

std::unique_ptr<framework::InterpreterCore> interpreter_core_;
};
Expand Down
26 changes: 13 additions & 13 deletions paddle/fluid/framework/new_executor/feed_fetch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ void SetColAttrForFeedFetchOps(std::shared_ptr<ProgramDesc> program_desc,
void SplitFeedTensors(const std::vector<std::string>& feed_names,
const int64_t micro_batch_num,
Scope* scope,
std::vector<std::vector<phi::DenseTensor>>* out) {
std::vector<phi::DenseTensor> feed_tensors;
std::vector<std::vector<DenseTensor>>* out) {
std::vector<DenseTensor> feed_tensors;
for (size_t i = 0; i < feed_names.size(); ++i) {
auto feed_name = feed_names[i];
auto feed_var = scope->GetVar(feed_name);
PADDLE_ENFORCE_NOT_NULL(
feed_var,
common::errors::NotFound("Variable %s should not be nullptr.",
feed_names[i]));
feed_tensors.push_back(feed_var->Get<phi::DenseTensor>());
feed_tensors.push_back(feed_var->Get<DenseTensor>());
}

out->resize(micro_batch_num);
Expand Down Expand Up @@ -109,10 +109,10 @@ void FetchTensors(const std::vector<std::string>& job_fetch_names,
int col = find(fetch_var_names.begin(), fetch_var_names.end(), var_name) -
fetch_var_names.begin();
auto* var = scope->FindVar(var_name);
if (var->IsType<phi::DenseTensor>()) {
auto& src = var->Get<phi::DenseTensor>();
if (var->IsType<DenseTensor>()) {
auto& src = var->Get<DenseTensor>();
auto* dst =
&(PADDLE_GET(phi::DenseTensor, fetch_list->at(micro_batch_id)[col]));
&(PADDLE_GET(DenseTensor, fetch_list->at(micro_batch_id)[col]));
if (src.IsInitialized()) {
TensorCopy(src, CPUPlace(), dst);
dst->set_lod(src.lod());
Expand Down Expand Up @@ -156,21 +156,21 @@ void MergeFetchTensors(const FetchUnmergedList& fetch_list,

out->resize(fetch_list[0].size());
for (size_t i = 0; i < fetch_list[0].size(); ++i) {
std::vector<const phi::DenseTensor*> tensors_ptr;
std::vector<const DenseTensor*> tensors_ptr;
for (auto micro_batch_id = 0; micro_batch_id < micro_batch_num;
++micro_batch_id) {
tensors_ptr.push_back(
&PADDLE_GET_CONST(phi::DenseTensor, fetch_list[micro_batch_id][i]));
&PADDLE_GET_CONST(DenseTensor, fetch_list[micro_batch_id][i]));
}
phi::DenseTensor merged_tensor;
DenseTensor merged_tensor;
MergeTensors(tensors_ptr, CPUPlace(), &merged_tensor);
out->at(i) = std::move(merged_tensor);
}
}

void MergeTensors(const std::vector<const phi::DenseTensor*>& tensors,
void MergeTensors(const std::vector<const DenseTensor*>& tensors,
const phi::Place dst_place,
phi::DenseTensor* target) {
DenseTensor* target) {
PADDLE_ENFORCE_EQ(
tensors.empty(),
false,
Expand Down Expand Up @@ -201,7 +201,7 @@ void MergeTensors(const std::vector<const phi::DenseTensor*>& tensors,
new_type,
framework::TransToProtoVarType(t->dtype()),
common::errors::InvalidArgument(
"phi::DenseTensor data type does not match, expected type is %s, "
"DenseTensor data type does not match, expected type is %s, "
"actual "
"type is %s.",
DataTypeToString(new_type),
Expand All @@ -210,7 +210,7 @@ void MergeTensors(const std::vector<const phi::DenseTensor*>& tensors,
new_layout,
t->layout(),
common::errors::InvalidArgument(
"phi::DenseTensor layout does not match, expected layout is %s, "
"DenseTensor layout does not match, expected layout is %s, "
"actual layout is %s.",
common::DataLayoutToString(new_layout),
common::DataLayoutToString(t->layout())));
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/core/dense_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,7 @@ class PADDLE_API DenseTensor : public TensorBase,
};

} // namespace phi

namespace paddle {
using DenseTensor = phi::DenseTensor;
}
Loading