diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 4ff5654de02..94bb04a3046 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -88,6 +88,7 @@ public enum Builtins { COLVAR("colVars", false), COMPONENTS("components", true), COMPRESS("compress", false, ReturnType.MULTI_RETURN), + QUANTIZE_COMPRESS("quantize_compress", false, ReturnType.MULTI_RETURN), CONFUSIONMATRIX("confusionMatrix", true), CONV2D("conv2d", false), CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index a878d3f0ace..541d3728f00 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -303,6 +303,7 @@ public enum Opcodes { PARTITION("partition", CPType.Partition), COMPRESS(Compression.OPCODE, CPType.Compression), DECOMPRESS(DeCompression.OPCODE, CPType.DeCompression), + QUANTIZE_COMPRESS("quantize_compress", CPType.QuantizeCompression), SPOOF("spoof", CPType.SpoofFused), PREFETCH("prefetch", CPType.Prefetch), EVICT("_evict", CPType.EvictLineageCache), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index c9820a2c092..58b226bcd23 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -634,8 +634,9 @@ public enum OpOp2 { //fused ML-specific operators for performance MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) - MINUS1_MULT(false); //1-X*Y - + MINUS1_MULT(false), //1-X*Y + QUANTIZE_COMPRESS(false); //quantization-fused compression + private final boolean _validOuter; private OpOp2(boolean outer) { diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index a3161c57230..df7457529d1 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -279,6 +279,15 @@ public enum MemoryManager { */ public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true; + /** + * This variable allows for insertion of Quantize and compress in the dml script from the user. + */ + public static boolean ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND = true; + + /** + * Boolean specifying if quantization-fused compression rewrite is allowed. + */ + public static boolean ALLOW_QUANTIZE_COMPRESS_REWRITE = true; /** * Boolean specifying if compression rewrites is allowed. This is disabled at run time if the IPA for Workload aware compression diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index b08d836efe5..874ddae0347 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -90,7 +90,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications _dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock - + if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE ) + _dagRuleSet.add( new RewriteQuantizationFusedCompression() ); + //add statement block rewrite rules if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java new file mode 100644 index 00000000000..f29d1dce816 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewrite; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; + +import org.apache.sysds.common.Types.OpOp1; +import org.apache.sysds.common.Types.OpOp2; +import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.hops.BinaryOp; + +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; + +import org.apache.sysds.hops.Hop; + +/** + * Rule: RewriteFloorCompress. Detects the sequence `M2 = floor(M * S)` followed by `C = compress(M2)` and prepares for + * fusion into a single operation. This rewrite improves performance by avoiding intermediate results. Currently, it + * identifies the pattern without applying fusion. + */ +public class RewriteQuantizationFusedCompression extends HopRewriteRule { + @Override + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { + if(roots == null) + return null; + + // traverse the HOP DAG + HashMap floors = new HashMap<>(); + HashMap compresses = new HashMap<>(); + for(Hop h : roots) + collectFloorCompressSequences(h, floors, compresses); + + Hop.resetVisitStatus(roots); + + // check compresses for compress-after-floor pattern + for(Entry e : compresses.entrySet()) { + String inputname = e.getKey(); + Hop compresshop = e.getValue(); + + if(floors.containsKey(inputname) // floors same name + && ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) || + (floors.get(inputname).getEndLine() < compresshop.getEndLine()) || + (floors.get(inputname).getBeginLine() == compresshop.getBeginLine() && + floors.get(inputname).getEndLine() == compresshop.getBeginLine() && + floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) { + + // retrieve the floor hop and inputs + Hop floorhop = floors.get(inputname); + Hop floorInput = floorhop.getInput().get(0); + + // check if the input of the floor operation is a matrix + if(floorInput.getDataType() == DataType.MATRIX) { + + // Check if the input of the floor operation involves a multiplication operation + if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) { + Hop initialMatrix = floorInput.getInput().get(0); + Hop sf = floorInput.getInput().get(1); + + // create fused hop + BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64, + OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf); + + // rewire compress consumers to fusedHop + List parents = new ArrayList<>(compresshop.getParent()); + for(Hop p : parents) { + HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop); + } + } + } + } + } + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + // do nothing, floor/compress do not occur in predicates + return root; + } + + private void collectFloorCompressSequences(Hop hop, HashMap floors, HashMap compresses) { + if(hop.isVisited()) + return; + + // process childs + if(!hop.getInput().isEmpty()) + for(Hop c : hop.getInput()) + collectFloorCompressSequences(c, floors, compresses); + + // process current hop + if(hop instanceof UnaryOp) { + UnaryOp uop = (UnaryOp) hop; + if(uop.getOp() == OpOp1.FLOOR) { + floors.put(uop.getName(), uop); + } + else if(uop.getOp() == OpOp1.COMPRESS) { + compresses.put(uop.getInput(0).getName(), uop); + } + } + hop.setVisited(); + } +} diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 6a68f867f90..70245de070c 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -751,7 +751,7 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; - + default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } @@ -2013,6 +2013,34 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; + case QUANTIZE_COMPRESS: + if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { + checkNumParameters(2); + Expression firstExpr = getFirstExpr(); + Expression secondExpr = getSecondExpr(); + + checkMatrixParam(getFirstExpr()); + + if(secondExpr != null) { + // check if scale factor is a scalar, vector or matrix + checkMatrixScalarParam(secondExpr); + // if scale factor is a vector or matrix, make sure it has an appropriate shape + if(secondExpr.getOutput().getDataType() != DataType.SCALAR) { + if(is1DMatrix(secondExpr)) { + long vectorLength = secondExpr.getOutput().getDim1(); + if(vectorLength != firstExpr.getOutput().getDim1()) { + raiseValidateError( + "The length of the row-wise scale factor vector must match the number of rows in the matrix."); + } + } + else { + checkMatchingDimensions(firstExpr, secondExpr); + } + } + } + } + break; + case ROW_COUNT_DISTINCT: checkNumParameters(1); checkMatrixParam(getFirstExpr()); diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index cd58af62337..9c398443515 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2584,6 +2584,9 @@ else if ( sop.equalsIgnoreCase("!=") ) case DECOMPRESS: currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.DECOMPRESS, expr); break; + case QUANTIZE_COMPRESS: + currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2); + break; // Boolean binary case XOR: diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 93240644a14..b128d30411a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -137,6 +138,21 @@ public static Pair compress(MatrixBlock mb, return compress(mb, k, new CompressionSettingsBuilder(), root); } + public static Pair compress(MatrixBlock mb, MatrixBlock sf, int k, WTreeRoot root) { + // Handle only row vectors, as column-wise quantization is not allowed. + // The restriction is handled upstream + double[] scaleFactors = sf.getDenseBlockValues(); + CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors); + return compress(mb, k, builder, root); + } + + public static Pair compress(MatrixBlock mb, ScalarObject sf, int k, WTreeRoot root) { + double[] scaleFactors = new double[1]; + scaleFactors[0] = sf.getDoubleValue(); + CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors); + return compress(mb, k, builder, root); + } + public static Pair compress(MatrixBlock mb, int k, CostEstimatorBuilder csb) { return compress(mb, k, new CompressionSettingsBuilder(), csb); } @@ -285,7 +301,7 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv return new ImmutablePair<>(mb, null); } - _stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0); + _stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0); _stats.sparseSize = MatrixBlock.estimateSizeSparseInMemory(mb.getNumRows(), mb.getNumColumns(), mb.getSparsity()); _stats.originalSize = mb.getInMemorySize(); _stats.originalCost = costEstimator.getCost(mb); @@ -300,8 +316,10 @@ else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOv res = new CompressedMatrixBlock(mb); // copy metadata and allocate soft reference logInit(); + classifyPhase(); - if(compressionGroups == null) + + if(compressionGroups == null) return abortCompression(); // clear extra data from analysis @@ -490,7 +508,26 @@ private Pair abortCompression() { MatrixBlock ucmb = ((CompressedMatrixBlock) mb).getUncompressed("Decompressing for abort: ", k); return new ImmutablePair<>(ucmb, _stats); } - return new ImmutablePair<>(mb, _stats); + if(compSettings.scaleFactors == null) { + LOG.warn("Scale factors are null - returning original matrix."); + return new ImmutablePair<>(mb, _stats); + } else { + LOG.warn("Scale factors are present - returning scaled matrix."); + MatrixBlock scaledMb = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.isInSparseFormat()); + scaledMb.copy(mb); + + // Apply scaling and flooring + // TODO: Use internal matrix prod + for(int r = 0; r < mb.getNumRows(); r++) { + double scaleFactor = compSettings.scaleFactors.length == 1 ? compSettings.scaleFactors[0] : compSettings.scaleFactors[r]; + for(int c = 0; c < mb.getNumColumns(); c++) { + double newValue = Math.floor(mb.get(r, c) * scaleFactor); + scaledMb.set(r, c, newValue); + } + } + scaledMb.recomputeNonZeros(); + return new ImmutablePair<>(scaledMb, _stats); + } } private void logInit() { diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java index 31c034ef4dc..f6321bc1b6d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java @@ -46,8 +46,8 @@ public class CompressionSettings { public static final int BITMAP_BLOCK_SZ = Character.MAX_VALUE; /** - * Sorting of values by physical length helps by 10-20%, especially for serial, while slight performance decrease for - * parallel incl multi-threaded, hence not applied for distributed operations (also because compression time + + * Sorting of values by physical length helps by 10-20%, especially for serial, while slight performance decrease + * for parallel incl multi-threaded, hence not applied for distributed operations (also because compression time + * garbage collection increases) */ public final boolean sortTuplesByFrequency; @@ -131,11 +131,13 @@ public class CompressionSettings { /** if the settings have been logged already. */ public static boolean printedStatus = false; + public final double[] scaleFactors; + protected CompressionSettings(double samplingRatio, double samplePower, boolean allowSharedDictionary, String transposeInput, int seed, boolean lossy, EnumSet validCompressions, boolean sortValuesByLength, PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage, int minimumSampleSize, int maxSampleSize, EstimationType estimationType, CostType costComputationType, - double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType) { + double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType, double[] scaleFactors) { this.samplingRatio = samplingRatio; this.samplePower = samplePower; this.allowSharedDictionary = allowSharedDictionary; @@ -154,6 +156,8 @@ protected CompressionSettings(double samplingRatio, double samplePower, boolean this.minimumCompressionRatio = minimumCompressionRatio; this.isInSparkInstruction = isInSparkInstruction; this.sdcSortType = sdcSortType; + this.scaleFactors = scaleFactors; + if(!printedStatus && LOG.isDebugEnabled()) { printedStatus = true; LOG.debug(this.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java index dc0908dc9bf..ae6a0b2d231 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java @@ -52,6 +52,7 @@ public class CompressionSettingsBuilder { private double minimumCompressionRatio = 1.0; private boolean isInSparkInstruction = false; private SORT_TYPE sdcSortType = SORT_TYPE.MATERIALIZE; + private double[] scaleFactors = null; public CompressionSettingsBuilder() { @@ -69,6 +70,19 @@ public CompressionSettingsBuilder() { } + /** + * Sets the scale factors for compression, enabling quantization-fused compression. + * + * @param scaleFactors An array of scale factors applied during compression. + * - If row-wise scaling is used, this should be an array where each value corresponds to a row. + * - If a single scalar is provided, it is applied uniformly to the entire matrix. + * @return The CompressionSettingsBuilder instance with the updated scale factors. + */ + public CompressionSettingsBuilder setScaleFactor(double[] scaleFactors) { + this.scaleFactors = scaleFactors; + return this; + } + /** * Copy the settings from another CompressionSettings Builder, modifies this, not that. * @@ -331,6 +345,6 @@ public CompressionSettings create() { return new CompressionSettings(samplingRatio, samplePower, allowSharedDictionary, transposeInput, seed, lossy, validCompressions, sortValuesByLength, columnPartitioner, maxColGroupCoCode, coCodePercentage, minimumSampleSize, maxSampleSize, estimationType, costType, minimumCompressionRatio, isInSparkInstruction, - sdcSortType); + sdcSortType, scaleFactors); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java b/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java index 70fd4da263e..7be7ac4b93b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java @@ -24,6 +24,7 @@ import java.util.Comparator; import java.util.List; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressionSettings; @@ -48,7 +49,8 @@ public class BitmapEncoder { public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, int estimatedNumberOfUniques, CompressionSettings cs) { - return extractBitmap(colIndices, rawBlock, cs.transposed, estimatedNumberOfUniques, cs.sortTuplesByFrequency); + return extractBitmap(colIndices, rawBlock, cs.transposed, estimatedNumberOfUniques, cs.sortTuplesByFrequency, + cs.scaleFactors); } /** @@ -61,75 +63,177 @@ public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, * @param rawBlock An uncompressed matrix block; can be dense, sparse, empty, or null (not * Compressed!) * @param transposed Boolean specifying if the rawBlock was transposed. - * @param estimatedNumberOfUniques The number of estimated uniques inside this group. Used to allocated the HashMaps. + * @param estimatedNumberOfUniques The number of estimated uniques inside this group. Used to allocated the + * HashMaps. * @param sortedEntries Boolean specifying if the entries should be sorted based on frequency of tuples * @return Uncompressed bitmap representation of the columns specified */ public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, boolean transposed, int estimatedNumberOfUniques, boolean sortedEntries) { + // Overloaded method with scaleFactors defaulted to null + return extractBitmap(colIndices, rawBlock, transposed, estimatedNumberOfUniques, sortedEntries, null); + } + + /** + * Generate quantization-fused uncompressed bitmaps for a set of columns in an uncompressed matrix block. + * + * if the rawBlock is transposed and sparse it should be guaranteed that the rows specified are not empty, aka all + * zero. + * + * @param colIndices Indexes (within the block) of the columns to extract + * @param rawBlock An uncompressed matrix block; can be dense, sparse, empty, or null (not + * Compressed!) + * @param transposed Boolean specifying if the rawBlock was transposed. + * @param estimatedNumberOfUniques The number of estimated uniques inside this group. Used to allocated the + * HashMaps. + * @param sortedEntries Boolean specifying if the entries should be sorted based on frequency of tuples + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for + * entire matrix + * @return Uncompressed bitmap representation of the columns specified + */ + public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, boolean transposed, + int estimatedNumberOfUniques, boolean sortedEntries, double[] scaleFactors) { if(rawBlock == null || rawBlock.isEmpty()) return null; final int numRows = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); final int estimatedNumber = Math.max(estimatedNumberOfUniques, 8); - if(colIndices.size() == 1) + if(colIndices.size() == 1) { return extractBitmapSingleColumn(colIndices.get(0), rawBlock, numRows, transposed, estimatedNumber, - sortedEntries); - else - return extractBitmapMultiColumns(colIndices, rawBlock, numRows, transposed, estimatedNumber, sortedEntries); + sortedEntries, scaleFactors); + } + else { + return extractBitmapMultiColumns(colIndices, rawBlock, numRows, transposed, estimatedNumber, sortedEntries, + scaleFactors); + } } - private static ABitmap extractBitmapSingleColumn(int colIndex, MatrixBlock rawBlock, int numRows, boolean transposed, - int est, boolean sort) { + private static ABitmap extractBitmapSingleColumn(int colIndex, MatrixBlock rawBlock, int numRows, + boolean transposed, int est, boolean sort, double[] scaleFactors) { if(transposed) { if(rawBlock.isInSparseFormat() && rawBlock.getSparseBlock().isEmpty(colIndex)) return null; - return makeSingleColBitmap(extractSingleColT(colIndex, rawBlock, est), rawBlock.getNumColumns(), sort); + return makeSingleColBitmap(extractSingleColT(colIndex, rawBlock, est, scaleFactors), rawBlock.getNumColumns(), sort); } else - return makeSingleColBitmap(extractSingleCol(colIndex, rawBlock, est), rawBlock.getNumRows(), sort); + return makeSingleColBitmap(extractSingleCol(colIndex, rawBlock, est, scaleFactors), rawBlock.getNumRows(), + sort); } - private static DoubleIntListHashMap extractSingleCol(int colIndex, MatrixBlock rawBlock, int estimatedUnique) { + private static DoubleIntListHashMap extractSingleCol(int colIndex, MatrixBlock rawBlock, int estimatedUnique, + double[] scaleFactors) { final DoubleIntListHashMap distinctVals = new DoubleIntListHashMap(estimatedUnique); final int nRows = rawBlock.getNumRows(); final int nCols = rawBlock.getNumColumns(); final boolean sparse = rawBlock.isInSparseFormat(); - if(sparse) { - final SparseBlock sb = rawBlock.getSparseBlock(); - for(int r = 0; r < nRows; r++) { - if(sb.isEmpty(r)) - continue; - final int apos = sb.pos(r); - final int alen = sb.size(r) + apos; - final int[] aix = sb.indexes(r); - final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); - if(idx >= 0) - distinctVals.appendValue(sb.values(r)[idx], r); + if(scaleFactors == null) { + if(sparse) { + final SparseBlock sb = rawBlock.getSparseBlock(); + for(int r = 0; r < nRows; r++) { + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); + if(idx >= 0) + distinctVals.appendValue(sb.values(r)[idx], r); + } + } + else if(rawBlock.getDenseBlock().isContiguous()) { + final double[] values = rawBlock.getDenseBlockValues(); + if(nCols == 1) + // Since the only values contained is in this column index. simply extract it continuously. + for(int i = 0; i < values.length; i++) + distinctVals.appendValue(values[i], i); + else + // For loop down through the rows skipping all other values than the ones in the specified column + // index. + for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) + distinctVals.appendValue(values[off], i); + } + else { // GENERAL CASE + // This case is slow, because it does a binary search in each row of the sparse input. (if sparse) + // and does get value in dense cases with multi blocks. + for(int i = 0; i < nRows; i++) + distinctVals.appendValue(rawBlock.get(i, colIndex), i); } } - else if(rawBlock.getDenseBlock().isContiguous()) { - final double[] values = rawBlock.getDenseBlockValues(); - if(nCols == 1) - // Since the only values contained is in this column index. simply extract it continuously. - for(int i = 0; i < values.length; i++) - distinctVals.appendValue(values[i], i); - else - // For loop down through the rows skipping all other values than the ones in the specified column index. - for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) - distinctVals.appendValue(values[off], i); - } - else { // GENERAL CASE - // This case is slow, because it does a binary search in each row of the sparse input. (if sparse) - // and does get value in dense cases with multi blocks. - for(int i = 0; i < nRows; i++) - distinctVals.appendValue(rawBlock.get(i, colIndex), i); + else { + // Apply single scale factor + if(scaleFactors.length == 1) { + final double scaleFactor = scaleFactors[0]; + + if(sparse) { + final SparseBlock sb = rawBlock.getSparseBlock(); + for(int r = 0; r < nRows; r++) { + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); + if(idx >= 0) + distinctVals.appendValue(Math.floor(sb.values(r)[idx] * scaleFactor), r); + } + } + else if(rawBlock.getDenseBlock().isContiguous()) { + final double[] values = rawBlock.getDenseBlockValues(); + if(nCols == 1) { + for(int i = 0; i < values.length; i++) + distinctVals.appendValue(Math.floor(values[i] * scaleFactor), i); + } + else { + for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) + distinctVals.appendValue(Math.floor(values[off] * scaleFactor), i); + } + } + else { // GENERAL CASE + for(int i = 0; i < nRows; i++) + distinctVals.appendValue(Math.floor(rawBlock.get(i, colIndex) * scaleFactor), i); + } + } + else { + // Apply scale factor row-wise. The shape of scale factor is handled upstream. + if(sparse) { + final SparseBlock sb = rawBlock.getSparseBlock(); + for(int r = 0; r < nRows; r++) { + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); + if(idx >= 0) + distinctVals.appendValue(Math.floor(sb.values(r)[idx] * scaleFactors[r]), r); + } + } + else if(rawBlock.getDenseBlock().isContiguous()) { + final double[] values = rawBlock.getDenseBlockValues(); + if(nCols == 1) { + for(int i = 0; i < values.length; i++) + distinctVals.appendValue(Math.floor(values[i] * scaleFactors[i]), i); + } + else { + for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) + distinctVals.appendValue(Math.floor(values[off] * scaleFactors[i]), i); + } + } + else { // GENERAL CASE + for(int i = 0; i < nRows; i++) + distinctVals.appendValue(Math.floor(rawBlock.get(i, colIndex) * scaleFactors[i]), i); + } + } } return distinctVals; } - private static DoubleIntListHashMap extractSingleColT(int colIndex, MatrixBlock rawBlock, int estimatedUnique) { + private static DoubleIntListHashMap extractSingleColT(int colIndex, MatrixBlock rawBlock, int estimatedUnique, double[] scaleFactors) { + if (scaleFactors != null) { + throw new NotImplementedException(); + } + // probe map for distinct items (for value or value groups) final DoubleIntListHashMap distinctVals = new DoubleIntListHashMap(estimatedUnique); @@ -163,15 +267,20 @@ else if(rawBlock.getNumRows() == 1) { } private static ABitmap extractBitmapMultiColumns(IColIndex colIndices, MatrixBlock rawBlock, int numRows, - boolean transposed, int estimatedUnique, boolean sort) { + boolean transposed, int estimatedUnique, boolean sort, double[] scaleFactors) { final DblArrayIntListHashMap map = new DblArrayIntListHashMap(estimatedUnique); - final ReaderColumnSelection reader = ReaderColumnSelection.createReader(rawBlock, colIndices, transposed); + + final ReaderColumnSelection reader = (scaleFactors == null) ? ReaderColumnSelection.createReader(rawBlock, + colIndices, + transposed) : ReaderColumnSelection.createQuantizedReader(rawBlock, colIndices, transposed, scaleFactors); + DblArray cellVals = null; try { DblArray empty = new DblArray(new double[colIndices.size()]); while((cellVals = reader.nextRow()) != null) { - if(!cellVals.equals(empty)) + if(!cellVals.equals(empty)) { map.appendValue(cellVals, reader.getCurrentRowIndex()); + } } } @@ -195,6 +304,7 @@ private static ABitmap makeMultiColBitmap(DblArrayIntListHashMap map, int numRow values[bitmapIx] = val.key.getData(); offsetsLists[bitmapIx++] = val.value; } + return new MultiColBitmap(offsetsLists, values, numRows); } else diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java index a759e999397..ef8b83c3b83 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java @@ -29,6 +29,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -226,12 +227,12 @@ private void logEstVsActual(double time, AColGroup act, CompressedSizeInfoColGro if(estC < actC * 0.75) { String warning = "The estimate cost is significantly off : " + est; LOG.debug( - String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", time, - retType, estC, actC, act.getNumValues(), cols, wanted, warning)); + String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", + time, retType, estC, actC, act.getNumValues(), cols, wanted, warning)); } else { - LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", time, - retType, estC, actC, act.getNumValues(), cols, wanted)); + LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", + time, retType, estC, actC, act.getNumValues(), cols, wanted)); } } @@ -261,19 +262,40 @@ private AColGroup compress(CompressedSizeInfoColGroup cg) throws Exception { if((ct == CompressionType.EMPTY && !t) || // (t && colIndexes.size() == 1 && in.isInSparseFormat() // Empty Column && in.getSparseBlock().isEmpty(colIndexes.get(0)))) + // TODO: handle quantization-fused compression if deemed necessary, + // but if the matrix reaches here, it likely doesn't need quantization. return new ColGroupEmpty(colIndexes); - else if(ct == CompressionType.UNCOMPRESSED) // don't construct mapping if uncompressed - return ColGroupUncompressed.create(colIndexes, in, t); + else if(ct == CompressionType.UNCOMPRESSED) { // don't construct mapping if uncompressed + if(cs.scaleFactors != null) { + return ColGroupUncompressed.createQuantized(colIndexes, in, t, cs.scaleFactors); + } + else { + return ColGroupUncompressed.create(colIndexes, in, t); + } + } else if((ct == CompressionType.SDC || ct == CompressionType.CONST) // && in.isInSparseFormat() // && t && (// (colIndexes.size() > 1 && cg.getNumOffs() < 0.3 * nRow) // - || colIndexes.size() == 1)) - return compressSDCFromSparseTransposedBlock(colIndexes, cg.getNumVals(), cg.getTupleSparsity()); - else if(ct == CompressionType.DDC) + || colIndexes.size() == 1)) { + if(cs.scaleFactors != null) { + throw new NotImplementedException(); // TODO: handle quantization-fused compression + } + else { + return compressSDCFromSparseTransposedBlock(colIndexes, cg.getNumVals(), cg.getTupleSparsity()); + } + } + else if(ct == CompressionType.DDC) { return directCompressDDC(colIndexes, cg); - else if(ct == CompressionType.LinearFunctional) - return compressLinearFunctional(colIndexes, in, cs); + } + else if(ct == CompressionType.LinearFunctional) { + if(cs.scaleFactors != null) { + throw new NotImplementedException(); // quantization-fused compression NOT allowed + } + else { + return compressLinearFunctional(colIndexes, in, cs); + } + } else if(ct == CompressionType.DDCFOR) { AColGroup g = directCompressDDC(colIndexes, cg); if(g instanceof ColGroupDDC) @@ -285,12 +307,13 @@ else if(ct == CompressionType.SDC && colIndexes.size() == 1 && !t) { } final ABitmap ubm = BitmapEncoder.extractBitmap(colIndexes, in, cg.getNumVals(), cs); - if(ubm == null) // no values ... therefore empty + if(ubm == null) {// no values ... therefore empty return new ColGroupEmpty(colIndexes); - + } final IntArrayList[] of = ubm.getOffsetList(); - if(of.length == 1 && of[0].size() == nRow) // If this always constant + if(of.length == 1 && of[0].size() == nRow) { // If this always constant return ColGroupConst.create(colIndexes, DictionaryFactory.create(ubm)); + } final double tupleSparsity = colIndexes.size() > 4 ? cg.getTupleSparsity() : 1.0; @@ -330,19 +353,100 @@ private AColGroup compressSDCSingleColDirectBlock(IColIndex colIndexes, int nVal IDictionary dict = Dictionary.create(cMap.getDictionary(dictSize)); IntArrayList offs = new IntArrayList(nRow - defCount); AMapToData map = MapToFactory.create(nRow - defCount, dictSize); - getOffsets(offs, map, cMap, col, def); - + if(cs.scaleFactors != null) { + getOffsetsScaled(offs, map, cMap, col, def); + } + else { + getOffsets(offs, map, cMap, col, def); + } AOffset aoff = OffsetFactory.createOffset(offs); return ColGroupSDC.create(colIndexes, nRow, dict, new double[] {def}, aoff, map, null); + } + private void getOffsetsScaled(IntArrayList offs, AMapToData map, DoubleCountHashMap cMap, int col, double def) { + final double scaleFactor = cs.scaleFactors[0]; // Single column, thus single scalar value. + + if(in.isInSparseFormat()) { + final SparseBlock sb = in.getSparseBlock(); + + if(def == 0) { // If zero is the default value + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + continue; // Skip explicitly storing zero values + + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + + if(idx >= 0) { + double v = Math.floor(sb.values(r)[idx] * scaleFactor); + map.set(offs.size(), cMap.getId(v)); + offs.appendValue(r); + } + } + } + + else { // If zero is NOT the default value, track missing values explicitly + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) { + map.set(offs.size(), cMap.getId(0.0)); + offs.appendValue(r); + } + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + + if(idx < 0) { // Missing entry + map.set(offs.size(), cMap.getId(0.0)); + offs.appendValue(r); + } + else { + double v = Math.floor(sb.values(r)[idx] * scaleFactor); + if(!Util.eq(v, def)) { + map.set(offs.size(), cMap.getId(v)); + offs.appendValue(r); + } + } + } + } + } + + } + else if(in.getDenseBlock().isContiguous()) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + + for(int r = 0; r < nRow; r++, off += nCol) { + double scaledValue = Math.floor(dv[off] * scaleFactor); + if(!Util.eq(scaledValue, def)) { + map.set(offs.size(), cMap.getId(scaledValue)); + offs.appendValue(r); + } + } + } + else { + final DenseBlock db = in.getDenseBlock(); + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + double scaledValue = Math.floor(dv[off] * scaleFactor); + if(!Util.eq(scaledValue, def)) { + map.set(offs.size(), cMap.getId(scaledValue)); + offs.appendValue(r); + } + } + } } private void getOffsets(IntArrayList offs, AMapToData map, DoubleCountHashMap cMap, int col, double def) { if(in.isInSparseFormat()) { + final SparseBlock sb = in.getSparseBlock(); if(def == 0) { - final SparseBlock sb = in.getSparseBlock(); for(int r = 0; r < nRow; r++) { if(sb.isEmpty(r)) continue; @@ -358,11 +462,8 @@ private void getOffsets(IntArrayList offs, AMapToData map, DoubleCountHashMap cM } } else { - - final SparseBlock sb = in.getSparseBlock(); for(int r = 0; r < nRow; r++) { if(sb.isEmpty(r)) { - map.set(offs.size(), cMap.getId(0.0)); offs.appendValue(r); } @@ -384,11 +485,13 @@ private void getOffsets(IntArrayList offs, AMapToData map, DoubleCountHashMap cM } } } + } } else if(in.getDenseBlock().isContiguous()) { final double[] dv = in.getDenseBlockValues(); int off = col; + for(int r = 0; r < nRow; r++, off += nCol) if(!Util.eq(dv[off], def)) { map.set(offs.size(), cMap.getId(dv[off])); @@ -409,16 +512,55 @@ else if(in.getDenseBlock().isContiguous()) { } private void countElements(DoubleCountHashMap map, int col) { - if(in.isInSparseFormat()) - countElementsSparse(map, col); - else if(in.getDenseBlock().isContiguous()) - countElementsDenseContiguous(map, col); - else - countElementsDenseGeneric(map, col); + if(cs.scaleFactors != null) { + if(in.isInSparseFormat()) { + countElementsSparseScaled(map, col); + } + else if(in.getDenseBlock().isContiguous()) { + countElementsDenseContiguousScaled(map, col); + } + else { + countElementsDenseGenericScaled(map, col); + } + } + else { + if(in.isInSparseFormat()) { + countElementsSparse(map, col); + } + else if(in.getDenseBlock().isContiguous()) { + countElementsDenseContiguous(map, col); + } + else { + countElementsDenseGeneric(map, col); + } + } + } + + private void countElementsSparseScaled(DoubleCountHashMap map, int col) { + final SparseBlock sb = in.getSparseBlock(); + + double scaleFactor = cs.scaleFactors[0]; + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + map.increment(0.0); + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + if(idx < 0) { + map.increment(0.0); + } + else { + map.increment(Math.floor(sb.values(r)[idx] * scaleFactor)); + } + } + } } private void countElementsSparse(DoubleCountHashMap map, int col) { final SparseBlock sb = in.getSparseBlock(); + for(int r = 0; r < nRow; r++) { if(sb.isEmpty(r)) map.increment(0.0); @@ -435,13 +577,34 @@ private void countElementsSparse(DoubleCountHashMap map, int col) { } } + private void countElementsDenseContiguousScaled(DoubleCountHashMap map, int col) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + + double scaleFactor = cs.scaleFactors[0]; + for(int r = 0; r < nRow; r++, off += nCol) { + map.increment(Math.floor(dv[off] * scaleFactor)); + } + } + private void countElementsDenseContiguous(DoubleCountHashMap map, int col) { final double[] dv = in.getDenseBlockValues(); int off = col; + for(int r = 0; r < nRow; r++, off += nCol) map.increment(dv[off]); } + private void countElementsDenseGenericScaled(DoubleCountHashMap map, int col) { + final DenseBlock db = in.getDenseBlock(); + double scaleFactor = cs.scaleFactors[0]; + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + map.increment(Math.floor(dv[off] * scaleFactor)); + } + } + private void countElementsDenseGeneric(DoubleCountHashMap map, int col) { final DenseBlock db = in.getDenseBlock(); for(int r = 0; r < nRow; r++) { @@ -452,10 +615,15 @@ private void countElementsDenseGeneric(DoubleCountHashMap map, int col) { } private AColGroup directCompressDDC(IColIndex colIndexes, CompressedSizeInfoColGroup cg) throws Exception { - if(colIndexes.size() > 1) + // testing multicol + if(colIndexes.size() > 1) { + LOG.debug("DDC multi column"); return directCompressDDCMultiCol(colIndexes, cg); - else + } + else { + LOG.debug("DDC single column"); return directCompressDDCSingleCol(colIndexes, cg); + } } private AColGroup directCompressDDCSingleCol(IColIndex colIndexes, CompressedSizeInfoColGroup cg) { @@ -465,16 +633,27 @@ private AColGroup directCompressDDCSingleCol(IColIndex colIndexes, CompressedSiz // unlike multi-col no special handling of zero entries are needed. if(cs.transposed) - readToMapDDCTransposed(col, map, d); - else - readToMapDDC(col, map, d); + if(cs.scaleFactors != null) { + throw new NotImplementedException(); // TODO: Handle scaled transposed columns + } + else { + readToMapDDCTransposed(col, map, d); + } + else { + if(cs.scaleFactors != null) { + readToMapDDCScaled(col, map, d); + } + else { + readToMapDDC(col, map, d); + } + } if(map.size() == 0) return new ColGroupEmpty(colIndexes); IDictionary dict = DictionaryFactory.create(map); final int nUnique = map.size(); - final AMapToData resData = d.resize( nUnique); + final AMapToData resData = d.resize(nUnique); return ColGroupDDC.create(colIndexes, dict, resData, null); } @@ -485,10 +664,14 @@ private AColGroup directCompressDDCMultiCol(IColIndex colIndexes, CompressedSize final DblArrayCountHashMap map = new DblArrayCountHashMap(Math.max(cg.getNumVals(), 64)); boolean extra; - if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k < csi.getNumberColGroups() || pool == null ) + if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k < csi.getNumberColGroups() || pool == null) { + LOG.debug("Non parallel"); extra = readToMapDDC(colIndexes, map, d, 0, nRow, fill); - else + } + else { + LOG.debug("Parallel"); extra = parallelReadToMapDDC(colIndexes, map, d, nRow, fill, k); + } if(map.size() == 0) // If the column was empty. @@ -507,7 +690,11 @@ private AColGroup directCompressDDCMultiCol(IColIndex colIndexes, CompressedSize private boolean readToMapDDC(IColIndex colIndexes, DblArrayCountHashMap map, AMapToData data, int rl, int ru, int fill) { - ReaderColumnSelection reader = ReaderColumnSelection.createReader(in, colIndexes, cs.transposed, rl, ru); + + ReaderColumnSelection reader = (cs.scaleFactors == null) ? ReaderColumnSelection.createReader(in, colIndexes, + cs.transposed, rl, + ru) : ReaderColumnSelection.createQuantizedReader(in, colIndexes, cs.transposed, rl, ru, cs.scaleFactors); + DblArray cellVals = reader.nextRow(); boolean extra = false; int r = rl; @@ -531,6 +718,49 @@ private boolean readToMapDDC(IColIndex colIndexes, DblArrayCountHashMap map, AMa return extra; } + // TODO: Merge logic to readToMapDDC. This should be done for other scaled methods + private void readToMapDDCScaled(int col, DoubleCountHashMap map, AMapToData data) { + double scaleFactor = cs.scaleFactors[0]; + + if(in.isInSparseFormat()) { + // not good but could happen + final SparseBlock sb = in.getSparseBlock(); + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + data.set(r, map.increment(0.0)); + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + if(idx < 0) + data.set(r, map.increment(0.0)); + else { + double scaledValue = Math.floor(sb.values(r)[idx] * scaleFactor); + data.set(r, map.increment(scaledValue)); + } + } + } + } + else if(in.getDenseBlock().isContiguous()) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + for(int r = 0; r < nRow; r++, off += nCol) { + double scaledValue = Math.floor(dv[off] * scaleFactor); + data.set(r, map.increment(scaledValue)); + } + } + else { + final DenseBlock db = in.getDenseBlock(); + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + double scaledValue = Math.floor(dv[off] * scaleFactor); + data.set(r, map.increment(scaledValue)); + } + } + } + private void readToMapDDC(int col, DoubleCountHashMap map, AMapToData data) { if(in.isInSparseFormat()) { // not good but could happen @@ -673,7 +903,7 @@ private static AColGroup compressSDCNormal(IColIndex colIndexes, int numZeros, i cs.sdcSortType); AOffset indexes = OffsetFactory.createOffset(s.getIndexes()); AMapToData _data = s.getData(); - _data = _data.resize( dict.getNumberOfValues(colIndexes.size())); + _data = _data.resize(dict.getNumberOfValues(colIndexes.size())); return ColGroupSDC.create(colIndexes, rlen, dict, defaultTuple, indexes, _data, null); } @@ -764,7 +994,9 @@ private AColGroup compressMultiColSDCFromSparseTransposedBlock(IColIndex cols, i } } IColIndex subCols = ColIndexFactory.create(cols.size()); - ReaderColumnSelection reader = ReaderColumnSelection.createReader(sub, subCols, false); + ReaderColumnSelection reader = (cs.scaleFactors == null) ? ReaderColumnSelection.createReader(sub, subCols, + false) : ReaderColumnSelection.createQuantizedReader(sub, subCols, false, cs.scaleFactors); + final int mapStartSize = Math.min(nrUniqueEstimate, offsetsInt.length / 2); DblArrayCountHashMap map = new DblArrayCountHashMap(mapStartSize); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index cf0959bba7f..b6bb5fd92cb 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -27,6 +27,7 @@ import java.util.List; import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; @@ -81,7 +82,7 @@ public class ColGroupUncompressed extends AColGroup { private final MatrixBlock _data; /** - * Do not use this constructor of column group uncompressed, instead uce the create constructor. + * Do not use this constructor of column group uncompressed, instead use the create constructor. * @param mb The contained data. * @param colIndexes Column indexes for this Columngroup */ @@ -90,6 +91,25 @@ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { _data = mb; } + /** + * Do not use this constructor of column group quantization-fused uncompressed, instead use the create constructor. + * @param mb The contained data. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Column indexes for this Columngroup + */ + protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { + super(colIndexes); + // Apply scaling and flooring + // TODO: Use internal matrix prod + for(int r = 0; r < mb.getNumRows(); r++) { + double scaleFactor = scaleFactors.length == 1 ? scaleFactors[0] : scaleFactors[r]; + for(int c = 0; c < mb.getNumColumns(); c++) { + double newValue = Math.floor(mb.get(r, c) * scaleFactor); + mb.set(r, c, newValue); + } + } + _data = mb; + } /** * Create an Uncompressed Matrix Block, where the columns are offset by col indexes. * @@ -106,6 +126,97 @@ public static AColGroup create(MatrixBlock mb, IColIndex colIndexes) { return new ColGroupUncompressed(mb, colIndexes); } + /** + * Create ana quantization-fused uncompressed Matrix Block, where the columns are offset by col indexes. + * + * It is assumed that the size of the colIndexes and number of columns in mb is matching. + * + * @param mb The MB / data to contain in the uncompressed column + * @param colIndexes The column indexes for the group + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @return An Uncompressed Column group + */ + public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { + if(mb == null || mb.isEmpty()) + // TODO: handle quantization-fused compression if deemed necessary, + // but if the matrix reaches here, it likely doesn't need quantization. + return new ColGroupEmpty(colIndexes); + else + return new ColGroupUncompressed(mb, colIndexes, scaleFactors); + } + + /** + * Main constructor for a quantization-fused uncompressed ColGroup. + * + * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. + * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is + * called + * @param transposed Says if the input matrix raw block have been transposed. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @return AColGroup. + */ + public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, double[] scaleFactors) { + + // special cases + if(rawBlock.isEmptyBlock(false)) // empty input + // TODO: handle quantization-fused compression if deemed necessary, + // but if the matrix reaches here, it likely doesn't need quantization. + return new ColGroupEmpty(colIndexes); + else if(!transposed && colIndexes.size() == rawBlock.getNumColumns()) + // full input to uncompressedColumnGroup + return new ColGroupUncompressed(rawBlock, colIndexes, scaleFactors); + + MatrixBlock mb; + final int _numRows = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); + + if(colIndexes.size() == 1) { + final int col = colIndexes.get(0); + if(transposed) { + mb = rawBlock.slice(col, col, 0, rawBlock.getNumColumns() - 1); + mb = LibMatrixReorg.transposeInPlace(mb, InfrastructureAnalyzer.getLocalParallelism()); + } + else + mb = rawBlock.slice(0, rawBlock.getNumRows() - 1, col, col); + + return createQuantized(mb, colIndexes, scaleFactors); + } + + // Create a matrix with just the requested rows of the original block + mb = new MatrixBlock(_numRows, colIndexes.size(), rawBlock.isInSparseFormat()); + + final int m = _numRows; + final int n = colIndexes.size(); + + if(transposed) { + if (scaleFactors.length == 1) { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); + } else { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[j])); + } + } + else { + if (scaleFactors.length == 1) { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); + } else { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[i])); + } + } + + mb.recomputeNonZeros(); + mb.examSparsity(); + + return create(mb, colIndexes); + + } + /** * Main constructor for Uncompressed ColGroup. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java index 6483eba1048..48dd245c4e6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java @@ -38,7 +38,7 @@ public ComEstExact(MatrixBlock data, CompressionSettings compSettings) { @Override public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { - final IEncode map = EncodingFactory.createFromMatrixBlock(_data, _cs.transposed, colIndexes); + final IEncode map = EncodingFactory.createFromMatrixBlock(_data, _cs.transposed, colIndexes, _cs.scaleFactors); if(map instanceof EmptyEncoding) return new CompressedSizeInfoColGroup(colIndexes, getNumRows(), CompressionType.EMPTY); return getFacts(map, colIndexes); @@ -59,7 +59,7 @@ protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, Compress protected CompressedSizeInfoColGroup getFacts(IEncode map, IColIndex colIndexes) { final int _numRows = getNumRows(); - final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); + final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java index 97b451daeef..8a28a0ca49d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java @@ -90,7 +90,7 @@ public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int esti _data.getSparseBlock().isEmpty(colIndexes.get(0)))) return new CompressedSizeInfoColGroup(colIndexes, getNumRows(), CompressionType.EMPTY); - final IEncode map = EncodingFactory.createFromMatrixBlock(_sample, _transposed, colIndexes); + final IEncode map = EncodingFactory.createFromMatrixBlock(_sample, _transposed, colIndexes, _cs.scaleFactors); return extractInfo(map, colIndexes, maxDistinct); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java index 6cc882cc2f9..963a044d14f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java @@ -206,6 +206,10 @@ public double getTupleSparsity() { return _facts.tupleSparsity; } + public EstimationFactors getFacts() { + return _facts; + } + public IEncode getMap() { return _map; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java index 257ddf6f3c2..b196da658c3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java @@ -56,10 +56,24 @@ public interface EncodingFactory { public static IEncode createFromMatrixBlock(MatrixBlock m, boolean transposed, IColIndex rowCols) { if(m.isEmpty()) return new EmptyEncoding(); - else if(rowCols.size() == 1) - return createFromMatrixBlock(m, transposed, rowCols.get(0)); - else - return createWithReader(m, rowCols, transposed); + else if(rowCols.size() == 1) { + return createFromMatrixBlock(m, transposed, rowCols.get(0), null); + } + else { + return createWithReader(m, rowCols, transposed, null); + } + } + + public static IEncode createFromMatrixBlock(MatrixBlock m, boolean transposed, IColIndex rowCols, + double[] scaleFactors) { + if(m.isEmpty()) + return new EmptyEncoding(); + else if(rowCols.size() == 1) { + return createFromMatrixBlock(m, transposed, rowCols.get(0), scaleFactors); + } + else { + return createWithReader(m, rowCols, transposed, scaleFactors); + } } /** @@ -115,8 +129,54 @@ else if(transposed) { } else if(m.isInSparseFormat()) return createFromSparse(m, rowCol); - else + else { return createFromDense(m, rowCol); + } + } + + /** + * Create encoding of a single specific column inside the matrix input. + * + * @param m The Matrix to encode a column from + * @param transposed If the matrix is in transposed format. + * @param rowCol The column index to encode + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire + * matrix + * @return An encoded format of the information of this column. + */ + public static IEncode createFromMatrixBlock(MatrixBlock m, boolean transposed, int rowCol, double[] scaleFactors) { + if(m.isEmpty()) + return new EmptyEncoding(); + else if(transposed) { + if(scaleFactors != null) { + if(m.isInSparseFormat()) + throw new NotImplementedException(); + else + return createFromDenseTransposedQuantized(m, rowCol, scaleFactors); + } + else { + if(m.isInSparseFormat()) + return createFromSparseTransposed(m, rowCol); + else + return createFromDenseTransposed(m, rowCol); + } + } + else if(m.isInSparseFormat()) { + if(scaleFactors != null) { + throw new NotImplementedException(); // TODO: handle quantization-fused compression + } + else { + return createFromSparse(m, rowCol); + } + } + else { + if(scaleFactors != null) { + return createFromDenseQuantized(m, rowCol, scaleFactors); + } + else { + return createFromDense(m, rowCol); + } + } } public static IEncode create(ColGroupConst c) { @@ -229,13 +289,13 @@ else if(alen - apos > nCol / 4) { // return a dense encoding // Iteration 3 of non zero indexes, make a Offset Encoding to know what cells are zero and not. // not done yet - try{ + try { final AOffset o = OffsetFactory.createOffset(aix, apos, alen); return new SparseEncoding(d, o, m.getNumColumns()); } - catch(Exception e){ - String mes = Arrays.toString(Arrays.copyOfRange(aix, apos, alen)) + "\n" + apos + " " + alen; + catch(Exception e) { + String mes = Arrays.toString(Arrays.copyOfRange(aix, apos, alen)) + "\n" + apos + " " + alen; mes += Arrays.toString(Arrays.copyOfRange(avals, apos, alen)); throw new DMLRuntimeException(mes, e); } @@ -341,8 +401,229 @@ private static IEncode createFromSparse(MatrixBlock m, int col) { return new SparseEncoding(d, o, m.getNumRows()); } - private static IEncode createWithReader(MatrixBlock m, IColIndex rowCols, boolean transposed) { - final ReaderColumnSelection reader1 = ReaderColumnSelection.createReader(m, rowCols, transposed); + private static IEncode createFromDenseTransposedQuantized(MatrixBlock m, int row, double[] scaleFactors) { + final DenseBlock db = m.getDenseBlock(); + if(!db.isContiguous()) + throw new NotImplementedException("Not Implemented non contiguous dense matrix encoding for sample"); + final DoubleCountHashMap map = new DoubleCountHashMap(); + final int off = db.pos(row); + final int nCol = m.getNumColumns(); + final int end = off + nCol; + final double[] vals = db.values(row); + + // Validate scaleFactors + boolean useSingleScalar = false; + if(scaleFactors != null) { + if(scaleFactors.length == 1) { + useSingleScalar = true; + } + } + + if(useSingleScalar == true) { + + // Iteration 1: Apply scaling & quantization, then populate the HashMap + for(int i = off; i < end; i++) // sequential access + map.increment(Math.floor(vals[i] * scaleFactors[0])); + + final int nUnique = map.size(); + + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + else if(nUnique == 0) + return new EmptyEncoding(); + else if(map.getOrDefault(0.0, -1) > nCol / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = nCol - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + if (value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new RuntimeException("Did not find equal number of elements " + di + " vs " + nV); + + final AOffset o = OffsetFactory.createOffset(offsets); + return new SparseEncoding(d, o, nCol); + } + else { + // Create output map + final AMapToData d = MapToFactory.create(nCol, nUnique); + + // Iteration 2, make final map + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + d.set(r, map.getId(value)); + } + + return new DenseEncoding(d); + } + } + else { + // Iteration 1: Apply scaling & quantization, then populate the HashMap + for(int i = off; i < end; i++) // sequential access + map.increment(Math.floor(vals[i] * scaleFactors[row])); + + final int nUnique = map.size(); + + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + else if(nUnique == 0) + return new EmptyEncoding(); + else if(map.getOrDefault(0.0, -1) > nCol / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = nCol - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[row]); + if (value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new RuntimeException("Did not find equal number of elements " + di + " vs " + nV); + + final AOffset o = OffsetFactory.createOffset(offsets); + return new SparseEncoding(d, o, nCol); + } + else { + // Create output map + final AMapToData d = MapToFactory.create(nCol, nUnique); + + // Iteration 2, make final map + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[row]); + d.set(r, map.getId(value)); + } + + return new DenseEncoding(d); + } + + } + } + + private static IEncode createFromDenseQuantized(MatrixBlock m, int col, double[] scaleFactors) { + final DenseBlock db = m.getDenseBlock(); + if(!db.isContiguous()) + throw new NotImplementedException("Not Implemented non contiguous dense matrix encoding for sample"); + final DoubleCountHashMap map = new DoubleCountHashMap(16); + final int off = col; + final int nCol = m.getNumColumns(); + final int nRow = m.getNumRows(); + final int end = off + nRow * nCol; + final double[] vals = m.getDenseBlockValues(); + + // Validate scaleFactors + boolean useSingleScalar = false; + if(scaleFactors != null) { + if(scaleFactors.length == 1) { + useSingleScalar = true; + } + } + + if(useSingleScalar == true) { + // Iteration 1, make Count HashMap with quantized values + for(int i = off; i < end; i += nCol) {// jump down through rows. + map.increment(Math.floor(vals[i] * scaleFactors[0])); + } + final int nUnique = map.size(); + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + + if(map.getOrDefault(0.0, -1) > nRow / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = m.getNumRows() - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + if(value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new DMLRuntimeException("Invalid number of zero."); + + final AOffset o = OffsetFactory.createOffset(offsets); + + return new SparseEncoding(d, o, nRow); + } + else { + // Allocate counts, and iterate once to replace counts with u ids + + final AMapToData d = MapToFactory.create(nRow, nUnique); + // Iteration 2, make final map with quantized values + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + d.set(r, map.getId(value)); + } + return new DenseEncoding(d); + } + } + else { + // Iteration 1, make Count HashMap with row-wise quantized values + for(int i = off, r = 0; i < end; i += nCol, r++) {// jump down through rows. + map.increment(Math.floor(vals[i] * scaleFactors[r])); + } + final int nUnique = map.size(); + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + + if(map.getOrDefault(0.0, -1) > nRow / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = m.getNumRows() - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[r]); + if(value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new DMLRuntimeException("Invalid number of zero."); + + final AOffset o = OffsetFactory.createOffset(offsets); + + return new SparseEncoding(d, o, nRow); + } + else { + // Allocate counts, and iterate once to replace counts with u ids + + final AMapToData d = MapToFactory.create(nRow, nUnique); + // Iteration 2, make final map with row-wise quantized values + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[r]); + d.set(r, map.getId(value)); + } + return new DenseEncoding(d); + } + } + } + + private static IEncode createWithReader(MatrixBlock m, IColIndex rowCols, boolean transposed, + double[] scaleFactors) { + final ReaderColumnSelection reader1 = (scaleFactors == null) ? ReaderColumnSelection.createReader(m, rowCols, + transposed) : ReaderColumnSelection.createQuantizedReader(m, rowCols, transposed, scaleFactors); final int nRows = transposed ? m.getNumColumns() : m.getNumRows(); final DblArrayCountHashMap map = new DblArrayCountHashMap(); final IntArrayList offsets = new IntArrayList(); @@ -362,17 +643,18 @@ else if(map.size() == 1 && offsets.size() == nRows) if(offsets.size() < nRows / 4) // Output encoded sparse since there is very empty. - return createWithReaderSparse(m, map, rowCols, offsets, nRows, transposed); + return createWithReaderSparse(m, map, rowCols, offsets, nRows, transposed, scaleFactors); else - return createWithReaderDense(m, map, rowCols, nRows, transposed, offsets.size() < nRows); + return createWithReaderDense(m, map, rowCols, nRows, transposed, offsets.size() < nRows, scaleFactors); } private static IEncode createWithReaderDense(MatrixBlock m, DblArrayCountHashMap map, IColIndex rowCols, int nRows, - boolean transposed, boolean zero) { + boolean transposed, boolean zero, double[] scaleFactors) { // Iteration 2, final int unique = map.size() + (zero ? 1 : 0); - final ReaderColumnSelection reader2 = ReaderColumnSelection.createReader(m, rowCols, transposed); + final ReaderColumnSelection reader2 = (scaleFactors == null) ? ReaderColumnSelection.createReader(m, rowCols, + transposed) : ReaderColumnSelection.createQuantizedReader(m, rowCols, transposed, scaleFactors); final AMapToData d = MapToFactory.create(nRows, unique); DblArray cellVals; @@ -387,8 +669,9 @@ private static IEncode createWithReaderDense(MatrixBlock m, DblArrayCountHashMap } private static IEncode createWithReaderSparse(MatrixBlock m, DblArrayCountHashMap map, IColIndex rowCols, - IntArrayList offsets, int nRows, boolean transposed) { - final ReaderColumnSelection reader2 = ReaderColumnSelection.createReader(m, rowCols, transposed); + IntArrayList offsets, int nRows, boolean transposed, double[] scaleFactors) { + final ReaderColumnSelection reader2 = (scaleFactors == null) ? ReaderColumnSelection.createReader(m, rowCols, + transposed) : ReaderColumnSelection.createQuantizedReader(m, rowCols, transposed, scaleFactors); DblArray cellVals = reader2.nextRow(); final AMapToData d = MapToFactory.create(offsets.size(), map.size()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java index e087525bbbd..d6ec60336f0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java +++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.compress.readers; +import org.apache.commons.lang3.NotImplementedException; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; @@ -83,7 +85,7 @@ public int getCurrentRowIndex() { } /** - * Create an reader of the matrix block that is able to iterate though all the rows and return as dense double + * Create a reader of the matrix block that is able to iterate though all the rows and return as dense double * arrays. * * Note the reader reuse the return, therefore if needed for something please copy the returned rows. @@ -100,7 +102,30 @@ public static ReaderColumnSelection createReader(MatrixBlock rawBlock, IColIndex } /** - * Create an reader of the matrix block that is able to iterate though all the rows and return as dense double + * Create a reader of the matrix block that directly reads quantized values using scale factors. + * + * Note the reader reuse the return, therefore if needed for something please copy the returned rows. + * + * @param rawBlock The block to iterate though + * @param colIndices The column indexes to extract and insert into the double array + * @param transposed If the raw block should be treated as transposed + * @param scaleFactors An array of scale factors applied. + * - If row-wise scaling is used, this should be an array where each value corresponds to a row. + * - If a single scalar is provided, it is applied uniformly to the entire matrix. + * @return A reader of the columns specified + */ + + public static ReaderColumnSelection createQuantizedReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed, double[] scaleFactors) { + if (transposed) { + throw new NotImplementedException(); + } + final int rl = 0; + final int ru = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); + return createQuantizedReader(rawBlock, colIndices, transposed, rl, ru, scaleFactors); + } + + /** + * Create a reader of the matrix block that is able to iterate though all the rows and return as dense double * arrays. * * Note the reader reuse the return, therefore if needed for something please copy the returned rows. @@ -136,6 +161,40 @@ else if(rawBlock.getDenseBlock().numBlocks() > 1) return new ReaderColumnSelectionDenseSingleBlock(rawBlock, colIndices, rl, ru); } + /** + * Create a reader of the matrix block that directly reads quantized values using scale factors. + * + * Note the reader reuse the return, therefore if needed for something please copy the returned rows. + * + * @param rawBlock The block to iterate though + * @param colIndices The column indexes to extract and insert into the double array + * @param transposed If the raw block should be treated as transposed + * @param rl The row to start at + * @param ru The row to end at (not inclusive) + * @param scaleFactors An array of scale factors applied. + * - If row-wise scaling is used, this should be an array where each value corresponds to a row. + * - If a single scalar is provided, it is applied uniformly to the entire matrix. + * @return A reader of the columns specified + */ + public static ReaderColumnSelection createQuantizedReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed, + int rl, int ru, double[] scaleFactors) { + checkInput(rawBlock, colIndices, rl, ru, transposed); + rl = rl - 1; + if(rawBlock.isEmpty()) { + LOG.warn("It is likely an error occurred when reading an empty block, but we do support it!"); + return new ReaderColumnSelectionEmpty(rawBlock, colIndices, rl, ru, transposed); + } + else if(transposed) { + throw new NotImplementedException(); + } + else if(rawBlock.isInSparseFormat()) { + throw new NotImplementedException(); + } + else { + return new ReaderColumnSelectionDenseSingleBlockQuantized(rawBlock, colIndices, rl, ru, scaleFactors); + } + } + private static void checkInput(final MatrixBlock rawBlock, final IColIndex colIndices, final int rl, final int ru, final boolean transposed) { if(colIndices.size() <= 1) diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockQuantized.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockQuantized.java new file mode 100644 index 00000000000..645e694bb43 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockQuantized.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.readers; + +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.DblArray; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +public class ReaderColumnSelectionDenseSingleBlockQuantized extends ReaderColumnSelection { + private final double[] _data; + private final int _numCols; + private final double[] _scaleFactors; + + protected ReaderColumnSelectionDenseSingleBlockQuantized(MatrixBlock data, IColIndex colIndices, int rl, int ru, + double[] scaleFactors) { + super(colIndices, rl, Math.min(ru, data.getNumRows()) - 1); + _data = data.getDenseBlockValues(); + _numCols = data.getNumColumns(); + _scaleFactors = scaleFactors; + } + + protected DblArray getNextRow() { + + _rl++; + final int indexOff = _rl * _numCols; + double scaleFactor = _scaleFactors.length == 1 ? _scaleFactors[0] : _scaleFactors[_rl]; + + for(int i = 0; i < _colIndexes.size(); i++) + reusableArr[i] = Math.floor(_data[indexOff + _colIndexes.get(i)] * scaleFactor); + + return reusableReturn; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index e30183a5067..8ab97689929 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -196,7 +196,11 @@ public static CPInstruction parseSingleInstruction ( CPType cptype, String str ) case DeCompression: return DeCompressionCPInstruction.parseInstruction(str); - + + case QuantizeCompression: + LOG.debug("Parsing Quantize Compress instruction"); + return CompressionCPInstruction.parseQuantizationFusedInstruction(str); + case Local: return LocalCPInstruction.parseInstruction(str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java index ccc97fd15c3..8b79ea70ef9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java @@ -65,8 +65,9 @@ public static Instruction[] parseMixedInstructions ( String str ) { return null; String[] strlist = str.split(Instruction.INSTRUCTION_DELIM); Instruction[] inst = new Instruction[strlist.length]; - for ( int i=0; i < inst.length; i++ ) + for ( int i=0; i < inst.length; i++ ) { inst[i] = parseSingleInstruction ( strlist[i] ); + } return inst; } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index 98ebc83d7b3..ff01d2ff479 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -270,6 +270,7 @@ public static SPType getSPType(String str) { public static CPType getCPType(String str) { String opcode = getOpCode(str); + LOG.debug(opcode); return Opcodes.getCPTypeByOpcode(opcode); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java index f8527276a7a..b0b502f8a09 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java @@ -46,6 +46,7 @@ public enum CPType { StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote, EvictLineageCache, NoOp, + QuantizeCompression } protected final CPType _cptype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index efc8e217771..52e0ad81540 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -44,6 +44,9 @@ public class CompressionCPInstruction extends ComputationCPInstruction { private final int _singletonLookupID; private final int _numThreads; + /** This is set to true only for quantization-fused compression */ + private final boolean _quantizationFused; + /** This is only for binned compression with 2 outputs */ protected final List _outputs; @@ -53,6 +56,7 @@ private CompressionCPInstruction(Operator op, CPOperand in, CPOperand out, Strin _outputs = null; this._singletonLookupID = singletonLookupID; this._numThreads = numThreads; + this._quantizationFused = false; } private CompressionCPInstruction(Operator op, CPOperand in1, CPOperand in2, List out, String opcode, @@ -61,8 +65,18 @@ private CompressionCPInstruction(Operator op, CPOperand in1, CPOperand in2, List _outputs = out; this._singletonLookupID = singletonLookupID; this._numThreads = numThreads; + this._quantizationFused = false; } + private CompressionCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, + String istr, int singletonLookupID, int numThreads) { + super(CPType.QuantizeCompression, op, in1, in2, null, out, opcode, istr); + _outputs = null; + this._singletonLookupID = singletonLookupID; + this._numThreads = numThreads; + this._quantizationFused = true; + } + public static CompressionCPInstruction parseInstruction(String str) { InstructionUtils.checkNumFields(str, 3, 4, 5); String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); @@ -89,10 +103,23 @@ else if(parts.length == 5) { } } + public static CompressionCPInstruction parseQuantizationFusedInstruction(String str) { + InstructionUtils.checkNumFields(str, 3, 4, 5); + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[3]); + int numThreads = Integer.parseInt(parts[4]); + return new CompressionCPInstruction(null, in1, in2, out, opcode, str, 0, numThreads); + } + @Override public void processInstruction(ExecutionContext ec) { if(input2 == null) processSimpleCompressInstruction(ec); + else if (this._quantizationFused == true) + processSimpleQuantizationFusedCompressInstruction(ec); else processCompressByBinInstruction(ec); } @@ -143,6 +170,28 @@ else if(ec.isMatrixObject(input1.getName())) } } + private void processSimpleQuantizationFusedCompressInstruction(ExecutionContext ec) { + // final MatrixBlock in = ec.getMatrixInput(input1.getName()); + final SingletonLookupHashMap m = SingletonLookupHashMap.getMap(); + + // Get and clear workload tree entry for this compression instruction. + final WTreeRoot root = (_singletonLookupID != 0) ? (WTreeRoot) m.get(_singletonLookupID) : null; + // We used to remove the key from the hash map, + // however this is not correct since the compression statement + // can be reused in multiple for loops. + + ScalarObject scalarIn2 = null; + MatrixBlock matrixIn2 = null; + + if (input2.isScalar() == true) { + scalarIn2 = ec.getScalarInput(input2); + processMatrixBlockQuantizationFusedCompression(ec, ec.getMatrixInput(input1.getName()), scalarIn2, _numThreads, root); + } else if (input2.isMatrix() == true) { + matrixIn2 = ec.getMatrixInput(input2.getName()); + processMatrixBlockQuantizationFusedCompression(ec, ec.getMatrixInput(input1.getName()), matrixIn2, _numThreads, root); + } + } + private void processMatrixBlockCompression(ExecutionContext ec, MatrixBlock in, int k, WTreeRoot root) { Pair compResult = CompressedMatrixBlockFactory.compress(in, k, root); if(LOG.isTraceEnabled()) @@ -161,4 +210,32 @@ private void processFrameBlockCompression(ExecutionContext ec, FrameBlock in, in ec.releaseFrameInput(input1.getName()); ec.setFrameOutput(output.getName(), compResult); } + + private void processMatrixBlockQuantizationFusedCompression(ExecutionContext ec, MatrixBlock in1, MatrixBlock in2, int k, WTreeRoot root) { + Pair compResult = CompressedMatrixBlockFactory.compress(in1, in2, k, root); + if(LOG.isTraceEnabled()) + LOG.trace(compResult.getRight()); + MatrixBlock out = compResult.getLeft(); + if(LOG.isInfoEnabled()) + LOG.info("Compression output class: " + out.getClass().getSimpleName()); + // Set output and release input + ec.releaseMatrixInput(input1.getName()); + ec.releaseMatrixInput(input2.getName()); + ec.setMatrixOutput(output.getName(), out); + } + + private void processMatrixBlockQuantizationFusedCompression(ExecutionContext ec, MatrixBlock in1, ScalarObject in2, int k, WTreeRoot root) { + Pair compResult = CompressedMatrixBlockFactory.compress(in1, in2, k, root); + if(LOG.isTraceEnabled()) + LOG.trace(compResult.getRight()); + MatrixBlock out = compResult.getLeft(); + if(LOG.isInfoEnabled()) + LOG.info("Compression output class: " + out.getClass().getSimpleName()); + // Set output and release input + ec.releaseMatrixInput(input1.getName()); + if (input2.isMatrix()) { + ec.releaseMatrixInput(input2.getName()); + } + ec.setMatrixOutput(output.getName(), out); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/qcompress/CompareCompressionTypeTest.java b/src/test/java/org/apache/sysds/test/component/compress/qcompress/CompareCompressionTypeTest.java new file mode 100644 index 00000000000..270ce53acce --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/qcompress/CompareCompressionTypeTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.qcompress; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.refEq; + +import java.util.List; +import java.util.Random; + +import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.cocode.CoCoderFactory.PartitionerType; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.cost.ACostEstimate; +import org.apache.sysds.runtime.compress.estim.AComEst; +import org.apache.sysds.runtime.compress.estim.ComEstFactory; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class CompareCompressionTypeTest { + + /** + * Test 1: Compare the best compression types of two matrices, m0 and m1: DDC. + * + * - m0 is generated as a floored matrix. - m1 is generated as a full-precision matrix, but will be internally + * multiplied by 1.0 and floored. - Since m1 undergoes an equivalent transformation (scaling by 1.0 and flooring), + * the best compression types determined by the estimator should match elementwise for both matrices. - This + * validates that the estimator correctly handles explicit flooring vs. internal scaling and flooring during + * quantization-fused compression. + */ + @Test + public void testCompareBestCompressionTypeForTwoMatricesDDC() { + try { + Random r = new Random(1234); + int k = 4; + + // Generate first floored matrix and compute compression info + MatrixBlock m0 = generateTestMatrix(10000, 500, 1, 100, 1.0, r, true); + CompressionSettings cs0 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setSeed(1234).create(); + AComEst estimator0 = ComEstFactory.createEstimator(m0, cs0, k); + CompressedSizeInfo compressedGroups0 = estimator0.computeCompressedSizeInfos(k); + + // Generate second matrix full-precision matrix that will be internally scaled by 1.0 and floored and + // compute + // compression info + MatrixBlock m1 = generateTestMatrix(10000, 500, 1, 100, 1.0, r, false); + double[] scaleFactor = {1.0}; + CompressionSettings cs1 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setScaleFactor(scaleFactor).setSeed(1234).create(); + AComEst estimator1 = ComEstFactory.createEstimator(m1, cs1, k); + CompressedSizeInfo compressedGroups1 = estimator1.computeCompressedSizeInfos(k); + + List groups0 = compressedGroups0.getInfo(); + List groups1 = compressedGroups1.getInfo(); + + assertEquals("Mismatch in number of compressed groups", groups0.size(), groups1.size()); + + for(int i = 0; i < groups0.size(); i++) { + assertEquals("Best compression type mismatch at index " + i, groups0.get(i).getBestCompressionType(), + groups1.get(i).getBestCompressionType()); + } + + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + /** + * Test 1: Compare the best compression types of two matrices, m0 and m1: CONST. + * + * - m0 is generated as a floored matrix. - m1 is generated as a full-precision matrix, but will be internally + * multiplied by 1.0 and floored. - Since m1 undergoes an equivalent transformation (scaling by 1.0 and flooring), + * the best compression types determined by the estimator should match elementwise for both matrices. - This + * validates that the estimator correctly handles explicit flooring vs. internal scaling and flooring during + * quantization-fused compression. + */ + @Test + public void testCompareBestCompressionTypeForTwoMatricesConst() { + try { + Random r = new Random(1234); + int k = 4; + + // Generate first floored matrix and compute compression info + MatrixBlock m0 = generateTestMatrix(10000, 500, 1, 1, 1.0, r, true); + CompressionSettings cs0 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setSeed(1234).create(); + AComEst estimator0 = ComEstFactory.createEstimator(m0, cs0, k); + CompressedSizeInfo compressedGroups0 = estimator0.computeCompressedSizeInfos(k); + + // Generate second matrix full-precision matrix that will be internally scaled by 1.0 and floored and + // compute + // compression info + MatrixBlock m1 = generateTestMatrix(10000, 500, 1, 1, 1.0, r, false); + double[] scaleFactor = {1.0}; + CompressionSettings cs1 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setScaleFactor(scaleFactor).setSeed(1234).create(); + AComEst estimator1 = ComEstFactory.createEstimator(m1, cs1, k); + CompressedSizeInfo compressedGroups1 = estimator1.computeCompressedSizeInfos(k); + + List groups0 = compressedGroups0.getInfo(); + List groups1 = compressedGroups1.getInfo(); + + assertEquals("Mismatch in number of compressed groups", groups0.size(), groups1.size()); + + for(int i = 0; i < groups0.size(); i++) { + assertEquals("Best compression type mismatch at index " + i, groups0.get(i).getBestCompressionType(), + groups1.get(i).getBestCompressionType()); + } + + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + /** + * Generate a test matrix with specified dimensions, value range, and sparsity. + */ + private static MatrixBlock generateTestMatrix(int nRow, int nCol, int min, int max, double s, Random r, + boolean floored) { + final int m = Integer.MAX_VALUE; + MatrixBlock mb = TestUtils.generateTestMatrixBlock(nRow, nCol, min, max, s, r.nextInt(m)); + return floored ? TestUtils.floor(mb) : mb; + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedCompressionTest.java b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedCompressionTest.java new file mode 100644 index 00000000000..67aa86dfa92 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedCompressionTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.qcompress; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.CompressionStatistics; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.apache.commons.lang3.tuple.Pair; + +/** + * This class tests the quantization-fused compression in SystemDS. + */ +public class QuantizationFusedCompressionTest { + + /** + * Test 1: Quantization-fused Compression with a scalar scaling factor. + */ + @Test + public void testQuantizationCompressionWithScalar() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(4, 4, 1, 10, 1.0, 1234); + ScalarObject sf = new DoubleObject(2.5); + Pair result = CompressedMatrixBlockFactory.compress(mb, sf, 1, null); + MatrixBlock qmb = result.getLeft(); + for(int i = 0; i < mb.getNumRows(); i++) { + for(int j = 0; j < mb.getNumColumns(); j++) { + double expected = Math.floor(mb.get(i, j) * sf.getDoubleValue()); + assertEquals("Quantized compression mismatch!", expected, qmb.get(i, j), 0.0); + } + } + } + + /** + * Test 2: Quantization-fused compression with row-wise vector scaling. + */ + @Test + public void testQuantizationCompressionWithRowwiseVectorScale() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(5, 4, 1, 10, 1.0, 5678); + MatrixBlock sf = new MatrixBlock(5, 1, false); + sf.set(0, 0, 1.5); + sf.set(1, 0, 2.0); + sf.set(2, 0, 2.5); + sf.set(3, 0, 3.0); + sf.set(4, 0, 3.5); + Pair result = CompressedMatrixBlockFactory.compress(mb, sf, 1, null); + MatrixBlock qmb = result.getLeft(); + for(int i = 0; i < mb.getNumRows(); i++) { + for(int j = 0; j < mb.getNumColumns(); j++) { + double expected = Math.floor(mb.get(i, j) * sf.get(i, 0)); + assertEquals("Quantized compression mismatch!", expected, qmb.get(i, j), 0.0); + } + } + } + + /** + * Test 3: Compare compression statistics of two matrices, m0 and m1, where m0 is derived as m0 = floor(m1 * sf) + * with sf = 0.5. + * + * - Compression for m0 is aborted at phase 1 (before co-code). + * - Compression for m1 should also be aborted at the same phase. + * - The resulting compression statistics for both matrices should match. + */ + @Test + public void testQuantizationFusedCompressionAbortedBeforeCoCodeStats() { + double[][] values0 = {{0, 1, 1, 2, 2}, {3, 3, 4, 4, 5}, {5, 6, 6, 7, 7}, {8, 8, 9, 9, 10}, {10, 11, 11, 12, 12}, + {13, 13, 14, 14, 15}}; + MatrixBlock m0 = DataConverter.convertToMatrixBlock(values0); + m0.recomputeNonZeros(); + + Pair cm0 = CompressedMatrixBlockFactory.compress(m0); + CompressionStatistics stats0 = cm0.getRight(); + + MatrixBlock m1 = new MatrixBlock(6, 5, false); + int val = 1; + for(int i = 0; i < 6; i++) { + for(int j = 0; j < 5; j++) { + m1.set(i, j, val++); + } + } + m1.recomputeNonZeros(); + + DoubleObject sf = new DoubleObject(0.5); + Pair cm1 = CompressedMatrixBlockFactory.compress(m1, sf, 1, null); + CompressionStatistics stats1 = cm1.getRight(); + + assertTrue("Compression statistics must match", stats0.toString().equals(stats1.toString())); + // Since m0 and m1 have different values their number of non-zero values is different + // assertEquals("Non-zero count should match", m0.getNonZeros(), m1.getNonZeros(), 0.1); + } + + /** + * Test 4: Compare compression statistics of two matrices, m0 and m1, where m0 is derived as m0 = floor(m1 * sf) + * with sf = 0.3. + * + * - Compression for m0 is aborted at phase 2 (after co-code). + * - Compression for m1 should also be aborted at the same phase. + * - The resulting compression statistics for both matrices should match. + */ + @Test + public void testQuantizationFusedCompressionAbortedAfterCoCodeStats() { + double[][] values1 = {{1, 8, 3, 4, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, + {3, 4, 5, 6, 7}}; + MatrixBlock m1 = DataConverter.convertToMatrixBlock(values1); + m1.recomputeNonZeros(); + + double scaleFactor = 0.3; + MatrixBlock m0 = new MatrixBlock(m1.getNumRows(), m1.getNumColumns(), false); + for(int i = 0; i < m1.getNumRows(); i++) { + for(int j = 0; j < m1.getNumColumns(); j++) { + m0.set(i, j, Math.floor(m1.get(i, j) * scaleFactor)); + } + } + m0.recomputeNonZeros(); + + Pair cm0 = CompressedMatrixBlockFactory.compress(m0); + CompressionStatistics stats0 = cm0.getRight(); + + DoubleObject sf = new DoubleObject(scaleFactor); + Pair cm1 = CompressedMatrixBlockFactory.compress(m1, sf, 1, null); + CompressionStatistics stats1 = cm1.getRight(); + + assertTrue("Compression statistics must match", stats0.toString().equals(stats1.toString())); + // Since m0 and m1 have different values their number of non-zero values is different + // assertEquals("Non-zero count should match", m0.getNonZeros(), m1.getNonZeros(), 0.1); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedForcedCompressionTypesTest.java b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedForcedCompressionTypesTest.java new file mode 100644 index 00000000000..60b8af359ad --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedForcedCompressionTypesTest.java @@ -0,0 +1,353 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information regarding copyright ownership. +* The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software distributed under the License +* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +* either express or implied. See the License for the specific language governing permissions and limitations +* under the License. +*/ + +package org.apache.sysds.test.component.compress.qcompress; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import java.util.*; +import org.apache.commons.lang3.tuple.Pair; + + +import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.cocode.CoCoderFactory; +import org.apache.sysds.runtime.compress.cocode.CoCoderFactory.PartitionerType; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE; +import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.cost.ACostEstimate; +import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory; +import org.apache.sysds.runtime.compress.estim.AComEst; +import org.apache.sysds.runtime.compress.estim.ComEstFactory; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class QuantizationFusedForcedCompressionTypesTest { + + private static final int K = 4; + private static final long SEED = 1234; + + /** + * Test 1: Test the Uncompressed column group by directly calling the create method. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be internally + * multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). The best compression types for both + * matrices are DDC, but we explicitly create UNCOMPRESSED columns. + * + */ + @Test + public void testForcedUncompressed() { + try { + MatrixBlock m0 = generateTestMatrix(10000, 500, -100, 100, 1.0, SEED, true); + MatrixBlock m1 = generateTestMatrix(10000, 500, -100, 100, 1.0, SEED, false); + + CompressionSettings cs0 = createCompressionSettings(null); + CompressionSettings cs1 = createCompressionSettings(new double[] {1.0}); + + Pair compressedGroupsResult0 = generateCompressedGroups(m0, cs0); + CompressedSizeInfo compressedGroups0 = compressedGroupsResult0.getLeft(); + + Pair compressedGroupsResult1 = generateCompressedGroups(m1, cs1); + CompressedSizeInfo compressedGroups1 = compressedGroupsResult1.getLeft(); + + assertEquals("Mismatch in number of compressed groups", compressedGroups0.getInfo().size(), + compressedGroups1.getInfo().size(), 0.0); + + for(int i = 0; i < compressedGroups0.getInfo().size(); i++) { + AColGroup colGroup0 = ColGroupUncompressed.create(compressedGroups0.getInfo().get(i).getColumns(), m0, + cs0.transposed); + AColGroup colGroup1 = ColGroupUncompressed.createQuantized( + compressedGroups1.getInfo().get(i).getColumns(), m1, cs1.transposed, cs1.scaleFactors); + + assertEquals("Mismatch in column group sum", colGroup0.getSum(m0.getNumRows()), + colGroup1.getSum(m1.getNumRows()), 0.0); + } + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + /** + * Test 2: Test the RLE compression type by forcing RLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches extractBitmapSingleColumn(). + */ + @Test + public void testForcedRLETypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.RLE, ColGroupRLE.class); + } + + /** + * Test 3: Test the RLE compression type by forcing RLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches extractBitmapMultiColumns(). + * + */ + @Test + public void testForcedRLETypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.RLE, ColGroupRLE.class); + } + + /** + * Test 4: Test the OLE compression type by forcing OLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches extractBitmapSingleColumn(). + */ + @Test + public void testForcedOLETypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.OLE, ColGroupOLE.class); + } + + /** + * Test 5: Test the OLE compression type by forcing OLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches extractBitmapMultiColumn(). + */ + @Test + public void testForcedOLETypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.OLE, ColGroupOLE.class); + } + + /** + * Test 6: Test the SDC compression type by forcing SDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches extractBitmapSingleColumn(). + * This should also cover CONST, EMPTY, SDCFOR. + */ + @Test + public void testForcedSDCTypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.SDC, ColGroupSDC.class); + } + + /** + * Test 7: Test the SDC compression type by forcing SDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches extractBitmapMultiColumn(). + * This should also cover CONST, EMPTY, SDCFOR. + */ + @Test + public void testForcedSDCTypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.SDC, ColGroupSDCSingle.class); + } + + /** + * Test 8: Test the DDC compression type by forcing DDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches directCompressDDCSingleCol(). + * This should also cover DDCFOR. + */ + @Test + public void testForcedDDCTypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.DDC, ColGroupDDC.class); + } + + /** + * Test 9: Test the DDC compression type by forcing DDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches directCompressDDCMultiCol(). + * This should also cover DDCFOR. + */ + @Test + public void testForcedDDCTypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.DDC, ColGroupDDC.class); + } + + /** + * Test the given compression type by forcing it in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). + * Reaches extractBitmapSingleColumn(). + */ + private void testForcedCompressionTypeSingleColumn(CompressionType compressionType, Class expectedGroupClass) { + try { + int nRow = 100; + int nCol = 1; + int max = 50; + int min = -50; + double s = 1.0; + + MatrixBlock m0 = generateTestMatrix(nRow, nCol, min, max, s, SEED, true); + MatrixBlock m1 = generateTestMatrix(nRow, nCol, min, max, s, SEED, false); + + CompressionSettings cs0 = createCompressionSettings(null); + CompressionSettings cs1 = createCompressionSettings(new double[]{1.0}); + + List results0 = compressWithForcedTypeNoCoCode(m0, cs0, compressionType); + List results1 = compressWithForcedTypeNoCoCode(m1, cs1, compressionType); + + assertEquals("Mismatch in number of resulting column groups", results0.size(), results1.size(), 0.0); + + for (int i = 0; i < results0.size(); i++) { + assertInstanceOf(expectedGroupClass, results0.get(i), "Mismatch in forced compression type"); + assertInstanceOf(expectedGroupClass, results1.get(i), "Mismatch in forced compression type"); + + assertEquals("Mismatch in sum of values in column group", + results0.get(i).getSum(nRow), results1.get(i).getSum(nRow), 0.0); + } + } catch (Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + private void testForcedCompressionTypeMultiColumn(CompressionType compressionType, Class expectedGroupClass) { + try { + double[][] values = { + {1.5, 2.5, 3.5, 4.5, 5.5}, + {1.5, 2.5, 3.5, 4.5, 5.5}, + {1.5, 2.5, 3.5, 4.5, 5.5}, + {2.5, 3.5, 4.5, 5.5, 6.5}, + {2.5, 3.5, 4.5, 5.5, 6.5}, + {2.5, 3.5, 4.5, 5.5, 6.5}, + }; + + int nRow = values.length; + + MatrixBlock m0 = DataConverter.convertToMatrixBlock(values); + m0 = TestUtils.floor(m0); + m0.recomputeNonZeros(); + + MatrixBlock m1 = DataConverter.convertToMatrixBlock(values); + + CompressionSettings cs0 = createCompressionSettings(null); + CompressionSettings cs1 = createCompressionSettings(new double[] {1.0}); + + List results0 = compressWithForcedTypeCoCode(m0, cs0, compressionType); + List results1 = compressWithForcedTypeCoCode(m1, cs1, compressionType); + + assertEquals("Mismatch in number of resulting column groups", results0.size(), results1.size(), 0.0); + + for(int i = 0; i < results0.size(); i++) { + assertInstanceOf(expectedGroupClass, results0.get(i), "Mismatch in forced compression type"); + assertInstanceOf(expectedGroupClass, results1.get(i), "Mismatch in forced compression type"); + assertEquals("Mismatch in sum of values in column group", results0.get(i).getSum(nRow), + results1.get(i).getSum(nRow), 0.0); + } + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + private static void assertInstanceOf(Class expected, Object obj, String message) { + if (!expected.isInstance(obj)) { + fail(message + ": Expected " + expected.getSimpleName() + ", but got " + obj.getClass().getSimpleName()); + } + } + + /** + * Generate compressed groups with an estimator. + */ + private static Pair generateCompressedGroups(MatrixBlock matrix, CompressionSettings cs) { + AComEst estimator = ComEstFactory.createEstimator(matrix, cs, K); + CompressedSizeInfo sizeInfo = estimator.computeCompressedSizeInfos(K); + return Pair.of(sizeInfo, estimator); + } + + /** + * Force a specific compression type (e.g., RLE) on a set of compressed groups. + */ + private static List compressWithForcedTypeNoCoCode(MatrixBlock matrix, CompressionSettings cs, + CompressionType type) { + Pair result= generateCompressedGroups(matrix, cs); + CompressedSizeInfo originalGroups = result.getLeft(); + List modifiedGroups = forceCompressionType(originalGroups, type); + CompressedSizeInfo compressedGroupsNew = new CompressedSizeInfo(modifiedGroups); + return ColGroupFactory.compressColGroups(matrix, compressedGroupsNew, cs, K); + } + + /** + * Force a specific compression type (e.g., RLE) on a set of compressed groups with CoCode. + */ + private static List compressWithForcedTypeCoCode(MatrixBlock matrix, CompressionSettings cs, + CompressionType type) { + Pair result= generateCompressedGroups(matrix, cs); + CompressedSizeInfo originalGroups = result.getLeft(); + AComEst estimator = result.getRight(); + ACostEstimate ice = CostEstimatorFactory.create(cs, null, matrix.getNumRows(), matrix.getNumColumns(), matrix.getSparsity()); + originalGroups = CoCoderFactory.findCoCodesByPartitioning(estimator, originalGroups, K, ice, cs); + List modifiedGroups = forceCompressionType(originalGroups, type); + CompressedSizeInfo compressedGroupsNew = new CompressedSizeInfo(modifiedGroups); + return ColGroupFactory.compressColGroups(matrix, compressedGroupsNew, cs, K); + } + + /** + * Modify the compression type of each group to a specific type. + */ + private static List forceCompressionType(CompressedSizeInfo originalGroups, + CompressionType type) { + List modifiedGroups = new ArrayList<>(); + for(CompressedSizeInfoColGroup cg : originalGroups.getInfo()) { + Set compressionTypes = new HashSet<>(); + compressionTypes.add(type); + modifiedGroups + .add(new CompressedSizeInfoColGroup(cg.getColumns(), cg.getFacts(), compressionTypes, cg.getMap())); + } + return modifiedGroups; + } + + /** + * Generate a test matrix with specified dimensions, value range, and sparsity. + */ + private static MatrixBlock generateTestMatrix(int nRow, int nCol, int min, int max, double s, long seed, + boolean floored) { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(nRow, nCol, min, max, s, seed); + return floored ? TestUtils.floor(mb) : mb; + } + + /** + * Create compression settings with an optional scale factor. + */ + private static CompressionSettings createCompressionSettings(double[] scaleFactor) { + CompressionSettingsBuilder builder = new CompressionSettingsBuilder(); + // .setColumnPartitioner(PartitionerType.GREEDY).setSeed((int) SEED); + if(scaleFactor != null) { + builder.setScaleFactor(scaleFactor); + } + return builder.create(); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java index b4054169afb..ae92d3a4313 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java @@ -19,7 +19,9 @@ package org.apache.sysds.test.component.compress.readers; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -31,6 +33,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.test.TestUtils; import org.junit.Test; +import java.util.Arrays; public class ReadersTest { @@ -83,4 +86,45 @@ public void testInvalidRange_02() { mb.allocateDenseBlock(); ReaderColumnSelection.createReader(mb, ColIndexFactory.create(2), false, 10, 9); } + + @Test + public void testReaderColumnSelectionQuantized() { + + // 4.0 0.0 + // 3.0 0.0 + // 0.0 5.0 + + MatrixBlock mb = new MatrixBlock(3, 2, false); + mb.allocateDenseBlock(); + mb.set(0, 0, 4); + mb.set(1, 0, 3); + mb.set(2, 1, 5); + + double[][] scaleFactorCases = { + {0.3}, // Scalar case + {0.3, 0.4, 0.5} // Per-row scale factor + }; + + for (double[] scaleFactors : scaleFactorCases) { + ReaderColumnSelection r = ReaderColumnSelection.createQuantizedReader( + mb, ColIndexFactory.create(2), false, scaleFactors); + + double[][] expectedValues = { + { Math.floor(4 * (scaleFactors.length > 1 ? scaleFactors[0] : scaleFactors[0])), Math.floor(0.0 * (scaleFactors.length > 1 ? scaleFactors[0] : scaleFactors[0])) }, + { Math.floor(3 * (scaleFactors.length > 1 ? scaleFactors[1] : scaleFactors[0])), Math.floor(0.0 * (scaleFactors.length > 1 ? scaleFactors[1] : scaleFactors[0])) }, + { Math.floor(0.0 * (scaleFactors.length > 1 ? scaleFactors[2] : scaleFactors[0])), Math.floor(5 * (scaleFactors.length > 1 ? scaleFactors[2] : scaleFactors[0])) } + }; + + DblArray d; + int rowIndex = 0; + while ((d = r.nextRow()) != null) { + assertNotNull("Row " + rowIndex + " should not be null", d); + assertArrayEquals("Mismatch for scaleFactors " + Arrays.toString(scaleFactors), + expectedValues[rowIndex], d.getData(), 0.0); + rowIndex++; + } + } + } + + } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteQuantizationFusedCompressionTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteQuantizationFusedCompressionTest.java new file mode 100644 index 00000000000..3a9dfa48dda --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteQuantizationFusedCompressionTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.utils.Statistics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.hops.OptimizerUtils; +import java.util.Arrays; + +/** + * Test for the rewrite that replaces a sequence of X = floor(M * sf) Y = compress(X) to a fused quantize_compress(M, + * sf). + * + */ +public class RewriteQuantizationFusedCompressionTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteQuantizationFusedCompressionScalar"; + private static final String TEST_NAME2 = "RewriteQuantizationFusedCompressionMatrix"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + + RewriteQuantizationFusedCompressionTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double sfValue = 0.5; // Value used to fill the scale factor matrix or as a standalone scalar + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"})); + } + + @Test + public void testRewriteQuantizationFusedCompressionScalar() { + testRewriteQuantizationFusedCompression(TEST_NAME1, true, true); + } + + @Test + public void testRewriteQuantizationFusedCompressionNoRewriteScalar() { + testRewriteQuantizationFusedCompression(TEST_NAME1, false, true); + } + + @Test + public void testRewriteQuantizationFusedCompression() { + testRewriteQuantizationFusedCompression(TEST_NAME2, true, false); + } + + @Test + public void testRewriteQuantizationFusedCompressionNoRewrite() { + testRewriteQuantizationFusedCompression(TEST_NAME2, false, false); + } + + /** + * Unified method to test both scalar and matrix scale factors. + * + * @param testname Test name + * @param rewrites Whether to enable fusion rewrites + * @param isScalar Whether the scale factor is a scalar or a matrix + */ + private void testRewriteQuantizationFusedCompression(String testname, boolean rewrites, boolean isScalar) { + boolean oldRewriteFlag = OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE; + OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE = rewrites; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7); + + String[] programArgs; + if(isScalar) { + // Scalar case: pass sfValue as a string + String s = Double.toString(sfValue); + programArgs = new String[] {"-stats", "-args", input("A"), s, output("R")}; + writeInputMatrixWithMTD("A", A, 174522, false); + } + else { + // Matrix case: pass S as a separate matrix + double[][] S = new double[rows][1]; + for(int i = 0; i < rows; i++) { + S[i][0] = sfValue; + } + programArgs = new String[] {"-stats", "-args", input("A"), input("S"), output("R")}; + writeInputMatrixWithMTD("A", A, 174522, false); + writeInputMatrixWithMTD("S", S, 500, false); + } + + this.programArgs = programArgs; + runTest(true, false, null, -1); + + // Simple check if quantization indeed occured by computing expected sum + // Even if compression is aborted, the quantization step should still take effect + double expectedR = Arrays.stream(A).flatMapToDouble(Arrays::stream).map(x -> Math.floor(x * sfValue)).sum(); + double actualR = TestUtils.readDMLScalar(output("R")); + + Assert.assertEquals("Mismatch in expected sum after quantization and compression", expectedR, actualR, 0.0); + + // Check if fusion occurred + if(rewrites) { + Assert.assertEquals("Expected fused operation count mismatch", 1, + Statistics.getCPHeavyHitterCount(Opcodes.QUANTIZE_COMPRESS.toString())); + Assert.assertEquals("Expected no separate floor op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.FLOOR.toString())); + Assert.assertEquals("Expected no separate compress op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.COMPRESS.toString())); + Assert.assertEquals("Expected no separate multiplication op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.MULT.toString())); + } + else { + Assert.assertEquals("Expected no fused op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.QUANTIZE_COMPRESS.toString())); + Assert.assertEquals("Expected separate floor op", 1, + Statistics.getCPHeavyHitterCount(Opcodes.FLOOR.toString())); + Assert.assertEquals("Expected separate compress op", 1, + Statistics.getCPHeavyHitterCount(Opcodes.COMPRESS.toString())); + Assert.assertEquals("Expected separate multiplication op", 1, + Statistics.getCPHeavyHitterCount(Opcodes.MULT.toString())); + } + } + finally { + OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE = oldRewriteFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionMatrix.dml b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionMatrix.dml new file mode 100644 index 00000000000..38c6d712c6b --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionMatrix.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +# Load matrix A +A = read($1); + +# Load vecotr/matrix scale factor S +S = read($2); + +# Quantize +B = floor(A * S); + +# Compress +C = compress(B); + +# Write the sum, as writing a compressed matrix is complicated +R = sum(C); +write(R, $3); diff --git a/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionScalar.dml b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionScalar.dml new file mode 100644 index 00000000000..c6ccfb2a214 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionScalar.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +# Load matrix A +A = read($1); + +# Load scalar scale factor S +s = as.double($2); + +# Quantize +B = floor(A * s); + +# Compress +C = compress(B); + +# Write the sum, as writing a compressed matrix is complicated +R = sum(C); +write(R, $3);