forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet_dag.cc
345 lines (310 loc) · 10.9 KB
/
net_dag.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
#include "caffe2/core/net_dag.h"
#include <iostream>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include "caffe2/core/operator.h"
#include "caffe2/core/static_tracepoint.h"
#include "caffe2/core/timer.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/thread_name.h"
C10_DEFINE_bool(
caffe2_disable_chaining,
false,
"Disable chaining logic (some latent multi-device issues).");
C10_DEFINE_bool(
caffe2_dag_net_collect_stats,
false,
"Collect time stats in DAG net");
namespace caffe2 {
DAGNetBase::DAGNetBase(
const std::shared_ptr<const NetDef>& net_def,
Workspace* ws)
: NetBase(net_def, ws), caught_exception_yet_(false), iter_(0) {
// Blob creator allows us to track which operator created which blob.
VLOG(1) << "Constructing DAGNet " << net_def->name();
operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
execution_chains_ =
(FLAGS_caffe2_disable_chaining
? dag_utils::singleChains(operator_nodes_)
: dag_utils::computeChains(operator_nodes_));
operators_.reserve(operator_nodes_.size());
for (const auto& node : operator_nodes_) {
operators_.push_back(node.operator_.get());
}
LOG(INFO) << "Number of parallel execution chains "
<< execution_chains_.size()
<< " Number of operators = " << net_def->op_size();
// TODO: do we want to make sure that there are no loops in the
// dependency graph?
// Figure out the initial frontier - this is the one we will feed into the job
// queue to start a run.
for (size_t idx = 0; idx < operator_nodes_.size(); ++idx) {
if (operator_nodes_[idx].parents_.size() == 0) {
initial_frontier_.push_back(idx);
}
}
// Finally, start the workers.
int num_workers = net_def->has_num_workers() ? net_def->num_workers() : 1;
CAFFE_ENFORCE(num_workers > 0, "Must have a positive number of workers.");
if (num_workers == 1) {
LOG(WARNING) << "Number of workers is 1: this means that all operators "
<< "will be executed sequentially. Did you forget to set "
<< "num_workers in the NetDef?";
}
num_workers_ = num_workers;
for (size_t idx = 0; idx < operator_nodes_.size(); ++idx) {
if (operator_nodes_[idx].is_chain_start_) {
task_timers_[idx] = caffe2::make_unique<Timer>();
}
}
constexpr auto MAX_DEVICE_TYPES =
DeviceTypeProto::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
stats_.reserve(MAX_DEVICE_TYPES);
for (auto device_idx = 0; device_idx < MAX_DEVICE_TYPES; ++device_idx) {
stats_.emplace_back(
"dag_net/stats/" + net_def->name() + "/" +
caffe2::DeviceTypeName(device_idx));
}
tracer_ = tracing::create(this, net_def->name());
if (tracer_) {
LOG(INFO) << "Tracing net: " << net_def->name();
}
}
DAGNetBase::~DAGNetBase() {
if (job_queue_) {
job_queue_->NoMoreJobs();
VLOG(1) << "Joining workers.";
for (auto& worker : workers_) {
worker.join();
}
}
}
bool DAGNetBase::DoRunAsync() {
StartAllObservers();
tracing::startIter(tracer_);
// Lock run_in_progress_ to prevent concurrent Run()s.
std::unique_lock<std::mutex> run_lock(run_in_progress_);
VLOG(1) << "Running parallel net.";
// First, set up job queue.
remaining_ops_ = operator_nodes_.size();
success_ = true;
iter_++;
if (!job_queue_) {
job_queue_ = caffe2::make_unique<SimpleQueue<int>>();
}
// Figure out number of workers to start.
size_t num_workers_to_start = num_workers_ - workers_.size();
// Ensure the number of workers matches the defined in case
// any of the previously started threads terminated.
for (size_t i = 0; i < num_workers_to_start; i++) {
VLOG(1) << "Start worker #" << workers_.size();
workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
}
// Initialize the runtime parent count.
for (auto& node : operator_nodes_) {
node.runtime_parent_count_ = node.parents_.size();
}
// Kickstart the job queue.
for (auto& value : initial_frontier_) {
if (FLAGS_caffe2_dag_net_collect_stats) {
task_timers_[value]->Start();
}
job_queue_->Push(value);
}
// Wait for failure or completed execution.
{
std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
for (;;) {
if (remaining_ops_ == 0 || !success_) {
break;
}
cv_.wait(mutex_lock);
}
}
// Wait for all workers to terminate after failure.
// If there is a failure, it is unlikely that the net is executed
// again without modifications. Therefore it's easier to let the
// workers terminate here, versus adding a drain state to make the
// sure the job queue is cleared.
if (!success_) {
for (auto& worker : workers_) {
worker.join();
}
workers_.clear();
job_queue_.reset(nullptr);
#ifdef CAFFE2_USE_EXCEPTION_PTR
if (caught_exception_) {
// Reset flag here in case Net gets run again
caught_exception_yet_ = false;
std::rethrow_exception(caught_exception_);
}
#endif // CAFFE2_USE_EXCEPTION_PTR
return success_;
}
VLOG(2) << "All ops finished running.";
for (const auto& op : operator_nodes_) {
CAFFE_ENFORCE(
op.runtime_parent_count_ == 0,
"Operator ",
op.operator_->debug_def().name(),
"(",
op.operator_->debug_def().type(),
") has some runtime parents left.");
}
StopAllObservers();
// If the above while loop finished, we know that the current run finished.
return success_;
}
void DAGNetBase::HandleException(
int operator_idx,
const std::string& exception_str) {
const std::string& operator_name =
operator_nodes_[operator_idx].operator_->debug_def().name();
const std::string& operator_type =
operator_nodes_[operator_idx].operator_->debug_def().type();
const char* prefix = "Exception from operator chain starting at '";
#ifdef CAFFE2_USE_EXCEPTION_PTR
if (!caught_exception_yet_.exchange(true)) {
caught_exception_ = std::current_exception();
} else {
prefix = "Secondary exception from operator chain starting at '";
}
#endif // CAFFE2_USE_EXCEPTION_PTR
LOG(ERROR) << prefix << operator_name << "' (type '" << operator_type
<< "'): " << exception_str << "\n";
#ifndef CAFFE2_USE_EXCEPTION_PTR
throw; // Can't capture for dispatch to other thread, re-throw here
#endif // CAFFE2_USE_EXCEPTION_PTR
}
void DAGNetBase::WorkerFunction() {
setThreadName("CaffeDAGNet");
// WorkerFunctions() is an infinite loop until there are no more jobs to run.
while (true) {
int idx = 0;
// Return if there are no more operators to run (e.g. the
// DAGNetBase is destructing, or there was an error on another
// worker and we're cleaning up).
if (!job_queue_->Pop(&idx)) {
return;
}
if (FLAGS_caffe2_dag_net_collect_stats) {
auto device_option =
operator_nodes_[idx].operator_->event().GetDeviceOption();
CAFFE_EVENT(
stats_[device_option.device_type()],
task_pool_wait_time_us,
task_timers_[idx]->MicroSeconds());
}
VLOG(1) << "Running chain starting at operator #" << idx << " "
<< operator_nodes_[idx].operator_->debug_def().name() << "("
<< operator_nodes_[idx].operator_->debug_def().type() << ").";
CAFFE_ENFORCE(
execution_chains_.find(idx) != execution_chains_.end(),
"Can't find chain ",
idx,
".");
bool this_success = false;
try {
this_success = RunAt(idx, execution_chains_[idx]);
if (!this_success) {
// If an exception was thrown, the operator def will get printed
// by Operator::Run[Async], but if no exception occurs we print it here.
LOG(ERROR) << "Operator chain failed starting at: "
<< ProtoDebugString(
operator_nodes_[idx].operator_->debug_def());
}
} catch (std::exception& e) {
std::string exception_str = c10::GetExceptionString(e);
HandleException(idx, exception_str);
} catch (...) {
std::string exception_str = "Unknown exception";
HandleException(idx, exception_str);
}
// Do book-keeping
std::vector<int> chains_to_queue;
const auto& chain = execution_chains_[idx];
for (const auto idx : chain) {
for (const auto child : operator_nodes_[idx].children_) {
const int count = --operator_nodes_[child].runtime_parent_count_;
CAFFE_ENFORCE(
count >= 0,
"Found runtime parent count smaller than zero for ",
"operator node ",
operator_nodes_[child].operator_->debug_def().name(),
"(",
operator_nodes_[child].operator_->debug_def().type(),
").");
if (count != 0) {
continue;
}
if (operator_nodes_[child].is_chain_start_) {
VLOG(2) << "Pushing chain #" << child << " to queue.";
chains_to_queue.push_back(child);
}
}
}
// Notify the caller of Run
{
std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
remaining_ops_ -= chain.size();
CAFFE_ENFORCE(remaining_ops_ >= 0);
success_ &= this_success;
if (remaining_ops_ == 0 || !success_) {
cv_.notify_one();
}
// Terminate thread if this or any other operator chain failed.
if (!success_) {
job_queue_->NoMoreJobs();
return;
}
// Queue follow up operator chains.
// Can't do this inline because it can race with another thread
// calling NoMoreJobs(). So the lock needs to be held on push.
for (const auto idx : chains_to_queue) {
if (FLAGS_caffe2_dag_net_collect_stats) {
task_timers_[idx]->Start();
}
job_queue_->Push(idx);
}
}
VLOG(2) << "Finished executing operator #" << idx;
}
}
bool DAGNet::RunAt(int chain_id, const std::vector<int>& chain) {
for (const auto i : chain) {
#ifdef CAFFE2_ENABLE_SDT
const auto& op_name =
operator_nodes_[i].operator_->debug_def().name().c_str();
const auto& op_type =
operator_nodes_[i].operator_->debug_def().type().c_str();
auto* op_ptr = operator_nodes_[i].operator_.get();
const auto& net_name = name_.c_str();
CAFFE_SDT(operator_start, net_name, op_name, op_type, op_ptr);
#endif
bool success = false;
{
TRACE_EVENT(tracing::TRACE_OP, i, tracing::TRACE_TASK, chain_id);
success = operator_nodes_[i].operator_->Run();
}
#ifdef CAFFE2_ENABLE_SDT
CAFFE_SDT(operator_done, net_name, op_name, op_type, op_ptr);
#endif
if (!success) {
return false;
}
}
if (FLAGS_caffe2_dag_net_collect_stats) {
auto device_option =
operator_nodes_[chain_id].operator_->event().GetDeviceOption();
CAFFE_EVENT(
stats_[device_option.device_type()],
task_time_to_succeeded_ms,
task_timers_[chain_id]->MilliSeconds());
}
return true;
}
REGISTER_NET(dag, DAGNet);
} // namespace caffe2