From 29b83b32e3358d45a2d7908ad57e393dca1a5386 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 14 Aug 2024 11:57:33 +0200 Subject: [PATCH 1/5] [DO NOT MERGE][skip ci] JAVA 17 BWARE COMMIT --- .../org/apache/sysds/hops/AggBinaryOp.java | 3 +- .../java/org/apache/sysds/hops/BinaryOp.java | 38 +++-- src/main/java/org/apache/sysds/hops/Hop.java | 11 ++ .../java/org/apache/sysds/hops/UnaryOp.java | 34 ++-- .../compress/CompressedMatrixBlock.java | 6 +- .../spark/utils/FrameRDDConverterUtils.java | 45 +++--- .../sysds/runtime/io/FrameWriterTextCSV.java | 13 +- .../sysds/runtime/io/IOUtilFunctions.java | 9 +- .../data/LibAggregateUnarySpecialization.java | 148 ++++++++++++++++++ .../runtime/matrix/data/LibMatrixMult.java | 25 ++- .../runtime/matrix/data/MatrixBlock.java | 28 ++-- .../encoding/EncodeSampleMultiColTest.java | 3 + .../encoding/EncodeSampleUnbalancedTest.java | 4 + .../compress/configuration/CompressForce.java | 2 +- 14 files changed, 278 insertions(+), 91 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index 2cf651f1894..85ce9882ecc 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -439,8 +439,7 @@ private boolean isApplicableForTransitiveSparkExecType(boolean left) || (left && !isLeftTransposeRewriteApplicable(true))) && getInput(index).getParent().size()==1 //bagg is only parent && !getInput(index).areDimsBelowThreshold() - && (getInput(index).optFindExecType() == ExecType.SPARK - || (getInput(index) instanceof DataOp && ((DataOp)getInput(index)).hasOnlyRDD())) + && getInput(index).hasSparkOutput() && getInput(index).getOutputMemEstimate()>getOutputMemEstimate(); } diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 8d2b00c1aa8..cbb154df74f 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -747,8 +747,8 @@ protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - DataType dt1 = getInput().get(0).getDataType(); - DataType dt2 = getInput().get(1).getDataType(); + final DataType dt1 = getInput(0).getDataType(); + final DataType dt2 = getInput(1).getDataType(); if( _etypeForced != null ) { setExecType(_etypeForced); @@ -796,18 +796,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { checkAndSetInvalidCPDimsAndSize(); } - //spark-specific decision refinement (execute unary scalar w/ spark input and + // spark-specific decision refinement (execute unary scalar w/ spark input and // single parent also in spark because it's likely cheap and reduces intermediates) - if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED && - getDataType().isMatrix() // output should be a matrix - && (dt1.isScalar() || dt2.isScalar()) // one side should be scalar - && supportsMatrixScalarOperations() // scalar operations - && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint - && getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent - && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec - && getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) { - // pull unary scalar operation into spark - _etype = ExecType.SPARK; + if(transitive // we allow transitive Spark operations. continue sequences of spark operations + && _etype == ExecType.CP // The instruction is currently in CP + && _etypeForced != ExecType.CP // not forced CP + && _etypeForced != ExecType.FED // not federated + && (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame + ) { + final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize(); + final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize(); + final boolean left = v1 == true; // left side is the vector or scalar + final Hop sparkIn = getInput(left ? 1 : 0); + if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar. + && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation + && sparkIn.getParent().size() == 1 // only one parent + && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec + && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && !(sparkIn instanceof DataOp) // input is not checkpoint + ) { + // pull operation into spark + _etype = ExecType.SPARK; + } } if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE && @@ -837,7 +847,7 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index b32a1a74aab..4a842c69b0f 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -1040,6 +1040,12 @@ public final String toString() { // ======================================================================================== + protected boolean isScalarOrVectorBellowBlockSize(){ + return getDataType().isScalar() || (dimsKnown() && + (( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize()) + || _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize())); + } + protected boolean isVector() { return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) ); } @@ -1624,6 +1630,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) { lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this)); } + protected boolean hasSparkOutput(){ + return (this.optFindExecType() == ExecType.SPARK + || (this instanceof DataOp && ((DataOp)this).hasOnlyRDD())); + } + /** * Set parse information. * diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 91f3a5ec584..e16896b869b 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + + if(getDataType() == DataType.FRAME) + return OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + else + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override @@ -463,6 +467,13 @@ public boolean isMetadataOperation() { || _op == OpOp1.CAST_AS_LIST; } + private boolean isDisallowedSparkOps(){ + return isCumulativeUnaryOperation() + || isCastUnaryOperation() + || _op==OpOp1.MEDIAN + || _op==OpOp1.IQM; + } + @Override protected ExecType optFindExecType(boolean transitive) { @@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto checkAndSetInvalidCPDimsAndSize(); } + //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) - if( _etype == ExecType.CP && _etypeForced != ExecType.CP - && getInput().get(0).optFindExecType() == ExecType.SPARK - && getDataType().isMatrix() - && !isCumulativeUnaryOperation() && !isCastUnaryOperation() - && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM - && !(getInput().get(0) instanceof DataOp) //input is not checkpoint - && getInput().get(0).getParent().size()==1 ) //unary is only parent - { + if(_etype == ExecType.CP // currently CP instruction + && _etype != ExecType.SPARK /// currently not SP. + && _etypeForced != ExecType.CP // not forced as CP instruction + && getInput(0).hasSparkOutput() // input is a spark instruction + && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame + && !isDisallowedSparkOps() // is invalid spark instruction + // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint + // && getInput(0).getParent().size() <= 1// unary is only parent + ) { //pull unary operation into spark _etype = ExecType.SPARK; } + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent } else { setRequiresRecompileIfNecessary(); } - + return _etype; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 48637595741..fd451805a28 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -58,8 +58,8 @@ import org.apache.sysds.runtime.compress.lib.CLALibMMChain; import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; -import org.apache.sysds.runtime.compress.lib.CLALibReplace; import org.apache.sysds.runtime.compress.lib.CLALibReorg; +import org.apache.sysds.runtime.compress.lib.CLALibReplace; import org.apache.sysds.runtime.compress.lib.CLALibReshape; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; @@ -1202,8 +1202,8 @@ public void examSparsity(boolean allowCSR, int k) { } @Override - public void sparseToDense(int k) { - // do nothing + public MatrixBlock sparseToDense(int k) { + return this; // do nothing } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java index 9371d43094c..a5974640cc5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java @@ -90,10 +90,7 @@ public static JavaPairRDD csvToBinaryBlock(JavaSparkContext sc JavaRDD tmp = input.values() .map(new TextToStringFunction()); String tmpStr = tmp.first(); - boolean metaHeader = tmpStr.startsWith(TfUtils.TXMTD_MVPREFIX) - || tmpStr.startsWith(TfUtils.TXMTD_NDPREFIX); - tmpStr = (metaHeader) ? tmpStr.substring(tmpStr.indexOf(delim)+1) : tmpStr; - long rlen = tmp.count() - (hasHeader ? 1 : 0) - (metaHeader ? 2 : 0); + long rlen = tmp.count() ; long clen = IOUtilFunctions.splitCSV(tmpStr, delim).length; mc.set(rlen, clen, mc.getBlocksize(), -1); } @@ -582,14 +579,14 @@ public Iterator> call(Iterator> arg0) _colnames = row.split(_delim); continue; } - if( row.startsWith(TfUtils.TXMTD_MVPREFIX) ) { - _mvMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); - continue; - } - else if( row.startsWith(TfUtils.TXMTD_NDPREFIX) ) { - _ndMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); - continue; - } + // if( row.startsWith(TfUtils.TXMTD_MVPREFIX) ) { + // _mvMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); + // continue; + // } + // else if( row.startsWith(TfUtils.TXMTD_NDPREFIX) ) { + // _ndMeta = Arrays.asList(Arrays.copyOfRange(IOUtilFunctions.splitCSV(row, _delim), 1, (int)_clen+1)); + // continue; + // } //adjust row index for header and meta data rowix += (_hasHeader ? 0 : 1) - ((_mvMeta == null) ? 0 : 2); @@ -670,18 +667,18 @@ public Iterator call(Tuple2 arg0) ret.add(sb.toString()); sb.setLength(0); //reset } - if( !blk.isColumnMetadataDefault() ) { - sb.append(TfUtils.TXMTD_MVPREFIX + _props.getDelim()); - for( int j=0; j data() { // Both Sparse and end dense joined tests.add(createT(1, 0.2, 10, 10, 0.1, 2, 1000, 1231521)); + + tests.add(createT(1, 1.0, 100, 1, 1.0, 10, 10000, 132)); + tests.add(createT(1, 1.0, 1000, 1, 1.0, 10, 10000, 132)); + return tests; } diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java index c6d52a70a51..872ec79c1f1 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java @@ -49,7 +49,7 @@ protected String getTestDir() { @Test public void testTranspose_CP() { - runTest(1500, 20, 1, 1, ExecType.CP, "transpose"); + runTest(1500, 20, 2, 1, ExecType.CP, "transpose"); } @Test From 884113c8102050b030a3d994b0de96922311f0ee Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 14 Mar 2025 13:48:19 +0100 Subject: [PATCH 2/5] remove set opt level --- .../sysds/runtime/controlprogram/ParForProgramBlock.java | 1 - .../controlprogram/parfor/opt/OptimizationWrapper.java | 7 ------- 2 files changed, 8 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index 0f00a966d49..4c492131d97 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -599,7 +599,6 @@ public void execute(ExecutionContext ec) //OPTIMIZATION of ParFOR body (incl all child parfor PBs) /////// if( _optMode != POptMode.NONE ) { - OptimizationWrapper.setLogLevel(_optLogLevel); //set optimizer log level OptimizationWrapper.optimize(_optMode, sb, this, ec, _numRuns); //core optimize } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java index e82da3367ad..22d58d50a40 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java @@ -24,8 +24,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.log4j.Level; -import org.apache.log4j.Logger; import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.OptimizerUtils; @@ -109,11 +107,6 @@ public static void optimize( POptMode type, ParForStatementBlock sb, ParForProgr } } - public static void setLogLevel( Level optLogLevel ) { - Logger.getLogger("org.apache.sysds.runtime.controlprogram.parfor.opt") - .setLevel( optLogLevel ); - } - private static void optimize( POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, int numRuns ) { From 818ab17f39dea5afa8e362e37ac5adbdd5080e80 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 14 Mar 2025 15:08:06 +0100 Subject: [PATCH 3/5] increment on mmchain --- .../federated/FederatedWorkloadAnalyzer.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java index fc0aa3b1a29..b0e37f1f8c6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java @@ -30,6 +30,8 @@ import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction; +import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; public class FederatedWorkloadAnalyzer { protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName()); @@ -89,6 +91,11 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap Date: Fri, 14 Mar 2025 15:15:38 +0100 Subject: [PATCH 4/5] increment counter on mmchain --- .../controlprogram/federated/FederatedWorkloadAnalyzer.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java index b0e37f1f8c6..eb09156db64 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java @@ -95,6 +95,7 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap Date: Fri, 14 Mar 2025 18:20:08 +0100 Subject: [PATCH 5/5] maybe faster --- .../compress/CompressedMatrixBlock.java | 8 ----- .../runtime/compress/lib/CLALibMMChain.java | 7 +++++ .../runtime/compress/lib/CLALibTSMM.java | 29 +++++++++++++------ 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index fd451805a28..5e5f43ad3db 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -611,14 +611,6 @@ public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, MatrixVal public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType tstype, int k) { // check for transpose type if(tstype == MMTSJType.LEFT) { - if(isEmpty()) - return new MatrixBlock(clen, clen, true); - // create output matrix block - if(out == null) - out = new MatrixBlock(clen, clen, false); - else - out.reset(clen, clen, false); - out.allocateDenseBlock(); CLALibTSMM.leftMultByTransposeSelf(this, out, k); return out; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index d82d58e323e..f47f76b7ef3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -30,9 +30,11 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.utils.stats.Timing; @@ -95,6 +97,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix if(x.isEmpty()) return returnEmpty(x, out); + if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){ + MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k); + return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k)); + } + // Morph the columns to efficient types for the operation. x = filterColGroups(x); double preFilterTime = t.stop(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java index a1d47a9b150..e0643572eae 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java @@ -42,6 +42,10 @@ private CLALibTSMM() { // private constructor } + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) { + return leftMultByTransposeSelf(cmb, new MatrixBlock(), k); + } + /** * Self left Matrix multiplication (tsmm) * @@ -51,17 +55,25 @@ private CLALibTSMM() { * @param ret The output matrix to put the result into * @param k The parallelization degree allowed */ - public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + final int numColumns = cmb.getNumColumns(); + final int numRows = cmb.getNumRows(); + if(cmb.isEmpty()) + return new MatrixBlock(numColumns, numColumns, true); + // create output matrix block + if(ret == null) + ret = new MatrixBlock(numColumns, numColumns, false); + else + ret.reset(numColumns, numColumns, false); + ret.allocateDenseBlock(); final List groups = cmb.getColGroups(); - final int numColumns = cmb.getNumColumns(); if(groups.size() >= numColumns) { MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k); LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k); - return; + return ret; } - final int numRows = cmb.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); final boolean overlapping = cmb.isOverlapping(); if(shouldFilter) { @@ -77,6 +89,7 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret)); ret.examSparsity(); + return ret; } private static void addCorrectionLayer(List filteredGroups, MatrixBlock result, int nRows, int nCols, @@ -86,8 +99,6 @@ private static void addCorrectionLayer(List filteredGroups, MatrixBlo addCorrectionLayer(constV, filteredColSum, nRows, retV); } - - private static void tsmmColGroups(List groups, MatrixBlock ret, int nRows, boolean overlapping, int k) { if(k <= 1) tsmmColGroupsSingleThread(groups, ret, nRows); @@ -136,12 +147,12 @@ private static void tsmmColGroupsMultiThread(List groups, MatrixBlock public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) { final int nColRow = constV.length; - for(int row = 0; row < nColRow; row++){ + for(int row = 0; row < nColRow; row++) { int offOut = nColRow * row; final double v1l = constV[row]; final double v2l = filteredColSum[row] + constV[row] * nRow; - for(int col = row; col < nColRow; col++){ - ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; + for(int col = row; col < nColRow; col++) { + ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; } } }