Skip to content

Commit 9192f5e

Browse files
author
zhaoyinglia
committed
fix pipelinepass for mp scene
1 parent 8907d76 commit 9192f5e

File tree

10 files changed

+498
-285
lines changed

10 files changed

+498
-285
lines changed

paddle/fluid/distributed/fleet_executor/fleet_executor.cc

Lines changed: 124 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -53,40 +53,40 @@ FleetExecutor::~FleetExecutor() {
5353
}
5454
}
5555

56-
void FleetExecutor::Init(
57-
const std::string& carrier_id,
58-
const framework::ProgramDesc& program_desc,
59-
framework::Scope* scope,
60-
const platform::Place& place,
61-
int64_t num_micro_batches,
62-
const std::vector<TaskNode*>& task_nodes,
63-
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
64-
const std::vector<std::string>& inference_root_scope_vars,
65-
const std::vector<framework::Scope*>& micro_scope_list) {
66-
PADDLE_ENFORCE_GT(task_nodes.size(),
67-
0,
68-
platform::errors::InvalidArgument(
69-
"Fleet executor is inited with empty task node"));
70-
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
71-
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
72-
for (const auto& desc : program_desc.Block(0).AllOps()) {
73-
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
56+
namespace {
57+
void GetSubBlockTask(const std::vector<TaskNode*>& tasks,
58+
TaskNode* cur_task,
59+
std::set<TaskNode*>* sub_block_task) {
60+
auto& downstream = cur_task->downstream();
61+
auto& id_to_dep_type = cur_task->id_to_dep_type();
62+
for (auto& down : downstream) {
63+
int64_t task_id = down.first;
64+
if (id_to_dep_type.at(task_id) == DependType::NORMAL) {
65+
for (const auto& task : tasks) {
66+
if (task->task_id() == task_id) {
67+
sub_block_task->emplace(task);
68+
GetSubBlockTask(tasks, task, sub_block_task);
69+
}
70+
}
71+
}
7472
}
75-
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
73+
}
7674

77-
// NOTE: For inference, the vars in inference_root_scope_vars
78-
// shouldn't be deleted during inf, for that they may be the result of the
79-
// inf. If they are GCed, it will cause error during ZeroCopy the result.
75+
void PreventVarsDelete(
76+
std::unordered_map<const framework::OperatorBase*,
77+
std::vector<std::string>>* unused_vars,
78+
const std::vector<std::string>& vars_not_gc) {
8079
std::vector<const framework::OperatorBase*> changed_ops;
81-
for (auto pair : unused_vars) {
80+
81+
for (const auto& pair : *unused_vars) {
8282
const framework::OperatorBase* op = pair.first;
83-
std::vector<std::string> unused = pair.second;
84-
for (auto name : inference_root_scope_vars) {
85-
auto iter = std::find(unused.begin(), unused.end(), name);
86-
if (iter != unused.end()) {
83+
std::vector<std::string> cur_unused = pair.second;
84+
for (auto name : vars_not_gc) {
85+
auto iter = std::find(cur_unused.begin(), cur_unused.end(), name);
86+
if (iter != cur_unused.end()) {
8787
VLOG(3) << "Removing var: [" << name
8888
<< "] from the unused vars list of op: [" << op->Type() << "]";
89-
unused.erase(iter);
89+
cur_unused.erase(iter);
9090
if (std::find(changed_ops.begin(), changed_ops.end(), op) ==
9191
changed_ops.end()) {
9292
// record the op whose unused vars have been updated
@@ -95,48 +95,118 @@ void FleetExecutor::Init(
9595
}
9696
}
9797
// update the unused vars list in the map
98-
unused_vars[op] = unused;
98+
unused_vars->at(op) = cur_unused;
9999
}
100100
for (auto op : changed_ops) {
101-
auto iter = unused_vars.find(op);
101+
const auto& iter = unused_vars->find(op);
102102
if (iter->second.empty()) {
103103
// remove those ops in the map that have empty unused vars list
104104
VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map.";
105-
unused_vars.erase(iter);
105+
unused_vars->erase(iter);
106106
}
107107
}
108-
runtime_graph_ = std::make_shared<RuntimeGraph>();
109-
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
110-
for (auto task_node : task_nodes) {
111-
task_node->SetUnusedVars(unused_vars);
112-
if (task_node->type() == "Cond") {
113-
std::vector<std::string> while_block_vars;
114-
VLOG(3) << "Vars in while sub block:";
115-
for (auto& var : program_desc.Block(1).AllVars()) {
116-
VLOG(3) << var->Name();
117-
while_block_vars.emplace_back(var->Name());
118-
}
119-
for (const auto& pair : unused_vars) {
120-
if (pair.first->Type() == "while") {
121-
for (const auto& var_name : pair.second) {
122-
while_block_vars.emplace_back(var_name);
123-
}
124-
}
108+
}
109+
110+
std::vector<std::string> GetUnusedVarsAfterWhile(
111+
const framework::ProgramDesc& program_desc,
112+
TaskNode* cond_task,
113+
const std::vector<std::string> vars_not_gc) {
114+
std::vector<std::string> while_block_vars;
115+
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
116+
for (const auto& desc : program_desc.Block(0).AllOps()) {
117+
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
118+
}
119+
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
120+
PreventVarsDelete(&unused_vars, vars_not_gc);
121+
for (const auto& pair : unused_vars) {
122+
if (pair.first->Type() == "while") {
123+
for (const auto& var_name : pair.second) {
124+
while_block_vars.emplace_back(var_name);
125125
}
126-
VLOG(3) << "Vars below will be removed after while:";
127-
for (const auto& name : while_block_vars) {
128-
VLOG(3) << name;
126+
}
127+
}
128+
return while_block_vars;
129+
}
130+
131+
} // namespace
132+
133+
void FleetExecutor::Init(
134+
const std::string& carrier_id,
135+
const framework::ProgramDesc& program_desc,
136+
framework::Scope* scope,
137+
const platform::Place& place,
138+
int64_t num_micro_batches,
139+
const std::vector<TaskNode*>& task_nodes,
140+
const std::unordered_map<int64_t, int64_t>& task_id_to_rank,
141+
const std::vector<std::string>& inference_root_scope_vars,
142+
const std::vector<framework::Scope*>& micro_scope_list) {
143+
PADDLE_ENFORCE_GT(task_nodes.size(),
144+
0,
145+
platform::errors::InvalidArgument(
146+
"Fleet executor is inited with empty task node"));
147+
// Set the unused var after running while op
148+
std::set<TaskNode*> sub_block_tasks;
149+
std::vector<std::string> while_block_vars;
150+
for (const auto& task_node : task_nodes) {
151+
if (task_node->type() == "Cond") {
152+
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
153+
while_block_vars = GetUnusedVarsAfterWhile(
154+
program_desc, task_node, inference_root_scope_vars);
155+
VLOG(3) << "Vars will be gced after while op";
156+
for (auto var : while_block_vars) {
157+
VLOG(3) << var;
129158
}
130159
task_node->SetWhileBlockVars(while_block_vars);
131160
}
161+
}
162+
std::vector<framework::OperatorBase*> sub_block_ops;
163+
for (const auto& task_node : sub_block_tasks) {
164+
for (const auto& op : task_node->ops()) {
165+
sub_block_ops.emplace_back(op);
166+
}
167+
}
168+
// Analyse the unused vars in block 0. The operators in block 1
169+
// should be passed in first for prevent vars been released but removed soon.
170+
// Since the unused vars in block 1 need to analyse separately.
171+
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
172+
for (const auto& task_node : task_nodes) {
173+
for (const auto& op : task_node->ops()) {
174+
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
175+
}
176+
}
177+
auto global_unused_vars =
178+
framework::GetUnusedVars(program_desc.Block(0), ops, {});
179+
180+
// Analyse the unused vars in block 1.
181+
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
182+
sub_unused_vars;
183+
if (program_desc.Size() > 1) {
184+
sub_unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
185+
PreventVarsDelete(&sub_unused_vars, while_block_vars);
186+
}
187+
for (auto& unique_op : ops) {
188+
unique_op.release();
189+
}
190+
191+
// NOTE: For inference, the vars in inference_root_scope_vars
192+
// shouldn't be deleted during inf, for that they may be the result of the
193+
// inf. If they are GCed, it will cause error during ZeroCopy the result.
194+
PreventVarsDelete(&global_unused_vars, inference_root_scope_vars);
195+
196+
runtime_graph_ = std::make_shared<RuntimeGraph>();
197+
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
198+
for (auto task_node : task_nodes) {
199+
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
200+
task_node->SetUnusedVars(global_unused_vars);
201+
} else {
202+
// task_node->SetUnusedVars(sub_unused_vars);
203+
}
132204
int64_t interceptor_id = task_node->task_id();
133205
interceptor_id_to_task.emplace(interceptor_id, task_node);
134206
}
135207
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
136208
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
137-
for (auto& unique_op : ops) {
138-
unique_op.release();
139-
}
209+
140210
VLOG(5) << runtime_graph_->DebugString();
141211
Carrier* carrier =
142212
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);

paddle/fluid/distributed/fleet_executor/start_interceptor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ void StartInterceptor::SendDataReadyToDownStream() {
7272
auto down_id = outs.first;
7373
InterceptorMessage ready_msg;
7474
ready_msg.set_message_type(DATA_IS_READY);
75-
ready_msg.set_scope_idx(step_);
75+
ready_msg.set_scope_idx(step_ % node_->max_run_times());
7676
VLOG(3) << "StartInterceptor " << interceptor_id_
7777
<< " Send data_is_ready msg to " << down_id
78-
<< " in scope: " << step_;
78+
<< " in scope: " << step_ % node_->max_run_times();
7979
Send(down_id, ready_msg);
8080
}
8181
step_++;

python/paddle/distributed/auto_parallel/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@
3535
from . import dist_reduce_sum_p
3636
from . import dist_shape
3737
from . import dist_assign
38+
from . import dist_scale

0 commit comments

Comments
 (0)