diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 423679d038c..67f81c556b1 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -153,6 +153,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_CATEGORICAL_MASK("getCategoricalMask", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), GLOVE("glove", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 28c5a7a6a8e..01a3d09ff3c 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -194,6 +194,8 @@ public enum Opcodes { TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin), + GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary), + //Ternary instruction opcodes PM("+*", InstructionType.Ternary), MINUSMULT("-*", InstructionType.Ternary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index e69ad375b20..2a5a1365616 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -635,6 +635,7 @@ public enum OpOp2 { MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) MINUS1_MULT(false), //1-X*Y + GET_CATEGORICAL_MASK(false), // get transformation mask QUANTIZE_COMPRESS(false); //quantization-fused compression private final boolean _validOuter; diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index a3ddb45ea6d..73e3c5fac86 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -763,8 +763,8 @@ protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - DataType dt1 = getInput().get(0).getDataType(); - DataType dt2 = getInput().get(1).getDataType(); + final DataType dt1 = getInput(0).getDataType(); + final DataType dt2 = getInput(1).getDataType(); if( _etypeForced != null ) { setExecType(_etypeForced); @@ -812,18 +812,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { checkAndSetInvalidCPDimsAndSize(); } - //spark-specific decision refinement (execute unary scalar w/ spark input and + // spark-specific decision refinement (execute unary scalar w/ spark input and // single parent also in spark because it's likely cheap and reduces intermediates) - if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED && - getDataType().isMatrix() // output should be a matrix - && (dt1.isScalar() || dt2.isScalar()) // one side should be scalar - && supportsMatrixScalarOperations() // scalar operations - && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint - && getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent - && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec - && getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) { - // pull unary scalar operation into spark - _etype = ExecType.SPARK; + if(transitive // we allow transitive Spark operations. continue sequences of spark operations + && _etype == ExecType.CP // The instruction is currently in CP + && _etypeForced != ExecType.CP // not forced CP + && _etypeForced != ExecType.FED // not federated + && (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame + ) { + final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize(); + final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize(); + final boolean left = v1 == true; // left side is the vector or scalar + final Hop sparkIn = getInput(left ? 1 : 0); + if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar. + && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation + && sparkIn.getParent().size() == 1 // only one parent + && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec + && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && !(sparkIn instanceof DataOp) // input is not checkpoint + ) { + // pull operation into spark + _etype = ExecType.SPARK; + } } if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE && @@ -853,7 +863,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + + if( op == OpOp2.GET_CATEGORICAL_MASK) + _etype = ExecType.CP; + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index b32a1a74aab..4a842c69b0f 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -1040,6 +1040,12 @@ public final String toString() { // ======================================================================================== + protected boolean isScalarOrVectorBellowBlockSize(){ + return getDataType().isScalar() || (dimsKnown() && + (( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize()) + || _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize())); + } + protected boolean isVector() { return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) ); } @@ -1624,6 +1630,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) { lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this)); } + protected boolean hasSparkOutput(){ + return (this.optFindExecType() == ExecType.SPARK + || (this instanceof DataOp && ((DataOp)this).hasOnlyRDD())); + } + /** * Set parse information. * diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 34da36dd13c..e16896b869b 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); + + if(getDataType() == DataType.FRAME) + return OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + else + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override @@ -463,6 +467,13 @@ public boolean isMetadataOperation() { || _op == OpOp1.CAST_AS_LIST; } + private boolean isDisallowedSparkOps(){ + return isCumulativeUnaryOperation() + || isCastUnaryOperation() + || _op==OpOp1.MEDIAN + || _op==OpOp1.IQM; + } + @Override protected ExecType optFindExecType(boolean transitive) { @@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto checkAndSetInvalidCPDimsAndSize(); } + //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) - if( _etype == ExecType.CP && _etypeForced != ExecType.CP - && getInput().get(0).optFindExecType() == ExecType.SPARK - && getDataType().isMatrix() - && !isCumulativeUnaryOperation() && !isCastUnaryOperation() - && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM - && !(getInput().get(0) instanceof DataOp) //input is not checkpoint - && getInput().get(0).getParent().size()==1 ) //unary is only parent - { + if(_etype == ExecType.CP // currently CP instruction + && _etype != ExecType.SPARK /// currently not SP. + && _etypeForced != ExecType.CP // not forced as CP instruction + && getInput(0).hasSparkOutput() // input is a spark instruction + && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame + && !isDisallowedSparkOps() // is invalid spark instruction + // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint + // && getInput(0).getParent().size() <= 1// unary is only parent + ) { //pull unary operation into spark _etype = ExecType.SPARK; } + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent } else { setRequiresRecompileIfNecessary(); } - + return _etype; } diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index ae582b052b2..25c03ea087e 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -2013,6 +2013,11 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV else raiseValidateError("The compress or decompress instruction is not allowed in dml scripts"); break; + case GET_CATEGORICAL_MASK: + checkNumParameters(2); + checkFrameParam(getFirstExpr()); + checkScalarParam(getSecondExpr()); + break; case QUANTIZE_COMPRESS: if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { checkNumParameters(2); @@ -2333,6 +2338,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS); } } + + protected void checkFrameParam(Expression e) { + if(e.getOutput().getDataType() != DataType.FRAME) { + raiseValidateError("Expecting frame parameter for function " + getOpCode(), false, + LanguageErrorCodes.UNSUPPORTED_PARAMETERS); + } + } protected void checkMatrixScalarParam(Expression e) { //always unconditional if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) { diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c4f7f672abe..881f129b3e1 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2817,6 +2817,9 @@ else if ( in.length == 2 ) DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr); break; + case GET_CATEGORICAL_MASK: + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOp2.GET_CATEGORICAL_MASK, expr, expr2); + break; default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 48637595741..90fc0c0f736 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -58,12 +58,13 @@ import org.apache.sysds.runtime.compress.lib.CLALibMMChain; import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; -import org.apache.sysds.runtime.compress.lib.CLALibReplace; import org.apache.sysds.runtime.compress.lib.CLALibReorg; +import org.apache.sysds.runtime.compress.lib.CLALibReplace; import org.apache.sysds.runtime.compress.lib.CLALibReshape; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; import org.apache.sysds.runtime.compress.lib.CLALibSlice; +import org.apache.sysds.runtime.compress.lib.CLALibSort; import org.apache.sysds.runtime.compress.lib.CLALibSquash; import org.apache.sysds.runtime.compress.lib.CLALibTSMM; import org.apache.sysds.runtime.compress.lib.CLALibTernaryOp; @@ -101,6 +102,7 @@ import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.utils.DMLCompressionStatistics; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; +import org.apache.sysds.utils.stats.Timing; public class CompressedMatrixBlock extends MatrixBlock { private static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName()); @@ -475,16 +477,20 @@ public void readFields(DataInput in) throws IOException { } public static CompressedMatrixBlock read(DataInput in) throws IOException { + Timing t = new Timing(); int rlen = in.readInt(); int clen = in.readInt(); long nonZeros = in.readLong(); boolean overlappingColGroups = in.readBoolean(); List groups = ColGroupIO.readGroups(in, rlen); - return new CompressedMatrixBlock(rlen, clen, nonZeros, overlappingColGroups, groups); + CompressedMatrixBlock ret = new CompressedMatrixBlock(rlen, clen, nonZeros, overlappingColGroups, groups); + LOG.debug("Compressed read serialization time: " + t.stop()); + return ret; } @Override public void write(DataOutput out) throws IOException { + Timing t = new Timing(); final long estimateUncompressed = nonZeros > 0 ? MatrixBlock.estimateSizeOnDisk(rlen, clen, nonZeros) : Long.MAX_VALUE; final long estDisk = nonZeros > 0 ? getExactSizeOnDisk() : Long.MAX_VALUE; @@ -512,6 +518,7 @@ public void write(DataOutput out) throws IOException { out.writeLong(nonZeros); out.writeBoolean(overlappingColGroups); ColGroupIO.writeGroups(out, _colGroups); + LOG.debug("Compressed write serialization time: " + t.stop()); } /** @@ -611,14 +618,6 @@ public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, MatrixVal public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType tstype, int k) { // check for transpose type if(tstype == MMTSJType.LEFT) { - if(isEmpty()) - return new MatrixBlock(clen, clen, true); - // create output matrix block - if(out == null) - out = new MatrixBlock(clen, clen, false); - else - out.reset(clen, clen, false); - out.allocateDenseBlock(); CLALibTSMM.leftMultByTransposeSelf(this, out, k); return out; } @@ -846,9 +845,8 @@ public CM_COV_Object covOperations(COVOperator op, MatrixBlock that, MatrixBlock } @Override - public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { - MatrixBlock right = getUncompressed(weights); - return getUncompressed("sortOperations").sortOperations(right, result); + public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result, int k) { + return CLALibSort.sort(this, weights, result, k); } @Override @@ -1202,8 +1200,8 @@ public void examSparsity(boolean allowCSR, int k) { } @Override - public void sparseToDense(int k) { - // do nothing + public MatrixBlock sparseToDense(int k) { + return this; // do nothing } @Override @@ -1236,16 +1234,6 @@ public double interQuartileMean() { return getUncompressed("interQuartileMean").interQuartileMean(); } - @Override - public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { - return getUncompressed("pickValues").pickValues(quantiles, ret); - } - - @Override - public double pickValue(double quantile, boolean average) { - return getUncompressed("pickValue").pickValue(quantile, average); - } - @Override public double sumWeightForQuantile() { return getUncompressed("sumWeightForQuantile").sumWeightForQuantile(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 4c48effb4df..f082d1ffc3d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -64,6 +64,8 @@ public class CompressedMatrixBlockFactory { private static final Log LOG = LogFactory.getLog(CompressedMatrixBlockFactory.class.getName()); + private static final Object asyncCompressLock = new Object(); + /** Timing object to measure the time of each phase in the compression */ private final Timing time = new Timing(true); /** Compression statistics gathered throughout the compression */ @@ -181,21 +183,23 @@ public static Future compressAsync(ExecutionContext ec, String varName) { } public static Future compressAsync(ExecutionContext ec, String varName, InstructionTypeCounter ins) { - LOG.debug("Compressing Async"); final ExecutorService pool = CommonThreadPool.get(); // We have to guarantee that a thread pool is allocated. return CompletableFuture.runAsync(() -> { // method call or code to be async try { CacheableData data = ec.getCacheableData(varName); - if(data instanceof MatrixObject) { - MatrixObject mo = (MatrixObject) data; - MatrixBlock mb = mo.acquireReadAndRelease(); - MatrixBlock mbc = CompressedMatrixBlockFactory.compress(mo.acquireReadAndRelease(), ins).getLeft(); - if(mbc instanceof CompressedMatrixBlock) { - ExecutionContext.createCacheableData(mb); - mo.acquireModify(mbc); - mo.release(); - mbc.sum(); // calculate sum to forcefully materialize counts + synchronized(asyncCompressLock){ // synchronize on the data object to not allow multiple compressions of the same matrix. + if(data instanceof MatrixObject) { + LOG.debug("Compressing Async"); + MatrixObject mo = (MatrixObject) data; + MatrixBlock mb = mo.acquireReadAndRelease(); + MatrixBlock mbc = CompressedMatrixBlockFactory.compress(mb, ins).getLeft(); + if(mbc instanceof CompressedMatrixBlock) { + ExecutionContext.createCacheableData(mb); + mo.acquireModify(mbc); + mo.release(); + mbc.sum(); // calculate sum to forcefully materialize counts + } } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index ec502d6d122..07c17c30893 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -401,8 +401,9 @@ public final AColGroup rightMultByMatrix(MatrixBlock right) { * @param cru The right hand side column upper * @param nRows The number of rows in this column group */ - public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru){ - throw new NotImplementedException("not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); + public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru) { + throw new NotImplementedException( + "not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); } /** @@ -806,7 +807,7 @@ public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlo else denseSelection(selection, points, ret, rl, ru); } - + /** * Get an approximate sparsity of this column group * @@ -972,6 +973,15 @@ public AColGroup[] splitReshapePushDown(final int multiplier, final int nRow, fi return splitReshape(multiplier, nRow, nColOrg); } + /** + * Sort the values of the column group according to double < > operations and return as another compressed group. + * + * This sorting assumes that the column group is sorted independently of everything else. + * + * @return The sorted group + */ + public abstract AColGroup sort(); + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -981,4 +991,5 @@ public String toString() { sb.append(_colIndexes); return sb.toString(); } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java index 0cde289b30f..4f53d8b912b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java @@ -59,8 +59,6 @@ public int getNumValues() { * produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct * tuples in the dictionary. * - * The returned counts always contains the number of zero tuples as well if there are some contained, even if they - * are not materialized. * * @return The count of each value in the MatrixBlock. */ @@ -212,6 +210,7 @@ public void clear() { counts = null; } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java index 8f2f0b46055..d114f029df8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java @@ -402,4 +402,5 @@ protected IDictionary combineDictionaries(int nCol, List right) { public double getSparsity() { return _dict.getSparsity(); } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java index 3de98a1c23f..30de5e120c5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java @@ -203,6 +203,22 @@ private final void leftMultByMatrixNoPreAggRowsDense(MatrixBlock mb, double[] re */ protected abstract void multiplyScalar(double v, double[] resV, int offRet, AIterator it); + public void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC, AIterator it) { + if(_dict instanceof MatrixBlockDictionary) { + final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; + final MatrixBlock mb = md.getMatrixBlock(); + // The dictionary is never empty. + if(mb.isInSparseFormat()) + // TODO make sparse decompression where the iterator is known in argument + decompressToSparseBlockSparseDictionary(sb, rl, ru, offR, offC, mb.getSparseBlock()); + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, mb.getDenseBlockValues(), + it); + } + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, _dict.getValues(), it); + } + public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC, AIterator it) { if(_dict instanceof MatrixBlockDictionary) { final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; @@ -223,6 +239,9 @@ public void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, _dict.getValues(), it); } + public abstract void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock db, int rl, int ru, + int offR, int offC, double[] values, AIterator it); + public abstract void decompressToDenseBlockDenseDictionaryWithProvidedIterator(DenseBlock db, int rl, int ru, int offR, int offC, double[] values, AIterator it); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 21c6a0e1d80..de107850693 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -758,4 +758,9 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) protected boolean allowShallowIdentityRightMult() { return true; } + + @Override + public AColGroup sort(){ + return this; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index fc82c58e16b..5995e5f17ad 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -26,8 +26,6 @@ import java.util.List; import java.util.concurrent.ExecutorService; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -67,6 +65,9 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.jboss.netty.handler.codec.compression.CompressionException; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + /** * Class to encapsulate information about a column group that is encoded with dense dictionary encoding (DDC). */ @@ -1091,6 +1092,27 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E return res; } + @Override + public AColGroup sort() { + // TODO restore support for run length encoding to exploit the runs + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDC.create(_colIndexes, _dict, m, counts); + + } + + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -1105,4 +1127,6 @@ protected boolean allowShallowIdentityRightMult() { return true; } + + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index 70191a27936..84571ff9639 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -546,6 +546,27 @@ protected boolean allowShallowIdentityRightMult() { return false; } + + @Override + public AColGroup sort() { + // TODO restore support for run length encoding. + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDCFOR.create(_colIndexes, _dict, m, counts, _reference); + + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index ba547a8d7aa..0c018293b7e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -476,4 +476,9 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) return new ColGroupEmpty(combinedIndex); } + + @Override + public AColGroup sort(){ + return this; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java index 91442281317..1091ae36890 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java @@ -94,9 +94,7 @@ public static long getExactSizeOnDisk(List colGroups) { } ret += grp.getExactSizeOnDisk(); } - if(LOG.isWarnEnabled()) - LOG.warn(" duplicate dicts on exact Size on Disk : " + (colGroups.size() - dicts.size()) ); - + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index 45b4fbeb026..cf984ce3d21 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -740,4 +740,8 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort(){ + throw new NotImplementedException("Unimplemented method 'sort'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index ea6d0f34c2a..591d795a20f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -26,10 +26,10 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.bitmap.ABitmap; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; @@ -731,5 +731,9 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 2b4b23792e3..59e421400e0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -1190,4 +1190,8 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 1270823bfdc..7e3e0cc4dce 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -508,10 +508,10 @@ protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, in AOffset indexes, AMapToData data, int[] counts, int def, int nVal) { if(d == null) { - if(def <= 0){ + if(def <= 0) { if(max > 0) return ColGroupEmpty.create(max); - else + else return null; } else if(def > max && max > 0) @@ -873,6 +873,52 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + final double def = _defaultTuple[0]; + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= def) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDC.create(_colIndexes, _numRows, _dict, _defaultTuple, o, m, counts); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 41fb7ac5709..62bf44c3868 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -620,6 +620,51 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if( _dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCFOR.create(_colIndexes, _numRows, _dict, o, m, counts, _reference); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index fa5772c0c3e..210a9ff35d4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -717,6 +717,50 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { } return res; } + + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + final double def = _defaultTuple[0]; + int defIdx = counts.length; + int nondefault = 0; + for(int i = 0; i < r.length; i++) { + if(defIdx == counts.length && _dict.getValue(r[i], 0, 1) >= def) { + defIdx = i; + } + nondefault += counts[i]; + } + + int defaultLength = _numRows - nondefault; + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingle.create(_colIndexes, _numRows, _dict, _defaultTuple, o, counts); + } @Override public String toString() { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index 9efd0c41098..6428c447e97 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -109,10 +109,8 @@ protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int return; else if(it.value() >= ru) return; - // _indexes.cacheIterator(it, ru); else { decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, values, it); - // _indexes.cacheIterator(it, ru); } } @@ -238,8 +236,10 @@ protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > last) { + return; + // _indexes.cacheIterator(it, ru); + else + if(ru > last) { final int apos = sb.pos(0); final int alen = sb.size(0) + apos; final int[] aix = sb.indexes(0); @@ -277,8 +277,14 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int nCol = _colIndexes.size(); final int lastOff = _indexes.getOffsetToLast(); int row = offR + it.value(); @@ -1043,6 +1049,49 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + int nondefault = 0; + for(int i = 0; i < r.length; i++) { + if(defIdx == counts.length && _dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + } + nondefault += counts[i]; + } + + int defaultLength = _numRows - nondefault; + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingleZeros.create(_colIndexes, _numRows, _dict, o, counts); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 69e0f776383..c5baae66047 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -184,8 +184,7 @@ private final void decompressToDenseBlockDenseDictionaryPostAllCols(DenseBlock d final double[] c = db.values(idx); final int off = db.pos(idx); final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); if(it.value() == lastOff) return; it.next(); @@ -301,13 +300,19 @@ private void decompressToDenseBlockDenseDictionaryPreAllCols(DenseBlock db, int final double[] c = db.values(idx); final int off = db.pos(idx) + offC; final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); it.next(); } } + private static void decompressSingleRow(double[] values, final int nCol, final double[] c, final int off, + final int offDict) { + final int end = nCol + off; + for(int j = off, k = offDict; j < end; j++, k++) + c[j] += values[k]; + } + @Override protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, SparseBlock sb) { @@ -438,8 +443,16 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, + int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int lastOff = _indexes.getOffsetToLast(); final int nCol = _colIndexes.size(); while(true) { @@ -467,7 +480,6 @@ else if(ru > _indexes.getOffsetToLast()) { } _indexes.cacheIterator(it, ru); } - } @Override @@ -899,7 +911,6 @@ public AColGroup morph(CompressionType ct, int nRow) { return super.morph(ct, nRow); } - @Override public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { final SparseBlock sr = ret.getSparseBlock(); @@ -942,14 +953,14 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret of = it.next(); } else if(points[c].o < of) - c++; + c++; else of = it.next(); - } - // increment the c pointer until it is pointing at least to last point or is done. - while(c < points.length && points[c].o < last) - c++; - c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); + } + // increment the c pointer until it is pointing at least to last point or is done. + while(c < points.length && points[c].o < last) + c++; + c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); } private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) { @@ -1078,6 +1089,54 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCZeros.create(_colIndexes, _numRows, _dict, o, m, counts); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 1c3bce2e16c..f498b652149 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -1294,6 +1294,11 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { // throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort(){ + return new ColGroupUncompressed(_data.sortOperations(), _colIndexes); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 31e29341645..17814419549 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.compress.colgroup; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; @@ -282,4 +283,9 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new UnsupportedOperationException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort(){ + throw new NotImplementedException("Unimplemented method 'sort'"); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java index 17b382f06ad..a7e715b59b8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; public abstract class AIdentityDictionary extends ACachingMBDictionary { @@ -74,4 +75,9 @@ public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { ret[ret.length - 1] *= defaultTuple[i]; return ret; } + + @Override + public int[] sort(){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java index d67ab95f824..30ba9fe9dc8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java @@ -121,4 +121,9 @@ public boolean equals(IDictionary o) { public IDictionary clone() { throw new NotImplementedException(); } + + @Override + public int[] sort(){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 939b48bf424..3cfec04908c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -1341,4 +1341,68 @@ public IDictionary append(double[] row) { return new Dictionary(retV); } + @Override + public int[] sort() { + return sort(_values); + } + + protected static int[] sort(double[] values) { + int[] indices = new int[values.length]; + for(int i = 0; i < indices.length; i++) { + indices[i] = i; + } + + + // quicksort with stack + int[] stack = new int[values.length]; + + int top = -1; + stack[++top] = 0; + stack[++top] = values.length - 1; + + while(top >= 0) { + int high = stack[top--]; + int low = stack[top--]; + + if(low < high) { + + int pivotIndex = partition(indices, values, low, high); + // Left side + if(pivotIndex - 1 > low) { + stack[++top] = low; + stack[++top] = pivotIndex - 1; + } + + // Right side + if(pivotIndex + 1 < high) { + stack[++top] = pivotIndex + 1; + stack[++top] = high; + } + } + } + + return indices; + } + + private static int partition(int[] indices, double[] values, int low, int high) { + double pivotValue = values[indices[high]]; + int i = low - 1; + + for(int j = low; j < high; j++) { + if(values[indices[j]] <= pivotValue) { + i++; + swap(indices, i, j); + } + } + + swap(indices, i + 1, high); + return i + 1; + } + + private static void swap(int[] arr, int i, int j) { + int tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index dddea0eec7a..06e6893ae46 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -1051,4 +1051,14 @@ public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thi * @return The nonzero count of each column in the dictionary. */ public int[] countNNZZeroColumns(int[] counts); + + /** + * Sort the values of this dictionary via an index of how the values mapped previously. + * + * In practice this design means we can reuse the previous dictionary for the resulting column group + * + * @return The sorted index. + */ + public int[] sort(); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 24776f3adc4..799f67f9c06 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -27,8 +27,6 @@ import java.util.Arrays; import java.util.Set; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; @@ -61,6 +59,9 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + public class MatrixBlockDictionary extends ADictionary { private static final long serialVersionUID = 2535887782150955098L; @@ -2801,4 +2802,13 @@ private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, } } + @Override + public int[] sort(){ + if(_data.getNumColumns() > 1) + throw new RuntimeException("Not supported sort on multicolumn dictionaries"); + _data.sparseToDense(); + + return Dictionary.sort(_data.getDenseBlockValues()); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index f5746647a37..08193cb5d4d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -101,4 +101,8 @@ public DictType getDictType() { throw new RuntimeException("invalid to get dictionary type for PlaceHolderDict"); } + @Override + public int[] sort(){ + throw new RuntimeException("Invalid call"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 6802d920b49..3ebc92c5053 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -23,6 +23,7 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.utils.MemoryEstimates; @@ -277,4 +278,8 @@ public MatrixBlockDictionary createMBDict(int nCol) { return new MatrixBlockDictionary(mb); } + @Override + public int[] sort(){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index ce52bcd23fd..0d5451f568d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -48,6 +48,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; +import org.apache.sysds.runtime.compress.utils.HashMapIntToInt; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.SparseBlock; @@ -55,7 +56,6 @@ import org.apache.sysds.runtime.data.SparseRow; import org.apache.sysds.runtime.data.SparseRowScalar; import org.apache.sysds.runtime.data.SparseRowVector; -import org.apache.sysds.runtime.frame.data.columns.HashMapToInt; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -77,7 +77,7 @@ public final class CLALibBinaryCellOp { private static final Log LOG = LogFactory.getLog(CLALibBinaryCellOp.class.getName()); - public static final int DECOMPRESSION_BLEN = 16384; + public static final int DECOMPRESSION_BLEN = 16384 / 2; private CLALibBinaryCellOp() { // empty private constructor. @@ -86,7 +86,7 @@ private CLALibBinaryCellOp() { public static MatrixBlock binaryOperationsRight(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that) { try { - op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); + op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); if((that.getNumRows() == 1 && that.getNumColumns() == 1) || that.isEmpty()) { ScalarOperator sop = new RightScalarOperator(op.fn, that.get(0, 0), op.getNumThreads()); @@ -122,8 +122,8 @@ private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, Comp BinaryAccessType atype = LibMatrixBincell.getBinaryAccessTypeExtended(m1, that); if(isDoubleCompressedOpApplicable(m1, that)) return doubleCompressedBinaryOp(op, m1, (CompressedMatrixBlock) that); - if(that instanceof CompressedMatrixBlock && that.getNumColumns() == m1.getNumColumns() - && that.getInMemorySize() < m1.getInMemorySize() ) { + if(that instanceof CompressedMatrixBlock && that.getNumColumns() == m1.getNumColumns() && + that.getInMemorySize() < m1.getInMemorySize()) { MatrixBlock m1uc = CompressedMatrixBlock.getUncompressed(m1, "Decompressing left side in BinaryOps"); return selectProcessingBasedOnAccessType(op, (CompressedMatrixBlock) that, m1uc, atype, true); } @@ -135,16 +135,15 @@ private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, Comp } private static boolean isDoubleCompressedOpApplicable(CompressedMatrixBlock m1, MatrixBlock that) { - return that instanceof CompressedMatrixBlock - && !m1.isOverlapping() - && m1.getColGroups().get(0) instanceof ColGroupDDC - && !((CompressedMatrixBlock) that).isOverlapping() - && ((CompressedMatrixBlock) that).getColGroups().get(0) instanceof ColGroupDDC - && ((IMapToDataGroup) m1.getColGroups().get(0)).getMapToData() == - ((IMapToDataGroup) ((CompressedMatrixBlock) that).getColGroups().get(0)).getMapToData(); + return that instanceof CompressedMatrixBlock && !m1.isOverlapping() && + m1.getColGroups().get(0) instanceof ColGroupDDC && !((CompressedMatrixBlock) that).isOverlapping() && + ((CompressedMatrixBlock) that).getColGroups().get(0) instanceof ColGroupDDC && + ((IMapToDataGroup) m1.getColGroups().get(0)) + .getMapToData() == ((IMapToDataGroup) ((CompressedMatrixBlock) that).getColGroups().get(0)).getMapToData(); } - private static CompressedMatrixBlock doubleCompressedBinaryOp(BinaryOperator op, CompressedMatrixBlock m1, CompressedMatrixBlock m2) { + private static CompressedMatrixBlock doubleCompressedBinaryOp(BinaryOperator op, CompressedMatrixBlock m1, + CompressedMatrixBlock m2) { LOG.debug("Double Compressed BinaryOp"); AColGroup left = m1.getColGroups().get(0); AColGroup right = m2.getColGroups().get(0); @@ -201,6 +200,7 @@ private static MatrixBlock mvCol(BinaryOperator op, CompressedMatrixBlock m1, Ma // Column vector access MatrixBlock d_compressed = m1.getCachedDecompressed(); if(d_compressed != null) { + LOG.debug("Using cached decompressed for Matrix column vector compressed operation"); if(left) throw new NotImplementedException("Binary row op left is not supported for Uncompressed Matrix, " + "Implement support for VMr in MatrixBlock Binary Cell operations"); @@ -416,17 +416,24 @@ private static MatrixBlock mvColCompressed(CompressedMatrixBlock m1, MatrixBlock Pair tuple = evaluateSparsityMVCol(m1, m2, op, left); double estSparsity = tuple.getKey(); double estNnzPerRow = tuple.getValue(); - boolean shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, (long) (estSparsity * nRows * nCols)); + boolean shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, + (long) (estSparsity * nRows * nCols)); // currently also jump into that case if estNnzPerRow == 0 - if(estNnzPerRow <= 2 && nCols <= 31 && op.fn instanceof ValueComparisonFunction){ - return k <= 1 ? binaryMVComparisonColSingleThreadCompressed(m1, m2, op, left) : - binaryMVComparisonColMultiCompressed(m1, m2, op, left); + if(estNnzPerRow <= 2 && nCols <= 31 && op.fn instanceof ValueComparisonFunction) { + return k <= 1 ? binaryMVComparisonColSingleThreadCompressed(m1, m2, op, + left) : binaryMVComparisonColMultiCompressed(m1, m2, op, left); } MatrixBlock ret = new MatrixBlock(nRows, nCols, shouldBeSparseOut, -1).allocateBlock(); if(shouldBeSparseOut) { - if(k <= 1) + if(!m1.isOverlapping() && MatrixBlock.evalSparseFormatInMemory(nRows, nCols, m1.getNonZeros())) { + if(k <= 1) + nnz = binaryMVColSingleThreadSparseSparse(m1, m2, op, left, ret); + else + nnz = binaryMVColMultiThreadSparseSparse(m1, m2, op, left, ret); + } + else if(k <= 1) nnz = binaryMVColSingleThreadSparse(m1, m2, op, left, ret); else nnz = binaryMVColMultiThreadSparse(m1, m2, op, left, ret); @@ -438,7 +445,7 @@ private static MatrixBlock mvColCompressed(CompressedMatrixBlock m1, MatrixBlock nnz = binaryMVColMultiThreadDense(m1, m2, op, left, ret); } - if(op.fn instanceof ValueComparisonFunction) { + if(op.fn instanceof ValueComparisonFunction) { // potentially empty or filled. if(nnz == (long) nRows * nCols)// all was 1 return CompressedMatrixBlockFactory.createConstant(nRows, nCols, 1.0); else if(nnz == 0) // all was 0 -> return empty. @@ -452,19 +459,19 @@ else if(nnz == 0) // all was 0 -> return empty. } private static MatrixBlock binaryMVComparisonColSingleThreadCompressed(CompressedMatrixBlock m1, MatrixBlock m2, - BinaryOperator op, boolean left) { + BinaryOperator op, boolean left) { final int nRows = m1.getNumRows(); final int nCols = m1.getNumColumns(); // get indicators (one-hot-encoded comparison results) - BinaryMVColTaskCompressed task = new BinaryMVColTaskCompressed(m1, m2, 0, nRows, op, left); + BinaryMVColTaskCompressed task = new BinaryMVColTaskCompressed(m1, m2, 0, nRows, op, left); long nnz = task.call(); int[] indicators = task._ret; // map each unique indicator to an index - HashMapToInt hm = new HashMapToInt<>(nCols*3); + HashMapIntToInt hm = new HashMapIntToInt(nCols * 3); int[] colMap = new int[nRows]; - for(int i = 0; i < m1.getNumRows(); i++){ + for(int i = 0; i < m1.getNumRows(); i++) { int nextId = hm.size(); int id = hm.putIfAbsentI(indicators[i], nextId); colMap[i] = id == -1 ? nextId : id; @@ -477,37 +484,39 @@ private static MatrixBlock binaryMVComparisonColSingleThreadCompressed(Compresse return getCompressedMatrixBlock(m1, colMap, hm.size(), outMb, nRows, nCols, nnz); } - private static void fillSparseBlockFromIndicatorFromIndicatorInt(int numCol, Integer indicator, Integer rix, SparseBlockMCSR out) { + private static void fillSparseBlockFromIndicatorFromIndicatorInt(int numCol, Integer indicator, Integer rix, + SparseBlockMCSR out) { ArrayList colIndices = new ArrayList<>(8); - for (int c = numCol - 1; c >= 0; c--) { + for(int c = numCol - 1; c >= 0; c--) { if(indicator <= 0) break; - if(indicator % 2 == 1){ + if(indicator % 2 == 1) { colIndices.add(c); } indicator = indicator >> 1; } SparseRow row = null; - if(colIndices.size() > 1){ + if(colIndices.size() > 1) { double[] vals = new double[colIndices.size()]; Arrays.fill(vals, 1); int[] indices = new int[colIndices.size()]; - for (int i = 0, j = colIndices.size() - 1; i < colIndices.size(); i++, j--) + for(int i = 0, j = colIndices.size() - 1; i < colIndices.size(); i++, j--) indices[i] = colIndices.get(j); row = new SparseRowVector(vals, indices); - } else if(colIndices.size() == 1){ + } + else if(colIndices.size() == 1) { row = new SparseRowScalar(colIndices.get(0), 1.0); } out.set(rix, row, false); } private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrixBlock m1, MatrixBlock m2, - BinaryOperator op, boolean left) throws Exception { + BinaryOperator op, boolean left) throws Exception { final int nRows = m1.getNumRows(); final int nCols = m1.getNumColumns(); final int k = op.getNumThreads(); - final int blkz = nRows / k; + final int blkz = Math.max((nRows + k) / k, 1000); // get indicators (one-hot-encoded comparison results) long nnz = 0; @@ -518,14 +527,11 @@ private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrix tasks.add(new BinaryMVColTaskCompressed(m1, m2, i, Math.min(nRows, i + blkz), op, left)); } List> futures = pool.invokeAll(tasks); - HashMapToInt hm = new HashMapToInt<>(nCols*2); + HashMapIntToInt hm = new HashMapIntToInt(nCols * 2); int[] colMap = new int[nRows]; - for(Future f : futures) - nnz += f.get(); - // map each unique indicator to an index - mergeMVColTaskResults(tasks, blkz, hm, colMap); + nnz = mergeMVColTaskResults(futures, tasks, blkz, hm, colMap); // decode the unique indicator ints to SparseVectors MatrixBlock outMb = getMCSRMatrixBlock(hm, nCols); @@ -539,48 +545,53 @@ private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrix } - private static void mergeMVColTaskResults(ArrayList tasks, int blkz, HashMapToInt hm, int[] colMap) { - + private static long mergeMVColTaskResults(List> futures, ArrayList tasks, + int blkz, HashMapIntToInt hm, int[] colMap) throws InterruptedException, ExecutionException { + long nnz = 0; for(int j = 0; j < tasks.size(); j++) { + nnz += futures.get(j).get(); // ensure task was finished. int[] indicators = tasks.get(j)._ret; - int offset = j* blkz; - - final int remainders = indicators.length % 8; - final int endVecLen = indicators.length - remainders; - for (int i = 0; i < endVecLen; i+= 8) { - colMap[offset + i] = hm.putIfAbsentReturnVal(indicators[i], hm.size()); - colMap[offset + i + 1] = hm.putIfAbsentReturnVal(indicators[i + 1], hm.size()); - colMap[offset + i + 2] = hm.putIfAbsentReturnVal(indicators[i + 2], hm.size()); - colMap[offset + i + 3] = hm.putIfAbsentReturnVal(indicators[i + 3], hm.size()); - colMap[offset + i + 4] = hm.putIfAbsentReturnVal(indicators[i + 4], hm.size()); - colMap[offset + i + 5] = hm.putIfAbsentReturnVal(indicators[i + 5], hm.size()); - colMap[offset + i + 6] = hm.putIfAbsentReturnVal(indicators[i + 6], hm.size()); - colMap[offset + i + 7] = hm.putIfAbsentReturnVal(indicators[i + 7], hm.size()); + int offset = j * blkz; - } - for (int i = 0; i < remainders; i++) { - colMap[offset + endVecLen + i] = hm.putIfAbsentReturnVal(indicators[endVecLen + i], hm.size()); - } + mergeMVColUnrolled(hm, colMap, indicators, offset); } + return nnz; } + private static void mergeMVColUnrolled(HashMapIntToInt hm, int[] colMap, int[] indicators, int offset) { + final int remainders = indicators.length % 8; + final int endVecLen = indicators.length - remainders; + for(int i = 0; i < endVecLen; i += 8) { + colMap[offset + i] = hm.putIfAbsentReturnVal(indicators[i], hm.size()); + colMap[offset + i + 1] = hm.putIfAbsentReturnVal(indicators[i + 1], hm.size()); + colMap[offset + i + 2] = hm.putIfAbsentReturnVal(indicators[i + 2], hm.size()); + colMap[offset + i + 3] = hm.putIfAbsentReturnVal(indicators[i + 3], hm.size()); + colMap[offset + i + 4] = hm.putIfAbsentReturnVal(indicators[i + 4], hm.size()); + colMap[offset + i + 5] = hm.putIfAbsentReturnVal(indicators[i + 5], hm.size()); + colMap[offset + i + 6] = hm.putIfAbsentReturnVal(indicators[i + 6], hm.size()); + colMap[offset + i + 7] = hm.putIfAbsentReturnVal(indicators[i + 7], hm.size()); - private static CompressedMatrixBlock getCompressedMatrixBlock(CompressedMatrixBlock m1, int[] colMap, - int mapSize, MatrixBlock outMb, int nRows, int nCols, long nnz) { + } + for(int i = 0; i < remainders; i++) { + colMap[offset + endVecLen + i] = hm.putIfAbsentReturnVal(indicators[endVecLen + i], hm.size()); + } + } + + private static CompressedMatrixBlock getCompressedMatrixBlock(CompressedMatrixBlock m1, int[] colMap, int mapSize, + MatrixBlock outMb, int nRows, int nCols, long nnz) { final IColIndex i = ColIndexFactory.create(0, m1.getNumColumns()); final AMapToData map = MapToFactory.create(m1.getNumRows(), colMap, mapSize); final AColGroup rgroup = ColGroupDDC.create(i, MatrixBlockDictionary.create(outMb), map, null); final ArrayList groups = new ArrayList<>(1); groups.add(rgroup); - return new CompressedMatrixBlock(nRows, nCols, nnz, false, groups); + return new CompressedMatrixBlock(nRows, nCols, nnz, false, groups); } - private static MatrixBlock getMCSRMatrixBlock(HashMapToInt hm, int nCols) { + private static MatrixBlock getMCSRMatrixBlock(HashMapIntToInt hm, int nCols) { // decode the unique indicator ints to SparseVectors SparseBlockMCSR out = new SparseBlockMCSR(hm.size()); - hm.forEach((indicator, rix) -> - fillSparseBlockFromIndicatorFromIndicatorInt(nCols, indicator, rix, out)); - return new MatrixBlock(hm.size(), nCols, -1, out); + hm.forEach((indicator, rix) -> fillSparseBlockFromIndicatorFromIndicatorInt(nCols, indicator, rix, out)); + return new MatrixBlock(hm.size(), nCols, -1, out); } private static long binaryMVColSingleThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, @@ -599,6 +610,14 @@ private static long binaryMVColSingleThreadSparse(CompressedMatrixBlock m1, Matr return nnz; } + private static long binaryMVColSingleThreadSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) { + final int nRows = m1.getNumRows(); + long nnz = 0; + nnz += new BinaryMVColTaskSparseSparse(m1, m2, ret, 0, nRows, op, left).call(); + return nnz; + } + private static long binaryMVColMultiThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left, MatrixBlock ret) throws Exception { final int nRows = m1.getNumRows(); @@ -641,6 +660,27 @@ private static long binaryMVColMultiThreadSparse(CompressedMatrixBlock m1, Matri return nnz; } + private static long binaryMVColMultiThreadSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) throws Exception { + final int nRows = m1.getNumRows(); + final int k = op.getNumThreads(); + final int blkz = Math.max(nRows / k, 64); + long nnz = 0; + final ExecutorService pool = CommonThreadPool.get(op.getNumThreads()); + try { + final ArrayList> tasks = new ArrayList<>(); + for(int i = 0; i < nRows; i += blkz) { + tasks.add(new BinaryMVColTaskSparseSparse(m1, m2, ret, i, Math.min(nRows, i + blkz), op, left)); + } + for(Future f : pool.invokeAll(tasks)) + nnz += f.get(); + } + finally { + pool.shutdown(); + } + return nnz; + } + private static MatrixBlock mmCompressed(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left) throws Exception { final int nCols = m1.getNumColumns(); @@ -724,8 +764,8 @@ private static class BinaryMVColTaskCompressed implements Callable { private MatrixBlock tmp; - protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, int rl, int ru, - BinaryOperator op, boolean left) { + protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, int rl, int ru, BinaryOperator op, + boolean left) { _m1 = m1; _m2 = m2; _op = op; @@ -738,21 +778,21 @@ protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, in @Override public Long call() { - tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); - final int _blklen = tmp.getNumRows(); + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlock(_blklen, _m1.getNumColumns()); final List groups = _m1.getColGroups(); final AIterator[] its = getIterators(groups, _rl); long nnz = 0; if(!_left) - for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + for(int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen) { int ru = Math.min(rl + _blklen, _ru); decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); nnz += processDense(rl, ru, retIxOff); tmp.reset(); } else - for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + for(int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen) { int ru = Math.min(rl + _blklen, _ru); decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); nnz += processDenseLeft(rl, ru, retIxOff); @@ -770,18 +810,24 @@ private final long processDense(final int rl, final int ru, final int retIxOffse for(int row = rl, retIx = retIxOffset; row < ru; row++, retIx++) { final double vr = _m2Dense[row]; final int tmpOff = (row - rl) * nCol; - int indicatorVector = 0; - for(int col = 0; col < nCol; col++) { - indicatorVector = indicatorVector << 1; - int indicator = _compFn.compare(_tmpDense[tmpOff + col], vr) ? 1 : 0; - indicatorVector += indicator; - nnz += indicator; - } - _ret[retIx] = indicatorVector; + nnz = processRow(nCol, _tmpDense, nnz, retIx, vr, tmpOff); } return nnz; } + private final long processRow(final int nCol, final double[] _tmpDense, long nnz, int retIx, final double vr, + final int tmpOff) { + int indicatorVector = 0; + for(int col = tmpOff; col < nCol + tmpOff; col++) { + indicatorVector = indicatorVector << 1; + int indicator = _compFn.compare(_tmpDense[col], vr) ? 1 : 0; + indicatorVector += indicator; + nnz += indicator; + } + _ret[retIx] = indicatorVector; + return nnz; + } + private final long processDenseLeft(final int rl, final int ru, final int retIxOffset) { final int nCol = _m1.getNumColumns(); final double[] _tmpDense = tmp.getDenseBlockValues(); @@ -847,7 +893,8 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { // unsafe decompress, since we count nonzeros afterwards. final DenseBlock db = _ret.getDenseBlock(); decompressToSubBlock(rl, ru, db, groups, its); @@ -887,7 +934,7 @@ private void processRow(final int ncol, final double[] ret, final int posR, fina private void processRowLeft(final int ncol, final double[] ret, final int posR, final double vr) { for(int col = 0; col < ncol; col++) - ret[posR + col] = _op.fn.execute(vr,ret[posR + col]); + ret[posR + col] = _op.fn.execute(vr, ret[posR + col]); } } @@ -917,8 +964,8 @@ protected BinaryMVColTaskSparse(CompressedMatrixBlock m1, MatrixBlock m2, Matrix @Override public Long call() { - tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); - final int _blklen = tmp.getNumRows(); + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlock(_blklen, _m1.getNumColumns()); final List groups = _m1.getColGroups(); final AIterator[] its = getIterators(groups, _rl); if(!_left) @@ -936,7 +983,8 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); processDenseLeft(rl, ru); tmp.reset(); @@ -971,8 +1019,107 @@ private final void processDenseLeft(final int rl, final int ru) { } } - private static MatrixBlock allocateTempUncompressedBlock(int cols) { - MatrixBlock out = new MatrixBlock(Math.max(DECOMPRESSION_BLEN / cols, 64), cols, false); + private static class BinaryMVColTaskSparseSparse implements Callable { + private final int _rl; + private final int _ru; + private final CompressedMatrixBlock _m1; + private final MatrixBlock _m2; + private final MatrixBlock _ret; + private final BinaryOperator _op; + + private MatrixBlock tmp; + + private boolean _left; + + protected BinaryMVColTaskSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, + BinaryOperator op, boolean left) { + _m1 = m1; + _m2 = m2; + _ret = ret; + _op = op; + _rl = rl; + _ru = ru; + _left = left; + } + + @Override + public Long call() { + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlockSparse(_blklen, _m1.getNumColumns()); + final List groups = _m1.getColGroups(); + final AIterator[] its = getIterators(groups, _rl); + if(!_left) + for(int r = _rl; r < _ru; r += _blklen) + processBlock(r, Math.min(r + _blklen, _ru), groups, its); + else + for(int r = _rl; r < _ru; r += _blklen) + processBlockLeft(r, Math.min(r + _blklen, _ru), groups, its); + return _ret.recomputeNonZeros(_rl, _ru - 1); + } + + private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getSparseBlock(), groups, its); + processDense(rl, ru); + tmp.reset(); + } + + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getSparseBlock(), groups, its); + processDenseLeft(rl, ru); + tmp.reset(); + } + + private final void processDense(final int rl, final int ru) { + final SparseBlock sb = _ret.getSparseBlock(); + final SparseBlock _tmpSparse = tmp.getSparseBlock(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl); + if(!_tmpSparse.isEmpty(tmpOff)){ + int[] aoff = _tmpSparse.indexes(tmpOff); + double[] aval = _tmpSparse.values(tmpOff); + int apos = _tmpSparse.pos(tmpOff); + int alen = apos + _tmpSparse.size(tmpOff); + + for(int j = apos; j < alen; j++){ + sb.append(row, aoff[j], _op.fn.execute(aval[j], vr)); + } + } + + } + } + + private final void processDenseLeft(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final SparseBlock sb = _ret.getSparseBlock(); + final SparseBlock _tmpSparse = tmp.getSparseBlock(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl) * nCol; + if(!_tmpSparse.isEmpty(tmpOff)){ + int[] aoff = _tmpSparse.indexes(tmpOff); + double[] aval = _tmpSparse.values(tmpOff); + int apos = _tmpSparse.pos(tmpOff); + int alen = apos + _tmpSparse.size(tmpOff); + for(int j = apos; j < alen; j++){ + sb.append(row, aoff[j], _op.fn.execute(vr,aval[j])); + } + } + } + } + } + + private static MatrixBlock allocateTempUncompressedBlock(int blklen, int cols) { + MatrixBlock out = new MatrixBlock(blklen, cols, false); + out.allocateBlock(); + return out; + } + + private static MatrixBlock allocateTempUncompressedBlockSparse(int blklen, int cols) { + MatrixBlock out = new MatrixBlock(blklen, cols, true); out.allocateBlock(); return out; } @@ -1199,6 +1346,25 @@ protected static void decompressToTmpBlock(final int rl, final int ru, final Den } } + protected static void decompressToTmpBlock(final int rl, final int ru, final SparseBlock db, + final List groups, final AIterator[] its) { + Timing time = new Timing(true); + for(int i = 0; i < groups.size(); i++) { + final AColGroup g = groups.get(i); + if(g.getCompType() == CompressionType.SDC) + ((ASDCZero) g).decompressToSparseBlock(db, rl, ru, -rl, 0, its[i]); + else + g.decompressToSparseBlock(db, rl, ru, -rl, 0); + } + + if(DMLScript.STATISTICS) { + final double t = time.stop(); + DMLCompressionStatistics.addDecompressToBlockTime(t, 1); + if(LOG.isTraceEnabled()) + LOG.trace("decompressed block w/ k=" + 1 + " in " + t + "ms."); + } + } + protected static AIterator[] getIterators(final List groups, final int rl) { final AIterator[] its = new AIterator[groups.size()]; for(int i = 0; i < groups.size(); i++) { @@ -1210,8 +1376,8 @@ protected static AIterator[] getIterators(final List groups, final in return its; } - private static Pair evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, - boolean left) { + private static Pair evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, + BinaryOperator op, boolean left) { final List groups = m1.getColGroups(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); @@ -1247,7 +1413,7 @@ private static Pair evaluateSparsityMVCol(CompressedMatrixBlock for(int r = 0; r < sampleRow; r++) { final double m = m2v[r]; final int off = r * sampleCol; - for(int c = 0; c < sampleCol; c++){ + for(int c = 0; c < sampleCol; c++) { int outVal = op.fn.execute(dv[off + c], m) != 0 ? 1 : 0; nnz += outVal; nnzPerRow[r] += outVal; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 99693635a9b..948a78f96af 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -486,7 +486,7 @@ private static List> generateUnaryAggregateOverlappingFuture final ArrayList tasks = new ArrayList<>(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); - final int blklen = Math.max(64, nRow / k); + final int blklen = Math.max(64, (nRow + k) / k); final List groups = m1.getColGroups(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); if(shouldFilter) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index d82d58e323e..cc7953f8c5d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -95,6 +96,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix if(x.isEmpty()) return returnEmpty(x, out); + if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){ + MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k); + return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k)); + } + // Morph the columns to efficient types for the operation. x = filterColGroups(x); double preFilterTime = t.stop(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java new file mode 100644 index 00000000000..c793e84ebef --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java @@ -0,0 +1,37 @@ +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue; + +public class CLALibSort { + + public static MatrixBlock sort(CompressedMatrixBlock mb, MatrixValue weights, MatrixBlock result, int k) { + // force uncompressed weights + weights = CompressedMatrixBlock.getUncompressed(weights); + + if(mb.getNumColumns() == 1 && mb.getColGroups().size() == 1 && weights == null) { + return sortSingleCol(mb, k); + } + + // fallback to uncompressed. + return CompressedMatrixBlock// + .getUncompressed(mb, "sortOperations")// + .sortOperations(weights, result); + } + + private static MatrixBlock sortSingleCol(CompressedMatrixBlock mb, int k) { + + AColGroup g = mb.getColGroups().get(0); + + AColGroup r = g.sort(); + + List rg = new ArrayList<>(); + rg.add(r); + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, rg); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java index a1d47a9b150..e0643572eae 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java @@ -42,6 +42,10 @@ private CLALibTSMM() { // private constructor } + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) { + return leftMultByTransposeSelf(cmb, new MatrixBlock(), k); + } + /** * Self left Matrix multiplication (tsmm) * @@ -51,17 +55,25 @@ private CLALibTSMM() { * @param ret The output matrix to put the result into * @param k The parallelization degree allowed */ - public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + final int numColumns = cmb.getNumColumns(); + final int numRows = cmb.getNumRows(); + if(cmb.isEmpty()) + return new MatrixBlock(numColumns, numColumns, true); + // create output matrix block + if(ret == null) + ret = new MatrixBlock(numColumns, numColumns, false); + else + ret.reset(numColumns, numColumns, false); + ret.allocateDenseBlock(); final List groups = cmb.getColGroups(); - final int numColumns = cmb.getNumColumns(); if(groups.size() >= numColumns) { MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k); LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k); - return; + return ret; } - final int numRows = cmb.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); final boolean overlapping = cmb.isOverlapping(); if(shouldFilter) { @@ -77,6 +89,7 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret)); ret.examSparsity(); + return ret; } private static void addCorrectionLayer(List filteredGroups, MatrixBlock result, int nRows, int nCols, @@ -86,8 +99,6 @@ private static void addCorrectionLayer(List filteredGroups, MatrixBlo addCorrectionLayer(constV, filteredColSum, nRows, retV); } - - private static void tsmmColGroups(List groups, MatrixBlock ret, int nRows, boolean overlapping, int k) { if(k <= 1) tsmmColGroupsSingleThread(groups, ret, nRows); @@ -136,12 +147,12 @@ private static void tsmmColGroupsMultiThread(List groups, MatrixBlock public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) { final int nColRow = constV.length; - for(int row = 0; row < nColRow; row++){ + for(int row = 0; row < nColRow; row++) { int offOut = nColRow * row; final double v1l = constV[row]; final double v2l = filteredColSum[row] + constV[row] * nRow; - for(int col = row; col < nColRow; col++){ - ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; + for(int col = row; col < nColRow; col++) { + ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java new file mode 100644 index 00000000000..29650048509 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.utils; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +public class HashMapIntToInt implements Map { + + static final int DEFAULT_INITIAL_CAPACITY = 1 << 4; + static final float DEFAULT_LOAD_FACTOR = 0.75f; + + protected Node[] buckets; + + protected int size; + + public HashMapIntToInt(int capacity) { + alloc(Math.max(capacity, DEFAULT_INITIAL_CAPACITY)); + } + + protected void alloc(int size) { + Node[] tmp = (Node[]) new Node[size]; + buckets = tmp; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public boolean containsKey(Object key) { + return getI((Integer) key) != -1; + } + + @Override + public boolean containsValue(Object value) { + if(value instanceof Integer) { + for(Entry v : this.entrySet()) { + if(v.getValue().equals(value)) + return true; + } + } + return false; + + } + + @Override + public Integer get(Object key) { + final int i = getI((Integer) key); + if(i != -1) + return i; + else + return null; + } + + public int getI(int key) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b != null) { + do { + if(key == b.key) + return b.value; + } + while((b = b.next) != null); + } + return -1; + + } + + public int hash(int key) { + return Math.abs(Integer.hashCode(key) % buckets.length); + } + + @Override + public Integer put(Integer key, Integer value) { + int i = putI(key, value); + if(i != -1) + return i; + else + return null; + } + + @Override + public Integer putIfAbsent(Integer key, Integer value) { + int i = putIfAbsentI(key, value); + if(i != -1) + return i; + else + return null; + } + + public int putIfAbsentI(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return putIfAbsentBucket(ix, key, value); + + } + + public int putIfAbsentReturnVal(int key, int value) { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + } + + public int putIfAbsentReturnValHash(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + + } + + private int putIfAbsentBucket(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(b.key == key) + return b.value; + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return -1; + } + b = b.next; + } + } + + private int putIfAbsentBucketReturnval(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(b.key == key) + return b.value; + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return value; + } + b = b.next; + } + } + + public int putI(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return addToBucket(ix, key, value); + + } + + private int createBucket(int ix, int key, int value) { + buckets[ix] = new Node(key, value, null); + size++; + return -1; + } + + private int createBucketReturnVal(int ix, int key, int value) { + buckets[ix] = new Node(key, value, null); + size++; + return value; + } + + private int addToBucket(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(key == b.key) { + int tmp = b.getValue(); + b.setValue(value); + return tmp; + } + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return -1; + } + b = b.next; + } + } + + private void resize() { + if(size > buckets.length * DEFAULT_LOAD_FACTOR) { + + Node[] tmp = (Node[]) new Node[buckets.length * 2]; + Node[] oldBuckets = buckets; + buckets = tmp; + size = 0; + for(Node n : oldBuckets) { + if(n != null) + do { + put(n.key, n.value); + } + while((n = n.next) != null); + } + + } + } + + @Override + public Integer remove(Object key) { + throw new UnsupportedOperationException("Unimplemented method 'remove'"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("Unimplemented method 'putAll'"); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("Unimplemented method 'clear'"); + } + + @Override + public Set keySet() { + throw new UnsupportedOperationException("Unimplemented method 'keySet'"); + } + + @Override + public Collection values() { + throw new UnsupportedOperationException("Unimplemented method 'values'"); + } + + @Override + public Set> entrySet() { + return new EntrySet(); + } + + @Override + public void forEach(BiConsumer action) { + + for(Node n : buckets) { + if(n != null) { + do { + action.accept(n.key, n.value); + } + while((n = n.next) != null); + } + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(size() * 3); + this.forEach((k, v) -> { + sb.append("(" + k + "→" + v + ")"); + }); + return sb.toString(); + } + + private static class Node implements Entry { + final int key; + int value; + Node next; + + Node(int key, int value, Node next) { + this.key = key; + this.value = value; + this.next = next; + } + + public final void setNext(Node n) { + next = n; + } + + @Override + public Integer getKey() { + return key; + } + + @Override + public Integer getValue() { + return value; + } + + @Override + public Integer setValue(Integer value) { + return this.value = value; + } + } + + private final class EntrySet extends AbstractSet> { + + @Override + public int size() { + return size; + } + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + } + + private final class EntryIterator implements Iterator> { + Node next; + int bucketId = 0; + + protected EntryIterator() { + + for(; bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public Entry next() { + + Node e = next; + + if(e.next != null) + next = e.next; + else { + for(; ++bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + if(bucketId >= buckets.length) + next = null; + } + + return e; + } + + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java index fc0aa3b1a29..4940dd801b3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java @@ -27,9 +27,18 @@ import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.IndexFunction; +import org.apache.sysds.runtime.functionobjects.KahanPlus; +import org.apache.sysds.runtime.functionobjects.Mean; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction; +import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; public class FederatedWorkloadAnalyzer { protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName()); @@ -55,7 +64,7 @@ public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) { } public void compressRun(ExecutionContext ec, long tid) { - if(counter >= compressRunFrequency ){ + if(counter >= compressRunFrequency) { counter = 0; get(tid).forEach((K, V) -> CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V)); } @@ -68,6 +77,7 @@ private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstr public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap mm, ComputationCPInstruction cpIns) { // TODO: Count transitive closure via lineage + // TODO: add more operations if(cpIns instanceof AggregateBinaryCPInstruction) { final String n1 = cpIns.input1.getName(); MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1); @@ -81,15 +91,48 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap mm, long id) { @@ -117,8 +160,8 @@ private static boolean validSize(int nRow, int nCol) { return nRow > 90 && nRow >= nCol; } - @Override - public String toString(){ + @Override + public String toString() { StringBuilder sb = new StringBuilder(); sb.append(this.getClass().getSimpleName()); sb.append(" Counter: "); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 99cce9f9e97..972a2893fd8 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -377,7 +377,7 @@ public static double parseDouble(String value) { return Double.POSITIVE_INFINITY; else if(len == 4 && value.compareToIgnoreCase("-Inf") == 0) return Double.NEGATIVE_INFINITY; - throw new DMLRuntimeException(e); + throw e; } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index b26695e5797..84e4e89a420 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -174,6 +174,29 @@ public int putIfAbsentReturnVal(K key, int value) { } + + public int putIfAbsentReturnValHash(K key, int value) { + + if(key == null) { + if(nullV == -1) { + size++; + nullV = value; + return -1; + } + else + return nullV; + } + else { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + } + + } + private int putIfAbsentBucket(int ix, K key, int value) { Node b = buckets[ix]; while(true) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 1fc582924e4..292fcb52bf5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -607,7 +607,6 @@ public double getAsNaNDouble(int i) { private static double getAsDouble(String s) { try { - return DoubleArray.parseDouble(s); } catch(Exception e) { @@ -617,7 +616,8 @@ private static double getAsDouble(String s) { else if(ls.equals("false") || ls.equals("f")) return 0; else - throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw e; // for efficiency + // throw new DMLRuntimeException("Unable to change to double: " + s, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java index 032afe2cd7c..987d14106ac 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java @@ -32,11 +32,17 @@ import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; -public interface MatrixBlockFromFrame { +public class MatrixBlockFromFrame { public static final Log LOG = LogFactory.getLog(MatrixBlockFromFrame.class.getName()); public static final int blocksizeIJ = 32; + public static Boolean WARNED_FOR_FAILED_CAST = false; + + private MatrixBlockFromFrame(){ + // private constructor for code coverage. + } + /** * Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type * double, we do a best effort conversion of non-double types which might result in errors for non-numerical data. @@ -94,10 +100,25 @@ else if(ret.getNumRows() != m || ret.getNumColumns() != n || ret.isInSparseForma } private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) { - if(mb.getDenseBlock().isContiguous()) - return convertContiguous(frame, mb, n, rl, ru); - else - return convertGeneric(frame, mb, n, rl, ru); + try { + + if(mb.getDenseBlock().isContiguous()) + return convertContiguous(frame, mb, n, rl, ru); + else + return convertGeneric(frame, mb, n, rl, ru); + } + catch(NumberFormatException | DMLRuntimeException e) { + synchronized(WARNED_FOR_FAILED_CAST){ + if(!WARNED_FOR_FAILED_CAST) { + LOG.error( + "Failed to convert to Matrix because of number format errors, falling back to NaN on incompatible cells", + e); + WARNED_FOR_FAILED_CAST = true; + } + } + return convertSafeCast(frame, mb, n, rl, ru); + + } } private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k) throws Exception { @@ -169,4 +190,37 @@ private static long convertBlockGeneric(final FrameBlock frame, long lnnz, final } return lnnz; } + + private static long convertSafeCast(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, + final int ru) { + final DenseBlock c = mb.getDenseBlock(); + long lnnz = 0; + for(int bi = rl; bi < ru; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + int bimin = Math.min(bi + blocksizeIJ, ru); + int bjmin = Math.min(bj + blocksizeIJ, n); + lnnz = convertBlockSafeCast(frame, lnnz, c, bi, bj, bimin, bjmin); + } + } + return lnnz; + } + + private static long convertBlockSafeCast(final FrameBlock frame, long lnnz, final DenseBlock c, final int rl, + final int cl, final int ru, final int cu) { + for(int i = rl; i < ru; i++) { + final double[] cvals = c.values(i); + final int cpos = c.pos(i); + for(int j = cl; j < cu; j++) { + try { + lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; + } + catch(NumberFormatException | DMLRuntimeException e) { + lnnz += 1; + cvals[cpos + j] = Double.NaN; + } + } + } + return lnnz; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java index 6b196489eac..02eaa332a1a 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java @@ -51,7 +51,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, - DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, + DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK, MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE} @@ -113,6 +113,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, String2BuiltinCode.put( "_map", BuiltinCode.MAP); String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP); String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA); + String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK); } protected Builtin(BuiltinCode bf) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 28b8775ebd5..86184f47be6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME) return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str); + else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR) + return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX) return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str); else diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java new file mode 100644 index 00000000000..99b3c1a3b13 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.util.Arrays; + +import org.apache.sysds.common.Builtins; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; +import org.apache.sysds.runtime.transform.TfUtils.TfMethod; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONObject; + +public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction { + // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); + + protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, + String opcode, String istr) { + super(CPType.Binary, op, in1, in2, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + // get input frames + FrameBlock inBlock1 = ec.getFrameInput(input1.getName()); + ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true); + if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) { + processGetCategorical(ec, inBlock1, spec); + } + else { + throw new DMLRuntimeException("Unsupported operation"); + } + + // Release the memory occupied by input frames + ec.releaseFrameInput(input1.getName()); + } + + public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + try { + + // MatrixBlock ret = new MatrixBlock(); + int nCol = f.getNumColumns(); + + JSONObject jSpec = new JSONObject(spec.getStringValue()); + + if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) { + throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); + } + + String recode = TfMethod.RECODE.toString(); + String dummycode = TfMethod.DUMMYCODE.toString(); + + int[] lengths = new int[nCol]; + // assume all columns encode to at least one column. + Arrays.fill(lengths, 1); + boolean[] categorical = new boolean[nCol]; + + if(jSpec.containsKey(recode)) { + JSONArray a = jSpec.getJSONArray(recode); + for(Object aa : a) { + int av = (Integer) aa - 1; + categorical[av] = true; + } + } + + if(jSpec.containsKey(dummycode)) { + JSONArray a = jSpec.getJSONArray(dummycode); + for(Object aa : a) { + int av = (Integer) aa - 1; + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + lengths[av] = ndist; + categorical[av] = true; + } + } + + // get total size after mapping + + int sumLengths = 0; + for(int i : lengths) { + sumLengths += i; + } + + MatrixBlock ret = new MatrixBlock(1, sumLengths, false); + ret.allocateDenseBlock(); + int off = 0; + for(int i = 0; i < lengths.length; i++) { + for(int j = 0; j < lengths[i]; j++) { + ret.set(0, off++, categorical[i] ? 1 : 0); + } + } + + ec.setMatrixOutput(output.getName(), ret); + + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..e53958ac4b8 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -352,7 +352,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString())) { // compute transformdecode Decoder decoder = DecoderFactory .createDecoder(getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns()); - FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema())); + FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()), InfrastructureAnalyzer.getLocalParallelism()); fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns())); // release locks diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 8400ec54e6f..7bcf7f063ed 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.io.FileFormatProperties; @@ -918,7 +919,7 @@ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { switch( getInput1().getDataType() ) { case FRAME: { FrameBlock fin = ec.getFrameInput(getInput1().getName()); - MatrixBlock out = DataConverter.convertToMatrixBlock(fin); + MatrixBlock out = MatrixBlockFromFrame.convertToMatrixBlock(fin, k); ec.releaseFrameInput(getInput1().getName()); ec.setMatrixOutput(output.getName(), out); break; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java new file mode 100644 index 00000000000..79f08cb353a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.CorrectionLocationType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.cp.KahanObject; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; + +public class LibAggregateUnarySpecialization { + protected static final Log LOG = LogFactory.getLog(LibAggregateUnarySpecialization.class.getName()); + + public static void aggregateUnary(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.sparseSafe) + sparseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + else + denseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + } + + private static void sparseAggregateUnaryHelp(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, + int blen, MatrixIndexes indexesIn) { + // initialize result + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + + if(mb.sparse && mb.sparseBlock != null) { + SparseBlock a = mb.sparseBlock; + for(int r = 0; r < Math.min(mb.rlen, a.numRows()); r++) { + if(a.isEmpty(r)) + continue; + int apos = a.pos(r); + int alen = a.size(r); + int[] aix = a.indexes(r); + double[] aval = a.values(r); + for(int i = apos; i < apos + alen; i++) { + tempCellIndex.set(r, aix[i]); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, aval[i], + buffer); + } + } + } + else if(!mb.sparse && mb.denseBlock != null) { + DenseBlock a = mb.getDenseBlock(); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, a.get(i, j), + buffer); + } + } + } + + private static void denseAggregateUnaryHelp(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, + mb.get(i, j), buffer); + } + } + + private static void incrementalAggregateUnaryHelp(AggregateOperator aggOp, MatrixBlock result, int row, int column, + double newvalue, KahanObject buffer) { + if(aggOp.existsCorrection()) { + if(aggOp.correction == CorrectionLocationType.LASTROW || + aggOp.correction == CorrectionLocationType.LASTCOLUMN) { + int corRow = row, corCol = column; + if(aggOp.correction == CorrectionLocationType.LASTROW)// extra row + corRow++; + else if(aggOp.correction == CorrectionLocationType.LASTCOLUMN) + corCol++; + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + + buffer._sum = result.get(row, column); + buffer._correction = result.get(corRow, corCol); + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue); + result.set(row, column, buffer._sum); + result.set(corRow, corCol, buffer._correction); + } + else if(aggOp.correction == CorrectionLocationType.NONE) { + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + } + else// for mean + { + int corRow = row, corCol = column; + int countRow = row, countCol = column; + if(aggOp.correction == CorrectionLocationType.LASTTWOROWS) { + countRow++; + corRow += 2; + } + else if(aggOp.correction == CorrectionLocationType.LASTTWOCOLUMNS) { + countCol++; + corCol += 2; + } + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + buffer._sum = result.get(row, column); + buffer._correction = result.get(corRow, corCol); + double count = result.get(countRow, countCol) + 1.0; + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue, count); + result.set(row, column, buffer._sum); + result.set(corRow, corCol, buffer._correction); + result.set(countRow, countCol, count); + } + + } + else { + newvalue = aggOp.increOp.fn.execute(result.get(row, column), newvalue); + result.set(row, column, newvalue); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index af702cb7fad..3113850ec80 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -3234,6 +3234,11 @@ private static void matrixMultWDivMMDense(MatrixBlock mW, MatrixBlock mU, Matrix DenseBlock x = (mX==null) ? null : mX.getDenseBlock(); DenseBlock c = ret.getDenseBlock(); + if(c == null){ + ret.allocateDenseBlock(); + c = ret.getDenseBlock(); + } + //approach: iterate over non-zeros of w, selective mm computation //cache-conscious blocking: due to blocksize constraint (default 1000), //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) @@ -3380,6 +3385,11 @@ private static void matrixMultWDivMMGeneric(MatrixBlock mW, MatrixBlock mU, Matr //output always in dense representation DenseBlock c = ret.getDenseBlock(); + + if(c == null){ + ret.allocateDenseBlock(); + c = ret.getDenseBlock(); + } //approach: iterate over non-zeros of w, selective mm computation if( mW.sparse ) //SPARSE diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 7bc516588a5..666f33ded48 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -1312,7 +1312,7 @@ public void examSparsity(boolean allowCSR, int k) { else if( !sparse && sparseDst ) denseToSparse(allowCSR, k); } - + public static boolean evalSparseFormatInMemory(DataCharacteristics dc) { return evalSparseFormatInMemory(dc.getRows(), dc.getCols(), dc.getNonZeros()); } @@ -1384,12 +1384,13 @@ public void denseToSparse(boolean allowCSR, int k){ LibMatrixDenseToSparse.denseToSparse(this, allowCSR, k); } - public final void sparseToDense() { - sparseToDense(1); + public final MatrixBlock sparseToDense() { + return sparseToDense(1); } - public void sparseToDense(int k) { + public MatrixBlock sparseToDense(int k) { LibMatrixSparseToDense.sparseToDense(this, k); + return this; } /** @@ -2951,13 +2952,14 @@ public boolean isShallowSerialize(boolean inclConvert) { boolean sparseDst = evalSparseFormatOnDisk(); return !sparse || !sparseDst || (sparse && sparseBlock instanceof SparseBlockCSR) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD - <= getExactSerializedSize()) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && nonZeros < Integer.MAX_VALUE //CSR constraint - && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE - && !isUltraSparseSerialize(sparseDst)); + || (sparse && sparseBlock instanceof SparseBlockMCSR); + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD + // <= getExactSerializedSize()) + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && nonZeros < Integer.MAX_VALUE //CSR constraint + // && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE + // && !isUltraSparseSerialize(sparseDst)); } @Override @@ -4647,7 +4649,7 @@ public final MatrixBlock sortOperations(MatrixValue weights){ return sortOperations(weights, null); } - public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { + public final MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { return sortOperations(weights, result, 1); } @@ -4751,7 +4753,17 @@ public static double computeIQMCorrection(double sum, double sum_wt, return (sum + q25Part*q25Val - q75Part*q75Val) / (sum_wt*0.5); } - public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { + /** + * Pick the quantiles out of this matrix. If this matrix contains two columns it is weighted quantile picking. + * If a single column it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantiles The quantiles to pick + * @param ret The result matrix + * @return The result matrix + */ + public final MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { MatrixBlock qs=checkType(quantiles); @@ -4772,17 +4784,56 @@ public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { return output; } - + + /** + * Pick the median quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @return The quantile + */ public double median() { double sum_wt = sumWeightForQuantile(); return pickValue(0.5, sum_wt%2==0); } - + + /** + * Pick a specific quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @return The quantile + */ public final double pickValue(double quantile){ return pickValue(quantile, false); } - public double pickValue(double quantile, boolean average) { + /** + * Pick a specific quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @param average If the quantile is averaged. + * @return The quantile + */ + public final double pickValue(double quantile, boolean average) { + if(this.getNumColumns() == 1) + return pickUnweightedValue(quantile, average); + return pickWeightedValue(quantile, average); + } + + private double pickUnweightedValue(double quantile, boolean average) { + double pos = quantile * rlen; + if(average && (int) pos != pos) + return (get((int) Math.floor(pos), 0) + get(Math.min(rlen - 1, (int) Math.ceil(pos)), 0)) / 2; + else + return get(Math.min(rlen - 1, (int) Math.round(pos)), 0); + } + + private double pickWeightedValue(double quantile, boolean average) { double sum_wt = sumWeightForQuantile(); // do averaging only if it is asked for; and sum_wt is even @@ -5254,8 +5305,8 @@ public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar * (i1,j1,v2) from input2 (that) * (w) from scalar_input3 (scalarThat2) * - * @param thatMatrix matrix value - * @param thatScalar scalar double + * @param thatMatrix matrix value, the vector to encode via table + * @param thatScalar scalar double, w, that is the weight to multiply on the encoded values * @param resultBlock result matrix block * @return resultBlock */ diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 724af1be630..70834675ded 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -23,6 +23,10 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +34,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Base class for all transform decoders providing both a row and block @@ -77,8 +82,31 @@ public String[] getColnames() { * @param k Parallelization degree * @return returns the given output frame block for convenience */ - public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { - return decode(in, out); + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max((in.getNumRows() + k) / k, 1000); + + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); + } + + for(Future f : tasks) + f.get(); + return out; + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index edee095f612..c9fcc23990a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -43,15 +44,18 @@ public class DecoderBin extends Decoder { // a) column bin boundaries private int[] _numBins; + private int[] _dcCols = null; + private int[] _srcCols = null; private double[][] _binMins = null; private double[][] _binMaxs = null; - public DecoderBin() { - super(null, null); - } + // public DecoderBin() { + // super(null, null); + // } - protected DecoderBin(ValueType[] schema, int[] binCols) { + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); + _dcCols = dcCols; } @Override @@ -66,14 +70,28 @@ public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { for( int i=rl; i< ru; i++ ) { for( int j=0; j<_colList.length; j++ ) { final Array a = out.getColumn(_colList[j] - 1); - final double val = in.get(i, _colList[j] - 1); + final double val = in.get(i, _srcCols[j] - 1); if(!Double.isNaN(val)){ - final int key = (int) Math.round(val); - double bmin = _binMins[j][key - 1]; - double bmax = _binMaxs[j][key - 1]; - double oval = bmin + (bmax - bmin) / 2 // bin center - + (val - key) * (bmax - bmin); // bin fractions - a.set(i, oval); + try{ + + final int key = (int) Math.round(val); + if(key == 0){ + a.set(i, _binMins[j][key]); + } + else{ + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } + } + catch(Exception e){ + LOG.error(a); + LOG.error(in.slice(0, in.getNumRows()-1, _colList[j]-1,_colList[j]-1)); + LOG.error( val); + throw e; + } } else a.set(i, val); // NaN @@ -111,6 +129,34 @@ public void initMetaData(FrameBlock meta) { _binMaxs[j][i] = Double.parseDouble(parts[1]); } } + + + if( _dcCols.length > 0 ) { + //prepare source column id mapping w/ dummy coding + _srcCols = new int[_colList.length]; + int ix1 = 0, ix2 = 0, off = 0; + while( ix1<_colList.length ) { + if( ix2>=_dcCols.length || _colList[ix1] < _dcCols[ix2] ) { + _srcCols[ix1] = _colList[ix1] + off; + ix1 ++; + } + else { //_colList[ix1] > _dcCols[ix2] + ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; + String v = meta.getString(0, _dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } + ix2 ++; + } + } + } + else { + //prepare direct source column mapping + _srcCols = _colList; + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..dff85e72dc6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -25,13 +25,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; /** * Simple composite decoder that applies a list of decoders @@ -50,7 +47,7 @@ protected DecoderComposite(ValueType[] schema, List decoders) { _decoders = decoders; } - public DecoderComposite() { super(null, null); } + // public DecoderComposite() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -59,33 +56,6 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { return out; } - - @Override - public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { - final ExecutorService pool = CommonThreadPool.get(k); - out.ensureAllocatedColumns(in.getNumRows()); - try { - final List> tasks = new ArrayList<>(); - int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } - } - for(Future f : tasks) - f.get(); - return out; - } - catch(Exception e) { - throw new RuntimeException(e); - } - finally { - pool.shutdown(); - } - } - @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ for( Decoder decoder : _decoders ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0c4c6b42690..debce027680 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -27,31 +27,30 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; /** - * Simple atomic decoder for dummycoded columns. This decoder builds internally - * inverted column mappings from the given frame meta data. - * + * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given + * frame meta data. + * */ -public class DecoderDummycode extends Decoder -{ +public class DecoderDummycode extends Decoder { private static final long serialVersionUID = 4758831042891032129L; - + private int[] _clPos = null; private int[] _cuPos = null; - + protected DecoderDummycode(ValueType[] schema, int[] dcCols) { - //dcCols refers to column IDs in output (non-dc) + // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { - //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); decode(in, out, 0, in.getNumRows()); return out; @@ -59,59 +58,98 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { - //TODO perf (exploit sparse representation for better asymptotic behavior) - // out.ensureAllocatedColumns(in.getNumRows()); - for( int i=rl; i= low && aix[h] < high) { + int k = aix[h]; + int col = _colList[j] - 1; + out.getColumn(col).set(i, k - _clPos[j] + 1); + } + // limit the binary search. + apos = h; + } + + } + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List dcList = new ArrayList<>(); List clPosList = new ArrayList<>(); List cuPosList = new ArrayList<>(); - + // get the column IDs for the sub range of the dummycode columns and their destination positions, // where they will be decoded to - for( int j=0; j<_colList.length; j++ ) { + for(int j = 0; j < _colList.length; j++) { int colID = _colList[j]; - if (colID >= colStart && colID < colEnd) { + if(colID >= colStart && colID < colEnd) { dcList.add(colID - (colStart - 1)); clPosList.add(_clPos[j] - dummycodedOffset); cuPosList.add(_cuPos[j] - dummycodedOffset); } } - if (dcList.isEmpty()) + if(dcList.isEmpty()) return null; // create sub-range decoder int[] colList = dcList.stream().mapToInt(i -> i).toArray(); - DecoderDummycode subRangeDecoder = new DecoderDummycode( - Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + DecoderDummycode subRangeDecoder = new DecoderDummycode(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList); subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); return subRangeDecoder; } - + @Override public void updateIndexRanges(long[] beginDims, long[] endDims) { if(_colList == null) return; - + long lowerColDest = beginDims[1]; long upperColDest = endDims[1]; for(int i = 0; i < _colList.length; i++) { long numDistinct = _cuPos[i] - _clPos[i]; - + if(_cuPos[i] <= beginDims[1] + 1) if(numDistinct > 0) lowerColDest -= numDistinct - 1; - + if(_cuPos[i] <= endDims[1] + 1) if(numDistinct > 0) upperColDest -= numDistinct - 1; @@ -119,16 +157,25 @@ public void updateIndexRanges(long[] beginDims, long[] endDims) { beginDims[1] = lowerColDest; endDims[1] = upperColDest; } - + @Override public void initMetaData(FrameBlock meta) { - _clPos = new int[_colList.length]; //col lower pos - _cuPos = new int[_colList.length]; //col upper pos - for( int j=0, off=0; j<_colList.length; j++ ) { + _clPos = new int[_colList.length]; // col lower pos + _cuPos = new int[_colList.length]; // col upper pos + for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID-1]; - int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct(); - ndist = ndist < -1 ? 0: ndist; + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + String v = meta.getString(0, colID - 1); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + + ndist = ndist < -1 ? 0 : ndist; // safety if all values was null. + _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; off += ndist - 1; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 0a400e6da92..12ba2968877 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -64,34 +64,52 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] try { //parse transform specification JSONObject jSpec = new JSONObject(spec); - List ldecoders = new ArrayList<>(); - //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'hash', 'dummy', and 'pass-through' List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); + List hcIDs = Arrays.asList(ArrayUtils.toObject( + TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); + // only specially treat the columns with both recode and dictionary rcIDs = unionDistinct(rcIDs, dcIDs); + // remove hash recoded. // todo potentially wrong and remove? + rcIDs = except(rcIDs, hcIDs); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); - + + // set the remaining columns to passthrough. + List ptIDs = UtilFunctions.getSeqList(1, len, 1); + // except recoded columns + ptIDs = except(ptIDs, rcIDs); + // binned columns + ptIDs = except(ptIDs, binIDs); + // hashed columns + ptIDs = except(ptIDs, hcIDs); // remove hashed columns + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.FP64; } + + // collect all the decoders in one list. + List ldecoders = new ArrayList<>(); if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, - ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !rcIDs.isEmpty() ) { + // todo figure out if we need to handle rc columns with regards to dictionary offsets. ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 5b6bf7a093e..c2de3ec1df3 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -49,7 +49,7 @@ protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { _dcCols = dcCols; } - public DecoderPassThrough() { super(null, null); } + // public DecoderPassThrough() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -61,13 +61,12 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { int clen = Math.min(_colList.length, out.getNumColumns()); - for( int i=rl; i _dcCols[ix2] ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + String v = meta.getString( 0,_dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 91d9fb62146..107f886df40 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.Pair; @@ -46,12 +47,12 @@ public class DecoderRecode extends Decoder private static final long serialVersionUID = -3784249774608228805L; private HashMap[] _rcMaps = null; - private Object[][] _rcMapsDirect = null; + // private Object[][] _rcMapsDirect = null; private boolean _onOut = false; - public DecoderRecode() { - super(null, null); - } + // public DecoderRecode() { + // super(null, null); + // } protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { super(schema, rcCols); @@ -59,8 +60,10 @@ protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { } public Object getRcMapValue(int i, long key) { - return (_rcMapsDirect != null) ? - _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key); + // LOG.error(_rcMapsDirect); + // return (_rcMapsDirect != null) ? + // _rcMapsDirect[i][(int)key-1] : + return _rcMaps[i].get(key); } @Override @@ -129,27 +132,33 @@ public void initMetaData(FrameBlock meta) { for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - _rcMapsDirect = new Object[_rcMaps.length][]; - for( int i=0; i<_rcMaps.length; i++ ) { - Object[] arr = new Object[(int)max[i]]; - for(Entry e1 : _rcMaps[i].entrySet()) - arr[e1.getKey().intValue()-1] = e1.getValue(); - _rcMapsDirect[i] = arr; - } - } + // if( Arrays.stream(max).allMatch(v -> v < Integer.MAX_VALUE) ) { + // _rcMapsDirect = new Object[_rcMaps.length][]; + // for( int i=0; i<_rcMaps.length; i++ ) { + // Object[] arr = new Object[(int)max[i]]; + // for(Entry e1 : _rcMaps[i].entrySet()) + // arr[e1.getKey().intValue()-1] = e1.getValue(); + // _rcMapsDirect[i] = arr; + // } + // } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 400b7f64ffc..361c9c52135 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,7 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); - meta.set(0, _colID - 1, String.valueOf(_K)); + // set metadata of hash columns to magical hash value + k + meta.set(0, _colID - 1, String.format("¿%d" , _K)); + return meta; } @@ -154,7 +156,7 @@ public FrameBlock getMetaData(FrameBlock meta) { public void initMetaData(FrameBlock meta) { if(meta == null || meta.getNumRows() <= 0) return; - _K = UtilFunctions.parseToLong(meta.get(0, _colID - 1).toString()); + _K = UtilFunctions.parseToLong(meta.getString(0, _colID - 1).substring(1)); } @Override diff --git a/src/main/java/org/apache/sysds/utils/DoubleParser.java b/src/main/java/org/apache/sysds/utils/DoubleParser.java index 9c77a3e95c8..c0122f8061f 100644 --- a/src/main/java/org/apache/sysds/utils/DoubleParser.java +++ b/src/main/java/org/apache/sysds/utils/DoubleParser.java @@ -184,7 +184,7 @@ public interface DoubleParser { 0x8e679c2f5e44ff8fL}; public static double parseFloatingPointLiteral(String str, int offset, int endIndex) { - if(endIndex > 100) + if(endIndex > 100)// long string return Double.parseDouble(str); // Skip leading whitespace int index = skipWhitespace(str, offset, endIndex); @@ -197,9 +197,10 @@ public static double parseFloatingPointLiteral(String str, int offset, int endIn } // Parse NaN or Infinity (this occurs rarely) - if(ch >= 'I') - return Double.parseDouble(str); - else if(str.charAt(endIndex - 1) >= 'a') + // : is the first character after numbers. + // 0 is the first number. + // we use the last position, since this is not allowed to be other values than a number. + if(str.charAt(endIndex - 1) > '9' || str.charAt(endIndex - 1) < '0') return Double.parseDouble(str); final double val = parseDecFloatLiteral(str, index, offset, endIndex); diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 195e36d6065..45c05cca70a 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -32,6 +32,7 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; +import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; @@ -2927,6 +2928,25 @@ public static void writeTestScalar(String file, double value) { } } + + /** + * Write scalar to file + * + * @param file File to write to + * @param value Value to write + */ + public static void writeTestScalar(String file, String value) { + try { + DataOutputStream out = new DataOutputStream(new FileOutputStream(file)); + try(PrintWriter pw = new PrintWriter(out)) { + pw.println(value); + } + } + catch(IOException e) { + fail("unable to write test scalar (" + file + "): " + e.getMessage()); + } + } + /** * Write scalar to file * diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index c3efeea4014..967f344579b 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -468,6 +468,12 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } } private class FakeDictBasedColGroup extends ADictBasedColGroup { @@ -777,5 +783,11 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java index 194f581121a..a5bd3cebfb0 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.fail; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.junit.Test; @@ -115,6 +116,8 @@ public void testJoinWithSecondSubpartLeft() { private void partJoinVerification(IEncode er) { boolean incorrectUnique = e.getUnique() != er.getUnique(); + er.extractFacts(10000, 1.0, 1.0, new CompressionSettingsBuilder().create()); + if(incorrectUnique) { StringBuilder sb = new StringBuilder(); sb.append("\nFailed joining sub parts to recreate whole."); diff --git a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java index 182bd7fa37e..5a298f145ec 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java @@ -81,6 +81,10 @@ public static Collection data() { // Both Sparse and end dense joined tests.add(createT(1, 0.2, 10, 10, 0.1, 2, 1000, 1231521)); + + tests.add(createT(1, 1.0, 100, 1, 1.0, 10, 10000, 132)); + tests.add(createT(1, 1.0, 1000, 1, 1.0, 10, 10000, 132)); + return tests; } diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java index c6d52a70a51..872ec79c1f1 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java @@ -49,7 +49,7 @@ protected String getTestDir() { @Test public void testTranspose_CP() { - runTest(1500, 20, 1, 1, ExecType.CP, "transpose"); + runTest(1500, 20, 2, 1, ExecType.CP, "transpose"); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java index ee6a2953980..18ca2fbc454 100644 --- a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java +++ b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java @@ -270,4 +270,96 @@ protected void toStringTestHelper(ExecMode platform, String testName, String exp DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } + + @Test + public void testPrintWithDecimal(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "22"; + String expectedOutput = "22.00\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal2(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "5.244058388023880"; + String expectedOutput = "5.24\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal3(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "5.244058388023880"; + String expectedOutput = "5.2440583880\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal4(){ + String testName = "ToString12"; + + String decimalPoints = "4"; + String value = "5.244058388023880"; + String expectedOutput = "5.2441\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal5(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "0.000000008023880"; + String expectedOutput = "0.0000000080\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + protected void toStringTestHelper2(ExecMode platform, String testName, String expectedOutput, String decimalPoints, String value) { + ExecMode platformOld = rtplatform; + + rtplatform = platform; + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if (rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + try { + // Create and load test configuration + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[]{"-args", output(OUTPUT_NAME), value, decimalPoints}; + + // Run DML and R scripts + runTest(true, false, null, -1); + + // Compare output strings + String output = TestUtils.readDMLString(output(OUTPUT_NAME)); + TestUtils.compareScalars(expectedOutput, output); + } + finally { + // Reset settings + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java new file mode 100644 index 00000000000..848bf1e229a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class GetCategoricalMaskTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskTest.class.getName()); + + private final static String TEST_NAME1 = "GetCategoricalMaskTest"; + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeApplyTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"})); + } + + @Test + public void testRecode() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 1, 1.0); + String spec = "{\"ids\": true, \"recode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testRecode2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8, ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 2, new double[] {0, 1}); + + String spec = "{\"ids\": true, \"recode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {0, 1, 1, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {1, 1, 1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + private void runTransformTest(FrameBlock fb, String spec, MatrixBlock expected) throws Exception { + try { + + getAndLoadTestConfiguration(TEST_NAME1); + + String inF = input("F-In"); + String inS = input("spec"); + + TestUtils.writeTestFrame(inF, fb, fb.getSchema(), FileFormat.CSV); + TestUtils.writeTestScalar(input("spec"), spec); + + String out = output("ret"); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-args", inF, inS, out}; + + runTest(true, false, null, -1); + + MatrixBlock result = TestUtils.readBinary(out); + + TestUtils.compareMatrices(expected, result, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + +} diff --git a/src/test/scripts/functions/misc/ToString12.dml b/src/test/scripts/functions/misc/ToString12.dml new file mode 100644 index 00000000000..4f120630b75 --- /dev/null +++ b/src/test/scripts/functions/misc/ToString12.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix($2, rows=1, cols=1) +str = toString(X, rows=3, cols=3, decimal=$3) +write(str, $1) diff --git a/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml new file mode 100644 index 00000000000..7e3f098c2fb --- /dev/null +++ b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($1, data_type="frame", format="csv"); + +jspec = read($2, data_type="scalar", value_type="string"); + +[X, M] = transformencode(target=F1, spec=jspec); + +Cm = getCategoricalMask(M, jspec) + +write(Cm, $3, format="csv"); +