Skip to content

Commit

Permalink
1. Make XGBoost always deterministic partitioned.
Browse files Browse the repository at this point in the history
2. Make XGBoost always repartition input to num of works.
3. Make XGBoost partition key as group column if group column exists and row hash if no group column exists.
  • Loading branch information
jinmfeng001 committed Aug 16, 2023
1 parent 57f812e commit 780cf7d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ object PreXGBoost extends PreXGBoostProvider {
val group = est match {
case regressor: XGBoostRegressor =>
// get group column, if group is not defined, default to lit(-1)
Some(
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
defaultGroupColumn
} else col(regressor.getGroupCol)
)
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
None
} else Some(col(regressor.getGroupCol))
case _ => None

}
Expand All @@ -144,7 +142,7 @@ object PreXGBoost extends PreXGBoostProvider {
})

(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
est.getNumWorkers), evalSets, xgbInput)

case _ => throw new RuntimeException("Unsupporting " + estimator)
}
Expand Down Expand Up @@ -379,7 +377,7 @@ object PreXGBoost extends PreXGBoostProvider {
xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory))
Iterator.single(buildWatches)
})
}).cache()
} else {
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
labeledPointGroupSets => {
Expand All @@ -390,7 +388,7 @@ object PreXGBoost extends PreXGBoostProvider {
},
getCacheDirName(xgbExecutionParam.useExternalMemory))
Iterator.single(buildWatches)
})
}).cache()
}
}

Expand Down Expand Up @@ -467,7 +465,7 @@ object PreXGBoost extends PreXGBoostProvider {
xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory))
Iterator.single(buildWatches)
}}
}}.cache()
} else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
mapPartitions {
Expand All @@ -479,7 +477,7 @@ object PreXGBoost extends PreXGBoostProvider {
},
getCacheDirName(xgbExecutionParams.useExternalMemory))
Iterator.single(buildWatches)
}
}.cache()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {

def needDeterministicRepartitioning: Boolean = {
isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
isDefined(checkpointInterval) && getCheckpointInterval > 0
}

/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,25 @@ object DataUtils extends Serializable {

private def attachPartitionKey(
row: Row,
deterministicPartition: Boolean,
numWorkers: Int,
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
if (deterministicPartition) {
(math.abs(row.hashCode() % numWorkers), xgbLp)
xgbLp: XGBLabeledPoint,
group: Option[Int]): (Int, XGBLabeledPoint) = {
// If group exists, we must use group as key to make sure instances for a group are
// the same partition.
if (group.isDefined){
(group.get % numWorkers, xgbLp)
// If no group exists, we can use row hash as key for the repartition
} else {
(1, xgbLp)
(math.abs(row.hashCode() % numWorkers), xgbLp)
}
}

private def repartitionRDDs(
deterministicPartition: Boolean,
numWorkers: Int,
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
if (deterministicPartition) {
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
rdd => rdd.map(_._2)
}
} else {
arrayOfRDDs.map(rdd => {
if (rdd.getNumPartitions != numWorkers) {
rdd.map(_._2).repartition(numWorkers)
} else {
rdd.map(_._2)
}
})
}
}

/** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
Expand All @@ -107,8 +99,7 @@ object DataUtils extends Serializable {
weight: Column,
baseMargin: Column,
group: Option[Column],
numWorkers: Int,
deterministicPartition: Boolean)
numWorkers: Int)

/**
* convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
Expand All @@ -122,8 +113,7 @@ object DataUtils extends Serializable {
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {

packedParams match {
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
deterministicPartition) =>
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers) =>
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
Expand All @@ -141,18 +131,18 @@ object DataUtils extends Serializable {
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
attachPartitionKey(row, numWorkers, xgbLp, Some(group))
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
baseMargin = baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
attachPartitionKey(row, numWorkers, xgbLp, None)
}
}
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
repartitionRDDs(numWorkers, arrayOfRDDs)

case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,6 @@ import org.apache.spark.sql.functions._

class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {

test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Classifier)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbClassifier = new XGBoostClassifier(paramMap)
assert(xgbClassifier.needDeterministicRepartitioning)
}

test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Regressor)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbRegressor = new XGBoostRegressor(paramMap)
assert(xgbRegressor.needDeterministicRepartitioning)
}

test("deterministic partitioning takes effect with various parts of data") {
val trainingDF = buildDataFrame(Classification.train)
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
Expand All @@ -62,8 +42,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true),
numWorkers),
df
).head)
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
Expand Down Expand Up @@ -97,8 +76,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite
lit(1.0),
lit(Float.NaN),
None,
10,
deterministicPartition = true), df
10), df
).head

val partitionsSizes = dfRepartitioned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
(id, lp.label, lp.features)
}.toDF("id", "label", "features")
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
val exception = intercept[Exception]{
xgb.fit(repartitioned)
}
exception.getMessage.contains("ml.dmlc.xgboost4j.java.XGBoostError")
}
}

0 comments on commit 780cf7d

Please sign in to comment.