Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3780] Compression-fused quantization #2226

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public enum Builtins {
COLVAR("colVars", false),
COMPONENTS("components", true),
COMPRESS("compress", false, ReturnType.MULTI_RETURN),
QUANTIZE_COMPRESS("quantize_compress", false, ReturnType.MULTI_RETURN),
CONFUSIONMATRIX("confusionMatrix", true),
CONV2D("conv2d", false),
CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false),
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Opcodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ public enum Opcodes {
PARTITION("partition", CPType.Partition),
COMPRESS(Compression.OPCODE, CPType.Compression),
DECOMPRESS(DeCompression.OPCODE, CPType.DeCompression),
QUANTIZE_COMPRESS("quantize_compress", CPType.QuantizeCompression),
SPOOF("spoof", CPType.SpoofFused),
PREFETCH("prefetch", CPType.Prefetch),
EVICT("_evict", CPType.EvictLineageCache),
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,9 @@ public enum OpOp2 {
//fused ML-specific operators for performance
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
MINUS1_MULT(false); //1-X*Y

MINUS1_MULT(false), //1-X*Y
QUANTIZE_COMPRESS(false); //quantization-fused compression

private final boolean _validOuter;

private OpOp2(boolean outer) {
Expand Down
9 changes: 9 additions & 0 deletions src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ public enum MemoryManager {
*/
public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;

/**
* This variable allows for insertion of Quantize and compress in the dml script from the user.
*/
public static boolean ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND = true;

/**
* Boolean specifying if quantization-fused compression rewrite is allowed.
*/
public static boolean ALLOW_QUANTIZE_COMPRESS_REWRITE = true;

/**
* Boolean specifying if compression rewrites is allowed. This is disabled at run time if the IPA for Workload aware compression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
_dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications
_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock

if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
_dagRuleSet.add( new RewriteQuantizationFusedCompression() );

//add statement block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;

import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.hops.BinaryOp;

import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;

import org.apache.sysds.hops.Hop;

/**
* Rule: RewriteFloorCompress. Detects the sequence `M2 = floor(M * S)` followed by `C = compress(M2)` and prepares for
* fusion into a single operation. This rewrite improves performance by avoiding intermediate results. Currently, it
* identifies the pattern without applying fusion.
*/
public class RewriteQuantizationFusedCompression extends HopRewriteRule {
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if(roots == null)
return null;

// traverse the HOP DAG
HashMap<String, Hop> floors = new HashMap<>();
HashMap<String, Hop> compresses = new HashMap<>();
for(Hop h : roots)
collectFloorCompressSequences(h, floors, compresses);

Hop.resetVisitStatus(roots);

// check compresses for compress-after-floor pattern
for(Entry<String, Hop> e : compresses.entrySet()) {
String inputname = e.getKey();
Hop compresshop = e.getValue();

if(floors.containsKey(inputname) // floors same name
&& ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) ||
(floors.get(inputname).getEndLine() < compresshop.getEndLine()) ||
(floors.get(inputname).getBeginLine() == compresshop.getBeginLine() &&
floors.get(inputname).getEndLine() == compresshop.getBeginLine() &&
floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) {

// retrieve the floor hop and inputs
Hop floorhop = floors.get(inputname);
Hop floorInput = floorhop.getInput().get(0);

// check if the input of the floor operation is a matrix
if(floorInput.getDataType() == DataType.MATRIX) {

// Check if the input of the floor operation involves a multiplication operation
if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) {
Hop initialMatrix = floorInput.getInput().get(0);
Hop sf = floorInput.getInput().get(1);

// create fused hop
BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64,
OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf);

// rewire compress consumers to fusedHop
List<Hop> parents = new ArrayList<>(compresshop.getParent());
for(Hop p : parents) {
HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop);
}
}
}
}
}
return roots;
}

@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
// do nothing, floor/compress do not occur in predicates
return root;
}

