diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 8738298fed0e7..c94ce35cb250b 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -90,7 +90,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false, val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS, - val checksumMismatchFullRetryEnabled: Boolean = false) + private val _checksumMismatchFullRetryEnabled: Boolean = false) extends Dependency[Product2[K, V]] with Logging { def this( @@ -144,6 +144,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( def shuffleMergeAllowed : Boolean = _shuffleMergeAllowed + def checksumMismatchFullRetryEnabled: Boolean = + _checksumMismatchFullRetryEnabled && !canShuffleMergeBeEnabled() + /** * Stores the location of the list of chosen external shuffle services for handling the * shuffle merge requests from mappers in this shuffle map stage. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 48f1c49e7af23..aa11148514a16 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3488,7 +3488,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep1 = new ShuffleDependency( shuffleMapRdd1, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true + _checksumMismatchFullRetryEnabled = true ) val shuffleId1 = shuffleDep1.shuffleId val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) @@ -3496,7 +3496,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep2 = new ShuffleDependency( shuffleMapRdd2, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true + _checksumMismatchFullRetryEnabled = true ) val shuffleId2 = shuffleDep2.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) @@ -3528,7 +3528,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep = new ShuffleDependency( mapRdd, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true + _checksumMismatchFullRetryEnabled = true ) val shuffleId = shuffleDep.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) @@ -3627,7 +3627,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep1 = new ShuffleDependency( shuffleMapRdd1, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true + _checksumMismatchFullRetryEnabled = true ) val shuffleId1 = shuffleDep1.shuffleId @@ -3636,7 +3636,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep2 = new ShuffleDependency( shuffleMapRdd2, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true + _checksumMismatchFullRetryEnabled = true ) val shuffleId2 = shuffleDep2.shuffleId @@ -3645,7 +3645,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep3 = new ShuffleDependency( shuffleMapRdd3, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true + _checksumMismatchFullRetryEnabled = true ) val shuffleId3 = shuffleDep3.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker) @@ -3859,21 +3859,21 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep1 = new ShuffleDependency( shuffleMapRdd1, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true) + _checksumMismatchFullRetryEnabled = true) val shuffleId1 = shuffleDep1.shuffleId val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) val shuffleDep2 = new ShuffleDependency( shuffleMapRdd2, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true) + _checksumMismatchFullRetryEnabled = true) val shuffleId2 = shuffleDep2.shuffleId val shuffleMapRdd3 = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) val shuffleDep3 = new ShuffleDependency( shuffleMapRdd3, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true) + _checksumMismatchFullRetryEnabled = true) val shuffleId3 = shuffleDep3.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep3), tracker = mapOutputTracker) @@ -3923,7 +3923,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val shuffleDep1 = new ShuffleDependency( shuffleMapRdd1, new HashPartitioner(2), - checksumMismatchFullRetryEnabled = true) + _checksumMismatchFullRetryEnabled = true) val shuffleId1 = shuffleDep1.shuffleId // Submit a job depending on shuffleDep1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index f052bd9068805..a1f693ef5c154 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -495,7 +495,7 @@ object ShuffleExchangeExec { serializer, shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize), - checksumMismatchFullRetryEnabled = SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) + _checksumMismatchFullRetryEnabled = SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) dependency }