-
Notifications
You must be signed in to change notification settings - Fork 337
/
Copy pathONNXOpTransformPass.cpp
127 lines (111 loc) · 4.61 KB
/
ONNXOpTransformPass.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===------- ONNXOpTransformPass.cpp - ONNX Op Transform ------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a combined pass that dynamically invoke several
// transformation on ONNX ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"
using namespace mlir;
namespace {
struct ONNXOpTransformPass : public mlir::PassWrapper<ONNXOpTransformPass,
OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXOpTransformPass)
StringRef getArgument() const override { return "onnx-op-transform"; }
StringRef getDescription() const override {
return "Invoke passes iteratively that transform ONNX operation.";
}
Option<int> onnxOpTransformThreshold{*this, "onnx-op-transform-threshold",
llvm::cl::desc("max iteration for op transform passes."),
llvm::cl::init(10)};
Option<bool> onnxOpTransformReport{*this, "onnx-op-transform-report",
llvm::cl::desc("Report diagnostic info for op transform passes."),
llvm::cl::init(false)};
Option<bool> onnxOpTransformTargetCPU{*this, "onnx-op-transform-target-cpu",
llvm::cl::desc("Target CPU op transform passes."), llvm::cl::init(true)};
Option<bool> onnxOpTransformEnableSimdDataLayout{*this,
"onnx-op-transform-simd-data-layout",
llvm::cl::desc("Enable SIMD data layout opt in op transform passes."),
llvm::cl::init(false)};
Option<bool> enableConvOptPass{*this, "enable-conv-opt-pass",
llvm::cl::desc("Enable the ConvOptPass. Default is true."),
llvm::cl::init(true)};
ONNXOpTransformPass() = default;
ONNXOpTransformPass(const ONNXOpTransformPass &pass)
: mlir::PassWrapper<ONNXOpTransformPass,
OperationPass<mlir::ModuleOp>>() {}
ONNXOpTransformPass(int threshold, bool report, bool targetCPU,
bool enableSimdDataLayoutOpt, bool enableConvOptPass) {
this->onnxOpTransformThreshold = threshold;
this->onnxOpTransformReport = report;
this->onnxOpTransformTargetCPU = targetCPU;
this->onnxOpTransformEnableSimdDataLayout = enableSimdDataLayoutOpt;
this->enableConvOptPass = enableConvOptPass;
}
void runOnOperation() final;
};
void ONNXOpTransformPass::runOnOperation() {
auto module = getOperation();
assert(onnxOpTransformThreshold > 0);
int n = onnxOpTransformThreshold;
OperationFingerPrint before(module);
do {
OpPassManager dynamicPM("builtin.module");
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createDecomposeONNXToONNXPass());
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createShapeInferencePass());
dynamicPM.addPass(mlir::createCanonicalizerPass());
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createShapeInferencePass());
// Convolution Optimization currently only for CPU.
if (onnxOpTransformTargetCPU && enableConvOptPass) {
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createConvOptONNXToONNXPass(
onnxOpTransformEnableSimdDataLayout));
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createShapeInferencePass());
}
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createConstPropONNXToONNXPass());
if (failed(runPipeline(dynamicPM, module)))
return signalPassFailure();
OperationFingerPrint after(module);
if (after == before)
break;
before = after;
} while (--n > 0);
if (n == 0) {
module->emitWarning()
<< "ONNXOpTransform did not converge after " << onnxOpTransformThreshold
<< "iterations. "
<< "You may set a higher threshold with command option";
}
if (onnxOpTransformReport) {
llvm::outs() << "ONNXOpTransform iterated " << onnxOpTransformThreshold - n
<< " times, converged " << (n > 0 ? "true" : "false") << "\n";
}
}
} // end anonymous namespace
/*!
* Create an instrumentation pass.
*/
std::unique_ptr<mlir::Pass> onnx_mlir::createONNXOpTransformPass() {
return std::make_unique<ONNXOpTransformPass>();
}
std::unique_ptr<mlir::Pass> onnx_mlir::createONNXOpTransformPass(int threshold,
bool report, bool targetCPU, bool enableSimdDataLayoutOpt,
bool enableConvOptPass) {
return std::make_unique<ONNXOpTransformPass>(
threshold, report, targetCPU, enableSimdDataLayoutOpt, enableConvOptPass);
}