private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) {
if(hop.isVisited())
return;

// process childs
if(!hop.getInput().isEmpty())
for(Hop c : hop.getInput())
collectFloorCompressSequences(c, floors, compresses);

// process current hop
if(hop instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) hop;
if(uop.getOp() == OpOp1.FLOOR) {
floors.put(uop.getName(), uop);
}
else if(uop.getOp() == OpOp1.COMPRESS) {
compresses.put(uop.getInput(0).getName(), uop);
}
}
hop.setVisited();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ else if(((ConstIdentifier) getThirdExpr().getOutput())
else
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
break;

default: //always unconditional
raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false);
}
Expand Down Expand Up @@ -2013,6 +2013,34 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
else
raiseValidateError("Compress/DeCompress instruction not allowed in dml script");
break;
case QUANTIZE_COMPRESS:
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) {
checkNumParameters(2);
Expression firstExpr = getFirstExpr();
Expression secondExpr = getSecondExpr();

checkMatrixParam(getFirstExpr());

if(secondExpr != null) {
// check if scale factor is a scalar, vector or matrix
checkMatrixScalarParam(secondExpr);
// if scale factor is a vector or matrix, make sure it has an appropriate shape
if(secondExpr.getOutput().getDataType() != DataType.SCALAR) {
if(is1DMatrix(secondExpr)) {
long vectorLength = secondExpr.getOutput().getDim1();
if(vectorLength != firstExpr.getOutput().getDim1()) {
raiseValidateError(
"The length of the row-wise scale factor vector must match the number of rows in the matrix.");
}
}
else {
checkMatchingDimensions(firstExpr, secondExpr);
}
}
}
}
break;

case ROW_COUNT_DISTINCT:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2584,6 +2584,9 @@ else if ( sop.equalsIgnoreCase("!=") )
case DECOMPRESS:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.DECOMPRESS, expr);
break;
case QUANTIZE_COMPRESS:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
break;

// Boolean binary
case XOR:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
Expand Down Expand Up @@ -137,6 +138,21 @@ public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb,
return compress(mb, k, new CompressionSettingsBuilder(), root);
}

public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, MatrixBlock sf, int k, WTreeRoot root) {
// Handle only row vectors, as column-wise quantization is not allowed.
// The restriction is handled upstream
double[] scaleFactors = sf.getDenseBlockValues();
CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors);
return compress(mb, k, builder, root);
}

public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, ScalarObject sf, int k, WTreeRoot root) {
double[] scaleFactors = new double[1];
scaleFactors[0] = sf.getDoubleValue();
CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors);
return compress(mb, k, builder, root);
}

public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CostEstimatorBuilder csb) {
return compress(mb, k, new CompressionSettingsBuilder(), csb);
}
Expand Down Expand Up @@ -285,7 +301,7 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv
return new ImmutablePair<>(mb, null);
}

_stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0);
_stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0);
_stats.sparseSize = MatrixBlock.estimateSizeSparseInMemory(mb.getNumRows(), mb.getNumColumns(), mb.getSparsity());
_stats.originalSize = mb.getInMemorySize();
_stats.originalCost = costEstimator.getCost(mb);
Expand All @@ -300,8 +316,10 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv

res = new CompressedMatrixBlock(mb); // copy metadata and allocate soft reference
logInit();

classifyPhase();
if(compressionGroups == null)

if(compressionGroups == null)
return abortCompression();

// clear extra data from analysis
Expand Down Expand Up @@ -490,7 +508,26 @@ private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
MatrixBlock ucmb = ((CompressedMatrixBlock) mb).getUncompressed("Decompressing for abort: ", k);
return new ImmutablePair<>(ucmb, _stats);
}
return new ImmutablePair<>(mb, _stats);
if(compSettings.scaleFactors == null) {
LOG.warn("Scale factors are null - returning original matrix.");
return new ImmutablePair<>(mb, _stats);
} else {
LOG.warn("Scale factors are present - returning scaled matrix.");
MatrixBlock scaledMb = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.isInSparseFormat());
scaledMb.copy(mb);

