forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet_async_dag_gpu.h
40 lines (31 loc) · 1.24 KB
/
net_async_dag_gpu.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#ifndef CAFFE2_CORE_NET_ASYNC_DAG_GPU_H_
#define CAFFE2_CORE_NET_ASYNC_DAG_GPU_H_
#include "caffe2/core/common.h"
#include "caffe2/core/net_dag.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
// Run an event-driven graph - before each operator chain, wait on each parent
// operator for the chain source, then execute each operator. Due to the chain
// construction mechanism, operators in the same chain implicitly runs on the
// same stream.
// AsyncDAGNet is only registered in gpu mode, because CPU code is always sync
// and a CPU only AsyncDAG net is essentially a DAG net.
class AsyncDAGNet : public DAGNetBase {
public:
AsyncDAGNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
bool SupportsAsync() override {
return true;
}
bool RunAt(int chain_id, const std::vector<int>& chain) override;
protected:
bool DoRunAsync() override;
// Tracks whether a given op has had an event recorded in each
// RunAt() iteration.
std::vector<int32_t> eventRecorded_;
int stream(const DeviceOption& device_option);
static thread_local std::vector<int> stream_counters_;
C10_DISABLE_COPY_AND_ASSIGN(AsyncDAGNet);
};
} // namespace caffe2
#endif // CAFFE2_CORE_NET_ASYNC_DAG_GPU_H_