Skip to content

Commit 4b81d0d

Browse files
committed
[SYSTEMML-1828,1832] New rewrite for merging statement block sequences
This patch introduces a new statement block rewrite for merging DAGs of subsequent last-level statement blocks. After constant folding and the removal of unnecessary branches, we often end up with such sequences of statement blocks. Since many rewrites and operator fusion work on the granularity of individual DAGs, these unnecessary DAG cuts cause missed optimization opportunities, especially in the context of operator fusion (i.e., codegen). We now merge such sequences in awareness of rewrites that explicitly split DAGs (to create recompilation points). Apart from the new merge rewrite, this patch also fixes the IPA rewrite pass that applies static rewrites per IPA round. The repeated application of the statement block rewrite for injecting spark checkpoints for variables used read-only in loops introduced redundant statement blocks and checkpoints. The IPA rewrite pass now explicitly excludes this rewrite. Additionally, this patch also modifies the related tests to use 'while(FALSE){}' instead of 'if(1==1){}' as a DAG cut, and fixes some minor compilation issues that showed up due to the increases optimization scope. Overall, there are many scripts and patterns that benefit from these changes. For example, on 1 epoch of lenet w/ codegen, this patch improved end-to-end performance from 328s to 297s due to increased fusion opportunities and fewer compiled spark instructions (70 vs 82).
1 parent 0221fbc commit 4b81d0d

File tree

123 files changed

+744
-311
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

123 files changed

+744
-311
lines changed

docs/dml-language-reference.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1782,7 +1782,7 @@ The following DML utilizes the `transformencode()` function.
17821782
jspec = read("/user/ml/homes.tfspec_recode2.json", data_type="scalar", value_type="string");
17831783
[X, M] = transformencode(target=F1, spec=jspec);
17841784
print(toString(X));
1785-
if(1==1){}
1785+
while(FALSE){}
17861786
print(toString(M));
17871787

17881788
The transformed matrix X and output M are as follows.

