Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] JAVA 17 BWARE COMMIT #2157

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/main/java/org/apache/sysds/hops/AggBinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,7 @@ private boolean isApplicableForTransitiveSparkExecType(boolean left)
|| (left && !isLeftTransposeRewriteApplicable(true)))
&& getInput(index).getParent().size()==1 //bagg is only parent
&& !getInput(index).areDimsBelowThreshold()
&& (getInput(index).optFindExecType() == ExecType.SPARK
|| (getInput(index) instanceof DataOp && ((DataOp)getInput(index)).hasOnlyRDD()))
&& getInput(index).hasSparkOutput()
&& getInput(index).getOutputMemEstimate()>getOutputMemEstimate();
}

Expand Down
38 changes: 24 additions & 14 deletions src/main/java/org/apache/sysds/hops/BinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,8 @@ protected ExecType optFindExecType(boolean transitive) {

checkAndSetForcedPlatform();

DataType dt1 = getInput().get(0).getDataType();
DataType dt2 = getInput().get(1).getDataType();
final DataType dt1 = getInput(0).getDataType();
final DataType dt2 = getInput(1).getDataType();

if( _etypeForced != null ) {
setExecType(_etypeForced);
Expand Down Expand Up @@ -796,18 +796,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) {
checkAndSetInvalidCPDimsAndSize();
}

//spark-specific decision refinement (execute unary scalar w/ spark input and
// spark-specific decision refinement (execute unary scalar w/ spark input and
// single parent also in spark because it's likely cheap and reduces intermediates)
if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED &&
getDataType().isMatrix() // output should be a matrix
&& (dt1.isScalar() || dt2.isScalar()) // one side should be scalar
&& supportsMatrixScalarOperations() // scalar operations
&& !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint
&& getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent
&& !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec
&& getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) {
// pull unary scalar operation into spark
_etype = ExecType.SPARK;
if(transitive // we allow transitive Spark operations. continue sequences of spark operations
&& _etype == ExecType.CP // The instruction is currently in CP
&& _etypeForced != ExecType.CP // not forced CP
&& _etypeForced != ExecType.FED // not federated
&& (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame
) {
final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize();
final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize();
final boolean left = v1 == true; // left side is the vector or scalar
final Hop sparkIn = getInput(left ? 1 : 0);
if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar.
&& (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation
&& sparkIn.getParent().size() == 1 // only one parent
&& !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec
&& sparkIn.optFindExecType() == ExecType.SPARK // input was spark op.
&& !(sparkIn instanceof DataOp) // input is not checkpoint
) {
// pull operation into spark
_etype = ExecType.SPARK;
}
}

if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE &&
Expand Down Expand Up @@ -837,7 +847,7 @@ else if( (op == OpOp2.CBIND && getDataType().isList())
|| (op == OpOp2.RBIND && getDataType().isList())) {
_etype = ExecType.CP;
}

//mark for recompile (forever)
setRequiresRecompileIfNecessary();

Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/apache/sysds/hops/Hop.java
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,12 @@ public final String toString() {
// ========================================================================================


protected boolean isScalarOrVectorBellowBlockSize(){
return getDataType().isScalar() || (dimsKnown() &&
(( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize())
|| _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize()));
}

protected boolean isVector() {
return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) );
}
Expand Down Expand Up @@ -1624,6 +1630,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) {
lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this));
}

protected boolean hasSparkOutput(){
return (this.optFindExecType() == ExecType.SPARK
|| (this instanceof DataOp && ((DataOp)this).hasOnlyRDD()));
}

