Skip to content

Commit 55b196f

Browse files
committed
Rewrite test fix, incorporated comments
1 parent a6ba80a commit 55b196f

File tree

8 files changed

+306
-267
lines changed

8 files changed

+306
-267
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java

+80-80
Original file line numberDiff line numberDiff line change
@@ -42,84 +42,84 @@
4242
* identifies the pattern without applying fusion.
4343
*/
4444
public class RewriteQuantizationFusedCompression extends HopRewriteRule {
45-
@Override
46-
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
47-
if(roots == null)
48-
return null;
49-
50-
// traverse the HOP DAG
51-
HashMap<String, Hop> floors = new HashMap<>();
52-
HashMap<String, Hop> compresses = new HashMap<>();
53-
for(Hop h : roots)
54-
collectFloorCompressSequences(h, floors, compresses);
55-
56-
Hop.resetVisitStatus(roots);
57-
58-
// check compresses for compress-after-floor pattern
59-
for(Entry<String, Hop> e : compresses.entrySet()) {
60-
String inputname = e.getKey();
61-
Hop compresshop = e.getValue();
62-
63-
if(floors.containsKey(inputname) // floors same name
64-
&& ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) ||
65-
(floors.get(inputname).getEndLine() < compresshop.getEndLine()) ||
66-
(floors.get(inputname).getBeginLine() == compresshop.getBeginLine() &&
67-
floors.get(inputname).getEndLine() == compresshop.getBeginLine() &&
68-
floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) {
69-
70-
// retrieve the floor hop and inputs
71-
Hop floorhop = floors.get(inputname);
72-
Hop floorInput = floorhop.getInput().get(0);
73-
74-
// check if the input of the floor operation is a matrix
75-
if(floorInput.getDataType() == DataType.MATRIX) {
76-
77-
// Check if the input of the floor operation involves a multiplication operation
78-
if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) {
79-
Hop initialMatrix = floorInput.getInput().get(0);
80-
Hop sf = floorInput.getInput().get(1);
81-
82-
// create fused hop
83-
BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64,
84-
OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf);
85-
86-
// rewire compress consumers to fusedHop
87-
List<Hop> parents = new ArrayList<>(compresshop.getParent());
88-
for(Hop p : parents) {
89-
HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop);
90-
}
91-
}
92-
}
93-
}
94-
}
95-
return roots;
96-
}
97-
98-
@Override
99-
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
100-
// do nothing, floor/compress do not occur in predicates
101-
return root;
102-
}
103-
104-
private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) {
105-
if(hop.isVisited())
106-
return;
107-
108-
// process childs
109-
if(!hop.getInput().isEmpty())
110-
for(Hop c : hop.getInput())
111-
collectFloorCompressSequences(c, floors, compresses);
112-
113-
// process current hop
114-
if(hop instanceof UnaryOp) {
115-
UnaryOp uop = (UnaryOp) hop;
116-
if(uop.getOp() == OpOp1.FLOOR) {
117-
floors.put(uop.getName(), uop);
118-
}
119-
else if(uop.getOp() == OpOp1.COMPRESS) {
120-
compresses.put(uop.getInput(0).getName(), uop);
121-
}
122-
}
123-
hop.setVisited();
124-
}
45+
@Override
46+
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
47+
if(roots == null)
48+
return null;
49+
50+
// traverse the HOP DAG
51+
HashMap<String, Hop> floors = new HashMap<>();
52+
HashMap<String, Hop> compresses = new HashMap<>();
53+
for(Hop h : roots)
54+
collectFloorCompressSequences(h, floors, compresses);
55+
56+
Hop.resetVisitStatus(roots);
57+
58+
// check compresses for compress-after-floor pattern
59+
for(Entry<String, Hop> e : compresses.entrySet()) {
60+
String inputname = e.getKey();
61+
Hop compresshop = e.getValue();
62+
63+
if(floors.containsKey(inputname) // floors same name
64+
&& ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) ||
65+
(floors.get(inputname).getEndLine() < compresshop.getEndLine()) ||
66+
(floors.get(inputname).getBeginLine() == compresshop.getBeginLine() &&
67+
floors.get(inputname).getEndLine() == compresshop.getBeginLine() &&
68+
floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) {
69+
70+
// retrieve the floor hop and inputs
71+
Hop floorhop = floors.get(inputname);
72+
Hop floorInput = floorhop.getInput().get(0);
73+
74+
// check if the input of the floor operation is a matrix
75+
if(floorInput.getDataType() == DataType.MATRIX) {
76+
77+
// Check if the input of the floor operation involves a multiplication operation
78+
if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) {
79+
Hop initialMatrix = floorInput.getInput().get(0);
80+
Hop sf = floorInput.getInput().get(1);
81+
82+
// create fused hop
83+
BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64,
84+
OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf);
85+
86+
// rewire compress consumers to fusedHop
87+
List<Hop> parents = new ArrayList<>(compresshop.getParent());
88+
for(Hop p : parents) {
89+
HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop);
90+
}
91+
}
92+
}
93+
}
94+
}
95+
return roots;
96+
}
97+
98+
@Override
99+
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
100+
// do nothing, floor/compress do not occur in predicates
101+
return root;
102+
}
103+
104+
private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) {
105+
if(hop.isVisited())
106+
return;
107+
108+
// process childs
109+
if(!hop.getInput().isEmpty())
110+
for(Hop c : hop.getInput())
111+
collectFloorCompressSequences(c, floors, compresses);
112+
113+
// process current hop
114+
if(hop instanceof UnaryOp) {
115+
UnaryOp uop = (UnaryOp) hop;
116+
if(uop.getOp() == OpOp1.FLOOR) {
117+
floors.put(uop.getName(), uop);
118+
}
119+
else if(uop.getOp() == OpOp1.COMPRESS) {
120+
compresses.put(uop.getInput(0).getName(), uop);
121+
}
122+
}
123+
hop.setVisited();
124+
}
125125
}

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

