@@ -53,40 +53,40 @@ FleetExecutor::~FleetExecutor() {
5353 }
5454}
5555
56- void FleetExecutor::Init (
57- const std::string& carrier_id,
58- const framework::ProgramDesc& program_desc,
59- framework::Scope* scope,
60- const platform::Place& place,
61- int64_t num_micro_batches,
62- const std::vector<TaskNode*>& task_nodes,
63- const std::unordered_map<int64_t , int64_t >& task_id_to_rank,
64- const std::vector<std::string>& inference_root_scope_vars,
65- const std::vector<framework::Scope*>& micro_scope_list) {
66- PADDLE_ENFORCE_GT (task_nodes.size (),
67- 0 ,
68- platform::errors::InvalidArgument (
69- " Fleet executor is inited with empty task node" ));
70- // TODO(fleet_exe devs): the unused_vars should be got from run time graph
71- std::vector<std::unique_ptr<framework::OperatorBase>> ops;
72- for (const auto & desc : program_desc.Block (0 ).AllOps ()) {
73- ops.emplace_back (framework::OpRegistry::CreateOp (*desc));
56+ namespace {
57+ void GetSubBlockTask (const std::vector<TaskNode*>& tasks,
58+ TaskNode* cur_task,
59+ std::set<TaskNode*>* sub_block_task) {
60+ auto & downstream = cur_task->downstream ();
61+ auto & id_to_dep_type = cur_task->id_to_dep_type ();
62+ for (auto & down : downstream) {
63+ int64_t task_id = down.first ;
64+ if (id_to_dep_type.at (task_id) == DependType::NORMAL) {
65+ for (const auto & task : tasks) {
66+ if (task->task_id () == task_id) {
67+ sub_block_task->emplace (task);
68+ GetSubBlockTask (tasks, task, sub_block_task);
69+ }
70+ }
71+ }
7472 }
75- auto unused_vars = framework::GetUnusedVars (program_desc. Block ( 0 ), ops, {});
73+ }
7674
77- // NOTE: For inference, the vars in inference_root_scope_vars
78- // shouldn't be deleted during inf, for that they may be the result of the
79- // inf. If they are GCed, it will cause error during ZeroCopy the result.
75+ void PreventVarsDelete (
76+ std::unordered_map<const framework::OperatorBase*,
77+ std::vector<std::string>>* unused_vars,
78+ const std::vector<std::string>& vars_not_gc) {
8079 std::vector<const framework::OperatorBase*> changed_ops;
81- for (auto pair : unused_vars) {
80+
81+ for (const auto & pair : *unused_vars) {
8282 const framework::OperatorBase* op = pair.first ;
83- std::vector<std::string> unused = pair.second ;
84- for (auto name : inference_root_scope_vars ) {
85- auto iter = std::find (unused .begin (), unused .end (), name);
86- if (iter != unused .end ()) {
83+ std::vector<std::string> cur_unused = pair.second ;
84+ for (auto name : vars_not_gc ) {
85+ auto iter = std::find (cur_unused .begin (), cur_unused .end (), name);
86+ if (iter != cur_unused .end ()) {
8787 VLOG (3 ) << " Removing var: [" << name
8888 << " ] from the unused vars list of op: [" << op->Type () << " ]" ;
89- unused .erase (iter);
89+ cur_unused .erase (iter);
9090 if (std::find (changed_ops.begin (), changed_ops.end (), op) ==
9191 changed_ops.end ()) {
9292 // record the op whose unused vars have been updated
@@ -95,48 +95,118 @@ void FleetExecutor::Init(
9595 }
9696 }
9797 // update the unused vars list in the map
98- unused_vars[op] = unused ;
98+ unused_vars-> at (op) = cur_unused ;
9999 }
100100 for (auto op : changed_ops) {
101- auto iter = unused_vars. find (op);
101+ const auto & iter = unused_vars-> find (op);
102102 if (iter->second .empty ()) {
103103 // remove those ops in the map that have empty unused vars list
104104 VLOG (3 ) << " Removing op: [" << op->Type () << " ] from unused_vars map." ;
105- unused_vars. erase (iter);
105+ unused_vars-> erase (iter);
106106 }
107107 }
108- runtime_graph_ = std::make_shared<RuntimeGraph>();
109- std::unordered_map< int64_t , TaskNode*> interceptor_id_to_task;
110- for ( auto task_node : task_nodes) {
111- task_node-> SetUnusedVars (unused_vars);
112- if (task_node-> type () == " Cond " ) {
113- std::vector<std::string> while_block_vars;
114- VLOG ( 3 ) << " Vars in while sub block: " ;
115- for ( auto & var : program_desc. Block ( 1 ). AllVars ()) {
116- VLOG ( 3 ) << var-> Name ();
117- while_block_vars .emplace_back (var-> Name ( ));
118- }
119- for ( const auto & pair : unused_vars) {
120- if (pair. first -> Type () == " while " ) {
121- for (const auto & var_name : pair. second ) {
122- while_block_vars. emplace_back (var_name);
123- }
124- }
108+ }
109+
110+ std::vector<std::string> GetUnusedVarsAfterWhile (
111+ const framework::ProgramDesc& program_desc,
112+ TaskNode* cond_task,
113+ const std::vector<std::string> vars_not_gc) {
114+ std::vector<std::string> while_block_vars ;
115+ std::vector<std::unique_ptr<framework::OperatorBase>> ops;
116+ for ( const auto & desc : program_desc. Block ( 0 ). AllOps ()) {
117+ ops .emplace_back (framework::OpRegistry::CreateOp (*desc ));
118+ }
119+ auto unused_vars = framework::GetUnusedVars (program_desc. Block ( 0 ), ops, {});
120+ PreventVarsDelete (&unused_vars, vars_not_gc);
121+ for (const auto & pair : unused_vars ) {
122+ if (pair. first -> Type () == " while " ) {
123+ for ( const auto & var_name : pair. second ) {
124+ while_block_vars. emplace_back (var_name);
125125 }
126- VLOG (3 ) << " Vars below will be removed after while:" ;
127- for (const auto & name : while_block_vars) {
128- VLOG (3 ) << name;
126+ }
127+ }
128+ return while_block_vars;
129+ }
130+
131+ } // namespace
132+
133+ void FleetExecutor::Init (
134+ const std::string& carrier_id,
135+ const framework::ProgramDesc& program_desc,
136+ framework::Scope* scope,
137+ const platform::Place& place,
138+ int64_t num_micro_batches,
139+ const std::vector<TaskNode*>& task_nodes,
140+ const std::unordered_map<int64_t , int64_t >& task_id_to_rank,
141+ const std::vector<std::string>& inference_root_scope_vars,
142+ const std::vector<framework::Scope*>& micro_scope_list) {
143+ PADDLE_ENFORCE_GT (task_nodes.size (),
144+ 0 ,
145+ platform::errors::InvalidArgument (
146+ " Fleet executor is inited with empty task node" ));
147+ // Set the unused var after running while op
148+ std::set<TaskNode*> sub_block_tasks;
149+ std::vector<std::string> while_block_vars;
150+ for (const auto & task_node : task_nodes) {
151+ if (task_node->type () == " Cond" ) {
152+ GetSubBlockTask (task_nodes, task_node, &sub_block_tasks);
153+ while_block_vars = GetUnusedVarsAfterWhile (
154+ program_desc, task_node, inference_root_scope_vars);
155+ VLOG (3 ) << " Vars will be gced after while op" ;
156+ for (auto var : while_block_vars) {
157+ VLOG (3 ) << var;
129158 }
130159 task_node->SetWhileBlockVars (while_block_vars);
131160 }
161+ }
162+ std::vector<framework::OperatorBase*> sub_block_ops;
163+ for (const auto & task_node : sub_block_tasks) {
164+ for (const auto & op : task_node->ops ()) {
165+ sub_block_ops.emplace_back (op);
166+ }
167+ }
168+ // Analyse the unused vars in block 0. The operators in block 1
169+ // should be passed in first for prevent vars been released but removed soon.
170+ // Since the unused vars in block 1 need to analyse separately.
171+ std::vector<std::unique_ptr<framework::OperatorBase>> ops;
172+ for (const auto & task_node : task_nodes) {
173+ for (const auto & op : task_node->ops ()) {
174+ ops.emplace_back (std::unique_ptr<framework::OperatorBase>(op));
175+ }
176+ }
177+ auto global_unused_vars =
178+ framework::GetUnusedVars (program_desc.Block (0 ), ops, {});
179+
180+ // Analyse the unused vars in block 1.
181+ std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
182+ sub_unused_vars;
183+ if (program_desc.Size () > 1 ) {
184+ sub_unused_vars = framework::GetUnusedVars (program_desc.Block (1 ), ops, {});
185+ PreventVarsDelete (&sub_unused_vars, while_block_vars);
186+ }
187+ for (auto & unique_op : ops) {
188+ unique_op.release ();
189+ }
190+
191+ // NOTE: For inference, the vars in inference_root_scope_vars
192+ // shouldn't be deleted during inf, for that they may be the result of the
193+ // inf. If they are GCed, it will cause error during ZeroCopy the result.
194+ PreventVarsDelete (&global_unused_vars, inference_root_scope_vars);
195+
196+ runtime_graph_ = std::make_shared<RuntimeGraph>();
197+ std::unordered_map<int64_t , TaskNode*> interceptor_id_to_task;
198+ for (auto task_node : task_nodes) {
199+ if (sub_block_tasks.find (task_node) == sub_block_tasks.end ()) {
200+ task_node->SetUnusedVars (global_unused_vars);
201+ } else {
202+ // task_node->SetUnusedVars(sub_unused_vars);
203+ }
132204 int64_t interceptor_id = task_node->task_id ();
133205 interceptor_id_to_task.emplace (interceptor_id, task_node);
134206 }
135207 runtime_graph_->SetInterceptorIdToRank (task_id_to_rank);
136208 runtime_graph_->SetInterceptorIdToNode (interceptor_id_to_task);
137- for (auto & unique_op : ops) {
138- unique_op.release ();
139- }
209+
140210 VLOG (5 ) << runtime_graph_->DebugString ();
141211 Carrier* carrier =
142212 GlobalMap<std::string, Carrier>::Create (carrier_id, carrier_id);
0 commit comments