/**
* Set parse information.
*
Expand Down
34 changes: 24 additions & 10 deletions src/main/java/org/apache/sysds/hops/UnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
} else {
sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
}
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);

if(getDataType() == DataType.FRAME)
return OptimizerUtils.estimateSizeExactFrame(dim1, dim2);
else
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}

@Override
Expand Down Expand Up @@ -463,6 +467,13 @@ public boolean isMetadataOperation() {
|| _op == OpOp1.CAST_AS_LIST;
}

private boolean isDisallowedSparkOps(){
return isCumulativeUnaryOperation()
|| isCastUnaryOperation()
|| _op==OpOp1.MEDIAN
|| _op==OpOp1.IQM;
}

@Override
protected ExecType optFindExecType(boolean transitive)
{
Expand Down Expand Up @@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto
checkAndSetInvalidCPDimsAndSize();
}


//spark-specific decision refinement (execute unary w/ spark input and
//single parent also in spark because it's likely cheap and reduces intermediates)
if( _etype == ExecType.CP && _etypeForced != ExecType.CP
&& getInput().get(0).optFindExecType() == ExecType.SPARK
&& getDataType().isMatrix()
&& !isCumulativeUnaryOperation() && !isCastUnaryOperation()
&& _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM
&& !(getInput().get(0) instanceof DataOp) //input is not checkpoint
&& getInput().get(0).getParent().size()==1 ) //unary is only parent
{
if(_etype == ExecType.CP // currently CP instruction
&& _etype != ExecType.SPARK /// currently not SP.
&& _etypeForced != ExecType.CP // not forced as CP instruction
&& getInput(0).hasSparkOutput() // input is a spark instruction
&& (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame
&& !isDisallowedSparkOps() // is invalid spark instruction
// && !(getInput().get(0) instanceof DataOp) // input is not checkpoint
// && getInput(0).getParent().size() <= 1// unary is only parent
) {
//pull unary operation into spark
_etype = ExecType.SPARK;
}


//mark for recompile (forever)
setRequiresRecompileIfNecessary();
Expand All @@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent
} else {
setRequiresRecompileIfNecessary();
}

return _etype;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReorg;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
Expand Down Expand Up @@ -611,14 +611,6 @@ public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, MatrixVal
public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType tstype, int k) {
// check for transpose type
if(tstype == MMTSJType.LEFT) {
if(isEmpty())
return new MatrixBlock(clen, clen, true);
// create output matrix block
if(out == null)
out = new MatrixBlock(clen, clen, false);
else
out.reset(clen, clen, false);
out.allocateDenseBlock();
CLALibTSMM.leftMultByTransposeSelf(this, out, k);
return out;
}
Expand Down Expand Up @@ -1202,8 +1194,8 @@ public void examSparsity(boolean allowCSR, int k) {
}

@Override
public void sparseToDense(int k) {
// do nothing
public MatrixBlock sparseToDense(int k) {
return this; // do nothing
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.utils.stats.Timing;

Expand Down Expand Up @@ -95,6 +97,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix
if(x.isEmpty())
return returnEmpty(x, out);

if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){
MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k);
return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k));
}

// Morph the columns to efficient types for the operation.
x = filterColGroups(x);
double preFilterTime = t.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ private CLALibTSMM() {
// private constructor
}

public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) {
return leftMultByTransposeSelf(cmb, new MatrixBlock(), k);
}

/**
* Self left Matrix multiplication (tsmm)
*
Expand All @@ -51,17 +55,25 @@ private CLALibTSMM() {
* @param ret The output matrix to put the result into
* @param k The parallelization degree allowed
*/
public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {

final int numColumns = cmb.getNumColumns();
final int numRows = cmb.getNumRows();
if(cmb.isEmpty())
return new MatrixBlock(numColumns, numColumns, true);
// create output matrix block
if(ret == null)
ret = new MatrixBlock(numColumns, numColumns, false);
else
ret.reset(numColumns, numColumns, false);
ret.allocateDenseBlock();
final List<AColGroup> groups = cmb.getColGroups();

final int numColumns = cmb.getNumColumns();
if(groups.size() >= numColumns) {
MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k);
LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k);
return;
return ret;
}
final int numRows = cmb.getNumRows();
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
final boolean overlapping = cmb.isOverlapping();
if(shouldFilter) {
Expand All @@ -77,6 +89,7 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc

ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret));
ret.examSparsity();
return ret;
}

private static void addCorrectionLayer(List<AColGroup> filteredGroups, MatrixBlock result, int nRows, int nCols,
Expand All @@ -86,8 +99,6 @@ private static void addCorrectionLayer(List<AColGroup> filteredGroups, MatrixBlo
addCorrectionLayer(constV, filteredColSum, nRows, retV);
}



private static void tsmmColGroups(List<AColGroup> groups, MatrixBlock ret, int nRows, boolean overlapping, int k) {
if(k <= 1)
tsmmColGroupsSingleThread(groups, ret, nRows);
Expand Down Expand Up @@ -136,12 +147,12 @@ private static void tsmmColGroupsMultiThread(List<AColGroup> groups, MatrixBlock

public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) {
final int nColRow = constV.length;
for(int row = 0; row < nColRow; row++){
for(int row = 0; row < nColRow; row++) {
int offOut = nColRow * row;
final double v1l = constV[row];
final double v2l = filteredColSum[row] + constV[row] * nRow;
for(int col = row; col < nColRow; col++){
ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col];
for(int col = row; col < nColRow; col++) {
ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col];
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,6 @@ public void execute(ExecutionContext ec)
//OPTIMIZATION of ParFOR body (incl all child parfor PBs)
///////
if( _optMode != POptMode.NONE ) {
OptimizationWrapper.setLogLevel(_optLogLevel); //set optimizer log level
OptimizationWrapper.optimize(_optMode, sb, this, ec, _numRuns); //core optimize
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class FederatedWorkloadAnalyzer {
protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName());
Expand Down Expand Up @@ -89,6 +91,12 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, Instr
counter++;
}

}else if (cpIns instanceof MMChainCPInstruction){
final String n1 = cpIns.input1.getName();
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(1);
getOrMakeCounter(mm, Long.parseLong(n1)).incLMM(1);
counter ++;

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
Expand Down Expand Up @@ -109,11 +107,6 @@ public static void optimize( POptMode type, ParForStatementBlock sb, ParForProgr
}
}

public static void setLogLevel( Level optLogLevel ) {
Logger.getLogger("org.apache.sysds.runtime.controlprogram.parfor.opt")
.setLevel( optLogLevel );
}

private static void optimize( POptMode otype, int ck, double cm,
ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, int numRuns )
{
Expand Down
Loading
Loading