// Apply scaling and flooring
// TODO: Use internal matrix prod
for(int r = 0; r < mb.getNumRows(); r++) {
double scaleFactor = compSettings.scaleFactors.length == 1 ? compSettings.scaleFactors[0] : compSettings.scaleFactors[r];
for(int c = 0; c < mb.getNumColumns(); c++) {
double newValue = Math.floor(mb.get(r, c) * scaleFactor);
scaledMb.set(r, c, newValue);
}
}
scaledMb.recomputeNonZeros();
return new ImmutablePair<>(scaledMb, _stats);
}
}

private void logInit() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ public class CompressionSettings {
public static final int BITMAP_BLOCK_SZ = Character.MAX_VALUE;

/**
* Sorting of values by physical length helps by 10-20%, especially for serial, while slight performance decrease for
* parallel incl multi-threaded, hence not applied for distributed operations (also because compression time +
* Sorting of values by physical length helps by 10-20%, especially for serial, while slight performance decrease
* for parallel incl multi-threaded, hence not applied for distributed operations (also because compression time +
* garbage collection increases)
*/
public final boolean sortTuplesByFrequency;
Expand Down Expand Up @@ -131,11 +131,13 @@ public class CompressionSettings {
/** if the settings have been logged already. */
public static boolean printedStatus = false;

public final double[] scaleFactors;

protected CompressionSettings(double samplingRatio, double samplePower, boolean allowSharedDictionary,
String transposeInput, int seed, boolean lossy, EnumSet<CompressionType> validCompressions,
boolean sortValuesByLength, PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage,
int minimumSampleSize, int maxSampleSize, EstimationType estimationType, CostType costComputationType,
double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType) {
double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType, double[] scaleFactors) {
this.samplingRatio = samplingRatio;
this.samplePower = samplePower;
this.allowSharedDictionary = allowSharedDictionary;
Expand All @@ -154,6 +156,8 @@ protected CompressionSettings(double samplingRatio, double samplePower, boolean
this.minimumCompressionRatio = minimumCompressionRatio;
this.isInSparkInstruction = isInSparkInstruction;
this.sdcSortType = sdcSortType;
this.scaleFactors = scaleFactors;

if(!printedStatus && LOG.isDebugEnabled()) {
printedStatus = true;
LOG.debug(this.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class CompressionSettingsBuilder {
private double minimumCompressionRatio = 1.0;
private boolean isInSparkInstruction = false;
private SORT_TYPE sdcSortType = SORT_TYPE.MATERIALIZE;
private double[] scaleFactors = null;

public CompressionSettingsBuilder() {

Expand All @@ -69,6 +70,19 @@ public CompressionSettingsBuilder() {

}

/**
* Sets the scale factors for compression, enabling quantization-fused compression.
*
* @param scaleFactors An array of scale factors applied during compression.
* - If row-wise scaling is used, this should be an array where each value corresponds to a row.
* - If a single scalar is provided, it is applied uniformly to the entire matrix.
* @return The CompressionSettingsBuilder instance with the updated scale factors.
*/
public CompressionSettingsBuilder setScaleFactor(double[] scaleFactors) {
this.scaleFactors = scaleFactors;
return this;
}

/**
* Copy the settings from another CompressionSettings Builder, modifies this, not that.
*
Expand Down Expand Up @@ -331,6 +345,6 @@ public CompressionSettings create() {
return new CompressionSettings(samplingRatio, samplePower, allowSharedDictionary, transposeInput, seed, lossy,
validCompressions, sortValuesByLength, columnPartitioner, maxColGroupCoCode, coCodePercentage,
minimumSampleSize, maxSampleSize, estimationType, costType, minimumCompressionRatio, isInSparkInstruction,
sdcSortType);
sdcSortType, scaleFactors);
}
}
Loading
Loading