-
Notifications
You must be signed in to change notification settings - Fork 482
/
Copy pathProgramRewriter.java
386 lines (344 loc) · 15.9 KB
/
ProgramRewriter.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
/**
* This program rewriter applies a variety of rule-based rewrites
* on all hop dags of the given program in one pass over the entire
* program.
*
*/
public class ProgramRewriter{
private static final boolean CHECK = false;
static {
//Logger.getLogger("org.apache.sysds.hops.rewrite").setLevel(Level.DEBUG);
}
private ArrayList<HopRewriteRule> _dagRuleSet = null;
private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null;
public ProgramRewriter() {
// by default which is used during initial compile
// apply all (static and dynamic) rewrites
this( true, true );
}
public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
{
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList<>();
//initialize StatementBlock rewrite ruleSet (with fixed rewrite order)
_sbRuleSet = new ArrayList<>();
//STATIC REWRITES (which do not rely on size information)
if( staticRewrites )
{
//add static HOP DAG rewrite rules
_dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize
_dagRuleSet.add( new RewriteBlockSizeAndReblock() );
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again)
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
_dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications
_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock
if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
_dagRuleSet.add( new RewriteQuantizationFusedCompression() );
//add statement block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
if( OptimizerUtils.ALLOW_FOR_LOOP_REMOVAL )
_sbRuleSet.add( new RewriteRemoveForLoopEmptySequence() ); //dependency: constant folding
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL || OptimizerUtils.ALLOW_FOR_LOOP_REMOVAL )
_sbRuleSet.add( new RewriteMergeBlockSequence() ); //dependency: remove branches, remove for-loops
if(OptimizerUtils.ALLOW_COMPRESSION_REWRITE)
_sbRuleSet.add( new RewriteCompressedReblock() ); // Compression Rewrite
if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )
_sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock, merge blocks
if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS &&
ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS) )
_sbRuleSet.add( new RewriteSplitDagDataDependentOperators() ); //dependency: merge blocks
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
_sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop)
_sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes)
if( OptimizerUtils.ALLOW_CODE_MOTION )
_sbRuleSet.add( new RewriteHoistLoopInvariantOperations() ); //dependency: vectorize, but before inplace
if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE )
_sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() );
if( LineageCacheConfig.getCompAssRW() )
_sbRuleSet.add( new MarkForLineageReuse() );
_sbRuleSet.add( new RewriteRemoveTransformEncodeMeta() );
_dagRuleSet.add( new RewriteNonScalarPrint() );
}
// DYNAMIC REWRITES (which do require size information)
if( dynamicRewrites )
{
if ( DMLScript.USE_ACCELERATOR ){
_dagRuleSet.add( new RewriteGPUSpecificOps() ); // gpu-specific rewrites
}
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()); //dependency: cse
}
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationSparse() ); //dependency: cse
}
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) {
_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
}
}
// cleanup after all rewrites applied
// (newly introduced operators, introduced redundancy after rewrites w/ multiple parents)
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) );
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
_sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks() );
_sbRuleSet.add( new RewriteRemoveEmptyForLoops() );
}
/**
* Construct a program rewriter for a given rewrite which is passed from outside.
*
* @param rewrites the HOP rewrite rules
*/
public ProgramRewriter(HopRewriteRule... rewrites) {
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList<>();
for( HopRewriteRule rewrite : rewrites )
_dagRuleSet.add( rewrite );
_sbRuleSet = new ArrayList<>();
}
/**
* Construct a program rewriter for a given rewrite which is passed from outside.
*
* @param rewrites the statement block rewrite rules
*/
public ProgramRewriter(StatementBlockRewriteRule... rewrites) {
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList<>();
_sbRuleSet = new ArrayList<>();
for( StatementBlockRewriteRule rewrite : rewrites )
_sbRuleSet.add( rewrite );
}
/**
* Construct a program rewriter for the given rewrite sets which are passed from outside.
*
* @param hRewrites HOP rewrite rules
* @param sbRewrites statement block rewrite rules
*/
public ProgramRewriter(ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites) {
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
_dagRuleSet = new ArrayList<>();
_dagRuleSet.addAll( hRewrites );
_sbRuleSet = new ArrayList<>();
_sbRuleSet.addAll( sbRewrites );
}
public void removeHopRewrite(Class<? extends HopRewriteRule> clazz) {
_dagRuleSet.removeIf(r -> r.getClass().equals(clazz));
}
public void removeStatementBlockRewrite(Class<? extends StatementBlockRewriteRule> clazz) {
_sbRuleSet.removeIf(r -> r.getClass().equals(clazz));
}
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) {
return rewriteProgramHopDAGs(dmlp, true);
}
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags) {
return rewriteProgramHopDAGs(dmlp, splitDags, new ProgramRewriteStatus());
}
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags, ProgramRewriteStatus state) {
// for each namespace, handle function statement blocks
for (String namespaceKey : dmlp.getNamespaces().keySet())
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
rewriteHopDAGsFunction(fsblock, state, splitDags);
}
// handle regular statement blocks in "main" method
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
rRewriteStatementBlockHopDAGs(current, state);
}
if( !_sbRuleSet.isEmpty() )
dmlp.setStatementBlocks(rRewriteStatementBlocks(
dmlp.getStatementBlocks(), state, splitDags));
return state;
}
public void rewriteHopDAGsFunction(FunctionStatementBlock fsb, boolean splitDags) {
rewriteHopDAGsFunction(fsb, new ProgramRewriteStatus(), splitDags);
}
public void rewriteHopDAGsFunction(FunctionStatementBlock fsb, ProgramRewriteStatus state, boolean splitDags) {
rRewriteStatementBlockHopDAGs(fsb, state);
if( !_sbRuleSet.isEmpty() )
rRewriteStatementBlock(fsb, state, splitDags);
}
public void rRewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) {
//ensure robustness for calls from outside
if( state == null )
state = new ProgramRewriteStatus();
if (current instanceof FunctionStatementBlock)
{
FunctionStatementBlock fsb = (FunctionStatementBlock)current;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for (StatementBlock sb : fstmt.getBody())
rRewriteStatementBlockHopDAGs(sb, state);
}
else if (current instanceof WhileStatementBlock)
{
WhileStatementBlock wsb = (WhileStatementBlock) current;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
for (StatementBlock sb : wstmt.getBody())
rRewriteStatementBlockHopDAGs(sb, state);
}
else if (current instanceof IfStatementBlock)
{
IfStatementBlock isb = (IfStatementBlock) current;
IfStatement istmt = (IfStatement)isb.getStatement(0);
isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
for (StatementBlock sb : istmt.getIfBody())
rRewriteStatementBlockHopDAGs(sb, state);
for (StatementBlock sb : istmt.getElseBody())
rRewriteStatementBlockHopDAGs(sb, state);
}
else if (current instanceof ForStatementBlock) //incl parfor
{
ForStatementBlock fsb = (ForStatementBlock) current;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state));
fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
for (StatementBlock sb : fstmt.getBody())
rRewriteStatementBlockHopDAGs(sb, state);
}
else //generic (last-level)
{
current.setHops( rewriteHopDAG(current.getHops(), state) );
}
}
public ArrayList<Hop> rewriteHopDAG(ArrayList<Hop> roots, ProgramRewriteStatus state) {
for( HopRewriteRule r : _dagRuleSet ) {
Hop.resetVisitStatus( roots ); //reset for each rule
roots = r.rewriteHopDAGs(roots, state);
if( CHECK )
HopDagValidator.validateHopDag(roots, r);
}
return roots;
}
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
if( root == null )
return null;
for( HopRewriteRule r : _dagRuleSet ) {
root.resetVisitStatus(); //reset for each rule
root = r.rewriteHopDAG(root, state);
if( CHECK )
HopDagValidator.validateHopDag(root, r);
}
return root;
}
public ArrayList<StatementBlock> rRewriteStatementBlocks(ArrayList<StatementBlock> sbs, ProgramRewriteStatus status, boolean splitDags) {
//ensure robustness for calls from outside
if( status == null )
status = new ProgramRewriteStatus();
//apply rewrite rules to list of statement blocks
List<StatementBlock> tmp = sbs;
for( StatementBlockRewriteRule r : _sbRuleSet )
if( splitDags || !r.createsSplitDag() )
tmp = r.rewriteStatementBlocks(tmp, status);
//recursively rewrite statement blocks (with potential expansion)
List<StatementBlock> tmp2 = new ArrayList<>();
for( StatementBlock sb : tmp )
tmp2.addAll( rRewriteStatementBlock(sb, status, splitDags) );
//apply rewrite rules to list of statement blocks (with potential contraction)
for( StatementBlockRewriteRule r : _sbRuleSet )
if( splitDags || !r.createsSplitDag() )
tmp2 = r.rewriteStatementBlocks(tmp2, status);
//prepare output list
sbs.clear();
sbs.addAll(tmp2);
return sbs;
}
public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status, boolean splitDags) {
ArrayList<StatementBlock> ret = new ArrayList<>();
ret.add(sb);
//recursive invocation
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
}
else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
wstmt.setBody(rRewriteStatementBlocks(wstmt.getBody(), status, splitDags));
}
else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement)isb.getStatement(0);
istmt.setIfBody(rRewriteStatementBlocks(istmt.getIfBody(), status, splitDags));
istmt.setElseBody(rRewriteStatementBlocks(istmt.getElseBody(), status, splitDags));
}
else if (sb instanceof ForStatementBlock) { //incl parfor
//maintain parfor context information (e.g., for checkpointing)
boolean prestatus = status.isInParforContext();
if( sb instanceof ParForStatementBlock )
status.setInParforContext(true);
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
status.setInParforContext(prestatus);
}
//apply rewrite rules to individual statement blocks
for( StatementBlockRewriteRule r : _sbRuleSet ) {
if( !splitDags && r.createsSplitDag() )
continue;
ArrayList<StatementBlock> tmp = new ArrayList<>();
for( StatementBlock sbc : ret )
tmp.addAll( r.rewriteStatementBlock(sbc, status) );
//take over set of rewritten sbs
ret.clear();
ret.addAll(tmp);
}
return ret;
}
}