diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 8d2b00c1aa8..12ecdb9c9fa 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -562,7 +562,7 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) else //e.g., for append,pow or after inference sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); - ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); } return ret; } diff --git a/src/main/java/org/apache/sysds/hops/DataGenOp.java b/src/main/java/org/apache/sysds/hops/DataGenOp.java index fd3ecb97fa2..b01d6ed3817 100644 --- a/src/main/java/org/apache/sysds/hops/DataGenOp.java +++ b/src/main/java/org/apache/sysds/hops/DataGenOp.java @@ -219,11 +219,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) else { //sparsity-aware estimation (dependent on sparse generation approach); for pure dense generation //we would need to disable sparsity-awareness and estimate via sparsity=1.0 - ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, _sparsity); + ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, _sparsity, getDataType()); } } else { - ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0); + ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0, getDataType()); } return ret; diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index 7be61f4129b..1ae8616001b 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -390,7 +390,7 @@ else if(dt == DataType.FRAME) { || _op == OpOpData.TRANSIENTREAD ) { double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); - ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); } // output memory estimate is not required for "write" nodes (just input) } diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java b/src/main/java/org/apache/sysds/hops/IndexingOp.java index 457c5b44e0a..22e0b30e0d9 100644 --- a/src/main/java/org/apache/sysds/hops/IndexingOp.java +++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java @@ -204,7 +204,7 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) { // only dense right indexing supported on GPU double sparsity = isGPUEnabled() ? 1.0 : OptimizerUtils.getSparsity(dim1, dim2, nnz); - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); } @Override diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index a3161c57230..169ecf5a4a3 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -26,6 +26,7 @@ import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.FileFormat; @@ -63,8 +64,8 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; -import org.apache.sysds.utils.stats.InfrastructureAnalyzer; import org.apache.sysds.utils.MemoryEstimates; +import org.apache.sysds.utils.stats.InfrastructureAnalyzer; public class OptimizerUtils { @@ -822,6 +823,13 @@ public static long estimateSizeExactSparsity(long nrows, long ncols, double sp) return MatrixBlock.estimateSizeInMemory(nrows,ncols,sp); } + public static long estimateSizeExactSparsity(long nrows, long ncols, double sp, DataType dt){ + if(dt == DataType.FRAME) + return estimateSizeExactFrame(nrows, ncols); + else + return estimateSizeExactSparsity(nrows, ncols, sp); + } + /** * Estimates the footprint (in bytes) for a partitioned in-memory representation of a * matrix with the given matrix characteristics diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java b/src/main/java/org/apache/sysds/hops/ReorgOp.java index 576ccaa83a1..df6b4381aeb 100644 --- a/src/main/java/org/apache/sysds/hops/ReorgOp.java +++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java @@ -232,7 +232,7 @@ public void computeMemEstimate(MemoTable memo){ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) { //no dedicated mem estimation per op type, because always propagated via refreshSizeInformation double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); } @Override diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 91f3a5ec584..34da36dd13c 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -366,7 +366,7 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); } @Override