projects/breast_cancer/MachineLearning.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@
683683
"source": [
684684
"# script = \"\"\"\n",
685685
"# f = function(matrix[double] X) return(matrix[double] Y) {\n",
686-
"# if (1==1) {}\n",
686+
"# while(FALSE){}\n",
687687
"# a = as.scalar(rand(rows=1, cols=1))\n",
688688
"# Y = X * a\n",
689689
"# }\n",

scripts/nn/test/test.dml

+1-1
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ max_pool2d = function() {
599599
for (padh in 0:3) {
600600
for (padw in 0:3) {
601601
print(" - Testing w/ padh="+padh+" & padw="+padw+".")
602-
#if (1==1) {} # force correct printing
602+
#while(FALSE){} # force correct printing
603603
#print(" - Testing forward")
604604
[out, Hout, Wout] = max_pool2d::forward(X, C, Hin, Win, Hf, Wf, stride, stride, padh, padw)
605605
[out_simple, Hout_simple, Wout_simple] = max_pool2d_simple::forward(X, C, Hin, Win, Hf, Wf,

src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import org.apache.sysml.hops.HopsException;
2424
import org.apache.sysml.hops.rewrite.ProgramRewriter;
25+
import org.apache.sysml.hops.rewrite.RewriteInjectSparkLoopCheckpointing;
2526
import org.apache.sysml.parser.DMLProgram;
2627
import org.apache.sysml.parser.LanguageException;
2728

@@ -43,7 +44,11 @@ public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionC
4344
throws HopsException
4445
{
4546
try {
47+
//construct rewriter w/o checkpoint injection to avoid redundancy
4648
ProgramRewriter rewriter = new ProgramRewriter(true, false);
49+
rewriter.removeStatementBlockRewrite(RewriteInjectSparkLoopCheckpointing.class);
50+
51+
//rewrite program hop dags and statement blocks
4752
rewriter.rewriteProgramHopDAGs(prog);
4853
}
4954
catch (LanguageException ex) {

src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java

+13
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,14 @@ public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 o
622622
return ternOp;
623623
}
624624

625+
public static DataOp createDataOp(String name, Hop input, DataOpTypes type) {
626+
DataOp dop = new DataOp(name, input.getDataType(), input.getValueType(), input, type, null);
627+
dop.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
628+
copyLineNumbers(input, dop);
629+
dop.refreshSizeInformation();
630+
return dop;
631+
}
632+
625633
public static void setOutputParameters( Hop hop, long rlen, long clen, long brlen, long bclen, long nnz ) {
626634
hop.setDim1( rlen );
627635
hop.setDim2( clen );
@@ -846,6 +854,11 @@ private static boolean rContainsInput(Hop current, Hop probe, HashSet<Long> memo
846854
return ret;
847855
}
848856

857+
public static boolean isData(Hop hop, DataOpTypes type) {
858+
return hop instanceof DataOp
859+
&& ((DataOp)hop).getDataOpType()==type;
860+
}
861+
849862
public static boolean isBinaryMatrixColVectorOperation(Hop hop) {
850863
return hop instanceof BinaryOp
851864
&& hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix()

src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java

+33-26
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.sysml.hops.rewrite;
2121

2222
import java.util.ArrayList;
23+
import java.util.List;
2324

2425
import org.apache.commons.logging.Log;
2526
import org.apache.commons.logging.LogFactory;
@@ -49,7 +50,7 @@
4950
* program.
5051
*
5152
*/
52-
public class ProgramRewriter
53+
public class ProgramRewriter
5354
{
5455
private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
5556

@@ -60,17 +61,15 @@ public class ProgramRewriter
6061
private ArrayList<HopRewriteRule> _dagRuleSet = null;
6162
private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null;
6263

63-
static{
64+
static {
6465
// for internal debugging only
6566
if( LDEBUG ) {
6667
Logger.getLogger("org.apache.sysml.hops.rewrite")
6768
.setLevel((Level) Level.DEBUG);
6869
}
69-
7070
}
7171

72-
public ProgramRewriter()
73-
{
72+
public ProgramRewriter() {
7473
// by default which is used during initial compile
7574
// apply all (static and dynamic) rewrites
7675
this( true, true );
@@ -107,12 +106,14 @@ public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites )
107106
_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock
108107

109108
//add statement block rewrite rules
110-
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
109+
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) {
111110
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
111+
_sbRuleSet.add( new RewriteMergeBlockSequence() ); //dependency: remove branches
112+
}
112113
if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )
113-
_sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock
114+
_sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock, merge blocks
114115
if( ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS) )
115-
_sbRuleSet.add( new RewriteSplitDagDataDependentOperators() );
116+
_sbRuleSet.add( new RewriteSplitDagDataDependentOperators() ); //dependency: merge blocks
116117
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
117118
_sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop)
118119
_sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes)
@@ -146,8 +147,7 @@ public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites )
146147
*
147148
* @param rewrites the HOP rewrite rules
148149
*/
149-
public ProgramRewriter( HopRewriteRule... rewrites )
150-
{
150+
public ProgramRewriter( HopRewriteRule... rewrites ) {
151151
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
152152
_dagRuleSet = new ArrayList<HopRewriteRule>();
153153
for( HopRewriteRule rewrite : rewrites )
@@ -161,8 +161,7 @@ public ProgramRewriter( HopRewriteRule... rewrites )
161161
*
162162
* @param rewrites the statement block rewrite rules
163163
*/
164-
public ProgramRewriter( StatementBlockRewriteRule... rewrites )
165-
{
164+
public ProgramRewriter( StatementBlockRewriteRule... rewrites ) {
166165
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
167166
_dagRuleSet = new ArrayList<HopRewriteRule>();
168167

@@ -177,8 +176,7 @@ public ProgramRewriter( StatementBlockRewriteRule... rewrites )
177176
* @param hRewrites HOP rewrite rules
178177
* @param sbRewrites statement block rewrite rules
179178
*/
180-
public ProgramRewriter( ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites )
181-
{
179+
public ProgramRewriter(ArrayList<HopRewriteRule> hRewrites, ArrayList<StatementBlockRewriteRule> sbRewrites) {
182180
//initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
183181
_dagRuleSet = new ArrayList<HopRewriteRule>();
184182
_dagRuleSet.addAll( hRewrites );
@@ -187,6 +185,14 @@ public ProgramRewriter( ArrayList<HopRewriteRule> hRewrites, ArrayList<Statement
187185
_sbRuleSet.addAll( sbRewrites );
188186
}
189187

188+
public void removeHopRewrite(Class<? extends HopRewriteRule> clazz) {
189+
_dagRuleSet.removeIf(r -> r.getClass().equals(clazz));
190+
}
191+
192+
public void removeStatementBlockRewrite(Class<? extends StatementBlockRewriteRule> clazz) {
193+
_sbRuleSet.removeIf(r -> r.getClass().equals(clazz));
194+
}
195+
190196
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp)
191197
throws LanguageException, HopsException
192198
{
@@ -301,21 +307,23 @@ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
301307
return root;
302308
}
303309

304-
public ArrayList<StatementBlock> rewriteStatementBlocks( ArrayList<StatementBlock> sbs, ProgramRewriteStatus state )
310+
public ArrayList<StatementBlock> rewriteStatementBlocks( ArrayList<StatementBlock> sbs, ProgramRewriteStatus status )
305311
throws HopsException
306312
{
307313
//ensure robustness for calls from outside
308-
if( state == null )
309-
state = new ProgramRewriteStatus();
310-
314+
if( status == null )
315+
status = new ProgramRewriteStatus();
311316

312-
ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
317+
//apply rewrite rules to list of statement blocks
318+
List<StatementBlock> sbList = sbs;
319+
for( StatementBlockRewriteRule r : _sbRuleSet ) {
320+
sbList = r.rewriteStatementBlocks(sbList, status);
321+
}
313322

314323
//rewrite statement blocks (with potential expansion)
315-
for( StatementBlock sb : sbs )
316-
tmp.addAll( rewriteStatementBlock(sb, state) );
317-
318-
//copy results into original collection
324+
ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
325+
for( StatementBlock sb : sbList )
326+
tmp.addAll( rewriteStatementBlock(sb, status) );
319327
sbs.clear();
320328
sbs.addAll( tmp );
321329

@@ -362,9 +370,8 @@ else if (sb instanceof ForStatementBlock) //incl parfor
362370
status.setInParforContext(prestatus);
363371
}
364372

365-
//apply rewrite rules
366-
for( StatementBlockRewriteRule r : _sbRuleSet )
367-
{
373+
//apply rewrite rules to individual statement blocks
374+
for( StatementBlockRewriteRule r : _sbRuleSet ) {
368375
ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
369376
for( StatementBlock sbc : ret )
370377
tmp.addAll( r.rewriteStatementBlock(sbc, status) );

src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

+3-5
Original file line numberDiff line numberDiff line change
@@ -1647,11 +1647,9 @@ private Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos)
16471647
inputargs.put("cast", new LiteralOp(false));
16481648

16491649
//create new hop
1650-
ParameterizedBuiltinOp pbop = new ParameterizedBuiltinOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
1651-
ParamBuiltinOp.REXPAND, inputargs);
1652-
pbop.setOutputBlocksizes(hi.getRowsInBlock(), hi.getColsInBlock());
1653-
pbop.refreshSizeInformation();
1654-
1650+
ParameterizedBuiltinOp pbop = HopRewriteUtils
1651+
.createParameterizedBuiltinOp(trgt, inputargs, ParamBuiltinOp.REXPAND);
1652+
16551653
//relink new hop into original position
16561654
HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
16571655
hi = pbop;

src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java

+10-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
package org.apache.sysml.hops.rewrite;
2121

22-
import java.util.ArrayList;
22+
import java.util.Arrays;
23+
import java.util.List;
2324

2425
import org.apache.sysml.hops.AggUnaryOp;
2526
import org.apache.sysml.hops.BinaryOp;
@@ -49,17 +50,14 @@
4950
*/
5051
public class RewriteForLoopVectorization extends StatementBlockRewriteRule
5152
{
52-
5353
private static final OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT, OpOp2.MIN, OpOp2.MAX};
5454
private static final AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = new AggOp[]{AggOp.SUM, AggOp.PROD, AggOp.MIN, AggOp.MAX};
5555

5656

5757
@Override
58-
public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state)
58+
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state)
5959
throws HopsException
6060
{
61-
ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
62-
6361
if( sb instanceof ForStatementBlock )
6462
{
6563
ForStatementBlock fsb = (ForStatementBlock) sb;
@@ -96,9 +94,13 @@ public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, Progra
9694

9795
//if no rewrite applied sb is the original for loop otherwise a last level statement block
9896
//that includes the equivalent vectorized operations.
99-
ret.add( sb );
100-
101-
return ret;
97+
return Arrays.asList(sb);
98+
}
99+
100+
@Override
101+
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs,
102+
ProgramRewriteStatus sate) throws HopsException {
103+
return sbs;
102104
}
103105

104106
private StatementBlock vectorizeScalarAggregate( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )

src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java

+16-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
package org.apache.sysml.hops.rewrite;
2121

2222
import java.util.ArrayList;
23+
import java.util.Arrays;
24+
import java.util.List;
2325

2426
import org.apache.sysml.hops.DataOp;
2527
import org.apache.sysml.hops.Hop;
@@ -47,28 +49,25 @@ public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRu
4749
{
4850
private boolean _checkCtx = false;
4951

50-
public RewriteInjectSparkLoopCheckpointing(boolean checkParForContext)
51-
{
52+
public RewriteInjectSparkLoopCheckpointing(boolean checkParForContext) {
5253
_checkCtx = checkParForContext;
5354
}
5455

5556
@Override
56-
public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status)
57+
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status)
5758
throws HopsException
5859
{
59-
ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
60-
61-
if( !OptimizerUtils.isSparkExecutionMode() )
62-
{
63-
ret.add(sb); // nothing to do here
64-
return ret; //return original statement block
60+
if( !OptimizerUtils.isSparkExecutionMode() ) {
61+
// nothing to do here, return original statement block
62+
return Arrays.asList(sb);
6563
}
6664

6765
//1) We currently add checkpoint operations without information about the global program structure,
6866
//this assumes that redundant checkpointing is prevented at runtime level (instruction-level)
6967
//2) Also, we do not take size information into account right now. This means that all candidates
7068
//are checkpointed even if they are only used by CP operations.
7169

70+
ArrayList<StatementBlock> ret = new ArrayList<>();
7271
int blocksize = status.getBlocksize(); //block size set by reblock rewrite
7372

7473
//apply rewrite for while, for, and parfor (the decision for parfor loop bodies is deferred until parfor
@@ -101,7 +100,7 @@ public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, Progra
101100
long dim2 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim2() : dat.getDim2();
102101
DataOp tread = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, DataOpTypes.TRANSIENTREAD,
103102
dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize, blocksize);
104-
tread.setRequiresCheckpoint( true );
103+
tread.setRequiresCheckpoint(true);
105104
DataOp twrite = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, tread, DataOpTypes.TRANSIENTWRITE, null);
106105
HopRewriteUtils.setOutputParameters(twrite, dim1, dim2, blocksize, blocksize, dat.getNnz());
107106
hops.add(twrite);
@@ -111,6 +110,7 @@ public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, Progra
111110
sb0.set_hops(hops);
112111
sb0.setLiveIn(livein);
113112
sb0.setLiveOut(liveout);
113+
sb0.setSplitDag(true);
114114
ret.add(sb0);
115115

116116
//maintain rewrite status
@@ -123,4 +123,10 @@ public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, Progra
123123

124124
return ret;
125125
}
126+
127+
@Override
128+
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs,
129+
ProgramRewriteStatus sate) throws HopsException {
130+
return sbs;
131+
}
126132
}

0 commit comments

Comments
 (0)