Skip to content

Commit 25358eb

Browse files
[SPARK-51340][ML][CONNECT] Model size estimation
### What changes were proposed in this pull request? Implement model size estimation. 2 new interfaces are added: 1. `Estimator.estimateModelSize`: This method is used to estimate the model size **before training**. This is an optional interface to be implemented. For linear classification & regression models the method is implemented. For other model such tree model, estimating model size before training is hard, so it is not implemented. 2. `Model.estimatedSize`: This method is used to estimate model size in **local process memory**. The default implementation is to use `SizeEstimator.estimate`, you can overwrite it to achieve more precise estimation. ### Why are the changes needed? For SparkConnect, we want to support model size control in spark connect server side. So this change is needed. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? No. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50278 from WeichenXu123/SPARK-51340-2. Lead-authored-by: Ruifeng Zheng <ruifengz@apache.org> Co-authored-by: Weichen Xu <weichen.xu@databricks.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
1 parent 37d191e commit 25358eb

23 files changed

+749
-11
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ sealed trait Vector extends Serializable {
191191
def compressed: Vector = compressedWithNNZ(numNonzeros)
192192

193193
private[ml] def compressedWithNNZ(nnz: Int): Vector = {
194-
// A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes.
195-
if (1.5 * (nnz + 1.0) < size) {
194+
if (Vectors.getSparseSize(nnz) < Vectors.getDenseSize(size)) {
196195
toSparseWithSize(nnz)
197196
} else {
198197
toDense
@@ -230,6 +229,8 @@ sealed trait Vector extends Serializable {
230229
*/
231230
private[spark] def nonZeroIterator: Iterator[(Int, Double)] =
232231
activeIterator.filter(_._2 != 0)
232+
233+
private[ml] def getSizeInBytes: Long
233234
}
234235

235236
/**
@@ -504,6 +505,27 @@ object Vectors {
504505

505506
/** Max number of nonzero entries used in computing hash code. */
506507
private[linalg] val MAX_HASH_NNZ = 128
508+
509+
private[ml] def getSparseSize(nnz: Long): Long = {
510+
/*
511+
A sparse vector stores one double array, one int array and one int:
512+
8 * values.length + 4 * values.length + arrayHeader * 2 + 4
513+
*/
514+
val doubleBytes = java.lang.Double.BYTES
515+
val intBytes = java.lang.Integer.BYTES
516+
val arrayHeader = 12L
517+
(doubleBytes + intBytes) * nnz + arrayHeader * 2L + 4L
518+
}
519+
520+
private[ml] def getDenseSize(size: Long): Long = {
521+
/*
522+
A dense vector stores one double array:
523+
8 * values.length + arrayHeader
524+
*/
525+
val doubleBytes = java.lang.Double.BYTES
526+
val arrayHeader = 12L
527+
doubleBytes * size + arrayHeader
528+
}
507529
}
508530

509531
/**
@@ -596,6 +618,8 @@ class DenseVector @Since("2.0.0") ( @Since("2.0.0") val values: Array[Double]) e
596618

597619
private[spark] override def activeIterator: Iterator[(Int, Double)] =
598620
iterator
621+
622+
override private[ml] def getSizeInBytes: Long = Vectors.getDenseSize(values.length)
599623
}
600624

601625
@Since("2.0.0")
@@ -845,6 +869,8 @@ class SparseVector @Since("2.0.0") (
845869
val localValues = values
846870
Iterator.tabulate(numActives)(j => (localIndices(j), localValues(j)))
847871
}
872+
873+
override private[ml] def getSizeInBytes: Long = Vectors.getSparseSize(values.length)
848874
}
849875

850876
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,25 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
8181
}
8282

8383
override def copy(extra: ParamMap): Estimator[M]
84+
85+
/**
86+
* For ml connect only.
87+
* Estimate an upper-bound size of the model to be fitted in bytes, based on the
88+
* parameters and the dataset, e.g., using $(k) and numFeatures to estimate a
89+
* k-means model size.
90+
* 1, Only driver side memory usage is counted, distributed objects (like DataFrame,
91+
* RDD, Graph, Summary) are ignored.
92+
* 2, Lazy vals are not counted, e.g., an auxiliary object used in prediction.
93+
* 3, If there is no enough information to get an accurate size, try to estimate the
94+
* upper-bound size, e.g.
95+
* - Given a LogisticRegression estimator, assume the coefficients are dense, even
96+
* though the actual fitted model might be sparse (by L1 penalty).
97+
* - Given a tree model, assume all underlying trees are complete binary trees, even
98+
* though some branches might be pruned or truncated.
99+
* 4, For some model such as tree model, estimating model size before training is hard,
100+
* the `estimateModelSize` method is not supported.
101+
*/
102+
private[spark] def estimateModelSize(dataset: Dataset[_]): Long = {
103+
throw new UnsupportedOperationException
104+
}
84105
}

mllib/src/main/scala/org/apache/spark/ml/Model.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
package org.apache.spark.ml
1919

2020
import org.apache.spark.ml.param.ParamMap
21+
import org.apache.spark.util.SizeEstimator
2122

2223
/**
2324
* A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
2425
*
2526
* @tparam M model type
2627
*/
27-
abstract class Model[M <: Model[M]] extends Transformer {
28+
abstract class Model[M <: Model[M]] extends Transformer { self =>
2829
/**
2930
* The parent estimator that produced this model.
3031
* @note For ensembles' component Models, this value can be null.
@@ -43,4 +44,18 @@ abstract class Model[M <: Model[M]] extends Transformer {
4344
def hasParent: Boolean = parent != null
4445

4546
override def copy(extra: ParamMap): M
47+
48+
/**
49+
* For ml connect only.
50+
* Estimate the size of this model in bytes.
51+
* This is an approximation, the real size might be different.
52+
* 1, Only driver side memory usage is counted, distributed objects (like DataFrame,
53+
* RDD, Graph, Summary) are ignored.
54+
* 2, Lazy vals are not counted, e.g., an auxiliary object used in prediction.
55+
* 3, The default implementation uses `org.apache.spark.util.SizeEstimator.estimate`,
56+
* some models override the default implementation to achieve more precise estimation.
57+
* 4, For 3-rd extension, if external languages are used, it is recommended to override
58+
* this method and return a proper size.
59+
*/
60+
private[spark] def estimatedSize: Long = SizeEstimator.estimate(self)
4661
}

mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,15 @@ class FMClassifier @Since("3.0.0") (
237237

238238
@Since("3.0.0")
239239
override def copy(extra: ParamMap): FMClassifier = defaultCopy(extra)
240+
241+
override def estimateModelSize(dataset: Dataset[_]): Long = {
242+
val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
243+
244+
var size = this.estimateMatadataSize
245+
size += Vectors.getDenseSize(numFeatures) // linear
246+
size += Matrices.getDenseSize(numFeatures, $(factorSize)) // factors
247+
size
248+
}
240249
}
241250

242251
@Since("3.0.0")
@@ -312,6 +321,17 @@ class FMClassificationModel private[classification] (
312321
copyValues(new FMClassificationModel(uid, intercept, linear, factors), extra)
313322
}
314323

324+
override def estimatedSize: Long = {
325+
var size = this.estimateMatadataSize
326+
if (this.linear != null) {
327+
size += this.linear.getSizeInBytes
328+
}
329+
if (this.factors != null) {
330+
size += this.factors.getSizeInBytes
331+
}
332+
size
333+
}
334+
315335
@Since("3.0.0")
316336
override def write: MLWriter =
317337
new FMClassificationModel.FMClassificationModelWriter(this)

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,13 @@ class LinearSVC @Since("2.2.0") (
168168
@Since("3.1.0")
169169
def setMaxBlockSizeInMB(value: Double): this.type = set(maxBlockSizeInMB, value)
170170

171+
private[spark] override def estimateModelSize(dataset: Dataset[_]): Long = {
172+
val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
173+
var size = this.estimateMatadataSize
174+
size += Vectors.getDenseSize(numFeatures) // coefficients
175+
size
176+
}
177+
171178
@Since("2.2.0")
172179
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
173180

@@ -259,7 +266,7 @@ class LinearSVC @Since("2.2.0") (
259266
if (featuresStd(i) != 0.0) rawCoefficients(i) / featuresStd(i) else 0.0
260267
}
261268
val intercept = if ($(fitIntercept)) rawCoefficients.last else 0.0
262-
createModel(dataset, Vectors.dense(coefficientArray), intercept, objectiveHistory)
269+
createModel(dataset, Vectors.dense(coefficientArray).compressed, intercept, objectiveHistory)
263270
}
264271

265272
private def createModel(
@@ -421,6 +428,14 @@ class LinearSVCModel private[classification] (
421428
copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent)
422429
}
423430

431+
private[spark] override def estimatedSize: Long = {
432+
var size = this.estimateMatadataSize
433+
if (this.coefficients != null) {
434+
size += this.coefficients.getSizeInBytes
435+
}
436+
size
437+
}
438+
424439
@Since("2.2.0")
425440
override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)
426441

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import org.apache.spark.rdd.RDD
4545
import org.apache.spark.sql._
4646
import org.apache.spark.sql.types.{DataType, StructType}
4747
import org.apache.spark.storage.StorageLevel
48-
import org.apache.spark.util.VersionUtils
48+
import org.apache.spark.util._
4949

5050
/**
5151
* Params for logistic regression.
@@ -1041,6 +1041,22 @@ class LogisticRegression @Since("1.2.0") (
10411041
(solution, arrayBuilder.result())
10421042
}
10431043

1044+
private[spark] override def estimateModelSize(dataset: Dataset[_]): Long = {
1045+
// TODO: get numClasses and numFeatures together from dataset
1046+
val numClasses = DatasetUtils.getNumClasses(dataset, $(labelCol))
1047+
val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
1048+
1049+
var size = this.estimateMatadataSize
1050+
if (checkMultinomial(numClasses)) {
1051+
size += Matrices.getDenseSize(numFeatures, numClasses) // coefficientMatrix
1052+
size += Vectors.getDenseSize(numClasses) // interceptVector
1053+
} else {
1054+
size += Matrices.getDenseSize(numFeatures, 1) // coefficientMatrix
1055+
size += Vectors.getDenseSize(1) // interceptVector
1056+
}
1057+
size
1058+
}
1059+
10441060
@Since("1.4.0")
10451061
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
10461062
}
@@ -1248,6 +1264,17 @@ class LogisticRegressionModel private[spark] (
12481264
}
12491265
}
12501266

1267+
private[spark] override def estimatedSize: Long = {
1268+
var size = this.estimateMatadataSize
1269+
if (this.coefficientMatrix != null) {
1270+
size += this.coefficientMatrix.getSizeInBytes
1271+
}
1272+
if (this.interceptVector != null) {
1273+
size += this.interceptVector.getSizeInBytes
1274+
}
1275+
size
1276+
}
1277+
12511278
@Since("1.4.0")
12521279
override def copy(extra: ParamMap): LogisticRegressionModel = {
12531280
val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
173173
@Since("1.5.0")
174174
override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
175175

176+
private[spark] override def estimateModelSize(dataset: Dataset[_]): Long = {
177+
val topology = FeedForwardTopology.multiLayerPerceptron($(layers), softmaxOnTop = true)
178+
val expectedWeightSize = topology.layers.map(_.weightSize).sum
179+
180+
var size = this.estimateMatadataSize
181+
size += Vectors.getDenseSize(expectedWeightSize) // weights
182+
size
183+
}
184+
176185
/**
177186
* Train a model using the given dataset and parameters.
178187
* Developers can implement this instead of `fit()` to avoid dealing with schema validation
@@ -328,6 +337,14 @@ class MultilayerPerceptronClassificationModel private[ml] (
328337
copyValues(copied, extra)
329338
}
330339

340+
private[spark] override def estimatedSize: Long = {
341+
var size = this.estimateMatadataSize
342+
if (this.weights != null) {
343+
size += this.weights.getSizeInBytes
344+
}
345+
size
346+
}
347+
331348
@Since("2.0.0")
332349
override def write: MLWriter =
333350
new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,22 @@ class NaiveBayes @Since("1.5.0") (
344344
new NaiveBayesModel(uid, pi.compressed, theta.compressed, sigma.compressed)
345345
}
346346

347+
private[spark] override def estimateModelSize(dataset: Dataset[_]): Long = {
348+
val numClasses = DatasetUtils.getNumClasses(dataset, $(labelCol))
349+
val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
350+
351+
var size = this.estimateMatadataSize
352+
size += Vectors.getDenseSize(numClasses) // pi
353+
size += Matrices.getDenseSize(numClasses, numFeatures) // theta
354+
$(modelType) match {
355+
case Multinomial | Bernoulli | Complement =>
356+
size += Matrices.getDenseSize(0, 0) // sigma
357+
case _ =>
358+
size += Matrices.getDenseSize(numClasses, numFeatures) // sigma
359+
}
360+
size
361+
}
362+
347363
@Since("1.5.0")
348364
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
349365
}
@@ -551,6 +567,20 @@ class NaiveBayesModel private[ml] (
551567
}
552568
}
553569

570+
private[spark] override def estimatedSize: Long = {
571+
var size = this.estimateMatadataSize
572+
if (this.pi != null) {
573+
size += this.pi.getSizeInBytes
574+
}
575+
if (this.theta != null) {
576+
size += this.theta.getSizeInBytes
577+
}
578+
if (this.sigma != null) {
579+
size += this.sigma.getSizeInBytes
580+
}
581+
size
582+
}
583+
554584
@Since("1.5.0")
555585
override def copy(extra: ParamMap): NaiveBayesModel = {
556586
copyValues(new NaiveBayesModel(uid, pi, theta, sigma).setParent(this.parent), extra)

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.annotation.Since
3333
import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector}
3434
import org.apache.spark.ml.util.Identifiable
3535
import org.apache.spark.util.ArrayImplicits._
36+
import org.apache.spark.util.SizeEstimator
3637

3738
/**
3839
* A param with self-contained documentation and optionally default value. Primitive-typed param
@@ -647,6 +648,10 @@ case class ParamPair[T] @Since("1.2.0") (
647648
*/
648649
trait Params extends Identifiable with Serializable {
649650

651+
private[ml] def estimateMatadataSize: Long = {
652+
SizeEstimator.estimate((this.paramMap, this.defaultParamMap, this.uid))
653+
}
654+
650655
/**
651656
* Returns all params sorted by their names. The default implementation uses Java reflection to
652657
* list all public methods that have no arguments and return [[Param]].

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,14 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
350350

351351
@Since("1.6.0")
352352
override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra)
353+
354+
private[spark] override def estimateModelSize(dataset: Dataset[_]): Long = {
355+
val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
356+
357+
var size = this.estimateMatadataSize
358+
size += Vectors.getDenseSize(numFeatures) // coefficients
359+
size
360+
}
353361
}
354362

355363
@Since("1.6.0")
@@ -469,6 +477,14 @@ class AFTSurvivalRegressionModel private[ml] (
469477
.setParent(parent)
470478
}
471479

480+
private[spark] override def estimatedSize: Long = {
481+
var size = this.estimateMatadataSize
482+
if (this.coefficients != null) {
483+
size += this.coefficients.getSizeInBytes
484+
}
485+
size
486+
}
487+
472488
@Since("1.6.0")
473489
override def write: MLWriter =
474490
new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this)

0 commit comments

Comments
 (0)