Skip to content

Commit 853934a

Browse files
committed
1. Make XGBoost always deterministic partitioned.
2. Make XGBoost always repartition input to num of works. 3. Make XGBoost partition key as group column if group column exists and row hash if no group column exists.
1 parent 57f812e commit 853934a

File tree

5 files changed

+27
-63
lines changed

5 files changed

+27
-63
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala

+8-10
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,9 @@ object PreXGBoost extends PreXGBoostProvider {
127127
val group = est match {
128128
case regressor: XGBoostRegressor =>
129129
// get group column, if group is not defined, default to lit(-1)
130-
Some(
131-
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
132-
defaultGroupColumn
133-
} else col(regressor.getGroupCol)
134-
)
130+
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
131+
None
132+
} else Some(col(regressor.getGroupCol))
135133
case _ => None
136134

137135
}
@@ -144,7 +142,7 @@ object PreXGBoost extends PreXGBoostProvider {
144142
})
145143

146144
(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
147-
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
145+
est.getNumWorkers), evalSets, xgbInput)
148146

149147
case _ => throw new RuntimeException("Unsupporting " + estimator)
150148
}
@@ -379,7 +377,7 @@ object PreXGBoost extends PreXGBoostProvider {
379377
xgbExecutionParam.allowNonZeroForMissing),
380378
getCacheDirName(xgbExecutionParam.useExternalMemory))
381379
Iterator.single(buildWatches)
382-
})
380+
}).cache()
383381
} else {
384382
coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions(
385383
labeledPointGroupSets => {
@@ -390,7 +388,7 @@ object PreXGBoost extends PreXGBoostProvider {
390388
},
391389
getCacheDirName(xgbExecutionParam.useExternalMemory))
392390
Iterator.single(buildWatches)
393-
})
391+
}).cache()
394392
}
395393
}
396394

@@ -467,7 +465,7 @@ object PreXGBoost extends PreXGBoostProvider {
467465
xgbExecutionParams.allowNonZeroForMissing),
468466
getCacheDirName(xgbExecutionParams.useExternalMemory))
469467
Iterator.single(buildWatches)
470-
}}
468+
}}.cache()
471469
} else {
472470
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
473471
mapPartitions {
@@ -479,7 +477,7 @@ object PreXGBoost extends PreXGBoostProvider {
479477
},
480478
getCacheDirName(xgbExecutionParams.useExternalMemory))
481479
Iterator.single(buildWatches)
482-
}
480+
}.cache()
483481
}
484482
}
485483

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala

-5
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@ private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
2828
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
2929
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
3030

31-
def needDeterministicRepartitioning: Boolean = {
32-
isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
33-
isDefined(checkpointInterval) && getCheckpointInterval > 0
34-
}
35-
3631
/**
3732
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
3833
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala

+13-23
Original file line numberDiff line numberDiff line change
@@ -72,33 +72,25 @@ object DataUtils extends Serializable {
7272

7373
private def attachPartitionKey(
7474
row: Row,
75-
deterministicPartition: Boolean,
7675
numWorkers: Int,
77-
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
78-
if (deterministicPartition) {
79-
(math.abs(row.hashCode() % numWorkers), xgbLp)
76+
xgbLp: XGBLabeledPoint,
77+
group: Option[Int]): (Int, XGBLabeledPoint) = {
78+
// If group exists, we must use group as key to make sure instances for a group are
79+
// the same partition.
80+
if (group.isDefined){
81+
(group.get % numWorkers, xgbLp)
82+
// If no group exists, we can use row hash as key for the repartition
8083
} else {
81-
(1, xgbLp)
84+
(math.abs(row.hashCode() % numWorkers), xgbLp)
8285
}
8386
}
8487

8588
private def repartitionRDDs(
86-
deterministicPartition: Boolean,
8789
numWorkers: Int,
8890
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
89-
if (deterministicPartition) {
9091
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
9192
rdd => rdd.map(_._2)
9293
}
93-
} else {
94-
arrayOfRDDs.map(rdd => {
95-
if (rdd.getNumPartitions != numWorkers) {
96-
rdd.map(_._2).repartition(numWorkers)
97-
} else {
98-
rdd.map(_._2)
99-
}
100-
})
101-
}
10294
}
10395

10496
/** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
@@ -107,8 +99,7 @@ object DataUtils extends Serializable {
10799
weight: Column,
108100
baseMargin: Column,
109101
group: Option[Column],
110-
numWorkers: Int,
111-
deterministicPartition: Boolean)
102+
numWorkers: Int)
112103

113104
/**
114105
* convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
@@ -122,8 +113,7 @@ object DataUtils extends Serializable {
122113
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
123114

124115
packedParams match {
125-
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
126-
deterministicPartition) =>
116+
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers) =>
127117
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
128118
featuresCol,
129119
weight.cast(FloatType),
@@ -141,18 +131,18 @@ object DataUtils extends Serializable {
141131
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
142132
}
143133
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
144-
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
134+
attachPartitionKey(row, numWorkers, xgbLp, Some(group))
145135
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
146136
val (size, indices, values) = features match {
147137
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
148138
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
149139
}
150140
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
151141
baseMargin = baseMargin)
152-
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
142+
attachPartitionKey(row, numWorkers, xgbLp, None)
153143
}
154144
}
155-
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
145+
repartitionRDDs(numWorkers, arrayOfRDDs)
156146

157147
case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
158148
}

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala

+2-24
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,6 @@ import org.apache.spark.sql.functions._
2525

2626
class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
2727

28-
test("perform deterministic partitioning when checkpointInternal and" +
29-
" checkpointPath is set (Classifier)") {
30-
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
31-
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
32-
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
33-
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
34-
val xgbClassifier = new XGBoostClassifier(paramMap)
35-
assert(xgbClassifier.needDeterministicRepartitioning)
36-
}
37-
38-
test("perform deterministic partitioning when checkpointInternal and" +
39-
" checkpointPath is set (Regressor)") {
40-
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
41-
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
42-
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
43-
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
44-
val xgbRegressor = new XGBoostRegressor(paramMap)
45-
assert(xgbRegressor.needDeterministicRepartitioning)
46-
}
47-
4828
test("deterministic partitioning takes effect with various parts of data") {
4929
val trainingDF = buildDataFrame(Classification.train)
5030
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
@@ -62,8 +42,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite
6242
lit(1.0),
6343
lit(Float.NaN),
6444
None,
65-
numWorkers,
66-
deterministicPartition = true),
45+
numWorkers),
6746
df
6847
).head)
6948
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
@@ -97,8 +76,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite
9776
lit(1.0),
9877
lit(Float.NaN),
9978
None,
100-
10,
101-
deterministicPartition = true), df
79+
10), df
10280
).head
10381

10482
val partitionsSizes = dfRepartitioned

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
6565
(id, lp.label, lp.features)
6666
}.toDF("id", "label", "features")
6767
val xgb = new XGBoostClassifier(paramMap)
68-
xgb.fit(repartitioned)
68+
val exception = intercept[Exception]{
69+
xgb.fit(repartitioned)
70+
}
71+
assert(exception.getMessage.contains("ml.dmlc.xgboost4j.java.XGBoostError"))
6972
}
7073
}

0 commit comments

Comments
 (0)