Skip to content

Commit cfbe190

Browse files
ReneEnjilianmboehm7
authored andcommitted
[SYSTEMDS-3774] Improved test coverage of simplification rewrites
Closes #2240.
1 parent 0db6159 commit cfbe190

File tree

46 files changed

+3278
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+3278
-1
lines changed

src/main/java/org/apache/sysds/hops/OptimizerUtils.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public enum MemoryManager {
199199
public static boolean ALLOW_SUM_PRODUCT_REWRITES2 = true;
200200

201201
/**
202-
* Enables additional mmchain optimizations. in the future, this might be merged with
202+
* Enables additional mmchain optimizations. In the future, this might be merged with
203203
* ALLOW_SUM_PRODUCT_REWRITES.
204204
*/
205205
public static boolean ALLOW_ADVANCED_MMCHAIN_REWRITES = false;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.hops.OptimizerUtils;
24+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
25+
import org.apache.sysds.test.AutomatedTestBase;
26+
import org.apache.sysds.test.TestConfiguration;
27+
import org.apache.sysds.test.TestUtils;
28+
import org.junit.Assert;
29+
import org.junit.Test;
30+
31+
import java.util.HashMap;
32+
33+
public class RewriteCanonicalizeMatrixMultScalarAddTest extends AutomatedTestBase {
34+
35+
private static final String TEST_NAME = "RewriteCanonicalizeMatrixMultScalarAdd";
36+
private static final String TEST_DIR = "functions/rewrite/";
37+
private static final String TEST_CLASS_DIR =
38+
TEST_DIR + RewriteCanonicalizeMatrixMultScalarAddTest.class.getSimpleName() + "/";
39+
40+
private static final int rows = 500;
41+
private static final int cols = 500;
42+
private static final double eps = Math.pow(10, -10);
43+
44+
@Override
45+
public void setUp() {
46+
TestUtils.clearAssertionInformation();
47+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
48+
}
49+
50+
@Test
51+
public void testCanonicalizeMatrixMultScalarAddPosNoRewrite() {
52+
testRewriteCanonicalizeMatrixMultScalarAdd(1, false);
53+
}
54+
55+
@Test
56+
public void testCanonicalizeMatrixMultScalarAddPosRewrite() {
57+
testRewriteCanonicalizeMatrixMultScalarAdd(1, true); // (z + U%*%V) -> (U%*%V + z)
58+
}
59+
60+
@Test
61+
public void testCanonicalizeMatrixMultScalarAddNegNoRewrite() {
62+
testRewriteCanonicalizeMatrixMultScalarAdd(2, false);
63+
}
64+
65+
@Test
66+
public void testCanonicalizeMatrixMultScalarAddNegRewrite() {
67+
testRewriteCanonicalizeMatrixMultScalarAdd(2, true); // (U%*%V - z) -> (U%*%V + (-z))
68+
}
69+
70+
private void testRewriteCanonicalizeMatrixMultScalarAdd(int ID, boolean rewrites) {
71+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
72+
try {
73+
TestConfiguration config = getTestConfiguration(TEST_NAME);
74+
loadTestConfiguration(config);
75+
76+
String HOME = SCRIPT_DIR + TEST_DIR;
77+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
78+
programArgs = new String[] {"-stats", "-args", input("U"), input("V"), String.valueOf(ID), output("R")};
79+
fullRScriptName = HOME + TEST_NAME + ".R";
80+
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());
81+
82+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
83+
84+
// create and write matrices
85+
double[][] U = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5);
86+
double[][] V = getRandomMatrix(rows, cols, -1, 1, 0.60d, 4);
87+
writeInputMatrixWithMTD("U", U, true);
88+
writeInputMatrixWithMTD("V", V, true);
89+
90+
runTest(true, false, null, -1);
91+
runRScript(true);
92+
93+
//compare matrices
94+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
95+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
96+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
97+
98+
if(ID == 1) {
99+
if(rewrites)
100+
Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString()));
101+
else
102+
Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString()));
103+
}
104+
else if(ID == 2) {
105+
if(rewrites)
106+
Assert.assertTrue(heavyHittersContainsString(Opcodes.PLUS.toString()));
107+
else
108+
Assert.assertFalse(heavyHittersContainsString(Opcodes.PLUS.toString()));
109+
}
110+
111+
}
112+
finally {
113+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
114+
}
115+
}
116+
117+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.hops.OptimizerUtils;
24+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
25+
import org.apache.sysds.test.AutomatedTestBase;
26+
import org.apache.sysds.test.TestConfiguration;
27+
import org.apache.sysds.test.TestUtils;
28+
import org.apache.sysds.utils.Statistics;
29+
import org.junit.Assert;
30+
import org.junit.Test;
31+
32+
import java.util.HashMap;
33+
34+
public class RewriteFuseOrderOperationChainTest extends AutomatedTestBase {
35+
36+
private static final String TEST_NAME = "RewriteFuseOrderOperationChain";
37+
private static final String TEST_DIR = "functions/rewrite/";
38+
private static final String TEST_CLASS_DIR =
39+
TEST_DIR + RewriteFuseOrderOperationChainTest.class.getSimpleName() + "/";
40+
41+
private static final int rows = 500;
42+
private static final int cols = 500;
43+
private static final double eps = Math.pow(10, -10);
44+
45+
@Override
46+
public void setUp() {
47+
TestUtils.clearAssertionInformation();
48+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
49+
}
50+
51+
@Test
52+
public void testFuseOrderOperationChainNoRewrite() {
53+
testRewriteFuseOrderOperationChain(false);
54+
}
55+
56+
@Test
57+
public void testFuseOrderOperationChainRewrite() {
58+
testRewriteFuseOrderOperationChain(true);
59+
}
60+
61+
private void testRewriteFuseOrderOperationChain(boolean rewrites) {
62+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
63+
try {
64+
TestConfiguration config = getTestConfiguration(TEST_NAME);
65+
loadTestConfiguration(config);
66+
67+
String HOME = SCRIPT_DIR + TEST_DIR;
68+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
69+
programArgs = new String[] {"-stats", "-args", input("X"), output("R")};
70+
fullRScriptName = HOME + TEST_NAME + ".R";
71+
rCmd = getRCmd(inputDir(), expectedDir());
72+
73+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
74+
75+
// create and write matrices
76+
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5);
77+
writeInputMatrixWithMTD("X", X, true);
78+
79+
runTest(true, false, null, -1);
80+
runRScript(true);
81+
82+
//compare matrices
83+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
84+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
85+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
86+
87+
long numOrder = Statistics.getCPHeavyHitterCount(Opcodes.SORT.toString());
88+
if(rewrites)
89+
Assert.assertEquals(numOrder, 1);
90+
else
91+
Assert.assertEquals(numOrder, 2);
92+
93+
}
94+
finally {
95+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
96+
}
97+
}
98+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import org.apache.sysds.common.Opcodes;
23+
import org.apache.sysds.hops.OptimizerUtils;
24+
import org.apache.sysds.runtime.matrix.data.MatrixValue;
25+
import org.apache.sysds.test.AutomatedTestBase;
26+
import org.apache.sysds.test.TestConfiguration;
27+
import org.apache.sysds.test.TestUtils;
28+
import org.junit.Assert;
29+
import org.junit.Test;
30+
31+
import java.util.HashMap;
32+
33+
public class RewriteRemoveUnnecessaryBinaryOperationTest extends AutomatedTestBase {
34+
35+
private static final String TEST_NAME = "RewriteRemoveUnnecessaryBinaryOperation";
36+
private static final String TEST_DIR = "functions/rewrite/";
37+
private static final String TEST_CLASS_DIR =
38+
TEST_DIR + RewriteRemoveUnnecessaryBinaryOperationTest.class.getSimpleName() + "/";
39+
40+
private static final int rows = 500;
41+
private static final int cols = 500;
42+
private static final double eps = Math.pow(10, -10);
43+
44+
@Override
45+
public void setUp() {
46+
TestUtils.clearAssertionInformation();
47+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
48+
}
49+
50+
@Test
51+
public void testRemoveUnnecessaryBinaryOperationDivNoRewrite() {
52+
testRewriteRemoveUnnecessaryBinaryOperation(1, false);
53+
}
54+
55+
@Test
56+
public void testRemoveUnnecessaryBinaryOperationDivRewrite() {
57+
testRewriteRemoveUnnecessaryBinaryOperation(1, true); // X/1
58+
}
59+
60+
@Test
61+
public void testRemoveUnnecessaryBinaryOperationMultRightNoRewrite() {
62+
testRewriteRemoveUnnecessaryBinaryOperation(2, false);
63+
}
64+
65+
@Test
66+
public void testRemoveUnnecessaryBinaryOperationMultRightRewrite() {
67+
testRewriteRemoveUnnecessaryBinaryOperation(2, true); // X*1
68+
}
69+
70+
@Test
71+
public void testRemoveUnnecessaryBinaryOperationMultLeftNoRewrite() {
72+
testRewriteRemoveUnnecessaryBinaryOperation(3, false);
73+
}
74+
75+
@Test
76+
public void testRemoveUnnecessaryBinaryOperationMultLeftRewrite() {
77+
testRewriteRemoveUnnecessaryBinaryOperation(3, true); // 1*X
78+
}
79+
80+
@Test
81+
public void testRemoveUnnecessaryBinaryOperationMinusNoRewrite() {
82+
testRewriteRemoveUnnecessaryBinaryOperation(4, false);
83+
}
84+
85+
@Test
86+
public void testRemoveUnnecessaryBinaryOperationMinusRewrite() {
87+
testRewriteRemoveUnnecessaryBinaryOperation(4, true); // X-0
88+
}
89+
90+
@Test
91+
public void testRemoveUnnecessaryBinaryOperationNegMultLeftNoRewrite() {
92+
testRewriteRemoveUnnecessaryBinaryOperation(5, false);
93+
}
94+
95+
@Test
96+
public void testRemoveUnnecessaryBinaryOperationNegMultLeftRewrite() {
97+
testRewriteRemoveUnnecessaryBinaryOperation(5, true); // -1*X
98+
}
99+
100+
@Test
101+
public void testRemoveUnnecessaryBinaryOperationNegMultRightNoRewrite() {
102+
testRewriteRemoveUnnecessaryBinaryOperation(6, false);
103+
}
104+
105+
@Test
106+
public void testRemoveUnnecessaryBinaryOperationNegMultRightRewrite() {
107+
testRewriteRemoveUnnecessaryBinaryOperation(6, true); // X*-1
108+
}
109+
110+
private void testRewriteRemoveUnnecessaryBinaryOperation(int ID, boolean rewrites) {
111+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
112+
try {
113+
TestConfiguration config = getTestConfiguration(TEST_NAME);
114+
loadTestConfiguration(config);
115+
116+
String HOME = SCRIPT_DIR + TEST_DIR;
117+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
118+
programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")};
119+
fullRScriptName = HOME + TEST_NAME + ".R";
120+
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());
121+
122+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
123+
124+
// create and write matrix
125+
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5);
126+
writeInputMatrixWithMTD("X", X, true);
127+
128+
runTest(true, false, null, -1);
129+
runRScript(true);
130+
131+
//compare matrices
132+
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
133+
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
134+
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
135+
136+
if(ID == 1) {
137+
if(rewrites)
138+
Assert.assertFalse(heavyHittersContainsString(Opcodes.DIV.toString()));
139+
else
140+
Assert.assertTrue(heavyHittersContainsString(Opcodes.DIV.toString()));
141+
}
142+
else if(ID == 2 || ID == 3) {
143+
if(rewrites)
144+
Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString()));
145+
else
146+
Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString()));
147+
}
148+
else if(ID == 4) {
149+
if(rewrites)
150+
Assert.assertFalse(heavyHittersContainsString(Opcodes.MINUS.toString()));
151+
else
152+
Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS.toString()));
153+
}
154+
else if(ID == 5 || ID == 6) {
155+
if(rewrites)
156+
Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS.toString()) &&
157+
!heavyHittersContainsString(Opcodes.MULT.toString()));
158+
else
159+
Assert.assertTrue(!heavyHittersContainsString(Opcodes.MINUS.toString()) &&
160+
heavyHittersContainsString(Opcodes.MULT.toString()));
161+
}
162+
163+
}
164+
finally {
165+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
166+
}
167+
}
168+
}

0 commit comments

Comments
 (0)