-25
Original file line numberDiff line numberDiff line change
@@ -751,31 +751,6 @@ else if(((ConstIdentifier) getThirdExpr().getOutput())
751751
else
752752
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
753753
break;
754-
755-
case QUANTIZE_COMPRESS: // this is not used when with two arguments it seems
756-
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) {
757-
Expression expressionTwo = getSecondExpr();
758-
checkNumParameters(getSecondExpr() != null ? 2 : 1);
759-
checkMatrixFrameParam(getFirstExpr());
760-
if(expressionTwo != null)
761-
checkMatrixParam(getSecondExpr());
762-
763-
Identifier compressInput1 = getFirstExpr().getOutput();
764-
// Identifier compressInput2 = getSecondExpr().getOutput();
765-
766-
DataIdentifier compressOutput = (DataIdentifier) getOutputs()[0];
767-
compressOutput.setDataType(DataType.MATRIX);
768-
compressOutput.setDimensions(compressInput1.getDim1(), compressInput1.getDim2());
769-
compressOutput.setBlocksize(compressInput1.getBlocksize());
770-
compressOutput.setValueType(compressInput1.getValueType());
771-
772-
DataIdentifier metaOutput = (DataIdentifier) getOutputs()[1];
773-
metaOutput.setDataType(DataType.FRAME);
774-
metaOutput.setDimensions(compressInput1.getDim1(), -1);
775-
}
776-
else
777-
raiseValidateError("Quantize and compress instruction not allowed in dml script");
778-
break;
779754

780755
default: //always unconditional
781756
raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false);

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java

+1-5
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,7 @@ public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb,
141141
public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, MatrixBlock sf, int k, WTreeRoot root) {
142142
// Handle only row vectors, as column-wise quantization is not allowed.
143143
// The restriction is handled upstream
144-
LOG.debug("Number of columns of sf " + sf.getNumColumns());
145-
double[] scaleFactors = new double[sf.getNumRows()];
146-
for (int i = 0; i < sf.getNumRows(); i++) {
147-
scaleFactors[i] = sf.get(i,0);
148-
}
144+
double[] scaleFactors = sf.getDenseBlockValues();
149145
CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors);
150146
return compress(mb, k, builder, root);
151147
}

src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java

+8
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ public CompressionSettingsBuilder() {
7070

7171
}
7272

73+
/**
74+
* Sets the scale factors for compression, enabling quantization-fused compression.
75+
*
76+
* @param scaleFactors An array of scale factors applied during compression.
77+
* - If row-wise scaling is used, this should be an array where each value corresponds to a row.
78+
* - If a single scalar is provided, it is applied uniformly to the entire matrix.
79+
* @return The CompressionSettingsBuilder instance with the updated scale factors.
80+
*/
7381
public CompressionSettingsBuilder setScaleFactor(double[] scaleFactors) {
7482
this.scaleFactors = scaleFactors;
7583
return this;

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java

-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import java.util.concurrent.ExecutorService;
3030
import java.util.concurrent.Future;
3131

32-
import javax.swing.plaf.basic.BasicInternalFrameTitlePane.SystemMenuBar;
33-
3432
import org.apache.commons.lang3.NotImplementedException;
3533
import org.apache.commons.logging.Log;
3634
import org.apache.commons.logging.LogFactory;

0 commit comments

Comments
 (0)